Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def setup_extras_post(self, extras: Extras, models: Models, optimizers: Optimize

# perform the training here
@abstractmethod
def train(self, data: Data, extras: Extras, models: Models, optimizers: Optimizers, schedulers: Schedulers):
def train(self, data: Data, extras: Extras, models: Models, optimizers: Optimizers, schedulers: Schedulers, single_gpu: bool=False):
raise NotImplementedError("This method needs to be overriden")
# ------------

Expand Down Expand Up @@ -357,7 +357,7 @@ def __call__(self, single_gpu=False):
# TRAIN
if self.is_main_node:
print("**TRAINING STARTING...**")
self.train(data, extras, models, optimizers, schedulers)
self.train(data, extras, models, optimizers, schedulers, single_gpu)

if single_gpu is False:
barrier()
Expand Down
9 changes: 5 additions & 4 deletions train/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def decode_latents(self, latents: torch.Tensor, batch: dict, models: Models, ext
raise NotImplementedError("This method needs to be overriden")

def train(self, data: WarpCore.Data, extras: WarpCore.Extras, models: Models, optimizers: Optimizers,
schedulers: WarpCore.Schedulers):
schedulers: WarpCore.Schedulers, single_gpu: bool=False):
start_iter = self.info.iter + 1
max_iters = self.config.updates * self.config.grad_accum_steps
if self.is_main_node:
Expand Down Expand Up @@ -304,13 +304,14 @@ def train(self, data: WarpCore.Data, extras: WarpCore.Extras, models: Models, op
'bucket_ranges': extras.gdf.loss_weight.bucket_ranges.tolist(),
'bucket_losses': extras.gdf.loss_weight.bucket_losses.tolist(),
}
self.save_checkpoints(models, optimizers)
self.save_checkpoints(models, optimizers, single_gpu=single_gpu)
if self.is_main_node:
create_folder_if_necessary(f'{self.config.output_path}/{self.config.experiment_id}/')
self.sample(models, data, extras)

def save_checkpoints(self, models: Models, optimizers: Optimizers, suffix=None):
barrier()
def save_checkpoints(self, models: Models, optimizers: Optimizers, suffix=None, single_gpu=False):
if single_gpu:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't this be 'if not single_gpu:'?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that's fixed in this commit.
c00a449

barrier()
suffix = '' if suffix is None else suffix
self.save_info(self.info, suffix=suffix)
models_dict = models.to_dict()
Expand Down
6 changes: 3 additions & 3 deletions train/train_c_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,11 +320,11 @@ def decode_latents(self, latents: torch.Tensor, batch: dict, models: Models, ext

if __name__ == '__main__':
print("Launching Script")
single_gpu = bool(sys.argv[2])
warpcore = WurstCore(
config_file_path=sys.argv[1] if len(sys.argv) > 1 else None,
device=torch.device(int(os.environ.get("SLURM_LOCALID")))
device=torch.device(int(os.environ.get("SLURM_LOCALID")) if not single_gpu else 0)
)
warpcore.fsdp_defaults['sharding_strategy'] = ShardingStrategy.NO_SHARD

# RUN TRAINING
warpcore()
warpcore(single_gpu=single_gpu)