From 55e00b467e5176fc3c723435252492b5b751dd5e Mon Sep 17 00:00:00 2001 From: k8tems Date: Thu, 15 Feb 2024 02:39:14 +0000 Subject: [PATCH 1/5] Add single gpu support for LoRAs --- core/__init__.py | 4 ++-- train/base.py | 9 +++++---- train/train_c_lora.py | 6 +++--- 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/core/__init__.py b/core/__init__.py index 03af283..efaf22d 100644 --- a/core/__init__.py +++ b/core/__init__.py @@ -112,7 +112,7 @@ def setup_extras_post(self, extras: Extras, models: Models, optimizers: Optimize # perform the training here @abstractmethod - def train(self, data: Data, extras: Extras, models: Models, optimizers: Optimizers, schedulers: Schedulers): + def train(self, data: Data, extras: Extras, models: Models, optimizers: Optimizers, schedulers: Schedulers, single_gpu: bool=False): raise NotImplementedError("This method needs to be overriden") # ------------ @@ -357,7 +357,7 @@ def __call__(self, single_gpu=False): # TRAIN if self.is_main_node: print("**TRAINING STARTING...**") - self.train(data, extras, models, optimizers, schedulers) + self.train(data, extras, models, optimizers, schedulers, single_gpu) if single_gpu is False: barrier() diff --git a/train/base.py b/train/base.py index 4e8a6ef..f1c4dcf 100755 --- a/train/base.py +++ b/train/base.py @@ -239,7 +239,7 @@ def decode_latents(self, latents: torch.Tensor, batch: dict, models: Models, ext raise NotImplementedError("This method needs to be overriden") def train(self, data: WarpCore.Data, extras: WarpCore.Extras, models: Models, optimizers: Optimizers, - schedulers: WarpCore.Schedulers): + schedulers: WarpCore.Schedulers, single_gpu: bool=False): start_iter = self.info.iter + 1 max_iters = self.config.updates * self.config.grad_accum_steps if self.is_main_node: @@ -304,13 +304,14 @@ def train(self, data: WarpCore.Data, extras: WarpCore.Extras, models: Models, op 'bucket_ranges': extras.gdf.loss_weight.bucket_ranges.tolist(), 'bucket_losses': extras.gdf.loss_weight.bucket_losses.tolist(), } - self.save_checkpoints(models, optimizers) + self.save_checkpoints(models, optimizers, single_gpu=single_gpu) if self.is_main_node: create_folder_if_necessary(f'{self.config.output_path}/{self.config.experiment_id}/') self.sample(models, data, extras) - def save_checkpoints(self, models: Models, optimizers: Optimizers, suffix=None): - barrier() + def save_checkpoints(self, models: Models, optimizers: Optimizers, suffix=None, single_gpu=False): + if single_gpu: + barrier() suffix = '' if suffix is None else suffix self.save_info(self.info, suffix=suffix) models_dict = models.to_dict() diff --git a/train/train_c_lora.py b/train/train_c_lora.py index 8b83eee..b9a54d0 100755 --- a/train/train_c_lora.py +++ b/train/train_c_lora.py @@ -320,11 +320,11 @@ def decode_latents(self, latents: torch.Tensor, batch: dict, models: Models, ext if __name__ == '__main__': print("Launching Script") + single_gpu = bool(sys.argv[2]) warpcore = WurstCore( config_file_path=sys.argv[1] if len(sys.argv) > 1 else None, - device=torch.device(int(os.environ.get("SLURM_LOCALID"))) + device=torch.device(int(os.environ.get("SLURM_LOCALID")) if not single_gpu else 0) ) warpcore.fsdp_defaults['sharding_strategy'] = ShardingStrategy.NO_SHARD - # RUN TRAINING - warpcore() + warpcore(single_gpu=single_gpu) From c00a44908dac95c9f2f456a989400d1ec95e143d Mon Sep 17 00:00:00 2001 From: k8tems Date: Thu, 15 Feb 2024 02:47:21 +0000 Subject: [PATCH 2/5] Call barrier when single_gpu is False --- train/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train/base.py b/train/base.py index f1c4dcf..7a2eb32 100755 --- a/train/base.py +++ b/train/base.py @@ -310,7 +310,7 @@ def train(self, data: WarpCore.Data, extras: WarpCore.Extras, models: Models, op self.sample(models, data, extras) def save_checkpoints(self, models: Models, optimizers: Optimizers, suffix=None, single_gpu=False): - if single_gpu: + if not single_gpu: barrier() suffix = '' if suffix is None else suffix self.save_info(self.info, suffix=suffix) From e6cc6555a5fcbdf88426aa605cc148f49d3147e5 Mon Sep 17 00:00:00 2001 From: k8tems Date: Thu, 15 Feb 2024 12:20:35 +0000 Subject: [PATCH 3/5] Also relay single_gpu to the recursive save_checkpoints invocation --- train/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train/base.py b/train/base.py index 7a2eb32..e2c84bb 100755 --- a/train/base.py +++ b/train/base.py @@ -326,7 +326,7 @@ def save_checkpoints(self, models: Models, optimizers: Optimizers, suffix=None, self.save_optimizer(optimizer, f'{key}_optim{suffix}', fsdp_model=models_dict[key] if self.config.use_fsdp else None) if suffix == '' and self.info.total_steps > 1 and self.info.total_steps % self.config.backup_every == 0: - self.save_checkpoints(models, optimizers, suffix=f"_{self.info.total_steps // 1000}k") + self.save_checkpoints(models, optimizers, suffix=f"_{self.info.total_steps // 1000}k", single_gpu=single_gpu) torch.cuda.empty_cache() def sample(self, models: Models, data: WarpCore.Data, extras: Extras): From fd33bf798e60c8c7f34948637a1501c58fb62f03 Mon Sep 17 00:00:00 2001 From: k8tems Date: Thu, 15 Feb 2024 12:42:23 +0000 Subject: [PATCH 4/5] Only reference argv[2] when there are sufficient parameters to preserve original CLI interface --- train/train_c_lora.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train/train_c_lora.py b/train/train_c_lora.py index b9a54d0..3adb863 100755 --- a/train/train_c_lora.py +++ b/train/train_c_lora.py @@ -320,7 +320,7 @@ def decode_latents(self, latents: torch.Tensor, batch: dict, models: Models, ext if __name__ == '__main__': print("Launching Script") - single_gpu = bool(sys.argv[2]) + single_gpu = bool(args[2]) if len(args) > 2 else False warpcore = WurstCore( config_file_path=sys.argv[1] if len(sys.argv) > 1 else None, device=torch.device(int(os.environ.get("SLURM_LOCALID")) if not single_gpu else 0) From e72b9705a1948cba36b7e30b23fd15ca56c4a72c Mon Sep 17 00:00:00 2001 From: k8tems Date: Fri, 16 Feb 2024 16:51:58 +0900 Subject: [PATCH 5/5] args -> sys.argv --- train/train_c_lora.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train/train_c_lora.py b/train/train_c_lora.py index 3adb863..c3f7c49 100755 --- a/train/train_c_lora.py +++ b/train/train_c_lora.py @@ -320,7 +320,7 @@ def decode_latents(self, latents: torch.Tensor, batch: dict, models: Models, ext if __name__ == '__main__': print("Launching Script") - single_gpu = bool(args[2]) if len(args) > 2 else False + single_gpu = bool(sys.argv[2]) if len(sys.argv) > 2 else False warpcore = WurstCore( config_file_path=sys.argv[1] if len(sys.argv) > 1 else None, device=torch.device(int(os.environ.get("SLURM_LOCALID")) if not single_gpu else 0)