Skip to content

Commit

Permalink
Merge pull request #12 from neph1/update-v0.10.0
Browse files Browse the repository at this point in the history
settings for fp8 training
  • Loading branch information
neph1 authored Jan 16, 2025
2 parents d065821 + 2ef4089 commit 4e9f0b2
Show file tree
Hide file tree
Showing 10 changed files with 27 additions and 15 deletions.
7 changes: 4 additions & 3 deletions config/config_categories.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
Dataset: data_root, video_column, caption_column, dataset_file, id_token, image_resolution_buckets, video_resolution_buckets, caption_dropout_p
Training: training_type, seed, mixed_precision, train_steps, rank, lora_alpha, target_modules, gradient_accumulation_steps, checkpointing_steps, checkpointing_limit, enable_slicing, enable_tiling, batch_size, resume_from_checkpoint
Dataset: data_root, video_column, caption_column, dataset_file, id_token, image_resolution_buckets, video_resolution_buckets, caption_dropout_p, precompute_conditions
Training: training_type, seed, train_steps, rank, lora_alpha, target_modules, gradient_accumulation_steps, checkpointing_steps, checkpointing_limit, enable_slicing, enable_tiling, batch_size, resume_from_checkpoint
Optimizer: optimizer, lr, beta1, beta2, epsilon, weight_decay, max_grad_norm, lr_scheduler, lr_num_cycles, lr_warmup_steps
Validation: validation_steps, validation_epochs, num_validation_videos, validation_prompts, validation_prompt_separator
Accelerate: gpu_ids, nccl_timeout, gradient_checkpointing, allow_tf32, dataloader_num_workers, report_to, accelerate_config
Accelerate: gpu_ids, nccl_timeout, gradient_checkpointing, allow_tf32, dataloader_num_workers, report_to, accelerate_config
Model: model_name, pretrained_model_name_or_path, text_encoder_dtype, text_encoder_2_dtype, text_encoder_3_dtype, vae_dtype, layerwise_upcasting_modules, layerwise_upcasting_storage_dtype, layerwise_upcasting_granularity
5 changes: 4 additions & 1 deletion config/config_template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,16 @@ gpu_ids: '0'
gradient_accumulation_steps: 4
gradient_checkpointing: true
id_token: afkx
layerwise_upcasting_modules: [none, transformer]
layerwise_upcasting_skip_modules_pattern: 'patch_embed pos_embed x_embedder context_embedder ^proj_in$ ^proj_out$ norm'
layerwise_upcasting_storage_dtype: [float8_e4m3fn, float8_e5m2]
image_resolution_buckets: 512x768
lora_alpha: 128
lr: 0.0001
lr_num_cycles: 1
lr_scheduler: ['linear', 'cosine', 'cosine_with_restarts', 'polynomial', 'constant', 'constant_with_warmup']
lr_warmup_steps: 400
max_grad_norm: 1.0
mixed_precision: [bf16, fp16, 'no']
model_name: ltx_video
nccl_timeout: 1800
num_validation_videos: 0
Expand All @@ -45,6 +47,7 @@ text_encoder_dtype: [bf16, fp16, fp32]
text_encoder_2_dtype: [bf16, fp16, fp32]
text_encoder_3_dtype: [bf16, fp16, fp32]
tracker_name: finetrainers
transformer_dtype: [bf16, fp16, fp32]
train_steps: 3000
training_type: lora
use_8bit_bnb: false
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "finetrainers-ui"
version = "0.9.3"
version = "0.10.0"
dependencies = [
"gradio",
"torch>=2.4.1"
Expand Down
13 changes: 11 additions & 2 deletions run_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,16 @@ def run(self, config: Config, finetrainers_path: str, log_file: str):
assert config.get('pretrained_model_name_or_path'), "pretrained_model_name_or_path required"

model_cmd = ["--model_name", config.get('model_name'),
"--pretrained_model_name_or_path", config.get('pretrained_model_name_or_path')]
"--pretrained_model_name_or_path", config.get('pretrained_model_name_or_path'),
"--text_encoder_dtype", config.get('text_encoder_dtype'),
"--text_encoder_2_dtype", config.get('text_encoder_2_dtype'),
"--text_encoder_3_dtype", config.get('text_encoder_3_dtype'),
"--vae_dtype", config.get('vae_dtype')]

if config.get('layerwise_upcasting_modules') != 'none':
model_cmd +=["--layerwise_upcasting_modules", config.get('layerwise_upcasting_modules'),
"--layerwise_upcasting_storage_dtype", config.get('layerwise_upcasting_storage_dtype'),
"--layerwise_upcasting_skip_modules_pattern", config.get('layerwise_upcasting_skip_modules_pattern')]

dataset_cmd = ["--data_root", config.get('data_root'),
"--video_column", config.get('video_column'),
Expand All @@ -36,6 +45,7 @@ def run(self, config: Config, finetrainers_path: str, log_file: str):
"--text_encoder_2_dtype", config.get('text_encoder_2_dtype'),
"--text_encoder_3_dtype", config.get('text_encoder_3_dtype'),
"--vae_dtype", config.get('vae_dtype'),
"--transformer_dtype", config.get('transformer_dtype'),
'--precompute_conditions' if config.get('precompute_conditions') else '']
if config.get('dataset_file'):
dataset_cmd += ["--dataset_file", config.get('dataset_file')]
Expand All @@ -47,7 +57,6 @@ def run(self, config: Config, finetrainers_path: str, log_file: str):

training_cmd = ["--training_type", config.get('training_type'),
"--seed", config.get('seed'),
"--mixed_precision", config.get('mixed_precision'),
"--batch_size", config.get('batch_size'),
"--train_steps", config.get('train_steps'),
"--rank", config.get('rank'),
Expand Down
2 changes: 1 addition & 1 deletion tabs/general_tab.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def __init__(self, title, config_file_path, allow_load=False):

try:
with self.settings_column:
inputs = self.update_form(self.config)
inputs = self.update_form()
self.components = OrderedDict(inputs)
children = []
for child in self.settings_column.children:
Expand Down
2 changes: 1 addition & 1 deletion tabs/prepare_tab.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def __init__(self, title, config_file_path, allow_load=False):

try:
with self.settings_column:
self.components = OrderedDict(self.update_form(self.config))
self.components = OrderedDict(self.update_form())
for i in range(len(self.settings_column.children)):
keys = list(self.components.keys())
properties[keys[i]] = self.settings_column.children[i]
Expand Down
6 changes: 3 additions & 3 deletions tabs/tab.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,10 @@ def add_buttons(self):
outputs=[self.save_status, self.config_file_box, *self.get_properties().values()]
)

def update_form(self, config):
def update_form(self):
inputs = dict()

for key, value in config.items():
for key, value in self.config.items():
category = 'Other'
for categories in self.config_categories.keys():
if key in self.config_categories[categories]:
Expand Down Expand Up @@ -114,6 +114,6 @@ def update_properties(self, *args):

properties_values[index] = value
#properties[key].value = value
return ["Config loaded. Edit below:", config_file_box, *properties_values]
return ["Config loaded.", config_file_box, *properties_values]
except Exception as e:
return [f"Error loading config: {e}", config_file_box, *properties_values]
2 changes: 1 addition & 1 deletion tabs/training_tab.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __init__(self, title, config_file_path, allow_load=False):

try:
with self.settings_column:
inputs = self.update_form(self.config)
inputs = self.update_form()
self.components = OrderedDict(inputs)
children = []
for child in self.settings_column.children:
Expand Down
2 changes: 1 addition & 1 deletion tabs/training_tab_legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def __init__(self, title, config_file_path, allow_load=False):

try:
with self.settings_column:
self.components = OrderedDict(self.update_form(self.config))
self.components = OrderedDict(self.update_form())
for i in range(len(self.settings_column.children)):
keys = list(self.components.keys())
properties[keys[i]] = self.settings_column.children[i]
Expand Down
1 change: 0 additions & 1 deletion trainer_config_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ def validate(self):
'lr_scheduler',
'lr_warmup_steps',
'max_grad_norm',
'mixed_precision',
'model_name',
'nccl_timeout',
'optimizer',
Expand Down

0 comments on commit 4e9f0b2

Please sign in to comment.