Catastrophic Forgetting in Split Federated Learning

Discover how catastrophic forgetting affects split federated learning (SFL) when data is heterogeneous and why the order of processing influences model accuracy.

Photo Dimitra-Tsigkari

Dimitra Tsigkari Follow

Reading time: 6 min

Deep neural networks are powerful, however, they tend to forget previously learned information while learning new information. This is called catastrophic forgetting. This phenomenon has been extensively studied in the continual learning setting, where a machine learning (ML) model is learning continuously from an ever-evolving environment. Catastrophic forgetting seems however to appear in many other settings, for example, in distributed learning. In this article, we study the setting of Split Federated Learning (SFL), a novel paradigm of distributed learning that allows devices to collaboratively train a model while offloading part of the training to a computationally powerful server. Data heterogeneity across devices can make forgetting worse. To tackle this, we designed Hydra, a new method inspired by multi-head neural networks that keeps models accurate by reducing the effect of forgetting.

What is Catastrophic Forgetting?

In real-life machine learning scenarios, a model needs to learn from a continuous stream of inputs (data). However, the knowledge of Task 1, e.g., recognizing vehicles or airplanes in images (for image recognition models), may be disrupted/lost upon acquiring knowledge related to the next task, Task 2, e.g., recognizing cats or dogs. This phenomenon is called catastrophic forgetting.  In fact, the parameters of the ML model are tuned for Task 1, and then they are adjusted to reflect the knowledge related to Task 2. However, the new parameters might overwrite the knowledge related to Task 1. This is related to the plasticity-stability trade-off: the model’s parameters should be able to adapt to the new knowledge acquired (plasticity), while, at the same time, be able to retain old knowledge (memory stability).   

Catastrophic forgetting has been widely studied in continual learning, where an ML model is learning continuously from an ever-evolving environment, and mitigation methods are comprising replaying old data, regularization, or neuro-inspired techniques. Yet its effects in distributed learning are less explored.

What Is Split Federated Learning?

In distributed learning, devices collaboratively train an ML model without sharing their local data. For example, our smartphones train an ML model to be able to classify pictures, but without sharing our own pictures with others. Now, we can think of the ML model as a factory assembly line. Devices may do the first steps of processing locally and then they send partially processed results to a central “finishing station” (a central server). This is Split Federated Learning (SFL). In detail, SFL is a distributed learning method where part of the devices’ training is offloaded to a server. This is particularly useful in cases where devices’ resources are insufficient to perform on-device training (for example, a small sensor in Internet-of-Things scenarios). In practice, the ML model (i.e., a deep neural network) is split into two parts of consecutive layers (as illustrated on the left of the figure below): part-1 is trained locally at each device, and part-2 by the server.

Left: An example of a neural network of 7 layers split into Part-1 and part-2 in the setting of Split Federated Learning. Right: The steps of the processing workflow of Split Federated Learning

The steps of the processing workflow of SFL for 3 devices is depicted in the figure above (right). At the beginning of every training round, the devices process a part (batch) of their data through part-1 of the model, send the intermediate results to the server, the server sequentially trains part-2 based on the results received from each device and sends the results to the devices, and finally the devices update their local models. These steps are repeated for all the batches, and once all devices have processed all their data, all model updates are combined into a global model (by the aggregator), before a new round begins.

Key Insights into Catastrophic Forgetting in Split Federated Learning

It is clear that, in SFL, part-1 of the model is trained as in Federated Learning (FL). On the other hand, part-2 is trained over intermediate results that devices send in sequential order. This dual aspect of SFL makes it particularly susceptible to catastrophic forgetting when the devices’ data is very heterogeneous (for example, one device has mostly pictures of cats and the other has pictures of dogs). In practice, our experiments (on image classification tasks) reveal that the processing order at the server has a significant impact on catastrophic forgetting in the ML model.

Accuracy per processing position (at the server) and global accuracy achieved by MobileNet in CIFAR-10 under heterogeneous data distributions among 10 devices.

The figure above showcases the catastrophic forgetting in SFL as a result of the processing order at the server when devices’ data is heterogeneous. In particular, we assume that the data of each device contains a highly represented label (e.g., different types of animals). We focus on the MobileNet model (for image classification) that is trained on the CIFAR-10 dataset (that contains 10 labels). The blue and orange lines depict the accuracy of the labels that are highly represented in devices that are processed at the first and last position respectively. We see that, under heterogeneous data, the label that is highly represented at the device whose intermediate results are processed last at the server outperforms the label of the device that was in the first position. The disparity in performance between the labels seen at the end of the sequence and those seen earlier is related to catastrophic forgetting in part-2.

The workflow of Hydra, the proposed method to mitigate catastrophic forgetting in SFL.

Hydra: A Novel Method to Mitigate Catastrophic Forgetting in SFL

Based on these insights, we designed Hydra, a method specifically created to reduce catastrophic forgetting in SFL. Hydra leaves device-side training unchanged but modifies the server-side workflow:

  • The server-side model is split into a shared component (part-2a) updated sequentially and multiple head components (part-2b) updated in parallel and merged at the end of each round.
  • Each head is trained by a group of devices with similar highly represented labels. After each round, all heads are aggregated into a single model, reducing forgetting.
  • Unlike traditional multi-head methods in continual learning, Hydra produces one unified model, keeping it efficient and deployable.

Our evaluations show Hydra effectively improves accuracy, closes the label performance gap, and adds minimal computational overhead.

Per-position accuracy in SFL, and SFL+Hydra achieved by ResNet101 with CIFAR-10 in a setting of high data heterogeneity.

The figure above illustrates the effectiveness of Hydra in the training of ResNet101 with CIFAR-10 under high data heterogeneity. Each line shows the accuracy of the label that is highly represented in devices processed at a given position in the sequence (from first to last). The left plot corresponds to standard SFL, while the right plot shows SFL enhanced with Hydra. See that, unlike the baseline SFL, Hydra leads to balanced performance across all processing positions. Moreover, the global accuracy improves significantly, reaching 44% instead of 28% after 100 training rounds.

Why dealing with catastrophic forgetting is important?

Many real-world ML models and AI systems — in healthcare, finance, and mobile devices — are trained collaboratively across multiple organizations or devices. If ML models forget previously learned information, their predictions can become biased or unreliable. By understanding and reducing catastrophic forgetting, we help ensure that distributed ML systems remain accurate, stable, and dependable in real-world use.

Next Steps

In this work, we studied catastrophic forgetting in Split Federated Learning and proposed Hydra, an effective solution that improves global accuracy and closes label performance gaps.

Looking ahead, SFL also raises optimization challenges, related to communication overhead and resource allocation. These are being studied within the OPALS project, funded by the Marie Skłodowska-Curie Actions (MSCA).

If you would like to read the full study, click

Share it on your social networks


Communication

Contact our communication department or requests additional material.