diff --git a/subspace_decoder/configs/seqlen128_mla-on_q96_k64_o96.json b/subspace_decoder/configs/seqlen128_mla-on_q96_k64_o96.json new file mode 100644 index 0000000..5ca5da3 --- /dev/null +++ b/subspace_decoder/configs/seqlen128_mla-on_q96_k64_o96.json @@ -0,0 +1,106 @@ +{ + "shorthand": "seqlen.128 - mla-on.96.64.96 - mlp.1024 - model.256.lyr.6 - ah.8.32", + "notes": "MLA-o with sequential decomp and norm. Sequence length is 128.", + "model": { + "hidden_size": 256, + "num_hidden_layers": 6, + "num_nextn_predict_layers": 1, + "moe_intermediate_size": 256, + "intermediate_size": 1024, + + "n_shared_experts": 0, + "n_routed_experts": 4, + "num_experts_per_tok": 2, + "moe_layer_freq": 0, + "first_k_dense_replace": 0, + + "ep_size": 1, + "routed_scaling_factor": 1, + "topk_method": "softmax_aux", + "n_group": 1, + "topk_group": 1, + "norm_topk_prob": true, + "scoring_func": "softmax", + "hidden_act": "silu", + "use_cache": false, + "pad_token_id": 50256, + "bos_token_id": 50256, + "eos_token_id": 50256, + "tie_word_embeddings": true, + "attention_dropout": 0.0, + "hidden_dropout_prob": 0.1, + "classifier_dropout": null, + "initializer_range": 0.02, + "rms_norm_eps": 1e-06, + "vocab_size": 50257, + "rope_theta": 10000.0, + "rope_scaling": null, + "max_position_embeddings": 1024, + "kv_lora_rank": 64, + "q_lora_rank": 96, + "qk_rope_head_dim": 32, + "v_head_dim": 32, + "qk_nope_head_dim": 0, + "num_attention_heads": 8, + "num_key_value_heads": 8, + "attention_bias": false, + "use_output_subspace": true, + "o_proj_variant": "sequential_norm", + "o_latent_dim": 96, + "attention_backend": "flash_attention_2" + }, + "pre_train": { + "output_dir": "checkpoints/seqlen128_mla-on_q96_k64_o96", + "seed": 42, + "train_batch_size": 128, + "gradient_accumulation_steps": 8, + "learning_rate": 0.0005, + "num_train_steps": 12500, + "eval_steps": 1000, + "weight_decay": 0.01, + "num_workers": 8, + "pin_memory": true, + "dataset_name": "wikitext", + "dataset_config": "wikitext-103-raw-v1", + "max_seq_length": 128, + "eval_batch_size": 32, + "bf16": true, + "fp16": false, + "torch_compile": true, + "torch_compile_backend": "inductor", + "torch_compile_mode": "default" + }, + "fine_tune": { + "task": "sst2", + "tokenizer_name_or_path": "gpt2", + "method": "lm_label_words", + "label_words": { + "0": " negative", + "1": " positive" + }, + "train_batch_size": 256, + "gradient_accumulation_steps": 1, + "eval_batch_size": 256, + "learning_rate": 5e-05, + "weight_decay": 0.05, + "max_steps": 1500, + "warmup_ratio": 0.1, + "eval_steps": 150, + "logging_steps": 20, + "save_strategy": "no", + "save_total_limit": 0, + "report_to_wandb": true, + "bf16": true, + "fp16": false, + "torch_compile": false, + "torch_compile_backend": null, + "torch_compile_mode": null, + "lora": { + "enabled": false + }, + "seed": 42, + "max_seq_length": 128, + "output_dir": "checkpoints/seqlen128_mla-on_q96_k64_o96/ft_sst2", + "run_name": "ft-sst2 - seqlen.128 - mla-on.96.64.96 - mlp.1024 - model.256.lyr.6 - ah.8.32" + } +} diff --git a/subspace_decoder/configs/seqlen128_mla_q96_k64.json b/subspace_decoder/configs/seqlen128_mla_q96_k64.json new file mode 100644 index 0000000..ace0c01 --- /dev/null +++ b/subspace_decoder/configs/seqlen128_mla_q96_k64.json @@ -0,0 +1,106 @@ +{ + "shorthand": "seqlen.128 - mla.96.64 - mlp.1024 - model.256.lyr.6 - ah.8.32", + "notes": "Baseline MLA, Sequence length is 128.", + "model": { + "hidden_size": 256, + "num_hidden_layers": 6, + "num_nextn_predict_layers": 1, + "moe_intermediate_size": 256, + "intermediate_size": 1024, + + "n_shared_experts": 0, + "n_routed_experts": 4, + "num_experts_per_tok": 2, + "moe_layer_freq": 0, + "first_k_dense_replace": 0, + + "ep_size": 1, + "routed_scaling_factor": 1, + "topk_method": "softmax_aux", + "n_group": 1, + "topk_group": 1, + "norm_topk_prob": true, + "scoring_func": "softmax", + "hidden_act": "silu", + "use_cache": false, + "pad_token_id": 50256, + "bos_token_id": 50256, + "eos_token_id": 50256, + "tie_word_embeddings": true, + "attention_dropout": 0.0, + "hidden_dropout_prob": 0.1, + "classifier_dropout": null, + "initializer_range": 0.02, + "rms_norm_eps": 1e-06, + "vocab_size": 50257, + "rope_theta": 10000.0, + "rope_scaling": null, + "max_position_embeddings": 1024, + "kv_lora_rank": 64, + "q_lora_rank": 96, + "qk_rope_head_dim": 32, + "v_head_dim": 32, + "qk_nope_head_dim": 0, + "num_attention_heads": 8, + "num_key_value_heads": 8, + "attention_bias": false, + "use_output_subspace": false, + "o_proj_variant": "vanilla", + "o_latent_dim": null, + "attention_backend": "flash_attention_2" + }, + "pre_train": { + "output_dir": "checkpoints/seqlen128_mla_q96_k64", + "seed": 42, + "train_batch_size": 128, + "gradient_accumulation_steps": 8, + "learning_rate": 0.0005, + "num_train_steps": 12500, + "eval_steps": 1000, + "weight_decay": 0.01, + "num_workers": 8, + "pin_memory": true, + "dataset_name": "wikitext", + "dataset_config": "wikitext-103-raw-v1", + "max_seq_length": 128, + "eval_batch_size": 32, + "bf16": true, + "fp16": false, + "torch_compile": true, + "torch_compile_backend": "inductor", + "torch_compile_mode": "default" + }, + "fine_tune": { + "task": "sst2", + "tokenizer_name_or_path": "gpt2", + "method": "lm_label_words", + "label_words": { + "0": " negative", + "1": " positive" + }, + "train_batch_size": 256, + "gradient_accumulation_steps": 1, + "eval_batch_size": 256, + "learning_rate": 5e-05, + "weight_decay": 0.05, + "max_steps": 1500, + "warmup_ratio": 0.1, + "eval_steps": 150, + "logging_steps": 20, + "save_strategy": "no", + "save_total_limit": 0, + "report_to_wandb": true, + "bf16": true, + "fp16": false, + "torch_compile": false, + "torch_compile_backend": null, + "torch_compile_mode": null, + "lora": { + "enabled": false + }, + "seed": 42, + "max_seq_length": 128, + "output_dir": "checkpoints/seqlen128_mla_q96_k64/ft_sst2", + "run_name": "ft-sst2 - seqlen.128 - mla.96.64 - mlp.1024 - model.256.lyr.6 - ah.8.32" + } +} diff --git a/subspace_decoder/scripts/finetune_sst2.py b/subspace_decoder/scripts/finetune_sst2.py index 2505558..5f6c745 100644 --- a/subspace_decoder/scripts/finetune_sst2.py +++ b/subspace_decoder/scripts/finetune_sst2.py @@ -26,13 +26,13 @@ set_seed, ) -from utils import summarize_parameters, format_size - # Project import path (same pattern as your train.py) PROJECT_ROOT = Path(__file__).resolve().parents[1] if str(PROJECT_ROOT) not in sys.path: sys.path.insert(0, str(PROJECT_ROOT)) +from utils import summarize_parameters, format_size + from layers.patch_o_proj import load_checkpoint_state_dict, load_and_patch_model, Variant from transformers import DeepseekV3Config, DeepseekV3ForCausalLM diff --git a/subspace_decoder/scripts/train.py b/subspace_decoder/scripts/train.py index 68e6c24..772371f 100644 --- a/subspace_decoder/scripts/train.py +++ b/subspace_decoder/scripts/train.py @@ -35,10 +35,6 @@ set_seed, ) -from utils import summarize_parameters, format_size -# To disable a warning. -os.environ["TOKENIZERS_PARALLELISM"] = "false" - # Make sure we can import modules from the decoder package PROJECT_ROOT = Path(__file__).resolve().parents[1] @@ -47,6 +43,10 @@ if str(PROJECT_ROOT) not in sys.path: sys.path.insert(0, str(PROJECT_ROOT)) +from utils import summarize_parameters, format_size +# To disable a warning. +os.environ["TOKENIZERS_PARALLELISM"] = "false" + from layers.patch_o_proj import patch_o_proj_implementation from transformers import DeepseekV3Config, DeepseekV3ForCausalLM