File tree Expand file tree Collapse file tree 3 files changed +24
-2
lines changed
src/pytorch_metric_learning Expand file tree Collapse file tree 3 files changed +24
-2
lines changed Original file line number Diff line number Diff line change 1- __version__ = "2.1.0 "
1+ __version__ = "2.1.1 "
Original file line number Diff line number Diff line change @@ -79,7 +79,9 @@ def set_default_stats(
7979 ):
8080 if self .collect_stats :
8181 with torch .no_grad ():
82- self .initial_avg_query_norm : torch .mean (self .get_norm (query_emb )).item ()
82+ self .initial_avg_query_norm = torch .mean (
83+ self .get_norm (query_emb )
84+ ).item ()
8385 self .initial_avg_ref_norm = torch .mean (self .get_norm (ref_emb )).item ()
8486 self .final_avg_query_norm = torch .mean (
8587 self .get_norm (query_emb_normalized )
Original file line number Diff line number Diff line change 1+ import unittest
2+
3+ import torch
4+
5+ from pytorch_metric_learning .distances import LpDistance
6+
7+ from .. import WITH_COLLECT_STATS
8+
9+
10+ class TestCollectedStats (unittest .TestCase ):
11+ @unittest .skipUnless (WITH_COLLECT_STATS , "WITH_COLLECT_STATS is false" )
12+ def test_collected_stats (self ):
13+ x = torch .randn (32 , 128 )
14+ d = LpDistance ()
15+ d (x )
16+
17+ self .assertNotEqual (d .initial_avg_query_norm , 0 )
18+ self .assertNotEqual (d .initial_avg_ref_norm , 0 )
19+ self .assertNotEqual (d .final_avg_query_norm , 0 )
20+ self .assertNotEqual (d .final_avg_ref_norm , 0 )
You can’t perform that action at this time.
0 commit comments