Skip to content

Precise reduction #125

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 23 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,7 @@
[submodule "third_party/gpudma"]
path = third_party/gpudma
url = https://github.com/karakozov/gpudma

[submodule "examples/llama/llama"]
path = examples/llama/llama
url = https://github.com/facebookresearch/llama
48 changes: 25 additions & 23 deletions ark/include/kernels/layernorm.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#ifndef ARK_KERNELS_LAYERNORM_H_
#define ARK_KERNELS_LAYERNORM_H_

#include "math_functions.h"
#include "reduce.h"

namespace ark {
Expand Down Expand Up @@ -63,21 +64,18 @@ struct LayerNorm
(tid_c + uc * UnitOutDims::C) * InDims::HW +
(tid_n + un * UnitOutDims::N) * InDims::CHW;

DataType reduced;
ReduceTypeMean::singleIdentity(&reduced);
DataType reduced = ReduceTypeMean::singleIdentity();
for (int idx_in_w = tid_w; idx_in_w < InShape::W;
idx_in_w += ThreadsPerRow) {
int idx_in = idx_in_base + idx_in_w;
ReduceTypeMean::singleReduce(&reduced, &reduced, &in[idx_in]);
ReduceTypeMean::singleReduce(reduced, in[idx_in_base + idx_in_w]);
}
UnitOp::sync_threads();
// final reduction on shared memory using warp shuffle.
reduced = warpsReduce<ReduceTypeMean, UnitOp, ThreadsPerRow>(
reduced, tid, smem_per_warp);
// get the average result.
ReduceTypeMean::singlePostReduce(&reduced, &reduced, UnitOutDims::W);
DataType variance;
ReduceTypeMean::singleIdentity(&variance);
ReduceTypeMean::singlePostReduce(reduced, UnitOutDims::W);
DataType variance = ReduceTypeMean::singleIdentity();
// get the variance
UnitOp::sync_threads();
for (int idx_in_w = tid_w; idx_in_w < InShape::W;
Expand All @@ -88,7 +86,7 @@ struct LayerNorm
UnitOp::sync_threads();
variance = warpsReduce<ReduceTypeMean, UnitOp, ThreadsPerRow>(
variance, tid, smem_per_warp);
ReduceTypeMean::singlePostReduce(&variance, &variance, UnitOutDims::W);
ReduceTypeMean::singlePostReduce(variance, UnitOutDims::W);
UnitOp::sync_threads();
// the output is (input - mean) / sqrt(variance)
for (int idx_in_w = tid_w; idx_in_w < InShape::W;
Expand Down Expand Up @@ -127,7 +125,8 @@ DEVICE void layernorm(ark::half *out, const ark::half *in, int uop_idx,
// Root Mean Square Layer Normalization: https://arxiv.org/pdf/1910.07467.pdf
template <typename InDims, typename InShape, typename OutDims,
typename OutShape, typename UnitOutDims, int NumThreads,
int SmemBytes, typename DataType, int NelemPerThread>
int SmemBytes, typename DataType, typename CompType,
int NelemPerThread>
struct RMSNorm
{
using UnitOp =
Expand All @@ -138,7 +137,8 @@ struct RMSNorm
int smem_per_warp)
{
using InOutChk = LayerNormShapeChecker<InShape, OutShape>;
using ReduceTypeMean = ReduceTypeMean<DataType, NelemPerThread>;
using ReduceTypeMean =
ReduceTypeMean<DataType, CompType, NelemPerThread>;

constexpr int NonReduceDimLength = UnitOutDims::NCH;
// The reduction dimension of the final stage.
Expand Down Expand Up @@ -166,25 +166,27 @@ struct RMSNorm
(tid_c + uc * UnitOutDims::C) * InDims::HW +
(tid_n + un * UnitOutDims::N) * InDims::CHW;

DataType variance;
ReduceTypeMean::singleIdentity(&variance);
CompType var = ReduceTypeMean::singleIdentity();

// get the variance
UnitOp::sync_threads();
for (int idx_in_w = tid_w; idx_in_w < InShape::W;
idx_in_w += ThreadsPerRow) {
int idx_in = idx_in_base + idx_in_w;
variance += (in[idx_in]) * (in[idx_in]);
CompType data = static_cast<CompType>(in[idx_in_base + idx_in_w]);
var += data * data;
}
UnitOp::sync_threads();
variance = warpsReduce<ReduceTypeMean, UnitOp, ThreadsPerRow>(
variance, tid, smem_per_warp);
ReduceTypeMean::singlePostReduce(&variance, &variance, UnitOutDims::W);
var = warpsReduce<ReduceTypeMean, UnitOp, ThreadsPerRow>(
var, tid, smem_per_warp) /
UnitOutDims::W;
UnitOp::sync_threads();
// the output is (input - mean) / sqrt(variance)
// the output is (input - mean) / sqrt(reduced)
for (int idx_in_w = tid_w; idx_in_w < InShape::W;
idx_in_w += ThreadsPerRow) {
int idx_in = idx_in_base + idx_in_w;
out[idx_in] = (in[idx_in]) * rsqrtf(variance + 1e-5f);
CompType data = static_cast<CompType>(in[idx_in]);
out[idx_in] = static_cast<DataType>(
data * Rsqrt::compute(var + CompType(1e-5f)));
}
}
};
Expand All @@ -196,8 +198,8 @@ DEVICE void rmsnorm(float *out, const float *in, int uop_idx, int smem_per_warp)
{
constexpr int NelemPerThread = 1;
RMSNorm<InDims, InShape, OutDims, OutShape, UnitOutDims, NumThreads,
SmemBytes, float, NelemPerThread>::run(out, in, uop_idx,
smem_per_warp);
SmemBytes, float, float, NelemPerThread>::run(out, in, uop_idx,
smem_per_warp);
}

template <typename InDims, typename InShape, typename OutDims,
Expand All @@ -208,8 +210,8 @@ DEVICE void rmsnorm(ark::half *out, const ark::half *in, int uop_idx,
{
constexpr int NelemPerThread = 1;
RMSNorm<InDims, InShape, OutDims, OutShape, UnitOutDims, NumThreads,
SmemBytes, ark::half, NelemPerThread>::run(out, in, uop_idx,
smem_per_warp);
SmemBytes, ark::half, float, NelemPerThread>::run(out, in, uop_idx,
smem_per_warp);
}

} // namespace ark
Expand Down
32 changes: 28 additions & 4 deletions ark/include/kernels/math_functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,28 +10,52 @@ namespace ark {

struct Exp
{
static DEVICE float compute(float input)
static DEVICE float compute(const float &input)
{
return expf(input);
}
static DEVICE __half2 compute(__half2 input)
static DEVICE __half compute(const __half &input)
{
return hexp(input);
}
static DEVICE __half2 compute(const __half2 &input)
{
return h2exp(input);
}
};

struct Sqrt
{
static DEVICE float compute(float input)
static DEVICE float compute(const float &input)
{
return sqrtf(input);
}
static DEVICE __half2 compute(__half2 input)
static DEVICE __half compute(const __half &input)
{
return hsqrt(input);
}
static DEVICE __half2 compute(const __half2 &input)
{
return h2sqrt(input);
}
};

struct Rsqrt
{
static DEVICE float compute(const float &input)
{
return rsqrtf(input);
}
static DEVICE __half compute(const __half &input)
{
return hrsqrt(input);
}
static DEVICE __half2 compute(const __half2 &input)
{
return h2rsqrt(input);
}
};

template <typename _MathType, typename _InShape, typename _DataType,
int _NelemPerThread>
struct Math;
Expand Down
Loading