forked from tile-ai/tilelang
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathexample_convolution.py
More file actions
123 lines (101 loc) · 4.31 KB
/
Copy pathexample_convolution.py
File metadata and controls
123 lines (101 loc) · 4.31 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import torch
import tilelang
from tilelang.autotuner import *
import tilelang.language as T
import argparse
def ref_program(stride, padding, dilation):
def main(A, B):
A = A.permute(0, 3, 1, 2) # N, H, W, C -> N, C, H, W
B = B.permute(3, 2, 0, 1) # H, W, C, F -> F, C, H, W
C = torch.conv2d(A, B, stride=stride, padding=padding, dilation=dilation)
C = C.permute(0, 2, 3, 1) # N, C, H, W -> N, H, W, C
return C
return main
@tilelang.jit
def convolution(data, kernel, S, D, P, block_M, block_N, block_K, num_stages, threads, dtype=T.float16, accum_dtype=T.float32):
N, C, H, W, F, K = T.const("N, C, H, W, F, K")
KH, KW = K, K
OH = (H + 2 * P - D * (K - 1) - 1) // S + 1
OW = (W + 2 * P - D * (K - 1) - 1) // S + 1
dtype = T.float16
accum_dtype = T.float32
data: T.Tensor((N, H, W, C), dtype)
kernel: T.Tensor((KH, KW, C, F), dtype)
out = T.empty((N, OH, OW, F), dtype)
with T.Kernel(T.ceildiv(F, block_N), T.ceildiv(N * OH * OW, block_M), threads=threads) as (bx, by):
data_shared = T.alloc_shared((block_M, block_K), dtype)
kernel_shared = T.alloc_shared((block_K, block_N), dtype)
out_local = T.alloc_fragment((block_M, block_N), accum_dtype)
out_shared = T.alloc_shared((block_M, block_N), dtype)
kernel_flat = T.Tensor((KH * KW * C, F), dtype, kernel.data)
out_flat = T.Tensor((N * OH * OW, F), dtype, out.data)
T.clear(out_local)
for k_iter in T.Pipelined(T.ceildiv(KH * KW * C, block_K), num_stages=num_stages):
T.im2col(data, data_shared, by, k_iter, KH, S, D, P)
T.copy(kernel_flat[k_iter * block_K, bx * block_N], kernel_shared)
T.gemm(data_shared, kernel_shared, out_local)
T.copy(out_local, out_shared)
T.copy(out_shared, out_flat[by * block_M, bx * block_N])
return out
def main(argv=None):
parser = argparse.ArgumentParser()
parser.add_argument("--n", type=int, default=128, help="n")
parser.add_argument("--c", type=int, default=128, help="c")
parser.add_argument("--h", type=int, default=64, help="h")
parser.add_argument("--w", type=int, default=64, help="w")
parser.add_argument("--f", type=int, default=128, help="f")
parser.add_argument("--k", type=int, default=3, help="k")
parser.add_argument("--s", type=int, default=1, help="s")
parser.add_argument("--d", type=int, default=1, help="d")
parser.add_argument("--p", type=int, default=1, help="p")
args = parser.parse_args(argv)
N, C, H, W, F, K, S, D, P = args.n, args.c, args.h, args.w, args.f, args.k, args.s, args.d, args.p
a = torch.randn(N, H, W, C).cuda().half()
b = torch.randn(K, K, C, F).cuda().half()
block_m = 64
block_n = 128
block_k = 32
num_stages = 3
threads = 256
out_c = convolution(a, b, S, D, P, block_m, block_n, block_k, num_stages, threads)
ref_c = ref_program(S, P, D)(a, b)
torch.testing.assert_close(out_c, ref_c, rtol=1e-2, atol=1e-2)
print("All checks passed.✅")
def run_regression_perf(argv=None):
parser = argparse.ArgumentParser()
parser.add_argument("--n", type=int, default=128, help="n")
parser.add_argument("--c", type=int, default=128, help="c")
parser.add_argument("--h", type=int, default=64, help="h")
parser.add_argument("--w", type=int, default=64, help="w")
parser.add_argument("--f", type=int, default=128, help="f")
parser.add_argument("--k", type=int, default=3, help="k")
parser.add_argument("--s", type=int, default=1, help="s")
parser.add_argument("--d", type=int, default=1, help="d")
parser.add_argument("--p", type=int, default=1, help="p")
args = parser.parse_args(argv)
N, C, H, W, F, K, S, D, P = args.n, args.c, args.h, args.w, args.f, args.k, args.s, args.d, args.p
block_m = 64
block_n = 128
block_k = 32
num_stages = 3
threads = 256
kernel = convolution.compile(
N=N,
C=C,
H=H,
W=W,
F=F,
K=K,
S=S,
D=D,
P=P,
block_M=block_m,
block_N=block_n,
block_K=block_k,
num_stages=num_stages,
threads=threads,
)
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Auto)
return profiler.do_bench(backend="cupti")
if __name__ == "__main__":
main()