Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions fbgemm_gpu/bench/verify_fp16_stochastic_benchmark.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ namespace fbgemm_gpu {

DEVICE_INLINE half
float_to_sto_half_fbgemm_rand(float x, StochasticRoundingRNGState& state) {
const uint4 random_bits = stochastic_rounding_rand4(&state);
const auto random_bits = state.rand4();
uint32_t random_value = random_bits.x;
uint32_t w_int = __float_as_uint(x);
unsigned assembles = (w_int & 0xff800000) | (random_value >> 19);
Expand All @@ -41,13 +41,10 @@ __global__ void convert_float_to_half_fbgemm_rand(
half* dst,
const float* src,
int size,
at::PhiloxCudaState stochastic_rounding_philox_args) {
at::PhiloxCudaState philox_args) {
const auto idx = blockIdx.x * blockDim.x + threadIdx.x;

StochasticRoundingRNGState state;
const auto seeds = at::cuda::philox::unpack(stochastic_rounding_philox_args);
stochastic_rounding_init(
std::get<0>(seeds) ^ std::get<1>(seeds), idx, &state);
auto state = StochasticRoundingRNGState(philox_args, idx);

if (idx < size) {
dst[idx] = float_to_sto_half_fbgemm_rand(src[idx], state);
Expand Down
52 changes: 10 additions & 42 deletions fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu
Original file line number Diff line number Diff line change
Expand Up @@ -371,26 +371,15 @@ __global__ void scaleMatrix(
const int64_t numel,
const int64_t lda,
at::PhiloxCudaState stochastic_rounding_philox_args) {
StochasticRoundingRNGState stoc_rounding_state;

const auto stochastic_rounding_seeds =
at::cuda::philox::unpack(stochastic_rounding_philox_args);
const uint64_t salt_value = threadIdx.x + blockIdx.x * blockDim.x;

stochastic_rounding_init(
std::get<0>(stochastic_rounding_seeds) ^
std::get<1>(stochastic_rounding_seeds),
// The salt value should be different for every *run* and every
// *thread*.
salt_value,
&stoc_rounding_state);
auto stoc_rounding_state = StochasticRoundingRNGState(
stochastic_rounding_philox_args, threadIdx.x + blockIdx.x * blockDim.x);
auto input_scal = static_cast<float>(input_scale[0]);

auto vec_output = reinterpret_cast<__nv_fp8x4_e4m3*>(&output[0]);
auto vec_input = reinterpret_cast<const bfx4*>(&input[0]);
for (int32_t d = (threadIdx.x + blockIdx.x * blockDim.x); d * 4 < numel;
d += (size_t)blockDim.x * gridDim.x) {
const uint4 random_bits = stochastic_rounding_rand4(&stoc_rounding_state);
const auto random_bits = stoc_rounding_state.rand4();
bfx4 v_in = vec_input[d];
float4 v_float;
v_float.x = stochastic_rounding_scalar_fp8(
Expand All @@ -417,25 +406,16 @@ __global__ void scaleMatrixRowwise(
const int64_t numel,
const int64_t lda,
at::PhiloxCudaState stochastic_rounding_philox_args) {
StochasticRoundingRNGState stoc_rounding_state;

const auto stochastic_rounding_seeds =
at::cuda::philox::unpack(stochastic_rounding_philox_args);
const uint64_t salt_value = threadIdx.x + blockIdx.x * blockDim.x;
stochastic_rounding_init(
std::get<0>(stochastic_rounding_seeds) ^
std::get<1>(stochastic_rounding_seeds),
// The salt value should be different for every *run* and every
// *thread*.
salt_value,
&stoc_rounding_state);
auto stoc_rounding_state = StochasticRoundingRNGState(
stochastic_rounding_philox_args, threadIdx.x + blockIdx.x * blockDim.x);
auto input_scal = static_cast<float>(input_scale[0]);

auto vec_output = reinterpret_cast<__nv_fp8x4_e4m3*>(&output[0]);
auto vec_input = reinterpret_cast<const bfx4*>(&input[0]);
auto vec_scale = reinterpret_cast<const float4*>(&input_scale[0]);
for (int32_t d = (threadIdx.x + blockIdx.x * blockDim.x); d * 4 < numel;
d += (size_t)blockDim.x * gridDim.x) {
const uint4 random_bits = stochastic_rounding_rand4(&stoc_rounding_state);
const auto random_bits = stoc_rounding_state.rand4();
bfx4 v_in = vec_input[d];
float4 v_float;
float4 v_scale = vec_scale[d / lda];
Expand Down Expand Up @@ -938,21 +918,9 @@ __global__ void dynamicQuantizeMatrixRowwiseStoc(
int64_t lda,
const float* scale_ub,
at::PhiloxCudaState stochastic_rounding_philox_args) {
StochasticRoundingRNGState stoc_rounding_state;

const auto stochastic_rounding_seeds =
at::cuda::philox::unpack(stochastic_rounding_philox_args);
const uint64_t salt_value = threadIdx.x + blockIdx.x * blockDim.x;

stochastic_rounding_init(
std::get<0>(stochastic_rounding_seeds) ^
std::get<1>(stochastic_rounding_seeds),
// The salt value should be different for every *run* and every
// *thread*.
salt_value,
&stoc_rounding_state);

const uint4 random_bits = stochastic_rounding_rand4(&stoc_rounding_state);
auto stoc_rounding_state = StochasticRoundingRNGState(
stochastic_rounding_philox_args, threadIdx.x + blockIdx.x * blockDim.x);
const auto random_bits = stoc_rounding_state.rand4();

extern __shared__ __align__(sizeof(float)) char _shmem[];
T_IN* shmem = reinterpret_cast<T_IN*>(_shmem);
Expand Down
15 changes: 11 additions & 4 deletions fbgemm_gpu/include/fbgemm_gpu/utils/host_device_buffer_pair.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
* LICENSE file in the root directory of this source tree.
*/

#include <c10/cuda/CUDAException.h>

#include <cuda.h>
#include <cuda_fp16.h>
#include <curand.h>
Expand Down Expand Up @@ -86,13 +84,22 @@ struct HostDeviceBufferPair {
}

inline void syncToDevice() {
cudaMemcpy(
const auto err = cudaMemcpy(
device, host.data(), host.size() * sizeof(T), cudaMemcpyHostToDevice);
if (err != cudaSuccess) {
fprintf(stderr, "CUDA error: %s\n", cudaGetErrorString(err));
std::exit(1);
}
}

inline void syncToHost() {
cudaMemcpy(
const auto err = cudaMemcpy(
host.data(), device, host.size() * sizeof(T), cudaMemcpyDeviceToHost);

if (err != cudaSuccess) {
fprintf(stderr, "CUDA error: %s\n", cudaGetErrorString(err));
std::exit(1);
}
}

inline void free() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ DEVICE_INLINE void stochastic_rounding_vector(
const Vec2T<at::Half>& value,
StochasticRoundingRNGState& state,
const float2 /* not used */) {
const uint4 random_bits = stochastic_rounding_rand4(&state);
const auto random_bits = state.rand4();
Half2 v;
v.a = __halves2half2(
stochastic_rounding_scalar(value.acc.x, random_bits.x),
Expand All @@ -64,7 +64,7 @@ DEVICE_INLINE void stochastic_rounding_vector(
const Vec2T<float>& value,
StochasticRoundingRNGState& state,
const float2 /* not used */) {
const uint4 random_bits = stochastic_rounding_rand4(&state);
const auto random_bits = state.rand4();
Half2 v;
v.a = __halves2half2(
stochastic_rounding_scalar(value.acc.x, random_bits.x),
Expand All @@ -79,7 +79,7 @@ DEVICE_INLINE void stochastic_rounding_vector(
const Vec2T<float>& value,
StochasticRoundingRNGState& state,
const float2 qparams) {
const uint4 random_bits = stochastic_rounding_rand4(&state);
const auto random_bits = state.rand4();
const float inv_scale = 255.0f / (qparams.x * 255.0f + kQParamEps);
output[0] = stochastic_rounding_scalar_uint8(
(value.acc.x - qparams.y) * inv_scale, random_bits.x);
Expand All @@ -93,7 +93,7 @@ DEVICE_INLINE void stochastic_rounding_vector(
const Vec2T<at::Half>& value,
StochasticRoundingRNGState& state,
const float2 qparams) {
const uint4 random_bits = stochastic_rounding_rand4(&state);
const auto random_bits = state.rand4();
const float inv_scale = 255.0f / (qparams.x * 255.0f + kQParamEps);
output[0] = stochastic_rounding_scalar_uint8(
(value.acc.x - qparams.y) * inv_scale, random_bits.x);
Expand Down
134 changes: 79 additions & 55 deletions fbgemm_gpu/include/fbgemm_gpu/utils/stochastic_rounding.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,78 @@
namespace fbgemm_gpu {

////////////////////////////////////////////////////////////////////////////////
// Stochastic Rounding
// Stochastic Rounding RNG State
//
// This is a simple xorshift* RNG with 64 bits of state (vs 384 bits of state
// for curandStatePhilox4_32_10). It is used for generating uint4 random bits
// for stochastic rounding.
////////////////////////////////////////////////////////////////////////////////

struct StochasticRoundingRNGState {
uint64_t state = 0;

__host__ DEVICE_INLINE constexpr StochasticRoundingRNGState() = default;

__host__ DEVICE_INLINE StochasticRoundingRNGState(
const at::PhiloxCudaState& philox_state,
const uint64_t salt_value) noexcept {
init(philox_state, salt_value);
}

// From https://github.com/lemire/testingRNG/blob/master/source/splitmix64.h
__host__ DEVICE_INLINE constexpr uint64_t splitmix64_stateless(
uint64_t index) noexcept {
uint64_t z = (index + UINT64_C(0x9E3779B97F4A7C15));
z = (z ^ (z >> 30)) * UINT64_C(0xBF58476D1CE4E5B9);
z = (z ^ (z >> 27)) * UINT64_C(0x94D049BB133111EB);
return z ^ (z >> 31);
}

__host__ DEVICE_INLINE void init(
const at::PhiloxCudaState& philox_state,
// The salt value should be different for every *run* and every
// *thread*. Passing in threadIdx.x + blockIdx.x * blockDim.x is
// recommended.
const uint64_t salt_value) noexcept {
const auto [s0, s1] = at::cuda::philox::unpack(philox_state);
state = splitmix64_stateless(s0 ^ s1) ^ splitmix64_stateless(salt_value);

// Ensure we never have a zero state (insanely low probability, but
// still...).
if (state == 0) {
state = 1;
}
}

// See https://www.pcg-random.org/pdf/hmc-cs-2014-0905.pdf and
// https://en.wikipedia.org/wiki/Xorshift#xorshift*
__host__ DEVICE_INLINE constexpr uint4 rand4() noexcept {
uint4 random_bits = {0, 0, 0, 0};
uint64_t x = state; /* The state must be seeded with a nonzero value. */
x ^= x >> 12; // a
x ^= x << 25; // b
x ^= x >> 27; // c
random_bits.x = (x * UINT64_C(0x2545F4914F6CDD1D)) >> 32;
x ^= x >> 12; // a
x ^= x << 25; // b
x ^= x >> 27; // c
random_bits.y = (x * UINT64_C(0x2545F4914F6CDD1D)) >> 32;
x ^= x >> 12; // a
x ^= x << 25; // b
x ^= x >> 27; // c
random_bits.z = (x * UINT64_C(0x2545F4914F6CDD1D)) >> 32;
x ^= x >> 12; // a
x ^= x << 25; // b
x ^= x >> 27; // c
random_bits.w = (x * UINT64_C(0x2545F4914F6CDD1D)) >> 32;
// Update internal state
state = x;
return random_bits;
}
};

////////////////////////////////////////////////////////////////////////////////
// Stochastic Rounding Scalar
////////////////////////////////////////////////////////////////////////////////

// Correct for cases where x is not subnormal.
Expand All @@ -43,56 +114,9 @@ stochastic_rounding_scalar_uint8(float x, uint32_t random_bits) {
return lrintf(x + noise.F);
}

// This is a simple xorshift* RNG with 64 bits of state (vs 384 bits of state
// for curandStatePhilox4_32_10)
struct StochasticRoundingRNGState {
uint64_t a;
};

// From https://github.com/lemire/testingRNG/blob/master/source/splitmix64.h
__host__ DEVICE_INLINE uint64_t splitmix64_stateless(uint64_t index) {
uint64_t z = (index + UINT64_C(0x9E3779B97F4A7C15));
z = (z ^ (z >> 30)) * UINT64_C(0xBF58476D1CE4E5B9);
z = (z ^ (z >> 27)) * UINT64_C(0x94D049BB133111EB);
return z ^ (z >> 31);
}

DEVICE_INLINE void stochastic_rounding_init(
uint64_t s0,
uint64_t s1,
StochasticRoundingRNGState* state) {
state->a = splitmix64_stateless(s0) ^ splitmix64_stateless(s1);
// Ensure we never have a zero state (insanely low probability, but still...).
if (state->a == 0) {
state->a = 1;
}
}

// See https://www.pcg-random.org/pdf/hmc-cs-2014-0905.pdf and
// https://en.wikipedia.org/wiki/Xorshift#xorshift*
DEVICE_INLINE uint4
stochastic_rounding_rand4(StochasticRoundingRNGState* state) {
uint4 random_bits;
uint64_t x = state->a; /* The state must be seeded with a nonzero value. */
x ^= x >> 12; // a
x ^= x << 25; // b
x ^= x >> 27; // c
random_bits.x = (x * UINT64_C(0x2545F4914F6CDD1D)) >> 32;
x ^= x >> 12; // a
x ^= x << 25; // b
x ^= x >> 27; // c
random_bits.y = (x * UINT64_C(0x2545F4914F6CDD1D)) >> 32;
x ^= x >> 12; // a
x ^= x << 25; // b
x ^= x >> 27; // c
random_bits.z = (x * UINT64_C(0x2545F4914F6CDD1D)) >> 32;
x ^= x >> 12; // a
x ^= x << 25; // b
x ^= x >> 27; // c
random_bits.w = (x * UINT64_C(0x2545F4914F6CDD1D)) >> 32;
state->a = x;
return random_bits;
}
////////////////////////////////////////////////////////////////////////////////
// Stochastic Rounding Vector
////////////////////////////////////////////////////////////////////////////////

template <typename dst_t, typename src_t>
DEVICE_INLINE void stochastic_rounding_vector(
Expand All @@ -109,7 +133,7 @@ DEVICE_INLINE void stochastic_rounding_vector(
const Vec4T<at::Half>& value,
StochasticRoundingRNGState& state,
const float2 /* not used */) {
const uint4 random_bits = stochastic_rounding_rand4(&state);
const auto random_bits = state.rand4();
Half4 v;
v.a = __halves2half2(
stochastic_rounding_scalar(value.acc.x, random_bits.x),
Expand All @@ -126,7 +150,7 @@ DEVICE_INLINE void stochastic_rounding_vector(
const Vec4T<float>& value,
StochasticRoundingRNGState& state,
const float2 /* not used */) {
const uint4 random_bits = stochastic_rounding_rand4(&state);
const auto random_bits = state.rand4();
Half4 v;
v.a = __halves2half2(
stochastic_rounding_scalar(value.acc.x, random_bits.x),
Expand All @@ -143,7 +167,7 @@ DEVICE_INLINE void stochastic_rounding_vector(
const Vec4T<float>& value,
StochasticRoundingRNGState& state,
const float2 qparams) {
const uint4 random_bits = stochastic_rounding_rand4(&state);
const auto random_bits = state.rand4();
const float inv_scale = 255.0f / (qparams.x * 255.0f + kQParamEps);
output[0] = stochastic_rounding_scalar_uint8(
(value.acc.x - qparams.y) * inv_scale, random_bits.x);
Expand All @@ -161,7 +185,7 @@ DEVICE_INLINE void stochastic_rounding_vector(
const Vec4T<at::Half>& value,
StochasticRoundingRNGState& state,
const float2 qparams) {
const uint4 random_bits = stochastic_rounding_rand4(&state);
const auto random_bits = state.rand4();
const float inv_scale = 255.0f / (qparams.x * 255.0f + kQParamEps);
output[0] = stochastic_rounding_scalar_uint8(
(value.acc.x - qparams.y) * inv_scale, random_bits.x);
Expand Down
11 changes: 1 addition & 10 deletions fbgemm_gpu/include/fbgemm_gpu/utils/weight_row.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -134,16 +134,7 @@ struct WeightRow {
stoc_rounding_state_ptr_ = nullptr;
if constexpr (!std::is_same_v<emb_t, float>) {
if (stochastic_rounding) {
const auto stochastic_rounding_seeds =
at::cuda::philox::unpack(*stochastic_rounding_philox_args);

stochastic_rounding_init(
std::get<0>(stochastic_rounding_seeds) ^
std::get<1>(stochastic_rounding_seeds),
// The salt value should be different for every *run* and every
// *thread*.
salt_value,
&stoc_rounding_state_);
stoc_rounding_state_.init(*stochastic_rounding_philox_args, salt_value);
// Store the pointer here to avoid an if-else cond during load/store
stoc_rounding_state_ptr_ = &stoc_rounding_state_;
}
Expand Down
Loading
Loading