@@ -435,3 +435,126 @@ def step(svi, optimizer):
435
435
actual .append (step (svi , optimizer ))
436
436
437
437
assert_equal (actual , expected )
438
+
439
+
440
+ def test_centered_clipped_adam (plot ):
441
+ """
442
+ Test the centered variance option of the ClippedAdam optimizer.
443
+ In order to create plots run pytest with the plot command line
444
+ option set to True, i.e. by executing
445
+
446
+ 'pytest tests/optim/test_optim.py::test_centered_clipped_adam --plot True'
447
+
448
+ """
449
+ if not plot :
450
+ lr_vec = [0.1 , 0.001 ]
451
+ else :
452
+ lr_vec = [0.1 , 0.05 , 0.02 , 0.01 , 0.005 , 0.002 , 0.001 ]
453
+
454
+ w = torch .Tensor ([1 , 500 ])
455
+
456
+ def loss_fn (p ):
457
+ return (1 + w * p * p ).sqrt ().sum () - len (w )
458
+
459
+ def fit (lr , centered_variance , num_iter = 5000 ):
460
+ loss_vec = []
461
+ p = torch .nn .Parameter (torch .Tensor ([10 , 1 ]))
462
+ optim = pyro .optim .clipped_adam .ClippedAdam (
463
+ lr = lr , params = [p ], centered_variance = centered_variance
464
+ )
465
+ for count in range (num_iter ):
466
+ optim .zero_grad ()
467
+ loss = loss_fn (p )
468
+ loss .backward ()
469
+ optim .step ()
470
+ loss_vec .append (loss )
471
+ return torch .Tensor (loss_vec )
472
+
473
+ def calc_convergence (loss_vec , tail_len = 100 , threshold = 0.01 ):
474
+ """
475
+ Calculate the number of iterations needed in order to reach the
476
+ ultimate loss plus a small threshold, and the convergence rate
477
+ which is the mean per iteration improvement of the gap between
478
+ the loss and the ultimate loss.
479
+ """
480
+ ultimate_loss = loss_vec [- tail_len :].mean ()
481
+ convergence_iter = (loss_vec < (ultimate_loss + threshold )).nonzero ().min ()
482
+ convergence_vec = loss_vec [:convergence_iter ] - ultimate_loss
483
+ convergence_rate = (convergence_vec [:- 1 ] / convergence_vec [1 :]).log ().mean ()
484
+ return ultimate_loss , convergence_rate , convergence_iter
485
+
486
+ def get_convergence_vec (lr_vec , centered_variance ):
487
+ """
488
+ Fit parameters for a vector of learning rates, with or without centered variance,
489
+ and calculate the convergence properties for each learning rate.
490
+ """
491
+ ultimate_loss_vec , convergence_rate_vec , convergence_iter_vec = [], [], []
492
+ for lr in lr_vec :
493
+ loss_vec = fit (lr = lr , centered_variance = centered_variance )
494
+ ultimate_loss , convergence_rate , convergence_iter = calc_convergence (
495
+ loss_vec
496
+ )
497
+ ultimate_loss_vec .append (ultimate_loss )
498
+ convergence_rate_vec .append (convergence_rate )
499
+ convergence_iter_vec .append (convergence_iter )
500
+ return (
501
+ torch .Tensor (ultimate_loss_vec ),
502
+ torch .Tensor (convergence_rate_vec ),
503
+ convergence_iter_vec ,
504
+ )
505
+
506
+ (
507
+ centered_ultimate_loss_vec ,
508
+ centered_convergence_rate_vec ,
509
+ centered_convergence_iter_vec ,
510
+ ) = get_convergence_vec (lr_vec = lr_vec , centered_variance = True )
511
+ ultimate_loss_vec , convergence_rate_vec , convergence_iter_vec = get_convergence_vec (
512
+ lr_vec = lr_vec , centered_variance = False
513
+ )
514
+
515
+ # ALl centered variance results should converge
516
+ assert (centered_ultimate_loss_vec < 0.01 ).all ()
517
+ # Some uncentered variance results do not converge
518
+ assert (ultimate_loss_vec > 0.01 ).any ()
519
+ # Verify convergence rate improvement
520
+ assert (
521
+ (centered_convergence_rate_vec / convergence_rate_vec )
522
+ > ((0.12 / torch .Tensor (lr_vec )).log () * 1.08 )
523
+ ).all ()
524
+
525
+ if plot :
526
+ from matplotlib import pyplot as plt
527
+
528
+ plt .figure (figsize = (6 , 8 ))
529
+ plt .subplot (3 , 1 , 1 )
530
+ plt .loglog (
531
+ lr_vec , centered_convergence_iter_vec , "b.-" , label = "Centered Variance"
532
+ )
533
+ plt .loglog (lr_vec , convergence_iter_vec , "r.-" , label = "Uncentered Variance" )
534
+ plt .xlabel ("Learning Rate" )
535
+ plt .ylabel ("Convergence Iteration" )
536
+ plt .title ("Convergence Iteration vs Learning Rate" )
537
+ plt .grid ()
538
+ plt .legend (loc = "best" )
539
+ plt .subplot (3 , 1 , 2 )
540
+ plt .loglog (
541
+ lr_vec , centered_convergence_rate_vec , "b.-" , label = "Centered Variance"
542
+ )
543
+ plt .loglog (lr_vec , convergence_rate_vec , "r.-" , label = "Uncentered Variance" )
544
+ plt .xlabel ("Learning Rate" )
545
+ plt .ylabel ("Convergence Rate" )
546
+ plt .title ("Convergence Rate vs Learning Rate" )
547
+ plt .grid ()
548
+ plt .legend (loc = "best" )
549
+ plt .subplot (3 , 1 , 3 )
550
+ plt .semilogx (
551
+ lr_vec , centered_ultimate_loss_vec , "b.-" , label = "Centered Variance"
552
+ )
553
+ plt .semilogx (lr_vec , ultimate_loss_vec , "r.-" , label = "Uncentered Variance" )
554
+ plt .xlabel ("Learning Rate" )
555
+ plt .ylabel ("Ultimate Loss" )
556
+ plt .title ("Ultimate Loss vs Learning Rate" )
557
+ plt .grid ()
558
+ plt .legend (loc = "best" )
559
+ plt .tight_layout ()
560
+ plt .savefig ("test_centered_variance.png" )
0 commit comments