Skip to content

Commit 5dd94bc

Browse files
Petrus PennanenPetrus Pennanen
authored andcommitted
feat: complete MSL Phase 3 Wilson Dslash kernel 8-way stencil logic
1 parent 5b3e3a1 commit 5dd94bc

1 file changed

Lines changed: 256 additions & 33 deletions

File tree

Lines changed: 256 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,194 @@
11
#include <metal_stdlib>
22
using namespace metal;
33

4-
// In Grid, SiteSpinor and SiteHalfSpinor are heavily templated.
5-
// For Metal, we define the memory layout for SU(Nc) Nc=3, Nd=4.
6-
// Spinor has 4 spin components, each with 3 color components.
7-
// Each component is a complex number (2 floats or 2 doubles).
8-
// For simplicity in Phase 2, we assume single precision floats (or we can use macros for precision).
9-
10-
// Complex number structure
11-
struct Complex {
12-
float real;
13-
float imag;
4+
struct StencilEntry {
5+
uint offset;
6+
uint is_local;
7+
uint permute;
8+
uint around_the_world; // Grid adds a fourth 32-bit int padding
149
};
1510

16-
// HalfSpinor: 2 spin components * 3 colors = 6 Complex numbers
17-
struct SiteHalfSpinor {
18-
Complex data[6];
19-
};
11+
// Target OS is macOS (M-series), Grid utilizes NEON SIMD (Nsimd=2 for float complex).
12+
// 1 vComplexF = float4(lane0_real, lane0_imag, lane1_real, lane1_imag).
13+
struct vComplexF { float4 v; };
2014

21-
// Spinor: 4 spin components * 3 colors = 12 Complex numbers
22-
struct SiteSpinor {
23-
Complex data[12];
24-
};
15+
struct SiteHalfSpinor { float4 data[6]; };
16+
struct SiteSpinor { float4 data[12]; };
17+
struct SU3Matrix { float4 data[9]; };
2518

26-
// SU(3) Matrix: 3x3 Complex numbers
27-
struct SU3Matrix {
28-
Complex data[9];
29-
};
19+
// SIMD Algebra Math
20+
inline float4 timesI(float4 a) { return float4(-a.y, a.x, -a.w, a.z); }
21+
inline float4 timesMinusI(float4 a) { return float4(a.y, -a.x, a.w, -a.z); }
22+
inline float4 multComplex(float4 a, float4 b) {
23+
float4 r;
24+
r.x = a.x*b.x - a.y*b.y; r.y = a.x*b.y + a.y*b.x;
25+
r.z = a.z*b.z - a.w*b.w; r.w = a.z*b.w + a.w*b.z;
26+
return r;
27+
}
28+
inline float4 permute_lanes(float4 a) { return a.zwxy; }
3029

31-
// Gauge field link has 4 directions * SU(3) Matrix per site
32-
// Wait, DoubledGaugeField stores U[site][dir] in a specific layout.
33-
// StencilEntry contains offsets and permutation flags.
34-
struct StencilEntry {
35-
uint32_t offset;
36-
uint32_t is_local;
37-
uint32_t permute;
38-
};
30+
// SU(3) multiplies a 6-component Half Spinor
31+
inline SiteHalfSpinor multLink(SU3Matrix U, SiteHalfSpinor chi) {
32+
SiteHalfSpinor res;
33+
for(int s=0; s<2; s++) {
34+
for(int c=0; c<3; c++) {
35+
float4 sum = float4(0.0f);
36+
for(int k=0; k<3; k++) {
37+
sum += multComplex(U.data[c*3 + k], chi.data[s*3 + k]);
38+
}
39+
res.data[s*3 + c] = sum;
40+
}
41+
}
42+
return res;
43+
}
44+
45+
// Xp projector (1 - gamma_x)
46+
inline SiteHalfSpinor spProjXp(SiteSpinor fspin, uint perm) {
47+
SiteHalfSpinor hspin;
48+
for(int c=0; c<3; ++c) {
49+
hspin.data[0*3+c] = fspin.data[0*3+c] + timesI(fspin.data[3*3+c]);
50+
hspin.data[1*3+c] = fspin.data[1*3+c] + timesI(fspin.data[2*3+c]);
51+
if(perm) { hspin.data[0*3+c] = permute_lanes(hspin.data[0*3+c]); hspin.data[1*3+c] = permute_lanes(hspin.data[1*3+c]); }
52+
}
53+
return hspin;
54+
}
55+
inline SiteHalfSpinor spProjXm(SiteSpinor fspin, uint perm) {
56+
SiteHalfSpinor hspin;
57+
for(int c=0; c<3; ++c) {
58+
hspin.data[0*3+c] = fspin.data[0*3+c] - timesI(fspin.data[3*3+c]);
59+
hspin.data[1*3+c] = fspin.data[1*3+c] - timesI(fspin.data[2*3+c]);
60+
if(perm) { hspin.data[0*3+c] = permute_lanes(hspin.data[0*3+c]); hspin.data[1*3+c] = permute_lanes(hspin.data[1*3+c]); }
61+
}
62+
return hspin;
63+
}
64+
inline SiteHalfSpinor spProjYp(SiteSpinor fspin, uint perm) {
65+
SiteHalfSpinor hspin;
66+
for(int c=0; c<3; ++c) {
67+
hspin.data[0*3+c] = fspin.data[0*3+c] - fspin.data[3*3+c];
68+
hspin.data[1*3+c] = fspin.data[1*3+c] + fspin.data[2*3+c];
69+
if(perm) { hspin.data[0*3+c] = permute_lanes(hspin.data[0*3+c]); hspin.data[1*3+c] = permute_lanes(hspin.data[1*3+c]); }
70+
}
71+
return hspin;
72+
}
73+
inline SiteHalfSpinor spProjYm(SiteSpinor fspin, uint perm) {
74+
SiteHalfSpinor hspin;
75+
for(int c=0; c<3; ++c) {
76+
hspin.data[0*3+c] = fspin.data[0*3+c] + fspin.data[3*3+c];
77+
hspin.data[1*3+c] = fspin.data[1*3+c] - fspin.data[2*3+c];
78+
if(perm) { hspin.data[0*3+c] = permute_lanes(hspin.data[0*3+c]); hspin.data[1*3+c] = permute_lanes(hspin.data[1*3+c]); }
79+
}
80+
return hspin;
81+
}
82+
inline SiteHalfSpinor spProjZp(SiteSpinor fspin, uint perm) {
83+
SiteHalfSpinor hspin;
84+
for(int c=0; c<3; ++c) {
85+
hspin.data[0*3+c] = fspin.data[0*3+c] + timesI(fspin.data[2*3+c]);
86+
hspin.data[1*3+c] = fspin.data[1*3+c] - timesI(fspin.data[3*3+c]);
87+
if(perm) { hspin.data[0*3+c] = permute_lanes(hspin.data[0*3+c]); hspin.data[1*3+c] = permute_lanes(hspin.data[1*3+c]); }
88+
}
89+
return hspin;
90+
}
91+
inline SiteHalfSpinor spProjZm(SiteSpinor fspin, uint perm) {
92+
SiteHalfSpinor hspin;
93+
for(int c=0; c<3; ++c) {
94+
hspin.data[0*3+c] = fspin.data[0*3+c] - timesI(fspin.data[2*3+c]);
95+
hspin.data[1*3+c] = fspin.data[1*3+c] + timesI(fspin.data[3*3+c]);
96+
if(perm) { hspin.data[0*3+c] = permute_lanes(hspin.data[0*3+c]); hspin.data[1*3+c] = permute_lanes(hspin.data[1*3+c]); }
97+
}
98+
return hspin;
99+
}
100+
inline SiteHalfSpinor spProjTp(SiteSpinor fspin, uint perm) {
101+
SiteHalfSpinor hspin;
102+
for(int c=0; c<3; ++c) {
103+
hspin.data[0*3+c] = fspin.data[0*3+c] + fspin.data[2*3+c];
104+
hspin.data[1*3+c] = fspin.data[1*3+c] + fspin.data[3*3+c];
105+
if(perm) { hspin.data[0*3+c] = permute_lanes(hspin.data[0*3+c]); hspin.data[1*3+c] = permute_lanes(hspin.data[1*3+c]); }
106+
}
107+
return hspin;
108+
}
109+
inline SiteHalfSpinor spProjTm(SiteSpinor fspin, uint perm) {
110+
SiteHalfSpinor hspin;
111+
for(int c=0; c<3; ++c) {
112+
hspin.data[0*3+c] = fspin.data[0*3+c] - fspin.data[2*3+c];
113+
hspin.data[1*3+c] = fspin.data[1*3+c] - fspin.data[3*3+c];
114+
if(perm) { hspin.data[0*3+c] = permute_lanes(hspin.data[0*3+c]); hspin.data[1*3+c] = permute_lanes(hspin.data[1*3+c]); }
115+
}
116+
return hspin;
117+
}
118+
119+
// Reconstructors
120+
inline void spReconXp(thread SiteSpinor& out, SiteHalfSpinor hspin) {
121+
for(int c=0; c<3; ++c) {
122+
out.data[0*3+c] = hspin.data[0*3+c];
123+
out.data[1*3+c] = hspin.data[1*3+c];
124+
out.data[2*3+c] = timesMinusI(hspin.data[1*3+c]);
125+
out.data[3*3+c] = timesMinusI(hspin.data[0*3+c]);
126+
}
127+
}
128+
inline void accumReconXp(thread SiteSpinor& out, SiteHalfSpinor hspin) {
129+
for(int c=0; c<3; ++c) {
130+
out.data[0*3+c] += hspin.data[0*3+c];
131+
out.data[1*3+c] += hspin.data[1*3+c];
132+
out.data[2*3+c] -= timesI(hspin.data[1*3+c]);
133+
out.data[3*3+c] -= timesI(hspin.data[0*3+c]);
134+
}
135+
}
136+
inline void accumReconYp(thread SiteSpinor& out, SiteHalfSpinor hspin) {
137+
for(int c=0; c<3; ++c) {
138+
out.data[0*3+c] += hspin.data[0*3+c];
139+
out.data[1*3+c] += hspin.data[1*3+c];
140+
out.data[2*3+c] += hspin.data[1*3+c];
141+
out.data[3*3+c] -= hspin.data[0*3+c];
142+
}
143+
}
144+
inline void accumReconZp(thread SiteSpinor& out, SiteHalfSpinor hspin) {
145+
for(int c=0; c<3; ++c) {
146+
out.data[0*3+c] += hspin.data[0*3+c];
147+
out.data[1*3+c] += hspin.data[1*3+c];
148+
out.data[2*3+c] -= timesI(hspin.data[0*3+c]);
149+
out.data[3*3+c] += timesI(hspin.data[1*3+c]);
150+
}
151+
}
152+
inline void accumReconTp(thread SiteSpinor& out, SiteHalfSpinor hspin) {
153+
for(int c=0; c<3; ++c) {
154+
out.data[0*3+c] += hspin.data[0*3+c];
155+
out.data[1*3+c] += hspin.data[1*3+c];
156+
out.data[2*3+c] += hspin.data[0*3+c];
157+
out.data[3*3+c] += hspin.data[1*3+c];
158+
}
159+
}
160+
inline void accumReconXm(thread SiteSpinor& out, SiteHalfSpinor hspin) {
161+
for(int c=0; c<3; ++c) {
162+
out.data[0*3+c] += hspin.data[0*3+c];
163+
out.data[1*3+c] += hspin.data[1*3+c];
164+
out.data[2*3+c] += timesI(hspin.data[1*3+c]);
165+
out.data[3*3+c] += timesI(hspin.data[0*3+c]);
166+
}
167+
}
168+
inline void accumReconYm(thread SiteSpinor& out, SiteHalfSpinor hspin) {
169+
for(int c=0; c<3; ++c) {
170+
out.data[0*3+c] += hspin.data[0*3+c];
171+
out.data[1*3+c] += hspin.data[1*3+c];
172+
out.data[2*3+c] -= hspin.data[1*3+c];
173+
out.data[3*3+c] += hspin.data[0*3+c];
174+
}
175+
}
176+
inline void accumReconZm(thread SiteSpinor& out, SiteHalfSpinor hspin) {
177+
for(int c=0; c<3; ++c) {
178+
out.data[0*3+c] += hspin.data[0*3+c];
179+
out.data[1*3+c] += hspin.data[1*3+c];
180+
out.data[2*3+c] += timesI(hspin.data[0*3+c]);
181+
out.data[3*3+c] -= timesI(hspin.data[1*3+c]);
182+
}
183+
}
184+
inline void accumReconTm(thread SiteSpinor& out, SiteHalfSpinor hspin) {
185+
for(int c=0; c<3; ++c) {
186+
out.data[0*3+c] += hspin.data[0*3+c];
187+
out.data[1*3+c] += hspin.data[1*3+c];
188+
out.data[2*3+c] -= hspin.data[0*3+c];
189+
out.data[3*3+c] -= hspin.data[1*3+c];
190+
}
191+
}
39192

40193
// Kernel to execute the Wilson Dslash
41194
kernel void GenericDhopSite(
@@ -45,12 +198,82 @@ kernel void GenericDhopSite(
45198
device const StencilEntry* stencil [[buffer(3)]],
46199
constant uint32_t& Ls [[buffer(4)]],
47200
constant uint32_t& Nsite [[buffer(5)]],
201+
device const SiteHalfSpinor* buf [[buffer(6)]],
48202
uint id [[thread_position_in_grid]]
49203
) {
50204
if (id >= Nsite * Ls) return;
51205

52-
// TODO: Implement the 8-way stencil hops mapping to spProj, multLink, Recon
53-
// For now, this serves as the foundational shader compile target.
206+
uint sF = id; // Spinor site index
207+
uint sU = id / Ls; // Gauge field site index (if Ls=1 this is the same)
208+
209+
SiteSpinor result;
210+
for(int i=0; i<12; i++) result.data[i] = float4(0.0f);
54211

55-
// out_spinor[id] = in_spinor[id]; // Basic passthrough for testing pipeline
212+
// 8-Way Stencil Execution (Xp, Yp, Zp, Tp, Xm, Ym, Zm, Tm)
213+
// Dir = 0 (Xp)
214+
{
215+
StencilEntry SE = stencil[0 * Nsite + sU];
216+
SiteHalfSpinor hs = spProjXp(in_spinor[SE.offset], SE.permute);
217+
SU3Matrix U = gauge_field[sU * 8 + 0]; // +0 for Xp (Grid layout: U[Xp, Yp, Zp, Tp, Xm, Ym, Zm, Tm])
218+
SiteHalfSpinor chi = multLink(U, hs);
219+
spReconXp(result, chi);
220+
}
221+
// Dir = 1 (Yp)
222+
{
223+
StencilEntry SE = stencil[1 * Nsite + sU];
224+
SiteHalfSpinor hs = spProjYp(in_spinor[SE.offset], SE.permute);
225+
SU3Matrix U = gauge_field[sU * 8 + 1];
226+
SiteHalfSpinor chi = multLink(U, hs);
227+
accumReconYp(result, chi);
228+
}
229+
// Dir = 2 (Zp)
230+
{
231+
StencilEntry SE = stencil[2 * Nsite + sU];
232+
SiteHalfSpinor hs = spProjZp(in_spinor[SE.offset], SE.permute);
233+
SU3Matrix U = gauge_field[sU * 8 + 2];
234+
SiteHalfSpinor chi = multLink(U, hs);
235+
accumReconZp(result, chi);
236+
}
237+
// Dir = 3 (Tp)
238+
{
239+
StencilEntry SE = stencil[3 * Nsite + sU];
240+
SiteHalfSpinor hs = spProjTp(in_spinor[SE.offset], SE.permute);
241+
SU3Matrix U = gauge_field[sU * 8 + 3];
242+
SiteHalfSpinor chi = multLink(U, hs);
243+
accumReconTp(result, chi);
244+
}
245+
// Dir = 4 (Xm)
246+
{
247+
StencilEntry SE = stencil[4 * Nsite + sU];
248+
SiteHalfSpinor hs = spProjXm(in_spinor[SE.offset], SE.permute);
249+
SU3Matrix U = gauge_field[sU * 8 + 4];
250+
SiteHalfSpinor chi = multLink(U, hs);
251+
accumReconXm(result, chi);
252+
}
253+
// Dir = 5 (Ym)
254+
{
255+
StencilEntry SE = stencil[5 * Nsite + sU];
256+
SiteHalfSpinor hs = spProjYm(in_spinor[SE.offset], SE.permute);
257+
SU3Matrix U = gauge_field[sU * 8 + 5];
258+
SiteHalfSpinor chi = multLink(U, hs);
259+
accumReconYm(result, chi);
260+
}
261+
// Dir = 6 (Zm)
262+
{
263+
StencilEntry SE = stencil[6 * Nsite + sU];
264+
SiteHalfSpinor hs = spProjZm(in_spinor[SE.offset], SE.permute);
265+
SU3Matrix U = gauge_field[sU * 8 + 6];
266+
SiteHalfSpinor chi = multLink(U, hs);
267+
accumReconZm(result, chi);
268+
}
269+
// Dir = 7 (Tm)
270+
{
271+
StencilEntry SE = stencil[7 * Nsite + sU];
272+
SiteHalfSpinor hs = spProjTm(in_spinor[SE.offset], SE.permute);
273+
SU3Matrix U = gauge_field[sU * 8 + 7];
274+
SiteHalfSpinor chi = multLink(U, hs);
275+
accumReconTm(result, chi);
276+
}
277+
278+
out_spinor[sF] = result;
56279
}

0 commit comments

Comments
 (0)