Skip to content

Commit fa3e031

Browse files
ssjiaSS-JIA
authored andcommitted
[ET-VK] Add alignment fields to PackedDimInfo for padded size calculation
Pull Request resolved: #17170 This change introduces separate alignment fields to PackedDimInfo, decoupling the alignment used for padding tensor dimensions from the block size used for packing. Previously, `calculate_padded_sizes` used `packed_dim_block_size` and `outer_packed_dim_block_size` directly to determine how much to pad tensor dimensions. This works but limits flexibility - there are scenarios where we want to pad dimensions to a larger alignment than the block size for performance reasons, such as ensuring loads are aligned to cache lines or removing the need for bounds checking in shaders. The new fields `packed_dim_align` and `outer_packed_dim_align` allow specifying the alignment independently. For now, these are initialized to match the corresponding block sizes, preserving existing behavior. Future changes can set larger alignment values when beneficial for performance. Authored with Claude. ghstack-source-id: 338638551 @exported-using-ghexport Differential Revision: [D92196649](https://our.internmc.facebook.com/intern/diff/D92196649/)
1 parent 9b19a91 commit fa3e031

3 files changed

Lines changed: 116 additions & 126 deletions

File tree

backends/vulkan/runtime/api/containers/Tensor.cpp

Lines changed: 106 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,17 @@ namespace api {
1717
PackedDimInfo::PackedDimInfo(
1818
const int32_t dim,
1919
const int32_t dim_block_size,
20+
const int32_t dim_align,
2021
const int32_t outer_dim,
2122
const int32_t outer_dim_block_size,
23+
const int32_t outer_dim_align,
2224
const bool is_block_transposed)
2325
: packed_dim(dim),
2426
packed_dim_block_size(dim_block_size),
27+
packed_dim_align(dim_align),
2528
outer_packed_dim(outer_dim),
2629
outer_packed_dim_block_size(outer_dim_block_size),
30+
outer_packed_dim_align(outer_dim_align),
2731
block_transposed(is_block_transposed),
2832
block_numel(packed_dim_block_size * outer_packed_dim_block_size) {
2933
// Packed dims must be different
@@ -33,32 +37,105 @@ PackedDimInfo::PackedDimInfo(
3337
PackedDimInfo calculate_packed_dim_info(
3438
const utils::GPUMemoryLayout memory_layout,
3539
const utils::StorageType storage_type) {
36-
const int32_t packed_dim = utils::to_packed_dim<int32_t>(memory_layout);
37-
const int32_t outer_packed_dim =
38-
utils::to_outer_packed_dim<int32_t>(memory_layout);
39-
const int32_t packed_dim_block_size =
40-
utils::to_packed_dim_block_size<int32_t>(memory_layout, storage_type);
41-
const int32_t outer_packed_dim_block_size =
42-
utils::to_outer_packed_dim_block_size<int32_t>(memory_layout);
43-
const bool is_block_transposed =
44-
utils::is_block_transposed_layout(memory_layout);
45-
46-
const int32_t block_numel =
47-
packed_dim_block_size * outer_packed_dim_block_size;
48-
if (storage_type != utils::kBuffer) {
40+
const bool is_buffer = storage_type == utils::kBuffer;
41+
42+
PackedDimInfo packed_dim_info(0, 1, 1, 1, 1, 1, false);
43+
switch (memory_layout) {
44+
case utils::kWidthPacked:
45+
packed_dim_info = PackedDimInfo(
46+
/*dim=*/0,
47+
/*dim_block_size=*/is_buffer ? 1 : 4,
48+
/*dim_align=*/is_buffer ? 1 : 4,
49+
/*outer_dim=*/1,
50+
/*outer_dim_block_size=*/1,
51+
/*outer_dim_align=*/1,
52+
/*is_block_transposed=*/false);
53+
break;
54+
case utils::kHeightPacked:
55+
packed_dim_info = PackedDimInfo(
56+
/*dim=*/1,
57+
/*dim_block_size=*/is_buffer ? 1 : 4,
58+
/*dim_align=*/is_buffer ? 1 : 4,
59+
/*outer_dim=*/0,
60+
/*outer_dim_block_size=*/1,
61+
/*outer_dim_align=*/1,
62+
/*is_block_transposed=*/false);
63+
break;
64+
case utils::kChannelsPacked:
65+
packed_dim_info = PackedDimInfo(
66+
/*dim=*/2,
67+
/*dim_block_size=*/is_buffer ? 1 : 4,
68+
/*dim_align=*/is_buffer ? 1 : 4,
69+
/*outer_dim=*/0,
70+
/*outer_dim_block_size=*/1,
71+
/*outer_dim_align=*/1,
72+
/*is_block_transposed=*/false);
73+
break;
74+
case utils::kPackedInt8_4W:
75+
packed_dim_info = PackedDimInfo(
76+
/*dim=*/0,
77+
/*dim_block_size=*/is_buffer ? 4 : 16,
78+
/*dim_align=*/is_buffer ? 4 : 16,
79+
/*outer_dim=*/1,
80+
/*outer_dim_block_size=*/1,
81+
/*outer_dim_align=*/1,
82+
/*is_block_transposed=*/false);
83+
break;
84+
case utils::kPackedInt8_4C:
85+
packed_dim_info = PackedDimInfo(
86+
/*dim=*/2,
87+
/*dim_block_size=*/is_buffer ? 4 : 16,
88+
/*dim_align=*/is_buffer ? 4 : 16,
89+
/*outer_dim=*/0,
90+
/*outer_dim_block_size=*/1,
91+
/*outer_dim_align=*/1,
92+
/*is_block_transposed=*/false);
93+
break;
94+
case utils::kPackedInt8_4W4C:
95+
packed_dim_info = PackedDimInfo(
96+
/*dim=*/2,
97+
/*dim_block_size=*/4,
98+
/*dim_align=*/4,
99+
/*outer_dim=*/0,
100+
/*outer_dim_block_size=*/4,
101+
/*outer_dim_align=*/4,
102+
/*is_block_transposed=*/false);
103+
break;
104+
case utils::kPackedInt8_4H4W:
105+
packed_dim_info = PackedDimInfo(
106+
/*dim=*/0,
107+
/*dim_block_size=*/4,
108+
/*dim_align=*/4,
109+
/*outer_dim=*/1,
110+
/*outer_dim_block_size=*/4,
111+
/*outer_dim_align=*/4,
112+
/*is_block_transposed=*/false);
113+
break;
114+
case utils::kPackedInt8_4C1W:
115+
packed_dim_info = PackedDimInfo(
116+
/*dim=*/2,
117+
/*dim_block_size=*/is_buffer ? 4 : 16,
118+
/*dim_align=*/is_buffer ? 4 : 16,
119+
/*outer_dim=*/0,
120+
/*outer_dim_block_size=*/1,
121+
/*outer_dim_align=*/1,
122+
/*is_block_transposed=*/true);
123+
break;
124+
default:
125+
VK_THROW("Unknown GPUMemoryLayout");
126+
}
127+
128+
if (!is_buffer) {
129+
const int32_t block_numel = packed_dim_info.packed_dim_block_size *
130+
packed_dim_info.outer_packed_dim_block_size;
49131
if (is_packed_int8_layout(memory_layout)) {
50132
VK_CHECK_COND(block_numel == 16);
51133
} else {
52134
VK_CHECK_COND(block_numel == 4);
53135
}
54136
}
55137

56-
return PackedDimInfo(
57-
packed_dim,
58-
packed_dim_block_size,
59-
outer_packed_dim,
60-
outer_packed_dim_block_size,
61-
is_block_transposed);
138+
return packed_dim_info;
62139
}
63140

64141
/*
@@ -297,7 +374,8 @@ utils::ivec4 flip_and_unsqueeze_ivec4(
297374
* for GPU storage in the following ways:
298375
*
299376
* 1. The dimensionality of the tensor will be padded to a multiple of 4.
300-
* 2. The size of the packed dimension will be padded to a multiple of 4.
377+
* 2. The size of the packed dimension will be padded to a multiple of the
378+
* packed dimension's alignment value.
301379
*
302380
* The "packed dimension" is determined based on the utils::GPUMemoryLayout
303381
* argument.
@@ -317,23 +395,23 @@ std::vector<int64_t> calculate_padded_sizes(
317395
padded_sizes.at(i) = utils::val_at(i - ndim_up4, sizes);
318396
}
319397

320-
// Pad the packed dim to the block size
321-
if (packed_dim_info.packed_dim_block_size > 1) {
398+
// Pad the packed dim to the alignment
399+
if (packed_dim_info.packed_dim_align > 1) {
322400
const int64_t dim_offset = packed_dim_info.packed_dim + 1;
323401
const int64_t padded_dim_size = utils::val_at(-dim_offset, sizes);
324402
padded_sizes.at(ndim_up4 - dim_offset) = utils::align_up(
325403
padded_dim_size,
326-
static_cast<int64_t>(packed_dim_info.packed_dim_block_size));
404+
static_cast<int64_t>(packed_dim_info.packed_dim_align));
327405
}
328406

329-
// Also pad the outer packed dimension if it's different from the inner packed
330-
// dimension and is marked as padded.
331-
if (packed_dim_info.outer_packed_dim_block_size > 1) {
407+
// Also pad the outer packed dimension if it has alignment > 1.
408+
if (packed_dim_info.outer_packed_dim_align > 1) {
332409
const int64_t outer_dim_offset = packed_dim_info.outer_packed_dim + 1;
333410
const int64_t outer_padded_dim_size =
334411
utils::val_at(-outer_dim_offset, sizes);
335-
padded_sizes.at(ndim_up4 - outer_dim_offset) =
336-
utils::align_up_4(outer_padded_dim_size);
412+
padded_sizes.at(ndim_up4 - outer_dim_offset) = utils::align_up(
413+
outer_padded_dim_size,
414+
static_cast<int64_t>(packed_dim_info.outer_packed_dim_align));
337415
}
338416

339417
return padded_sizes;

backends/vulkan/runtime/api/containers/Tensor.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,12 @@ struct PackedDimInfo {
6767
// In physical memory, the size of the packed dim is aligned to this size to
6868
// ensure that data for the packed dim aligns with texel/block boundaries.
6969
int32_t packed_dim_block_size;
70+
// In physical memory, the size of the packed dimension will be aligned to be
71+
// a multiple of this value. This value must be a multiple of the packed_dim's
72+
// block size, and is selected for performance reasons i.e. to ensure loads
73+
// along the packed dim are aligned to cache lines, or to enable performance
74+
// optimizations in shaders, i.e. remove the need for bounds checking.
75+
int32_t packed_dim_align;
7076
// For block-packed layouts, represents the second tensor dimension that forms
7177
// the "width" dimension of the MxN square that is kept contiguous in memory.
7278
// For non block-packed layouts, represent the dimension with the next lowest
@@ -77,6 +83,8 @@ struct PackedDimInfo {
7783
// 4H4W, represents the "height" of the square block that is kept contiguous
7884
// in memory.
7985
int32_t outer_packed_dim_block_size;
86+
// See packed_dim_align
87+
int32_t outer_packed_dim_align;
8088
// Typically the blocks of the tensor will be arranged such that the inner
8189
// dim of the block (i.e. the packed dim) has the lowest stride, and the
8290
// outer dim of the block (i.e. the outer packed dim) has the next lowest
@@ -94,8 +102,10 @@ struct PackedDimInfo {
94102
PackedDimInfo(
95103
const int32_t dim,
96104
const int32_t dim_block_size,
105+
const int32_t dim_align,
97106
const int32_t outer_dim,
98107
const int32_t outer_dim_block_size,
108+
const int32_t outer_dim_align,
99109
const bool is_block_transposed);
100110
};
101111

backends/vulkan/runtime/utils/StorageUtils.h

Lines changed: 0 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -139,104 +139,6 @@ static constexpr GPUMemoryLayout kPackedInt8_4H4W =
139139
static constexpr GPUMemoryLayout kPackedInt8_4C1W =
140140
GPUMemoryLayout::TENSOR_PACKED_INT8_4C1W;
141141

142-
template <typename T>
143-
T to_packed_dim(const GPUMemoryLayout layout) {
144-
switch (layout) {
145-
case kWidthPacked:
146-
return 0;
147-
case kHeightPacked:
148-
return 1;
149-
case kChannelsPacked:
150-
return 2;
151-
case kPackedInt8_4W:
152-
return 0;
153-
case kPackedInt8_4C:
154-
return 2;
155-
case kPackedInt8_4W4C:
156-
return 2;
157-
case kPackedInt8_4H4W:
158-
return 0;
159-
case kPackedInt8_4C1W:
160-
return 2;
161-
};
162-
// Should be unreachable
163-
return 0;
164-
}
165-
166-
template <typename T>
167-
T to_outer_packed_dim(const GPUMemoryLayout layout) {
168-
switch (layout) {
169-
case kWidthPacked:
170-
return 1;
171-
case kHeightPacked:
172-
return 0;
173-
case kChannelsPacked:
174-
return 0;
175-
case kPackedInt8_4W:
176-
return 1;
177-
case kPackedInt8_4C:
178-
return 0;
179-
case kPackedInt8_4W4C:
180-
return 0;
181-
case kPackedInt8_4H4W:
182-
return 1;
183-
case kPackedInt8_4C1W:
184-
return 0;
185-
};
186-
// Should be unreachable
187-
return 1;
188-
}
189-
190-
template <typename T>
191-
T to_packed_dim_block_size(
192-
const GPUMemoryLayout layout,
193-
const StorageType storage) {
194-
switch (layout) {
195-
case kWidthPacked:
196-
return storage == kBuffer ? 1 : 4;
197-
case kHeightPacked:
198-
return storage == kBuffer ? 1 : 4;
199-
case kChannelsPacked:
200-
return storage == kBuffer ? 1 : 4;
201-
case kPackedInt8_4W:
202-
return storage == kBuffer ? 4 : 16;
203-
case kPackedInt8_4C:
204-
return storage == kBuffer ? 4 : 16;
205-
case kPackedInt8_4W4C:
206-
return 4;
207-
case kPackedInt8_4H4W:
208-
return 4;
209-
case kPackedInt8_4C1W:
210-
return storage == kBuffer ? 4 : 16;
211-
};
212-
// Should be unreachable
213-
return 1;
214-
}
215-
216-
template <typename T>
217-
T to_outer_packed_dim_block_size(const GPUMemoryLayout layout) {
218-
switch (layout) {
219-
case kWidthPacked:
220-
return 1;
221-
case kHeightPacked:
222-
return 1;
223-
case kChannelsPacked:
224-
return 1;
225-
case kPackedInt8_4W:
226-
return 1;
227-
case kPackedInt8_4C:
228-
return 1;
229-
case kPackedInt8_4W4C:
230-
return 4;
231-
case kPackedInt8_4H4W:
232-
return 4;
233-
case kPackedInt8_4C1W:
234-
return 1;
235-
};
236-
// Should be unreachable
237-
return 1;
238-
}
239-
240142
bool is_block_transposed_layout(const GPUMemoryLayout layout);
241143

242144
bool is_packed_int8_layout(const GPUMemoryLayout layout);

0 commit comments

Comments
 (0)