|
180 | 180 | " metric = evaluate.load(\"accuracy\")\n",
|
181 | 181 | "\n",
|
182 | 182 | " def compute_metrics(eval_pred):\n",
|
183 |
| - " print(\"COMPUTE METRICS CALLED!\")\n", |
184 | 183 | " logits, labels = eval_pred\n",
|
185 | 184 | " predictions = np.argmax(logits, axis=-1)\n",
|
186 | 185 | "\n",
|
|
192 | 191 | "\n",
|
193 | 192 | " # Hugging Face Trainer\n",
|
194 | 193 | " training_args = TrainingArguments(\n",
|
195 |
| - " seed=SEED,\n", |
| 194 | + " do_eval=True,\n", |
| 195 | + " do_train=True,\n", |
| 196 | + " eval_strategy=\"epoch\",\n", |
| 197 | + " num_train_epochs=config[\"epochs\"],\n", |
196 | 198 | " output_dir=\"./results\",\n",
|
197 | 199 | " overwrite_output_dir=True,\n",
|
198 |
| - " num_train_epochs=config[\"epochs\"],\n", |
199 |
| - " eval_strategy=\"epoch\",\n", |
200 |
| - " do_train=True,\n", |
201 |
| - " do_eval=True,\n", |
| 200 | + " per_device_eval_batch_size=4,\n", |
| 201 | + " per_device_train_batch_size=4,\n", |
202 | 202 | " report_to=[\"comet_ml\"],\n",
|
| 203 | + " seed=SEED,\n", |
203 | 204 | " )\n",
|
204 | 205 | " trainer = Trainer(\n",
|
205 | 206 | " model=model,\n",
|
|
0 commit comments