You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
size(shape.batch * shape.num_heads_q)); // (h,b) -- split later
118
124
int num_heads = shape.num_heads_q;
119
-
120
-
auto total_wg = grid.x * grid.y * grid.z;
121
-
// FIXME: replace with runtime check
122
-
assert(shape.batch == 1);
123
-
assert((grid.z <= hw_info.sm_count / 2) && "XeFHMAIndividualPersistentTileScheduler only enabled for decode case where num batch heads samller than SM count");
124
-
125
-
// how many partitions each KV seq is split into
126
-
int num_partitions = hw_info.sm_count / grid.z;
127
-
// this is for the case where sm_count cannot be divisible by num_batch_heads,
128
-
// for some head/work group, the KV seq need to split into `num_partitions+1`
129
-
// partitions to occupy all xecores, here we assme first `tail_wg` work groups
130
-
// will handle one more partition
131
-
// for eample, num head is 8, sm_count is 20, so first 20%8=4 work groups
132
-
// will handle 3 partitions, the rest 4 work groups will handle 2 partitions
133
-
int num_tail_wg = hw_info.sm_count % grid.z;
134
-
135
-
// assume grid shape (1, 1, hw_info.sm_count) to use all xecores
136
125
grid.z = hw_info.sm_count;
137
-
// int num_partitions = 4; // for 5/1
138
-
// grid.z *= num_partitions;
139
-
// num_heads *= num_partitions;
140
-
141
-
// FIXME: add fallback mechanism if given problem size doesn't meet requirement
0 commit comments