Skip to content

Commit 66f2a2f

Browse files
authored
Add a centered variance option to the ClippedAdam optimizer (#3415)
* Add option to use centered variance in the ClippedAdam optimizer. * Add test for the centered ClippedAdam optimizer. * Calculate convergence iteration for the centered ClippedAdam optimizer. * Added reference of the centered Adam optimizer. * Add option to use the ClippedAdam optimizer with centered variance in the Latent Dirichlet Allocation example. * Added more detailed comments on ClippedAdam with centered variance and its tests. * Shortened the ClippedAdam centered variance test and added an option to run the full test with plots via a pytest command line option.
1 parent 86277f3 commit 66f2a2f

File tree

4 files changed

+152
-6
lines changed

4 files changed

+152
-6
lines changed

examples/lda.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,9 @@ def main(args):
137137
guide = functools.partial(parametrized_guide, predictor)
138138
Elbo = JitTraceEnum_ELBO if args.jit else TraceEnum_ELBO
139139
elbo = Elbo(max_plate_nesting=2)
140-
optim = ClippedAdam({"lr": args.learning_rate})
140+
optim = ClippedAdam(
141+
{"lr": args.learning_rate, "centered_variance": args.centered_variance}
142+
)
141143
svi = SVI(model, guide, optim, elbo)
142144
logging.info("Step\tLoss")
143145
for step in range(args.num_steps):
@@ -160,6 +162,7 @@ def main(args):
160162
parser.add_argument("-n", "--num-steps", default=1000, type=int)
161163
parser.add_argument("-l", "--layer-sizes", default="100-100")
162164
parser.add_argument("-lr", "--learning-rate", default=0.01, type=float)
165+
parser.add_argument("-cv", "--centered-variance", default=False, type=bool)
163166
parser.add_argument("-b", "--batch-size", default=32, type=int)
164167
parser.add_argument("--jit", action="store_true")
165168
args = parser.parse_args()

pyro/optim/clipped_adam.py

+15-5
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,21 @@ class ClippedAdam(Optimizer):
1919
:param weight_decay: weight decay (L2 penalty) (default: 0)
2020
:param clip_norm: magnitude of norm to which gradients are clipped (default: 10.0)
2121
:param lrd: rate at which learning rate decays (default: 1.0)
22+
:param centered_variance: use centered variance (default: False)
2223
2324
Small modification to the Adam algorithm implemented in torch.optim.Adam
24-
to include gradient clipping and learning rate decay.
25+
to include gradient clipping and learning rate decay and an option to use
26+
the centered variance (see equation 2 in [2]).
2527
26-
Reference
28+
**References**
2729
28-
`A Method for Stochastic Optimization`, Diederik P. Kingma, Jimmy Ba
29-
https://arxiv.org/abs/1412.6980
30+
[1] `A Method for Stochastic Optimization`, Diederik P. Kingma, Jimmy Ba
31+
https://arxiv.org/abs/1412.6980
32+
33+
[2] `A Two-Step Machine Learning Method for Predicting the Formation Energy of Ternary Compounds`,
34+
Varadarajan Rengaraj, Sebastian Jost, Franz Bethke, Christian Plessl,
35+
Hossein Mirhosseini, Andrea Walther, Thomas D. Kühne
36+
https://doi.org/10.3390/computation11050095
3037
"""
3138

3239
def __init__(
@@ -38,6 +45,7 @@ def __init__(
3845
weight_decay=0,
3946
clip_norm: float = 10.0,
4047
lrd: float = 1.0,
48+
centered_variance: bool = False,
4149
):
4250
defaults = dict(
4351
lr=lr,
@@ -46,6 +54,7 @@ def __init__(
4654
weight_decay=weight_decay,
4755
clip_norm=clip_norm,
4856
lrd=lrd,
57+
centered_variance=centered_variance,
4958
)
5059
super().__init__(params, defaults)
5160

@@ -87,7 +96,8 @@ def step(self, closure: Optional[Callable] = None) -> Optional[Any]:
8796

8897
# Decay the first and second moment running average coefficient
8998
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
90-
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
99+
grad_var = (grad - exp_avg) if group["centered_variance"] else grad
100+
exp_avg_sq.mul_(beta2).addcmul_(grad_var, grad_var, value=1 - beta2)
91101

92102
denom = exp_avg_sq.sqrt().add_(group["eps"])
93103

tests/optim/conftest.py

+10
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,13 @@ def pytest_collection_modifyitems(items):
1111
item.add_marker(pytest.mark.stage("unit"))
1212
if "init" not in item.keywords:
1313
item.add_marker(pytest.mark.init(rng_seed=123))
14+
15+
16+
def pytest_addoption(parser):
17+
parser.addoption("--plot", action="store", default="FALSE")
18+
19+
20+
def pytest_generate_tests(metafunc):
21+
option_value = metafunc.config.option.plot != "FALSE"
22+
if "plot" in metafunc.fixturenames and option_value is not None:
23+
metafunc.parametrize("plot", [option_value])

tests/optim/test_optim.py

+123
Original file line numberDiff line numberDiff line change
@@ -435,3 +435,126 @@ def step(svi, optimizer):
435435
actual.append(step(svi, optimizer))
436436

437437
assert_equal(actual, expected)
438+
439+
440+
def test_centered_clipped_adam(plot):
441+
"""
442+
Test the centered variance option of the ClippedAdam optimizer.
443+
In order to create plots run pytest with the plot command line
444+
option set to True, i.e. by executing
445+
446+
'pytest tests/optim/test_optim.py::test_centered_clipped_adam --plot True'
447+
448+
"""
449+
if not plot:
450+
lr_vec = [0.1, 0.001]
451+
else:
452+
lr_vec = [0.1, 0.05, 0.02, 0.01, 0.005, 0.002, 0.001]
453+
454+
w = torch.Tensor([1, 500])
455+
456+
def loss_fn(p):
457+
return (1 + w * p * p).sqrt().sum() - len(w)
458+
459+
def fit(lr, centered_variance, num_iter=5000):
460+
loss_vec = []
461+
p = torch.nn.Parameter(torch.Tensor([10, 1]))
462+
optim = pyro.optim.clipped_adam.ClippedAdam(
463+
lr=lr, params=[p], centered_variance=centered_variance
464+
)
465+
for count in range(num_iter):
466+
optim.zero_grad()
467+
loss = loss_fn(p)
468+
loss.backward()
469+
optim.step()
470+
loss_vec.append(loss)
471+
return torch.Tensor(loss_vec)
472+
473+
def calc_convergence(loss_vec, tail_len=100, threshold=0.01):
474+
"""
475+
Calculate the number of iterations needed in order to reach the
476+
ultimate loss plus a small threshold, and the convergence rate
477+
which is the mean per iteration improvement of the gap between
478+
the loss and the ultimate loss.
479+
"""
480+
ultimate_loss = loss_vec[-tail_len:].mean()
481+
convergence_iter = (loss_vec < (ultimate_loss + threshold)).nonzero().min()
482+
convergence_vec = loss_vec[:convergence_iter] - ultimate_loss
483+
convergence_rate = (convergence_vec[:-1] / convergence_vec[1:]).log().mean()
484+
return ultimate_loss, convergence_rate, convergence_iter
485+
486+
def get_convergence_vec(lr_vec, centered_variance):
487+
"""
488+
Fit parameters for a vector of learning rates, with or without centered variance,
489+
and calculate the convergence properties for each learning rate.
490+
"""
491+
ultimate_loss_vec, convergence_rate_vec, convergence_iter_vec = [], [], []
492+
for lr in lr_vec:
493+
loss_vec = fit(lr=lr, centered_variance=centered_variance)
494+
ultimate_loss, convergence_rate, convergence_iter = calc_convergence(
495+
loss_vec
496+
)
497+
ultimate_loss_vec.append(ultimate_loss)
498+
convergence_rate_vec.append(convergence_rate)
499+
convergence_iter_vec.append(convergence_iter)
500+
return (
501+
torch.Tensor(ultimate_loss_vec),
502+
torch.Tensor(convergence_rate_vec),
503+
convergence_iter_vec,
504+
)
505+
506+
(
507+
centered_ultimate_loss_vec,
508+
centered_convergence_rate_vec,
509+
centered_convergence_iter_vec,
510+
) = get_convergence_vec(lr_vec=lr_vec, centered_variance=True)
511+
ultimate_loss_vec, convergence_rate_vec, convergence_iter_vec = get_convergence_vec(
512+
lr_vec=lr_vec, centered_variance=False
513+
)
514+
515+
# ALl centered variance results should converge
516+
assert (centered_ultimate_loss_vec < 0.01).all()
517+
# Some uncentered variance results do not converge
518+
assert (ultimate_loss_vec > 0.01).any()
519+
# Verify convergence rate improvement
520+
assert (
521+
(centered_convergence_rate_vec / convergence_rate_vec)
522+
> ((0.12 / torch.Tensor(lr_vec)).log() * 1.08)
523+
).all()
524+
525+
if plot:
526+
from matplotlib import pyplot as plt
527+
528+
plt.figure(figsize=(6, 8))
529+
plt.subplot(3, 1, 1)
530+
plt.loglog(
531+
lr_vec, centered_convergence_iter_vec, "b.-", label="Centered Variance"
532+
)
533+
plt.loglog(lr_vec, convergence_iter_vec, "r.-", label="Uncentered Variance")
534+
plt.xlabel("Learning Rate")
535+
plt.ylabel("Convergence Iteration")
536+
plt.title("Convergence Iteration vs Learning Rate")
537+
plt.grid()
538+
plt.legend(loc="best")
539+
plt.subplot(3, 1, 2)
540+
plt.loglog(
541+
lr_vec, centered_convergence_rate_vec, "b.-", label="Centered Variance"
542+
)
543+
plt.loglog(lr_vec, convergence_rate_vec, "r.-", label="Uncentered Variance")
544+
plt.xlabel("Learning Rate")
545+
plt.ylabel("Convergence Rate")
546+
plt.title("Convergence Rate vs Learning Rate")
547+
plt.grid()
548+
plt.legend(loc="best")
549+
plt.subplot(3, 1, 3)
550+
plt.semilogx(
551+
lr_vec, centered_ultimate_loss_vec, "b.-", label="Centered Variance"
552+
)
553+
plt.semilogx(lr_vec, ultimate_loss_vec, "r.-", label="Uncentered Variance")
554+
plt.xlabel("Learning Rate")
555+
plt.ylabel("Ultimate Loss")
556+
plt.title("Ultimate Loss vs Learning Rate")
557+
plt.grid()
558+
plt.legend(loc="best")
559+
plt.tight_layout()
560+
plt.savefig("test_centered_variance.png")

0 commit comments

Comments
 (0)