Skip to content

Commit 08a280a

Browse files
authored
[NEW DEVICES AND MLP] H100 and batch_norm mlp (#56)
* mlp models for h100 * register new mlp batch_norm * fixed limited record number * new batch_norm model * changed how we colapse batch_norm2d * reverted to original batch sizes * update batch norm MLP with a40 and a4000 data * changed order of image_size and channels --------- Co-authored-by: John Calderon <[email protected]>
1 parent 8bba737 commit 08a280a

File tree

18 files changed

+212
-13
lines changed

18 files changed

+212
-13
lines changed

analyzer/habitat/analysis/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
'__matmul__', # calls the same kernel as linear
1313
'bmm',
1414

15+
# batch normalization
16+
'batch_norm',
17+
1518
# Recurrent operations
1619
'lstm',
1720
'gru',

analyzer/habitat/analysis/mlp/devices.csv

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,5 @@ RTX3090,24,GDDR6X,936.2,82,556.0,35.58,35.58
1010
A40,48,GDDR6,614.9,84,1.168,37.4,299.4
1111
A4000,16,GDDR6,378.1,48,0.599,19.17,19.17
1212
RTX4000,8,GDDR6,364.1,36,0.2225,7.119,7.119
13-
L4,24,GDDR6,254,60,0.473,30.29,30.29
13+
L4,24,GDDR6,254,60,0.473,30.29,30.29
14+
H100,80,HBM,2090,132,33.45,66.91,267.6

analyzer/habitat/analysis/mlp/mlp.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,22 @@ def forward(self, x):
128128

129129
return x
130130

131+
class BatchNorm(nn.Module):
132+
def __init__(self, layers, layer_size):
133+
super().__init__()
134+
135+
self.features = ["batch","channels","image_size"]
136+
self.fc1 = nn.Linear(len(self.features) + 4, layer_size)
137+
self.mlp = MLPBase(layers, layer_size)
138+
self.fc2 = nn.Linear(layer_size, 1)
139+
140+
def forward(self, x):
141+
x = self.fc1(x)
142+
x = F.relu(x)
143+
x = self.mlp(x)
144+
x = self.fc2(x)
145+
146+
return x
131147

132148
class RuntimePredictor:
133149
def __init__(self, model_name, layers, layer_size, model_path=None):
@@ -141,6 +157,7 @@ def __init__(self, model_name, layers, layer_size, model_path=None):
141157
"conv2d": Conv2DMLP,
142158
"conv_transpose2d": ConvTranspose2DMLP,
143159
"bmm": BMMMLP,
160+
"batch_norm": BatchNorm,
144161
}[self.model_name](layers, layer_size)
145162

146163
self.device_params = ['mem', 'mem_bw', 'num_sm', 'single']

analyzer/habitat/analysis/predictor.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import functools
22
import logging
33
import operator
4+
import numpy as np
45

56
from habitat.analysis import SPECIAL_OPERATIONS
67
from habitat.analysis.operation import PredictedOperation
@@ -11,7 +12,6 @@
1112
from habitat.utils import ms_to_ns, name_all_arguments
1213

1314
from habitat.analysis.mlp.mlp import RuntimePredictor
14-
1515
logger = logging.getLogger(__name__)
1616

1717
CONV2D_PARAMS = [
@@ -52,6 +52,16 @@
5252

5353
MATMUL_PARAMS = ['input', 'other', 'out']
5454

55+
BATCH_NORM = [
56+
'input',
57+
'running_mean',
58+
'running_var',
59+
'weight',
60+
'bias',
61+
'training',
62+
'momentum',
63+
'eps'
64+
]
5565

5666
class Predictor:
5767
def __init__(
@@ -86,6 +96,10 @@ def __init__(
8696
"conv_transpose2d", 8, 1024,
8797
path_to_data("conv_transpose2d/model.pth"),
8898
)
99+
self.batch_norm_pred = RuntimePredictor(
100+
"batch_norm", 8, 1024,
101+
path_to_data("batch_norm/model.pth"),
102+
)
89103

90104

91105
def predict_operation(self, operation, dest_device, unscaled=False):
@@ -108,6 +122,8 @@ def predict_operation(self, operation, dest_device, unscaled=False):
108122
return self._special_scale(operation, dest_device, self._bmm_scale, unscaled)
109123
elif operation.name == 'conv_transpose2d':
110124
return self._special_scale(operation, dest_device, self._conv_transpose2d_scale, unscaled)
125+
elif operation.name == "batch_norm":
126+
return self._special_scale(operation, dest_device, self._batch_norm_scale, unscaled)
111127

112128
logger.warn('Unhandled special operation: %s', operation.name)
113129
return PredictedOperation(
@@ -354,3 +370,34 @@ def _lstm_scale(self, operation, dest_device, unscaled=False):
354370
return pred_orig
355371

356372
return operation.run_time_ms * pred_dest / pred_orig
373+
374+
def _batch_norm_scale(self, operation, dest_device, unscaled=False):
375+
merged = name_all_arguments(
376+
BATCH_NORM,
377+
operation.arguments.args,
378+
operation.arguments.kwargs,
379+
)
380+
381+
# 2. Construct arguments that the predictor expects
382+
arguments = dict(
383+
batch=merged['input'][0],
384+
channels=merged['input'][1],
385+
# batch_norm can be called by BatchNorm1d, BatchNorm2d, BatchNorm3d
386+
# so we need to collapse all features after channels into a single int
387+
image_size=np.mean(merged['input'][2:]),
388+
)
389+
390+
# 3. Call model to make prediction
391+
arguments = [arguments[x] for x in self.batch_norm_pred.model.features]
392+
393+
pred_dest = self.batch_norm_pred.predict(arguments, dest_device.name)
394+
pred_orig = self.batch_norm_pred.predict(arguments, operation.device.name)
395+
396+
if unscaled:
397+
return pred_dest
398+
399+
if dest_device.name == operation.device.name: #local prediction
400+
return pred_orig
401+
402+
return operation.run_time_ms * pred_dest / pred_orig
403+
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
version https://git-lfs.github.com/spec/v1
2+
oid sha256:786bf25f8e13164adb455502897b68ca8c2031894d76087f20aa56507c72607b
3+
size 33630314
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
version https://git-lfs.github.com/spec/v1
2-
oid sha256:1f0249c126f0fb7959f2397dc52b241e3756f8035bb34ffb31e58dbec411345c
2+
oid sha256:628cd9ecca8cda59e0b5277580c996a72bae9b29bf3c5bdabccd9dfa6fc34389
33
size 33634474
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
version https://git-lfs.github.com/spec/v1
2-
oid sha256:d7b119c75ca91f55fb0541a925811f2799d8e81c8e9175bddec58840a6e0831d
2+
oid sha256:3974db996896911deb43bfc57711ce8a8b5875e37712d1d6e9511322da6e6f7b
33
size 33650922
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
version https://git-lfs.github.com/spec/v1
2-
oid sha256:ea25ab6d7cbfe462aebfea9f32e9a8491d6885a714c34235a57792472db447fe
2+
oid sha256:647e6fb1b31328ed52ba1803c0c14bbac2b0856c51600d533dffaf2e3bd64644
33
size 33650922

analyzer/habitat/data/devices.yml

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,3 +224,19 @@ L4:
224224
mem_bandwidth_gb: 254
225225
base_clock_mhz: 795
226226
peak_gflops_per_second: 15130
227+
228+
H100:
229+
compute_major: 9
230+
compute_minor: 0
231+
max_threads_per_block: 1024
232+
max_threads_per_multiprocessor: 2048
233+
regs_per_block: 65536
234+
regs_per_multiprocessor: 65536
235+
warp_size: 32
236+
shared_mem_per_block: 49152
237+
shared_mem_per_multiprocessor: 233472
238+
num_sms: 132
239+
shared_mem_per_block_optin: 232448
240+
mem_bandwidth_gb: 2090
241+
base_clock_mhz: 1590
242+
peak_gflops_per_second: 33425
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
version https://git-lfs.github.com/spec/v1
2-
oid sha256:203a7f1d8d6837055490ba33573fd8f422ead79d2b41e00d0f68400742ec81f4
2+
oid sha256:e65df5de655bf09f97a4ecd6a3e3c942fcef53fd81e9d2b83f9c43c6ae6c0e3a
33
size 33634474

0 commit comments

Comments
 (0)