Skip to content

Commit 924f4cc

Browse files
authored
Merge branch 'master' into mhof_dev
2 parents ae907b2 + 163ca91 commit 924f4cc

File tree

6 files changed

+63
-4
lines changed

6 files changed

+63
-4
lines changed

domainlab/algos/trainers/a_trainer.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import torch
77
from torch import optim
8+
from torch.optim import lr_scheduler
89

910
from domainlab.compos.pcr.p_chain_handler import AbstractChainNodeHandler
1011

@@ -27,7 +28,12 @@ def mk_opt(model, aconf):
2728
# {'params': model._decoratee.parameters()}
2829
# ], lr=aconf.lr)
2930
optimizer = optim.Adam(list_par, lr=aconf.lr)
30-
return optimizer
31+
if aconf.lr_scheduler is not None:
32+
class_lr_scheduler = getattr(lr_scheduler, aconf.lr_scheduler)
33+
scheduler = class_lr_scheduler(optimizer, T_max=aconf.epos)
34+
else:
35+
scheduler = None
36+
return optimizer, scheduler
3137

3238

3339
class AbstractTrainer(AbstractChainNodeHandler, metaclass=abc.ABCMeta):
@@ -102,6 +108,8 @@ def __init__(self, successor_node=None, extend=None):
102108
self.dict_multiplier = {}
103109
# MIRO
104110
self.input_tensor_shape = None
111+
# LR-scheduler
112+
self.lr_scheduler = None
105113

106114
@property
107115
def model(self):
@@ -178,7 +186,7 @@ def reset(self):
178186
"""
179187
make a new optimizer to clear internal state
180188
"""
181-
self.optimizer = mk_opt(self.model, self.aconf)
189+
self.optimizer, self.lr_scheduler = mk_opt(self.model, self.aconf)
182190

183191
@abc.abstractmethod
184192
def tr_epoch(self, epoch):

domainlab/algos/trainers/train_basic.py

+2
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,8 @@ def tr_batch(self, tensor_x, tensor_y, tensor_d, others, ind_batch, epoch):
8383
loss, *_ = self.cal_loss(tensor_x, tensor_y, tensor_d, others)
8484
loss.backward()
8585
self.optimizer.step()
86+
if self.lr_scheduler:
87+
self.lr_scheduler.step()
8688
self.after_batch(epoch, ind_batch)
8789
self.counter_batch += 1
8890

domainlab/arg_parser.py

+7
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,13 @@ def mk_parser_main():
264264
help="name of pytorch optimizer",
265265
)
266266

267+
parser.add_argument(
268+
"--lr_scheduler",
269+
type=str,
270+
default=None,
271+
help="name of pytorch learning rate scheduler",
272+
)
273+
267274
parser.add_argument(
268275
"--param_idx",
269276
type=bool,

scripts/generate_latex_table.py

+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
"""
2+
aggregate benchmark csv file to generate latex table
3+
"""
4+
import argparse
5+
import pandas as pd
6+
7+
8+
def gen_latex_table(raw_df, fname="table_perf.tex",
9+
group="method", str_perf="acc"):
10+
"""
11+
aggregate benchmark csv file to generate latex table
12+
"""
13+
df_result = raw_df.groupby(group)[str_perf].agg(["mean", "std", "count"])
14+
latex_table = df_result.to_latex(float_format="%.3f")
15+
str_table = df_result.to_string()
16+
print(str_table)
17+
with open(fname, 'w') as file:
18+
file.write(latex_table)
19+
20+
21+
if __name__ == "__main__":
22+
parser = argparse.ArgumentParser(description="Read a CSV file")
23+
parser.add_argument("filename", help="Name of the CSV file to read")
24+
args = parser.parse_args()
25+
26+
df = pd.read_csv(args.filename, index_col=False, skipinitialspace=True)
27+
gen_latex_table(df)

scripts/sh_genplot.sh

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
1-
mkdir $2
2-
python main_out.py --gen_plots $1 --outp_dir $2
1+
# mkdir $2
2+
sh scripts/merge_csvs.sh $1
3+
python main_out.py --gen_plots merged_data.csv --outp_dir partial_agg_plots

tests/test_lr_scheduler.py

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
2+
"""
3+
unit and end-end test for lr scheduler
4+
"""
5+
from tests.utils_test import utils_test_algo
6+
7+
8+
def test_lr_scheduler():
9+
"""
10+
train
11+
"""
12+
args = "--te_d=2 --tr_d 0 1 --task=mnistcolor10 --debug --bs=100 --model=erm \
13+
--nname=conv_bn_pool_2 --no_dump --lr_scheduler=CosineAnnealingLR"
14+
utils_test_algo(args)

0 commit comments

Comments
 (0)