Skip to content

Commit 0fa2133

Browse files
Integrate byoc preprocess in collage and benchmark (#26)
Integrate implicit call of BYOC preprocessing module into collage tunning module and enable benchmark script for adreno targets. Benchmark results: **Networks | OpenCL texture | OpenCLML | Collage** resnet-18-float32 | 0.010584622 | 0.00720695 | 0.007289728 resnet-18-float16 | 0.007052029 | 0.0045642 | 0.004857585 resnet-34-float32 | 0.016259185 | 0.01242092 | 0.013071063 resnet-34-float16 | 0.011350326 | 0.0073473 | 0.00796802 resnet-50-float32 | 0.019188419 | 0.02085548 | 0.018910226 resnet-50-float16 | 0.01338978 | 0.01199576 | 0.011089206 densenet-121-float32 | 0.025430062 | 0.01798478 | 0.013212844 densenet-121-float16 | 0.012384599 | 0.01101491 | 0.008722716 inception_v3-float32 | 0.040408253 | 0.02229727 | 0.022636675 inception_v3-float16 | 0.029910533 | 0.01368941 | 0.014519823 mobilenet-float32 | 0.004093148 | 0.00367917 | 0.003189258 mobilenet-float16 | 0.00280268 | 0.00244494 | 0.002101514 </body> </html> Co-authored-by: krishnaraj36 <[email protected]>
1 parent e3665ae commit 0fa2133

File tree

9 files changed

+559
-276
lines changed

9 files changed

+559
-276
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,389 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
"""Compares Collage with various other baselines."""
19+
import argparse
20+
import tvm
21+
from tvm import relay
22+
import os
23+
import sys
24+
import numpy as np
25+
from tvm.relay import testing
26+
from tvm.contrib.utils import tempdir
27+
from tvm import rpc
28+
from tvm.relay.build_module import bind_params_by_name
29+
from tvm import autotvm
30+
from tvm.runtime.vm import VirtualMachine
31+
import tvm.contrib.graph_executor as runtime
32+
from tvm.contrib import utils, ndk
33+
from tvm.relay.collage.collage import *
34+
from tvm.relay.op.contrib import clml
35+
36+
37+
###
38+
### How aggressively to look for candidates?
39+
###
40+
TVM_MAX_DEPTH = 8
41+
BYOC_MAX_DEPTH = 8
42+
43+
##
44+
## Default config definition
45+
##
46+
HOST = tvm.target.Target("llvm -mtriple=arm64-linux-android")
47+
OPENCL = tvm.target.Target("opencl -device=adreno", HOST)
48+
NDK_CC = os.getenv("TVM_NDK_CC", "aarch64-linux-android-g++")
49+
50+
51+
def print_progress(msg):
52+
"""print progress message
53+
54+
Parameters
55+
----------
56+
msg: str
57+
The message to print
58+
"""
59+
sys.stdout.write(msg + "\r")
60+
sys.stdout.flush()
61+
62+
63+
def tune_tasks(
64+
tasks,
65+
measure_option,
66+
tuner="xgb",
67+
n_trial=1024,
68+
early_stopping=None,
69+
log_filename="tuning.log",
70+
):
71+
from tvm.autotvm.tuner import XGBTuner
72+
73+
tmp_log_file = log_filename + ".tmp"
74+
75+
for i, tsk in enumerate(reversed(tasks)):
76+
print("Task: ", tsk)
77+
prefix = "[Task %2d/%2d] " % (i + 1, len(tasks))
78+
79+
# create tuner
80+
if tuner == "xgb":
81+
tuner_obj = XGBTuner(tsk, loss_type="reg")
82+
elif tuner == "xgb_knob":
83+
tuner_obj = XGBTuner(tsk, loss_type="reg", feature_type="knob")
84+
elif tuner == "xgb_itervar":
85+
tuner_obj = XGBTuner(tsk, loss_type="reg", feature_type="itervar")
86+
elif tuner == "xgb_curve":
87+
tuner_obj = XGBTuner(tsk, loss_type="reg", feature_type="curve")
88+
elif tuner == "xgb_rank":
89+
tuner_obj = XGBTuner(tsk, loss_type="rank")
90+
elif tuner == "xgb_rank_knob":
91+
tuner_obj = XGBTuner(tsk, loss_type="rank", feature_type="knob")
92+
elif tuner == "xgb_rank_itervar":
93+
tuner_obj = XGBTuner(tsk, loss_type="rank", feature_type="itervar")
94+
elif tuner == "xgb_rank_curve":
95+
tuner_obj = XGBTuner(tsk, loss_type="rank", feature_type="curve")
96+
elif tuner == "xgb_rank_binary":
97+
tuner_obj = XGBTuner(tsk, loss_type="rank-binary")
98+
elif tuner == "xgb_rank_binary_knob":
99+
tuner_obj = XGBTuner(tsk, loss_type="rank-binary", feature_type="knob")
100+
elif tuner == "xgb_rank_binary_itervar":
101+
tuner_obj = XGBTuner(tsk, loss_type="rank-binary", feature_type="itervar")
102+
elif tuner == "xgb_rank_binary_curve":
103+
tuner_obj = XGBTuner(tsk, loss_type="rank-binary", feature_type="curve")
104+
elif tuner == "ga":
105+
tuner_obj = GATuner(tsk, pop_size=50)
106+
elif tuner == "random":
107+
tuner_obj = RandomTuner(tsk)
108+
elif tuner == "gridsearch":
109+
tuner_obj = GridSearchTuner(tsk)
110+
else:
111+
raise ValueError("Invalid tuner: " + tuner)
112+
113+
tsk_trial = min(n_trial, len(tsk.config_space))
114+
tuner_obj.tune(
115+
n_trial=tsk_trial,
116+
early_stopping=early_stopping,
117+
measure_option=measure_option,
118+
callbacks=[
119+
autotvm.callback.progress_bar(tsk_trial, prefix=prefix),
120+
autotvm.callback.log_to_file(tmp_log_file),
121+
],
122+
)
123+
124+
autotvm.record.pick_best(tmp_log_file, log_filename)
125+
126+
127+
########### Collage Drivers ###########
128+
129+
130+
def compile_and_run(label, model, targets, inputs):
131+
"""Compile model for target and run it with profiling."""
132+
print(f"Compiling {model['name']} using {label} with {targets}...")
133+
mod = model["mod"]
134+
exe = tvm.relay.vm.compile(mod, target=targets, params=model["params"])
135+
lib = exe.mod
136+
temp = utils.tempdir()
137+
dso_binary = "dev_lib_cl.so"
138+
dso_binary_path = temp.relpath(dso_binary)
139+
print(f"Exporting library to {dso_binary_path}...")
140+
lib.export_library(dso_binary_path, cc=NDK_CC)
141+
tracker = rpc.connect_tracker(args.host, args.port)
142+
remote = tracker.request(args.rpc_key, priority=0, session_timeout=600)
143+
ctx = remote.cl(0)
144+
remote.upload(dso_binary_path)
145+
rlib = remote.load_module(dso_binary)
146+
vm_factory = tvm.runtime.vm.VirtualMachine(rlib, ctx, "naive")
147+
func_name = "main"
148+
main_args = {v.name_hint: arg_for(v.checked_type, ctx) for v in mod[func_name].params}
149+
profile = vm_factory.benchmark(
150+
ctx, repeat=5, number=20, min_repeat_ms=0, func_name=func_name, **main_args
151+
)
152+
print(profile)
153+
return profile.median * 1e3
154+
155+
156+
def collage(model, input_data, tune_log=""):
157+
"""Run the Collage partitioner for a set of Opencl Adreno related targets and profile the result"""
158+
print(f"collage | {model['name']}")
159+
print("-------------- BEGIN ORIGINAL --------------")
160+
print(model["mod"])
161+
print("-------------- END ORIGINAL ----------------")
162+
targets = []
163+
targets.append(OPENCL)
164+
use_fp16 = model["main_dtype"] == "float16"
165+
targets.append(tvm.target.Target("clml", HOST))
166+
# Register byoc fusion style for compiler with available
167+
# options [compiler.NoFusion | compiler.TVMFusion | compiler.MaxDepthFusion]
168+
config = {
169+
"relay.collage.tvm_max_depth": TVM_MAX_DEPTH,
170+
"relay.collage.byoc_max_depth": BYOC_MAX_DEPTH,
171+
"relay.collage.byoc_fusion_style": ["clml.NoFusion"],
172+
}
173+
print(f"Using PassContext(config={config}")
174+
ctxt = tvm.transform.PassContext(config=config)
175+
config = tvm.target.make_compilation_config(ctxt, targets)
176+
with ctxt:
177+
mod = model["mod"]
178+
"""Collage partition with tvm opencl and clml target on rpc device"""
179+
mod = tvm.relay.transform.CollagePartition(
180+
config,
181+
cost_estimator=CostEstimator(
182+
host=args.host, port=args.port, rpc_key=args.rpc_key, ndk_cc=NDK_CC
183+
),
184+
)(mod)
185+
partitioned_model = model.copy()
186+
partitioned_model["mod"] = mod
187+
print("-------------- BEGIN PARTITIONED --------------")
188+
print(partitioned_model["mod"])
189+
print("-------------- END PARTITIONED ----------------")
190+
return compile_and_run("collage", partitioned_model, targets, input_data)
191+
192+
193+
def just_clml(model, input_data, tune_log=""):
194+
"""Run partition_for_clml, complete the compilation with TVM, and profile the result."""
195+
print(f"just_clml | {model['name']}")
196+
print("-------------- BEGIN ORIGINAL --------------")
197+
print(model["mod"])
198+
print("-------------- END ORIGINAL ----------------")
199+
with autotvm.apply_history_best(tune_log):
200+
with tvm.transform.PassContext(opt_level=3):
201+
print("Partitioning for CLML...")
202+
mod = tvm.relay.op.contrib.clml.partition_for_clml(model["mod"], model["params"])
203+
partitioned_model = model.copy()
204+
partitioned_model["mod"] = mod
205+
print("-------------- BEGIN PARTITIONED --------------")
206+
print(partitioned_model["mod"])
207+
print("-------------- END PARTITIONED ----------------")
208+
targets = []
209+
targets.append(OPENCL)
210+
targets.append(tvm.target.Target("clml", HOST))
211+
return compile_and_run("just_clml", partitioned_model, OPENCL, input_data)
212+
213+
214+
def just_tvm(model, input_data, tune_log=""):
215+
"""Compile and profile using vanilla TVM."""
216+
print(f"just_tvm | {model['name']}")
217+
print("-------------- BEGIN ORIGINAL --------------")
218+
print(model["mod"])
219+
print("-------------- END ORIGINAL ----------------")
220+
if tune_log:
221+
with autotvm.apply_history_best(tune_log):
222+
with tvm.transform.PassContext(opt_level=3):
223+
return compile_and_run("just_tvm", model, OPENCL, input_data)
224+
else:
225+
with tvm.transform.PassContext(opt_level=3):
226+
return compile_and_run("just_tvm", model, OPENCL, input_data)
227+
228+
229+
def get_model(model_name, dtype):
230+
231+
if "mobilenet" in model_name:
232+
mod, params = testing.mobilenet.get_workload(batch_size=1, dtype=dtype)
233+
elif "resnet" in model_name:
234+
n_layer = int(model_name.split("-")[1])
235+
mod, params = testing.resnet.get_workload(num_layers=n_layer, batch_size=1, dtype=dtype)
236+
elif model_name == "inception_v3":
237+
input_shape = (1, 3, 299, 299)
238+
mod, params = testing.inception_v3.get_workload(batch_size=1, dtype=dtype)
239+
elif "vgg" in model_name:
240+
n_layer = int(model_name.split("-")[1])
241+
mod, params = testing.vgg.get_workload(num_layers=n_layer, batch_size=1, dtype=dtype)
242+
elif "densenet" in model_name:
243+
n_layer = int(model_name.split("-")[1])
244+
mod, params = testing.densenet.get_workload(
245+
densenet_size=n_layer, batch_size=1, dtype=dtype
246+
)
247+
elif "squeezenet" in model_name:
248+
version = model_name.split("_v")[1]
249+
mod, params = testing.squeezenet.get_workload(batch_size=1, version=version, dtype=dtype)
250+
251+
initializer = tvm.relay.testing.init.Xavier()
252+
for param_name in list(params.keys()):
253+
filter_data = np.zeros(params[param_name].shape).astype(params[param_name].dtype)
254+
if len(filter_data.shape) > 1:
255+
initializer("weight", filter_data)
256+
else:
257+
initializer("bias", filter_data)
258+
params[param_name] = tvm.nd.array(filter_data)
259+
260+
if params:
261+
mod["main"] = bind_params_by_name(mod["main"], params)
262+
mod = tvm.relay.transform.FoldConstant()(mod)
263+
return {
264+
"name": model_name,
265+
"input_shapes": {"data": [1, 3, 224, 224]},
266+
"input_dtypes": {"data": dtype},
267+
"mod": mod,
268+
"params": params,
269+
"main_dtype": dtype,
270+
}
271+
272+
273+
########### Runners ###########
274+
def evaluate_network(model_name, dtype):
275+
print("Network evaluating .. " + model_name + " " + dtype)
276+
np.random.seed(0)
277+
model = get_model(model_name, dtype)
278+
tune_log = ""
279+
if args.tune:
280+
tune_log = "adreno_v0.01.log"
281+
# Auto Tuning
282+
tune_log = "adreno-" + model_name + "-" + dtype + ".log"
283+
tuning_options = {
284+
"log_filename": tune_log,
285+
"early_stopping": None,
286+
"measure_option": autotvm.measure_option(
287+
builder=autotvm.LocalBuilder(build_func=ndk.create_shared, timeout=15),
288+
runner=autotvm.RPCRunner(
289+
args.rpc_key,
290+
host=args.host,
291+
port=args.port,
292+
number=3,
293+
timeout=600,
294+
),
295+
),
296+
}
297+
tasks = autotvm.task.extract_from_program(
298+
net, target=OPENCL, target_host=HOST, params=params
299+
)
300+
tune_tasks(tasks, **tuning_options)
301+
302+
print_progress("%-20s building..." % network)
303+
input_data = {}
304+
for name, shape in model["input_shapes"].items():
305+
input_data[name] = np.random.uniform(-1.0, 1.0, shape).astype(model["input_dtypes"][name])
306+
clml_time = just_clml(model, input_data, tune_log)
307+
tvm_time = just_tvm(model, input_data, tune_log)
308+
309+
"""Run Collage for tvm and clml compiler target."""
310+
if tune_log:
311+
with autotvm.apply_history_best(tune_log):
312+
collage_time = collage(model, input_data, tune_log)
313+
else:
314+
collage_time = collage(model, input_data, tune_log)
315+
return (tvm_time, clml_time, collage_time)
316+
317+
318+
if __name__ == "__main__":
319+
parser = argparse.ArgumentParser()
320+
parser.add_argument(
321+
"--network",
322+
type=str,
323+
choices=[
324+
"resnet-18",
325+
"resnet-34",
326+
"resnet-50",
327+
"vgg-16",
328+
"vgg-19",
329+
"densenet-121",
330+
"inception_v3",
331+
"mobilenet",
332+
"squeezenet_v1.0",
333+
"squeezenet_v1.1",
334+
],
335+
help="The name of neural network",
336+
)
337+
parser.add_argument("--host", type=str, default="127.0.0.1")
338+
parser.add_argument("--port", type=int, default=9190)
339+
parser.add_argument("--rpc-key", type=str, default="android")
340+
parser.add_argument(
341+
"--dtype",
342+
type=str,
343+
choices=["float32", "float16"],
344+
help="The data type of neural network",
345+
)
346+
parser.add_argument("--tune", type=bool, default=False)
347+
args = parser.parse_args()
348+
349+
if args.network is None:
350+
networks = [
351+
"resnet-18",
352+
"resnet-34",
353+
"resnet-50",
354+
# "vgg-16",
355+
# "vgg-19",
356+
"densenet-121",
357+
"inception_v3",
358+
"mobilenet",
359+
"squeezenet_v1.0",
360+
"squeezenet_v1.1",
361+
]
362+
else:
363+
networks = [args.network]
364+
365+
if args.dtype is None:
366+
dtypes = ["float32", "float16"]
367+
else:
368+
dtypes = [args.dtype]
369+
370+
results = {}
371+
net_results = []
372+
for network in networks:
373+
for dtype in dtypes:
374+
ftime = evaluate_network(network, dtype)
375+
results[network + "-" + dtype] = ftime
376+
# net_results.append([network + "-" + dtype] + list(ftime))
377+
# np.savetxt("results.txt", np.array(net_results), fmt="%s")
378+
379+
print("----------------------------------------------------------------------")
380+
print(
381+
"%-30s %-20s %-20s %-20s"
382+
% ("Network Name", "TVM Opencl Time", "CLML Time", "Collage - TVM/CLML Time")
383+
)
384+
print("----------------------------------------------------------------------")
385+
for key, val in results.items():
386+
print(
387+
"%-30s %-20s %-20s %-20s"
388+
% (key, "%.2f ms" % val[0], "%.2f ms" % val[1], "%.2f ms" % val[2])
389+
)

0 commit comments

Comments
 (0)