diff --git a/train.py b/train.py index c4a6ad9..45e0606 100644 --- a/train.py +++ b/train.py @@ -1,23 +1,24 @@ -from collections.abc import Callable import json -from pathlib import Path import random import re +from collections.abc import Callable +from pathlib import Path from typing import Any, Iterator, Optional -import wandb + import torch -import torch.optim as optim import torch.nn.functional as F +import torch.optim as optim +import wandb +from loss import GRPOLoss +from replay_buffer import Experience, join_experience_batch, ReplayBuffer from torch.nn.utils import clip_grad_norm_ from torch.utils.data import DataLoader from transformers import ( AutoTokenizer, - PreTrainedTokenizer, - LlamaForCausalLM, GenerationConfig, + LlamaForCausalLM, + PreTrainedTokenizer, ) -from loss import approx_kl_divergence, GRPOLoss -from replay_buffer import ReplayBuffer, Experience, join_experience_batch def load_model( @@ -193,7 +194,7 @@ def read_prompts( def main(): seed = 42 - wandb_project = None # "tiny_grpo" + wandb_project = "tiny-grpo" # "tiny_grpo" device_index = 0 model_name = "meta-llama/Llama-3.2-1B-Instruct" checkpoint_path = Path("./output") @@ -291,11 +292,12 @@ def main(): sequence_ids=sequence_ids, attention_mask=attention_mask, ) - kl = approx_kl_divergence( - log_probs=log_probs, - log_probs_ref=log_probs_ref, - action_mask=action_mask, - ) + # kl = approx_kl_divergence( + # log_probs=log_probs, + # log_probs_ref=log_probs_ref, + # action_mask=action_mask, + # ) + kl = 0 experience = Experience( sequences=sequence_ids,