11#include < metal_stdlib>
22using namespace metal ;
33
4- // In Grid, SiteSpinor and SiteHalfSpinor are heavily templated.
5- // For Metal, we define the memory layout for SU(Nc) Nc=3, Nd=4.
6- // Spinor has 4 spin components, each with 3 color components.
7- // Each component is a complex number (2 floats or 2 doubles).
8- // For simplicity in Phase 2, we assume single precision floats (or we can use macros for precision).
9-
10- // Complex number structure
11- struct Complex {
12- float real;
13- float imag;
4+ struct StencilEntry {
5+ uint offset;
6+ uint is_local;
7+ uint permute;
8+ uint around_the_world; // Grid adds a fourth 32-bit int padding
149};
1510
16- // HalfSpinor: 2 spin components * 3 colors = 6 Complex numbers
17- struct SiteHalfSpinor {
18- Complex data[6 ];
19- };
11+ // Target OS is macOS (M-series), Grid utilizes NEON SIMD (Nsimd=2 for float complex).
12+ // 1 vComplexF = float4(lane0_real, lane0_imag, lane1_real, lane1_imag).
13+ struct vComplexF { float4 v; };
2014
21- // Spinor: 4 spin components * 3 colors = 12 Complex numbers
22- struct SiteSpinor {
23- Complex data[12 ];
24- };
15+ struct SiteHalfSpinor { float4 data[6 ]; };
16+ struct SiteSpinor { float4 data[12 ]; };
17+ struct SU3Matrix { float4 data[9 ]; };
2518
26- // SU(3) Matrix: 3x3 Complex numbers
27- struct SU3Matrix {
28- Complex data[9 ];
29- };
19+ // SIMD Algebra Math
20+ inline float4 timesI (float4 a) { return float4 (-a.y , a.x , -a.w , a.z ); }
21+ inline float4 timesMinusI (float4 a) { return float4 (a.y , -a.x , a.w , -a.z ); }
22+ inline float4 multComplex (float4 a, float4 b) {
23+ float4 r;
24+ r.x = a.x *b.x - a.y *b.y ; r.y = a.x *b.y + a.y *b.x ;
25+ r.z = a.z *b.z - a.w *b.w ; r.w = a.z *b.w + a.w *b.z ;
26+ return r;
27+ }
28+ inline float4 permute_lanes (float4 a) { return a.zwxy ; }
3029
31- // Gauge field link has 4 directions * SU(3) Matrix per site
32- // Wait, DoubledGaugeField stores U[site][dir] in a specific layout.
33- // StencilEntry contains offsets and permutation flags.
34- struct StencilEntry {
35- uint32_t offset;
36- uint32_t is_local;
37- uint32_t permute;
38- };
30+ // SU(3) multiplies a 6-component Half Spinor
31+ inline SiteHalfSpinor multLink (SU3Matrix U, SiteHalfSpinor chi) {
32+ SiteHalfSpinor res;
33+ for (int s=0 ; s<2 ; s++) {
34+ for (int c=0 ; c<3 ; c++) {
35+ float4 sum = float4 (0 .0f );
36+ for (int k=0 ; k<3 ; k++) {
37+ sum += multComplex (U.data [c*3 + k], chi.data [s*3 + k]);
38+ }
39+ res.data [s*3 + c] = sum;
40+ }
41+ }
42+ return res;
43+ }
44+
45+ // Xp projector (1 - gamma_x)
46+ inline SiteHalfSpinor spProjXp (SiteSpinor fspin, uint perm) {
47+ SiteHalfSpinor hspin;
48+ for (int c=0 ; c<3 ; ++c) {
49+ hspin.data [0 *3 +c] = fspin.data [0 *3 +c] + timesI (fspin.data [3 *3 +c]);
50+ hspin.data [1 *3 +c] = fspin.data [1 *3 +c] + timesI (fspin.data [2 *3 +c]);
51+ if (perm) { hspin.data [0 *3 +c] = permute_lanes (hspin.data [0 *3 +c]); hspin.data [1 *3 +c] = permute_lanes (hspin.data [1 *3 +c]); }
52+ }
53+ return hspin;
54+ }
55+ inline SiteHalfSpinor spProjXm (SiteSpinor fspin, uint perm) {
56+ SiteHalfSpinor hspin;
57+ for (int c=0 ; c<3 ; ++c) {
58+ hspin.data [0 *3 +c] = fspin.data [0 *3 +c] - timesI (fspin.data [3 *3 +c]);
59+ hspin.data [1 *3 +c] = fspin.data [1 *3 +c] - timesI (fspin.data [2 *3 +c]);
60+ if (perm) { hspin.data [0 *3 +c] = permute_lanes (hspin.data [0 *3 +c]); hspin.data [1 *3 +c] = permute_lanes (hspin.data [1 *3 +c]); }
61+ }
62+ return hspin;
63+ }
64+ inline SiteHalfSpinor spProjYp (SiteSpinor fspin, uint perm) {
65+ SiteHalfSpinor hspin;
66+ for (int c=0 ; c<3 ; ++c) {
67+ hspin.data [0 *3 +c] = fspin.data [0 *3 +c] - fspin.data [3 *3 +c];
68+ hspin.data [1 *3 +c] = fspin.data [1 *3 +c] + fspin.data [2 *3 +c];
69+ if (perm) { hspin.data [0 *3 +c] = permute_lanes (hspin.data [0 *3 +c]); hspin.data [1 *3 +c] = permute_lanes (hspin.data [1 *3 +c]); }
70+ }
71+ return hspin;
72+ }
73+ inline SiteHalfSpinor spProjYm (SiteSpinor fspin, uint perm) {
74+ SiteHalfSpinor hspin;
75+ for (int c=0 ; c<3 ; ++c) {
76+ hspin.data [0 *3 +c] = fspin.data [0 *3 +c] + fspin.data [3 *3 +c];
77+ hspin.data [1 *3 +c] = fspin.data [1 *3 +c] - fspin.data [2 *3 +c];
78+ if (perm) { hspin.data [0 *3 +c] = permute_lanes (hspin.data [0 *3 +c]); hspin.data [1 *3 +c] = permute_lanes (hspin.data [1 *3 +c]); }
79+ }
80+ return hspin;
81+ }
82+ inline SiteHalfSpinor spProjZp (SiteSpinor fspin, uint perm) {
83+ SiteHalfSpinor hspin;
84+ for (int c=0 ; c<3 ; ++c) {
85+ hspin.data [0 *3 +c] = fspin.data [0 *3 +c] + timesI (fspin.data [2 *3 +c]);
86+ hspin.data [1 *3 +c] = fspin.data [1 *3 +c] - timesI (fspin.data [3 *3 +c]);
87+ if (perm) { hspin.data [0 *3 +c] = permute_lanes (hspin.data [0 *3 +c]); hspin.data [1 *3 +c] = permute_lanes (hspin.data [1 *3 +c]); }
88+ }
89+ return hspin;
90+ }
91+ inline SiteHalfSpinor spProjZm (SiteSpinor fspin, uint perm) {
92+ SiteHalfSpinor hspin;
93+ for (int c=0 ; c<3 ; ++c) {
94+ hspin.data [0 *3 +c] = fspin.data [0 *3 +c] - timesI (fspin.data [2 *3 +c]);
95+ hspin.data [1 *3 +c] = fspin.data [1 *3 +c] + timesI (fspin.data [3 *3 +c]);
96+ if (perm) { hspin.data [0 *3 +c] = permute_lanes (hspin.data [0 *3 +c]); hspin.data [1 *3 +c] = permute_lanes (hspin.data [1 *3 +c]); }
97+ }
98+ return hspin;
99+ }
100+ inline SiteHalfSpinor spProjTp (SiteSpinor fspin, uint perm) {
101+ SiteHalfSpinor hspin;
102+ for (int c=0 ; c<3 ; ++c) {
103+ hspin.data [0 *3 +c] = fspin.data [0 *3 +c] + fspin.data [2 *3 +c];
104+ hspin.data [1 *3 +c] = fspin.data [1 *3 +c] + fspin.data [3 *3 +c];
105+ if (perm) { hspin.data [0 *3 +c] = permute_lanes (hspin.data [0 *3 +c]); hspin.data [1 *3 +c] = permute_lanes (hspin.data [1 *3 +c]); }
106+ }
107+ return hspin;
108+ }
109+ inline SiteHalfSpinor spProjTm (SiteSpinor fspin, uint perm) {
110+ SiteHalfSpinor hspin;
111+ for (int c=0 ; c<3 ; ++c) {
112+ hspin.data [0 *3 +c] = fspin.data [0 *3 +c] - fspin.data [2 *3 +c];
113+ hspin.data [1 *3 +c] = fspin.data [1 *3 +c] - fspin.data [3 *3 +c];
114+ if (perm) { hspin.data [0 *3 +c] = permute_lanes (hspin.data [0 *3 +c]); hspin.data [1 *3 +c] = permute_lanes (hspin.data [1 *3 +c]); }
115+ }
116+ return hspin;
117+ }
118+
119+ // Reconstructors
120+ inline void spReconXp (thread SiteSpinor& out, SiteHalfSpinor hspin) {
121+ for (int c=0 ; c<3 ; ++c) {
122+ out.data [0 *3 +c] = hspin.data [0 *3 +c];
123+ out.data [1 *3 +c] = hspin.data [1 *3 +c];
124+ out.data [2 *3 +c] = timesMinusI (hspin.data [1 *3 +c]);
125+ out.data [3 *3 +c] = timesMinusI (hspin.data [0 *3 +c]);
126+ }
127+ }
128+ inline void accumReconXp (thread SiteSpinor& out, SiteHalfSpinor hspin) {
129+ for (int c=0 ; c<3 ; ++c) {
130+ out.data [0 *3 +c] += hspin.data [0 *3 +c];
131+ out.data [1 *3 +c] += hspin.data [1 *3 +c];
132+ out.data [2 *3 +c] -= timesI (hspin.data [1 *3 +c]);
133+ out.data [3 *3 +c] -= timesI (hspin.data [0 *3 +c]);
134+ }
135+ }
136+ inline void accumReconYp (thread SiteSpinor& out, SiteHalfSpinor hspin) {
137+ for (int c=0 ; c<3 ; ++c) {
138+ out.data [0 *3 +c] += hspin.data [0 *3 +c];
139+ out.data [1 *3 +c] += hspin.data [1 *3 +c];
140+ out.data [2 *3 +c] += hspin.data [1 *3 +c];
141+ out.data [3 *3 +c] -= hspin.data [0 *3 +c];
142+ }
143+ }
144+ inline void accumReconZp (thread SiteSpinor& out, SiteHalfSpinor hspin) {
145+ for (int c=0 ; c<3 ; ++c) {
146+ out.data [0 *3 +c] += hspin.data [0 *3 +c];
147+ out.data [1 *3 +c] += hspin.data [1 *3 +c];
148+ out.data [2 *3 +c] -= timesI (hspin.data [0 *3 +c]);
149+ out.data [3 *3 +c] += timesI (hspin.data [1 *3 +c]);
150+ }
151+ }
152+ inline void accumReconTp (thread SiteSpinor& out, SiteHalfSpinor hspin) {
153+ for (int c=0 ; c<3 ; ++c) {
154+ out.data [0 *3 +c] += hspin.data [0 *3 +c];
155+ out.data [1 *3 +c] += hspin.data [1 *3 +c];
156+ out.data [2 *3 +c] += hspin.data [0 *3 +c];
157+ out.data [3 *3 +c] += hspin.data [1 *3 +c];
158+ }
159+ }
160+ inline void accumReconXm (thread SiteSpinor& out, SiteHalfSpinor hspin) {
161+ for (int c=0 ; c<3 ; ++c) {
162+ out.data [0 *3 +c] += hspin.data [0 *3 +c];
163+ out.data [1 *3 +c] += hspin.data [1 *3 +c];
164+ out.data [2 *3 +c] += timesI (hspin.data [1 *3 +c]);
165+ out.data [3 *3 +c] += timesI (hspin.data [0 *3 +c]);
166+ }
167+ }
168+ inline void accumReconYm (thread SiteSpinor& out, SiteHalfSpinor hspin) {
169+ for (int c=0 ; c<3 ; ++c) {
170+ out.data [0 *3 +c] += hspin.data [0 *3 +c];
171+ out.data [1 *3 +c] += hspin.data [1 *3 +c];
172+ out.data [2 *3 +c] -= hspin.data [1 *3 +c];
173+ out.data [3 *3 +c] += hspin.data [0 *3 +c];
174+ }
175+ }
176+ inline void accumReconZm (thread SiteSpinor& out, SiteHalfSpinor hspin) {
177+ for (int c=0 ; c<3 ; ++c) {
178+ out.data [0 *3 +c] += hspin.data [0 *3 +c];
179+ out.data [1 *3 +c] += hspin.data [1 *3 +c];
180+ out.data [2 *3 +c] += timesI (hspin.data [0 *3 +c]);
181+ out.data [3 *3 +c] -= timesI (hspin.data [1 *3 +c]);
182+ }
183+ }
184+ inline void accumReconTm (thread SiteSpinor& out, SiteHalfSpinor hspin) {
185+ for (int c=0 ; c<3 ; ++c) {
186+ out.data [0 *3 +c] += hspin.data [0 *3 +c];
187+ out.data [1 *3 +c] += hspin.data [1 *3 +c];
188+ out.data [2 *3 +c] -= hspin.data [0 *3 +c];
189+ out.data [3 *3 +c] -= hspin.data [1 *3 +c];
190+ }
191+ }
39192
40193// Kernel to execute the Wilson Dslash
41194kernel void GenericDhopSite (
@@ -45,12 +198,82 @@ kernel void GenericDhopSite(
45198 device const StencilEntry* stencil [[buffer(3 )]],
46199 constant uint32_t& Ls [[buffer(4 )]],
47200 constant uint32_t& Nsite [[buffer(5 )]],
201+ device const SiteHalfSpinor* buf [[buffer(6 )]],
48202 uint id [[thread_position_in_grid]]
49203) {
50204 if (id >= Nsite * Ls) return ;
51205
52- // TODO: Implement the 8-way stencil hops mapping to spProj, multLink, Recon
53- // For now, this serves as the foundational shader compile target.
206+ uint sF = id; // Spinor site index
207+ uint sU = id / Ls; // Gauge field site index (if Ls=1 this is the same)
208+
209+ SiteSpinor result;
210+ for (int i=0 ; i<12 ; i++) result.data [i] = float4 (0 .0f );
54211
55- // out_spinor[id] = in_spinor[id]; // Basic passthrough for testing pipeline
212+ // 8-Way Stencil Execution (Xp, Yp, Zp, Tp, Xm, Ym, Zm, Tm)
213+ // Dir = 0 (Xp)
214+ {
215+ StencilEntry SE = stencil[0 * Nsite + sU ];
216+ SiteHalfSpinor hs = spProjXp (in_spinor[SE.offset ], SE.permute );
217+ SU3Matrix U = gauge_field[sU * 8 + 0 ]; // +0 for Xp (Grid layout: U[Xp, Yp, Zp, Tp, Xm, Ym, Zm, Tm])
218+ SiteHalfSpinor chi = multLink (U, hs);
219+ spReconXp (result, chi);
220+ }
221+ // Dir = 1 (Yp)
222+ {
223+ StencilEntry SE = stencil[1 * Nsite + sU ];
224+ SiteHalfSpinor hs = spProjYp (in_spinor[SE.offset ], SE.permute );
225+ SU3Matrix U = gauge_field[sU * 8 + 1 ];
226+ SiteHalfSpinor chi = multLink (U, hs);
227+ accumReconYp (result, chi);
228+ }
229+ // Dir = 2 (Zp)
230+ {
231+ StencilEntry SE = stencil[2 * Nsite + sU ];
232+ SiteHalfSpinor hs = spProjZp (in_spinor[SE.offset ], SE.permute );
233+ SU3Matrix U = gauge_field[sU * 8 + 2 ];
234+ SiteHalfSpinor chi = multLink (U, hs);
235+ accumReconZp (result, chi);
236+ }
237+ // Dir = 3 (Tp)
238+ {
239+ StencilEntry SE = stencil[3 * Nsite + sU ];
240+ SiteHalfSpinor hs = spProjTp (in_spinor[SE.offset ], SE.permute );
241+ SU3Matrix U = gauge_field[sU * 8 + 3 ];
242+ SiteHalfSpinor chi = multLink (U, hs);
243+ accumReconTp (result, chi);
244+ }
245+ // Dir = 4 (Xm)
246+ {
247+ StencilEntry SE = stencil[4 * Nsite + sU ];
248+ SiteHalfSpinor hs = spProjXm (in_spinor[SE.offset ], SE.permute );
249+ SU3Matrix U = gauge_field[sU * 8 + 4 ];
250+ SiteHalfSpinor chi = multLink (U, hs);
251+ accumReconXm (result, chi);
252+ }
253+ // Dir = 5 (Ym)
254+ {
255+ StencilEntry SE = stencil[5 * Nsite + sU ];
256+ SiteHalfSpinor hs = spProjYm (in_spinor[SE.offset ], SE.permute );
257+ SU3Matrix U = gauge_field[sU * 8 + 5 ];
258+ SiteHalfSpinor chi = multLink (U, hs);
259+ accumReconYm (result, chi);
260+ }
261+ // Dir = 6 (Zm)
262+ {
263+ StencilEntry SE = stencil[6 * Nsite + sU ];
264+ SiteHalfSpinor hs = spProjZm (in_spinor[SE.offset ], SE.permute );
265+ SU3Matrix U = gauge_field[sU * 8 + 6 ];
266+ SiteHalfSpinor chi = multLink (U, hs);
267+ accumReconZm (result, chi);
268+ }
269+ // Dir = 7 (Tm)
270+ {
271+ StencilEntry SE = stencil[7 * Nsite + sU ];
272+ SiteHalfSpinor hs = spProjTm (in_spinor[SE.offset ], SE.permute );
273+ SU3Matrix U = gauge_field[sU * 8 + 7 ];
274+ SiteHalfSpinor chi = multLink (U, hs);
275+ accumReconTm (result, chi);
276+ }
277+
278+ out_spinor[sF ] = result;
56279}
0 commit comments