This project focuses on applying continual learning methods to an Encoder-only transformer model for Named Entity Recognition (NER). The model is trained across multiple datasets, each representing a distinct NER task. The approach is highly flexible, allowing for the number of tasks (datasets) to be configured. Elastic Weight Consolidation (EWC) is employed to mitigate catastrophic forgetting.
- Continual Learning: The model sequentially learns multiple NER tasks on different datasets.
- Encoder-only Transformer: Utilizes an Encoder-only transformer model for NER tasks.
- Configurable Training: Users can configure the number of datasets, entity set, and preprocessing mechanisms.
- EWC Mechanism: Fisher Information is computed to retain knowledge from previous tasks.
Continual Learning is crucial because it allows the same model to retain old knowledge while learning new knowledge. This means that once the model is trained on previous tasks, it can still remember and perform well on those tasks while also learning to handle new tasks. This is particularly important in dynamic environments where new data and tasks are continuously introduced.
Elastic Weight Consolidation (EWC) is a method used for continual learning, where a neural network must learn new tasks without forgetting previously learned ones. EWC helps to mitigate catastrophic forgetting by regularizing the model’s weights.
The idea is to identify the important weights for previously learned tasks and then "protect" them while learning new tasks. Fisher Information plays a crucial role here: it is used to measure the importance of each weight for a given task.
- Fisher Information Matrix: For each task, EWC computes the Fisher Information, which indicates how much the loss will change if a particular weight is altered. Higher Fisher Information means that weight is more important for the task.
- Regularization Term: During learning of new tasks, EWC adds a regularization term to the loss function. This term penalizes changes to important weights (those with high Fisher Information), preventing the model from forgetting what it learned on the previous tasks.
- Objective: The goal is to allow the model to learn new tasks while preserving important weights from previous tasks. By using the Fisher Information, EWC ensures that the model focuses on adapting only the necessary weights for new tasks, while minimizing disruption to previously learned knowledge.
- Train on Task 1 (data1): The model is initially trained on the first dataset (data1). During this training, the model learns to recognize entities specific to this dataset.
- Identify Important Parameters: After training on data1, the model identifies which parameters (weights) are most important for the task. This is done by computing the Fisher Information Matrix (FIM), which measures the sensitivity of the loss function to changes in each parameter.
- Retain a Sample from Task 1: A subset of data1 is stored to help the model remember the first task.
- Train on Task 2 (data2) with Task 1 Sample: The model is then trained on the second dataset (data2). During this training, the model uses the stored sample from data1 and the FIM to ensure that important parameters from the first task are not significantly altered. This helps the model retain knowledge from the first task while learning the second task.
- Repeat for All Tasks: The process is repeated for subsequent datasets (data3, data4, etc.). For each new task, the model uses the FIM and stored samples from previous tasks to maintain its performance on earlier tasks.
- Train the model on a dataset and compute the gradients of the loss with respect to the model parameters.
- Estimate Fisher Information Matrix (FIM) using these gradients, capturing parameter importance.
- Store learned weights along with FIM values.
- When training on a new task, add an EWC loss term that penalizes changes to important parameters based on their FIM scores.
- Iterate through tasks, ensuring previous knowledge is preserved.
Elastic Weight Consolidation (EWC) helps prevent catastrophic forgetting in continual learning. The method works by regularizing the weights that are important for previously learned tasks using the Fisher Information.
Fisher Information ( F_i ) measures the importance of each weight ( \theta_i ) to the learned task. It is computed as the expected value of the second derivative of the loss with respect to the weight:
In practice, Fisher Information is estimated using the diagonal of the Fisher Information Matrix, which can be approximated as:
where:
- ( L_n ) is the loss for the ( n )-th training example,
- ( \theta_i ) is the ( i )-th weight in the model,
- ( N ) is the number of training examples.
The total loss function in EWC consists of two terms:
- The loss for the new task.
- A regularization term that penalizes large changes in important weights based on Fisher Information.
The EWC objective function is:
Where:
- ( L_{\text{new task}} ) is the loss function for the new task.
- ( \lambda ) is a regularization strength parameter that controls the importance of the EWC term.
- ( F_i ) is the Fisher Information for weight ( \theta_i ) related to the previous task.
- ( \hat{\theta}_i ) is the optimal value of weight ( \theta_i ) after training on the previous task.
The goal of EWC is to allow the model to learn new tasks while preserving important weights from previous tasks by adding this regularization. The term ( (\theta_i - \hat{\theta}_i)^2 ) ensures that weights important for previous tasks don't change drastically, while the Fisher Information ( F_i ) quantifies how sensitive the loss is to changes in each weight.
In this project, we have provided three datasets (data1, data2, data3) in the data/ folder. These datasets are used to train a Named Entity Recognition (NER) model for medical named entity recognition. Each dataset corresponds to a different task:
- T1: Task 1 corresponding to
data1 - T2: Task 2 corresponding to
data2 - T3: Task 3 corresponding to
data3
We have conducted experiments to compare the performance of continual learning against training on an aggregated dataset (T1+T2+T3). The results show that continual learning performs better than training on the combined dataset.
| Entity | T1 | T1 and T2 | T1, T2, and T3 | T1+T2+T3 Combined |
|---|---|---|---|---|
| allergy_name | 0.738386 | 0.824524 | 0.901130 | 0.797745 |
| cancer | 0.726179 | 0.793356 | 0.833593 | 0.742930 |
| chronic_disease | 0.779630 | 0.805484 | 0.854898 | 0.783958 |
| treatment | 0.777369 | 0.840783 | 0.881324 | 0.801621 |
| micro avg | 0.769029 | 0.820854 | 0.865570 | 0.786875 |
| macro avg | 0.755391 | 0.816037 | 0.867736 | 0.781564 |
| weighted avg | 0.768511 | 0.820849 | 0.865853 | 0.786826 |
For more detailed experimentation and metrics, please refer to the notebooks directory(in the 'experiment' branch) where you can find Jupyter notebooks demonstrating the experiments.
Note: The provided datasets are for medical NER tasks, but you can replace them with any other datasets for different use cases.
-
Create and Activate Virtual Environment:
python -m venv .venv source .venv/bin/activate # macOS/Linux .venv\Scripts\activate # Windows
-
Install Dependencies:
pip install -r requirements.txt
-
Run the Project:
python -m src.main --data-dir data --output-dir output --wandb
-
Build Docker Image:
docker build -t ewc_ner . -
Run without Saving Output:
docker run -it --rm ewc_ner --data-dir /app/data --output-dir /app/output --wandb
-
Run with Output Saved to Host:
docker run -it \ -v /path/to/host/data:/app/data \ -v /path/to/host/output:/app/output \ ewc_ner
Modify src/config.py to adapt the model to different datasets and training setups. It is important to go through the config file and customize it according to your specific needs and data:
@dataclass
class TrainingConfig:
max_len: int
train_batch_size: int
valid_batch_size: int
keep_batch_size: int
epochs: int
learning_rate: float
max_grad_norm: float
keep_sample_size: int # Number of data points to keep from Task N-1 (~Dataset N-1)
train_size: float
num_datasets: int # Number of NER tasks (datasets)
random_state: int
drop_columns: list # Redundant columns to drop (from datasets)
wandb_project_name: str
@dataclass
class ModelConfig:
model_name: str # You can use any Encoder-Only Transformer, As it's TokenClassification
num_labels: int
device: str
ENTITY_SET = [] # The comprehensive list of all the labels possible- Modify
num_datasetsto change the number of NER tasks. - Update
ENTITY_SETto redefine entity categories. - Preprocessing is configurable, allowing adaptation to different data formats.
- Datasets should be placed in the
data/folder and follow the naming convention:data1,data2,data3, etc. You can add as many datasets as needed, named sequentially. - Note: When running this project using Docker, please ensure that the container is allocated sufficient memory and GPU access to avoid any issues. Adjust your Docker runtime settings accordingly before starting the container.
This project is licensed under the MIT License. See the LICENSE file for details.