diff --git a/pytorch_modelsize.py b/pytorch_modelsize.py index 213406c..5b2ea7c 100644 --- a/pytorch_modelsize.py +++ b/pytorch_modelsize.py @@ -12,7 +12,7 @@ def __init__(self, model, input_size=(1,1,32,32), bits=32): ''' self.model = model self.input_size = input_size - self.bits = 32 + self.bits = bits return def get_parameter_sizes(self):