@@ -17,6 +17,11 @@ pub mod ops;
1717mod intrinsics;
1818use intrinsics:: * ;
1919
20+ // Stride validation module
21+ pub mod stride;
22+ // Re-export for convenience
23+ pub use stride:: { StrideValidator , ValidStride } ;
24+
2025// ============================================================================
2126// Shape Types with Compile-Time Validation
2227// ============================================================================
@@ -1749,80 +1754,7 @@ impl MmaWithShapeAndLayout<f64, f64, f64, dims::Shape<8, 8, 4>, layout::Row, lay
17491754 }
17501755}
17511756
1752- // ============================================================================
1753- // Stride Validation
1754- // ============================================================================
1755-
1756- /// Compile-time stride validation
1757- pub struct StrideValidator < T , const STRIDE : usize > ( PhantomData < T > ) ;
1758-
1759- /// Trait for valid strides
1760- #[ diagnostic:: on_unimplemented(
1761- message = "`{Self}` is not a valid stride configuration" ,
1762- label = "invalid stride for tensor core operations" ,
1763- note = "f16/bf16 require stride to be a multiple of 8" ,
1764- note = "f32/i32 require stride to be a multiple of 4" ,
1765- note = "i8/u8/bool require stride to be a multiple of 16" ,
1766- note = "f64 requires stride to be a multiple of 2"
1767- ) ]
1768- pub trait ValidStride : sealed:: Sealed { }
1769-
1770- // f16 requires stride to be multiple of 8
1771- impl ValidStride for StrideValidator < f16 , 8 > { }
1772- impl ValidStride for StrideValidator < f16 , 16 > { }
1773- impl ValidStride for StrideValidator < f16 , 24 > { }
1774- impl ValidStride for StrideValidator < f16 , 32 > { }
1775- impl ValidStride for StrideValidator < f16 , 40 > { }
1776- impl ValidStride for StrideValidator < f16 , 48 > { }
1777- impl ValidStride for StrideValidator < f16 , 56 > { }
1778- impl ValidStride for StrideValidator < f16 , 64 > { }
1779-
1780- // f32 requires stride to be multiple of 4
1781- impl ValidStride for StrideValidator < f32 , 4 > { }
1782- impl ValidStride for StrideValidator < f32 , 8 > { }
1783- impl ValidStride for StrideValidator < f32 , 12 > { }
1784- impl ValidStride for StrideValidator < f32 , 16 > { }
1785- impl ValidStride for StrideValidator < f32 , 20 > { }
1786- impl ValidStride for StrideValidator < f32 , 24 > { }
1787- impl ValidStride for StrideValidator < f32 , 28 > { }
1788- impl ValidStride for StrideValidator < f32 , 32 > { }
1789-
1790- // bf16 requires stride to be multiple of 8 (same as f16)
1791- impl ValidStride for StrideValidator < bf16 , 8 > { }
1792- impl ValidStride for StrideValidator < bf16 , 16 > { }
1793- impl ValidStride for StrideValidator < bf16 , 24 > { }
1794- impl ValidStride for StrideValidator < bf16 , 32 > { }
1795- impl ValidStride for StrideValidator < bf16 , 40 > { }
1796- impl ValidStride for StrideValidator < bf16 , 48 > { }
1797- impl ValidStride for StrideValidator < bf16 , 56 > { }
1798- impl ValidStride for StrideValidator < bf16 , 64 > { }
1799-
1800- // i8/u8 require stride to be multiple of 16
1801- impl ValidStride for StrideValidator < i8 , 16 > { }
1802- impl ValidStride for StrideValidator < i8 , 32 > { }
1803- impl ValidStride for StrideValidator < i8 , 48 > { }
1804- impl ValidStride for StrideValidator < i8 , 64 > { }
1805-
1806- impl ValidStride for StrideValidator < u8 , 16 > { }
1807- impl ValidStride for StrideValidator < u8 , 32 > { }
1808- impl ValidStride for StrideValidator < u8 , 48 > { }
1809- impl ValidStride for StrideValidator < u8 , 64 > { }
1810-
1811- // i32 requires stride to be multiple of 4 (same as f32)
1812- impl ValidStride for StrideValidator < i32 , 4 > { }
1813- impl ValidStride for StrideValidator < i32 , 8 > { }
1814- impl ValidStride for StrideValidator < i32 , 12 > { }
1815- impl ValidStride for StrideValidator < i32 , 16 > { }
1816- impl ValidStride for StrideValidator < i32 , 20 > { }
1817- impl ValidStride for StrideValidator < i32 , 24 > { }
1818- impl ValidStride for StrideValidator < i32 , 28 > { }
1819- impl ValidStride for StrideValidator < i32 , 32 > { }
1820-
1821- // bool uses u8 stride requirements (multiple of 16)
1822- impl ValidStride for StrideValidator < bool , 16 > { }
1823- impl ValidStride for StrideValidator < bool , 32 > { }
1824- impl ValidStride for StrideValidator < bool , 48 > { }
1825- impl ValidStride for StrideValidator < bool , 64 > { }
1757+ // Stride validation types and implementations are now in the stride module
18261758
18271759// ============================================================================
18281760// Ergonomic Builder API
@@ -1990,5 +1922,5 @@ mod sealed {
19901922 impl Sealed for bool { }
19911923
19921924 // Seal stride validators
1993- impl < T , const S : usize > Sealed for StrideValidator < T , S > { }
1925+ // StrideValidator sealing is now in stride module
19941926}
0 commit comments