diff --git a/src/dgemm_kernel.rs b/src/dgemm_kernel.rs index 4f1e6a7..eee30dd 100644 --- a/src/dgemm_kernel.rs +++ b/src/dgemm_kernel.rs @@ -24,15 +24,14 @@ macro_rules! loop_n { } impl GemmKernel for Gemm { - type Elem = T; + type ElemIn = T; + type ElemOut = T; - #[inline(always)] - fn align_to() -> usize { 0 } + const MR: usize = MR; + const NR: usize = NR; #[inline(always)] - fn mr() -> usize { MR } - #[inline(always)] - fn nr() -> usize { NR } + fn align_to() -> usize { 0 } #[inline(always)] fn always_masked() -> bool { true } diff --git a/src/gemm.rs b/src/gemm.rs index 4dc11bd..c3337cc 100644 --- a/src/gemm.rs +++ b/src/gemm.rs @@ -19,7 +19,7 @@ use kernel::GemmKernel; use kernel::Element; use sgemm_kernel; use dgemm_kernel; -use igemm_kernel; +use i8gemm_kernel; use rawpointer::PointerExt; /// General matrix multiplication (f32) @@ -88,15 +88,15 @@ pub unsafe fn dgemm( c, rsc, csc) } -pub unsafe fn igemm( +pub unsafe fn i8gemm( m: usize, k: usize, n: usize, - alpha: i32, - a: *const i32, rsa: isize, csa: isize, - b: *const i32, rsb: isize, csb: isize, - beta: i32, - c: *mut i32, rsc: isize, csc: isize) + alpha: i16, + a: *const i8, rsa: isize, csa: isize, + b: *const i8, rsb: isize, csb: isize, + beta: i16, + c: *mut i16, rsc: isize, csc: isize) { - gemm_loop::( + gemm_loop::( m, k, n, alpha, a, rsa, csa, @@ -113,14 +113,14 @@ pub unsafe fn igemm( fn ensure_kernel_params() where K: GemmKernel { - let mr = K::mr(); - let nr = K::nr(); + let mr = K::MR; + let nr = K::NR; assert!(mr > 0 && mr <= 8); assert!(nr > 0 && nr <= 8); - assert!(mr * nr * size_of::() <= 8 * 4 * 8); + assert!(mr * nr * size_of::() <= 8 * 4 * 8); assert!(K::align_to() <= 32); // one row/col of the kernel is limiting the max align we can provide - let max_align = size_of::() * min(mr, nr); + let max_align = size_of::() * min(mr, nr); assert!(K::align_to() <= max_align); } @@ -128,11 +128,11 @@ fn ensure_kernel_params() /// strategy, the type parameter `K` is the gemm microkernel. unsafe fn gemm_loop( m: usize, k: usize, n: usize, - alpha: K::Elem, - a: *const K::Elem, rsa: isize, csa: isize, - b: *const K::Elem, rsb: isize, csb: isize, - beta: K::Elem, - c: *mut K::Elem, rsc: isize, csc: isize) + alpha: K::ElemOut, + a: *const K::ElemIn, rsa: isize, csa: isize, + b: *const K::ElemIn, rsb: isize, csb: isize, + beta: K::ElemOut, + c: *mut K::ElemOut, rsc: isize, csc: isize) where K: GemmKernel { debug_assert!(m <= 1 || n == 0 || rsc != 0); @@ -146,7 +146,7 @@ unsafe fn gemm_loop( let knc = K::nc(); let kkc = K::kc(); let kmc = K::mc(); - ensure_kernel_params::(); + // ensure_kernel_params::(); let (mut packing_buffer, bp_offset) = make_packing_buffer::(m, k, n); let app = packing_buffer.ptr_mut(); @@ -165,7 +165,7 @@ unsafe fn gemm_loop( let a = a.stride_offset(csa, kkc * l4); // Pack B -> B~ - pack(kc, nc, K::nr(), bpp, b, csb, rsb); + pack(kc, nc, K::NR, bpp, b, csb, rsb); // LOOP 3: split m into mc parts for (l3, mc) in range_chunk(m, kmc) { @@ -174,7 +174,7 @@ unsafe fn gemm_loop( let c = c.stride_offset(rsc, kmc * l3); // Pack A -> A~ - pack(kc, mc, K::mr(), app, a, rsa, csa); + pack(kc, mc, K::MR, app, a, rsa, csa); // First time writing to C, use user's `beta`, else accumulate let betap = if l4 == 0 { beta } else { <_>::one() }; @@ -198,18 +198,19 @@ unsafe fn gemm_loop( /// + kc: columns of packed A / rows of packed B /// + mc: rows of packed A unsafe fn gemm_packed(nc: usize, kc: usize, mc: usize, - alpha: K::Elem, - app: *const K::Elem, bpp: *const K::Elem, - beta: K::Elem, - c: *mut K::Elem, rsc: isize, csc: isize) + alpha: K::ElemOut, + app: *const K::ElemIn, bpp: *const K::ElemIn, + beta: K::ElemOut, + c: *mut K::ElemOut, rsc: isize, csc: isize) where K: GemmKernel, { - let mr = K::mr(); - let nr = K::nr(); + let mr = K::MR; + let nr = K::NR; // make a mask buffer that fits 8 x 8 f32 and 8 x 4 f64 kernels and alignment - assert!(mr * nr * size_of::() <= 256 && K::align_to() <= 32); - let mut mask_buf = [0u8; 256 + 31]; - let mask_ptr = align_ptr(32, mask_buf.as_mut_ptr()) as *mut K::Elem; + // assert!(mr * nr * size_of::() <= 256 && K::align_to() <= 32); + // let mut mask_buf = [0u8; 256 + 31]; + let mut mask_buf = [0u8; 16*32*2 + 31]; + let mask_ptr = align_ptr(32, mask_buf.as_mut_ptr()) as *mut K::ElemOut; // LOOP 2: through micropanels in packed `b` for (l2, nr_) in range_chunk(nc, nr) { @@ -225,7 +226,7 @@ unsafe fn gemm_packed(nc: usize, kc: usize, mc: usize, // NOTE: For the rust kernels, it performs better to simply // always use the masked kernel function! if K::always_masked() || nr_ < nr || mr_ < mr { - masked_kernel::<_, K>(kc, alpha, &*app, &*bpp, + masked_kernel::<_, _, K>(kc, alpha, &*app, &*bpp, beta, &mut *c, rsc, csc, mr_, nr_, mask_ptr); continue; @@ -244,7 +245,7 @@ unsafe fn gemm_packed(nc: usize, kc: usize, mc: usize, /// we have rounded up to a multiple of the kernel size). /// /// Return packing buffer and offset to start of b -unsafe fn make_packing_buffer(m: usize, k: usize, n: usize) -> (Alloc, usize) +unsafe fn make_packing_buffer(m: usize, k: usize, n: usize) -> (Alloc, usize) where K: GemmKernel, { // max alignment requirement is a multiple of min(MR, NR) * sizeof @@ -254,8 +255,8 @@ unsafe fn make_packing_buffer(m: usize, k: usize, n: usize) -> (Alloc(kc: usize, mc: usize, mr: usize, pack: *mut T, /// + rows: rows of kernel unmasked /// + cols: cols of kernel unmasked #[inline(never)] -unsafe fn masked_kernel(k: usize, alpha: T, - a: *const T, - b: *const T, - beta: T, - c: *mut T, rsc: isize, csc: isize, +unsafe fn masked_kernel(k: usize, alpha: Tout, + a: *const Tin, + b: *const Tin, + beta: Tout, + c: *mut Tout, rsc: isize, csc: isize, rows: usize, cols: usize, - mask_buf: *mut T) - where K: GemmKernel, T: Element, + mask_buf: *mut Tout) + where K: GemmKernel, + Tin: Element, + Tout: Element, { - let mr = K::mr(); - let nr = K::nr(); + let mr = K::MR; + let nr = K::NR; // use column major order for `mask_buf` - K::kernel(k, T::one(), a, b, T::zero(), mask_buf, 1, mr as isize); + K::kernel(k, Tout::one(), a, b, Tout::zero(), mask_buf, 1, mr as isize); let mut ab = mask_buf; for j in 0..nr { for i in 0..mr { @@ -369,7 +372,7 @@ unsafe fn masked_kernel(k: usize, alpha: T, let cptr = c.stride_offset(rsc, i) .stride_offset(csc, j); if beta.is_zero() { - *cptr = T::zero(); // initialize C + *cptr = Tout::zero(); // initialize C } else { (*cptr).scale_by(beta); } diff --git a/src/i8gemm_kernel.rs b/src/i8gemm_kernel.rs new file mode 100644 index 0000000..532e5c3 --- /dev/null +++ b/src/i8gemm_kernel.rs @@ -0,0 +1,566 @@ +// Copyright 2016 - 2018 Ulrik Sverdrup "bluss" +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use kernel::GemmKernel; +use archparam; + + +#[cfg(target_arch="x86")] +use std::arch::x86::*; +#[cfg(target_arch="x86_64")] +use std::arch::x86_64::*; + +pub enum Gemm { } + +pub type Tin = i8; +pub type Tout = i16; + +const MR: usize = 16; +const NR: usize = 32; + +macro_rules! loop_m { ($i:ident, $e:expr) => { loop16!($i, $e) }; } +macro_rules! loop_n { ($j:ident, $e:expr) => { loop32!($j, $e) }; } + +impl GemmKernel for Gemm { + type ElemIn = Tin; + type ElemOut = Tout; + + const MR: usize = MR; + const NR: usize = NR; + + #[inline(always)] + fn align_to() -> usize { 16 } + + #[inline(always)] + fn always_masked() -> bool { true } + + #[inline(always)] + fn nc() -> usize { archparam::S_NC } + #[inline(always)] + fn kc() -> usize { archparam::S_KC } + #[inline(always)] + fn mc() -> usize { archparam::S_MC } + + #[inline(always)] + unsafe fn kernel( + k: usize, + alpha: Tout, + a: *const Tin, + b: *const Tin, + beta: Tout, + c: *mut Tout, rsc: isize, csc: isize) { + kernel(k, alpha, a, b, beta, c, rsc, csc) + } +} + +/// Multiply two 128-bit vectors of 16 8-bit integers each,by sign-extending them to 256-bit +/// vectors of 16-bit integers, and then multiplying these temporaries. +#[inline(always)] +unsafe fn _mm256_mulepi8_epi16(a: __m128i, b: __m128i) -> __m256i +{ + let tmp0 = _mm256_cvtepi8_epi16(a); + let tmp1 = _mm256_cvtepi8_epi16(b); + + _mm256_mullo_epi16(tmp0, tmp1) +} + +/// matrix multiplication kernel +/// +/// This does the matrix multiplication: +/// +/// C ← α A B + β C +/// +/// + k: length of data in a, b +/// + a, b are packed +/// + c has general strides +/// + rsc: row stride of c +/// + csc: col stride of c +/// + if beta is 0, then c does not need to be initialized +#[inline(never)] +pub unsafe fn kernel(k: usize, alpha: Tout, a: *const Tin, b: *const Tin, + beta: Tout, c: *mut Tout, rsc: isize, csc: isize) +{ + // dispatch to specific compiled versions + #[cfg(any(target_arch="x86", target_arch="x86_64"))] + { + if is_x86_feature_detected_!("avx2") { + return kernel_target_avx2(k, alpha, a, b, beta, c, rsc, csc); + } else if is_x86_feature_detected_!("avx") { + return kernel_target_avx(k, alpha, a, b, beta, c, rsc, csc); + } else if is_x86_feature_detected_!("sse2") { + return kernel_target_sse2(k, alpha, a, b, beta, c, rsc, csc); + } + } + kernel_fallback_impl(k, alpha, a, b, beta, c, rsc, csc); +} + +#[inline] +#[target_feature(enable="avx2")] +#[cfg(any(target_arch="x86", target_arch="x86_64"))] +unsafe fn kernel_target_avx2(k: usize, alpha: Tout, a: *const Tin, b: *const Tin, + beta: Tout, c: *mut Tout, rsc: isize, csc: isize) +{ + kernel_x86_avx2(k, alpha, a, b, beta, c, rsc, csc) +} + +#[inline] +#[target_feature(enable="avx")] +#[cfg(any(target_arch="x86", target_arch="x86_64"))] +unsafe fn kernel_target_avx(k: usize, alpha: Tout, a: *const Tin, b: *const Tin, + beta: Tout, c: *mut Tout, rsc: isize, csc: isize) +{ + kernel_fallback_impl(k, alpha, a, b, beta, c, rsc, csc) +} + +#[inline] +#[target_feature(enable="sse2")] +#[cfg(any(target_arch="x86", target_arch="x86_64"))] +unsafe fn kernel_target_sse2(k: usize, alpha: Tout, a: *const Tin, b: *const Tin, + beta: Tout, c: *mut Tout, rsc: isize, csc: isize) +{ + kernel_fallback_impl(k, alpha, a, b, beta, c, rsc, csc) +} + + +#[inline(always)] +unsafe fn kernel_fallback_impl(k: usize, alpha: Tout, a: *const Tin, b: *const Tin, + beta: Tout, c: *mut Tout, rsc: isize, csc: isize) +{ + let mut ab: [[Tout; NR]; MR] = [[0; NR]; MR]; + let mut a = a; + let mut b = b; + debug_assert_eq!(beta, 0); + + // Compute A B into ab[i][j] + unroll_by!(4 => k, { + loop_m!(i, loop_n!(j, { + ab[i][j] = ab[i][j].saturating_add( + (at(a, i) as i16) + .saturating_mul( + at(b, j) as i16 + ));})); + + a = a.offset(MR as isize); + b = b.offset(NR as isize); + }); + + macro_rules! c { + ($i:expr, $j:expr) => (c.offset(rsc * $i as isize + csc * $j as isize)); + } + + // set C = α A B + β C + loop_n!(j, loop_m!(i, *c![i, j] = alpha.wrapping_mul(ab[i][j]))); +} + +#[inline(always)] +#[cfg(any(target_arch="x86", target_arch="x86_64"))] +unsafe fn kernel_x86_avx2(k: usize, alpha: Tout, a: *const Tin, b: *const Tin, + beta: Tout, c: *mut Tout, rsc: isize, csc: isize) +{ + debug_assert_ne!(k, 0); + + let mut ab = [_mm256_setzero_si256(); NR]; + + let (mut a, mut b) = (a, b); + + let mut a_col = _mm_loadu_si128(a as *const __m128i); + + // Load two rows from b at a time. + let mut b_row = _mm256_loadu_si256(b as *const __m256i); + + // FIXME: Is this k a meaningful number in this context? + unroll_by_with_last!(4 => k, is_last, { + let b0_b16 = _mm256_shuffle_epi8( + b_row, + _mm256_set_epi8( + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + ) + ); + + let b0 = _mm256_extracti128_si256(b0_b16, 0); + let b16 = _mm256_extracti128_si256(b0_b16, 1); + + let b1_b17 = _mm256_shuffle_epi8( + b_row, + _mm256_set_epi8( + 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, + 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, 0x1, + ) + ); + + let b1 = _mm256_extracti128_si256(b1_b17, 0); + let b17 = _mm256_extracti128_si256(b1_b17, 1); + + let b2_b18 = _mm256_shuffle_epi8( + b_row, + _mm256_set_epi8( + 0x2, 0x2, 0x2, 0x2, 0x2, 0x2, 0x2, 0x2, 0x2, 0x2, 0x2, 0x2, 0x2, 0x2, 0x2, 0x2, + 0x2, 0x2, 0x2, 0x2, 0x2, 0x2, 0x2, 0x2, 0x2, 0x2, 0x2, 0x2, 0x2, 0x2, 0x2, 0x2, + ) + ); + + let b2 = _mm256_extracti128_si256(b2_b18, 0); + let b18 = _mm256_extracti128_si256(b2_b18, 1); + + let b3_b19 = _mm256_shuffle_epi8( + b_row, + _mm256_set_epi8( + 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, + 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, + ) + ); + + let b3 = _mm256_extracti128_si256(b3_b19, 0); + let b19 = _mm256_extracti128_si256(b3_b19, 1); + + let b4_b20 = _mm256_shuffle_epi8( + b_row, + _mm256_set_epi8( + 0x4, 0x4, 0x4, 0x4, 0x4, 0x4, 0x4, 0x4, 0x4, 0x4, 0x4, 0x4, 0x4, 0x4, 0x4, 0x4, + 0x4, 0x4, 0x4, 0x4, 0x4, 0x4, 0x4, 0x4, 0x4, 0x4, 0x4, 0x4, 0x4, 0x4, 0x4, 0x4, + ) + ); + + let b4 = _mm256_extracti128_si256(b4_b20, 0); + let b20 = _mm256_extracti128_si256(b4_b20, 1); + + let b5_b21 = _mm256_shuffle_epi8( + b_row, + _mm256_set_epi8( + 0x5, 0x5, 0x5, 0x5, 0x5, 0x5, 0x5, 0x5, 0x5, 0x5, 0x5, 0x5, 0x5, 0x5, 0x5, 0x5, + 0x5, 0x5, 0x5, 0x5, 0x5, 0x5, 0x5, 0x5, 0x5, 0x5, 0x5, 0x5, 0x5, 0x5, 0x5, 0x5, + ) + ); + + let b5 = _mm256_extracti128_si256(b5_b21, 0); + let b21 = _mm256_extracti128_si256(b5_b21, 1); + + let b6_b22 = _mm256_shuffle_epi8( + b_row, + _mm256_set_epi8( + 0x6, 0x6, 0x6, 0x6, 0x6, 0x6, 0x6, 0x6, 0x6, 0x6, 0x6, 0x6, 0x6, 0x6, 0x6, 0x6, + 0x6, 0x6, 0x6, 0x6, 0x6, 0x6, 0x6, 0x6, 0x6, 0x6, 0x6, 0x6, 0x6, 0x6, 0x6, 0x6, + ) + ); + + let b6 = _mm256_extracti128_si256(b6_b22, 0); + let b22 = _mm256_extracti128_si256(b6_b22, 1); + + let b7_b23 = _mm256_shuffle_epi8( + b_row, + _mm256_set_epi8( + 0x7, 0x7, 0x7, 0x7, 0x7, 0x7, 0x7, 0x7, 0x7, 0x7, 0x7, 0x7, 0x7, 0x7, 0x7, 0x7, + 0x7, 0x7, 0x7, 0x7, 0x7, 0x7, 0x7, 0x7, 0x7, 0x7, 0x7, 0x7, 0x7, 0x7, 0x7, 0x7, + ) + ); + + let b7 = _mm256_extracti128_si256(b7_b23, 0); + let b23 = _mm256_extracti128_si256(b7_b23, 1); + + let b8_b24 = _mm256_shuffle_epi8( + b_row, + _mm256_set_epi8( + 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, + 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, + ) + ); + + let b8 = _mm256_extracti128_si256(b8_b24, 0); + let b24 = _mm256_extracti128_si256(b8_b24, 1); + + let b9_b25 = _mm256_shuffle_epi8( + b_row, + _mm256_set_epi8( + 0x9, 0x9, 0x9, 0x9, 0x9, 0x9, 0x9, 0x9, 0x9, 0x9, 0x9, 0x9, 0x9, 0x9, 0x9, 0x9, + 0x9, 0x9, 0x9, 0x9, 0x9, 0x9, 0x9, 0x9, 0x9, 0x9, 0x9, 0x9, 0x9, 0x9, 0x9, 0x9, + ) + ); + + let b9 = _mm256_extracti128_si256(b9_b25, 0); + let b25 = _mm256_extracti128_si256(b9_b25, 1); + + let b10_b26 = _mm256_shuffle_epi8( + b_row, + _mm256_set_epi8( + 0xa, 0xa, 0xa, 0xa, 0xa, 0xa, 0xa, 0xa, 0xa, 0xa, 0xa, 0xa, 0xa, 0xa, 0xa, 0xa, + 0xa, 0xa, 0xa, 0xa, 0xa, 0xa, 0xa, 0xa, 0xa, 0xa, 0xa, 0xa, 0xa, 0xa, 0xa, 0xa, + ) + ); + + let b10 = _mm256_extracti128_si256(b10_b26, 0); + let b26 = _mm256_extracti128_si256(b10_b26, 1); + + let b11_b27 = _mm256_shuffle_epi8( + b_row, + _mm256_set_epi8( + 0xb, 0xb, 0xb, 0xb, 0xb, 0xb, 0xb, 0xb, 0xb, 0xb, 0xb, 0xb, 0xb, 0xb, 0xb, 0xb, + 0xb, 0xb, 0xb, 0xb, 0xb, 0xb, 0xb, 0xb, 0xb, 0xb, 0xb, 0xb, 0xb, 0xb, 0xb, 0xb, + ) + ); + + let b11 = _mm256_extracti128_si256(b11_b27, 0); + let b27 = _mm256_extracti128_si256(b11_b27, 1); + + let b12_b28 = _mm256_shuffle_epi8( + b_row, + _mm256_set_epi8( + 0xc, 0xc, 0xc, 0xc, 0xc, 0xc, 0xc, 0xc, 0xc, 0xc, 0xc, 0xc, 0xc, 0xc, 0xc, 0xc, + 0xc, 0xc, 0xc, 0xc, 0xc, 0xc, 0xc, 0xc, 0xc, 0xc, 0xc, 0xc, 0xc, 0xc, 0xc, 0xc, + ) + ); + + let b12 = _mm256_extracti128_si256(b12_b28, 0); + let b28 = _mm256_extracti128_si256(b12_b28, 1); + + let b13_b29 = _mm256_shuffle_epi8( + b_row, + _mm256_set_epi8( + 0xd, 0xd, 0xd, 0xd, 0xd, 0xd, 0xd, 0xd, 0xd, 0xd, 0xd, 0xd, 0xd, 0xd, 0xd, 0xd, + 0xd, 0xd, 0xd, 0xd, 0xd, 0xd, 0xd, 0xd, 0xd, 0xd, 0xd, 0xd, 0xd, 0xd, 0xd, 0xd, + ) + ); + + let b13 = _mm256_extracti128_si256(b13_b29, 0); + let b29 = _mm256_extracti128_si256(b13_b29, 1); + + let b14_b30 = _mm256_shuffle_epi8( + b_row, + _mm256_set_epi8( + 0xe, 0xe, 0xe, 0xe, 0xe, 0xe, 0xe, 0xe, 0xe, 0xe, 0xe, 0xe, 0xe, 0xe, 0xe, 0xe, + 0xe, 0xe, 0xe, 0xe, 0xe, 0xe, 0xe, 0xe, 0xe, 0xe, 0xe, 0xe, 0xe, 0xe, 0xe, 0xe, + ) + ); + + let b14 = _mm256_extracti128_si256(b14_b30, 0); + let b30 = _mm256_extracti128_si256(b14_b30, 1); + + let b15_b31 = _mm256_shuffle_epi8( + b_row, + _mm256_set_epi8( + 0xf, 0xf, 0xf, 0xf, 0xf, 0xf, 0xf, 0xf, 0xf, 0xf, 0xf, 0xf, 0xf, 0xf, 0xf, 0xf, + 0xf, 0xf, 0xf, 0xf, 0xf, 0xf, 0xf, 0xf, 0xf, 0xf, 0xf, 0xf, 0xf, 0xf, 0xf, 0xf, + ) + ); + + let b15 = _mm256_extracti128_si256(b15_b31, 0); + let b31 = _mm256_extracti128_si256(b15_b31, 1); + + // Multiplication and addition with the first row. + ab[0] = _mm256_adds_epi16(ab[0], _mm256_mulepi8_epi16(a_col, b0)); + ab[1] = _mm256_adds_epi16(ab[1], _mm256_mulepi8_epi16(a_col, b1)); + ab[2] = _mm256_adds_epi16(ab[2], _mm256_mulepi8_epi16(a_col, b2)); + ab[3] = _mm256_adds_epi16(ab[3], _mm256_mulepi8_epi16(a_col, b3)); + + ab[4] = _mm256_adds_epi16(ab[4], _mm256_mulepi8_epi16(a_col, b4)); + ab[5] = _mm256_adds_epi16(ab[5], _mm256_mulepi8_epi16(a_col, b5)); + ab[6] = _mm256_adds_epi16(ab[6], _mm256_mulepi8_epi16(a_col, b6)); + ab[7] = _mm256_adds_epi16(ab[7], _mm256_mulepi8_epi16(a_col, b7)); + + ab[8] = _mm256_adds_epi16(ab[8], _mm256_mulepi8_epi16(a_col, b8)); + ab[9] = _mm256_adds_epi16(ab[9], _mm256_mulepi8_epi16(a_col, b9)); + ab[10] = _mm256_adds_epi16(ab[10], _mm256_mulepi8_epi16(a_col, b10)); + ab[11] = _mm256_adds_epi16(ab[11], _mm256_mulepi8_epi16(a_col, b11)); + + ab[12] = _mm256_adds_epi16(ab[12], _mm256_mulepi8_epi16(a_col, b12)); + ab[13] = _mm256_adds_epi16(ab[13], _mm256_mulepi8_epi16(a_col, b13)); + ab[14] = _mm256_adds_epi16(ab[14], _mm256_mulepi8_epi16(a_col, b14)); + ab[15] = _mm256_adds_epi16(ab[15], _mm256_mulepi8_epi16(a_col, b15)); + + // Multiplication and addition with the second row.); + ab[16] = _mm256_adds_epi16(ab[0], _mm256_mulepi8_epi16(a_col, b16)); + ab[17] = _mm256_adds_epi16(ab[1], _mm256_mulepi8_epi16(a_col, b17)); + ab[18] = _mm256_adds_epi16(ab[2], _mm256_mulepi8_epi16(a_col, b18)); + ab[19] = _mm256_adds_epi16(ab[3], _mm256_mulepi8_epi16(a_col, b19)); + + ab[20] = _mm256_adds_epi16(ab[4], _mm256_mulepi8_epi16(a_col, b20)); + ab[21] = _mm256_adds_epi16(ab[5], _mm256_mulepi8_epi16(a_col, b21)); + ab[22] = _mm256_adds_epi16(ab[6], _mm256_mulepi8_epi16(a_col, b22)); + ab[23] = _mm256_adds_epi16(ab[7], _mm256_mulepi8_epi16(a_col, b23)); + + ab[24] = _mm256_adds_epi16(ab[8], _mm256_mulepi8_epi16(a_col, b24)); + ab[25] = _mm256_adds_epi16(ab[9], _mm256_mulepi8_epi16(a_col, b25)); + ab[26] = _mm256_adds_epi16(ab[10], _mm256_mulepi8_epi16(a_col, b26)); + ab[27] = _mm256_adds_epi16(ab[11], _mm256_mulepi8_epi16(a_col, b27)); + + ab[28] = _mm256_adds_epi16(ab[12], _mm256_mulepi8_epi16(a_col, b28)); + ab[29] = _mm256_adds_epi16(ab[13], _mm256_mulepi8_epi16(a_col, b29)); + ab[30] = _mm256_adds_epi16(ab[14], _mm256_mulepi8_epi16(a_col, b30)); + ab[31] = _mm256_adds_epi16(ab[15], _mm256_mulepi8_epi16(a_col, b31)); + + if !is_last { + a = a.add(MR); + b = b.add(NR); + + a_col = _mm_loadu_si128(a as _); + b_row = _mm256_loadu_si256(b as _); + } + }); + + // Compute α (A B) + let alpha_v = _mm256_set1_epi16(alpha); + loop_m!(i, ab[i] = _mm256_mullo_epi16(alpha_v, ab[i])); + + macro_rules! c { + ($i:expr, $j:expr) => + (c.offset(rsc * $i as isize + csc * $j as isize)); + } + + // C ← α A B + β C + let mut cv = [_mm256_setzero_si256(); MR]; + + if beta != 0 { + let beta_v = _mm256_set1_epi16(beta); + + // Read C + if rsc == 1 { + loop_m!(i, cv[i] = _mm256_loadu_si256(c![0, i] as _)); + // } else if csc == 1 { + // loop4!(i, cv[i] = _mm256_loadu_pd(c![i, 0])); + // loop4!(i, cv[i+4] = _mm256_loadu_pd(c![i+4, 0])); + } else { + loop_m!(i, cv[i] = + _mm256_setr_epi16( + *c![0, i], + *c![1, i], + *c![2, i], + *c![3, i], + *c![4, i], + *c![5, i], + *c![6, i], + *c![7, i], + *c![8, i], + *c![9, i], + *c![10, i], + *c![11, i], + *c![12, i], + *c![13, i], + *c![14, i], + *c![15, i], + )); + } + // Compute β C + loop_m!(i, cv[i] = _mm256_mullo_epi16(cv[i], beta_v)); + } + + // Compute (α A B) + (β C) + loop_m!(i, cv[i] = _mm256_add_epi32(cv[i], ab[i])); + + if rsc == 1 { + loop_m!(i, _mm256_storeu_si256(c![0, i] as _, cv[i])); + // } else if csc == 1 { + // loop4!(i, _mm256_storeu_pd(c![i, 0], cv[i])); + // loop4!(i, _mm256_storeu_pd(c![i+4, 0], cv[i + 4])); + } else { + // TODO: This inner unrolled loop should be replaced by + // `loop_n!(j, *c![i, j] = _mm256_extract_epi32(cv[i], j);` + // However, rustc currently errors with: + // > error: argument 2 is required to be a constant + // Some reading: + // + https://internals.rust-lang.org/t/pre-rfc-const-function-arguments/6709/12 + // + https://www.reddit.com/r/rust/comments/9pxuoj/simd_instructions_requiring_a_constant_parameter/ + loop_m!(i, { + *c![i, 0] = _mm256_extract_epi16(cv[i], 0); + *c![i, 1] = _mm256_extract_epi16(cv[i], 1); + *c![i, 2] = _mm256_extract_epi16(cv[i], 2); + *c![i, 3] = _mm256_extract_epi16(cv[i], 3); + *c![i, 4] = _mm256_extract_epi16(cv[i], 4); + *c![i, 5] = _mm256_extract_epi16(cv[i], 5); + *c![i, 6] = _mm256_extract_epi16(cv[i], 6); + *c![i, 7] = _mm256_extract_epi16(cv[i], 7); + }) + } +} + +#[inline(always)] +unsafe fn at(ptr: *const Tin, i: usize) -> Tin { + *ptr.offset(i as isize) +} + +#[cfg(test)] +mod tests { + use super::*; + use aligned_alloc::Alloc; + + fn aligned_alloc(elt: T, n: usize) -> Alloc where T: Copy + { + unsafe { + Alloc::new(n, Gemm::align_to()).init_with(elt) + } + } + + use super::Tin; + use super::Tout; + type KernelFn = unsafe fn(usize, Tout, *const Tin, *const Tin, Tout, *mut Tout, isize, isize); + + fn test_a_kernel(_name: &str, kernel_fn: KernelFn) { + const K: usize = 4; + let mut a = aligned_alloc(1, MR * K); + let mut b = aligned_alloc(0, NR * K); + for (i, x) in a.iter_mut().enumerate() { + *x = i as _; + } + + for i in 0..K { + b[i + i * NR] = 1; + } + let mut c = [0; MR * NR]; + unsafe { + kernel_fn(K, 1, &a[0], &b[0], 0, &mut c[0], 1, MR as isize); + // col major C + } + let a: Vec<_> = a.iter().map(|x| *x as i16).collect(); + assert_eq!(&a[..], &c[..a.len()]); + } + + #[test] + fn test_native_kernel() { + test_a_kernel("kernel", kernel); + } + + #[test] + fn test_kernel_fallback_impl() { + test_a_kernel("kernel", kernel_fallback_impl); + } + + #[test] + fn test_loop_m_n() { + let mut m = [[0; NR]; MR]; + loop_m!(i, loop_n!(j, m[i][j] += 1)); + for arr in &m[..] { + for elt in &arr[..] { + assert_eq!(*elt, 1); + } + } + } + + mod test_arch_kernels { + use super::test_a_kernel; + macro_rules! test_arch_kernels_x86 { + ($($feature_name:tt, $function_name:ident),*) => { + $( + #[test] + fn $function_name() { + if is_x86_feature_detected_!($feature_name) { + test_a_kernel(stringify!($function_name), super::super::$function_name); + } else { + println!("Skipping, host does not have feature: {:?}", $feature_name); + } + } + )* + } + } + + #[cfg(any(target_arch="x86", target_arch="x86_64"))] + test_arch_kernels_x86! { + "avx2", kernel_target_avx2, + "avx", kernel_target_avx, + "sse2", kernel_target_sse2 + } + } +} diff --git a/src/igemm_kernel.rs b/src/igemm_kernel.rs deleted file mode 100644 index c91bfec..0000000 --- a/src/igemm_kernel.rs +++ /dev/null @@ -1,219 +0,0 @@ -// Copyright 2016 - 2018 Ulrik Sverdrup "bluss" -// -// Licensed under the Apache License, Version 2.0 or the MIT license -// , at your -// option. This file may not be copied, modified, or distributed -// except according to those terms. - -use kernel::GemmKernel; -use kernel::Element; -use archparam; - - -#[cfg(target_arch="x86")] -use std::arch::x86::*; -#[cfg(target_arch="x86_64")] -use std::arch::x86_64::*; - -pub enum Gemm { } - -pub type T = i32; - -const MR: usize = 8; -const NR: usize = 4; - -macro_rules! loop_m { ($i:ident, $e:expr) => { loop8!($i, $e) }; } -macro_rules! loop_n { ($j:ident, $e:expr) => { loop4!($j, $e) }; } - -impl GemmKernel for Gemm { - type Elem = T; - - #[inline(always)] - fn align_to() -> usize { 16 } - - #[inline(always)] - fn mr() -> usize { MR } - #[inline(always)] - fn nr() -> usize { NR } - - #[inline(always)] - fn always_masked() -> bool { true } - - #[inline(always)] - fn nc() -> usize { archparam::S_NC } - #[inline(always)] - fn kc() -> usize { archparam::S_KC } - #[inline(always)] - fn mc() -> usize { archparam::S_MC } - - #[inline(always)] - unsafe fn kernel( - k: usize, - alpha: T, - a: *const T, - b: *const T, - beta: T, - c: *mut T, rsc: isize, csc: isize) { - kernel(k, alpha, a, b, beta, c, rsc, csc) - } -} - -/// matrix multiplication kernel -/// -/// This does the matrix multiplication: -/// -/// C ← α A B + β C -/// -/// + k: length of data in a, b -/// + a, b are packed -/// + c has general strides -/// + rsc: row stride of c -/// + csc: col stride of c -/// + if beta is 0, then c does not need to be initialized -#[inline(never)] -pub unsafe fn kernel(k: usize, alpha: T, a: *const T, b: *const T, - beta: T, c: *mut T, rsc: isize, csc: isize) -{ - // dispatch to specific compiled versions - #[cfg(any(target_arch="x86", target_arch="x86_64"))] - { - if is_x86_feature_detected_!("avx") { - return kernel_target_avx(k, alpha, a, b, beta, c, rsc, csc); - } else if is_x86_feature_detected_!("sse2") { - return kernel_target_sse2(k, alpha, a, b, beta, c, rsc, csc); - } - } - kernel_fallback_impl(k, alpha, a, b, beta, c, rsc, csc); -} - -#[inline] -#[target_feature(enable="avx")] -#[cfg(any(target_arch="x86", target_arch="x86_64"))] -unsafe fn kernel_target_avx(k: usize, alpha: T, a: *const T, b: *const T, - beta: T, c: *mut T, rsc: isize, csc: isize) -{ - kernel_fallback_impl(k, alpha, a, b, beta, c, rsc, csc) -} - -#[inline] -#[target_feature(enable="sse2")] -#[cfg(any(target_arch="x86", target_arch="x86_64"))] -unsafe fn kernel_target_sse2(k: usize, alpha: T, a: *const T, b: *const T, - beta: T, c: *mut T, rsc: isize, csc: isize) -{ - kernel_fallback_impl(k, alpha, a, b, beta, c, rsc, csc) -} - - -#[inline(always)] -unsafe fn kernel_fallback_impl(k: usize, alpha: T, a: *const T, b: *const T, - beta: T, c: *mut T, rsc: isize, csc: isize) -{ - let mut ab: [[T; NR]; MR] = [[0; NR]; MR]; - let mut a = a; - let mut b = b; - debug_assert_eq!(beta, 0); - - // Compute A B into ab[i][j] - unroll_by!(4 => k, { - loop_m!(i, loop_n!(j, { - ab[i][j] = ab[i][j].wrapping_add(at(a, i).wrapping_mul(at(b, j))); - })); - - a = a.offset(MR as isize); - b = b.offset(NR as isize); - }); - - macro_rules! c { - ($i:expr, $j:expr) => (c.offset(rsc * $i as isize + csc * $j as isize)); - } - - // set C = α A B + β C - loop_n!(j, loop_m!(i, *c![i, j] = alpha.wrapping_mul(ab[i][j]))); -} - -#[inline(always)] -unsafe fn at(ptr: *const T, i: usize) -> T { - *ptr.offset(i as isize) -} - -#[cfg(test)] -mod tests { - use super::*; - use aligned_alloc::Alloc; - - fn aligned_alloc(elt: T, n: usize) -> Alloc where T: Copy - { - unsafe { - Alloc::new(n, Gemm::align_to()).init_with(elt) - } - } - - use super::T; - type KernelFn = unsafe fn(usize, T, *const T, *const T, T, *mut T, isize, isize); - - fn test_a_kernel(_name: &str, kernel_fn: KernelFn) { - const K: usize = 4; - let mut a = aligned_alloc(1, MR * K); - let mut b = aligned_alloc(0, NR * K); - for (i, x) in a.iter_mut().enumerate() { - *x = i as _; - } - - for i in 0..K { - b[i + i * NR] = 1; - } - let mut c = [0; MR * NR]; - unsafe { - kernel_fn(K, 1, &a[0], &b[0], 0, &mut c[0], 1, MR as isize); - // col major C - } - assert_eq!(&a[..], &c[..a.len()]); - } - - #[test] - fn test_native_kernel() { - test_a_kernel("kernel", kernel); - } - - #[test] - fn test_kernel_fallback_impl() { - test_a_kernel("kernel", kernel_fallback_impl); - } - - #[test] - fn test_loop_m_n() { - let mut m = [[0; NR]; MR]; - loop_m!(i, loop_n!(j, m[i][j] += 1)); - for arr in &m[..] { - for elt in &arr[..] { - assert_eq!(*elt, 1); - } - } - } - - mod test_arch_kernels { - use super::test_a_kernel; - macro_rules! test_arch_kernels_x86 { - ($($feature_name:tt, $function_name:ident),*) => { - $( - #[test] - fn $function_name() { - if is_x86_feature_detected_!($feature_name) { - test_a_kernel(stringify!($function_name), super::super::$function_name); - } else { - println!("Skipping, host does not have feature: {:?}", $feature_name); - } - } - )* - } - } - - #[cfg(any(target_arch="x86", target_arch="x86_64"))] - test_arch_kernels_x86! { - "avx", kernel_target_avx, - "sse2", kernel_target_sse2 - } - } -} diff --git a/src/kernel.rs b/src/kernel.rs index 801b3d9..5b1c233 100644 --- a/src/kernel.rs +++ b/src/kernel.rs @@ -8,16 +8,18 @@ /// General matrix multiply kernel pub trait GemmKernel { - type Elem: Element; + type ElemIn: Element; + type ElemOut: Element; + + /// Number of kernel rows + const MR: usize; + + /// Number of kernel columns + const NR: usize; /// align inputs to this fn align_to() -> usize; - /// Kernel rows - fn mr() -> usize; - /// Kernel cols - fn nr() -> usize; - /// Whether to always use the masked wrapper around the kernel. /// /// If masked, the kernel is always called with α=1, β=0 @@ -41,11 +43,11 @@ pub trait GemmKernel { /// + if `beta` is `0.`, then c does not need to be initialized unsafe fn kernel( k: usize, - alpha: Self::Elem, - a: *const Self::Elem, - b: *const Self::Elem, - beta: Self::Elem, - c: *mut Self::Elem, rsc: isize, csc: isize); + alpha: Self::ElemOut, + a: *const Self::ElemIn, + b: *const Self::ElemIn, + beta: Self::ElemOut, + c: *mut Self::ElemOut, rsc: isize, csc: isize); } pub trait Element : Copy { @@ -56,38 +58,91 @@ pub trait Element : Copy { fn scaled_add(&mut self, alpha: Self, a: Self); } -impl Element for f32 { - fn zero() -> Self { 0. } - fn one() -> Self { 1. } - fn is_zero(&self) -> bool { *self == 0. } - fn scale_by(&mut self, x: Self) { - *self *= x; - } - fn scaled_add(&mut self, alpha: Self, a: Self) { - *self += alpha * a; - } -} +// impl Element for f32 { +// fn zero() -> Self { 0. } +// fn one() -> Self { 1. } +// fn is_zero(&self) -> bool { *self == 0. } +// fn scale_by(&mut self, x: Self) { +// *self *= x; +// } +// fn scaled_add(&mut self, alpha: Self, a: Self) { +// *self += alpha * a; +// } +// } -impl Element for f64 { - fn zero() -> Self { 0. } - fn one() -> Self { 1. } - fn is_zero(&self) -> bool { *self == 0. } - fn scale_by(&mut self, x: Self) { - *self *= x; - } - fn scaled_add(&mut self, alpha: Self, a: Self) { - *self += alpha * a; - } -} +// impl Element for f64 { +// fn zero() -> Self { 0. } +// fn one() -> Self { 1. } +// fn is_zero(&self) -> bool { *self == 0. } +// fn scale_by(&mut self, x: Self) { +// *self *= x; +// } +// fn scaled_add(&mut self, alpha: Self, a: Self) { +// *self += alpha * a; +// } +// } -impl Element for i32 { - fn zero() -> Self { 0 } - fn one() -> Self { 1 } - fn is_zero(&self) -> bool { *self == 0 } - fn scale_by(&mut self, x: Self) { - *self = self.wrapping_mul(x); - } - fn scaled_add(&mut self, alpha: Self, a: Self) { - *self = self.wrapping_add(alpha.wrapping_mul(a)); - } -} +// impl Element for i32 { +// fn zero() -> Self { 0 } +// fn one() -> Self { 1 } +// fn is_zero(&self) -> bool { *self == 0 } +// fn scale_by(&mut self, x: Self) { +// *self = self.wrapping_mul(x); +// } +// fn scaled_add(&mut self, alpha: Self, a: Self) { +// *self = self.wrapping_add(alpha.wrapping_mul(a)); +// } +// } + +// impl Element for i32 { +// fn zero() -> Self { 0 } +// fn one() -> Self { 1 } +// fn is_zero(&self) -> bool { *self == 0 } +// fn scale_by(&mut self, x: Self) { +// *self = self.wrapping_mul(x); +// } +// fn scaled_add(&mut self, alpha: Self, a: Self) { +// *self = self.wrapping_add(alpha.wrapping_mul(a)); +// } +// } + +macro_rules! impl_element_f { + ($($t:ty),+) => { + $( + impl Element for $t { + fn zero() -> Self { 0.0 } + fn one() -> Self { 1.0 } + fn is_zero(&self) -> bool { *self == 0.0 } + fn scale_by(&mut self, x: Self) { + // TODO: Change the semantics + *self *= x; + } + // TODO: Change the semantics + fn scaled_add(&mut self, alpha: Self, a: Self) { + *self += alpha * a; + } + } + )+ +};} + +macro_rules! impl_element_i { + ($($t:ty),+) => { + $( + impl Element for $t { + fn zero() -> Self { 0 } + fn one() -> Self { 1 } + fn is_zero(&self) -> bool { *self == 0 } + fn scale_by(&mut self, x: Self) { + // TODO: Change the semantics + *self = self.saturating_mul(x); + } + // TODO: Change the semantics + fn scaled_add(&mut self, alpha: Self, a: Self) { + *self = self.saturating_add(alpha.saturating_mul(a)); + } + } + )+ +};} + +impl_element_f!(f32, f64); +impl_element_i!(i8, i16, i32); diff --git a/src/lib.rs b/src/lib.rs index f263bd8..9ab1695 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -52,6 +52,8 @@ #![doc(html_root_url = "https://docs.rs/matrixmultiply/0.2/")] +#![feature(const_fn)] + extern crate rawpointer; #[macro_use] mod archmacros_x86; @@ -62,10 +64,10 @@ mod kernel; mod gemm; mod sgemm_kernel; mod dgemm_kernel; -mod igemm_kernel; +mod i8gemm_kernel; mod util; mod aligned_alloc; pub use gemm::sgemm; pub use gemm::dgemm; -pub use gemm::igemm; +pub use gemm::i8gemm; diff --git a/src/loopmacros.rs b/src/loopmacros.rs index 8e40d81..917d2c8 100644 --- a/src/loopmacros.rs +++ b/src/loopmacros.rs @@ -50,6 +50,80 @@ macro_rules! loop8 { }} } +#[cfg(debug_assertions)] +macro_rules! loop16 { + ($i:ident, $e:expr) => { + for $i in 0..16 { $e } + } +} + +#[cfg(not(debug_assertions))] +macro_rules! loop16 { + ($i:ident, $e:expr) => {{ + let $i = 0; $e; + let $i = 1; $e; + let $i = 2; $e; + let $i = 3; $e; + let $i = 4; $e; + let $i = 5; $e; + let $i = 6; $e; + let $i = 7; $e; + let $i = 8; $e; + let $i = 9; $e; + let $i = 10; $e; + let $i = 11; $e; + let $i = 12; $e; + let $i = 13; $e; + let $i = 14; $e; + let $i = 15; $e; + }} +} + +#[cfg(debug_assertions)] +macro_rules! loop32 { + ($i:ident, $e:expr) => { + for $i in 0..32 { $e } + } +} + +#[cfg(not(debug_assertions))] +macro_rules! loop32 { + ($i:ident, $e:expr) => {{ + let $i = 0; $e; + let $i = 1; $e; + let $i = 2; $e; + let $i = 3; $e; + let $i = 4; $e; + let $i = 5; $e; + let $i = 6; $e; + let $i = 7; $e; + let $i = 8; $e; + let $i = 9; $e; + let $i = 10; $e; + let $i = 11; $e; + let $i = 12; $e; + let $i = 13; $e; + let $i = 14; $e; + let $i = 15; $e; + let $i = 16; $e; + let $i = 17; $e; + let $i = 18; $e; + let $i = 19; $e; + let $i = 20; $e; + let $i = 21; $e; + let $i = 22; $e; + let $i = 23; $e; + let $i = 24; $e; + let $i = 25; $e; + let $i = 26; $e; + let $i = 27; $e; + let $i = 28; $e; + let $i = 29; $e; + let $i = 30; $e; + let $i = 31; $e; + }} +} + #[cfg(debug_assertions)] macro_rules! unroll_by { ($by:tt => $ntimes:expr, $e:expr) => { diff --git a/src/sgemm_kernel.rs b/src/sgemm_kernel.rs index 6064869..b23f490 100644 --- a/src/sgemm_kernel.rs +++ b/src/sgemm_kernel.rs @@ -26,15 +26,14 @@ macro_rules! loop_m { ($i:ident, $e:expr) => { loop8!($i, $e) }; } macro_rules! loop_n { ($j:ident, $e:expr) => { loop8!($j, $e) }; } impl GemmKernel for Gemm { - type Elem = T; + type ElemIn = T; + type ElemOut = T; - #[inline(always)] - fn align_to() -> usize { 32 } + const MR: usize = MR; + const NR: usize = NR; #[inline(always)] - fn mr() -> usize { MR } - #[inline(always)] - fn nr() -> usize { NR } + fn align_to() -> usize { 32 } #[inline(always)] fn always_masked() -> bool { false } diff --git a/tests/sgemm.rs b/tests/sgemm.rs index b4314a0..e15bbb5 100644 --- a/tests/sgemm.rs +++ b/tests/sgemm.rs @@ -1,7 +1,11 @@ extern crate itertools; extern crate matrixmultiply; -use matrixmultiply::{sgemm, dgemm, igemm}; +use matrixmultiply::{ + sgemm, + dgemm, + i8gemm, +}; use itertools::Itertools; use itertools::{ @@ -11,93 +15,129 @@ use itertools::{ }; use std::fmt::{Display, Debug}; -trait Float : Copy + Display + Debug + PartialEq { - fn zero() -> Self; - fn one() -> Self; - fn from(x: i64) -> Self; - fn nan() -> Self; - fn is_nan(self) -> bool; +trait GemmElement : Copy + Display + Debug + PartialEq { + // TODO: Provide default associated types once the following RFCs are merged and implemented: + // https://github.com/rust-lang/rfcs/pull/2532 + // https://github.com/rust-lang/rust/issues/29661 + // I.e., then we can do something like: + // type Output = Self; + // + // XXX: Is it somehow possible to already provide default impls for the _out functions in terms + // of the _in functions, where we assume that Input = Output? + type Output: Copy + Display + Debug + PartialEq; + + fn zero_in() -> Self; + fn one_in() -> Self; + fn nan_in() -> Self; + fn from_in(x: i64) -> Self; + fn is_nan_in(Self) -> bool; + + fn zero_out() -> Self::Output; + fn one_out() -> Self::Output; + fn nan_out() -> Self::Output; + fn from_out(x: i64) -> Self::Output; + fn is_nan_out(Self::Output) -> bool; + + fn to_out(Self) -> Self::Output; } -impl Float for f32 { - fn zero() -> Self { 0. } - fn one() -> Self { 1. } - fn from(x: i64) -> Self { x as Self } - fn nan() -> Self { 0./0. } - fn is_nan(self) -> bool { self.is_nan() } +macro_rules! impl_gemm_element_f { + ($($t:ty),+) => { + $( + impl GemmElement for $t { + type Output = Self; + + fn zero_in() -> Self { 0. } + fn one_in() -> Self { 1. } + fn from_in(x: i64) -> Self { x as Self } + fn nan_in() -> Self { 0./0. } + fn is_nan_in(var: Self) -> bool { var.is_nan() } + + fn zero_out() -> Self::Output { 0. } + fn one_out() -> Self::Output { 1. } + fn from_out(x: i64) -> Self::Output { x as Self::Output } + fn nan_out() -> Self { 0./0. } + fn is_nan_out(var: Self::Output) -> bool { var.is_nan() } + + fn to_out(var: Self) -> Self::Output { + var + } + } + )+ + }; } -impl Float for f64 { - fn zero() -> Self { 0. } - fn one() -> Self { 1. } - fn from(x: i64) -> Self { x as Self } - fn nan() -> Self { 0./0. } - fn is_nan(self) -> bool { self.is_nan() } -} +impl_gemm_element_f!(f32, f64); + +impl GemmElement for i8 { + type Output = i16; -impl Float for i32 { - fn zero() -> Self { 0 } - fn one() -> Self { 1 } - fn from(x: i64) -> Self { x as Self } - fn nan() -> Self { i32::min_value() } // hack - fn is_nan(self) -> bool { self == i32::min_value() } + fn zero_in() -> Self { 0 } + fn one_in() -> Self { 1 } + fn from_in(x: i64) -> Self { x as Self } + fn nan_in() -> Self { i8::min_value() } // hack + fn is_nan_in(var: Self) -> bool { var == Self::nan_in() } + + fn zero_out() -> Self::Output { 0 } + fn one_out() -> Self::Output { 1 } + fn from_out(x: i64) -> Self::Output { x as Self::Output } + fn nan_out() -> Self::Output { i16::min_value() } // hack + fn is_nan_out(var: Self::Output) -> bool { var == Self::nan_out() } + + fn to_out(var: Self) -> Self::Output { + var as Self::Output + } } trait Gemm : Sized { - unsafe fn gemm( - m: usize, k: usize, n: usize, - alpha: Self, - a: *const Self, rsa: isize, csa: isize, - b: *const Self, rsb: isize, csb: isize, - beta: Self, - c: *mut Self, rsc: isize, csc: isize); -} + type Output; -impl Gemm for f32 { unsafe fn gemm( m: usize, k: usize, n: usize, - alpha: Self, + alpha: Self::Output, a: *const Self, rsa: isize, csa: isize, b: *const Self, rsb: isize, csb: isize, - beta: Self, - c: *mut Self, rsc: isize, csc: isize) { - sgemm( - m, k, n, - alpha, - a, rsa, csa, - b, rsb, csb, - beta, - c, rsc, csc) - } + beta: Self::Output, + c: *mut Self::Output, rsc: isize, csc: isize); } -impl Gemm for i32 { - unsafe fn gemm( - m: usize, k: usize, n: usize, - alpha: Self, - a: *const Self, rsa: isize, csa: isize, - b: *const Self, rsb: isize, csb: isize, - beta: Self, - c: *mut Self, rsc: isize, csc: isize) { - igemm( - m, k, n, - alpha, - a, rsa, csa, - b, rsb, csb, - beta, - c, rsc, csc) - } +macro_rules! impl_gemm_f { + ($(($t:ty, $f:ident)),+) => { + $( + impl Gemm for $t { + type Output = Self; + unsafe fn gemm( + m: usize, k: usize, n: usize, + alpha: Self, + a: *const Self, rsa: isize, csa: isize, + b: *const Self, rsb: isize, csb: isize, + beta: Self, + c: *mut Self, rsc: isize, csc: isize) { + $f( + m, k, n, + alpha, + a, rsa, csa, + b, rsb, csb, + beta, + c, rsc, csc) + } + } + )+ + }; } -impl Gemm for f64 { +impl_gemm_f!((f32, sgemm), (f64, dgemm)); + +impl Gemm for i8 { + type Output = i16; unsafe fn gemm( m: usize, k: usize, n: usize, - alpha: Self, + alpha: i16, a: *const Self, rsa: isize, csa: isize, b: *const Self, rsb: isize, csb: isize, - beta: Self, - c: *mut Self, rsc: isize, csc: isize) { - dgemm( + beta: i16, + c: *mut i16, rsc: isize, csc: isize) { + i8gemm( m, k, n, alpha, a, rsa, csa, @@ -109,105 +149,112 @@ impl Gemm for f64 { #[test] fn test_sgemm() { - test_gemm::(); + test_gemm::(); } #[test] fn test_dgemm() { - test_gemm::(); + test_gemm::(); } #[test] fn test_sgemm_strides() { - test_gemm_strides::(); + test_gemm_strides::(); } #[test] fn test_dgemm_strides() { - test_gemm_strides::(); + test_gemm_strides::(); } #[test] -fn test_i32gemm_strides() { - test_gemm_strides::(); +fn test_i8gemm_strides() { + test_gemm_strides::(); } -fn test_gemm_strides() where F: Gemm + Float { +fn test_gemm_strides() + where F: Gemm + GemmElement, + Tout: Copy + Display + Debug + PartialEq +{ for n in 0..10 { - test_strides::(n, n, n); + test_strides::(n, n, n); } for n in (3..12).map(|x| x * 7) { - test_strides::(n, n, n); + test_strides::(n, n, n); } - test_strides::(8, 12, 16); - test_strides::(8, 0, 10); + test_strides::(8, 12, 16); + test_strides::(8, 0, 10); } -fn test_gemm() where F: Gemm + Float { - test_mul_with_id::(4, 4, true); - test_mul_with_id::(8, 8, true); - test_mul_with_id::(32, 32, false); - test_mul_with_id::(128, 128, false); - test_mul_with_id::(17, 128, false); +fn test_gemm() + where F: Gemm + GemmElement, + Tout: Copy + Display + Debug + PartialEq +{ + test_mul_with_id::(4, 4, true); + test_mul_with_id::(8, 8, true); + test_mul_with_id::(32, 32, false); + test_mul_with_id::(128, 128, false); + test_mul_with_id::(17, 128, false); for i in 0..12 { for j in 0..12 { - test_mul_with_id::(i, j, true); + test_mul_with_id::(i, j, true); } } /* */ - test_mul_with_id::(17, 257, false); - test_mul_with_id::(24, 512, false); + test_mul_with_id::(17, 257, false); + test_mul_with_id::(24, 512, false); for i in 0..10 { for j in 0..10 { - test_mul_with_id::(i * 4, j * 4, true); + test_mul_with_id::(i * 4, j * 4, true); } } - test_mul_with_id::(266, 265, false); - test_mul_id_with::(4, 4, true); + test_mul_with_id::(266, 265, false); + test_mul_id_with::(4, 4, true); for i in 0..12 { for j in 0..12 { - test_mul_id_with::(i, j, true); + test_mul_id_with::(i, j, true); } } - test_mul_id_with::(266, 265, false); - test_scale::(0, 4, 4, true); - test_scale::(4, 0, 4, true); - test_scale::(4, 4, 0, true); - test_scale::(4, 4, 4, true); - test_scale::(19, 20, 16, true); - test_scale::(150, 140, 128, false); + test_mul_id_with::(266, 265, false); + test_scale::(0, 4, 4, true); + test_scale::(4, 0, 4, true); + test_scale::(4, 4, 0, true); + test_scale::(4, 4, 4, true); + test_scale::(19, 20, 16, true); + test_scale::(150, 140, 128, false); } /// multiply a M x N matrix with an N x N id matrix #[cfg(test)] -fn test_mul_with_id(m: usize, n: usize, small: bool) - where F: Gemm + Float +fn test_mul_with_id(m: usize, n: usize, small: bool) + where F: Gemm + GemmElement, + Tout: Copy + Display + Debug + PartialEq { let (m, k, n) = (m, n, n); - let mut a = vec![F::zero(); m * k]; - let mut b = vec![F::zero(); k * n]; - let mut c = vec![F::zero(); m * n]; + let mut a = vec![F::zero_in(); m * k]; + let mut b = vec![F::zero_in(); k * n]; + let mut c = vec![F::zero_out(); m * n]; println!("test matrix with id input M={}, N={}", m, n); for (i, elt) in a.iter_mut().enumerate() { - *elt = F::from(i as i64); + *elt = F::from_in(i as i64); } for i in 0..k { - b[i + i * k] = F::one(); + b[i + i * k] = F::one_in(); } unsafe { F::gemm( m, k, n, - F::one(), + F::one_out(), a.as_ptr(), k as isize, 1, b.as_ptr(), n as isize, 1, - F::zero(), + F::zero_out(), c.as_mut_ptr(), n as isize, 1, ) } for (i, (x, y)) in a.iter().zip(&c).enumerate() { - if x != y { + if F::to_out(*x) != *y { if k != 0 && n != 0 && small { for row in a.chunks(k) { println!("{:?}", row); @@ -228,33 +275,34 @@ fn test_mul_with_id(m: usize, n: usize, small: bool) /// multiply a K x K id matrix with an K x N matrix #[cfg(test)] -fn test_mul_id_with(k: usize, n: usize, small: bool) - where F: Gemm + Float +fn test_mul_id_with(k: usize, n: usize, small: bool) + where F: Gemm + GemmElement, + Tout: Copy + Display + Debug + PartialEq { let (m, k, n) = (k, k, n); - let mut a = vec![F::zero(); m * k]; - let mut b = vec![F::zero(); k * n]; - let mut c = vec![F::zero(); m * n]; + let mut a = vec![F::zero_in(); m * k]; + let mut b = vec![F::zero_in(); k * n]; + let mut c = vec![F::zero_out(); m * n]; for i in 0..k { - a[i + i * k] = F::one(); + a[i + i * k] = F::one_in(); } for (i, elt) in b.iter_mut().enumerate() { - *elt = F::from(i as i64); + *elt = F::from_in(i as i64); } unsafe { F::gemm( m, k, n, - F::one(), + F::one_out(), a.as_ptr(), k as isize, 1, b.as_ptr(), n as isize, 1, - F::zero(), + F::zero_out(), c.as_mut_ptr(), n as isize, 1, ) } for (i, (x, y)) in b.iter().zip(&c).enumerate() { - if x != y { + if F::to_out(*x) != *y { if k != 0 && n != 0 && small { for row in a.chunks(k) { println!("{:?}", row); @@ -274,55 +322,56 @@ fn test_mul_id_with(k: usize, n: usize, small: bool) } #[cfg(test)] -fn test_scale(m: usize, k: usize, n: usize, small: bool) - where F: Gemm + Float +fn test_scale(m: usize, k: usize, n: usize, small: bool) + where F: Gemm + GemmElement, + Tout: Copy + Display + Debug + PartialEq { let (m, k, n) = (m, k, n); - let mut a = vec![F::zero(); m * k]; - let mut b = vec![F::zero(); k * n]; - let mut c1 = vec![F::one(); m * n]; - let mut c2 = vec![F::nan(); m * n]; + let mut a = vec![F::zero_in(); m * k]; + let mut b = vec![F::zero_in(); k * n]; + let mut c1 = vec![F::one_out(); m * n]; + let mut c2 = vec![F::nan_out(); m * n]; // init c2 with NaN to test the overwriting behavior when beta = 0. for (i, elt) in a.iter_mut().enumerate() { - *elt = F::from(i as i64); + *elt = F::from_in(i as i64); } for (i, elt) in b.iter_mut().enumerate() { - *elt = F::from(i as i64); + *elt = F::from_in(i as i64); } unsafe { // C1 = 3 A B F::gemm( m, k, n, - F::from(3), + F::from_out(3), a.as_ptr(), k as isize, 1, b.as_ptr(), n as isize, 1, - F::zero(), + F::zero_out(), c1.as_mut_ptr(), n as isize, 1, ); // C2 = A B F::gemm( m, k, n, - F::one(), + F::one_out(), a.as_ptr(), k as isize, 1, b.as_ptr(), n as isize, 1, - F::zero(), + F::zero_out(), c2.as_mut_ptr(), n as isize, 1, ); // C2 = A B + 2 C2 F::gemm( m, k, n, - F::one(), + F::one_out(), a.as_ptr(), k as isize, 1, b.as_ptr(), n as isize, 1, - F::from(2), + F::from_out(2), c2.as_mut_ptr(), n as isize, 1, ); } for (i, (x, y)) in c1.iter().zip(&c2).enumerate() { - if x != y || x.is_nan() || y.is_nan() { + if x != y || F::is_nan_out(*x) || F::is_nan_out(*y) { if k != 0 && n != 0 && small { for row in a.chunks(k) { println!("{:?}", row); @@ -369,8 +418,9 @@ impl Default for Layout { #[cfg(test)] -fn test_strides(m: usize, k: usize, n: usize) - where F: Gemm + Float +fn test_strides(m: usize, k: usize, n: usize) + where F: Gemm + GemmElement, + Tout: Copy + Display + Debug + PartialEq { let (m, k, n) = (m, k, n); @@ -383,15 +433,16 @@ fn test_strides(m: usize, k: usize, n: usize) for elt in layouts_iter { let layouts = [elt[0], elt[1], elt[2], elt[3]]; let (m0, m1, m2, m3) = multipliers_iter.next_tuple().unwrap(); - test_strides_inner::(m, k, n, [m0, m1, m2, m3], layouts); + test_strides_inner::(m, k, n, [m0, m1, m2, m3], layouts); } } -fn test_strides_inner(m: usize, k: usize, n: usize, +fn test_strides_inner(m: usize, k: usize, n: usize, stride_multipliers: [[usize; 2]; 4], layouts: [Layout; 4]) - where F: Gemm + Float + where F: Gemm + GemmElement, + Tout: Copy + Display + Debug + PartialEq { let (m, k, n) = (m, k, n); @@ -401,16 +452,16 @@ fn test_strides_inner(m: usize, k: usize, n: usize, let mstridec = stride_multipliers[2]; let mstridec2 = stride_multipliers[3]; - let mut a = vec![F::zero(); m * k * mstridea[0] * mstridea[1]]; - let mut b = vec![F::zero(); k * n * mstrideb[0] * mstrideb[1]]; - let mut c1 = vec![F::nan(); m * n * mstridec[0] * mstridec[1]]; - let mut c2 = vec![F::nan(); m * n * mstridec2[0] * mstridec2[1]]; + let mut a = vec![F::zero_in(); m * k * mstridea[0] * mstridea[1]]; + let mut b = vec![F::zero_in(); k * n * mstrideb[0] * mstrideb[1]]; + let mut c1 = vec![F::nan_out(); m * n * mstridec[0] * mstridec[1]]; + let mut c2 = vec![F::nan_out(); m * n * mstridec2[0] * mstridec2[1]]; for (i, elt) in a.iter_mut().enumerate() { - *elt = F::from(i as i64); + *elt = F::from_in(i as i64); } for (i, elt) in b.iter_mut().enumerate() { - *elt = F::from(i as i64); + *elt = F::from_in(i as i64); } let la = layouts[0]; @@ -439,30 +490,30 @@ fn test_strides_inner(m: usize, k: usize, n: usize, // C1 = A B F::gemm( m, k, n, - F::from(1), + F::from_out(1), a.as_ptr(), rs_a, cs_a, b.as_ptr(), rs_b, cs_b, - F::zero(), + F::zero_out(), c1.as_mut_ptr(), rs_c1, cs_c1, ); - + // C1 += 2 A B F::gemm( m, k, n, - F::from(2), + F::from_out(2), a.as_ptr(), rs_a, cs_a, b.as_ptr(), rs_b, cs_b, - F::from(1), + F::from_out(1), c1.as_mut_ptr(), rs_c1, cs_c1, ); // C2 = 3 A B F::gemm( m, k, n, - F::from(3), + F::from_out(3), a.as_ptr(), rs_a, cs_a, b.as_ptr(), rs_b, cs_b, - F::zero(), + F::zero_out(), c2.as_mut_ptr(), rs_c2, cs_c2, ); } @@ -488,7 +539,7 @@ fn test_strides_inner(m: usize, k: usize, n: usize, let irem = index % rs_c1 as usize; let jrem = index % cs_c1 as usize; if irem != 0 && jrem != 0 { - assert!(elt.is_nan(), + assert!(F::is_nan_out(*elt), "Element at index={} ({}, {}) should be NaN, but was {}\n\ c1: {:?}\n", index, i, j, elt,