diff --git a/ds_config_zero1.json b/data_example/deepspeed_config.json similarity index 100% rename from ds_config_zero1.json rename to data_example/deepspeed_config.json diff --git a/finetune/train.py b/finetune/train.py index bb969bd..dd83477 100644 --- a/finetune/train.py +++ b/finetune/train.py @@ -14,13 +14,16 @@ from piccolo.model import STEmbedder from tqdm import tqdm -def load_all_datasets(meta_paths, root_dirs, query_prefix, doc_prefix) -> list[DatsetWithInfo]: + +def load_all_datasets( + meta_paths, root_dirs, query_prefix, doc_prefix +) -> list[DatsetWithInfo]: all_datasets = [] for meta_path, root_dir in zip(meta_paths, root_dirs): CNT = 0 - meta_file = open(meta_path, 'r') + meta_file = open(meta_path, "r") for line in tqdm(meta_file.readlines()): - dataset_name, repeat_num = line.strip().split(' ') + dataset_name, repeat_num = line.strip().split(" ") dataset_dict = load_from_disk(str(os.path.join(root_dir, dataset_name))) if isinstance(dataset_dict, dict): dataset: HfDataset = concatenate_datasets(list(dataset_dict.values())) @@ -28,13 +31,18 @@ def load_all_datasets(meta_paths, root_dirs, query_prefix, doc_prefix) -> list[D dataset = dataset_dict for idx in range(int(repeat_num)): all_datasets.append( - DatsetWithInfo(hf_dataset=dataset, name=dataset_name + '_{}'.format(idx), - query_prefix=query_prefix, passage_prefix=doc_prefix) + DatsetWithInfo( + hf_dataset=dataset, + name=dataset_name + "_{}".format(idx), + query_prefix=query_prefix, + passage_prefix=doc_prefix, + ) ) CNT += 1 - print('loading {} datasets from path: {}'.format(CNT, meta_path)) + print("loading {} datasets from path: {}".format(CNT, meta_path)) return all_datasets + class MyCallback(TrainerCallback): def on_epoch_end(self, args, state, control, train_dataloader, **kwargs): train_dataloader.dataset.create_or_refresh_data() @@ -45,25 +53,43 @@ def __init__(self, efficient_save, **kwargs): super().__init__(**kwargs) self.efficient_save = efficient_save - def save_ckpt_for_sentence_transformers(self, tmp_dir, output_dir, pooling_mode: str = 'mean'): - '''convert to sentence transformer format''' + def save_ckpt_for_sentence_transformers( + self, tmp_dir, output_dir, pooling_mode: str = "mean" + ): + """convert to sentence transformer format""" import shutil from sentence_transformers import models, SentenceTransformer + word_embedding_model = models.Transformer(tmp_dir) - pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), pooling_mode='mean') - if os.path.exists(os.path.join(tmp_dir, 'scaling_layer.bin')): - state_dict = torch.load(os.path.join(tmp_dir, 'scaling_layer.bin')) - in_features, out_features = state_dict['linear.weight'].shape[1], state_dict['linear.weight'].shape[0] - scaling_layer = models.Dense(in_features, out_features, bias=True, activation_function=torch.nn.modules.linear.Identity()) + pooling_model = models.Pooling( + word_embedding_model.get_word_embedding_dimension(), pooling_mode="mean" + ) + if os.path.exists(os.path.join(tmp_dir, "scaling_layer.bin")): + state_dict = torch.load(os.path.join(tmp_dir, "scaling_layer.bin")) + in_features, out_features = ( + state_dict["linear.weight"].shape[1], + state_dict["linear.weight"].shape[0], + ) + scaling_layer = models.Dense( + in_features, + out_features, + bias=True, + activation_function=torch.nn.modules.linear.Identity(), + ) scaling_layer.load_state_dict(state_dict, strict=True) - model = SentenceTransformer(modules=[word_embedding_model, pooling_model, scaling_layer], device='cpu') + model = SentenceTransformer( + modules=[word_embedding_model, pooling_model, scaling_layer], + device="cpu", + ) else: - model = SentenceTransformer(modules=[word_embedding_model, pooling_model], device='cpu') + model = SentenceTransformer( + modules=[word_embedding_model, pooling_model], device="cpu" + ) model.save(output_dir, safe_serialization=False) shutil.rmtree(tmp_dir) def _save(self, output_dir: Optional[str] = None, **kwargs): - '''save the unwrap model''' + """save the unwrap model""" output_dir = output_dir if output_dir is not None else self.args.output_dir os.makedirs(output_dir, exist_ok=True) @@ -72,19 +98,30 @@ def _save(self, output_dir: Optional[str] = None, **kwargs): unwrap_model = self.model.embedder.encoder if self.is_world_process_zero(): # first saves to the tmp dir, then converts to sentence-transformer - tmp_dir = output_dir + '-tmp' - unwrap_model.save_pretrained(tmp_dir, safe_serialization=self.args.save_safetensors) + tmp_dir = output_dir + "-tmp" + unwrap_model.save_pretrained( + tmp_dir, safe_serialization=self.args.save_safetensors + ) self.tokenizer.save_pretrained(tmp_dir) - if hasattr(self.model, 'scaling_layer'): - scaling_layer = {'linear.weight': self.model.scaling_layer.state_dict()['linear.weight'].data.cpu(), - 'linear.bias': self.model.scaling_layer.state_dict()['linear.bias'].data.cpu()} - torch.save(scaling_layer, os.path.join(tmp_dir, 'scaling_layer.bin')) - self.save_ckpt_for_sentence_transformers(tmp_dir, output_dir, self.model.embedder.pooling_strategy.value) + if hasattr(self.model, "scaling_layer"): + scaling_layer = { + "linear.weight": self.model.scaling_layer.state_dict()[ + "linear.weight" + ].data.cpu(), + "linear.bias": self.model.scaling_layer.state_dict()[ + "linear.bias" + ].data.cpu(), + } + torch.save(scaling_layer, os.path.join(tmp_dir, "scaling_layer.bin")) + self.save_ckpt_for_sentence_transformers( + tmp_dir, output_dir, self.model.embedder.pooling_strategy.value + ) def _save_checkpoint(self, model, trial, metrics=None): if self.efficient_save: - '''only save the model ckpt weights to save disk mem''' + """only save the model ckpt weights to save disk mem""" from transformers.trainer import PREFIX_CHECKPOINT_DIR + checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" run_dir = self._get_output_dir(trial=trial) output_dir = os.path.join(run_dir, checkpoint_folder) @@ -102,10 +139,17 @@ def main(): # DataLoader tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path) - all_datasets = load_all_datasets(data_args.meta_paths, data_args.root_dirs, data_args.query_prefix, data_args.doc_prefix) - train_dataset = UniDataset(all_datasets, batch_size=data_args.batch_size, neg_num=data_args.neg_num) + all_datasets = load_all_datasets( + data_args.meta_paths, + data_args.root_dirs, + data_args.query_prefix, + data_args.doc_prefix, + ) + train_dataset = UniDataset( + all_datasets, batch_size=data_args.batch_size, neg_num=data_args.neg_num + ) data_collator = UniCollator(tokenizer=tokenizer, max_length=model_args.max_length) - + # Model model = STEmbedder( model_name_or_path=model_args.model_name_or_path, @@ -114,10 +158,10 @@ def main(): add_scaling_layer=model_args.use_scaling_layer, use_mrl=model_args.use_mrl, extend_pe=model_args.extend_pe, - max_length=model_args.max_length + max_length=model_args.max_length, ) model.embedder.encoder.config.pad_token_id = tokenizer.pad_token_id - + # Trainer trainer = STETrainer( model=model, @@ -126,20 +170,24 @@ def main(): data_collator=data_collator, tokenizer=tokenizer, callbacks=[MyCallback], - efficient_save=training_args.efficient_save + efficient_save=training_args.efficient_save, ) # save training info if trainer.is_world_process_zero(): Path(training_args.output_dir).mkdir(parents=True, exist_ok=True) - Path(os.path.join(training_args.output_dir, 'parameters')).mkdir(parents=True, exist_ok=True) + Path(os.path.join(training_args.output_dir, "parameters")).mkdir( + parents=True, exist_ok=True + ) ## save data list info meta_paths = data_args.meta_paths - with open(os.path.join(training_args.output_dir, 'parameters','data.list'), 'w') as f: + with open( + os.path.join(training_args.output_dir, "parameters", "data.list"), "w" + ) as f: for meta_path in meta_paths: - f.writelines(f'list_name: {meta_path} \n') - f.writelines(open(meta_path, 'r').readlines()) - f.writelines('\n\n') + f.writelines(f"list_name: {meta_path} \n") + f.writelines(open(meta_path, "r").readlines()) + f.writelines("\n\n") trainer.train() @@ -147,12 +195,19 @@ def main(): if trainer.is_world_process_zero(): trainer.save_model(training_args.output_dir, _internal_call=True) ## save parameter - parameter_dict = {'model_args': asdict(model_args), 'data_args': asdict(data_args), 'train_args': asdict(training_args)} - Path(os.path.join(training_args.output_dir, 'parameters')).mkdir(parents=True, exist_ok=True) - with open(os.path.join(training_args.output_dir, 'parameters', 'param.yaml'), 'w') as yaml_file: + parameter_dict = { + "model_args": asdict(model_args), + "data_args": asdict(data_args), + "train_args": asdict(training_args), + } + Path(os.path.join(training_args.output_dir, "parameters")).mkdir( + parents=True, exist_ok=True + ) + with open( + os.path.join(training_args.output_dir, "parameters", "param.yaml"), "w" + ) as yaml_file: yaml.dump(parameter_dict, yaml_file) - if __name__ == "__main__": main() diff --git a/finetune/train_gpt.py b/finetune/train_gpt.py new file mode 100644 index 0000000..e712f9f --- /dev/null +++ b/finetune/train_gpt.py @@ -0,0 +1,218 @@ +import os +import yaml +import torch +from pathlib import Path +from optimum.bettertransformer import BetterTransformer + +from datasets import Dataset as HfDataset +from datasets import concatenate_datasets, load_from_disk +from dataclasses import asdict +from transformers import AutoTokenizer, HfArgumentParser, TrainerCallback +from transformers.trainer import Trainer +from transformers.trainer import logger, Optional + + +from piccolo.arguments import ModelArguments, DataArguments, STETrainingArguments +from piccolo.data import ( + UniCollator, + UniDataset, + DatsetWithInfo, +) +from piccolo.model import GPTEmbedder +from tqdm import tqdm + + +def load_all_datasets( + meta_paths, root_dirs, query_prefix, doc_prefix +) -> list[DatsetWithInfo]: + all_datasets = [] + for meta_path, root_dir in zip(meta_paths, root_dirs): + CNT = 0 + meta_file = open(meta_path, "r") + for line in tqdm(meta_file.readlines()): + dataset_name, repeat_num = line.strip().split(" ") + dataset_dict = load_from_disk(str(os.path.join(root_dir, dataset_name))) + if isinstance(dataset_dict, dict): + dataset: HfDataset = concatenate_datasets(list(dataset_dict.values())) + else: + dataset = dataset_dict + for idx in range(int(repeat_num)): + all_datasets.append( + DatsetWithInfo( + hf_dataset=dataset, + name=dataset_name + "_{}".format(idx), + query_prefix=query_prefix, + passage_prefix=doc_prefix, + ) + ) + CNT += 1 + print("loading {} datasets from path: {}".format(CNT, meta_path)) + return all_datasets + + +class MyCallback(TrainerCallback): + def on_epoch_end(self, args, state, control, train_dataloader, **kwargs): + train_dataloader.dataset.create_or_refresh_data() + + +class GPTTrainer(Trainer): + def __init__(self, use_optimum, efficient_save, **kwargs): + super().__init__(**kwargs) + self.use_optimum = use_optimum + self.efficient_save = efficient_save + + def _save(self, output_dir: Optional[str] = None, **kwargs): + """save the unwrap model, bcz we use better transformer""" + output_dir = output_dir if output_dir is not None else self.args.output_dir + os.makedirs(output_dir, exist_ok=True) + + logger.info("Saving model checkpoint to %s", output_dir) + if self.use_optimum: + from optimum.bettertransformer import BetterTransformer + + unwrap_model = BetterTransformer.reverse(self.model.embedder.encoder) + else: + unwrap_model = self.model.embedder.encoder + if self.is_world_process_zero(): + unwrap_model.save_pretrained( + output_dir, safe_serialization=self.args.save_safetensors + ) + self.tokenizer.save_pretrained(output_dir) + if hasattr(self.model, "scaling_layer"): + scaling_layer_sd_st = { + "linear.weight": self.model.scaling_layer.state_dict()[ + "linear.weight" + ].data.cpu(), + "linear.bias": self.model.scaling_layer.state_dict()[ + "linear.bias" + ].data.cpu(), + } + torch.save( + scaling_layer_sd_st, + os.path.join(output_dir, "scaling_layer_st.bin"), + ) + + def _save_checkpoint(self, model, trial, metrics=None): + if self.efficient_save: + """only save the model ckpt weights to save disk mem""" + from transformers.trainer import PREFIX_CHECKPOINT_DIR + + checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" + run_dir = self._get_output_dir(trial=trial) + output_dir = os.path.join(run_dir, checkpoint_folder) + self.save_model(output_dir, _internal_call=True) + else: + super()._save_checkpoint(model, trial, metrics) + + +def main(): + parser = HfArgumentParser((ModelArguments, DataArguments, STETrainingArguments)) + parser.parse_args() + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + model_args: ModelArguments + data_args: DataArguments + training_args: STETrainingArguments + + # DataLoader and tokenizer + tokenizer = AutoTokenizer.from_pretrained( + model_args.model_name_or_path, + trust_remote_code=True, + ) + all_datasets = load_all_datasets( + data_args.meta_paths, + data_args.root_dirs, + data_args.query_prefix, + data_args.doc_prefix, + ) + train_dataset = UniDataset( + all_datasets, + batch_size=data_args.batch_size, + with_instruction=data_args.with_instruction, + neg_num=data_args.neg_num, + drop_last=data_args.drop_last, + ) + data_collator = UniCollator(tokenizer=tokenizer, max_length=model_args.max_length) + loss_kwargs = { + "loss_type": model_args.loss_type, + "temperature": model_args.temperature, + "neg_num": data_args.neg_num, + "use_all_pair": data_args.use_all_pair, + } + + # Model + model = GPTEmbedder( + model_name_or_path=model_args.model_name_or_path, + loss_kwargs=loss_kwargs, + embedding_strategy=model_args.embedding_strategy, + freeze_pos_emb=False, + add_scaling_layer=model_args.use_scaling_layer, + use_mrl=model_args.use_mrl, + add_cls_head=model_args.add_cls_head, + ) + model.embedder.encoder.config.pad_token_id = tokenizer.pad_token_id + + # If on A100 GPU, try this. + if training_args.use_optimum: + from optimum.bettertransformer import BetterTransformer + + model.embedder.encoder = BetterTransformer.transform( + model.embedder.encoder + ) # optimum better transformer + + # Trainer + trainer = GPTTrainer( + model=model, + args=training_args, + train_dataset=train_dataset, + data_collator=data_collator, + tokenizer=tokenizer, + callbacks=[MyCallback], + use_optimum=training_args.use_optimum, + efficient_save=training_args.efficient_save, + ) + + # Save parameter model at the end + if trainer.is_world_process_zero(): + Path(training_args.output_dir).mkdir(parents=True, exist_ok=True) + Path(os.path.join(training_args.output_dir, "parameters")).mkdir( + parents=True, exist_ok=True + ) + # Save data list info + meta_paths = data_args.meta_paths + with open( + os.path.join(training_args.output_dir, "parameters", "data.list"), "w" + ) as f: + for meta_path in meta_paths: + f.writelines(f"list_name: {meta_path} \n") + f.writelines(open(meta_path, "r").readlines()) + f.writelines("\n\n") + + # Run training + if training_args.use_optimum: + with torch.backends.cuda.sdp_kernel( + enable_flash=True, enable_math=True, enable_mem_efficient=True + ): + trainer.train() + else: + trainer.train() + + # Save parameter and model at the end + if trainer.is_world_process_zero(): + trainer.save_model(training_args.output_dir, _internal_call=True) + # Save parameter + parameter_dict = { + "model_args": asdict(model_args), + "data_args": asdict(data_args), + "train_args": asdict(training_args), + } + Path(os.path.join(training_args.output_dir, "parameters")).mkdir( + parents=True, exist_ok=True + ) + with open( + os.path.join(training_args.output_dir, "parameters", "param.yaml"), "w" + ) as yaml_file: + yaml.dump(parameter_dict, yaml_file) + + +if __name__ == "__main__": + main() diff --git a/piccolo/__pycache__/arguments.cpython-310.pyc b/piccolo/__pycache__/arguments.cpython-310.pyc deleted file mode 100644 index 28bdc27..0000000 Binary files a/piccolo/__pycache__/arguments.cpython-310.pyc and /dev/null differ diff --git a/piccolo/__pycache__/criteria.cpython-310.pyc b/piccolo/__pycache__/criteria.cpython-310.pyc deleted file mode 100644 index 9d9c455..0000000 Binary files a/piccolo/__pycache__/criteria.cpython-310.pyc and /dev/null differ diff --git a/piccolo/__pycache__/data.cpython-310.pyc b/piccolo/__pycache__/data.cpython-310.pyc deleted file mode 100644 index 75bcbb3..0000000 Binary files a/piccolo/__pycache__/data.cpython-310.pyc and /dev/null differ diff --git a/piccolo/__pycache__/data_structures.cpython-310.pyc b/piccolo/__pycache__/data_structures.cpython-310.pyc deleted file mode 100644 index 3b37dd5..0000000 Binary files a/piccolo/__pycache__/data_structures.cpython-310.pyc and /dev/null differ diff --git a/piccolo/__pycache__/model.cpython-310.pyc b/piccolo/__pycache__/model.cpython-310.pyc deleted file mode 100644 index a92342b..0000000 Binary files a/piccolo/__pycache__/model.cpython-310.pyc and /dev/null differ diff --git a/piccolo/arguments.py b/piccolo/arguments.py index 0101daf..a311877 100644 --- a/piccolo/arguments.py +++ b/piccolo/arguments.py @@ -1,31 +1,46 @@ from dataclasses import dataclass, field from transformers import TrainingArguments -from piccolo.model import PoolingStrategy +from piccolo.model import ( + InBatchNegLossType, + PoolingStrategy, +) + @dataclass class ModelArguments: """ Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. """ - model_name_or_path: str = field() # must require - embedding_strategy: PoolingStrategy = field(default=PoolingStrategy.mean) + + model_name_or_path: str = field() # must require + temperature: float = field(default=0.01) + loss_type: InBatchNegLossType = field(default=InBatchNegLossType.softmax) + embedding_strategy: PoolingStrategy = field(default=PoolingStrategy.last_mean) extend_pe: bool = field(default=False) + use_rope: bool = field(default=False) max_length: int = field(default=512) # scaling layer and mrl Training use_scaling_layer: bool = field(default=False) use_mrl: bool = field(default=False) + add_cls_head: bool = field(default=False) + @dataclass class DataArguments: # train data - meta_paths: list[str] = field() # must require - root_dirs: list[str] = field() # must require + meta_paths: list[str] = field() # must require + root_dirs: list[str] = field() # must require batch_size: int = field(default=16) - query_prefix: str = field(default='') - doc_prefix: str = field(default='') + with_instruction: bool = field(default=False) + drop_last: bool = field(default=True) + query_prefix: str = field(default="") + doc_prefix: str = field(default="") # hard neg - neg_num: int = field(default=1) # only affects retri_contrast_loss + neg_num: int = field(default=1) # only affects retri_contrast_loss + use_all_pair: bool = field(default=False) + @dataclass class STETrainingArguments(TrainingArguments): + use_optimum: bool = field(default=False) efficient_save: bool = field(default=True) diff --git a/piccolo/criteria.py b/piccolo/criteria.py index 0886b1b..6ff5d2d 100644 --- a/piccolo/criteria.py +++ b/piccolo/criteria.py @@ -1,22 +1,105 @@ import torch + class ContrastLoss(torch.nn.Module): def __init__(self, temperature: float = 0.05): super().__init__() self.temperature = temperature +class PairInBatchNegSoftmaxContrastLoss(ContrastLoss): + def __init__(self, temperature: float = 0.05): + super().__init__() + self.temperature = temperature + self._cross_entropy_loss = torch.nn.CrossEntropyLoss() + + def forward( + self, + text_embeddings: torch.Tensor, + text_pos_embeddings: torch.Tensor, + ) -> torch.Tensor: + sim_matrix = torch.cosine_similarity( + text_embeddings.unsqueeze(1), + text_pos_embeddings.unsqueeze(0), + dim=-1, + ) + sim_matrix = sim_matrix / self.temperature + labels = torch.arange( + sim_matrix.size(0), device=text_embeddings.device, dtype=torch.long + ) + loss = self._cross_entropy_loss(sim_matrix, labels) + return loss + + +class PairInBatchHardNegSoftmaxContrastLoss(ContrastLoss): + def __init__(self, temperature: float = 0.05, neg_num=7, use_all_pair=False): + super().__init__() + self.temperature = temperature + self.neg_num = neg_num + self._cross_entropy_loss = torch.nn.CrossEntropyLoss() + self.use_all_pair = use_all_pair + + def forward( + self, + text_embeddings: torch.Tensor, + text_pos_embeddings: torch.Tensor, + text_neg_embeddings: torch.Tensor | None, + text_neg_index: torch.Tensor | None, + ) -> torch.Tensor: + if text_neg_embeddings is None: + """For no neg""" + sim_matrix = torch.cosine_similarity( + text_embeddings.unsqueeze(1), + text_pos_embeddings.unsqueeze(0), + dim=-1, + ) + sim_matrix = sim_matrix / self.temperature + labels = torch.arange( + sim_matrix.size(0), device=text_embeddings.device, dtype=torch.long + ) + loss = self._cross_entropy_loss(sim_matrix, labels) + return loss + + sim_neg_matrix = torch.cosine_similarity( + text_embeddings.unsqueeze(1), + text_neg_embeddings.unsqueeze(0), + dim=-1, + ) + if self.use_all_pair: + sim_pos_matrix = torch.cosine_similarity( + text_embeddings.unsqueeze(1), + text_pos_embeddings.unsqueeze(0), + dim=-1, + ) + sim_matrix = torch.cat([sim_pos_matrix, sim_neg_matrix], dim=1) + labels = torch.arange( + sim_matrix.size(0), dtype=torch.long, device=sim_matrix.device + ) + else: + sim_pos_vector = torch.cosine_similarity( + text_embeddings, text_pos_embeddings, dim=-1 + ) + sim_matrix = torch.cat([sim_pos_vector.unsqueeze(1), sim_neg_matrix], dim=1) + labels = torch.zeros( + sim_matrix.size(0), dtype=torch.long, device=sim_matrix.device + ) + sim_matrix = sim_matrix / self.temperature + loss = self._cross_entropy_loss(sim_matrix, labels) + return loss + + class RetriContrastLoss(ContrastLoss): - ''' + """ loss for retrieval - if use_all_pair is set to true, it will use the query-query pair as neg, + if use_all_pair is set to true, it will use the query-query pair as neg, otherwise it use query-passage as neg - ''' + """ + def __init__(self, temperature: float = 0.05, use_all_pair=False): super().__init__() self.temperature = temperature self._cross_entropy_loss = torch.nn.CrossEntropyLoss() - self.use_all_pair=use_all_pair + self.use_all_pair = use_all_pair def forward( self, @@ -32,10 +115,12 @@ def forward( dim=-1, ) sim_matrix = sim_matrix / self.temperature - labels = torch.arange(sim_matrix.size(0), device=text_embeddings.device, dtype=torch.long) + labels = torch.arange( + sim_matrix.size(0), device=text_embeddings.device, dtype=torch.long + ) loss = self._cross_entropy_loss(sim_matrix, labels) return loss - + sim_neg_matrix = torch.cosine_similarity( text_embeddings.unsqueeze(1), text_neg_embeddings.unsqueeze(0), @@ -48,56 +133,71 @@ def forward( dim=-1, ) sim_matrix = torch.cat([sim_pos_matrix, sim_neg_matrix], dim=1) - labels = torch.arange(sim_matrix.size(0), dtype=torch.long, device=sim_matrix.device) + labels = torch.arange( + sim_matrix.size(0), dtype=torch.long, device=sim_matrix.device + ) else: - sim_pos_vector = torch.cosine_similarity(text_embeddings, text_pos_embeddings, dim=-1) + sim_pos_vector = torch.cosine_similarity( + text_embeddings, text_pos_embeddings, dim=-1 + ) sim_matrix = torch.cat([sim_pos_vector.unsqueeze(1), sim_neg_matrix], dim=1) - labels = torch.zeros(sim_matrix.size(0), dtype=torch.long, device=sim_matrix.device) + labels = torch.zeros( + sim_matrix.size(0), dtype=torch.long, device=sim_matrix.device + ) sim_matrix = sim_matrix / self.temperature loss = self._cross_entropy_loss(sim_matrix, labels) return loss class CoSentLoss(ContrastLoss): - ''' + """ loss for sts and pair classification. here we hard code the cosent loss weight to 0.04 - ''' + """ + bias: torch.Tensor def __init__(self, temperature: float = 0.05, cosent_w: float = 0.04) -> None: super().__init__(temperature) - self.register_buffer('bias', torch.tensor([0.0])) + self.register_buffer("bias", torch.tensor([0.0])) self.cosent_w = cosent_w - def forward(self, predict_similarity: torch.Tensor, true_similarity: torch.Tensor) -> torch.Tensor: + def forward( + self, predict_similarity: torch.Tensor, true_similarity: torch.Tensor + ) -> torch.Tensor: predict_similarity = predict_similarity / self.temperature - cosine_similarity_diff = -(predict_similarity.unsqueeze(0) - predict_similarity.unsqueeze(1)) + cosine_similarity_diff = -( + predict_similarity.unsqueeze(0) - predict_similarity.unsqueeze(1) + ) smaller_mask = true_similarity.unsqueeze(0) <= true_similarity.unsqueeze(1) cosine_similarity_diff = cosine_similarity_diff[~smaller_mask] cosine_diff_scores_add_bias = torch.cat((cosine_similarity_diff, self.bias)) loss = torch.logsumexp(cosine_diff_scores_add_bias, dim=0) * self.cosent_w return loss + class ClsContrastLoss(torch.nn.Module): - ''' + """ loss for clustering and classification here we hard code the cls contrast loss weight to 0.2 - ''' - def __init__(self, temperature: float = 0.05, cls_w = 0.2): + """ + + def __init__(self, temperature: float = 0.05, cls_w=0.2): super().__init__() self.temperature = temperature self._cross_entropy_loss = torch.nn.CrossEntropyLoss() self.cls_w = cls_w - + def forward( self, text_embeddings: torch.Tensor, text_pos_embeddings: torch.Tensor, text_neg_embeddings: torch.Tensor, - ) -> torch.Tensor: + ) -> torch.Tensor: bs = text_embeddings.shape[0] - assert text_neg_embeddings.shape[0] % bs == 0, 'neg num is not equal for each sample' + assert ( + text_neg_embeddings.shape[0] % bs == 0 + ), "neg num is not equal for each sample" neg_num = int(text_neg_embeddings.shape[0] // bs) sim_neg_matrix = torch.cosine_similarity( @@ -105,7 +205,9 @@ def forward( text_neg_embeddings.unsqueeze(0), dim=-1, ) - sim_pos_vector = torch.cosine_similarity(text_embeddings, text_pos_embeddings, dim=-1) + sim_pos_vector = torch.cosine_similarity( + text_embeddings, text_pos_embeddings, dim=-1 + ) # find the neg for eatch training sample neg_matrix = [] @@ -114,6 +216,8 @@ def forward( sim_neg_matrix = torch.stack(neg_matrix) sim_matrix = torch.cat([sim_pos_vector.unsqueeze(1), sim_neg_matrix], dim=1) sim_matrix = sim_matrix / self.temperature - labels = torch.zeros(sim_matrix.size(0), dtype=torch.long, device=sim_matrix.device) + labels = torch.zeros( + sim_matrix.size(0), dtype=torch.long, device=sim_matrix.device + ) loss = self._cross_entropy_loss(sim_matrix, labels) * self.cls_w - return loss \ No newline at end of file + return loss diff --git a/piccolo/data.py b/piccolo/data.py index 195d4f8..76c3f2e 100644 --- a/piccolo/data.py +++ b/piccolo/data.py @@ -7,7 +7,13 @@ from datasets import Dataset as HfDataset from torch.utils.data import Dataset, RandomSampler -from piccolo.data_structures import PairRetriContrastRecord, PairScoredRecord, PairClsContrastRecord +from piccolo.data_structures import ( + PairRetriContrastRecord, + PairNegRecord, + PairScoredRecord, + PairClsContrastRecord, + PairCLSRecord, +) from transformers.tokenization_utils import PreTrainedTokenizer from transformers.tokenization_utils_fast import PreTrainedTokenizerFast @@ -15,18 +21,21 @@ class UniCollator: - ''' + """ Uni Data Collator for retrieval, sts, pair classification, the query max length is 64, doc max length is 512 - for clustering and classification, we specially set the query max length to 512, - bcz for clustering and classification task, query('text') is ofent much longer than the pos/neg ('label') - ''' - def __init__(self, tokenizer: Tokenizer, max_length: int, q_max_length: int = 64) -> None: + for clustering and classification, we specially set the query max length to 512, + bcz for clustering and classification task, query('text') is ofent much longer than the pos/neg ('label') + """ + + def __init__( + self, tokenizer: Tokenizer, max_length: int, q_max_length: int = 64 + ) -> None: self.tokenizer = tokenizer self.max_length = max_length self.q_max_length = q_max_length - + def __call__(self, records: list) -> dict[str, torch.Tensor]: records = records[0] if isinstance(records[0], PairClsContrastRecord): @@ -36,38 +45,116 @@ def __call__(self, records: list) -> dict[str, torch.Tensor]: for i, record in enumerate(records): for neg in record.text_neg: texts_neg.append(neg) - text_ids = self.tokenizer(texts, padding=True, max_length=self.max_length, truncation=True, return_tensors='pt',)['input_ids'] - text_pos_ids = self.tokenizer(texts_pos, padding=True, max_length=self.max_length, truncation=True, return_tensors='pt',)['input_ids'] - text_neg_ids = self.tokenizer(texts_neg, padding=True, max_length=self.max_length, truncation=True, return_tensors='pt',)['input_ids'] - + text_ids = self.tokenizer( + texts, + padding=True, + max_length=self.max_length, + truncation=True, + return_tensors="pt", + )["input_ids"] + text_pos_ids = self.tokenizer( + texts_pos, + padding=True, + max_length=self.max_length, + truncation=True, + return_tensors="pt", + )["input_ids"] + text_neg_ids = self.tokenizer( + texts_neg, + padding=True, + max_length=self.max_length, + truncation=True, + return_tensors="pt", + )["input_ids"] + return { - 'text_ids': cast(torch.Tensor, text_ids), - 'text_pos_ids': cast(torch.Tensor, text_pos_ids), - 'text_neg_ids': cast(torch.Tensor, text_neg_ids), - 'type': 'cls_contrast', + "text_ids": cast(torch.Tensor, text_ids), + "text_pos_ids": cast(torch.Tensor, text_pos_ids), + "text_neg_ids": cast(torch.Tensor, text_neg_ids), + "type": "cls_contrast", } elif isinstance(records[0], PairRetriContrastRecord): texts = [record.text for record in records] texts_pos = [record.text_pos for record in records] texts_neg = [] - texts_neg_index = [] # index indictates for which text the negative sample belongs + texts_neg_index = ( + [] + ) # index indictates for which text the negative sample belongs for i, record in enumerate(records): for neg in record.text_neg: texts_neg.append(neg) texts_neg_index.append(i) - - text_ids = self.tokenizer(texts, padding=True, max_length=self.q_max_length, truncation=True, return_tensors='pt',)['input_ids'] - text_pos_ids = self.tokenizer(texts_pos, padding=True, max_length=self.max_length, truncation=True, return_tensors='pt',)['input_ids'] + + text_ids = self.tokenizer( + texts, + padding=True, + max_length=self.q_max_length, + truncation=True, + return_tensors="pt", + )["input_ids"] + text_pos_ids = self.tokenizer( + texts_pos, + padding=True, + max_length=self.max_length, + truncation=True, + return_tensors="pt", + )["input_ids"] if len(texts_neg) > 0: - text_neg_ids = self.tokenizer(texts_neg, padding=True, max_length=self.max_length, truncation=True, return_tensors='pt',)['input_ids'] + text_neg_ids = self.tokenizer( + texts_neg, + padding=True, + max_length=self.max_length, + truncation=True, + return_tensors="pt", + )["input_ids"] else: text_neg_ids = None return { - 'text_ids': cast(torch.Tensor, text_ids), - 'text_pos_ids': cast(torch.Tensor, text_pos_ids), - 'text_neg_ids': cast(torch.Tensor, text_neg_ids), - 'text_neg_index': cast(torch.Tensor, texts_neg_index), - 'type': 'retri_contrast', + "text_ids": cast(torch.Tensor, text_ids), + "text_pos_ids": cast(torch.Tensor, text_pos_ids), + "text_neg_ids": cast(torch.Tensor, text_neg_ids), + "text_neg_index": cast(torch.Tensor, texts_neg_index), + "type": "retri_contrast", + } + elif isinstance(records[0], PairNegRecord): + texts = [record.text for record in records] + texts_pos = [record.text_pos for record in records] + texts_neg = [] + texts_neg_index = [] + for i, record in enumerate(records): + for neg in record.text_neg: + texts_neg.append(neg) + texts_neg_index.append(i) + + text_ids = self.tokenizer( + texts, + padding=True, + max_length=self.q_max_length, + truncation=True, + return_tensors="pt", + )["input_ids"] + text_pos_ids = self.tokenizer( + texts_pos, + padding=True, + max_length=self.max_length, + truncation=True, + return_tensors="pt", + )["input_ids"] + if len(texts_neg) > 0: + text_neg_ids = self.tokenizer( + texts_neg, + padding=True, + max_length=self.max_length, + truncation=True, + return_tensors="pt", + )["input_ids"] + else: + text_neg_ids = None + return { + "text_ids": cast(torch.Tensor, text_ids), + "text_pos_ids": cast(torch.Tensor, text_pos_ids), + "text_neg_ids": cast(torch.Tensor, text_neg_ids), + "text_neg_index": cast(torch.Tensor, texts_neg_index), } elif isinstance(records[0], PairScoredRecord): texts = [record.text for record in records] @@ -75,14 +162,40 @@ def __call__(self, records: list) -> dict[str, torch.Tensor]: labels = [record.label for record in records] labels = torch.tensor(labels, dtype=torch.float32) - text_ids = self.tokenizer(texts, padding=True, max_length=self.q_max_length, truncation=True, return_tensors='pt',)['input_ids'] - text_pair_ids = self.tokenizer(texts_pair, padding=True, max_length=self.max_length, truncation=True, return_tensors='pt',)['input_ids'] + text_ids = self.tokenizer( + texts, + padding=True, + max_length=self.q_max_length, + truncation=True, + return_tensors="pt", + )["input_ids"] + text_pair_ids = self.tokenizer( + texts_pair, + padding=True, + max_length=self.max_length, + truncation=True, + return_tensors="pt", + )["input_ids"] return { - 'text_ids': cast(torch.Tensor, text_ids), - 'text_pair_ids': cast(torch.Tensor, text_pair_ids), - 'labels': labels, - 'type': 'cosent', + "text_ids": cast(torch.Tensor, text_ids), + "text_pair_ids": cast(torch.Tensor, text_pair_ids), + "labels": labels, + "type": "cosent", + } + elif isinstance(records[0], PairCLSRecord): + texts = [record.text for record in records] + text_labels = [record.text_label for record in records] + text_ids = self.tokenizer( + texts, + padding=True, + max_length=self.q_max_length, + truncation=True, + return_tensors="pt", + )["input_ids"] + return { + "text_ids": cast(torch.Tensor, text_ids), + "text_labels": torch.tensor(text_labels, dtype=torch.float32), } else: raise NotImplementedError("only support pairscored and pairneg records") @@ -98,12 +211,12 @@ class TaskBatchIndex: class DatsetWithInfo: hf_dataset: HfDataset name: str - query_prefix: str = '' - passage_prefix: str = '' + query_prefix: str = "" + passage_prefix: str = "" class UniDataset(Dataset): - ''' + """ Task-Homogenous Dataset Code is modified from M3E: https://github.com/wangyuxinwhy/uniem @@ -111,30 +224,44 @@ class UniDataset(Dataset): This technique can ensure that the in-batch-negative samples come from the same data set, and it is generally believed that this can improve the quality of the in batch negatives. - ''' + """ + def __init__( self, hf_datasets: list[DatsetWithInfo], neg_num: int = 1, batch_size: int = 32, + with_instruction: bool = False, + drop_last: bool = True, max_samples: int | None = None, ): self.batch_size = batch_size + self.drop_last = drop_last self.hf_datasets = hf_datasets self.max_samples = max_samples - self.name_dataset_map = {dataset.name: dataset.hf_dataset for dataset in hf_datasets} + self.name_dataset_map = { + dataset.name: dataset.hf_dataset for dataset in hf_datasets + } + self.with_instruction = with_instruction self.neg_num = neg_num - self.query_prefix_map = {dataset.name: dataset.query_prefix for dataset in hf_datasets} - self.passage_prefix_map = {dataset.name: dataset.passage_prefix for dataset in hf_datasets} + if with_instruction: + self.query_prefix_map = { + dataset.name: dataset.query_prefix for dataset in hf_datasets + } + self.passage_prefix_map = { + dataset.name: dataset.passage_prefix for dataset in hf_datasets + } + else: + self.query_prefix_map, self.passage_prefix_map = None, None self.create_or_refresh_data() def __len__(self): return len(self.task_batch_index_list) - + @staticmethod def is_valid_text(text: Any) -> bool: return isinstance(text, str) and bool(text.strip()) - + def create_or_refresh_data(self): self.task_batch_index_list: list[TaskBatchIndex] = [] for dataset in self.hf_datasets: @@ -145,67 +272,144 @@ def create_or_refresh_data(self): for i in RandomSampler(dataset.hf_dataset, num_samples=num_samples): buffer.append(i) if len(buffer) == batch_size: - self.task_batch_index_list.append(TaskBatchIndex(name=dataset.name, batch_index=buffer)) + self.task_batch_index_list.append( + TaskBatchIndex(name=dataset.name, batch_index=buffer) + ) buffer = [] self.random_index_list = list(RandomSampler(self.task_batch_index_list)) + def get_clsf_records(self, records, task_name): + cls_records = [] + for record in records: + text = record["text"] + text_label = record["label"] + if not self.is_valid_text(text): + continue + cls_records.append(PairCLSRecord(text=text, text_label=text_label)) + return cls_records def get_pair_scored_records(self, records, task_name): pair_records = [] for record in records: - text = record['text'] - text_pair = record['text_pair'] - label = record['label'] + text = record["text"] + text_pair = record["text_pair"] + label = record["label"] if not (self.is_valid_text(text) and self.is_valid_text(text_pair)): continue - text = self.query_prefix_map[task_name] + text - text_pair = self.passage_prefix_map[task_name] + text_pair - pair_records.append(PairScoredRecord(text=text, text_pair=text_pair, label=label)) + if self.with_instruction: + text = self.query_prefix_map[task_name] + text + text_pair = self.passage_prefix_map[task_name] + text_pair + pair_records.append( + PairScoredRecord(text=text, text_pair=text_pair, label=label) + ) + return pair_records + + def get_pair_contrast_records(self, records, task_name, hf_dataset, batch_index): + + def process_records(record): + text = record["text"] + if isinstance(record["text_pos"], list): # random sample a positive + if len(record["text_pos"]) == 0: + text_pos = text + else: + text_pos = random.sample(record["text_pos"], 1)[0] + else: + text_pos = record["text_pos"] + + if not (self.is_valid_text(text) and self.is_valid_text(text_pos)): + # skip current sample and random sample an index, ensure right batch size + random_index = random.sample(range(len(hf_dataset)), k=1)[0] + while random_index in batch_index: + random_index = random.sample(range(len(hf_dataset)), k=1)[0] + return process_records(hf_dataset[random_index]) + # append neg list to const length + text_neg = random.sample( + record["text_neg"], min(self.neg_num, len(record["text_neg"])) + ) + + if self.with_instruction: + text = self.query_prefix_map[task_name] + text + text_pos = self.passage_prefix_map[task_name] + text_pos + text_neg = [ + self.passage_prefix_map[task_name] + neg for neg in text_neg + ] + return text, text_pos, text_neg + + pair_records = [] + for record in records: + text, text_pos, text_neg = process_records(record) + pair_records.append( + PairNegRecord(text=text, text_pos=text_pos, text_neg=text_neg) + ) + assert ( + len(pair_records) == self.batch_size + ), "error, current batch size not match !!!" return pair_records - def get_pair_retri_contrast_records(self, records, task_name, hf_dataset, batch_index): + def get_pair_retri_contrast_records( + self, records, task_name, hf_dataset, batch_index + ): def process_records(record): - text = record['text'] - if isinstance(record['text_pos'], list): # random sample a positive - assert len(record['text_pos']) >= 1, 'text pos num should be at least 1' - text_pos = random.sample(record['text_pos'], 1)[0] + text = record["text"] + if isinstance(record["text_pos"], list): # random sample a positive + assert len(record["text_pos"]) >= 1, "text pos num should be at least 1" + text_pos = random.sample(record["text_pos"], 1)[0] else: - text_pos = record['text_pos'] - + text_pos = record["text_pos"] + if not (self.is_valid_text(text) and self.is_valid_text(text_pos)): - # skip current sample and random sample an index + # skip current sample and random sample an index random_index = random.sample(range(len(hf_dataset)), k=1)[0] while random_index in batch_index: random_index = random.sample(range(len(hf_dataset)), k=1)[0] return process_records(hf_dataset[random_index]) - text_neg = random.sample(record['text_neg'], min(self.neg_num, len(record['text_neg']))) - text = self.query_prefix_map[task_name] + text - text_pos = self.passage_prefix_map[task_name] + text_pos - text_neg = [self.passage_prefix_map[task_name] + neg for neg in text_neg] + text_neg = random.sample( + record["text_neg"], min(self.neg_num, len(record["text_neg"])) + ) + if self.with_instruction: + text = self.query_prefix_map[task_name] + text + text_pos = self.passage_prefix_map[task_name] + text_pos + text_neg = [ + self.passage_prefix_map[task_name] + neg for neg in text_neg + ] return text, text_pos, text_neg pair_records = [] for record in records: text, text_pos, text_neg = process_records(record) - pair_records.append(PairRetriContrastRecord(text=text, text_pos=text_pos, text_neg=text_neg)) - assert len(pair_records) == self.batch_size, 'error, current batch size not match !!!' + pair_records.append( + PairRetriContrastRecord(text=text, text_pos=text_pos, text_neg=text_neg) + ) + assert ( + len(pair_records) == self.batch_size + ), "error, current batch size not match !!!" return pair_records def get_pair_cls_contrast_records(self, records, task_name): pair_records = [] for record in records: - text, text_pos, text_neg = record['text'], record['text_pos'], record['text_neg'] - if isinstance(record['text_pos'], list): - text_pos = random.sample(record['text_pos'], 1)[0] - elif isinstance(record['text_pos'], str): - text_pos = record['text_pos'] + text, text_pos, text_neg = ( + record["text"], + record["text_pos"], + record["text_neg"], + ) + if isinstance(record["text_pos"], list): + text_pos = random.sample(record["text_pos"], 1)[0] + elif isinstance(record["text_pos"], str): + text_pos = record["text_pos"] else: - assert False, 'type error' - text_neg = random.sample(record['text_neg'], min(10, len(record['text_neg']))) # TODO, hard code the neg num to 10 + assert False, "type error" + text_neg = random.sample( + record["text_neg"], min(10, len(record["text_neg"])) + ) # TODO, hard code the neg num to 10 if self.is_valid_text(text) and self.is_valid_text(text_pos): - pair_records.append(PairClsContrastRecord(text=text, text_pos=text_pos, text_neg=text_neg)) + pair_records.append( + PairClsContrastRecord( + text=text, text_pos=text_pos, text_neg=text_neg + ) + ) return pair_records def __getitem__(self, index: int): @@ -213,19 +417,21 @@ def __getitem__(self, index: int): task_batch_index = self.task_batch_index_list[index] task_name = task_batch_index.name batch_index = task_batch_index.batch_index - + hf_dataset = self.name_dataset_map[task_name] records = [hf_dataset[i] for i in batch_index] - if hf_dataset[0]['type'] == 'cls_contrast': + if hf_dataset[0]["type"] == "cls_contrast": pair_records = self.get_pair_cls_contrast_records(records, task_name) - elif hf_dataset[0]['type'] == 'retri_contrast': - pair_records = self.get_pair_retri_contrast_records(records, task_name, hf_dataset, batch_index) - elif hf_dataset[0]['type'] == 'cosent': + elif hf_dataset[0]["type"] == "retri_contrast": + pair_records = self.get_pair_retri_contrast_records( + records, task_name, hf_dataset, batch_index + ) + elif hf_dataset[0]["type"] == "cosent": pair_records = self.get_pair_scored_records(records, task_name) else: - raise NotImplementedError('only support pair contrast and pair scored') + raise NotImplementedError("only support pair contrast and pair scored") if not pair_records: - print(f'records is empty', records) + print(f"records is empty", records) return self.__getitem__(index + 1) return pair_records diff --git a/piccolo/data_structures.py b/piccolo/data_structures.py index 9ec2683..599093f 100644 --- a/piccolo/data_structures.py +++ b/piccolo/data_structures.py @@ -1,19 +1,41 @@ from dataclasses import dataclass + +@dataclass(slots=True) +class PairRecord: + text: str + text_pos: str + + +@dataclass(slots=True) +class PairNegRecord: + text: str + text_pos: str + text_neg: list + + @dataclass(slots=True) class PairRetriContrastRecord: text: str text_pos: str text_neg: list + @dataclass(slots=True) class PairClsContrastRecord: text: str text_pos: str text_neg: list + @dataclass(slots=True) class PairScoredRecord: text: str text_pair: str - label: float \ No newline at end of file + label: int + + +@dataclass(slots=True) +class PairCLSRecord: + text: str + text_label: str diff --git a/piccolo/model.py b/piccolo/model.py index 13e1056..7a4e4e3 100644 --- a/piccolo/model.py +++ b/piccolo/model.py @@ -2,20 +2,41 @@ import torch import os +from transformers import ( + AutoConfig, + AutoModel, + AutoTokenizer, + PreTrainedModel, + AutoModelForCausalLM, +) from enum import Enum from pathlib import Path from typing import ClassVar, Literal, Type, TypeVar, cast -from transformers import AutoConfig, AutoModel, AutoTokenizer, PreTrainedModel # type: ignore -from piccolo.criteria import CoSentLoss, ClsContrastLoss, RetriContrastLoss + +# type: ignore +from piccolo.criteria import ( + CoSentLoss, + ClsContrastLoss, + RetriContrastLoss, + PairInBatchNegSoftmaxContrastLoss, + PairInBatchHardNegSoftmaxContrastLoss, +) + class PoolingStrategy(str, Enum): - cls = 'cls' - mean = 'mean' + cls = "cls" + mean = "mean" + last_mean = "last_mean" + last_mean_dropout = "last_mean_dropout" + class InBatchNegLossType(str, Enum): - cosent = 'cosent' - retri_contrast = 'retri_contrast' - cls_contrast = 'cls_contrast' + cosent = "cosent" + retri_contrast = "retri_contrast" + softmax = "softmax" + hardneg_softmax = "hardneg_softmax" + cls_contrast = "cls_contrast" + def build_loss(loss_type, temperature, **kwargs): loss_type = InBatchNegLossType(loss_type) @@ -24,29 +45,63 @@ def build_loss(loss_type, temperature, **kwargs): return CoSentLoss(temperature) case InBatchNegLossType.cls_contrast: return ClsContrastLoss(temperature) + case InBatchNegLossType.softmax: + return PairInBatchNegSoftmaxContrastLoss(temperature) + case InBatchNegLossType.hardneg_softmax: + return PairInBatchHardNegSoftmaxContrastLoss(temperature, **kwargs) case InBatchNegLossType.retri_contrast: return RetriContrastLoss(temperature, **kwargs) -def creat_attention_mask_from_input_ids(input_ids: torch.Tensor, pad_token_id: int) -> torch.Tensor: + +def creat_attention_mask_from_input_ids( + input_ids: torch.Tensor, pad_token_id: int +) -> torch.Tensor: return input_ids != pad_token_id -def mean_pooling(hidden_state: torch.Tensor, attention_mask: torch.Tensor | None = None) -> torch.Tensor: + +def mean_pooling( + hidden_state: torch.Tensor, attention_mask: torch.Tensor | None = None +) -> torch.Tensor: if attention_mask is None: return torch.mean(hidden_state, dim=1) attention_mask = attention_mask.float() - return torch.sum(hidden_state * attention_mask.unsqueeze(-1), dim=1) / torch.sum(attention_mask, dim=-1, keepdim=True) + return torch.sum(hidden_state * attention_mask.unsqueeze(-1), dim=1) / torch.sum( + attention_mask, dim=-1, keepdim=True + ) + + +def last_pooling( + hidden_state: torch.Tensor, attention_mask: torch.Tensor | None = None +) -> torch.Tensor: + last_hidden = hidden_state.masked_fill(~attention_mask[..., None].bool(), 0.0) + left_padding = attention_mask[:, -1].sum() == attention_mask.shape[0] + if left_padding: + emb = last_hidden[:, -1] + else: + sequence_lengths = attention_mask.sum(dim=1) - 1 + batch_size = last_hidden.shape[0] + emb = last_hidden[torch.arange(batch_size), sequence_lengths] + + return emb + def load_hf_pretrained_model(model_name_or_path: str) -> PreTrainedModel: - config = AutoConfig.from_pretrained(model_name_or_path) - if config.model_type == 't5': + config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True) + if config.model_type == "t5": from transformers import T5EncoderModel # type: ignore - pretrained_model = T5EncoderModel.from_pretrained(model_name_or_path) + + pretrained_model = T5EncoderModel.from_pretrained( + model_name_or_path, trust_remote_code=True + ) else: - pretrained_model = AutoModel.from_pretrained(model_name_or_path) + pretrained_model = AutoModel.from_pretrained( + model_name_or_path, trust_remote_code=True + ) return pretrained_model # type: ignore -StrategyEmbedderClsMap: dict[PoolingStrategy, Type['Embedder']] = {} +StrategyEmbedderClsMap: dict[PoolingStrategy, Type["Embedder"]] = {} + class Embedder(torch.nn.Module): pooling_strategy: ClassVar[PoolingStrategy] @@ -54,6 +109,7 @@ class Embedder(torch.nn.Module): def __init__(self, encoder: PreTrainedModel, pad_token_id: int | None = None): super().__init__() self.encoder = encoder + self.encoder.config.piccolo_pooling_strategy = str(self.pooling_strategy.value) if pad_token_id is None: if encoder.config.pad_token_id is not None: @@ -66,6 +122,11 @@ def __init__(self, encoder: PreTrainedModel, pad_token_id: int | None = None): def __init_subclass__(cls) -> None: StrategyEmbedderClsMap[cls.pooling_strategy] = cls + def forward( + self, input_ids: torch.Tensor, mask: torch.Tensor | None = None + ) -> torch.Tensor: + raise NotImplementedError + def save_pretrained(self, path: str | Path): self.encoder.save_pretrained(path) @@ -74,14 +135,42 @@ def from_pretrained(cls, model_name_or_path: str): encoder = load_hf_pretrained_model(model_name_or_path) return cls(encoder) + @property + def max_length(self): + return self.encoder.config.max_position_embeddings + + +class LastEmbedder(Embedder): + pooling_strategy: ClassVar[PoolingStrategy] = PoolingStrategy.last_mean + + def forward( + self, input_ids: torch.Tensor, attention_mask: torch.Tensor | None = None + ) -> torch.Tensor: + if attention_mask is None: + attention_mask = creat_attention_mask_from_input_ids( + input_ids, self.pad_token_id + ) + embeddings = self.encoder( + input_ids, attention_mask=attention_mask, output_hidden_states=True + ) + last_hidden_state = embeddings.hidden_states[-1] + embeddings = last_pooling(last_hidden_state, attention_mask) + return embeddings + class LastMeanEmbedder(Embedder): pooling_strategy: ClassVar[PoolingStrategy] = PoolingStrategy.mean - def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor | None = None) -> torch.Tensor: + def forward( + self, input_ids: torch.Tensor, attention_mask: torch.Tensor | None = None + ) -> torch.Tensor: if attention_mask is None: - attention_mask = creat_attention_mask_from_input_ids(input_ids, self.pad_token_id) - embeddings = self.encoder(input_ids, attention_mask=attention_mask).last_hidden_state + attention_mask = creat_attention_mask_from_input_ids( + input_ids, self.pad_token_id + ) + embeddings = self.encoder( + input_ids, attention_mask=attention_mask + ).last_hidden_state embeddings = mean_pooling(embeddings, attention_mask) return embeddings @@ -89,10 +178,16 @@ def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor | None = class ClsEmbedder(Embedder): pooling_strategy: ClassVar[PoolingStrategy] = PoolingStrategy.cls - def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor | None = None) -> torch.Tensor: + def forward( + self, input_ids: torch.Tensor, attention_mask: torch.Tensor | None = None + ) -> torch.Tensor: if attention_mask is None: - attention_mask = creat_attention_mask_from_input_ids(input_ids, self.pad_token_id) - embeddings = self.encoder(input_ids, attention_mask=attention_mask).last_hidden_state[:, 0] + attention_mask = creat_attention_mask_from_input_ids( + input_ids, self.pad_token_id + ) + embeddings = self.encoder( + input_ids, attention_mask=attention_mask + ).last_hidden_state[:, 0] return embeddings @@ -104,13 +199,18 @@ def __init__(self, embedder: Embedder): self.embedder = embedder def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs): - self.embedder.encoder.gradient_checkpointing_enable(gradient_checkpointing_kwargs) + self.embedder.encoder.gradient_checkpointing_enable( + gradient_checkpointing_kwargs + ) class ScalingLayer(torch.nn.Module): def __init__(self, origin_dim: int = 1024, scaling_dim: int = 1792): super().__init__() - self.linear = torch.nn.Linear(in_features=origin_dim, out_features=scaling_dim, bias=True) + self.linear = torch.nn.Linear( + in_features=origin_dim, out_features=scaling_dim, bias=True + ) + def forward(self, input): return self.linear(input) @@ -135,10 +235,11 @@ class STEmbedder(EmbedderForTrain): extend_pe (`bool`) extend position embedding to longer length, here we adopt a very simple method from: https://kexue.fm/archives/7947 - + Notation: - some parameter are hard-coded in this code, such as in_feature and out_feature of scaling layer, and nesting list of MRL. + some parameter are hard-coded in this code, such as in_feature and out_feature of scaling layer, and nesting list of MRL. """ + def __init__( self, model_name_or_path: str, @@ -150,45 +251,72 @@ def __init__( max_length: int = 512, ): pretrained_model = load_hf_pretrained_model(model_name_or_path) - embedder = StrategyEmbedderClsMap[PoolingStrategy(embedding_strategy)](pretrained_model) + embedder = StrategyEmbedderClsMap[PoolingStrategy(embedding_strategy)]( + pretrained_model + ) super().__init__(embedder) - self.retri_contrst_loss = build_loss('retri_contrast', temperature=0.01, use_all_pair=True) - self.cosent_loss = build_loss('cosent', temperature=0.05) - self.cls_contrast_loss = build_loss('cls_contrast', temperature=0.05) + self.retri_contrst_loss = build_loss( + "retri_contrast", temperature=0.01, use_all_pair=True + ) + self.cosent_loss = build_loss("cosent", temperature=0.05) + self.cls_contrast_loss = build_loss("cls_contrast", temperature=0.05) self.use_mrl = use_mrl self.add_scaling_layer = add_scaling_layer if add_scaling_layer: - ''' + """ Here we hard code the scaling layer pretrain path, input_dim and output_dim, you can modify it by yourself - ''' + """ self.scaling_layer = ScalingLayer(origin_dim=1024, scaling_dim=1792) - if os.path.exists(os.path.join(model_name_or_path, '2_Dense/pytorch_model.bin')): - scaling_layer_state_dict = torch.load(os.path.join(model_name_or_path, '2_Dense/pytorch_model.bin')) - self.scaling_layer.load_state_dict(scaling_layer_state_dict, strict=True) - print('load scaling layer successfully') + if os.path.exists( + os.path.join(model_name_or_path, "2_Dense/pytorch_model.bin") + ): + scaling_layer_state_dict = torch.load( + os.path.join(model_name_or_path, "2_Dense/pytorch_model.bin") + ) + self.scaling_layer.load_state_dict( + scaling_layer_state_dict, strict=True + ) + print("load scaling layer successfully") else: - print('not found pretrain, random init scaling layer') + print("not found pretrain, random init scaling layer") if use_mrl: - self.mrl_nesting_list = [256, 512, 768, 1024, 1280, 1536, 1792] # hard code here - + self.mrl_nesting_list = [ + 256, + 512, + 768, + 1024, + 1280, + 1536, + 1792, + ] # hard code here + if extend_pe: - sp = 0 # TODO, hard code here, for xlm roberta, this should be 2 + sp = 0 # TODO, hard code here, for xlm roberta, this should be 2 # re-init the position embeddings org_pe = self.embedder.encoder.embeddings.position_embeddings pad_idx = self.embedder.encoder.embeddings.position_embeddings.padding_idx - extended_pe = torch.nn.Embedding(max_length + sp, org_pe.embedding_dim, padding_idx=pad_idx) - for start_idx in range(0, max_length + sp, org_pe.num_embeddings): # 迭代式地去复制,从而扩增embedding + extended_pe = torch.nn.Embedding( + max_length + sp, org_pe.embedding_dim, padding_idx=pad_idx + ) + for start_idx in range( + 0, max_length + sp, org_pe.num_embeddings + ): # 迭代式地去复制,从而扩增embedding end_idx = min(start_idx + org_pe.num_embeddings, max_length + sp) - extended_pe.weight.data[start_idx : end_idx] = org_pe.weight.data[:end_idx - start_idx].clone() + extended_pe.weight.data[start_idx:end_idx] = org_pe.weight.data[ + : end_idx - start_idx + ].clone() self.embedder.encoder.embeddings.position_embeddings = extended_pe - self.embedder.encoder.embeddings.position_ids = torch.arange(max_length + sp).expand((1, -1)) - self.embedder.encoder.embeddings.token_type_ids = \ - torch.zeros(self.embedder.encoder.embeddings.position_ids.size(), dtype=torch.long) + self.embedder.encoder.embeddings.position_ids = torch.arange( + max_length + sp + ).expand((1, -1)) + self.embedder.encoder.embeddings.token_type_ids = torch.zeros( + self.embedder.encoder.embeddings.position_ids.size(), dtype=torch.long + ) self.embedder.encoder.config.max_position_embeddings = max_length + sp - if not extend_pe and freeze_pos_emb: # extend pe时, 不能 freeze pos emb + if not extend_pe and freeze_pos_emb: # extend pe时, 不能 freeze pos emb for name, param in self.embedder.encoder.embeddings.named_parameters(): if "position_embeddings" in name: param.requires_grad = False @@ -200,15 +328,20 @@ def get_embedding(self, text_ids): if self.add_scaling_layer: text_embeddings = self.scaling_layer(text_embeddings.half()).float() return text_embeddings - + def compute_cls_loss(self, text_ids: torch.Tensor, text_labels: torch.tensor): text_embeddings = self.get_embedding(text_ids) pred_cls = self.cls_head(text_embeddings.half()) loss = torch.nn.functional.cross_entropy(pred_cls, text_labels) - return {'loss': loss} + return {"loss": loss} - def compute_cls_contrast_loss(self, text_ids: torch.Tensor, text_pos_ids: torch.Tensor, - text_neg_ids: torch.Tensor = None, type: str = 'cls_contrast') -> dict[str, torch.Tensor]: + def compute_cls_contrast_loss( + self, + text_ids: torch.Tensor, + text_pos_ids: torch.Tensor, + text_neg_ids: torch.Tensor = None, + type: str = "cls_contrast", + ) -> dict[str, torch.Tensor]: text_embeddings = self.get_embedding(text_ids) text_pos_embeddings = self.get_embedding(text_pos_ids) text_neg_embeddings = self.get_embedding(text_neg_ids) @@ -216,15 +349,29 @@ def compute_cls_contrast_loss(self, text_ids: torch.Tensor, text_pos_ids: torch. if self.use_mrl: loss = torch.tensor(0.0, device=text_embeddings.device) for num_feat in self.mrl_nesting_list: - emb, pos_emb, neg_emb = text_embeddings[..., :num_feat], text_pos_embeddings[..., :num_feat], text_neg_embeddings[..., :num_feat] - loss += self.cls_contrast_loss(emb, pos_emb, neg_emb) / len(self.mrl_nesting_list) + emb, pos_emb, neg_emb = ( + text_embeddings[..., :num_feat], + text_pos_embeddings[..., :num_feat], + text_neg_embeddings[..., :num_feat], + ) + loss += self.cls_contrast_loss(emb, pos_emb, neg_emb) / len( + self.mrl_nesting_list + ) else: - loss = self.cls_contrast_loss(text_embeddings, text_pos_embeddings, text_neg_embeddings) - print('cls contrast loss: ', loss) - return {'loss': loss} + loss = self.cls_contrast_loss( + text_embeddings, text_pos_embeddings, text_neg_embeddings + ) + print("cls contrast loss: ", loss) + return {"loss": loss} - def compute_retri_contrast_loss(self, text_ids: torch.Tensor, text_pos_ids: torch.Tensor, text_neg_ids: torch.Tensor = None, - type: str = 'retri_contrast', **kwargs) -> dict[str, torch.Tensor]: + def compute_retri_contrast_loss( + self, + text_ids: torch.Tensor, + text_pos_ids: torch.Tensor, + text_neg_ids: torch.Tensor = None, + type: str = "retri_contrast", + **kwargs, + ) -> dict[str, torch.Tensor]: text_embeddings = self.get_embedding(text_ids) text_pos_embeddings = self.get_embedding(text_pos_ids) text_neg_embeddings = self.get_embedding(text_neg_ids) @@ -233,38 +380,225 @@ def compute_retri_contrast_loss(self, text_ids: torch.Tensor, text_pos_ids: torc loss = torch.tensor(0.0, device=text_embeddings.device) for num_feat in self.mrl_nesting_list: if text_neg_embeddings is not None: - emb, pos_emb, neg_emb = text_embeddings[..., :num_feat], text_pos_embeddings[..., :num_feat], text_neg_embeddings[..., :num_feat] + emb, pos_emb, neg_emb = ( + text_embeddings[..., :num_feat], + text_pos_embeddings[..., :num_feat], + text_neg_embeddings[..., :num_feat], + ) else: - emb, pos_emb = text_embeddings[..., :num_feat], text_pos_embeddings[..., :num_feat] + emb, pos_emb = ( + text_embeddings[..., :num_feat], + text_pos_embeddings[..., :num_feat], + ) neg_emb = None - loss += self.retri_contrst_loss(emb, pos_emb, neg_emb, **kwargs) / len(self.mrl_nesting_list) + loss += self.retri_contrst_loss(emb, pos_emb, neg_emb, **kwargs) / len( + self.mrl_nesting_list + ) else: - loss = self.retri_contrst_loss(text_embeddings, text_pos_embeddings, text_neg_embeddings, **kwargs) - print('triplet loss: ', loss) - return {'loss': loss} + loss = self.retri_contrst_loss( + text_embeddings, text_pos_embeddings, text_neg_embeddings, **kwargs + ) + print("triplet loss: ", loss) + return {"loss": loss} - def compute_cosent_loss(self, text_ids: torch.Tensor, text_pair_ids: torch.Tensor, labels: torch.Tensor, type: str = 'cosent'): + def compute_cosent_loss( + self, + text_ids: torch.Tensor, + text_pair_ids: torch.Tensor, + labels: torch.Tensor, + type: str = "cosent", + ): text_embeddings = self.get_embedding(text_ids) text_pair_embeddings = self.get_embedding(text_pair_ids) if self.use_mrl: loss = torch.tensor(0.0, device=text_embeddings.device) for num_feat in self.mrl_nesting_list: - emb, emb_pair = text_embeddings[..., :num_feat], text_pair_embeddings[..., :num_feat] + emb, emb_pair = ( + text_embeddings[..., :num_feat], + text_pair_embeddings[..., :num_feat], + ) predict_labels = torch.cosine_similarity(emb, emb_pair, dim=-1) - loss += self.cosent_loss(predict_labels, labels) / len(self.mrl_nesting_list) + loss += self.cosent_loss(predict_labels, labels) / len( + self.mrl_nesting_list + ) else: - predict_labels = torch.cosine_similarity(text_embeddings, text_pair_embeddings, dim=-1) + predict_labels = torch.cosine_similarity( + text_embeddings, text_pair_embeddings, dim=-1 + ) loss = self.cosent_loss(predict_labels, labels) - print('cosent loss: ', loss) - return {'loss': loss, 'predict_labels': predict_labels} + print("cosent loss: ", loss) + return {"loss": loss, "predict_labels": predict_labels} - def forward(self, **kwargs): - if kwargs['type'] == 'cls_contrast': + if kwargs["type"] == "cls_contrast": return self.compute_cls_contrast_loss(**kwargs) - elif kwargs['type'] == 'retri_contrast': + elif kwargs["type"] == "retri_contrast": return self.compute_retri_contrast_loss(**kwargs) - elif kwargs['type'] == 'cosent': + elif kwargs["type"] == "cosent": return self.compute_cosent_loss(**kwargs) else: - raise NotImplementedError('not suuport current input kwargs') \ No newline at end of file + raise NotImplementedError("not suuport current input kwargs") + + +class GPTEmbedder(EmbedderForTrain): + def __init__( + self, + model_name_or_path: str, + loss_kwargs: dict, + embedding_strategy: PoolingStrategy | str = PoolingStrategy.last_mean, + freeze_pos_emb: bool = False, + add_scaling_layer: bool = False, + use_mrl: bool = False, + add_cls_head: bool = False, + ): + pretrained_model = load_hf_pretrained_model(model_name_or_path) + embedder = StrategyEmbedderClsMap[PoolingStrategy(embedding_strategy)]( + pretrained_model + ) + super().__init__(embedder) + self.criterion = build_loss(**loss_kwargs) + self.cosent_loss = build_loss("cosent", temperature=0.05) + self.cls_contrast_loss = build_loss("cls_contrast", temperature=0.05) + self.use_mrl = use_mrl + self.add_scaling_layer = add_scaling_layer + + if add_scaling_layer: + scaling_layer_state_dict = torch.load( + os.path.join(model_name_or_path, "2_Dense/pytorch_model.bin") + ) + self.scaling_layer = ScalingLayer( + origin_dim=1024, scaling_dim=1792 + ) # hard code here + self.scaling_layer.load_state_dict(scaling_layer_state_dict, strict=True) + + if use_mrl: + self.mrl_nesting_list = [ + 256, + 512, + 768, + 1024, + 1280, + 1536, + 1792, + ] # hard code here + + if freeze_pos_emb: + for name, param in self.embedder.encoder.embeddings.named_parameters(): + if "position_embeddings" in name: + param.requires_grad = False + + if add_cls_head: + self.cls_head = torch.nn.Linear(1024, 2) # hard code here + + def get_embedding(self, text_ids): + if text_ids is None: + return None + text_embeddings = self.embedder(text_ids) + if self.add_scaling_layer: + text_embeddings = self.scaling_layer(text_embeddings.half()).float() + return text_embeddings + + def compute_cls_loss(self, text_ids: torch.Tensor, text_labels: torch.tensor): + text_embeddings = self.get_embedding(text_ids) + pred_cls = self.cls_head(text_embeddings.half()) + loss = torch.nn.functional.cross_entropy(pred_cls, text_labels) + return {"loss": loss} + + def compute_cls_contrast_loss( + self, + text_ids: torch.Tensor, + text_pos_ids: torch.Tensor, + text_neg_ids: torch.Tensor = None, + type: str = "cls_contrast", + ) -> dict[str, torch.Tensor]: + text_embeddings = self.get_embedding(text_ids) + text_pos_embeddings = self.get_embedding(text_pos_ids) + text_neg_embeddings = self.get_embedding(text_neg_ids) + + if self.use_mrl: + loss = torch.tensor(0.0, device=text_embeddings.device) + for num_feat in self.mrl_nesting_list: + emb, pos_emb, neg_emb = ( + text_embeddings[..., :num_feat], + text_pos_embeddings[..., :num_feat], + text_neg_embeddings[..., :num_feat], + ) + loss += self.cls_contrast_loss(emb, pos_emb, neg_emb) / len( + self.mrl_nesting_list + ) + else: + loss = self.cls_contrast_loss( + text_embeddings, text_pos_embeddings, text_neg_embeddings + ) + print("cls contrast loss: ", loss) + return {"loss": loss} + + def compute_triplet_loss( + self, + text_ids: torch.Tensor, + text_pos_ids: torch.Tensor, + text_neg_ids: torch.Tensor = None, + type: str = "triplet_loss", + **kwargs, + ) -> dict[str, torch.Tensor]: + text_embeddings = self.get_embedding(text_ids) + text_pos_embeddings = self.get_embedding(text_pos_ids) + text_neg_embeddings = self.get_embedding(text_neg_ids) + + if self.use_mrl: + loss = torch.tensor(0.0, device=text_embeddings.device) + for num_feat in self.mrl_nesting_list: + emb, pos_emb, neg_emb = ( + text_embeddings[..., :num_feat], + text_pos_embeddings[..., :num_feat], + text_neg_embeddings[..., :num_feat], + ) + loss += self.criterion(emb, pos_emb, neg_emb, **kwargs) / len( + self.mrl_nesting_list + ) + else: + loss = self.criterion( + text_embeddings, text_pos_embeddings, text_neg_embeddings, **kwargs + ) + print("triplet loss: ", loss) + return {"loss": loss} + + def compute_scored_pair_loss( + self, + text_ids: torch.Tensor, + text_pair_ids: torch.Tensor, + labels: torch.Tensor, + type: str = "cosent", + ): + text_embeddings = self.get_embedding(text_ids) + text_pair_embeddings = self.get_embedding(text_pair_ids) + if self.use_mrl: + loss = torch.tensor(0.0, device=text_embeddings.device) + for num_feat in self.mrl_nesting_list: + emb, emb_pair = ( + text_embeddings[..., :num_feat], + text_pair_embeddings[..., :num_feat], + ) + predict_labels = torch.cosine_similarity(emb, emb_pair, dim=-1) + loss += self.cosent_loss(predict_labels, labels) / len( + self.mrl_nesting_list + ) + else: + predict_labels = torch.cosine_similarity( + text_embeddings, text_pair_embeddings, dim=-1 + ) + loss = self.cosent_loss(predict_labels, labels) + print("cosent loss: ", loss) + return {"loss": loss, "predict_labels": predict_labels} + + def forward(self, **kwargs): + if "type" in kwargs and "cls_contrast" == kwargs["type"]: + return self.compute_cls_contrast_loss(**kwargs) + elif "text_ids" in kwargs and "text_pos_ids" in kwargs: + return self.compute_triplet_loss(**kwargs) + elif "text_ids" in kwargs and "text_pair_ids" in kwargs and "labels" in kwargs: + return self.compute_scored_pair_loss(**kwargs) + elif "text_ids" in kwargs and "text_labels" in kwargs: + return self.compute_cls_loss(**kwargs) + else: + raise NotImplementedError("not suuport current input kwargs") diff --git a/scripts/ft.sh b/scripts/ft.sh index 3f85760..9ee6551 100644 --- a/scripts/ft.sh +++ b/scripts/ft.sh @@ -1,4 +1,4 @@ -ROOT=/mnt/lustre/huangjunqin/piccolo +ROOT=/mnt/lustre/jingzihao/piccolo-embedding export PYTHONPATH=$ROOT:${PYTHONPATH} # SLURM Parameter @@ -16,7 +16,7 @@ EPOCHS=3 BATCH_SIZE=8 LR=1e-5 NEG_NUM=1 -DS_PATH=$ROOT/ds_config_zero1.json +DS_PATH=$ROOT/data_example/deepspeed_config.json MAX_LENGTH=512 META_PATHS=( meta_lists/piccolo-ft.txt diff --git a/scripts/ft_gpt.sh b/scripts/ft_gpt.sh new file mode 100644 index 0000000..440272e --- /dev/null +++ b/scripts/ft_gpt.sh @@ -0,0 +1,96 @@ +ROOT=/mnt/lustre/jingzihao/piccolo-embedding +export PYTHONPATH=$ROOT:${PYTHONPATH} + +# SLURM Parameter +GPUS_PER_NODE=8 +if [ -z "$WORLD_SIZE" ]; then + WORLD_SIZE=1 + RANK=0 + MASTER_ADDR=127.0.0.1 + MASTER_PORT=6000 +fi + +# WorkSpace Param +JOBNAME='Fine-Tuning-GPT' +LOGDIR=$ROOT/logs +mkdir -p $LOGDIR + +# Hyper Parameter Start +PRETRAIN_MODEL_NAME=Internlm-1_8B +MODEL_PATH=/mnt/lustre/huangjunqin/model # No '/' ended +EPOCHS=3 +BATCH_SIZE=4 +LR=1e-5 +NEG_NUM=1 +DS_PATH=$ROOT/data_example/deepspeed_config.json +TEMPRATURE=0.01 +OUTPUT_NAME=$PRETRAIN_MODEL_NAME.e$EPOCHS.lr$LR.B$BATCH_SIZE.Neg$NEG_NUM.G$WORLD_SIZE.$JOBNAME +MAX_LENGTH=512 +META_PATHS=( +$ROOT/meta_lists/piccolo-ft.txt +) +ROOT_DIRS=( +$ROOT/data_example +) +# Hyper Parameter End + + +# Model Parameter +model_args=( + "--model_name_or_path" "$MODEL_PATH/$PRETRAIN_MODEL_NAME/" + "--loss_type=hardneg_softmax" + "--temperature=$TEMPRATURE" + "--max_length=$MAX_LENGTH" + "--query_prefix=''" + "--doc_prefix=''" + "--use_scaling_layer=False" + "--use_mrl=True" +) + +# Data Parameter +data_args=( + "--meta_paths" "${META_PATHS[@]}" + "--root_dirs" "${ROOT_DIRS[@]}" + "--neg_num=$NEG_NUM" + "--use_all_pair=True" +) + +# Train Parameter +train_args=( + "--fp16" + "--gradient_checkpointing=True" + "--with_instruction=True" + "--output_dir=$ROOT/outputs" + "--num_train_epochs=$EPOCHS" + "--dataloader_num_workers=0" + "--batch_size=$BATCH_SIZE" + "--learning_rate=$LR" + "--deepspeed=$DS_PATH" + "--logging_steps=500" + "--save_safetensors=False" + "--report_to=tensorboard" + "--save_strategy=epoch" + "--per_device_train_batch_size=1" + "--use_optimum=False" +) + +# Merged Parameters +all_args=("${model_args[@]}" "${data_args[@]}" "${train_args[@]}") + + +export LAUNCHER="python -u -m torch.distributed.run \ + --nproc_per_node $GPUS_PER_NODE \ + --nnodes $WORLD_SIZE \ + --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT \ + --rdzv_backend c10d \ + --max_restarts 0 \ + --tee 3 \ + " + +export CMD=" \ + $ROOT/finetune/train_gpt.py \ + " + +echo $CMD + +bash -c "$LAUNCHER $CMD ${all_args[*]}" 2>&1 | tee -a $LOGDIR/$OUTPUT_NAME.txt \ No newline at end of file