Scalable Collaborative Learning via Representation Sharing

Decentralization and Trustworthy ML Workshop @ NeurIPS 2022

F. Berdoz, A. Singh, M. Jaggi, R. Raskar

EPFL, Switzerland; MIT, USA

federated-learningknowledge-distillationprivacycollaborative-learning

Abstract

Privacy-preserving machine learning has become a key conundrum for multi-party artificial intelligence. Federated learning (FL) and Split Learning (SL) are two frameworks that enable collaborative learning while keeping the data private (on device). In this work, we present a novel approach for privacy-preserving machine learning, where the clients collaborate via online knowledge distillation using a contrastive loss. The goal is to ensure that the participants learn similar features on similar classes without sharing their input data. For cross-device applications, this approach increases the utility of the models compared to independent learning and other federated knowledge distillation schemes, is communication efficient and is scalable with the number of clients.

Overview

Architecture of the proposed collaborative learning framework showing clients sharing per-class feature representations through a relay server, with local training combining cross-entropy, knowledge distillation, and contrastive losses.
Figure 1: Overview of the proposed framework. Each client trains a local model and shares per-class averaged feature representations (prototypes) with a relay server. The server aggregates prototypes across clients and broadcasts them back. Local training combines a standard cross-entropy loss with a feature-based knowledge distillation loss and a contrastive discriminator loss.

Privacy-preserving machine learning has become a key challenge for multi-party AI systems. Federated learning (FL) and split learning (SL) enable collaborative training while keeping data on device, but they face limitations in communication efficiency and scalability. This paper presents a novel approach where clients collaborate via online knowledge distillation using a contrastive loss on feature representations.

How it works

The framework has three key components:

  1. Prototype sharing: Each client computes per-class averaged representations from its model’s last hidden layer and shares them with a relay server. The server aggregates prototypes across clients and broadcasts the global averages back.
  2. Feature-based knowledge distillation: A KD loss minimizes the L2 distance between each client’s local representations and the global prototypes for the same class, encouraging clients to learn similar features.
  3. Contrastive discrimination: A discriminator loss distinguishes whether two feature vectors come from the same or different classes, maximizing mutual information between client representations.

A theoretical analysis shows that this objective maximizes a lower bound on the mutual information between student and teacher representations.

Communication efficiency

Communication scales as O(M * C * d’) per round, where M is the number of clients, C is the number of classes, and d’ is the feature dimension. Critically, this does not depend on the model size D, making the approach orders of magnitude more efficient than federated learning for large models.

Results

Heatmap showing test accuracy as a function of the knowledge distillation and discriminator loss weights.
Figure 2: Hyperparameter ablation on MNIST with 10 clients. The optimal configuration uses a knowledge distillation weight of 10 and a discriminator weight of 1.

Key numbers

  • MNIST with 10 clients (LeNet-5): 82.07% accuracy vs. FL 70.06%, independent learning 72.86%, and federated distillation 77.90%.
  • Fashion-MNIST (ResNet-9): Outperforms independent learning and federated distillation, with competitive performance against FL.
  • Communication: Orders of magnitude less than FL, since only feature prototypes (not model parameters) are exchanged.
  • Regularization effect: With 2 clients on MNIST, the method even outperforms centralized training, suggesting the prototype exchange acts as a form of regularization.
Key takeaway: Sharing per-class feature representations instead of model parameters enables scalable, communication-efficient collaborative learning. The approach outperforms independent learning and federated distillation while requiring orders of magnitude less communication than federated learning.

Citation

@misc{berdoz2022scalable,
  author = {Berdoz, F. and Singh, A. and Jaggi, M. and Raskar, R.},
  title = {{Scalable Collaborative Learning via Representation Sharing}},
  note = {Best Paper Runner-up at NeurIPS Workshop on Decentralization and Trustworthy ML in Web3. arXiv:2211.10943},
  year = {2022}
}