diff --git a/src/alignrl/grpo.py b/src/alignrl/grpo.py index 3e1407b..3eaf7da 100644 --- a/src/alignrl/grpo.py +++ b/src/alignrl/grpo.py @@ -120,7 +120,7 @@ def train(self) -> TrainResult: save_strategy="steps", save_steps=50, ) - if self.config.reward_weights: + if self.config.reward_weights is not None: grpo_args.reward_weights = self.config.reward_weights trainer = GRPOTrainer(