11import functools
22import logging
33import operator
4+ import numpy as np
45
56from habitat .analysis import SPECIAL_OPERATIONS
67from habitat .analysis .operation import PredictedOperation
1112from habitat .utils import ms_to_ns , name_all_arguments
1213
1314from habitat .analysis .mlp .mlp import RuntimePredictor
14-
1515logger = logging .getLogger (__name__ )
1616
1717CONV2D_PARAMS = [
5252
5353MATMUL_PARAMS = ['input' , 'other' , 'out' ]
5454
55+ BATCH_NORM = [
56+ 'input' ,
57+ 'running_mean' ,
58+ 'running_var' ,
59+ 'weight' ,
60+ 'bias' ,
61+ 'training' ,
62+ 'momentum' ,
63+ 'eps'
64+ ]
5565
5666class Predictor :
5767 def __init__ (
@@ -86,6 +96,10 @@ def __init__(
8696 "conv_transpose2d" , 8 , 1024 ,
8797 path_to_data ("conv_transpose2d/model.pth" ),
8898 )
99+ self .batch_norm_pred = RuntimePredictor (
100+ "batch_norm" , 8 , 1024 ,
101+ path_to_data ("batch_norm/model.pth" ),
102+ )
89103
90104
91105 def predict_operation (self , operation , dest_device , unscaled = False ):
@@ -108,6 +122,8 @@ def predict_operation(self, operation, dest_device, unscaled=False):
108122 return self ._special_scale (operation , dest_device , self ._bmm_scale , unscaled )
109123 elif operation .name == 'conv_transpose2d' :
110124 return self ._special_scale (operation , dest_device , self ._conv_transpose2d_scale , unscaled )
125+ elif operation .name == "batch_norm" :
126+ return self ._special_scale (operation , dest_device , self ._batch_norm_scale , unscaled )
111127
112128 logger .warn ('Unhandled special operation: %s' , operation .name )
113129 return PredictedOperation (
@@ -354,3 +370,34 @@ def _lstm_scale(self, operation, dest_device, unscaled=False):
354370 return pred_orig
355371
356372 return operation .run_time_ms * pred_dest / pred_orig
373+
374+ def _batch_norm_scale (self , operation , dest_device , unscaled = False ):
375+ merged = name_all_arguments (
376+ BATCH_NORM ,
377+ operation .arguments .args ,
378+ operation .arguments .kwargs ,
379+ )
380+
381+ # 2. Construct arguments that the predictor expects
382+ arguments = dict (
383+ batch = merged ['input' ][0 ],
384+ channels = merged ['input' ][1 ],
385+ # batch_norm can be called by BatchNorm1d, BatchNorm2d, BatchNorm3d
386+ # so we need to collapse all features after channels into a single int
387+ image_size = np .mean (merged ['input' ][2 :]),
388+ )
389+
390+ # 3. Call model to make prediction
391+ arguments = [arguments [x ] for x in self .batch_norm_pred .model .features ]
392+
393+ pred_dest = self .batch_norm_pred .predict (arguments , dest_device .name )
394+ pred_orig = self .batch_norm_pred .predict (arguments , operation .device .name )
395+
396+ if unscaled :
397+ return pred_dest
398+
399+ if dest_device .name == operation .device .name : #local prediction
400+ return pred_orig
401+
402+ return operation .run_time_ms * pred_dest / pred_orig
403+
0 commit comments