diff --git a/.clang-format b/.clang-format new file mode 100644 index 0000000000..eab4576fe0 --- /dev/null +++ b/.clang-format @@ -0,0 +1,87 @@ +--- +AccessModifierOffset: -1 +AlignAfterOpenBracket: AlwaysBreak +AlignConsecutiveAssignments: false +AlignConsecutiveDeclarations: false +AlignEscapedNewlinesLeft: true +AlignOperands: false +AlignTrailingComments: false +AllowAllParametersOfDeclarationOnNextLine: false +AllowShortBlocksOnASingleLine: false +AllowShortCaseLabelsOnASingleLine: false +AllowShortFunctionsOnASingleLine: Empty +AllowShortIfStatementsOnASingleLine: false +AllowShortLoopsOnASingleLine: false +AlwaysBreakAfterReturnType: None +AlwaysBreakBeforeMultilineStrings: true +AlwaysBreakTemplateDeclarations: true +BinPackArguments: false +BinPackParameters: false +BraceWrapping: + AfterClass: false + AfterControlStatement: false + AfterEnum: false + AfterFunction: false + AfterNamespace: false + AfterObjCDeclaration: false + AfterStruct: false + AfterUnion: false + BeforeCatch: false + BeforeElse: false + IndentBraces: false +BreakBeforeBinaryOperators: None +BreakBeforeBraces: Attach +BreakBeforeTernaryOperators: true +BreakConstructorInitializersBeforeComma: false +BreakAfterJavaFieldAnnotations: false +BreakStringLiterals: false +ColumnLimit: 80 +CommentPragmas: '^ IWYU pragma:' +ConstructorInitializerAllOnOneLineOrOnePerLine: true +ConstructorInitializerIndentWidth: 4 +ContinuationIndentWidth: 4 +Cpp11BracedListStyle: true +DerivePointerAlignment: false +DisableFormat: false +ForEachMacros: [ FOR_EACH, FOR_EACH_R, FOR_EACH_RANGE, ] +IncludeCategories: + - Regex: '^<.*\.h(pp)?>' + Priority: 1 + - Regex: '^<.*' + Priority: 2 + - Regex: '.*' + Priority: 3 +IndentCaseLabels: true +IndentWidth: 2 +IndentWrappedFunctionNames: false +KeepEmptyLinesAtTheStartOfBlocks: false +MacroBlockBegin: '' +MacroBlockEnd: '' +MaxEmptyLinesToKeep: 1 +NamespaceIndentation: None +ObjCBlockIndentWidth: 2 +ObjCSpaceAfterProperty: false +ObjCSpaceBeforeProtocolList: false +PenaltyBreakBeforeFirstCallParameter: 1 +PenaltyBreakComment: 300 +PenaltyBreakFirstLessLess: 120 +PenaltyBreakString: 1000 +PenaltyExcessCharacter: 1000000 +PenaltyReturnTypeOnItsOwnLine: 200 +PointerAlignment: Left +ReflowComments: true +SortIncludes: true +SpaceAfterCStyleCast: false +SpaceBeforeAssignmentOperators: true +SpaceBeforeParens: ControlStatements +SpaceInEmptyParentheses: false +SpacesBeforeTrailingComments: 1 +SpacesInAngles: false +SpacesInContainerLiterals: true +SpacesInCStyleCastParentheses: false +SpacesInParentheses: false +SpacesInSquareBrackets: false +Standard: Cpp11 +TabWidth: 8 +UseTab: Never +... diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000000..3d116a5115 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,3 @@ +include CMakeLists.txt +recursive-include mlx/ * +include python/src/* diff --git a/benchmarks/cpp/irregular_strides.cpp b/benchmarks/cpp/irregular_strides.cpp new file mode 100644 index 0000000000..e6a570995c --- /dev/null +++ b/benchmarks/cpp/irregular_strides.cpp @@ -0,0 +1,198 @@ +#include +#include + +#include "mlx/mlx.h" +#include "time_utils.h" + +using namespace mlx::core; + +void time_irregular_binary_ops_1D() { + auto device = default_device(); + int size = 1000000; + int step = 2; + auto a = random::uniform({size}); + auto b = random::uniform({size}); + eval(a, b); + a = slice(a, {0}, {size}, {step}); + b = slice(b, {0}, {size}, {step}); + TIMEM("1D strided", add, a, b, device); +} + +void time_irregular_binary_ops_2D() { + auto device = default_device(); + int size = 2048; + auto a = random::uniform({size, size}); + auto b = random::uniform({size, size}); + eval(a, b); + TIMEM("2D regular", add, a, b, device); + + b = transpose(b); + eval(b); + TIMEM("2D transpose", add, a, b, device); + + b = random::uniform({size}); + eval(b); + TIMEM("2D broadcast dim 0", add, a, b, device); + + b = reshape(b, {size, 1}); + eval(b); + TIMEM("2D broadcast dim 1", add, a, b, device); +} + +void time_irregular_binary_ops_3D() { + auto device = default_device(); + int d0 = 32; + int d1 = 512; + int d2 = 512; + auto a = random::uniform({d0, d1, d2}); + auto b = random::uniform({d0, d1, d2}); + TIMEM("3D regular", add, a, b, device); + + b = transpose(b, {0, 2, 1}); + TIMEM("3D transpose", add, a, b, device); + + b = random::uniform({d1, d2}); + TIMEM("3D broadcast dim 0", add, a, b, device); + + b = random::uniform({d0, 1, d2}); + TIMEM("3D broadcast dim 1", add, a, b, device); + + b = random::uniform({d0, d1, 1}); + TIMEM("3D broadcast dim 2", add, a, b, device); + + b = random::uniform({d2}); + TIMEM("3D broadcast dims 0, 1", add, a, b, device); + + b = random::uniform({d1, 1}); + TIMEM("3D broadcast dims 0, 2", add, a, b, device); + + b = random::uniform({d0, 1, 1}); + TIMEM("3D broadcast dims 1, 2", add, a, b, device); +} + +void time_irregular_binary_ops_4D() { + auto device = default_device(); + std::vector shape = {8, 8, 512, 512}; + auto a = random::uniform(shape); + auto b = random::uniform(shape); + + TIMEM("4D regular", add, a, b, device); + + b = transpose(b, {0, 1, 3, 2}); + TIMEM("4D transpose", add, a, b, device); + + std::string om = "4D broadcast dims "; + for (int i = 0; i < shape.size(); ++i) { + shape[i] = 1; + b = random::uniform(shape); + std::ostringstream msg; + msg << om << i; + TIMEM(msg.str(), add, a, b, device); + + for (int j = i + 1; j < shape.size(); ++j) { + shape[j] = 1; + std::ostringstream msg; + msg << om << i << ", " << j; + b = random::uniform(shape); + TIMEM(msg.str(), add, a, b, device); + shape[j] = a.shape(j); + + for (int k = j + 1; k < shape.size(); ++k) { + shape[k] = 1; + std::ostringstream msg; + msg << om << i << ", " << j << ", " << k; + b = random::uniform(shape); + TIMEM(msg.str(), add, a, b, device); + shape[k] = a.shape(k); + } + } + shape[i] = a.shape(i); + } +} + +void time_irregular_reshape() { + auto device = default_device(); + std::vector shape; + auto reshape_fn = [&shape, device](const array& a) { + return reshape(a, shape, device); + }; + + int size = 64; + int d = 2 * size; + + auto a = random::uniform({d, d, d}); + + shape = {8 * size, size, size}; + TIMEM("3D contiguous", reshape_fn, a); + + a = transpose(a); + shape = {8 * size, size, size}; + TIMEM("3D transpose", reshape_fn, a); + + a = transpose(a, {1, 2, 0}); + shape = {8 * size, size, size}; + TIMEM("3D transpose dims 1 2", reshape_fn, a); + + a = broadcast_to(random::uniform({d, d}), {d, d, d}); + TIMEM("3D broadcast dim 0", reshape_fn, a); + + a = broadcast_to(random::uniform({d, 1, d}), {d, d, d}); + TIMEM("3D broadcast dim 1", reshape_fn, a); + + a = broadcast_to(random::uniform({d, d, 1}), {d, d, d}); + TIMEM("3D broadcast dim 2", reshape_fn, a); + + a = broadcast_to(random::uniform({d}), {d, d, d}); + TIMEM("3D broadcast dims 0, 1", reshape_fn, a); + + a = broadcast_to(random::uniform({d, 1}), {d, d, d}); + TIMEM("3D broadcast dims 0, 2", reshape_fn, a); + + a = broadcast_to(random::uniform({d, 1, 1}), {d, d, d}); + TIMEM("3D broadcast dims 1, 2", reshape_fn, a); + + a = broadcast_to(random::uniform({1, 1, 1}), {d, d, d}); + TIMEM("3D broadcast dims 1, 2, 3", reshape_fn, a); +} + +void time_irregular_astype_1D() { + auto device = default_device(); + int size = 1000000; + int step = 2; + auto a = random::uniform({size}); + a = slice(a, {0}, {size}, {step}); + TIMEM("1D strided", astype, a, int32, device); +} + +void time_irregular_astype_2D() { + auto device = default_device(); + int size = 2048; + std::vector shape = {size, size}; + + auto a = random::uniform(shape); + TIMEM("2D regular", astype, a, int32, device); + + a = transpose(a); + TIMEM("2D transpose", astype, a, int32, device); + + a = broadcast_to(random::uniform({size}), shape); + TIMEM("2D broadcast dim 0", astype, a, int32, device); + + a = broadcast_to(random::uniform({size, 1}), shape); + TIMEM("2D broadcast dim 1", astype, a, int32, device); +} + +int main(int argc, char** argv) { + if (argc > 1) { + bool use_gpu = !strcmp(argv[1], "gpu"); + set_default_device(use_gpu ? Device::gpu : Device::cpu); + } + std::cout << "Benchmarks for " << default_device() << std::endl; + time_irregular_binary_ops_1D(); + time_irregular_binary_ops_2D(); + time_irregular_binary_ops_3D(); + time_irregular_binary_ops_4D(); + time_irregular_reshape(); + time_irregular_astype_1D(); + time_irregular_astype_2D(); +} diff --git a/benchmarks/cpp/single_ops.cpp b/benchmarks/cpp/single_ops.cpp new file mode 100644 index 0000000000..aa0fa2b363 --- /dev/null +++ b/benchmarks/cpp/single_ops.cpp @@ -0,0 +1,247 @@ +#include "mlx/mlx.h" +#include "time_utils.h" + +using namespace mlx::core; + +void time_creation_ops() { + int M = 2000; + int N = 500; + auto shape = {M, N}; + auto full_fp32 = [&]() { return full(shape, 3.3f); }; + TIME(full_fp32); + auto zeros_fp32 = [&]() { return zeros(shape, float32); }; + TIME(zeros_fp32); + auto ones_fp32 = [&]() { return ones(shape, float32); }; + TIME(ones_fp32); + + auto arange_fp32 = [&]() { return arange(0.0, 10.0, 1e-4); }; + TIME(arange_fp32); +} + +void time_type_conversions() { + int M = 2000; + int N = 500; + auto shape = {M, N}; + auto device = default_device(); + + auto a = zeros(shape, float32); + eval(a); + TIMEM("float32 to int32", astype, a, int32, device); + TIMEM("float32 to uint32", astype, a, uint32, device); + + a = zeros(shape, int32); + eval(a); + TIMEM("int32 to float32", astype, a, float32, device); + + a = zeros(shape, bool_); + eval(a); + TIMEM("bool to float32", astype, a, float32, device); + TIMEM("bool to int32", astype, a, int32, device); + TIMEM("bool to uint32", astype, a, uint32, device); +} + +void time_random_generation() { + int M = 2000; + int N = 500; + + auto uniform = [&]() { return random::uniform({M, N}, float32); }; + TIME(uniform); + auto normal = [&]() { return random::normal({M, N}, float32); }; + TIME(normal); +} + +void time_unary_ops() { + int M = 2000; + int N = 500; + auto device = default_device(); + + auto a = random::normal({M, N}); + eval(a); + TIME(mlx::core::abs, a, device); + TIME(negative, a, device); + TIME(sign, a, device); + TIME(square, a, device); + TIME(mlx::core::sqrt, a, device); + TIME(rsqrt, a, device); + TIME(mlx::core::exp, a, device); + + a = random::uniform({M, N}); + TIME(mlx::core::log, a, device); +} + +void time_binary_ops() { + int M = 1000, N = 100, K = 10; + auto a = random::uniform({M, N, K}); + auto b = random::uniform({M, N, K}); + auto device = default_device(); + eval(a, b); + + TIME(add, a, b, device); + TIME(subtract, a, b, device); + TIME(multiply, a, b, device); + TIME(divide, a, b, device); + TIME(maximum, a, b, device); + TIME(minimum, a, b, device); + + b = random::uniform({1}); + eval(b); + TIMEM("scalar", add, a, b, device); + TIMEM("vector-scalar", subtract, a, b, device); + TIMEM("scalar-vector", subtract, b, a, device); + TIMEM("scalar", multiply, a, b, device); + TIMEM("vector-scalar", divide, a, b, device); + TIMEM("scalar-vector", divide, b, a, device); + + a = broadcast_to(random::uniform({1}), {1000, 100}); + b = broadcast_to(random::uniform({1}), {1000, 100}); + eval(a, b); + TIMEM("scalar-scalar broadcast", add, a, b, device); + TIMEM("scalar-scalar broadcast", subtract, a, b, device); + TIMEM("scalar-scalar broadcast", multiply, a, b, device); + TIMEM("scalar-scalar broadcast", divide, a, b, device); +} + +void time_strided_ops() { + int M = 50, N = 50, O = 50, P = 50; + auto a = random::uniform({M, N, O, P}); + auto b = random::uniform({M, N, O, P}); + auto device = default_device(); + eval(a, b); + TIMEM("non-strided", add, a, b, device); + a = transpose(a, {1, 0, 2, 3}); + b = transpose(b, {3, 2, 0, 1}); + eval(a, b); + TIMEM("strided", add, a, b, device); +} + +void time_comparisons() { + int M = 1000, N = 100, K = 10; + auto a = random::uniform({M, N, K}); + auto b = random::uniform({M, N, K}); + auto device = default_device(); + eval(a, b); + TIME(equal, a, b, device); + TIME(greater, a, b, device); + TIME(greater_equal, a, b, device); + TIME(less, a, b, device); + TIME(less_equal, a, b, device); +} + +void time_matvec() { + int M = 2000, N = 200; + auto a = random::uniform({M, N}); + auto b = random::uniform({N}); + auto c = random::uniform({M}); + eval(a, b, c); + auto matvec = [&]() { return matmul(a, b); }; + TIME(matvec); + + auto matvec_transpose = [&]() { return matmul(transpose(a), c); }; + TIME(matvec_transpose); +} + +void time_matmul() { + int M = 1000, N = 1000, K = 1000; + auto a = random::uniform({M, K}); + auto b = random::uniform({K, N}); + auto device = default_device(); + eval(a, b); + TIME(matmul, a, b, device); + + auto transpose_matmul = [&]() { return matmul(transpose(a), b); }; + TIME(transpose_matmul); +} + +void time_reductions() { + auto a = random::normal({10000, 1000}); + eval(a); + auto sum_all = [&a]() { return sum(a, false); }; + TIME(sum_all); + + auto sum_along_0 = [&a]() { return sum(a, 0, false); }; + TIME(sum_along_0); + + auto sum_along_1 = [&a]() { return sum(a, 1, false); }; + TIME(sum_along_1); + + auto prod_all = [&a]() { return prod(a, false); }; + TIME(prod_all); + + auto all_true = [&a]() { return all(a, false); }; + TIME(all_true); + + auto all_along_0 = [&a]() { return all(a, 0, false); }; + TIME(all_along_0); + + auto all_along_1 = [&a]() { return all(a, 1, false); }; + TIME(all_along_1); + + auto any_true = [&a]() { return any(a, false); }; + TIME(any_true); + + auto argmin_along_0 = [&a]() { return argmin(a, 0, false); }; + TIME(argmin_along_0); + + auto argmin_along_1 = [&a]() { return argmin(a, 1, false); }; + TIME(argmin_along_1); +} + +void time_gather_scatter() { + auto a = random::normal({1000, 768}); + eval(a); + auto indices = random::randint(0, 1000, {256}); + eval(indices); + + auto embedding_lookup = [&a, &indices]() { return take(a, indices, 0); }; + TIME(embedding_lookup); + + indices = random::randint(0, 768 * 1000, {256 * 768}); + eval(indices); + + auto single_element_lookup = [&a, &indices]() { return take(a, indices); }; + TIME(single_element_lookup); + + indices = random::randint(0, 1000, {256}); + auto updates = random::normal({256, 1, 768}); + eval(indices, updates); + + auto embedding_update = [&a, &indices, &updates]() { + return scatter(a, indices, updates, 0); + }; + TIME(embedding_update); + + auto embedding_add = [&a, &indices, &updates]() { + return scatter_add(a, indices, updates, 0); + }; + TIME(embedding_add); + + a = reshape(a, {-1}); + indices = random::randint(0, 768 * 1000, {768 * 256}); + updates = random::normal({256 * 768, 1}); + eval(a, indices, updates); + + auto single_element_update = [&a, &indices, &updates]() { + return scatter(a, indices, updates, 0); + }; + TIME(single_element_update); + + auto single_element_add = [&a, &indices, &updates]() { + return scatter_add(a, indices, updates, 0); + }; + TIME(single_element_add); +} + +int main() { + std::cout << "Benchmarks for " << default_device() << std::endl; + time_creation_ops(); + time_type_conversions(); + time_unary_ops(); + time_binary_ops(); + time_strided_ops(); + time_random_generation(); + time_comparisons(); + time_matvec(); + time_matmul(); + time_reductions(); + time_gather_scatter(); +} diff --git a/benchmarks/python/comparative/README.md b/benchmarks/python/comparative/README.md new file mode 100644 index 0000000000..dc0e8306a7 --- /dev/null +++ b/benchmarks/python/comparative/README.md @@ -0,0 +1,15 @@ +Microbenchmarks comparing MLX to PyTorch +======================================== + +Implement the same microbenchmarks in MLX and PyTorch to compare and make a +list of the biggest possible performance improvements and/or regressions. + +Run with `python bench_mlx.py sum_axis --size 8x1024x128 --axis 2 --cpu` for +instance to measure the times it takes to sum across the 3rd axis of the above +tensor on the cpu. + +`compare.py` runs several benchmarks and compares the speed-up or lack thereof +in comparison to PyTorch. + +Each bench script can be run with `--print-pid` to print the PID and wait for a +key in order to ease attaching a debugger. diff --git a/benchmarks/python/comparative/bench_mlx.py b/benchmarks/python/comparative/bench_mlx.py new file mode 100644 index 0000000000..a897742704 --- /dev/null +++ b/benchmarks/python/comparative/bench_mlx.py @@ -0,0 +1,313 @@ +import argparse +import math +import os +import time + +import mlx.core as mx + + +def int_or_list(x): + try: + return int(x) + except ValueError: + return [int(xi) for xi in x.split(",")] + + +def none_or_list(x): + if x == "": + return None + else: + return [int(xi) for xi in x.split(",")] + + +def bench(f, *args): + for i in range(10): + f(*args) + + s = time.time() + for i in range(100): + f(*args) + e = time.time() + return e - s + + +def matmul_square(x): + y = x + for i in range(10): + y = y @ x + mx.eval(y) + return y + + +def matmul(x, y): + ys = [] + for i in range(10): + ys.append(x @ y) + mx.eval(ys) + + +def conv1d(x, y): + ys = [] + for i in range(10): + ys.append(mx.conv1d(x, y)) + mx.eval(ys) + + +def conv2d(x, y): + ys = [] + for i in range(10): + ys.append(mx.conv2d(x, y)) + mx.eval(ys) + + +def binary(op, x, y): + for i in range(100): + y = getattr(mx, op)(x, y) + mx.eval(y) + + +def reduction(op, axis, x): + ys = [] + for i in range(100): + ys.append(getattr(mx, op)(x, axis=axis)) + mx.eval(ys) + + +def softmax(axis, x): + ys = [] + for i in range(100): + ex = mx.exp(x - mx.max(x, axis=axis, keepdims=True)) + y = ex / mx.sum(ex, axis=axis, keepdims=True) + ys.append(y) + mx.eval(ys) + + +def softmax_fused(axis, x): + ys = [] + for i in range(100): + y = mx.softmax(x, axis=axis) + ys.append(y) + mx.eval(ys) + + +def relu(x): + y = x + for i in range(100): + y = mx.maximum(y, 0) + mx.eval(y) + + +def scalar_mult(x): + y = x + for i in range(100): + y = y * (1.0 / (1 + i)) + mx.eval(y) + + +def cross_entropy(targets, x): + ys = [] + for i in range(100): + y = mx.logsumexp(x, axis=-1, keepdims=True) - mx.take_along_axis( + x, mx.reshape(targets, (-1, 1)), axis=-1 + ) + ys.append(mx.mean(y)) + mx.eval(ys) + + +def logsumexp(axis, x): + ys = [] + for i in range(100): + ys.append(mx.logsumexp(x, axis=axis)) + mx.eval(ys) + + +def linear(w, b, x): + ys = [] + for i in range(10): + ys.append(x @ mx.transpose(w, (1, 0)) + b) + mx.eval(ys) + + +def rope(x): + *_, N, D = x.shape + ys = [] + for i in range(10): + shape = x.shape + x = mx.reshape(x, (-1, N, D)) + positions = mx.arange(N) + freqs = mx.exp(mx.arange(0.0, D // 2) / math.log(10000 / (D // 2 - 1))) + theta = mx.reshape(positions, (-1, 1)) * mx.reshape(freqs, (1, -1)) + costheta = mx.cos(theta) + sintheta = mx.sin(theta) + x1 = x[..., ::2] + x2 = x[..., 1::2] + rx1 = x1 * costheta - x2 * sintheta + rx2 = x1 * sintheta + x2 * costheta + y = mx.concatenate([rx1[..., None], rx2[..., None]], axis=-1) + y = mx.reshape(y, (-1, N, D)) + ys.append(y) + mx.eval(ys) + + +def concatenate(axis, x, y): + ys = [] + for i in range(10): + ys.append(mx.concatenate([x, y], axis=axis)) + mx.eval(ys) + + +def cumsum(axis, x): + ys = [] + for i in range(10): + ys.append(mx.cumsum(x, axis)) + mx.eval(ys) + + +def sort(axis, x): + ys = [] + for i in range(10): + ys.append(mx.sort(x, axis)) + mx.eval(ys) + + +def topk(axis, x): + k = x.shape[axis] // 3 + ys = [] + for i in range(10): + ys.append(mx.topk(x, k, axis)) + mx.eval(ys) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("benchmark", help="Choose the benchmark to run") + parser.add_argument( + "--size", + default=[(1024, 1024)], + type=lambda x: list(map(int, x.split("x"))), + help="Set the matrix size", + action="append", + ) + parser.add_argument( + "--axis", + default=[1], + type=int_or_list, + help="Set a reduction axis", + action="append", + ) + parser.add_argument( + "--transpose", + type=none_or_list, + default=[], + help="Permute the matrix", + action="append", + ) + parser.add_argument( + "--print-pid", action="store_true", help="Print the PID and pause" + ) + parser.add_argument("--cpu", action="store_true", help="Use the CPU") + parser.add_argument( + "--fused", action="store_true", help="Use fused functions where possible" + ) + parser.add_argument( + "--dtype", choices=["float32", "float16", "bfloat16"], default="float32" + ) + + args = parser.parse_args() + + if len(args.size) > 1: + args.size.pop(0) + if len(args.axis) > 1: + args.axis.pop(0) + + if args.print_pid: + print(os.getpid()) + input("Press enter to run") + + if args.cpu: + mx.set_default_device(mx.cpu) + else: + mx.set_default_device(mx.gpu) + dtype = dict(float32=mx.float32, float16=mx.float16, bfloat16=mx.bfloat16)[ + args.dtype + ] + xs = [] + for size in args.size: + xs.append(mx.random.normal(size).astype(dtype)) + for i, t in enumerate(args.transpose): + if t is None: + continue + xs[i] = mx.transpose(xs[i], t) + mx.eval(xs) + x = xs[0] + axis = args.axis[0] + + if args.benchmark == "matmul_square": + print(bench(matmul_square, x)) + + elif args.benchmark == "matmul": + print(bench(matmul, *xs)) + + elif args.benchmark == "linear": + print(bench(linear, *xs)) + + elif args.benchmark == "sum_axis": + print(bench(reduction, "sum", axis, x)) + + elif args.benchmark == "sum_all": + print(bench(reduction, "sum", None, x)) + + elif args.benchmark == "argmax": + print(bench(reduction, "argmax", axis, x)) + + elif args.benchmark == "add": + print(bench(binary, "add", *xs)) + + elif args.benchmark == "mul": + print(bench(binary, "multiply", *xs)) + + elif args.benchmark == "softmax": + if args.fused: + print(bench(softmax_fused, axis, x)) + else: + print(bench(softmax, axis, x)) + + elif args.benchmark == "relu": + print(bench(relu, x)) + + elif args.benchmark == "scalar_mul": + print(bench(scalar_mult, x)) + + elif args.benchmark == "cross_entropy": + if len(size) != 2: + raise ValueError("Error: [cross_entropy] benchmark requires a 2 dim size") + + targets = mx.zeros((len(x),), dtype=mx.uint32) + print(bench(cross_entropy, targets, x)) + + elif args.benchmark == "logsumexp": + print(bench(logsumexp, axis, x)) + + elif args.benchmark == "rope": + print(bench(rope, x)) + + elif args.benchmark == "concatenate": + print(bench(concatenate, axis, *xs)) + + elif args.benchmark == "cumsum": + print(bench(cumsum, axis, *xs)) + + elif args.benchmark == "conv1d": + print(bench(conv1d, *xs)) + + elif args.benchmark == "conv2d": + print(bench(conv2d, *xs)) + + elif args.benchmark == "sort": + print(bench(sort, axis, x)) + + elif args.benchmark == "topk": + print(bench(topk, axis, x)) + + else: + raise ValueError("Unknown benchmark") diff --git a/benchmarks/python/comparative/bench_torch.py b/benchmarks/python/comparative/bench_torch.py new file mode 100644 index 0000000000..444ae6ce15 --- /dev/null +++ b/benchmarks/python/comparative/bench_torch.py @@ -0,0 +1,338 @@ +import argparse +import os +import time + +import torch +import torch.mps + + +def int_or_list(x): + try: + return int(x) + except ValueError: + return [int(xi) for xi in x.split(",")] + + +def none_or_list(x): + if x == "": + return None + else: + return [int(xi) for xi in x.split(",")] + + +def bench(f, *args): + for i in range(10): + f(*args) + + s = time.time() + for i in range(100): + f(*args) + e = time.time() + return e - s + + +def sync_if_needed(x): + if x.device != torch.device("cpu"): + torch.mps.synchronize() + + +@torch.no_grad() +def matmul_square(x): + y = x + for i in range(10): + y = y @ x + sync_if_needed(x) + + +@torch.no_grad() +def matmul(x, y): + ys = [] + for i in range(10): + ys.append(x @ y) + sync_if_needed(x) + + +@torch.no_grad() +def conv1d(x, y): + x = torch.transpose(x, -1, -2) + y = torch.transpose(y, -1, -2) + ys = [] + for i in range(10): + ys.append(torch.nn.functional.conv1d(x, y)) + sync_if_needed(x) + + +@torch.no_grad() +def conv2d(x, y): + x = torch.permute(x, (0, 3, 1, 2)) + y = torch.permute(y, (0, 3, 1, 2)) + ys = [] + for i in range(10): + ys.append(torch.nn.functional.conv2d(x, y)) + sync_if_needed(x) + + +@torch.no_grad() +def binary(op, x, y): + for i in range(100): + y = getattr(torch, op)(x, y) + sync_if_needed(x) + + +@torch.no_grad() +def reduction(op, axis, x): + ys = [] + for i in range(100): + ys.append(getattr(x, op)(axis)) + sync_if_needed(x) + + +@torch.no_grad() +def softmax(axis, x): + ys = [] + for i in range(100): + ex = torch.exp(x - torch.max(x, dim=axis, keepdims=True).values) + y = ex / torch.sum(ex, dim=axis, keepdims=True) + ys.append(y) + sync_if_needed(x) + + +@torch.no_grad() +def softmax_fused(axis, x): + ys = [] + for i in range(100): + ys.append(torch.nn.functional.softmax(x, dim=axis)) + sync_if_needed(x) + + +@torch.no_grad() +def relu(x): + y = x + for i in range(100): + y = torch.nn.functional.relu(y) + sync_if_needed(x) + + +@torch.no_grad() +def scalar_mult(x): + y = x + for i in range(100): + y = y * (1.0 / (1 + i)) + sync_if_needed(x) + + +@torch.no_grad() +def cross_entropy(targets, x): + ys = [] + for i in range(100): + ys.append(torch.nn.functional.cross_entropy(x, targets)) + sync_if_needed(x) + + +@torch.no_grad() +def logsumexp(axis, x): + ys = [] + for i in range(100): + ys.append(torch.logsumexp(x, dim=axis)) + sync_if_needed(x) + + +@torch.no_grad() +def linear_fused(w, b, x): + ys = [] + for i in range(10): + ys.append(torch.nn.functional.linear(x, w, b)) + sync_if_needed(x) + + +@torch.no_grad() +def linear(w, b, x): + ys = [] + for i in range(10): + ys.append((x @ torch.transpose(w, -2, -1)) + b) + sync_if_needed(x) + + +@torch.no_grad() +def rope(x): + *_, N, D = x.shape + ys = [] + for i in range(10): + x = x.view(-1, N, D) + positions = torch.arange(N, device=x.device) + freqs = 10000 ** torch.linspace(0, 1, D // 2, device=x.device) + theta = positions[:, None] * freqs[None] + costheta = torch.cos(theta) + sintheta = torch.sin(theta) + x1 = x[..., ::2] + x2 = x[..., 1::2] + rx1 = x1 * costheta - x2 * sintheta + rx2 = x1 * sintheta + x2 * costheta + y = torch.cat([rx1[..., None], rx2[..., None]], dim=-1) + y = y.reshape(-1, N, D) + ys.append(y) + sync_if_needed(x) + + +@torch.no_grad() +def concatenate(axis, x, y): + ys = [] + for i in range(10): + ys.append(torch.cat([x, y], dim=axis)) + sync_if_needed(x) + + +@torch.no_grad() +def cumsum(axis, x): + ys = [] + for i in range(10): + ys.append(x.cumsum(axis)) + sync_if_needed(x) + + +@torch.no_grad() +def sort(axis, x): + ys = [] + for i in range(10): + ys.append(torch.sort(x, dim=axis)[0]) + sync_if_needed(x) + + +@torch.no_grad() +def topk(axis, x): + k = x.shape[axis] // 3 + ys = [] + for i in range(10): + ys.append(torch.topk(x, k, dim=axis)[0]) + sync_if_needed(x) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("benchmark", help="Choose the benchmark to run") + parser.add_argument( + "--size", + default=[(1024, 1024)], + type=lambda x: list(map(int, x.split("x"))), + help="Set the matrix size", + action="append", + ) + parser.add_argument( + "--axis", + default=[1], + type=int_or_list, + help="Set a reduction axis", + action="append", + ) + parser.add_argument( + "--transpose", + type=none_or_list, + default=[], + help="Permute the matrix", + action="append", + ) + parser.add_argument( + "--print-pid", action="store_true", help="Print the PID and pause" + ) + parser.add_argument("--cpu", action="store_true", help="Use the CPU") + parser.add_argument( + "--fused", action="store_true", help="Use fused functions where possible" + ) + parser.add_argument("--dtype", choices=["float32", "float16"], default="float32") + + args = parser.parse_args() + + if len(args.size) > 1: + args.size.pop(0) + if len(args.axis) > 1: + args.axis.pop(0) + + if args.print_pid: + print(os.getpid()) + input("Press enter to run") + + torch.set_num_threads(1) + device = "cpu" if args.cpu else "mps" + dtype = dict(float32=torch.float32, float16=torch.float16)[args.dtype] + xs = [] + for size in args.size: + xs.append(torch.randn(*size).to(device).to(dtype)) + for i, t in enumerate(args.transpose): + if t is None: + continue + xs[i] = xs[i].permute(*t) + x = xs[0] + axis = args.axis[0] + + if args.benchmark == "matmul_square": + print(bench(matmul_square, x)) + + elif args.benchmark == "matmul": + print(bench(matmul, *xs)) + + elif args.benchmark == "linear": + if args.fused: + print(bench(linear_fused, *xs)) + else: + print(bench(linear, *xs)) + + elif args.benchmark == "sum_axis": + print(bench(reduction, "sum", axis, x)) + + elif args.benchmark == "sum_all": + print(bench(reduction, "sum", None, x)) + + elif args.benchmark == "argmax": + print(bench(reduction, "argmax", axis, x)) + + elif args.benchmark == "add": + print(bench(binary, "add", *xs)) + + elif args.benchmark == "mul": + print(bench(binary, "mul", *xs)) + + elif args.benchmark == "softmax": + if args.fused: + print(bench(softmax_fused, axis, x)) + else: + print(bench(softmax, axis, x)) + + elif args.benchmark == "relu": + print(bench(relu, x)) + + elif args.benchmark == "scalar_mul": + print(bench(scalar_mult, x)) + + elif args.benchmark == "cross_entropy": + if len(size) != 2: + raise ValueError("Error: [cross_entropy] benchmark requires a 2 dim size") + + targets = torch.zeros(len(x), dtype=torch.long).to(x.device) + print(bench(cross_entropy, targets, x)) + + elif args.benchmark == "logsumexp": + print(bench(logsumexp, axis, x)) + + elif args.benchmark == "rope": + print(bench(rope, x)) + + elif args.benchmark == "concatenate": + print(bench(concatenate, axis, *xs)) + + elif args.benchmark == "cumsum": + print(bench(cumsum, axis, *xs)) + + elif args.benchmark == "conv1d": + print(bench(conv1d, *xs)) + + elif args.benchmark == "conv2d": + print(bench(conv2d, *xs)) + + elif args.benchmark == "sort": + print(bench(sort, axis, x)) + + elif args.benchmark == "topk": + print(bench(topk, axis, x)) + + else: + raise ValueError("Unknown benchmark") diff --git a/benchmarks/python/comparative/compare.py b/benchmarks/python/comparative/compare.py new file mode 100644 index 0000000000..00b9e40713 --- /dev/null +++ b/benchmarks/python/comparative/compare.py @@ -0,0 +1,253 @@ +#!/usr/bin/env python + +import argparse +import re +from pathlib import Path +from subprocess import run + +BENCH_MLX = Path(__file__).parent / "bench_mlx.py" +BENCH_TORCH = Path(__file__).parent / "bench_torch.py" + + +def run_or_raise(*args, **kwargs): + try: + result = run(*args, capture_output=True, **kwargs) + return float(result.stdout) + except ValueError: + raise ValueError(f"stdout: {result.stdout}\nstderr: {result.stderr}") + + +def compare(args): + t_mlx = run_or_raise(["python", BENCH_MLX] + args) + t_torch = run_or_raise(["python", BENCH_TORCH] + args) + + print((t_torch - t_mlx) / t_torch, " ".join(args), sep="\t") + + +def compare_mlx_dtypes(args, dt1, dt2): + t_mlx_dt1 = run_or_raise(["python", BENCH_MLX] + args + ["--dtype", dt1]) + t_mlx_dt2 = run_or_raise(["python", BENCH_MLX] + args + ["--dtype", dt2]) + + print((t_mlx_dt2 - t_mlx_dt1) / t_mlx_dt2, " ".join(args), sep="\t") + + +def make_regex_search(regexes): + compiled_regexes = list(map(re.compile, regexes)) + + def search(x): + return (c.search(x) is not None for c in compiled_regexes) + + return search + + +def make_predicate(positive_filter, negative_filter): + if positive_filter is not None: + positive_filter_search = make_regex_search(positive_filter) + positive_filter = lambda x: all(positive_filter_search(x)) + else: + positive_filter = lambda x: True + + if negative_filter is not None: + negative_filter_search = make_regex_search(negative_filter) + negative_filter = lambda x: not any(negative_filter_search(x)) + else: + negative_filter = lambda x: True + + def predicate(x): + return positive_filter(x) and negative_filter(x) + + return predicate + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Run comparisons agains PyTorch") + parser.add_argument( + "--filter", "-f", help="Regex filter to select benchmarks", nargs="+" + ) + parser.add_argument( + "--negative_filter", "-n", help="Regex filter to remove benchmarks", nargs="+" + ) + parser.add_argument( + "--mlx_dtypes", + "-d", + help="Compare mlx benchmarks between the 2 provided data types", + nargs=2, + ) + args, rest = parser.parse_known_args() + + _filter = make_predicate(args.filter, args.negative_filter) + + if args.mlx_dtypes: + compare_filtered = ( + lambda x: compare_mlx_dtypes( + x.split() + rest, args.mlx_dtypes[0], args.mlx_dtypes[1] + ) + if _filter(x) + else None + ) + else: + compare_filtered = lambda x: compare(x.split() + rest) if _filter(x) else None + + # Binary ops + compare_filtered("add --size 10x1024x128 --size 1x1024x128 --cpu") + compare_filtered("add --size 10x1024x128 --size 1x1024x128") + compare_filtered("add --size 1024x128 --size 1x128 --cpu") + compare_filtered("add --size 1024x128 --size 1x128") + compare_filtered("add --size 1024x4096 --size 1x4096 --cpu") + compare_filtered("add --size 1024x4096 --size 1x4096") + compare_filtered("add --size 1024x4096 --size 1x1024 --transpose 1,0 --cpu") + compare_filtered("add --size 1024x4096 --size 1x1024 --transpose 1,0") + compare_filtered("add --size 1024x1024 --size 1024x1024 --cpu") + compare_filtered("add --size 1024x1024 --size 1024x1024") + compare_filtered("add --size 1024x1024 --size 1024x1024 --transpose 1,0 --cpu") + compare_filtered("add --size 1024x1024 --size 1024x1024 --transpose 1,0") + compare_filtered( + "add --size 1024x1024 --size 1024x1024 --transpose 1,0 --transpose 1,0 --cpu" + ) + compare_filtered( + "add --size 1024x1024 --size 1024x1024 --transpose 1,0 --transpose 1,0" + ) + + # Reduction ops + compare_filtered("sum_all --size 10x1024x128 --cpu") + compare_filtered("sum_all --size 10x1024x128") + compare_filtered("sum_axis --size 16x1024x128 --axis 2 --cpu") + compare_filtered("sum_axis --size 16x1024x128 --axis 2") + compare_filtered("sum_axis --size 16x128x1024 --axis 2 --cpu") + compare_filtered("sum_axis --size 16x128x1024 --axis 2") + compare_filtered("sum_axis --size 1024x1024 --axis 1 --cpu") + compare_filtered("sum_axis --size 1024x1024 --axis 1") + compare_filtered("sum_axis --size 1024x1024 --axis 0 --cpu") + compare_filtered("sum_axis --size 1024x1024 --axis 0") + compare_filtered("sum_axis --size 16x128x1024 --axis 1 --cpu") + compare_filtered("sum_axis --size 16x128x1024 --axis 1") + compare_filtered("sum_axis --size 16x128x1024 --axis 0 --cpu") + compare_filtered("sum_axis --size 16x128x1024 --axis 0") + compare_filtered("argmax --size 10x1024x128 --axis 1 --cpu") + compare_filtered("argmax --size 10x1024x128 --axis 1") + compare_filtered("argmax --size 10x1024x128 --axis 2 --cpu") + compare_filtered("argmax --size 10x1024x128 --axis 2") + compare_filtered("argmax --size 1024x1024 --axis 1 --cpu") + compare_filtered("argmax --size 1024x1024 --axis 1") + + # Matmul ops + compare_filtered("matmul_square --size 1024x1024") + compare_filtered("matmul_square --size 1024x1024 --cpu") + compare_filtered("matmul_square --size 16x1024x1024") + compare_filtered("matmul_square --size 16x1024x1024 --cpu") + compare_filtered( + "matmul --size 16x768x768 --size 16x768x768 --transpose= --transpose 0,2,1" + ) + compare_filtered( + "matmul --size 16x768x768 --size 16x768x768 --transpose= --transpose 0,2,1 --cpu" + ) + compare_filtered( + "matmul --size 16x768x128 --size 16x768x128 --transpose= --transpose 0,2,1" + ) + compare_filtered( + "matmul --size 16x768x128 --size 16x768x128 --transpose= --transpose 0,2,1 --cpu" + ) + compare_filtered("matmul --size 512x8192 --size 8192x512") + compare_filtered("matmul --size 512x8192 --size 8192x512 --cpu") + # compare_filtered("matmul --size 512x131072 --size 131072x512") + # compare_filtered("matmul --size 512x131072 --size 131072x512 --cpu") + compare_filtered("matmul --size 8192x512 --size 512x8192") + compare_filtered("matmul --size 8192x512 --size 512x8192 --cpu") + # compare_filtered("matmul --size 131072x512 --size 512x512") + # compare_filtered("matmul --size 131072x512 --size 512x512 --cpu") + compare_filtered("linear --size 1024x1024 --size 1024 --size 128x1024") + compare_filtered("linear --size 1024x1024 --size 1024 --size 128x1024 --cpu") + compare_filtered("linear --size 1024x1024 --size 1024 --size 128x1024 --fused") + compare_filtered( + "linear --size 1024x1024 --size 1024 --size 128x1024 --fused --cpu" + ) + + # Matvec ops + compare_filtered("matmul --size 1x1x4096 --size 4096x4096 --cpu") + compare_filtered("matmul --size 1x1x4096 --size 4096x4096") + compare_filtered( + "matmul --size 1x1x4096 --size 4096x4096 --transpose= --transpose 1,0 --cpu" + ) + compare_filtered( + "matmul --size 1x1x4096 --size 4096x4096 --transpose= --transpose 1,0" + ) + compare_filtered("matmul --size 32x1x1000 --size 32x1000x128 --cpu") + compare_filtered("matmul --size 32x1x1000 --size 32x1000x128") + compare_filtered( + "matmul --size 32x1x1000 --size 32x128x1000 --transpose= --transpose 0,2,1 --cpu" + ) + compare_filtered( + "matmul --size 32x1x1000 --size 32x128x1000 --transpose= --transpose 0,2,1" + ) + + # Various ops + compare_filtered("softmax --size 32x16x1024 --axis 2") + compare_filtered("softmax --size 32x16x1024 --axis 2 --cpu") + compare_filtered("softmax --size 32x16x1024 --axis 2 --fused") + compare_filtered("softmax --size 32x16x1024 --axis 2 --fused --cpu") + compare_filtered("softmax --size 2x1024x1024 --axis 1") + compare_filtered("softmax --size 2x1024x1024 --axis 1 --cpu") + compare_filtered("softmax --size 2x1024x1024 --axis 1 --fused") + compare_filtered("softmax --size 2x1024x1024 --axis 1 --fused --cpu") + compare_filtered("relu --size 32x16x1024") + compare_filtered("relu --size 32x16x1024 --cpu") + compare_filtered("scalar_mul --size 32x16x1024") + compare_filtered("scalar_mul --size 32x16x1024 --cpu") + compare_filtered("cross_entropy --size 256x1024") + compare_filtered("cross_entropy --size 256x1024 --cpu") + compare_filtered("logsumexp --size 1024x1024 --axis 1") + compare_filtered("logsumexp --size 1024x1024 --axis 1 --cpu") + compare_filtered("logsumexp --size 1024x1024 --axis 0") + compare_filtered("logsumexp --size 1024x1024 --axis 0 --cpu") + compare_filtered("concatenate --size 32x1024x128 --size 32x1024x128 --axis 2") + compare_filtered("concatenate --size 32x1024x128 --size 32x1024x128 --axis 2 --cpu") + compare_filtered("concatenate --size 32x1024x128 --size 32x1024x128 --axis 1") + compare_filtered("concatenate --size 32x1024x128 --size 32x1024x128 --axis 1 --cpu") + compare_filtered("concatenate --size 32x1024x128 --size 32x1024x128 --axis 0") + compare_filtered("concatenate --size 32x1024x128 --size 32x1024x128 --axis 0 --cpu") + compare_filtered("concatenate --size 32x1024x128 --size 32x16x128 --axis 1") + compare_filtered("concatenate --size 32x1024x128 --size 32x16x128 --axis 1 --cpu") + compare_filtered("concatenate --size 32x1024x128 --size 32x1x128 --axis 1") + compare_filtered("concatenate --size 32x1024x128 --size 32x1x128 --axis 1 --cpu") + compare_filtered("concatenate --size 1x32x1024x128 --size 1x32x1x128 --axis 2") + compare_filtered( + "concatenate --size 1x32x1024x128 --size 1x32x1x128 --axis 2 --cpu" + ) + compare_filtered("conv1d --size 1x1000x80 --size 128x11x80") + compare_filtered("conv1d --size 1x1000x80 --size 128x11x80 --cpu") + compare_filtered("conv1d --size 16x1000x80 --size 128x11x80") + compare_filtered("conv1d --size 4x1000x80 --size 128x11x80 --cpu") + compare_filtered("conv2d --size 1x256x256x3 --size 8x3x3x3") + compare_filtered("conv2d --size 1x256x256x3 --size 8x3x3x3 --cpu") + compare_filtered("conv2d --size 16x256x256x3 --size 8x3x3x3") + compare_filtered("conv2d --size 4x256x256x3 --size 8x3x3x3 --cpu") + compare_filtered("cumsum --size 1024x1024 --axis 1 --cpu") + compare_filtered("cumsum --size 1024x1024 --axis 0 --cpu") + compare_filtered("cumsum --size 1024x1024 --axis 1") + compare_filtered("cumsum --size 1024x1024 --axis 0") + compare_filtered("cumsum --size 128x1024 --axis 1") + compare_filtered("cumsum --size 128x1024 --axis 0") + compare_filtered("cumsum --size 1024x4096 --axis 1") + compare_filtered("cumsum --size 1024x4096 --axis 0") + compare_filtered("cumsum --size 128x4096 --axis 1") + compare_filtered("cumsum --size 128x4096 --axis 0") + compare_filtered("cumsum --size 1024x7777 --axis 1") + compare_filtered("cumsum --size 1024x7777 --axis 0") + compare_filtered("cumsum --size 128x7777 --axis 1") + compare_filtered("cumsum --size 128x7777 --axis 0") + compare_filtered("cumsum --size 32768x128 --axis 1") + compare_filtered("cumsum --size 32768x128 --axis 0") + + compare_filtered("sort --size 1024x1024 --axis 0") + compare_filtered("sort --size 1024x1024 --axis 1") + compare_filtered("sort --size 32768x128 --axis 0") + compare_filtered("sort --size 32768x128 --axis 1") + compare_filtered("sort --size 128x128 --axis 0 --cpu") + compare_filtered("sort --size 128x128 --axis 1 --cpu") + + compare_filtered("topk --size 1024x1024 --axis 0") + compare_filtered("topk --size 1024x1024 --axis 1") + compare_filtered("topk --size 32768x128 --axis 0") + compare_filtered("topk --size 32768x128 --axis 1") + compare_filtered("topk --size 128x128 --axis 0 --cpu") + compare_filtered("topk --size 128x128 --axis 1 --cpu") diff --git a/benchmarks/python/llama_jax_bench.py b/benchmarks/python/llama_jax_bench.py new file mode 100644 index 0000000000..e7539eca36 --- /dev/null +++ b/benchmarks/python/llama_jax_bench.py @@ -0,0 +1,196 @@ +import math +import time + +import jax +import jax.numpy as jnp +from flax import linen as nn + + +class RoPE(nn.Module): + dims: int + traditional: bool = False + + def _compute_rope(self, costheta, sintheta, x): + x1 = x[..., : self.dims // 2] + x2 = x[..., self.dims // 2 : self.dims] + rx1 = x1 * costheta - x2 * sintheta + rx2 = x1 * sintheta + x2 * costheta + + if self.dims < x.shape[-1]: + rx = jnp.concatenate([rx1, rx2, x[..., self.dims :]], axis=-1) + else: + rx = jnp.concatenate([rx1, rx2], axis=-1) + + return rx + + def _compute_traditional_rope(self, costheta, sintheta, x): + x1 = x[..., ::2] + x2 = x[..., 1::2] + rx1 = x1 * costheta - x2 * sintheta + rx2 = x1 * sintheta + x2 * costheta + + if self.dims < x.shape[-1]: + raise NotImplementedError( + "RoPE doesn't implement partial traditional application" + ) + + rx = jnp.concatenate([rx1[..., None], rx2[..., None]], axis=-1) + + return rx + + @staticmethod + def create_cos_sin_theta( + N: int, + D: int, + offset: int = 0, + base: float = 10000, + dtype=jnp.float32, + ): + D = D // 2 + positions = jnp.arange(offset, N, dtype=dtype) + freqs = jnp.exp(-jnp.arange(0, D, dtype=dtype) * (math.log(base) / D)) + theta = positions.reshape((-1, 1)) * freqs.reshape((1, -1)) + costheta = jnp.cos(theta) + sintheta = jnp.sin(theta) + + return costheta, sintheta + + @nn.compact + def __call__(self, x, offset: int = 0): + shape = x.shape + x = x.reshape((-1, shape[-2], shape[-1])) + N = x.shape[1] + offset + costheta, sintheta = RoPE.create_cos_sin_theta( + N, self.dims, offset=offset, dtype=x.dtype + ) + + rope = ( + self._compute_traditional_rope if self.traditional else self._compute_rope + ) + rx = rope(costheta, sintheta, x) + + return rx.reshape(shape) + + +class LlamaAttention(nn.Module): + dims: int + num_heads: int + dtype: jnp.dtype + + def setup(self): + num_heads = self.num_heads + dims = self.dims + + self.rope = RoPE(dims // num_heads, True) + self.query_proj = nn.Dense(dims, use_bias=False, param_dtype=self.dtype) + self.key_proj = nn.Dense(dims, use_bias=False, param_dtype=self.dtype) + self.value_proj = nn.Dense(dims, use_bias=False, param_dtype=self.dtype) + self.out_proj = nn.Dense(dims, use_bias=False, param_dtype=self.dtype) + + def __call__(self, queries, keys, values, mask=None, cache=None): + queries = self.query_proj(queries) + keys = self.key_proj(keys) + values = self.value_proj(values) + + num_heads = self.num_heads + B, L, D = queries.shape + queries = queries.reshape((B, L, num_heads, -1)).transpose((0, 2, 1, 3)) + keys = keys.reshape((B, L, num_heads, -1)).transpose((0, 2, 1, 3)) + values = values.reshape((B, L, num_heads, -1)).transpose((0, 2, 1, 3)) + + if cache is not None: + key_cache, value_cache = cache + queries = self.rope(queries, offset=key_cache.shape[2]) + keys = self.rope(keys, offset=key_cache.shape[2]) + keys = jnp.concatenate([key_cache, keys], axis=2) + values = jnp.concatenate([value_cache, values], axis=2) + else: + queries = self.rope(queries) + keys = self.rope(keys) + + # Dimensions are [batch x num heads x sequence x hidden dim] + scale = math.sqrt(1 / queries.shape[-1]) + scores = (queries * scale) @ keys.transpose((0, 1, 3, 2)) + if mask is not None: + scores = scores + mask + scores = jax.nn.softmax(scores, axis=-1) + values_hat = (scores @ values).transpose((0, 2, 1, 3)).reshape((B, L, -1)) + + return self.out_proj(values_hat), (keys, values) + + +class LlamaEncoderLayer(nn.Module): + dims: int + mlp_dims: int + num_heads: int + dtype: jnp.dtype + + def setup(self): + dims = self.dims + mlp_dims = self.mlp_dims + num_heads = self.num_heads + + self.attention = LlamaAttention(dims, num_heads, dtype) + + self.norm1 = nn.RMSNorm(param_dtype=self.dtype) + self.norm2 = nn.RMSNorm(param_dtype=self.dtype) + + self.linear1 = nn.Dense(mlp_dims, use_bias=False, param_dtype=self.dtype) + self.linear2 = nn.Dense(mlp_dims, use_bias=False, param_dtype=self.dtype) + self.linear3 = nn.Dense(dims, use_bias=False, param_dtype=self.dtype) + + def __call__(self, x, mask=None, cache=None): + y = self.norm1(x) + y, cache = self.attention(y, y, y, mask, cache) + x = x + y + + y = self.norm2(x) + a = self.linear1(y) + b = self.linear2(y) + y = jax.nn.silu(a) * b + y = self.linear3(y) + x = x + y + + return x, cache + + +def measure(model, x, cache): + for i in range(5): + y, c = model(x, mask=None, cache=cache) + jax.block_until_ready((y, c)) + + start = time.time() + for i in range(5): + y, c = model(x, mask=None, cache=cache) + jax.block_until_ready((y, c)) + + end = time.time() + return (end - start) * 1000 / 5 + + +if __name__ == "__main__": + H = 32 + D = 4096 + F = 43 * 256 + C = 1000 + dtype = jnp.float16 + + k1, k2, k3, k4 = jax.random.split(jax.random.PRNGKey(0), 4) + + x = jax.random.normal(k1, (1, 1, D), dtype) + cache = [ + jax.random.normal(k2, [1, H, C, D // H], dtype), + jax.random.normal(k3, [1, H, C, D // H], dtype), + ] + + layer = LlamaEncoderLayer(D, F, H, dtype=dtype) + params = layer.init(k4, x, mask=None, cache=cache)["params"] + + @jax.jit + def model_fn(x, mask, cache): + return layer.apply({"params": params}, x, mask=mask, cache=cache) + + T = measure(model_fn, x, cache) + + print("Time per layer per token:", T, "ms") + print("Lower bound total time per token:", T * 32, "ms") diff --git a/benchmarks/python/llama_torch_bench.py b/benchmarks/python/llama_torch_bench.py new file mode 100644 index 0000000000..25b24484b1 --- /dev/null +++ b/benchmarks/python/llama_torch_bench.py @@ -0,0 +1,197 @@ +import math +import time + +import torch +import torch.nn as nn +import torch.mps + + +def sync_if_needed(x): + if x.device != torch.device("cpu"): + torch.mps.synchronize() + + +class RoPE(nn.Module): + def __init__(self, dims: int, traditional: bool = False): + super().__init__() + self.dims = dims + self.traditional = traditional + + def _compute_rope(self, costheta, sintheta, x): + x1 = x[..., : self.dims // 2] + x2 = x[..., self.dims // 2 : self.dims] + rx1 = x1 * costheta - x2 * sintheta + rx2 = x1 * sintheta + x2 * costheta + + if self.dims < x.shape[-1]: + rx = torch.cat([rx1, rx2, x[..., self.dims :]], dim=-1) + else: + rx = torch.cat([rx1, rx2], dim=-1) + + return rx + + def _compute_traditional_rope(self, costheta, sintheta, x): + x1 = x[..., ::2] + x2 = x[..., 1::2] + rx1 = x1 * costheta - x2 * sintheta + rx2 = x1 * sintheta + x2 * costheta + + if self.dims < x.shape[-1]: + raise NotImplementedError( + "RoPE doesn't implement partial traditional application" + ) + + rx = torch.cat([rx1[..., None], rx2[..., None]], dim=-1) + + return rx + + def forward(self, x, offset: int = 0): + shape = x.shape + x = x.view(-1, shape[-2], shape[-1]) + N = x.shape[1] + offset + costheta, sintheta = RoPE.create_cos_sin_theta( + N, self.dims, offset=offset, device=x.device, dtype=x.dtype + ) + + rope = ( + self._compute_traditional_rope if self.traditional else self._compute_rope + ) + rx = rope(costheta, sintheta, x) + + return rx.view(*shape) + + @staticmethod + def create_cos_sin_theta( + N: int, + D: int, + offset: int = 0, + base: float = 10000, + device="cpu", + dtype=torch.float32, + ): + D = D // 2 + positions = torch.arange(offset, N, dtype=dtype, device=device) + freqs = torch.exp( + -torch.arange(0, D, dtype=dtype, device=device) * (math.log(base) / D) + ) + theta = positions.view(-1, 1) * freqs.view(1, -1) + costheta = torch.cos(theta) + sintheta = torch.sin(theta) + + return costheta, sintheta + + +class RMSNorm(nn.Module): + def __init__(self, dims: int, epsilon: float = 1e-6): + super().__init__() + self.gamma = nn.Parameter(torch.ones((dims,))) + self.epsilon = epsilon + + def forward(self, x): + n = torch.rsqrt(x.square().mean(dim=-1, keepdims=True) + self.epsilon) + return self.gamma * x * n + + +class LlamaAttention(nn.Module): + def __init__(self, dims: int, num_heads: int): + super().__init__() + self.num_heads = num_heads + self.rope = RoPE(dims // num_heads, True) + self.query_proj = nn.Linear(dims, dims, bias=False) + self.key_proj = nn.Linear(dims, dims, bias=False) + self.value_proj = nn.Linear(dims, dims, bias=False) + self.out_proj = nn.Linear(dims, dims, bias=False) + + def forward(self, queries, keys, values, mask=None, cache=None): + queries = self.query_proj(queries) + keys = self.key_proj(keys) + values = self.value_proj(values) + + num_heads = self.num_heads + B, L, D = queries.shape + queries = queries.view(B, L, num_heads, -1).permute(0, 2, 1, 3) + keys = keys.view(B, L, num_heads, -1).permute(0, 2, 1, 3) + values = values.view(B, L, num_heads, -1).permute(0, 2, 1, 3) + + if cache is not None: + key_cache, value_cache = cache + queries = self.rope(queries, offset=key_cache.shape[2]) + keys = self.rope(keys, offset=key_cache.shape[2]) + keys = torch.cat([key_cache, keys], dim=2) + values = torch.cat([value_cache, values], dim=2) + else: + queries = self.rope(queries) + keys = self.rope(keys) + + # Dimensions are [batch x num heads x sequence x hidden dim] + scale = math.sqrt(1 / queries.shape[-1]) + scores = (queries * scale) @ keys.permute(0, 1, 3, 2) + if mask is not None: + scores = scores + mask + scores = torch.softmax(scores, dim=-1) + values_hat = (scores @ values).permute(0, 2, 1, 3).reshape(B, L, -1) + + return self.out_proj(values_hat), (keys, values) + + +class LlamaEncoderLayer(nn.Module): + def __init__(self, dims: int, mlp_dims: int, num_heads: int): + super().__init__() + + self.attention = LlamaAttention(dims, num_heads) + + self.norm1 = RMSNorm(dims) + self.norm2 = RMSNorm(dims) + + self.linear1 = nn.Linear(dims, mlp_dims, bias=False) + self.linear2 = nn.Linear(dims, mlp_dims, bias=False) + self.linear3 = nn.Linear(mlp_dims, dims, bias=False) + + def forward(self, x, mask=None, cache=None): + y = self.norm1(x) + y, cache = self.attention(y, y, y, mask, cache) + x = x + y + + y = self.norm2(x) + a = self.linear1(y) + b = self.linear2(y) + y = torch.nn.functional.silu(a) * b + y = self.linear3(y) + x = x + y + + return x, cache + + +@torch.no_grad() +def measure(model, x, cache): + for i in range(5): + y, c = model(x, mask=None, cache=cache) + sync_if_needed(x) + + start = time.time() + for i in range(5): + y, c = model(x, mask=None, cache=cache) + sync_if_needed(x) + end = time.time() + return (end - start) * 1000 / 5 + + +if __name__ == "__main__": + H = 32 + D = 4096 + F = 43 * 256 + C = 1000 + device = torch.device("mps") + dtype = torch.float16 + + layer = LlamaEncoderLayer(D, F, H).to(device).to(dtype) + x = torch.randn(1, 1, D).to(device).to(dtype) + cache = [ + torch.randn(1, H, C, D // H).to(device).to(dtype), + torch.randn(1, H, C, D // H).to(device).to(dtype), + ] + + T = measure(layer, x, cache) + + print("Time per layer per token:", T, "ms") + print("Lower bound total time per token:", T * 32, "ms") diff --git a/benchmarks/python/single_ops.py b/benchmarks/python/single_ops.py new file mode 100644 index 0000000000..f2977175b5 --- /dev/null +++ b/benchmarks/python/single_ops.py @@ -0,0 +1,106 @@ +import argparse +import mlx.core as mx + +from time_utils import time_fn + + +def time_add(): + a = mx.random.uniform(shape=(32, 1024, 1024)) + b = mx.random.uniform(shape=(32, 1024, 1024)) + mx.eval(a, b) + time_fn(mx.add, a, b) + + aT = mx.transpose(a, [0, 2, 1]) + mx.eval(aT) + + def transpose_add(a, b): + return mx.add(a, b) + + time_fn(transpose_add, aT, b) + + b = mx.random.uniform(shape=(1024,)) + mx.eval(b) + + def slice_add(a, b): + return mx.add(a, b) + + time_fn(slice_add, a, b) + + b = mx.reshape(b, (1, 1024, 1)) + mx.eval(b) + + def mid_slice_add(a, b): + return mx.add(a, b) + + time_fn(mid_slice_add, a, b) + + +def time_matmul(): + a = mx.random.uniform(shape=(1024, 1024)) + b = mx.random.uniform(shape=(1024, 1024)) + mx.eval(a, b) + time_fn(mx.matmul, a, b) + + +def time_negative(): + a = mx.random.uniform(shape=(10000, 1000)) + mx.eval(a) + + def negative(a): + return -a + + mx.eval(a) + + time_fn(negative, a) + + +def time_exp(): + a = mx.random.uniform(shape=(1000, 100)) + mx.eval(a) + time_fn(mx.exp, a) + + +def time_logsumexp(): + a = mx.random.uniform(shape=(64, 10, 10000)) + mx.eval(a) + time_fn(mx.logsumexp, a, axis=-1) + + +def time_take(): + a = mx.random.uniform(shape=(10000, 500)) + ids = mx.random.randint(low=0, high=10000, shape=(20, 10)) + ids = [mx.reshape(idx, (-1,)) for idx in ids] + mx.eval(ids) + + def random_take(): + return [mx.take(a, idx, 0) for idx in ids] + + time_fn(random_take) + + +def time_reshape_transposed(): + x = mx.random.uniform(shape=(256, 256, 128)) + mx.eval(x) + + def reshape_transposed(): + return mx.reshape(mx.transpose(x, (1, 0, 2)), (-1,)) + + time_fn(reshape_transposed) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("MLX benchmarks.") + parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.") + args = parser.parse_args() + if args.gpu: + mx.set_default_device(mx.gpu) + else: + mx.set_default_device(mx.cpu) + + time_add() + time_matmul() + time_exp() + time_negative() + time_logsumexp() + time_take() + time_reshape_transposed() diff --git a/docs/.gitignore b/docs/.gitignore new file mode 100644 index 0000000000..27834a90de --- /dev/null +++ b/docs/.gitignore @@ -0,0 +1 @@ +src/python/_autosummary*/ diff --git a/docs/README.md b/docs/README.md new file mode 100644 index 0000000000..d0c35a31c6 --- /dev/null +++ b/docs/README.md @@ -0,0 +1,36 @@ +## Build the Docs + +### Setup (do once) + +Install [sphinx](https://www.sphinx-doc.org/en/master/usage/installation.html) +for example with `conda`: + +``` +conda install sphinx +pip install sphinx-rtd-theme +``` + +### Build + +Build the docs from `mlx/docs/` + +``` +make html +``` + +View the docs by running a server in `mlx/docs/build/html/`: + +``` +python -m http.server +``` + +and point your browser to `http://localhost:`. + +### Push to Github Pages + +Check-out the `gh-pages` branch (`git switch gh-pages`) and build +the docs. Then force add the `build/html` directory: + +`git add -f build/html` + +Commit and push the changes to the `gh-pages` branch. diff --git a/docs/src/_templates/optimizers-template.rst b/docs/src/_templates/optimizers-template.rst new file mode 100644 index 0000000000..80f049b4af --- /dev/null +++ b/docs/src/_templates/optimizers-template.rst @@ -0,0 +1,20 @@ +{{ fullname | escape | underline}} + +.. currentmodule:: {{ module }} + +.. autoclass:: {{ objname }} + + {% block methods %} + + {% if methods %} + .. rubric:: {{ _('Methods') }} + + .. autosummary:: + {% for item in methods %} + {%- if item not in inherited_members %} + ~{{ name }}.{{ item }} + {%- endif %} + {%- endfor %} + {% endif %} + {% endblock %} + diff --git a/docs/src/conf.py b/docs/src/conf.py new file mode 100644 index 0000000000..9018c231fa --- /dev/null +++ b/docs/src/conf.py @@ -0,0 +1,44 @@ +# -*- coding: utf-8 -*- + +import os +import subprocess + +# -- Project information ----------------------------------------------------- + +project = "MLX" +copyright = "2023, MLX Contributors" +author = "MLX Contributors" +version = "0.0.0" +release = "0.0.0" + +# -- General configuration --------------------------------------------------- + +extensions = [ + "sphinx.ext.autodoc", + "sphinx.ext.autosummary", + "sphinx.ext.intersphinx", + "sphinx.ext.napoleon", +] + +python_use_unqualified_type_names = True +autosummary_generate = True + +intersphinx_mapping = { + "https://docs.python.org/3": None, + "https://numpy.org/doc/stable/": None, +} + +templates_path = ["_templates"] +html_static_path = ["_static"] +source_suffix = ".rst" +master_doc = "index" +highlight_language = "python" +pygments_style = "sphinx" + +# -- Options for HTML output ------------------------------------------------- + +html_theme = "sphinx_rtd_theme" + +# -- Options for HTMLHelp output --------------------------------------------- + +htmlhelp_basename = "mlx_doc" diff --git a/docs/src/examples/llama-inference.rst b/docs/src/examples/llama-inference.rst new file mode 100644 index 0000000000..c1a6a7e84f --- /dev/null +++ b/docs/src/examples/llama-inference.rst @@ -0,0 +1,382 @@ +LLM inference +============== + +MLX enables efficient inference of large-ish transformers on Apple silicon +without compromising on ease of use. In this example we will create an +inference script for the Llama family of transformer models in which the model +is defined in less than 200 lines of python. + +Implementing the model +---------------------- + +We will use the neural network building blocks defined in the :mod:`mlx.nn` +module to concisely define the model architecture. + +Attention layer +^^^^^^^^^^^^^^^^ + +We will start with the llama attention layer which notably uses the RoPE +positional encoding. [1]_ In addition, our attention layer will optionally use a +key/value cache that will be concatenated with the provided keys and values to +support efficient inference. + +Our implementation uses :class:`mlx.nn.Linear` for all the projections and +:class:`mlx.nn.RoPE` for the positional encoding. + +.. code-block:: python + + import mlx.core as mx + import mlx.nn as nn + + class LlamaAttention(nn.Module): + def __init__(self, dims: int, num_heads: int): + super().__init__() + + self.num_heads = num_heads + + self.rope = nn.RoPE(dims // num_heads, traditional=True) + self.query_proj = nn.Linear(dims, dims, bias=False) + self.key_proj = nn.Linear(dims, dims, bias=False) + self.value_proj = nn.Linear(dims, dims, bias=False) + self.out_proj = nn.Linear(dims, dims, bias=False) + + def __call__(self, queries, keys, values, mask=None, cache=None): + queries = self.query_proj(queries) + keys = self.key_proj(keys) + values = self.value_proj(values) + + # Extract some shapes + num_heads = self.num_heads + B, L, D = queries.shape + + # Prepare the queries, keys and values for the attention computation + queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) + values = values.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) + + # Add RoPE to the queries and keys and combine them with the cache + if cache is not None: + key_cache, value_cache = cache + queries = self.rope(queries, offset=key_cache.shape[2]) + keys = self.rope(keys, offset=key_cache.shape[2]) + keys = mx.concatenate([key_cache, keys], axis=2) + values = mx.concatenate([value_cache, values], axis=2) + else: + queries = self.rope(queries) + keys = self.rope(keys) + + # Finally perform the attention computation + scale = math.sqrt(1 / queries.shape[-1]) + scores = (queries * scale) @ keys.transpose(0, 1, 3, 2) + if mask is not None: + scores = scores + mask + scores = mx.softmax(scores, axis=-1) + values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) + + # Note that we return the keys and values to possibly be used as a cache + return self.out_proj(values_hat), (keys, values) + +Encoder layer +^^^^^^^^^^^^^ + +The other component of the Llama model is the encoder layer which uses RMS +normalization [2]_ and SwiGLU. [3]_ For RMS normalization we will use +:class:`mlx.nn.RMSNorm` that is already provided in :mod:`mlx.nn`. + +.. code-block:: python + + class LlamaEncoderLayer(nn.Module): + def __init__(self, dims: int, mlp_dims: int, num_heads: int): + super().__init__() + + self.attention = LlamaAttention(dims, num_heads) + + self.norm1 = nn.RMSNorm(dims) + self.norm2 = nn.RMSNorm(dims) + + self.linear1 = nn.Linear(dims, mlp_dims, bias=False) + self.linear2 = nn.Linear(dims, mlp_dims, bias=False) + self.linear3 = nn.Linear(mlp_dims, dims, bias=False) + + def __call__(self, x, mask=None, cache=None): + y = self.norm1(x) + y, cache = self.attention(y, y, y, mask, cache) + x = x + y + + y = self.norm2(x) + a = self.linear1(y) + b = self.linear2(y) + y = a * mx.sigmoid(a) * b + y = self.linear3(y) + x = x + y + + return x, cache + +Full model +^^^^^^^^^^ + +To implement any Llama model we simply have to combine ``LlamaEncoderLayer`` +instances with an :class:`mlx.nn.Embedding` to embed the input tokens. + +.. code-block:: python + + class Llama(nn.Module): + def __init__( + self, num_layers: int, vocab_size: int, dims: int, mlp_dims: int, num_heads: int + ): + super().__init__() + + self.embedding = nn.Embedding(vocab_size, dims) + self.layers = [ + LlamaEncoderLayer(dims, mlp_dims, num_heads) for _ in range(num_layers) + ] + self.norm = nn.RMSNorm(dims) + self.out_proj = nn.Linear(dims, vocab_size, bias=False) + + def __call__(self, x): + mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1]) + mask = mask.astype(self.embedding.weight.dtype) + + x = self.embedding(x) + for l in self.layers: + x, _ = l(x, mask) + x = self.norm(x) + return self.out_proj(x) + +Note that in the implementation above we use a simple list to hold the encoder +layers but using ``model.parameters()`` will still consider these layers. + +Generation +^^^^^^^^^^^ + +Our ``Llama`` module can be used for training but not inference as the +``__call__`` method above processes one input, completely ignores the cache and +performs no sampling whatsoever. In the rest of this subsection, we will +implement the inference function as a python generator that processes the +prompt and then autoregressively yields tokens one at a time. + +.. code-block:: python + + class Llama(nn.Module): + ... + + def generate(self, x, temp=1.0): + cache = [] + + # Make an additive causal mask. We will need that to process the prompt. + mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1]) + mask = mask.astype(self.embedding.weight.dtype) + + # First we process the prompt x the same way as in __call__ but + # save the caches in cache + x = self.embedding(x) + for l in self.layers: + x, c = l(x, mask=mask) + cache.append(c) # <--- we store the per layer cache in a + # simple python list + x = self.norm(x) + y = self.out_proj(x[:, -1]) # <--- we only care about the last logits + # that generate the next token + y = mx.random.categorical(y * (1/temp)) + + # y now has size [1] + # Since MLX is lazily evaluated nothing is computed yet. + # Calling y.item() would force the computation to happen at + # this point but we can also choose not to do that and let the + # user choose when to start the computation. + yield y + + # Now we parsed the prompt and generated the first token we + # need to feed it back into the model and loop to generate the + # rest. + while True: + # Unsqueezing the last dimension to add a sequence length + # dimension of 1 + x = y[:, None] + + x = self.embedding(x) + for i in range(len(cache)): + # We are overwriting the arrays in the cache list. When + # the computation will happen, MLX will be discarding the + # old cache the moment it is not needed anymore. + x, cache[i] = self.layers[i](x, mask=None, cache=cache[i]) + x = self.norm(x) + y = self.out_proj(x[:, -1]) + y = mx.random.categorical(y * (1/temp)) + + yield y + +Putting it all together +^^^^^^^^^^^^^^^^^^^^^^^ + +We now have everything we need to create a Llama model and sample tokens from +it. In the following code, we randomly initialize a small Llama model, process +6 tokens of prompt and generate 10 tokens. + +.. code-block:: python + + model = Llama(num_layers=12, vocab_size=8192, dims=512, mlp_dims=1024, num_heads=8) + + # Since MLX is lazily evaluated nothing has actually been materialized yet. + # We could have set the `dims` to 20_000 on a machine with 8GB of RAM and the + # code above would still run. Let's actually materialize the model. + mx.eval(model.parameters()) + + prompt = mx.array([[1, 10, 8, 32, 44, 7]]) # <-- Note the double brackets because we + # have a batch dimension even + # though it is 1 in this case + + generated = [t for i, t in zip(range(10), model.generate(prompt, 0.8))] + + # Since we haven't evaluated anything, nothing is computed yet. The list + # `generated` contains the arrays that hold the computation graph for the + # full processing of the prompt and the generation of 10 tokens. + # + # We can evaluate them one at a time, or all together. Concatenate them or + # print them. They would all result in very similar runtimes and give exactly + # the same results. + mx.eval(generated) + +Converting the weights +---------------------- + +This section assumes that you have access to the original Llama weights and the +SentencePiece model that comes with them. We will write a small script to +convert the PyTorch weights to MLX compatible ones and write them in a NPZ file +that can be loaded directly by MLX. + +.. code-block:: python + + import argparse + from itertools import starmap + + import numpy as np + import torch + + def map_torch_to_mlx(key, value): + if "tok_embedding" in key: + key = "embedding.weight" + + elif "norm" in key: + key = key.replace("attention_norm", "norm1").replace("ffn_norm", "norm2") + + elif "wq" in key or "wk" in key or "wv" in key or "wo" in key: + key = key.replace("wq", "query_proj") + key = key.replace("wk", "key_proj") + key = key.replace("wv", "value_proj") + key = key.replace("wo", "out_proj") + + elif "w1" in key or "w2" in key or "w3" in key: + # The FFN is a separate submodule in PyTorch + key = key.replace("feed_forward.w1", "linear1") + key = key.replace("feed_forward.w3", "linear2") + key = key.replace("feed_forward.w2", "linear3") + + elif "output" in key: + key = key.replace("output", "out_proj") + + elif "rope" in key: + return None, None + + return key, value.numpy() + + + if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Convert Llama weights to MLX") + parser.add_argument("torch_weights") + parser.add_argument("output_file") + args = parser.parse_args() + + state = torch.load(args.torch_weights) + np.savez( + args.output_file, + **{k: v for k, v in starmap(map_torch_to_mlx, state.items()) if k is not None} + ) + + +Weight loading and benchmarking +------------------------------- + +After converting the weights to be compatible to our implementation, all that is +left is to load them from disk and we can finally use the LLM to generate text. +We can load numpy format files using the :func:`mlx.core.load` operation. + +To create a parameter dictionary from the key/value representation of NPZ files +we will use the :func:`mlx.utils.tree_unflatten` helper method as follows: + +.. code-block:: python + + from mlx.utils import tree_unflatten + + model.update(tree_unflatten(list(mx.load(weight_file).items()))) + +:meth:`mlx.utils.tree_unflatten` will take keys from the NPZ file that look +like ``layers.2.attention.query_proj.weight`` and will transform them to + +.. code-block:: python + + {"layers": [..., ..., {"attention": {"query_proj": {"weight": ...}}}]} + +which can then be used to update the model. Note that the method above incurs +several unnecessary copies from disk to numpy and then from numpy to MLX. It +will be replaced in the future with direct loading to MLX. + +You can download the full example code in `mlx-examples `_. Assuming, the +existence of ``weights.pth`` and ``tokenizer.model`` in the current working +directory we can play around with our inference script as follows (the timings +are representative of an M1 Ultra and the 7B parameter Llama model): + +.. code-block:: bash + + $ python convert.py weights.pth llama-7B.mlx.npz + $ python llama.py llama-7B.mlx.npz tokenizer.model 'Call me Ishmael. Some years ago never mind how long precisely' + [INFO] Loading model from disk: 5.247 s + Press enter to start generation + ------ + , having little or no money in my purse, and nothing of greater consequence in my mind, I happened to be walking down Gower Street in the afternoon, in the heavy rain, and I saw a few steps off, a man in rags, who sat upon his bundle and looked hard into the wet as if he were going to cry. I watched him attentively for some time, and could not but observe that, though a numerous crowd was hurrying up and down, + ------ + [INFO] Prompt processing: 0.437 s + [INFO] Full generation: 4.330 s + +We observe that 4.3 seconds are required to generate 100 tokens and 0.4 seconds +of those are spent processing the prompt. This amounts to a little over **39 ms +per token**. + +By running with a much bigger prompt we can see that the per token generation +time as well as the prompt processing time remains almost constant. + +.. code-block:: bash + + $ python llama.py llama-7B.mlx.npz tokenizer.model 'Call me Ishmael. Some years ago never mind how long precisely, having little or no money in my purse, and nothing of greater consequence in my mind, I happened to be walking down Gower Street in the afternoon, in the heavy rain, and I saw a few steps off, a man in rags, who sat upon his bundle and looked hard into the wet as if he were going to cry. I watched him attentively for some time, and could not but observe that, though a numerous crowd was hurrying up and down, nobody took the least notice of him. I stopped at last, at a little distance, as if I had been in doubt, and after looking on a few minutes, walked straight up to him. He slowly raised his eyes, and fixed them upon me for a moment, without speaking, and then resumed his place and posture as before. I stood looking at him for a while, feeling very much pain at heart, and then said to him, “What are you doing there?” Something like a smile passed over his face, as he said slowly, “I am waiting for someone; but it has been three quarters of an hour now, and he has not come.” “What is it you are waiting for?” said I. Still he made no immediate reply, but again put his face down upon his hands, and did not' + [INFO] Loading model from disk: 5.247 s + Press enter to start generation + ------ + take his eyes from the ground. “What is it you are waiting for?” said I. “I am not accustomed to be thus questioned,” said he. “You look like a reasonable man—tell me, then, what are you waiting for?” “You would not understand,” he replied; “and how could you help me, if I were to tell you?” “I should not only understand, but would do all that I could,” said I. He did not + ------ + [INFO] Prompt processing: 0.579 s + [INFO] Full generation: 4.690 s + $ python llama.py --num-tokens 500 llama-7B.mlx.npz tokenizer.model 'Call me Ishmael. Some years ago never mind how long precisely, having little or no money in my purse, and nothing of greater consequence in my mind, I happened to be walking down Gower Street in the afternoon, in the heavy rain, and I saw a few steps off, a man in rags, who sat upon his bundle and looked hard into the wet as if he were going to cry. I watched him attentively for some time, and could not but observe that, though a numerous crowd was hurrying up and down, nobody took the least notice of him. I stopped at last, at a little distance, as if I had been in doubt, and after looking on a few minutes, walked straight up to him. He slowly raised his eyes, and fixed them upon me for a moment, without speaking, and then resumed his place and posture as before. I stood looking at him for a while, feeling very much pain at heart, and then said to him, “What are you doing there?” Something like a smile passed over his face, as he said slowly, “I am waiting for someone; but it has been three quarters of an hour now, and he has not come.” “What is it you are waiting for?” said I. Still he made no immediate reply, but again put his face down upon his hands, and did not' + [INFO] Loading model from disk: 5.628 s + Press enter to start generation + ------ + take his eyes from the ground. “What is it you are waiting for?” said I. “I am not accustomed to be thus questioned,” said he. “You look like a reasonable man—tell me, then, what are you waiting for?” “You would not understand,” he replied; “and how could you help me, if I were to tell you?” “I should not only understand, but would do all that I could,” said I. He did not reply, but still went on looking at the ground, and took hold of his bundle with a nervous trembling. I waited some time, and then resumed. “It is of no use to say you would not understand, if I were to tell you,” said he. “I have not told you why I am waiting for him,” said I. “And I am sure I should not understand,” replied he. “I will tell you then,” said I, “and, perhaps, you would not be surprised.” “No matter,” said he, “I shall be surprised anyhow; so tell me why you are waiting for him.” “He is my friend,” said I. “Yes,” said he, with a slight smile, “I know.” “He has been kind to me,” said I, “and I am waiting for him. I want to see him, and could have waited as I am now, for a much longer time.” “He will not soon come,” said he. “Unless he sees you here, he will not know of your having waited, and he will be very unlikely to come.” “No matter,” said I, “I shall wait for him.” “This is a strange thing,” said he, still with the same amused smile. “How did you know,” said I, “that he was coming? How should you be waiting?” “That is my secret,” said he. “And you expect him?” “Yes,” said I. “Are you disappointed then, if he does not come?” “No,” said I, “it is his secret, not mine.” “If he comes,” said he, “do you mean to go straight away?” “Yes,” said I, “I cannot be happy if I do not go straight away after him.” “Did you know this place before?” asked he. “Yes,” said I. “Is there any shop to buy food here?” “ + ------ + [INFO] Prompt processing: 0.633 s + [INFO] Full generation: 21.475 s + +Scripts +------- + +.. admonition:: Download the code + + The full example code is available in `mlx-examples `_. + +.. code: `https://github.com/ml-explore/mlx-examples/tree/main/llama`_ + +.. [1] Su, J., Lu, Y., Pan, S., Murtadha, A., Wen, B. and Liu, Y., 2021. + Roformer: Enhanced transformer with rotary position embedding. arXiv + preprint arXiv:2104.09864. +.. [2] Zhang, B. and Sennrich, R., 2019. Root mean square layer normalization. + Advances in Neural Information Processing Systems, 32. +.. [3] Shazeer, N., 2020. Glu variants improve transformer. arXiv preprint + arXiv:2002.05202. diff --git a/docs/src/index.rst b/docs/src/index.rst new file mode 100644 index 0000000000..a5bf3e9a2c --- /dev/null +++ b/docs/src/index.rst @@ -0,0 +1,49 @@ +MLX +=== + +.. toctree:: + :caption: Install + :maxdepth: 1 + + install + +.. toctree:: + :caption: Usage + :maxdepth: 1 + + quick_start + using_streams + +.. toctree:: + :caption: Examples + :maxdepth: 1 + + examples/linear_regression + examples/mlp + examples/llama-inference + +.. toctree:: + :caption: Further Reading + :maxdepth: 1 + + dev/extensions + +.. toctree:: + :caption: Python API Reference + :maxdepth: 1 + + python/array + python/devices_and_streams + python/ops + python/random + python/transforms + python/fft + python/nn + python/optimizers + python/tree_utils + +.. toctree:: + :caption: C++ API Reference + :maxdepth: 1 + + cpp/ops diff --git a/docs/src/install.rst b/docs/src/install.rst new file mode 100644 index 0000000000..fd4f9becdf --- /dev/null +++ b/docs/src/install.rst @@ -0,0 +1,102 @@ +Build and Install +================= + +Install from PyPI +----------------- + +MLX is available at Apple's internal PyPI repository. All you have to do to use +MLX with your own Apple silicon computer is + +.. code-block:: shell + + pip install apple-mlx -i https://pypi.apple.com/simple + +Build from source +----------------- + +Build Requirements +^^^^^^^^^^^^^^^^^^ + +- A C++ compiler with C++17 support (e.g. Clang >= 5.0) +- `cmake `_ -- version 3.24 or later, and ``make`` + + +Python API +^^^^^^^^^^ + +To build and install the MLX python library from source, first, clone MLX from +`its GitHub repo `_: + +.. code-block:: shell + + git clone git@github.com:ml-explore/mlx.git mlx && cd mlx + +Make sure that you have `pybind11 `_ +installed. You can install ``pybind11`` with ``pip``, ``brew`` or ``conda`` as follows: + +.. code-block:: shell + + pip install "pybind11[global]" + conda install pybind11 + brew install pybind11 + +Then simply build and install it using pip: + +.. code-block:: shell + + env CMAKE_BUILD_PARALLEL_LEVEL="" pip install . + + +C++ API +^^^^^^^ + +Currently, MLX must be built and installed from source. + +Similarly to the python library, to build and install the MLX C++ library start +by cloning MLX from `its GitHub repo +`_: + +.. code-block:: shell + + git clone git@github.com:ml-explore/mlx.git mlx && cd mlx + +Create a build directory and run CMake and make: + +.. code-block:: shell + + mkdir -p build && cd build + cmake .. && make -j + +Run tests with: + +.. code-block:: shell + + make test + +Install with: + +.. code-block:: shell + + make install + +Note that the built ``mlx.metallib`` file should be either at the same +directory as the executable statically linked to ``libmlx.a`` or the +preprocessor constant ``METAL_PATH`` should be defined at build time and it +should point to the path to the built metal library. + +.. list-table:: Build Options + :widths: 25 8 + :header-rows: 1 + + * - Option + - Default + * - MLX_BUILD_TESTS + - ON + * - MLX_BUILD_EXAMPLES + - OFF + * - MLX_BUILD_BENCHMARKS + - OFF + * - MLX_BUILD_METAL + - ON + * - MLX_BUILD_PYTHON_BINDINGS + - OFF diff --git a/docs/src/python/fft.rst b/docs/src/python/fft.rst new file mode 100644 index 0000000000..9e4be084b9 --- /dev/null +++ b/docs/src/python/fft.rst @@ -0,0 +1,22 @@ +.. _fft: + +FFT +=== + +.. currentmodule:: mlx.core.fft + +.. autosummary:: + :toctree: _autosummary + + fft + ifft + fft2 + ifft2 + fftn + ifftn + rfft + irfft + rfft2 + irfft2 + rfftn + irfftn diff --git a/docs/src/python/nn.rst b/docs/src/python/nn.rst new file mode 100644 index 0000000000..114fd8a904 --- /dev/null +++ b/docs/src/python/nn.rst @@ -0,0 +1,172 @@ +.. _nn: + +.. currentmodule:: mlx.nn + +Neural Networks +=============== + +Writing arbitrarily complex neural networks in MLX can be done using only +:class:`mlx.core.array` and :meth:`mlx.core.value_and_grad`. However, this requires the +user to write again and again the same simple neural network operations as well +as handle all the parameter state and initialization manually and explicitly. + +The module :mod:`mlx.nn` solves this problem by providing an intuitive way of +composing neural network layers, initializing their parameters, freezing them +for finetuning and more. + +Quick Start with Neural Networks +--------------------------------- + +.. code-block:: python + + import mlx.core as mx + import mlx.nn as nn + + class MLP(nn.Module): + def __init__(self, in_dims: int, out_dims: int): + super().__init__() + + self.layers = [ + nn.Linear(in_dims, 128), + nn.Linear(128, 128), + nn.Linear(128, out_dims), + ] + + def __call__(self, x): + for i, l in enumerate(self.layers): + x = mx.maximum(x, 0) if i > 0 else x + x = l(x) + return x + + # The model is created with all its parameters but nothing is initialized + # yet because MLX is lazily evaluated + mlp = MLP(2, 10) + + # We can access its parameters by calling mlp.parameters() + params = mlp.parameters() + print(params["layers"][0]["weight"].shape) + + # Printing a parameter will cause it to be evaluated and thus initialized + print(params["layers"][0]) + + # We can also force evaluate all parameters to initialize the model + mx.eval(mlp.parameters()) + + # A simple loss function. + # NOTE: It doesn't matter how it uses the mlp model. It currently captures + # it from the local scope. It could be a positional argument or a + # keyword argument. + def l2_loss(x, y): + y_hat = mlp(x) + return (y_hat - y).square().mean() + + # Calling `nn.value_and_grad` instead of `mx.value_and_grad` returns the + # gradient with respect to `mlp.trainable_parameters()` + loss_and_grad = nn.value_and_grad(mlp, l2_loss) + + +.. _module_class: + +The Module Class +---------------- + +The workhorse of any neural network library is the :class:`Module` class. In +MLX the :class:`Module` class is a container of :class:`mlx.core.array` or +:class:`Module` instances. Its main function is to provide a way to +recursively **access** and **update** its parameters and those of its +submodules. + +Parameters +^^^^^^^^^^ + +A parameter of a module is any public member of type :class:`mlx.core.array` (its +name should not start with ``_``). It can be arbitrarily nested in other +:class:`Module` instances or lists and dictionaries. + +:meth:`Module.parameters` can be used to extract a nested dictionary with all +the parameters of a module and its submodules. + +A :class:`Module` can also keep track of "frozen" parameters. +:meth:`Module.trainable_parameters` returns only the subset of +:meth:`Module.parameters` that is not frozen. When using +:meth:`mlx.nn.value_and_grad` the gradients returned will be with respect to these +trainable parameters. + +Updating the parameters +^^^^^^^^^^^^^^^^^^^^^^^ + +MLX modules allow accessing and updating individual parameters. However, most +times we need to update large subsets of a module's parameters. This action is +performed by :meth:`Module.update`. + +Value and grad +-------------- + +Using a :class:`Module` does not preclude using MLX's high order function +transformations (:meth:`mlx.core.value_and_grad`, :meth:`mlx.core.grad`, etc.). However, +these function transformations assume pure functions, namely the parameters +should be passed as an argument to the function being transformed. + +There is an easy pattern to achieve that with MLX modules + +.. code-block:: python + + model = ... + + def f(params, other_inputs): + model.update(params) # <---- Necessary to make the model use the passed parameters + return model(other_inputs) + + f(model.trainable_parameters(), mx.zeros((10,))) + +However, :meth:`mlx.nn.value_and_grad` provides precisely this pattern and only +computes the gradients with respect to the trainable parameters of the model. + +In detail: + +- it wraps the passed function with a function that calls :meth:`Module.update` + to make sure the model is using the provided parameters. +- it calls :meth:`mlx.core.value_and_grad` to transform the function into a function + that also computes the gradients with respect to the passed parameters. +- it wraps the returned function with a function that passes the trainable + parameters as the first argument to the function returned by + :meth:`mlx.core.value_and_grad` + +.. autosummary:: + :toctree: _autosummary + + value_and_grad + +Neural Network Layers +--------------------- + +.. autosummary:: + :toctree: _autosummary + :template: nn-module-template.rst + + Embedding + ReLU + GELU + SiLU + Linear + Conv1d + Conv2d + LayerNorm + RMSNorm + GroupNorm + RoPE + MultiHeadAttention + Sequential + +Layers without parameters (e.g. activation functions) are also provided as +simple functions. + +.. autosummary:: + :toctree: _autosummary_functions + :template: nn-module-template.rst + + gelu + gelu_approx + gelu_fast_approx + relu + silu diff --git a/docs/src/python/nn/module.rst b/docs/src/python/nn/module.rst new file mode 100644 index 0000000000..e14ba96f47 --- /dev/null +++ b/docs/src/python/nn/module.rst @@ -0,0 +1,7 @@ +mlx.nn.Module +============= + +.. currentmodule:: mlx.nn + +.. autoclass:: Module + :members: diff --git a/docs/src/python/optimizers.rst b/docs/src/python/optimizers.rst new file mode 100644 index 0000000000..7f5d3a0679 --- /dev/null +++ b/docs/src/python/optimizers.rst @@ -0,0 +1,41 @@ +.. _optimizers: + +Optimizers +========== + +The optimizers in MLX can be used both with :mod:`mlx.nn` but also with pure +:mod:`mlx.core` functions. A typical example involves calling +:meth:`Optimizer.update` to update a model's parameters based on the loss +gradients and subsequently calling :func:`mlx.core.eval` to evaluate both the +model's parameters and the **optimizer state**. + +.. code-block:: python + + # Create a model + model = MLP(num_layers, train_images.shape[-1], hidden_dim, num_classes) + mx.eval(model.parameters()) + + # Create the gradient function and the optimizer + loss_and_grad_fn = nn.value_and_grad(model, loss_fn) + optimizer = optim.SGD(learning_rate=learning_rate) + + for e in range(num_epochs): + for X, y in batch_iterate(batch_size, train_images, train_labels): + loss, grads = loss_and_grad_fn(model, X, y) + + # Update the model with the gradients. So far no computation has happened. + optimizer.update(model, grads) + + # Compute the new parameters but also the optimizer state. + mx.eval(model.parameters(), optimizer.state) + +.. currentmodule:: mlx.optimizers + +.. autosummary:: + :toctree: _autosummary + :template: optimizers-template.rst + + OptimizerState + Optimizer + SGD + Adam diff --git a/docs/src/python/random.rst b/docs/src/python/random.rst new file mode 100644 index 0000000000..8ac3eaa11c --- /dev/null +++ b/docs/src/python/random.rst @@ -0,0 +1,45 @@ +.. _random: + +Random +====== + +Random sampling functions in MLX use an implicit global PRNG state by default. +However, all function take an optional ``key`` keyword argument for when more +fine-grained control or explicit state management is needed. + +For example, you can generate random numbers with: + +.. code-block:: python + + for _ in range(3): + print(mx.random.uniform()) + +which will print a sequence of unique pseudo random numbers. Alternatively you +can explicitly set the key: + +.. code-block:: python + + key = mx.random.key(0) + for _ in range(3): + print(mx.random.uniform(key=key)) + +which will yield the same pseudo random number at each iteration. + +Following `JAX's PRNG design `_ +we use a splittable version of Threefry, which is a counter-based PRNG. + +.. currentmodule:: mlx.core.random + +.. autosummary:: + :toctree: _autosummary + + seed + key + split + bernoulli + categorical + gumbel + normal + randint + uniform + truncated_normal diff --git a/docs/src/quick_start.rst b/docs/src/quick_start.rst new file mode 100644 index 0000000000..c3e2b678b0 --- /dev/null +++ b/docs/src/quick_start.rst @@ -0,0 +1,93 @@ +Quick Start Guide +================= + +MLX is a NumPy-like array framework designed for efficient and flexible +machine learning on Apple silicon. The Python API closely follows NumPy with +a few exceptions. MLX also has a fully featured C++ API which closely follows +the Python API. + +The main differences between MLX and NumPy are: + + - **Composable function transformations**: MLX has composable function + transformations for automatic differentiation, automatic vectorization, + and computation graph optimization. + - **Lazy computation**: Computations in MLX are lazy. Arrays are only + materialized when needed. + - **Multi-device**: Operations can run on any of the suppoorted devices (CPU, + GPU, ...) + +The design of MLX is strongly inspired by frameworks like `PyTorch +`_, `Jax `_, and +`ArrayFire `_. A noteable difference from these +frameworks and MLX is the *unified memory model*. Arrays in MLX live in shared +memory. Operations on MLX arrays can be performed on any of the supported +device types without performing data copies. Currently supported device types +are the CPU and GPU. + +Basics +------ + +.. currentmodule:: mlx.core + +Import ``mlx.core`` and make an :class:`array`: + +.. code-block:: python + + >> import mlx.core as mx + >> a = mx.array([1, 2, 3, 4]) + >> a.shape + [4] + >> a.dtype + int32 + >> b = mx.array([1.0, 2.0, 3.0, 4.0]) + >> b.dtype + float32 + +Operations in MLX are lazy. The outputs of MLX operations are not computed +until they are needed. To force an array to be evaluated use +:func:`eval`. Arrays will automatically be evaluated in a few cases. For +example, inspecting a scalar with :meth:`array.item`, printing an array, +or converting an array from :class:`array` to :class:`numpy.ndarray` all +automatically evaluate the array. + +.. code-block:: python + + >> c = a + b # c not yet evaluated + >> mx.eval(c) # evaluates c + >> c = a + b + >> print(c) # Also evaluates c + array([2, 4, 6, 8], dtype=float32) + >> c = a + b + >> import numpy as np + >> np.array(c) # Also evaluates c + array([2., 4., 6., 8.], dtype=float32) + +Function and Graph Transformations +---------------------------------- + +MLX has standard function transformations like :func:`grad` and :func:`vmap`. +Transformations can be composed arbitrarily. For example +``grad(vmap(grad(fn)))`` (or any other composition) is allowed. + +.. code-block:: python + + >> x = mx.array(0.0) + >> mx.sin(x) + array(0, dtype=float32) + >> mx.grad(mx.sin)(x) + array(1, dtype=float32) + >> mx.grad(mx.grad(mx.sin))(x) + array(-0, dtype=float32) + +Other gradient transformations include :func:`vjp` for vector-Jacobian products +and :func:`jvp` for Jacobian-vector products. + +Use :func:`value_and_grad` to efficiently compute both a function's output and +gradient with respect to the function's input. + + +Devices and Streams +------------------- + + + diff --git a/docs/src/using_streams.rst b/docs/src/using_streams.rst new file mode 100644 index 0000000000..7b42e8d4b8 --- /dev/null +++ b/docs/src/using_streams.rst @@ -0,0 +1,16 @@ +Using Streams +============= + +.. currentmodule:: mlx.core + +Specifying the :obj:`Stream` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +All operations (including random number generation) take an optional +keyword argument ``stream``. The ``stream`` kwarg specifies which +:obj:`Stream` the operation should run on. If the stream is unspecified then +the operation is run on the default stream of the default device: +``mx.default_stream(mx.default_device())``. The ``stream`` kwarg can also +be a :obj:`Device` (e.g. ``stream=my_device``) in which case the operation is +run on the default stream of the provided device +``mx.default_stream(my_device)``. diff --git a/examples/cpp/logistic_regression.cpp b/examples/cpp/logistic_regression.cpp new file mode 100644 index 0000000000..c34ae7feda --- /dev/null +++ b/examples/cpp/logistic_regression.cpp @@ -0,0 +1,52 @@ +#include +#include +#include + +#include "mlx/mlx.h" +#include "timer.h" + +/** + * An example of logistic regression with MLX. + */ +using namespace mlx::core; + +int main() { + int num_features = 100; + int num_examples = 1'000; + int num_iters = 10'000; + float learning_rate = 0.1; + + // True parameters + auto w_star = random::normal({num_features}); + + // The input examples + auto X = random::normal({num_examples, num_features}); + + // Labels + auto y = matmul(X, w_star) > 0; + + // Initialize random parameters + array w = 1e-2 * random::normal({num_features}); + + auto loss_fn = [&](array w) { + auto logits = matmul(X, w); + auto scale = (1.0f / num_examples); + return scale * sum(logaddexp(array(0.0f), logits) - y * logits); + }; + + auto grad_fn = grad(loss_fn); + + auto tic = timer::time(); + for (int it = 0; it < num_iters; ++it) { + auto grad = grad_fn(w); + w = w - learning_rate * grad; + eval(w); + } + auto toc = timer::time(); + + auto loss = loss_fn(w); + auto acc = sum((matmul(X, w) > 0) == y) / num_examples; + auto throughput = num_iters / timer::seconds(toc - tic); + std::cout << "Loss " << loss << ", Accuracy, " << acc << ", Throughput " + << throughput << " (it/s)." << std::endl; +} diff --git a/examples/cpp/tutorial.cpp b/examples/cpp/tutorial.cpp new file mode 100644 index 0000000000..8426878e23 --- /dev/null +++ b/examples/cpp/tutorial.cpp @@ -0,0 +1,97 @@ +#include +#include + +#include "mlx/mlx.h" + +using namespace mlx::core; + +void array_basics() { + // Make a scalar array: + array x(1.0); + + // Get the value out of it: + auto s = x.item(); + assert(s == 1.0); + + // Scalars have a size of 1: + size_t size = x.size(); + assert(size == 1); + + // Scalars have 0 dimensions: + int ndim = x.ndim(); + assert(ndim == 0); + + // The shape should be an empty vector: + auto shape = x.shape(); + assert(shape.empty()); + + // The datatype should be float32: + auto dtype = x.dtype(); + assert(dtype == float32); + + // Specify the dtype when constructing the array: + x = array(1, int32); + assert(x.dtype() == int32); + x.item(); // OK + // x.item(); // Undefined! + + // Make a multidimensional array: + x = array({1.0f, 2.0f, 3.0f, 4.0f}, {2, 2}); + // mlx is row-major by default so the first row of this array + // is [1.0, 2.0] and the second row is [3.0, 4.0] + + // Make an array of shape {2, 2} filled with ones: + auto y = ones({2, 2}); + + // Pointwise add x and y: + auto z = add(x, y); + + // Same thing: + z = x + y; + + // mlx is lazy by default. At this point `z` only + // has a shape and a type but no actual data: + assert(z.dtype() == float32); + assert(z.shape(0) == 2); + assert(z.shape(1) == 2); + + // To actually run the compuation you must evaluate `z`. + // Under the hood, mlx records operations in a graph. + // The variable `z` is a node in the graph which points to its operation + // and inputs. When `eval` is called on an array (or arrays), the array and + // all of its dependencies are recursively evaluated to produce the result. + // Once an array is evaluated, it has data and is detached from its inputs. + eval(z); + + // Of course the array can still be an input to other operations. You can even + // call eval on the array again, this will just be a no-op: + eval(z); // no-op + + // Some functions or methods on arrays implicitly evaluate them. For example + // accessing a value in an array or printing the array implicitly evaluate it: + z = ones({1}); + z.item(); // implicit evaluation + + z = ones({2, 2}); + std::cout << z << std::endl; // implicit evaluation +} + +void automatic_differentiation() { + auto fn = [](array x) { return square(x); }; + + // Computing the derivative function of a function + auto grad_fn = grad(fn); + // Call grad_fn on the input to get the derivative + auto x = array(1.5); + auto dfdx = grad_fn(x); + // dfdx is 2 * x + + // Get the second derivative by composing grad with grad + auto df2dx2 = grad(grad(fn))(x); + // df2dx2 is 2 +} + +int main() { + array_basics(); + automatic_differentiation(); +} diff --git a/examples/extensions/CMakeLists.txt b/examples/extensions/CMakeLists.txt new file mode 100644 index 0000000000..0c9951103d --- /dev/null +++ b/examples/extensions/CMakeLists.txt @@ -0,0 +1,66 @@ +cmake_minimum_required(VERSION 3.24) + +project(mlx_sample_extensions LANGUAGES CXX) + +# ----------------------------- Setup ----------------------------- +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_POSITION_INDEPENDENT_CODE ON) + +option(BUILD_SHARED_LIBS "Build extensions as a shared library" ON) + +# ----------------------------- Dependencies ----------------------------- +find_package(MLX CONFIG REQUIRED) +find_package(Python COMPONENTS Interpreter Development) +find_package(pybind11 CONFIG REQUIRED) + +# ----------------------------- Extensions ----------------------------- + +# Add library +add_library(mlx_ext) + +# Add sources +target_sources( + mlx_ext + PUBLIC + ${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.cpp +) + +# Add include headers +target_include_directories( + mlx_ext PUBLIC ${CMAKE_CURRENT_LIST_DIR} +) + +# Link to mlx +target_link_libraries(mlx_ext PUBLIC mlx) + +# ----------------------------- Metal ----------------------------- + +# Build metallib +if(MLX_BUILD_METAL) + + mlx_build_metallib( + TARGET mlx_ext_metallib + TITLE mlx_ext + SOURCES ${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.metal + INCLUDE_DIRS ${PROJECT_SOURCE_DIR} ${MLX_INCLUDE_DIRS} + OUTPUT_DIRECTORY ${CMAKE_LIBRARY_OUTPUT_DIRECTORY} + ) + + add_dependencies( + mlx_ext + mlx_ext_metallib + ) + +endif() + +# ----------------------------- Pybind ----------------------------- +pybind11_add_module( + mlx_sample_extensions + ${CMAKE_CURRENT_LIST_DIR}/bindings.cpp +) +target_link_libraries(mlx_sample_extensions PRIVATE mlx_ext) + +if(BUILD_SHARED_LIBS) + target_link_options(mlx_sample_extensions PRIVATE -Wl,-rpath,@loader_path) +endif() \ No newline at end of file diff --git a/examples/extensions/axpby/axpby.h b/examples/extensions/axpby/axpby.h new file mode 100644 index 0000000000..6abc48a9ed --- /dev/null +++ b/examples/extensions/axpby/axpby.h @@ -0,0 +1,84 @@ +#pragma once + +#include "mlx/ops.h" +#include "mlx/primitives.h" + +namespace mlx::core { + +/////////////////////////////////////////////////////////////////////////////// +// Operation +/////////////////////////////////////////////////////////////////////////////// + +/** + * Scale and sum two vectors elementwise + * z = alpha * x + beta * y + * + * Follow numpy style broadcasting between x and y + * Inputs are upcasted to floats if needed + **/ +array axpby( + const array& x, // Input array x + const array& y, // Input array y + const float alpha, // Scaling factor for x + const float beta, // Scaling factor for y + StreamOrDevice s = {} // Stream on which to schedule the operation +); + +/////////////////////////////////////////////////////////////////////////////// +// Primitive +/////////////////////////////////////////////////////////////////////////////// + +class Axpby : public Primitive { + public: + explicit Axpby(Stream stream, float alpha, float beta) + : Primitive(stream), alpha_(alpha), beta_(beta){}; + + /** + * A primitive must know how to evaluate itself on the CPU/GPU + * for the given inputs and populate the output array. + * + * To avoid unecessary allocations, the evaluation function + * is responsible for allocating space for the array. + */ + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + /** The Jacobian-vector product. */ + array jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) override; + + /** The vector-Jacobian product. */ + std::vector vjp( + const std::vector& primals, + const array& cotan, + const std::vector& argnums) override; + + /** + * The primitive must know how to vectorize itself accross + * the given axes. The output is a pair containing the array + * representing the vectorized computation and the axis which + * corresponds to the output vectorized dimension. + */ + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + /** Print the primitive. */ + void print(std::ostream& os) override { + os << "Axpby"; + } + + /** Equivalence check **/ + bool is_equivalent(const Primitive& other) const override; + + private: + float alpha_; + float beta_; + + /** Fall back implementation for evaluation on CPU */ + void eval(const std::vector& inputs, array& out); +}; + +} // namespace mlx::core \ No newline at end of file diff --git a/examples/extensions/axpby/axpby.metal b/examples/extensions/axpby/axpby.metal new file mode 100644 index 0000000000..fa2849579b --- /dev/null +++ b/examples/extensions/axpby/axpby.metal @@ -0,0 +1,61 @@ +#include + +#include "mlx/backend/metal/kernels/bf16.h" +#include "mlx/backend/metal/kernels/utils.h" + +template +[[kernel]] void axpby_general( + device const T* x [[buffer(0)]], + device const T* y [[buffer(1)]], + device T* out [[buffer(2)]], + constant const float& alpha [[buffer(3)]], + constant const float& beta [[buffer(4)]], + constant const int* shape [[buffer(5)]], + constant const size_t* x_strides [[buffer(6)]], + constant const size_t* y_strides [[buffer(7)]], + constant const int& ndim [[buffer(8)]], + uint index [[thread_position_in_grid]]) { + auto x_offset = elem_to_loc(index, shape, x_strides, ndim); + auto y_offset = elem_to_loc(index, shape, y_strides, ndim); + out[index] = + static_cast(alpha) * x[x_offset] + static_cast(beta) * y[y_offset]; +} + +template +[[kernel]] void axpby_contiguous( + device const T* x [[buffer(0)]], + device const T* y [[buffer(1)]], + device T* out [[buffer(2)]], + constant const float& alpha [[buffer(3)]], + constant const float& beta [[buffer(4)]], + uint index [[thread_position_in_grid]]) { + out[index] = + static_cast(alpha) * x[index] + static_cast(beta) * y[index]; +} + +#define instantiate_axpby(type_name, type) \ + template [[host_name("axpby_general_" #type_name)]] \ + [[kernel]] void axpby_general( \ + device const type* x [[buffer(0)]], \ + device const type* y [[buffer(1)]], \ + device type* out [[buffer(2)]], \ + constant const float& alpha [[buffer(3)]], \ + constant const float& beta [[buffer(4)]], \ + constant const int* shape [[buffer(5)]], \ + constant const size_t* x_strides [[buffer(6)]], \ + constant const size_t* y_strides [[buffer(7)]], \ + constant const int& ndim [[buffer(8)]], \ + uint index [[thread_position_in_grid]]); \ + template [[host_name("axpby_contiguous_" #type_name)]] \ + [[kernel]] void axpby_contiguous( \ + device const type* x [[buffer(0)]], \ + device const type* y [[buffer(1)]], \ + device type* out [[buffer(2)]], \ + constant const float& alpha [[buffer(3)]], \ + constant const float& beta [[buffer(4)]], \ + uint index [[thread_position_in_grid]]); + +instantiate_axpby(float32, float); +instantiate_axpby(float16, half); +instantiate_axpby(bflot16, bfloat16_t); +instantiate_axpby(complex64, complex64_t); \ No newline at end of file diff --git a/examples/extensions/mlx_sample_extensions/__init__.py b/examples/extensions/mlx_sample_extensions/__init__.py new file mode 100644 index 0000000000..1b17da245e --- /dev/null +++ b/examples/extensions/mlx_sample_extensions/__init__.py @@ -0,0 +1,2 @@ +import mlx.core as mx +from .mlx_sample_extensions import * diff --git a/examples/extensions/setup.py b/examples/extensions/setup.py new file mode 100644 index 0000000000..c3817b049f --- /dev/null +++ b/examples/extensions/setup.py @@ -0,0 +1,16 @@ +from mlx import extension +from setuptools import setup + +if __name__ == "__main__": + setup( + name="mlx_sample_extensions", + version="0.0.0", + description="Sample C++ and Metal extensions for MLX primitives.", + ext_modules=[extension.CMakeExtension("mlx_sample_extensions")], + cmdclass={"build_ext": extension.CMakeBuild}, + packages=["mlx_sample_extensions"], + package_dir={"": "."}, + package_data={"mlx_sample_extensions": ["*.so", "*.dylib", "*.metallib"]}, + zip_safe=False, + python_requires=">=3.7", + ) diff --git a/examples/python/linear_regression.py b/examples/python/linear_regression.py new file mode 100644 index 0000000000..f1837a6487 --- /dev/null +++ b/examples/python/linear_regression.py @@ -0,0 +1,43 @@ +import mlx.core as mx +import time + +num_features = 100 +num_examples = 1_000 +num_iters = 10_000 +lr = 0.01 + +# True parameters +w_star = mx.random.normal((num_features,)) + +# Input examples (design matrix) +X = mx.random.normal((num_examples, num_features)) + +# Noisy labels +eps = 1e-2 * mx.random.normal((num_examples,)) +y = X @ w_star + eps + +# Initialize random parameters +w = 1e-2 * mx.random.normal((num_features,)) + + +def loss_fn(w): + return 0.5 * mx.mean(mx.square(X @ w - y)) + + +grad_fn = mx.grad(loss_fn) + +tic = time.time() +for _ in range(num_iters): + grad = grad_fn(w) + w = w - lr * grad + mx.eval(w) +toc = time.time() + +loss = loss_fn(w) +error_norm = mx.sum(mx.square(w - w_star)).item() ** 0.5 +throughput = num_iters / (toc - tic) + +print( + f"Loss {loss.item():.5f}, |w-w*| = {error_norm:.5f}, " + f"Throughput {throughput:.5f} (it/s)" +) diff --git a/examples/python/logistic_regression.py b/examples/python/logistic_regression.py new file mode 100644 index 0000000000..928ae34a5f --- /dev/null +++ b/examples/python/logistic_regression.py @@ -0,0 +1,46 @@ +import mlx.core as mx +import time + +num_features = 100 +num_examples = 1_000 +num_iters = 10_000 +lr = 0.1 + +# True parameters +w_star = mx.random.normal((num_features,)) + +# Input examples +X = mx.random.normal((num_examples, num_features)) + +# Labels +y = (X @ w_star) > 0 + + +# Initialize random parameters +w = 1e-2 * mx.random.normal((num_features,)) + + +def loss_fn(w): + logits = X @ w + return mx.mean(mx.logaddexp(0.0, logits) - y * logits) + + +grad_fn = mx.grad(loss_fn) + +tic = time.time() +for _ in range(num_iters): + grad = grad_fn(w) + w = w - lr * grad + mx.eval(w) + +toc = time.time() + +loss = loss_fn(w) +final_preds = (X @ w) > 0 +acc = mx.mean(final_preds == y) + +throughput = num_iters / (toc - tic) +print( + f"Loss {loss.item():.5f}, Accuracy {acc.item():.5f} " + f"Throughput {throughput:.5f} (it/s)" +) diff --git a/mlx.pc.in b/mlx.pc.in new file mode 100644 index 0000000000..c3828b30b3 --- /dev/null +++ b/mlx.pc.in @@ -0,0 +1,43 @@ +# Find MLX +# +# Defines the following variables: +# +# MLX_FOUND : True if MLX is found +# MLX_INCLUDE_DIRS : Include directory +# MLX_LIBRARIES : Libraries to link against +# MLX_CXX_FLAGS : Additional compiler flags +# MLX_BUILD_ACCELERATE : True if MLX was built with accelerate +# MLX_BUILD_METAL : True if MLX was built with metal + +@PACKAGE_INIT@ + +include(@PACKAGE_MLX_CMAKE_INSTALL_MODULE_DIR@/MLXTargets.cmake) +include(@PACKAGE_MLX_CMAKE_INSTALL_MODULE_DIR@/extension.cmake) + +set_and_check(MLX_LIBRARY_DIRS @PACKAGE_CMAKE_INSTALL_LIBDIR@) +set_and_check(MLX_INCLUDE_DIRS @PACKAGE_CMAKE_INSTALL_INCLUDEDIR@) +set(MLX_LIBRARIES mlx) + +find_library(MLX_LIBRARY mlx PATHS ${MLX_LIBRARY_DIRS}) + +if (@MLX_BUILD_ACCELERATE@) + set(MLX_BUILD_ACCELERATE @MLX_BUILD_ACCELERATE@) + set(MLX_CXX_FLAGS ${MLX_CXX_FLAGS} -DACCELERATE_NEW_LAPACK) +endif() + +if (@MLX_BUILD_METAL@) + set(MLX_BUILD_METAL @MLX_BUILD_METAL@) + set(MLX_CXX_FLAGS ${MLX_CXX_FLAGS} -D_METAL_) + set_and_check(MLX_INCLUDE_DIRS + ${MLX_INCLUDE_DIRS} + @PACKAGE_CMAKE_INSTALL_INCLUDEDIR@/metal_cpp + ) +endif() + +set_target_properties(mlx PROPERTIES + CXX_STANDARD 17 + INTERFACE_COMPILE_OPTIONS "${MLX_CXX_FLAGS}" +) + +include(FindPackageHandleStandardArgs) +find_package_handle_standard_args(MLX DEFAULT_MSG MLX_LIBRARY MLX_INCLUDE_DIRS) \ No newline at end of file diff --git a/mlx/3rdparty/.clang-format b/mlx/3rdparty/.clang-format new file mode 100644 index 0000000000..47a38a93f2 --- /dev/null +++ b/mlx/3rdparty/.clang-format @@ -0,0 +1,2 @@ +DisableFormat: true +SortIncludes: Never diff --git a/mlx/allocator.cpp b/mlx/allocator.cpp new file mode 100644 index 0000000000..1d8cddd220 --- /dev/null +++ b/mlx/allocator.cpp @@ -0,0 +1,48 @@ +#include +#include + +#include "mlx/allocator.h" +#include "mlx/scheduler.h" + +namespace mlx::core::allocator { + +Buffer malloc(size_t size) { + auto buffer = allocator().malloc(size); + if (size && !buffer.ptr()) { + std::ostringstream msg; + msg << "[malloc] Unable to allocate " << size << " bytes."; + throw std::runtime_error(msg.str()); + } + return buffer; +} + +void free(Buffer buffer) { + return allocator().free(buffer); +} + +Buffer CommonAllocator::malloc(size_t size) { + return Buffer{std::malloc(size)}; +} + +void CommonAllocator::free(Buffer buffer) { + std::free(buffer.raw_ptr()); +} + +Buffer malloc_or_wait(size_t size) { + auto buffer = allocator().malloc(size); + + while (size && !buffer.ptr() && scheduler::n_active_tasks() > 0) { + scheduler::wait_for_one(); + buffer = allocator().malloc(size); + } + + if (size && !buffer.ptr()) { + std::ostringstream msg; + msg << "[malloc_or_wait] Unable to allocate " << size << " bytes."; + throw std::runtime_error(msg.str()); + } + + return buffer; +} + +} // namespace mlx::core::allocator diff --git a/mlx/array.h b/mlx/array.h new file mode 100644 index 0000000000..c0f02ee4d7 --- /dev/null +++ b/mlx/array.h @@ -0,0 +1,436 @@ +#pragma once +#include +#include +#include +#include +#include + +#include "mlx/allocator.h" +#include "mlx/dtype.h" + +namespace mlx::core { + +// Forward declaration +class Primitive; +using deleter_t = std::function; + +class array { + /* An array is really a node in a graph. It contains a shared ArrayDesc + * object */ + + public: + /** Construct a scalar array with zero dimensions. */ + template + explicit array(T val, Dtype dtype = TypeToDtype()); + + /* Special case since std::complex can't be implicitly converted to other + * types. */ + explicit array(const std::complex& val, Dtype dtype = complex64); + + template + array( + It data, + const std::vector& shape, + Dtype dtype = + TypeToDtype::value_type>()); + + template + array(std::initializer_list data, Dtype dtype = TypeToDtype()); + + /* Special case so empty lists default to float32. */ + array(std::initializer_list data); + + template + array( + std::initializer_list data, + const std::vector& shape, + Dtype dtype = TypeToDtype()); + + /* Build an array from a buffer */ + array( + allocator::Buffer data, + const std::vector& shape, + Dtype dtype, + deleter_t deleter = allocator::free); + + /** Assignment to rvalue does not compile. */ + array& operator=(const array& other) && = delete; + array& operator=(array&& other) && = delete; + + /** Default copy and move constructors otherwise. */ + array& operator=(array&& other) & = default; + array(const array& other) = default; + array(array&& other) = default; + + array& operator=(const array& other) & { + if (this->id() != other.id()) { + this->array_desc_ = other.array_desc_; + } + return *this; + }; + + /** The size of the array's datatype in bytes. */ + size_t itemsize() const { + return size_of(dtype()); + }; + + /** The number of elements in the array. */ + size_t size() const { + return array_desc_->size; + }; + + /** The number of bytes in the array. */ + size_t nbytes() const { + return size() * itemsize(); + }; + + /** The number of dimensions of the array. */ + size_t ndim() const { + return array_desc_->shape.size(); + }; + + /** The shape of the array as a vector of integers. */ + const std::vector& shape() const { + return array_desc_->shape; + }; + + /** + * Get the size of the corresponding dimension. + * + * This function supports negative indexing and provides + * bounds checking. */ + int shape(int dim) const { + return shape().at(dim < 0 ? dim + ndim() : dim); + }; + + /** The strides of the array. */ + const std::vector& strides() const { + return array_desc_->strides; + }; + + /** Get the arrays data type. */ + Dtype dtype() const { + return array_desc_->dtype; + }; + + /** Evaluate the array. */ + void eval(bool retain_graph = false); + + /** Get the value from a scalar array. */ + template + T item(bool retain_graph = false); + + struct ArrayIterator { + using iterator_category = std::random_access_iterator_tag; + using difference_type = size_t; + using value_type = const array; + using reference = value_type; + + explicit ArrayIterator(const array& arr, int idx = 0) : arr(arr), idx(idx) { + if (arr.ndim() == 0) { + throw std::invalid_argument("Cannot iterate over 0-d array."); + } + } + + reference operator*() const; + + ArrayIterator& operator+(difference_type diff) { + idx += diff; + return *this; + } + + ArrayIterator& operator++() { + idx++; + return *this; + } + + friend bool operator==(const ArrayIterator& a, const ArrayIterator& b) { + return a.arr.id() == b.arr.id() && a.idx == b.idx; + }; + friend bool operator!=(const ArrayIterator& a, const ArrayIterator& b) { + return !(a == b); + }; + + private: + int idx; + const array& arr; + }; + + ArrayIterator begin() const { + return ArrayIterator(*this); + } + ArrayIterator end() const { + return ArrayIterator(*this, shape(0)); + } + + /** + * The following methods should be used with caution. + * They are intended for use by the backend implementation and the + * API may change. + */ + + array( + const std::vector& shape, + Dtype dtype, + std::unique_ptr primitive, + const std::vector& inputs); + + /** A unique identifier for an array. */ + std::uintptr_t id() const { + return reinterpret_cast(array_desc_.get()); + } + + struct Data { + allocator::Buffer buffer; + deleter_t d; + Data(allocator::Buffer buffer, deleter_t d = allocator::free) + : buffer(buffer), d(d){}; + // Not copyable + Data(const Data& d) = delete; + Data& operator=(const Data& d) = delete; + ~Data() { + d(buffer); + } + }; + + struct Flags { + // True if there are no gaps in the underlying data. Each item + // in the underlying data buffer belongs to at least one index. + bool contiguous : 1; + + bool row_contiguous : 1; + bool col_contiguous : 1; + }; + + /** The array's primitive. */ + Primitive& primitive() const { + return *(array_desc_->primitive); + }; + + /** Check if the array has an attached primitive or is a leaf node. */ + bool has_primitive() const { + return array_desc_->primitive != nullptr; + }; + + /** The array's inputs. */ + const std::vector& inputs() const { + return array_desc_->inputs; + }; + + /** A non-const reference to the array's inputs so that they can be used to + * edit the graph. */ + std::vector& editable_inputs() { + return array_desc_->inputs; + } + + /** Detach the array from the graph. */ + void detach(); + + /** Get the Flags bit-field. */ + const Flags& flags() const { + return array_desc_->flags; + }; + + /** The size (in elements) of the underlying buffer the array points to. */ + size_t data_size() const { + return array_desc_->data_size; + }; + + allocator::Buffer& buffer() { + return array_desc_->data->buffer; + }; + const allocator::Buffer& buffer() const { + return array_desc_->data->buffer; + }; + + template + T* data() { + return static_cast(array_desc_->data_ptr); + }; + + template + const T* data() const { + return static_cast(array_desc_->data_ptr); + }; + + // Check if the array has been evaluated + bool is_evaled() const { + return array_desc_->data != nullptr; + } + + // Mark the array as a tracer array (true) or not. + void set_tracer(bool is_tracer) { + array_desc_->is_tracer = is_tracer; + } + // Check if the array is a tracer array + bool is_tracer() const { + return array_desc_->is_tracer; + } + + void set_data(allocator::Buffer buffer, deleter_t d = allocator::free); + + void set_data( + allocator::Buffer buffer, + size_t data_size, + std::vector strides, + Flags flags, + deleter_t d = allocator::free); + + void copy_shared_buffer( + const array& other, + const std::vector& strides, + Flags flags, + size_t data_size, + size_t offset = 0); + + void copy_shared_buffer(const array& other); + + void overwrite_descriptor(const array& other) { + array_desc_ = other.array_desc_; + } + + private: + // Initialize the arrays data + template + void init(const It src); + + struct ArrayDesc { + std::vector shape; + std::vector strides; + size_t size; + Dtype dtype; + std::unique_ptr primitive{nullptr}; + + // Indicates an array is being used in a graph transform + // and should not be detached from the graph + bool is_tracer{false}; + + // This is a shared pointer so that *different* arrays + // can share the underlying data buffer. + std::shared_ptr data{nullptr}; + + // Properly offset data pointer + void* data_ptr{nullptr}; + + // The size in elements of the data buffer the array accesses + // This can be different than the actual size of the array if it + // has been broadcast or irregularly strided. + size_t data_size; + + // Contains useful meta data about the array + Flags flags; + + std::vector inputs; + + explicit ArrayDesc(const std::vector& shape, Dtype dtype); + + explicit ArrayDesc( + const std::vector& shape, + Dtype dtype, + std::unique_ptr primitive, + const std::vector& inputs); + + ~ArrayDesc(); + }; + + // The ArrayDesc contains the details of the materialized array including the + // shape, strides, the data type. It also includes + // the primitive which knows how to compute the array's data from its inputs + // and a the list of array's inputs for the primitive. + std::shared_ptr array_desc_{nullptr}; +}; + +template +array::array(T val, Dtype dtype /* = TypeToDtype() */) + : array_desc_(std::make_shared(std::vector{}, dtype)) { + init(&val); +} + +template +array::array( + It data, + const std::vector& shape, + Dtype dtype /* = TypeToDtype::value_type>() */) : + array_desc_(std::make_shared(shape, dtype)) { + init(data); +} + +template +array::array( + std::initializer_list data, + Dtype dtype /* = TypeToDtype() */) + : array_desc_(std::make_shared( + std::vector{static_cast(data.size())}, + dtype)) { + init(data.begin()); +} + +template +array::array( + std::initializer_list data, + const std::vector& shape, + Dtype dtype /* = TypeToDtype() */) + : array_desc_(std::make_shared(shape, dtype)) { + if (data.size() != size()) { + throw std::invalid_argument( + "Data size and provided shape mismatch in array construction."); + } + init(data.begin()); +} + +template +T array::item(bool retain_graph /* = false */) { + if (size() != 1) { + throw std::invalid_argument("item can only be called on arrays of size 1."); + } + eval(retain_graph); + return *data(); +} + +template +void array::init(It src) { + set_data(allocator::malloc(size() * size_of(dtype()))); + switch (dtype()) { + case bool_: + std::copy(src, src + size(), data()); + break; + case uint8: + std::copy(src, src + size(), data()); + break; + case uint16: + std::copy(src, src + size(), data()); + break; + case uint32: + std::copy(src, src + size(), data()); + break; + case uint64: + std::copy(src, src + size(), data()); + break; + case int8: + std::copy(src, src + size(), data()); + break; + case int16: + std::copy(src, src + size(), data()); + break; + case int32: + std::copy(src, src + size(), data()); + break; + case int64: + std::copy(src, src + size(), data()); + break; + case float16: + std::copy(src, src + size(), data()); + break; + case float32: + std::copy(src, src + size(), data()); + break; + case bfloat16: + std::copy(src, src + size(), data()); + break; + case complex64: + std::copy(src, src + size(), data()); + break; + } +} + +} // namespace mlx::core diff --git a/mlx/backend/accelerate/CMakeLists.txt b/mlx/backend/accelerate/CMakeLists.txt new file mode 100644 index 0000000000..34269f9c24 --- /dev/null +++ b/mlx/backend/accelerate/CMakeLists.txt @@ -0,0 +1,9 @@ +target_sources( + mlx + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp +) diff --git a/mlx/backend/accelerate/matmul.cpp b/mlx/backend/accelerate/matmul.cpp new file mode 100644 index 0000000000..5e8014c692 --- /dev/null +++ b/mlx/backend/accelerate/matmul.cpp @@ -0,0 +1,167 @@ +#include + +#include +#include + +#include "mlx/backend/accelerate/utils.h" +#include "mlx/backend/common/copy.h" +#include "mlx/primitives.h" +#include "mlx/utils.h" + +namespace mlx::core { + +namespace { + +std::tuple check_transpose(const array& arr) { + auto stx = arr.strides()[arr.ndim() - 2]; + auto sty = arr.strides()[arr.ndim() - 1]; + if (stx == arr.shape(-1) && sty == 1) { + return std::make_tuple(false, stx, arr); + } else if (stx == 1 && sty == arr.shape(-2)) { + return std::make_tuple(true, sty, arr); + } else { + array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); + copy(arr, arr_copy, CopyType::General); + size_t stx = arr.shape(-1); + return std::make_tuple(false, stx, arr_copy); + } +} + +inline void matmul_cblas(const array& a_pre, const array& b_pre, array& out) { + if (out.dtype() != float32) { + throw std::runtime_error( + "[matmul_cblas] on CPU currently only supports float32"); + } + out.set_data(allocator::malloc_or_wait(out.nbytes())); + + auto [a_transposed, lda, a] = check_transpose(a_pre); + auto [b_transposed, ldb, b] = check_transpose(b_pre); + size_t M = a.shape(-2); + size_t N = b.shape(-1); + size_t K = a.shape(-1); + + for (int i = 0; i < (a.size() / (M * K)); ++i) { + cblas_sgemm( + CblasRowMajor, + a_transposed ? CblasTrans : CblasNoTrans, // transA + b_transposed ? CblasTrans : CblasNoTrans, // transB + M, + N, + K, + 1.0f, // alpha + a.data() + elem_to_loc(M * K * i, a.shape(), a.strides()), + lda, + b.data() + elem_to_loc(K * N * i, b.shape(), b.strides()), + ldb, + 0.0f, // beta + out.data() + M * N * i, + out.shape(-1) // ldc + ); + } +} + +inline void matmul_bnns(const array& a_pre, const array& b_pre, array& out) { + // TODO: Update to utilize BNNS broadcasting + out.set_data(allocator::malloc_or_wait(out.nbytes())); + + auto [a_transposed, lda, a] = check_transpose(a_pre); + auto [b_transposed, ldb, b] = check_transpose(b_pre); + size_t M = a.shape(-2); + size_t N = b.shape(-1); + size_t K = a.shape(-1); + + BNNSDataType bnns_dtype = to_bnns_dtype(out.dtype()); + + const BNNSLayerParametersBroadcastMatMul gemm_params{ + /* float alpha = */ 1.0, + /* float beta = */ 0.0, + /* bool transA = */ a_transposed, + /* bool transB = */ b_transposed, + /* bool quadratic = */ false, + /* bool a_is_weights = */ false, + /* bool b_is_weights = */ false, + /* BNNSNDArrayDescriptor iA_desc = */ + BNNSNDArrayDescriptor{ + /* BNNSNDArrayFlags flags = */ BNNSNDArrayFlagBackpropSet, + /* BNNSDataLayout layout = */ BNNSDataLayoutRowMajorMatrix, + + /* size_t size[BNNS_MAX_TENSOR_DIMENSION] = */ + {lda, (M * K) / lda, 0, 0, 0, 0, 0, 0}, + /* size_t stride[BNNS_MAX_TENSOR_DIMENSION] = */ + {1, lda, 0, 0, 0, 0, 0, 0}, + + /* void * _Nullable data = */ nullptr, + /* BNNSDataType data_type = */ bnns_dtype, + + /* void * _Nullable table_data = */ nullptr, + /* BNNSDataType table_data_type = */ bnns_dtype, + + /* float data_scale = */ 1.0, + /* float data_bias = */ 0.0, + }, + /* BNNSNDArrayDescriptor iB_desc = */ + BNNSNDArrayDescriptor{ + /* BNNSNDArrayFlags flags = */ BNNSNDArrayFlagBackpropSet, + /* BNNSDataLayout layout = */ BNNSDataLayoutRowMajorMatrix, + + /* size_t size[BNNS_MAX_TENSOR_DIMENSION] = */ + {ldb, (K * N) / ldb, 0, 0, 0, 0, 0, 0}, + /* size_t stride[BNNS_MAX_TENSOR_DIMENSION] = */ + {1, ldb, 0, 0, 0, 0, 0, 0}, + + /* void * _Nullable data = */ nullptr, + /* BNNSDataType data_type = */ bnns_dtype, + + /* void * _Nullable table_data = */ nullptr, + /* BNNSDataType table_data_type = */ bnns_dtype, + + /* float data_scale = */ 1.0, + /* float data_bias = */ 0.0, + }, + /* BNNSNDArrayDescriptor o_desc = */ + BNNSNDArrayDescriptor{ + /* BNNSNDArrayFlags flags = */ BNNSNDArrayFlagBackpropSet, + /* BNNSDataLayout layout = */ BNNSDataLayoutRowMajorMatrix, + + /* size_t size[BNNS_MAX_TENSOR_DIMENSION] = */ + {N, M, 0, 0, 0, 0, 0, 0}, + /* size_t stride[BNNS_MAX_TENSOR_DIMENSION] = */ + {1, N, 0, 0, 0, 0, 0, 0}, + + /* void * _Nullable data = */ nullptr, + /* BNNSDataType data_type = */ bnns_dtype, + + /* void * _Nullable table_data = */ nullptr, + /* BNNSDataType table_data_type = */ bnns_dtype, + + /* float data_scale = */ 1.0, + /* float data_bias = */ 0.0, + }, + }; + + auto bnns_filter = + BNNSFilterCreateLayerBroadcastMatMul(&gemm_params, nullptr); + + for (int i = 0; i < (a.size() / (M * K)); ++i) { + BNNSFilterApplyTwoInput( + bnns_filter, + a.data() + + elem_to_loc(M * K * i, a.shape(), a.strides()) * a.itemsize(), + b.data() + + elem_to_loc(K * N * i, b.shape(), b.strides()) * b.itemsize(), + out.data() + M * N * i * out.itemsize()); + } + + BNNSFilterDestroy(bnns_filter); +} + +} // namespace + +void Matmul::eval_cpu(const std::vector& inputs, array& out) { + if (out.dtype() == float32) { + return matmul_cblas(inputs[0], inputs[1], out); + } + return matmul_bnns(inputs[0], inputs[1], out); +} + +} // namespace mlx::core \ No newline at end of file diff --git a/mlx/backend/accelerate/primitives.cpp b/mlx/backend/accelerate/primitives.cpp new file mode 100644 index 0000000000..1a2b7b3e02 --- /dev/null +++ b/mlx/backend/accelerate/primitives.cpp @@ -0,0 +1,672 @@ +#include +#include + +#include +#include + +#include "mlx/allocator.h" +#include "mlx/backend/common/binary.h" +#include "mlx/backend/common/copy.h" +#include "mlx/backend/common/unary.h" +#include "mlx/primitives.h" + +#define DEFAULT(primitive) \ + void primitive::eval_cpu(const std::vector& inputs, array& out) { \ + primitive::eval(inputs, out); \ + } + +namespace mlx::core { + +// Use the default implementation for the following primitives +DEFAULT(Arange) +DEFAULT(ArgPartition) +DEFAULT(ArgReduce) +DEFAULT(ArgSort) +DEFAULT(AsStrided) +DEFAULT(Broadcast) +DEFAULT(Concatenate) +DEFAULT(Copy) +DEFAULT(Equal) +DEFAULT(Erf) +DEFAULT(ErfInv) +DEFAULT(FFT) +DEFAULT(Gather) +DEFAULT(Greater) +DEFAULT(GreaterEqual) +DEFAULT(Less) +DEFAULT(LessEqual) +DEFAULT(Load) +DEFAULT(LogicalNot) +DEFAULT(LogAddExp) +DEFAULT(NotEqual) +DEFAULT(Pad) +DEFAULT(Partition) +DEFAULT(RandomBits) +DEFAULT(Reshape) +DEFAULT(Scatter) +DEFAULT(Sigmoid) +DEFAULT(Sign) +DEFAULT(Slice) +DEFAULT(Sort) +DEFAULT(StopGradient) +DEFAULT(Transpose) + +void Abs::eval_cpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + auto& in = inputs[0]; + if (in.dtype() == float32 && in.flags().contiguous) { + auto size = in.data_size(); + out.set_data( + allocator::malloc_or_wait(size * out.itemsize()), + size, + in.strides(), + in.flags()); + vDSP_vabs(in.data(), 1, out.data(), 1, size); + } else if (in.dtype() == int32 && in.flags().contiguous) { + auto size = in.data_size(); + out.set_data( + allocator::malloc_or_wait(size * out.itemsize()), + size, + in.strides(), + in.flags()); + vDSP_vabsi(in.data(), 1, out.data(), 1, size); + } else if (is_unsigned(in.dtype())) { + // No-op for unsigned types + out.copy_shared_buffer(in); + } else { + unary(in, out, AbsOp()); + } +} + +void Add::eval_cpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 2); + auto& a = inputs[0]; + auto& b = inputs[1]; + + if (a.dtype() == float32) { + binary( + a, + b, + out, + [](auto x, auto y) { return x + y; }, + [](const auto* s, const auto* vec, auto* o, auto n) { + vDSP_vsadd((const float*)vec, 1, (const float*)s, (float*)o, 1, n); + }, + [](const auto* vec, const auto* s, auto* o, auto n) { + vDSP_vsadd((const float*)vec, 1, (const float*)s, (float*)o, 1, n); + }, + [](const auto* a, const auto* b, auto* o, auto n) { + vDSP_vadd((const float*)a, 1, (const float*)b, 1, (float*)o, 1, n); + }); + } else if (a.dtype() == int32) { + binary( + a, + b, + out, + [](auto x, auto y) { return x + y; }, + [](const auto* s, const auto* vec, auto* o, auto n) { + vDSP_vsaddi((const int*)vec, 1, (const int*)s, (int*)o, 1, n); + }, + [](const auto* vec, const auto* s, auto* o, auto n) { + vDSP_vsaddi((const int*)vec, 1, (const int*)s, (int*)o, 1, n); + }, + [](const auto* a, const auto* b, auto* o, auto n) { + vDSP_vaddi((const int*)a, 1, (const int*)b, 1, (int*)o, 1, n); + }); + } else { + binary(a, b, out, [](auto x, auto y) { return x + y; }); + } +} + +void ArcCos::eval_cpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + const auto& in = inputs[0]; + if (out.dtype() == float32 && in.flags().contiguous) { + int size = in.data_size(); + out.set_data( + allocator::malloc_or_wait(size * out.itemsize()), + size, + in.strides(), + in.flags()); + vvacosf(out.data(), in.data(), &size); + } else { + eval(inputs, out); + } +} + +void ArcCosh::eval_cpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + const auto& in = inputs[0]; + if (out.dtype() == float32 && in.flags().contiguous) { + int size = in.data_size(); + out.set_data( + allocator::malloc_or_wait(size * out.itemsize()), + size, + in.strides(), + in.flags()); + vvacoshf(out.data(), in.data(), &size); + } else { + eval(inputs, out); + } +} + +void ArcSin::eval_cpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + const auto& in = inputs[0]; + if (out.dtype() == float32 && in.flags().contiguous) { + int size = in.data_size(); + out.set_data( + allocator::malloc_or_wait(size * out.itemsize()), + size, + in.strides(), + in.flags()); + vvasinf(out.data(), in.data(), &size); + } else { + eval(inputs, out); + } +} + +void ArcSinh::eval_cpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + const auto& in = inputs[0]; + if (out.dtype() == float32 && in.flags().contiguous) { + int size = in.data_size(); + out.set_data( + allocator::malloc_or_wait(size * out.itemsize()), + size, + in.strides(), + in.flags()); + vvasinhf(out.data(), in.data(), &size); + } else { + eval(inputs, out); + } +} + +void ArcTan::eval_cpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + const auto& in = inputs[0]; + if (out.dtype() == float32 && in.flags().contiguous) { + int size = in.data_size(); + out.set_data( + allocator::malloc_or_wait(size * out.itemsize()), + size, + in.strides(), + in.flags()); + vvatanf(out.data(), in.data(), &size); + } else { + eval(inputs, out); + } +} + +void ArcTanh::eval_cpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + const auto& in = inputs[0]; + if (out.dtype() == float32 && in.flags().contiguous) { + int size = in.data_size(); + out.set_data( + allocator::malloc_or_wait(size * out.itemsize()), + size, + in.strides(), + in.flags()); + vvatanhf(out.data(), in.data(), &size); + } else { + eval(inputs, out); + } +} + +void AsType::eval_cpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + auto& in = inputs[0]; + + if (in.flags().contiguous) { + auto allocfn = [&in, &out]() { + out.set_data( + allocator::malloc_or_wait(in.data_size() * out.itemsize()), + in.data_size(), + in.strides(), + in.flags()); + }; + // Use accelerate functions if possible + if (in.dtype() == float32 && out.dtype() == uint32) { + allocfn(); + vDSP_vfixu32( + in.data(), 1, out.data(), 1, in.data_size()); + return; + } else if (in.dtype() == float32 && out.dtype() == int32) { + allocfn(); + vDSP_vfix32(in.data(), 1, out.data(), 1, in.data_size()); + return; + } else if (in.dtype() == uint32 && out.dtype() == float32) { + allocfn(); + vDSP_vfltu32( + in.data(), 1, out.data(), 1, in.data_size()); + return; + } else if (in.dtype() == int32 && out.dtype() == float32) { + allocfn(); + vDSP_vflt32(in.data(), 1, out.data(), 1, in.data_size()); + return; + } + } + eval(inputs, out); +} + +void Cos::eval_cpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + const auto& in = inputs[0]; + if (out.dtype() == float32 && in.flags().contiguous) { + int size = in.data_size(); + out.set_data( + allocator::malloc_or_wait(size * out.itemsize()), + size, + in.strides(), + in.flags()); + vvcosf(out.data(), in.data(), &size); + } else { + eval(inputs, out); + } +} + +void Cosh::eval_cpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + const auto& in = inputs[0]; + if (out.dtype() == float32 && in.flags().contiguous) { + int size = in.data_size(); + out.set_data( + allocator::malloc_or_wait(size * out.itemsize()), + size, + in.strides(), + in.flags()); + vvcoshf(out.data(), in.data(), &size); + } else { + eval(inputs, out); + } +} + +void Divide::eval_cpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 2); + auto& a = inputs[0]; + auto& b = inputs[1]; + + if (a.dtype() == int32) { + binary( + a, + b, + out, + [](auto x, auto y) { return x / y; }, + UseDefaultBinaryOp(), + [](const auto* vec, const auto* s, auto* o, auto n) { + vDSP_vsdivi((const int*)vec, 1, (const int*)s, (int*)o, 1, n); + }, + [](const auto* a, const auto* b, auto* o, auto n) { + vDSP_vdivi((const int*)b, 1, (const int*)a, 1, (int*)o, 1, n); + }); + } else if (a.dtype() == float32) { + binary( + a, + b, + out, + [](auto x, auto y) { return x / y; }, + [](const auto* s, const auto* vec, auto* o, auto n) { + vDSP_svdiv((const float*)s, (const float*)vec, 1, (float*)o, 1, n); + }, + [](const auto* vec, const auto* s, auto* o, auto n) { + vDSP_vsdiv((const float*)vec, 1, (const float*)s, (float*)o, 1, n); + }, + [](const auto* a, const auto* b, auto* o, auto n) { + vDSP_vdiv((const float*)b, 1, (const float*)a, 1, (float*)o, 1, n); + }); + } else { + binary(a, b, out, [](auto x, auto y) { return x / y; }); + } +} + +void Exp::eval_cpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + const auto& in = inputs[0]; + if (out.dtype() == float32 && in.flags().contiguous) { + auto size = in.data_size(); + out.set_data( + allocator::malloc_or_wait(size * out.itemsize()), + size, + in.strides(), + in.flags()); + vvexpf(out.data(), in.data(), reinterpret_cast(&size)); + } else if (is_floating_point(out.dtype())) { + unary_fp(in, out, [](auto x) { return std::exp(x); }); + } else { + throw std::invalid_argument( + "[exp] Cannot exponentiate elements in array" + " with non floating point type."); + } +} + +void Full::eval_cpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + auto& in = inputs[0]; + assert(in.dtype() == out.dtype()); + if (in.data_size() == 1 && out.dtype() == float32) { + out.set_data(allocator::malloc_or_wait(out.nbytes())); + vDSP_vfill(in.data(), out.data(), 1, out.size()); + } else { + eval(inputs, out); + } +} + +void Log::eval_cpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + const auto& in = inputs[0]; + if (out.dtype() == float32 && in.flags().contiguous) { + auto size = in.data_size(); + out.set_data( + allocator::malloc_or_wait(size * out.itemsize()), + size, + in.strides(), + in.flags()); + switch (base_) { + case Base::e: + vvlogf( + out.data(), in.data(), reinterpret_cast(&size)); + break; + case Base::two: + vvlog2f( + out.data(), in.data(), reinterpret_cast(&size)); + break; + case Base::ten: + vvlog10f( + out.data(), in.data(), reinterpret_cast(&size)); + break; + } + } else { + eval(inputs, out); + } +} + +void Log1p::eval_cpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + const auto& in = inputs[0]; + if (out.dtype() == float32 && in.flags().contiguous) { + auto size = in.data_size(); + out.set_data( + allocator::malloc_or_wait(size * out.itemsize()), + size, + in.strides(), + in.flags()); + vvlog1pf( + out.data(), in.data(), reinterpret_cast(&size)); + } else if (is_floating_point(out.dtype())) { + unary_fp(in, out, [](auto x) { return std::log1p(x); }); + } else { + throw std::invalid_argument( + "[log1p] Cannot compute log of elements in array with" + " non floating point type."); + } +} + +void Maximum::eval_cpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 2); + auto& a = inputs[0]; + auto& b = inputs[1]; + if (out.dtype() == float32) { + binary( + a, + b, + out, + [](auto x, auto y) { return (x > y) ? x : y; }, + UseDefaultBinaryOp(), + UseDefaultBinaryOp(), + [](const auto* a, const auto* b, auto* out, int n) { + vDSP_vmax((const float*)a, 1, (const float*)b, 1, (float*)out, 1, n); + }); + } else { + binary(a, b, out, [](auto x, auto y) { return (x > y) ? x : y; }); + } +} + +void Minimum::eval_cpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 2); + auto& a = inputs[0]; + auto& b = inputs[1]; + + if (out.dtype() == float32) { + binary( + a, + b, + out, + [](auto x, auto y) { return (x < y) ? x : y; }, + UseDefaultBinaryOp(), + UseDefaultBinaryOp(), + [](const auto* a, const auto* b, auto* out, int n) { + vDSP_vmin((const float*)a, 1, (const float*)b, 1, (float*)out, 1, n); + }); + } else { + binary(a, b, out, [](auto x, auto y) { return (x < y) ? x : y; }); + } +} + +void Multiply::eval_cpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 2); + auto& a = inputs[0]; + auto& b = inputs[1]; + + if (a.dtype() == float32) { + binary( + a, + b, + out, + [](auto x, auto y) { return x * y; }, + [](const auto* s, const auto* vec, auto* o, auto n) { + vDSP_vsmul((const float*)vec, 1, (const float*)s, (float*)o, 1, n); + }, + [](const auto* vec, const auto* s, auto* o, auto n) { + vDSP_vsmul((const float*)vec, 1, (const float*)s, (float*)o, 1, n); + }, + [](const auto* a, const auto* b, auto* o, auto n) { + vDSP_vmul((const float*)a, 1, (const float*)b, 1, (float*)o, 1, n); + }); + } else { + binary(a, b, out, [](auto x, auto y) { return x * y; }); + } +} + +void Negative::eval_cpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + auto& in = inputs[0]; + if (in.dtype() == float32 && in.flags().contiguous) { + auto size = in.data_size(); + out.set_data( + allocator::malloc_or_wait(size * out.itemsize()), + size, + in.strides(), + in.flags()); + vDSP_vneg(in.data(), 1, out.data(), 1, size); + } else { + unary(in, out, [](auto x) { return -x; }); + } +} + +void Power::eval_cpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 2); + auto& a = inputs[0]; + auto& b = inputs[1]; + if (out.dtype() == float32 && a.flags().row_contiguous && + b.flags().row_contiguous) { + int size = a.size(); + out.set_data(allocator::malloc_or_wait(out.nbytes())); + vvpowf(out.data(), a.data(), b.data(), &size); + } else { + eval(inputs, out); + } +} + +void Scan::eval_cpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + const auto& in = inputs[0]; + if (reduce_type_ == Scan::Sum && out.dtype() == float32 && + in.flags().row_contiguous && in.strides()[axis_] == 1 && !inclusive_) { + out.set_data(allocator::malloc_or_wait(out.nbytes())); + int stride = in.shape(axis_); + int count = in.size() / stride; + const float* input = in.data(); + float* output = out.data(); + float s = 1.0; + if (!reverse_) { + for (int i = 0; i < count; i++) { + vDSP_vrsum(input - 1, 1, &s, output, 1, stride); + input += stride; + output += stride; + } + } else { + for (int i = 0; i < count; i++) { + input += stride - 1; + output += stride - 1; + vDSP_vrsum(input + 1, -1, &s, output, -1, stride); + input++; + output++; + } + } + } else { + eval(inputs, out); + } +} + +void Sin::eval_cpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + const auto& in = inputs[0]; + if (out.dtype() == float32 && in.flags().contiguous) { + int size = in.data_size(); + out.set_data( + allocator::malloc_or_wait(size * out.itemsize()), + size, + in.strides(), + in.flags()); + vvsinf(out.data(), in.data(), &size); + } else { + eval(inputs, out); + } +} + +void Sinh::eval_cpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + const auto& in = inputs[0]; + if (out.dtype() == float32 && in.flags().contiguous) { + int size = in.data_size(); + out.set_data( + allocator::malloc_or_wait(size * out.itemsize()), + size, + in.strides(), + in.flags()); + vvsinhf(out.data(), in.data(), &size); + } else { + eval(inputs, out); + } +} + +void Square::eval_cpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + auto& in = inputs[0]; + if (in.dtype() == float32 && in.flags().contiguous) { + auto size = in.data_size(); + out.set_data( + allocator::malloc_or_wait(size * out.itemsize()), + size, + in.strides(), + in.flags()); + vDSP_vsq(in.data(), 1, out.data(), 1, size); + } else { + unary(in, out, [](auto x) { return x * x; }); + } +} + +void Sqrt::eval_cpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + auto& in = inputs[0]; + if (in.dtype() == float32 && in.flags().contiguous) { + int size = in.data_size(); + out.set_data( + allocator::malloc_or_wait(size * out.itemsize()), + size, + in.strides(), + in.flags()); + if (recip_) { + vvrsqrtf(out.data(), in.data(), &size); + } else { + vvsqrtf(out.data(), in.data(), &size); + } + } else { + eval(inputs, out); + } +} + +void Subtract::eval_cpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 2); + auto& a = inputs[0]; + auto& b = inputs[1]; + + if (a.dtype() == float32) { + binary( + a, + b, + out, + [](auto x, auto y) { return x - y; }, + [](const auto* s, const auto* vec, auto* o, auto n) { + float minus_1 = -1; + vDSP_vsmsa( + (const float*)vec, 1, &minus_1, (const float*)s, (float*)o, 1, n); + }, + [](const auto* vec, const auto* s, auto* o, auto n) { + float val = -(*s); + vDSP_vsadd((const float*)vec, 1, &val, (float*)o, 1, n); + }, + [](const auto* a, const auto* b, auto* o, auto n) { + vDSP_vsub((const float*)b, 1, (const float*)a, 1, (float*)o, 1, n); + }); + } else if (a.dtype() == int32) { + binary( + a, + b, + out, + [](auto x, auto y) { return x - y; }, + UseDefaultBinaryOp(), + [](const auto* vec, const auto* s, auto* o, auto n) { + int val = -(*s); + vDSP_vsaddi((const int*)vec, 1, &val, (int*)o, 1, n); + }, + UseDefaultBinaryOp()); + } else { + binary(a, b, out, [](auto x, auto y) { return x - y; }); + } +} + +void Tan::eval_cpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + const auto& in = inputs[0]; + if (out.dtype() == float32 && in.flags().contiguous) { + int size = in.data_size(); + out.set_data( + allocator::malloc_or_wait(size * out.itemsize()), + size, + in.strides(), + in.flags()); + vvtanf(out.data(), in.data(), &size); + } else { + eval(inputs, out); + } +} + +void Tanh::eval_cpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + const auto& in = inputs[0]; + if (out.dtype() == float32 && in.flags().contiguous) { + int size = in.data_size(); + out.set_data( + allocator::malloc_or_wait(size * out.itemsize()), + size, + in.strides(), + in.flags()); + vvtanhf(out.data(), in.data(), &size); + } else { + eval(inputs, out); + } +} + +} // namespace mlx::core diff --git a/mlx/backend/accelerate/reduce.cpp b/mlx/backend/accelerate/reduce.cpp new file mode 100644 index 0000000000..19de62f56c --- /dev/null +++ b/mlx/backend/accelerate/reduce.cpp @@ -0,0 +1,147 @@ +#include + +#include +#include + +#include "mlx/backend/common/reduce.h" +#include "mlx/primitives.h" + +namespace mlx::core { + +template +void _vectorized_strided_sum(const T* x, T* accum, int size, size_t stride) { + for (int i = 0; i < size; i++) { + size_t s = stride; + T* a = accum; + while (s >= N) { + VT val = (*(VT*)x); + *(VT*)a += val; + x += N; + a += N; + s -= N; + } + while (s-- > 0) { + *a++ += *x++; + } + } +} + +// TODO: Add proper templates for the strided reduce algorithm so we don't have +// to write max/min/sum etc. +template +void _vectorized_strided_max(const T* x, T* accum, int size, size_t stride) { + for (int i = 0; i < size; i++) { + size_t s = stride; + T* a = accum; + while (s >= N) { + *(VT*)a = simd_max((*(VT*)x), (*(VT*)a)); + x += N; + a += N; + s -= N; + } + while (s-- > 0) { + *a = std::max(*a, *x); + a++; + x++; + } + } +} + +template +void _vectorized_strided_min(const T* x, T* accum, int size, size_t stride) { + for (int i = 0; i < size; i++) { + size_t s = stride; + T* a = accum; + while (s >= N) { + *(VT*)a = simd_min((*(VT*)x), (*(VT*)a)); + x += N; + a += N; + s -= N; + } + while (s-- > 0) { + *a = std::min(*a, *x); + a++; + x++; + } + } +} + +template +void _vectorized_sum(const T* x, T* accum, int size) { + VT _sum = {0}; + while (size >= N) { + _sum += (*(VT*)x); + x += N; + size -= N; + } + T sum = _sum[0]; + for (int i = 1; i < N; i++) { + sum += _sum[i]; + } + *accum += sum; +} + +void Reduce::eval_cpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + auto& in = inputs[0]; + + if (in.dtype() == float32) { + if (reduce_type_ == Reduce::Sum) { + reduction_op( + in, + out, + axes_, + 0, + [](const auto* x, auto* accum, int size, size_t stride) { + _vectorized_strided_sum( + (const float*)x, (float*)accum, size, stride); + }, + [](const auto* x, auto* accum, int size) { + float acc; + vDSP_sve((const float*)x, 1, &acc, size); + (*accum) += acc; + }, + [](auto* accum, auto x) { *accum += x; }); + return; + } else if (reduce_type_ == Reduce::Max) { + reduction_op( + in, + out, + axes_, + -std::numeric_limits::infinity(), + [](const auto* x, auto* accum, int size, size_t stride) { + _vectorized_strided_max( + (const float*)x, (float*)accum, size, stride); + }, + [](const auto* x, auto* accum, int size) { + float max; + vDSP_maxv((const float*)x, 1, &max, size); + (*accum) = (*accum < max) ? max : *accum; + }, + [](auto* accum, auto x) { (*accum) = (*accum < x) ? x : *accum; }); + return; + } else if (reduce_type_ == Reduce::Min) { + reduction_op( + in, + out, + axes_, + std::numeric_limits::infinity(), + [](const auto* x, auto* accum, int size, size_t stride) { + _vectorized_strided_min( + (const float*)x, (float*)accum, size, stride); + }, + [](const auto* x, auto* accum, int size) { + float min; + vDSP_minv((const float*)x, 1, &min, size); + (*accum) = (*accum > min) ? min : *accum; + }, + [](auto* accum, auto x) { (*accum) = (*accum > x) ? x : *accum; }); + return; + } + } + // TODO: Add integer addition and min/max using the templates above and + // simd_int16 and friends. + eval(inputs, out); +} + +} // namespace mlx::core diff --git a/mlx/backend/common/CMakeLists.txt b/mlx/backend/common/CMakeLists.txt new file mode 100644 index 0000000000..5ab4c29799 --- /dev/null +++ b/mlx/backend/common/CMakeLists.txt @@ -0,0 +1,18 @@ +target_sources( + mlx + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/binary.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/erf.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/threefry.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp +) diff --git a/mlx/backend/common/arange.h b/mlx/backend/common/arange.h new file mode 100644 index 0000000000..05d8748840 --- /dev/null +++ b/mlx/backend/common/arange.h @@ -0,0 +1,72 @@ +#pragma once + +#include "mlx/allocator.h" +#include "mlx/array.h" + +namespace mlx::core { + +namespace { + +template +void arange(T start, T next, array& out, size_t size) { + auto ptr = out.data(); + auto step_size = next - start; + for (int i = 0; i < size; ++i) { + ptr[i] = start; + start += step_size; + } +} + +} // namespace + +void arange( + const std::vector& inputs, + array& out, + double start, + double step) { + assert(inputs.size() == 0); + out.set_data(allocator::malloc_or_wait(out.nbytes())); + switch (out.dtype()) { + case bool_: + throw std::runtime_error("Bool type unsupported for arange."); + break; + case uint8: + arange(start, start + step, out, out.size()); + break; + case uint16: + arange(start, start + step, out, out.size()); + break; + case uint32: + arange(start, start + step, out, out.size()); + break; + case uint64: + arange(start, start + step, out, out.size()); + break; + case int8: + arange(start, start + step, out, out.size()); + break; + case int16: + arange(start, start + step, out, out.size()); + break; + case int32: + arange(start, start + step, out, out.size()); + break; + case int64: + arange(start, start + step, out, out.size()); + break; + case float16: + arange(start, start + step, out, out.size()); + break; + case float32: + arange(start, start + step, out, out.size()); + break; + case bfloat16: + arange(start, start + step, out, out.size()); + break; + case complex64: + arange(start, start + step, out, out.size()); + break; + } +} + +} // namespace mlx::core diff --git a/mlx/backend/common/arg_reduce.cpp b/mlx/backend/common/arg_reduce.cpp new file mode 100644 index 0000000000..221df856f3 --- /dev/null +++ b/mlx/backend/common/arg_reduce.cpp @@ -0,0 +1,110 @@ +#include + +#include "mlx/primitives.h" +#include "utils.h" + +namespace mlx::core { + +namespace { + +template +void arg_reduce(const array& in, array& out, const OpT& op, int axis) { + auto axis_size = in.shape()[axis]; + auto axis_stride = in.strides()[axis]; + std::vector strides = in.strides(); + std::vector shape = in.shape(); + strides.erase(strides.begin() + axis); + shape.erase(shape.begin() + axis); + for (uint32_t i = 0; i < out.size(); ++i) { + auto loc = elem_to_loc(i, shape, strides); + auto in_ptr = in.data() + loc; + uint32_t ind_v = 0; + InT v = (*in_ptr); + for (uint32_t j = 0; j < axis_size; ++j, in_ptr += axis_stride) { + op(j, (*in_ptr), &ind_v, &v); + } + out.data()[i] = ind_v; + } +} + +template +void arg_reduce_dispatch( + const array& in, + array& out, + ArgReduce::ReduceType rtype, + int axis) { + switch (rtype) { + case ArgReduce::ArgMin: { + auto op = [](auto ind_x, auto x, auto ind_y, auto y) { + if (x < (*y)) { + (*y) = x; + (*ind_y) = ind_x; + } + }; + arg_reduce(in, out, op, axis); + break; + } + case ArgReduce::ArgMax: { + auto op = [](auto ind_x, auto x, auto ind_y, auto y) { + if (x > (*y)) { + (*y) = x; + (*ind_y) = ind_x; + } + }; + arg_reduce(in, out, op, axis); + break; + } + } +} + +} // namespace + +void ArgReduce::eval(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + auto& in = inputs[0]; + out.set_data(allocator::malloc_or_wait(out.nbytes())); + + switch (in.dtype()) { + case bool_: + arg_reduce_dispatch(in, out, reduce_type_, axis_); + break; + case uint8: + arg_reduce_dispatch(in, out, reduce_type_, axis_); + break; + case uint16: + arg_reduce_dispatch(in, out, reduce_type_, axis_); + break; + case uint32: + arg_reduce_dispatch(in, out, reduce_type_, axis_); + break; + case uint64: + arg_reduce_dispatch(in, out, reduce_type_, axis_); + break; + case int8: + arg_reduce_dispatch(in, out, reduce_type_, axis_); + break; + case int16: + arg_reduce_dispatch(in, out, reduce_type_, axis_); + break; + case int32: + arg_reduce_dispatch(in, out, reduce_type_, axis_); + break; + case int64: + arg_reduce_dispatch(in, out, reduce_type_, axis_); + break; + case float16: + arg_reduce_dispatch(in, out, reduce_type_, axis_); + break; + case float32: + arg_reduce_dispatch(in, out, reduce_type_, axis_); + break; + case bfloat16: + arg_reduce_dispatch(in, out, reduce_type_, axis_); + break; + case complex64: + arg_reduce_dispatch(in, out, reduce_type_, axis_); + break; + } +} + +} // namespace mlx::core diff --git a/mlx/backend/common/conv.cpp b/mlx/backend/common/conv.cpp new file mode 100644 index 0000000000..863aa94114 --- /dev/null +++ b/mlx/backend/common/conv.cpp @@ -0,0 +1,541 @@ +#include + +#ifdef ACCELERATE_NEW_LAPACK +#include +#else +#include +#endif + +#include "mlx/backend/common/copy.h" +#include "mlx/primitives.h" +#include "mlx/utils.h" + +namespace mlx::core { + +namespace { + +/////////////////////////////////////////////////////////////////////////////// +// Naive reference conv +/////////////////////////////////////////////////////////////////////////////// + +template +void slow_conv_1D( + const array& in, + const array& wt, + array out, + const std::vector& padding, + const std::vector& wt_strides, + const std::vector& wt_dilation) { + const T* start_wt_ptr = wt.data(); + + const T* in_ptr = in.data(); + T* out_ptr = out.data(); + + const int N = in.shape(0); // Batch size, should be the same as out.shape(0) + const int iH = in.shape(1); // Input spatial dim + const int oH = out.shape(1); // Output spatial dim + const int O = wt.shape(0); // Out channels + const int C = wt.shape(2); // In channels + const int wH = wt.shape(1); // Weight spatial dim + + const size_t in_stride_N = in.strides()[0]; + const size_t in_stride_H = in.strides()[1]; + const size_t in_stride_C = in.strides()[2]; + + const size_t wt_stride_O = wt.strides()[0]; + const size_t wt_stride_H = wt.strides()[1]; + const size_t wt_stride_C = wt.strides()[2]; + + const size_t out_stride_N = out.strides()[0]; + const size_t out_stride_H = out.strides()[1]; + const size_t out_stride_O = out.strides()[2]; + + for (int n = 0; n < N; ++n) { + for (int oh = 0; oh < oH; ++oh) { + for (int o = 0; o < O; ++o) { + const T* filter_wt_ptr = start_wt_ptr + o * wt_stride_O; + float r = 0.; + + for (int wh = 0; wh < wH; ++wh) { + const T* wt_ptr = filter_wt_ptr + wh * wt_stride_H; + + int ih = oh * wt_strides[0] - padding[0] + wh * wt_dilation[0]; + + if (ih >= 0 && ih < iH) { + for (int c = 0; c < C; ++c) { + r += static_cast( + in_ptr[ih * in_stride_H + c * in_stride_C]) * + static_cast(wt_ptr[c * wt_stride_C]); + } // c + + } // ih check + } // wh + + out_ptr[oh * out_stride_H + o * out_stride_O] = static_cast(r); + } // o + } // oh + + in_ptr += in_stride_N; + out_ptr += out_stride_N; + + } // n +} + +template +void slow_conv_2D( + const array& in, + const array& wt, + array out, + const std::vector& padding, + const std::vector& wt_strides, + const std::vector& wt_dilation) { + const T* st_wt_ptr = wt.data(); + const T* st_in_ptr = in.data(); + T* st_out_ptr = out.data(); + + const int N = in.shape(0); // Batch size, should be the same as out.shape(0) + const int iH = in.shape(1); // Input spatial dim + const int iW = in.shape(2); // Input spatial dim + const int oH = out.shape(1); // Output spatial dim + const int oW = out.shape(2); // Output spatial dim + const int O = wt.shape(0); // Out channels + const int C = wt.shape(3); // In channels + const int wH = wt.shape(1); // Weight spatial dim + const int wW = wt.shape(2); // Weight spatial dim + + const size_t in_stride_N = in.strides()[0]; + const size_t in_stride_H = in.strides()[1]; + const size_t in_stride_W = in.strides()[2]; + const size_t in_stride_C = in.strides()[3]; + + const size_t wt_stride_O = wt.strides()[0]; + const size_t wt_stride_H = wt.strides()[1]; + const size_t wt_stride_W = wt.strides()[2]; + const size_t wt_stride_C = wt.strides()[3]; + + const size_t out_stride_N = out.strides()[0]; + const size_t out_stride_H = out.strides()[1]; + const size_t out_stride_W = out.strides()[2]; + const size_t out_stride_O = out.strides()[3]; + + auto pt_conv_no_checks = + [&](const T* in_ptr, const T* wt_ptr, T* out_ptr, int oh, int ow) { + out_ptr += oh * out_stride_H + ow * out_stride_W; + int ih_base = oh * wt_strides[0] - padding[0]; + int iw_base = ow * wt_strides[1] - padding[1]; + + for (int o = 0; o < O; ++o) { + float r = 0.; + + for (int wh = 0; wh < wH; ++wh) { + for (int ww = 0; ww < wW; ++ww) { + int ih = ih_base + wh * wt_dilation[0]; + int iw = iw_base + ww * wt_dilation[1]; + + const T* wt_ptr_pt = wt_ptr + wh * wt_stride_H + ww * wt_stride_W; + const T* in_ptr_pt = in_ptr + ih * in_stride_H + iw * in_stride_W; + + for (int c = 0; c < C; ++c) { + r += static_cast(in_ptr_pt[0]) * + static_cast(wt_ptr_pt[0]); + in_ptr_pt += in_stride_C; + wt_ptr_pt += wt_stride_C; + } // c + + } // ww + } // wh + + out_ptr[0] = static_cast(r); + out_ptr += out_stride_O; + wt_ptr += wt_stride_O; + } // o + }; + + auto pt_conv_all_checks = + [&](const T* in_ptr, const T* wt_ptr, T* out_ptr, int oh, int ow) { + out_ptr += oh * out_stride_H + ow * out_stride_W; + int ih_base = oh * wt_strides[0] - padding[0]; + int iw_base = ow * wt_strides[1] - padding[1]; + + for (int o = 0; o < O; ++o) { + float r = 0.; + + for (int wh = 0; wh < wH; ++wh) { + for (int ww = 0; ww < wW; ++ww) { + int ih = ih_base + wh * wt_dilation[0]; + int iw = iw_base + ww * wt_dilation[1]; + + if (ih >= 0 && ih < iH && iw >= 0 && iw < iW) { + const T* wt_ptr_pt = + wt_ptr + wh * wt_stride_H + ww * wt_stride_W; + const T* in_ptr_pt = + in_ptr + ih * in_stride_H + iw * in_stride_W; + + for (int c = 0; c < C; ++c) { + r += static_cast(in_ptr_pt[0]) * + static_cast(wt_ptr_pt[0]); + in_ptr_pt += in_stride_C; + wt_ptr_pt += wt_stride_C; + } // c + + } // ih, iw check + } // ww + } // wh + + out_ptr[0] = static_cast(r); + out_ptr += out_stride_O; + wt_ptr += wt_stride_O; + } // o + }; + + int oH_border_0 = 0; + int oH_border_1 = (padding[0] + wt_strides[0] + 1) / wt_strides[0]; + int oH_border_2 = (iH + padding[0] - wH * wt_dilation[0]) / wt_strides[0]; + int oH_border_3 = oH; + + int oW_border_0 = 0; + int oW_border_1 = (padding[1] + wt_strides[0] + 1) / wt_strides[1]; + int oW_border_2 = (iW + padding[1] - wW * wt_dilation[1]) / wt_strides[1]; + int oW_border_3 = oW; + + for (int n = 0; n < N; ++n) { + // Case 1: oh might put us out of bounds + for (int oh = oH_border_0; oh < oH_border_1; ++oh) { + for (int ow = 0; ow < oW; ++ow) { + pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow); + } // ow + } // oh + + // Case 2: oh in bounds + for (int oh = oH_border_1; oh < oH_border_2; ++oh) { + // Case a: ow might put us out of bounds + for (int ow = oW_border_0; ow < oW_border_1; ++ow) { + pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow); + } // ow + + // Case b: ow in bounds + for (int ow = oW_border_1; ow < oW_border_2; ++ow) { + pt_conv_no_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow); + } // ow + + // Case c: ow might put us out of bounds + for (int ow = oW_border_2; ow < oW_border_3; ++ow) { + pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow); + } // ow + + } // oh + + // Case 3: oh might put us out of bounds + for (int oh = oH_border_2; oh < oH_border_3; ++oh) { + for (int ow = 0; ow < oW; ++ow) { + pt_conv_all_checks(st_in_ptr, st_wt_ptr, st_out_ptr, oh, ow); + } // ow + } // oh + + st_in_ptr += in_stride_N; + st_out_ptr += out_stride_N; + + } // n +} + +void dispatch_slow_conv_1D( + const array& in, + const array& wt, + array out, + const std::vector& padding, + const std::vector& wt_strides, + const std::vector& wt_dilation) { + if (in.dtype() == float32) { + return slow_conv_1D(in, wt, out, padding, wt_strides, wt_dilation); + } else if (in.dtype() == float16) { + return slow_conv_1D( + in, wt, out, padding, wt_strides, wt_dilation); + } else if (in.dtype() == bfloat16) { + return slow_conv_1D( + in, wt, out, padding, wt_strides, wt_dilation); + } else { + throw std::invalid_argument( + "[Convolution::eval] got unsupported data type."); + } +} + +void dispatch_slow_conv_2D( + const array& in, + const array& wt, + array out, + const std::vector& padding, + const std::vector& wt_strides, + const std::vector& wt_dilation) { + if (in.dtype() == float32) { + return slow_conv_2D(in, wt, out, padding, wt_strides, wt_dilation); + } else if (in.dtype() == float16) { + return slow_conv_2D( + in, wt, out, padding, wt_strides, wt_dilation); + } else if (in.dtype() == bfloat16) { + return slow_conv_2D( + in, wt, out, padding, wt_strides, wt_dilation); + } else { + throw std::invalid_argument( + "[Convolution::eval] got unsupported data type."); + } +} + +/////////////////////////////////////////////////////////////////////////////// +// Explicit gemm conv +/////////////////////////////////////////////////////////////////////////////// + +void explicit_gemm_conv_1D_cpu( + const array& in, + const array& wt, + array out, + const std::vector& padding, + const std::vector& wt_strides, + const std::vector& wt_dilation) { + const int N = in.shape(0); // Batch size, should be the same as out.shape(0) + const int iH = in.shape(1); // Input spatial dim + const int oH = out.shape(1); // Output spatial dim + const int O = wt.shape(0); // Out channels + const int C = wt.shape(2); // In channels + const int wH = wt.shape(1); // Weight spatial dim + + auto conv_dtype = float32; + + // Pad input + std::vector padded_shape = {N, iH + 2 * padding[0], C}; + array in_padded(padded_shape, conv_dtype, nullptr, {}); + + // Fill with zeros + copy(array(0, conv_dtype), in_padded, CopyType::Scalar); + + // Pick input slice from padded + size_t data_offset = padding[0] * in_padded.strides()[1]; + array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {}); + in_padded_slice.copy_shared_buffer( + in_padded, + in_padded.strides(), + in_padded.flags(), + in_padded_slice.size(), + data_offset); + + // Copy input values into the slice + copy_inplace(in, in_padded_slice, CopyType::GeneralGeneral); + + // Make strided view + std::vector strided_shape = {N, oH, wH, C}; + + std::vector strided_strides = { + in_padded.strides()[0], + in_padded.strides()[1] * wt_strides[0], + in_padded.strides()[1], + in_padded.strides()[2]}; + auto flags = in_padded.flags(); + + array in_strided_view(strided_shape, in_padded.dtype(), nullptr, {}); + in_strided_view.copy_shared_buffer( + in_padded, strided_strides, flags, in_strided_view.size(), 0); + + // Materialize strided view + std::vector strided_reshape = {N * oH, wH * C}; + array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {}); + copy(in_strided_view, in_strided, CopyType::General); + + // Check wt dtype and prepare + auto gemm_wt = wt; + auto gemm_out = out; + + if (wt.dtype() != float32 || !wt.flags().row_contiguous) { + auto ctype = + wt.flags().row_contiguous ? CopyType::Vector : CopyType::General; + gemm_wt = array(wt.shape(), float32, nullptr, {}); + copy(wt, gemm_wt, ctype); + } + + if (out.dtype() != float32) { + gemm_out = array(out.shape(), float32, nullptr, {}); + gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes())); + } + + // Peform gemm + cblas_sgemm( + CblasRowMajor, + CblasNoTrans, // no trans A + CblasTrans, // transB + strided_reshape[0], // M + O, // N + strided_reshape[1], // K + 1.0f, // alpha + in_strided.data(), + strided_reshape[1], // lda + gemm_wt.data(), + strided_reshape[1], // ldb + 0.0f, // beta + gemm_out.data(), + O // ldc + ); + + // Copy results if needed + if (out.dtype() != float32) { + copy(gemm_out, out, CopyType::Vector); + } +} + +void explicit_gemm_conv_2D_cpu( + const array& in, + const array& wt, + array out, + const std::vector& padding, + const std::vector& wt_strides, + const std::vector& wt_dilation) { + const int N = in.shape(0); // Batch size, should be the same as out.shape(0) + const int iH = in.shape(1); // Input spatial dim + const int iW = in.shape(2); // Input spatial dim + const int oH = out.shape(1); // Output spatial dim + const int oW = out.shape(2); // Output spatial dim + const int O = wt.shape(0); // Out channels + const int C = wt.shape(3); // In channels + const int wH = wt.shape(1); // Weight spatial dim + const int wW = wt.shape(2); // Weight spatial dim + + auto conv_dtype = out.dtype(); + + // Pad input + std::vector padded_shape = { + N, iH + 2 * padding[0], iW + 2 * padding[1], C}; + array in_padded(padded_shape, conv_dtype, nullptr, {}); + + // Fill with zeros + copy(array(0, conv_dtype), in_padded, CopyType::Scalar); + + // Pick input slice from padded + size_t data_offset = + padding[0] * in_padded.strides()[1] + padding[1] * in_padded.strides()[2]; + array in_padded_slice(in.shape(), in_padded.dtype(), nullptr, {}); + in_padded_slice.copy_shared_buffer( + in_padded, + in_padded.strides(), + in_padded.flags(), + in_padded_slice.size(), + data_offset); + + // Copy input values into the slice + copy_inplace(in, in_padded_slice, CopyType::GeneralGeneral); + + // Make strided view + std::vector strided_shape = {N, oH, oW, wH, wW, C}; + + std::vector strided_strides = { + in_padded.strides()[0], + in_padded.strides()[1] * wt_strides[0], + in_padded.strides()[2] * wt_strides[1], + in_padded.strides()[1], + in_padded.strides()[2], + in_padded.strides()[3]}; + auto flags = in_padded.flags(); + + array in_strided_view(strided_shape, in_padded.dtype(), nullptr, {}); + in_strided_view.copy_shared_buffer( + in_padded, strided_strides, flags, in_strided_view.size(), 0); + + // Materialize strided view + std::vector strided_reshape = {N * oH * oW, wH * wW * C}; + array in_strided(strided_reshape, in_strided_view.dtype(), nullptr, {}); + copy(in_strided_view, in_strided, CopyType::General); + + // Check wt dtype and prepare + auto gemm_wt = wt; + auto gemm_out = out; + + if (wt.dtype() != float32 || !wt.flags().row_contiguous) { + auto ctype = + wt.flags().row_contiguous ? CopyType::Vector : CopyType::General; + gemm_wt = array(wt.shape(), float32, nullptr, {}); + copy(wt, gemm_wt, ctype); + } + + if (out.dtype() != float32) { + gemm_out = array(out.shape(), float32, nullptr, {}); + gemm_out.set_data(allocator::malloc_or_wait(gemm_out.nbytes())); + } + + // Peform gemm + cblas_sgemm( + CblasRowMajor, + CblasNoTrans, // no trans A + CblasTrans, // transB + strided_reshape[0], // M + O, // N + strided_reshape[1], // K + 1.0f, // alpha + in_strided.data(), + strided_reshape[1], // lda + gemm_wt.data(), + strided_reshape[1], // ldb + 0.0f, // beta + gemm_out.data(), + O // ldc + ); + + // Copy results if needed + if (out.dtype() != float32) { + copy(gemm_out, out, CopyType::Vector); + } +} + +/////////////////////////////////////////////////////////////////////////////// +// Conv routing +/////////////////////////////////////////////////////////////////////////////// + +void conv_1D_cpu( + const array& in, + const array& wt, + array out, + const std::vector& padding, + const std::vector& wt_strides, + const std::vector& wt_dilation) { + if (wt_dilation[0] == 1) { + return explicit_gemm_conv_1D_cpu( + in, wt, out, padding, wt_strides, wt_dilation); + } + + return dispatch_slow_conv_1D(in, wt, out, padding, wt_strides, wt_dilation); +} + +void conv_2D_cpu( + const array& in, + const array& wt, + array out, + const std::vector& padding, + const std::vector& wt_strides, + const std::vector& wt_dilation) { + return dispatch_slow_conv_2D(in, wt, out, padding, wt_strides, wt_dilation); +} + +} // namespace + +void Convolution::eval(const std::vector& inputs, array& out) { + out.set_data(allocator::malloc_or_wait(out.nbytes())); + + auto& in = inputs[0]; + auto& wt = inputs[1]; + + // 2D convolution + if (in.ndim() == (2 + 2)) { + return conv_2D_cpu( + in, wt, out, padding_, kernel_strides_, kernel_dilation_); + } + // 1D convolution + else if (in.ndim() == (1 + 2)) { + return conv_1D_cpu( + in, wt, out, padding_, kernel_strides_, kernel_dilation_); + } + // Throw error + else { + std::ostringstream msg; + msg << "[Convolution::eval] Convolution currently only supports" + << " 1D and 2D convolutions. Got inputs with " << in.ndim() - 2 + << " spatial dimensions"; + throw std::invalid_argument(msg.str()); + } +} + +} // namespace mlx::core diff --git a/mlx/backend/common/copy.cpp b/mlx/backend/common/copy.cpp new file mode 100644 index 0000000000..36d79ac9d2 --- /dev/null +++ b/mlx/backend/common/copy.cpp @@ -0,0 +1,308 @@ +#include + +#include "mlx/allocator.h" +#include "mlx/backend/common/copy.h" + +namespace mlx::core { + +namespace { + +template +void copy_single(const array& src, array& dst) { + auto val = static_cast(src.data()[0]); + auto dst_ptr = dst.data(); + for (int i = 0; i < dst.size(); ++i) { + dst_ptr[i] = val; + } +} + +template +void copy_vector(const array& src, array& dst) { + auto src_ptr = src.data(); + auto dst_ptr = dst.data(); + std::copy(src_ptr, src_ptr + src.data_size(), dst_ptr); +} + +template +void copy_general_dim1(const array& src, array& dst) { + const SrcT* src_ptr = src.data(); + DstT* dst_ptr = dst.data(); + size_t src_idx = 0; + size_t dst_idx = 0; + for (size_t i = 0; i < src.shape()[0]; ++i) { + dst_ptr[dst_idx++] = static_cast(src_ptr[src_idx]); + src_idx += src.strides()[0]; + } +} + +template +void copy_general_dim2(const array& src, array& dst) { + const SrcT* src_ptr = src.data(); + DstT* dst_ptr = dst.data(); + size_t src_idx = 0; + size_t dst_idx = 0; + for (size_t i = 0; i < src.shape()[0]; ++i) { + for (size_t j = 0; j < src.shape()[1]; ++j) { + dst_ptr[dst_idx++] = static_cast(src_ptr[src_idx]); + src_idx += src.strides()[1]; + } + src_idx += src.strides()[0] - src.strides()[1] * src.shape()[1]; + } +} + +template +void copy_general_dim3(const array& src, array& dst) { + const SrcT* src_ptr = src.data(); + DstT* dst_ptr = dst.data(); + size_t src_idx = 0; + size_t dst_idx = 0; + for (size_t i = 0; i < src.shape()[0]; ++i) { + for (size_t j = 0; j < src.shape()[1]; ++j) { + for (size_t k = 0; k < src.shape()[2]; ++k) { + dst_ptr[dst_idx++] = static_cast(src_ptr[src_idx]); + src_idx += src.strides()[2]; + } + src_idx += src.strides()[1] - src.strides()[2] * src.shape()[2]; + } + src_idx += src.strides()[0] - src.strides()[1] * src.shape()[1]; + } +} + +template +void copy_general_dim4(const array& src, array& dst) { + const SrcT* src_ptr = src.data(); + DstT* dst_ptr = dst.data(); + size_t src_idx = 0; + size_t dst_idx = 0; + for (size_t i = 0; i < src.shape()[0]; ++i) { + for (size_t j = 0; j < src.shape()[1]; ++j) { + for (size_t k = 0; k < src.shape()[2]; ++k) { + for (size_t ii = 0; ii < src.shape()[3]; ++ii) { + dst_ptr[dst_idx++] = static_cast(src_ptr[src_idx]); + src_idx += src.strides()[3]; + } + src_idx += src.strides()[2] - src.strides()[3] * src.shape()[3]; + } + src_idx += src.strides()[1] - src.strides()[2] * src.shape()[2]; + } + src_idx += src.strides()[0] - src.strides()[1] * src.shape()[1]; + } +} + +template +void copy_general(const array& src, array& dst) { + switch (src.ndim()) { + case 1: + copy_general_dim1(src, dst); + return; + case 2: + copy_general_dim2(src, dst); + return; + case 3: + copy_general_dim3(src, dst); + return; + case 4: + copy_general_dim4(src, dst); + return; + } + + auto src_ptr = src.data(); + auto dst_ptr = dst.data(); + for (size_t i = 0; i < dst.size(); ++i) { + size_t src_elem = elem_to_loc(i, src.shape(), src.strides()); + dst_ptr[i] = static_cast(src_ptr[src_elem]); + } +} + +template +inline void copy_general_general_dims( + const array& src, + array& dst, + size_t offset_src, + size_t offset_dst) { + if constexpr (D > 1) { + int axis = src.ndim() - D; + auto stride_src = src.strides()[axis]; + auto stride_dst = dst.strides()[axis]; + auto N = src.shape(axis); + for (int i = 0; i < N; i++) { + copy_general_general_dims( + src, dst, offset_src, offset_dst); + offset_src += stride_src; + offset_dst += stride_dst; + } + } else { + int axis = src.ndim() - 1; + auto stride_src = src.strides()[axis]; + auto stride_dst = dst.strides()[axis]; + auto N = src.shape(axis); + const SrcT* src_ptr = src.data() + offset_src; + DstT* dst_ptr = dst.data() + offset_dst; + for (int i = 0; i < N; i++) { + *dst_ptr = static_cast(*src_ptr); + src_ptr += stride_src; + dst_ptr += stride_dst; + } + } +} + +template +void copy_general_general(const array& src, array& dst) { + switch (src.ndim()) { + case 1: + copy_general_general_dims(src, dst, 0, 0); + return; + case 2: + copy_general_general_dims(src, dst, 0, 0); + return; + case 3: + copy_general_general_dims(src, dst, 0, 0); + return; + case 4: + copy_general_general_dims(src, dst, 0, 0); + return; + case 5: + copy_general_general_dims(src, dst, 0, 0); + return; + } + + int size = std::accumulate( + src.shape().begin() - 5, src.shape().end(), 1, std::multiplies()); + for (int i = 0; i < src.size(); i += size) { + size_t offset_src = elem_to_loc(i, src.shape(), src.strides()); + size_t offset_dst = elem_to_loc(i, dst.shape(), dst.strides()); + copy_general_general_dims(src, dst, offset_src, offset_dst); + } +} + +template +void copy(const array& src, array& dst, CopyType ctype) { + switch (ctype) { + case CopyType::Scalar: + copy_single(src, dst); + return; + case CopyType::Vector: + copy_vector(src, dst); + return; + case CopyType::General: + copy_general(src, dst); + return; + case CopyType::GeneralGeneral: + copy_general_general(src, dst); + } +} + +template +void copy(const array& src, array& dst, CopyType ctype) { + switch (dst.dtype()) { + case bool_: + copy(src, dst, ctype); + break; + case uint8: + copy(src, dst, ctype); + break; + case uint16: + copy(src, dst, ctype); + break; + case uint32: + copy(src, dst, ctype); + break; + case uint64: + copy(src, dst, ctype); + break; + case int8: + copy(src, dst, ctype); + break; + case int16: + copy(src, dst, ctype); + break; + case int32: + copy(src, dst, ctype); + break; + case int64: + copy(src, dst, ctype); + break; + case float16: + copy(src, dst, ctype); + break; + case float32: + copy(src, dst, ctype); + break; + case bfloat16: + copy(src, dst, ctype); + break; + case complex64: + copy(src, dst, ctype); + break; + } +} + +} // namespace + +void copy_inplace(const array& src, array& dst, CopyType ctype) { + switch (src.dtype()) { + case bool_: + copy(src, dst, ctype); + break; + case uint8: + copy(src, dst, ctype); + break; + case uint16: + copy(src, dst, ctype); + break; + case uint32: + copy(src, dst, ctype); + break; + case uint64: + copy(src, dst, ctype); + break; + case int8: + copy(src, dst, ctype); + break; + case int16: + copy(src, dst, ctype); + break; + case int32: + copy(src, dst, ctype); + break; + case int64: + copy(src, dst, ctype); + break; + case float16: + copy(src, dst, ctype); + break; + case float32: + copy(src, dst, ctype); + break; + case bfloat16: + copy(src, dst, ctype); + break; + case complex64: + copy(src, dst, ctype); + break; + } +} + +void copy(const array& src, array& dst, CopyType ctype) { + // Allocate the output + switch (ctype) { + case CopyType::Vector: + dst.set_data( + allocator::malloc_or_wait(src.data_size() * dst.itemsize()), + src.data_size(), + src.strides(), + src.flags()); + break; + case CopyType::Scalar: + case CopyType::General: + case CopyType::GeneralGeneral: + dst.set_data(allocator::malloc_or_wait(dst.nbytes())); + break; + } + if (ctype == CopyType::GeneralGeneral) { + ctype = CopyType::General; + } + copy_inplace(src, dst, ctype); +} + +} // namespace mlx::core diff --git a/mlx/backend/common/default_primitives.cpp b/mlx/backend/common/default_primitives.cpp new file mode 100644 index 0000000000..3ed9f27500 --- /dev/null +++ b/mlx/backend/common/default_primitives.cpp @@ -0,0 +1,130 @@ +#include + +#include "mlx/array.h" +#include "mlx/backend/common/copy.h" +#include "mlx/backend/common/utils.h" +#include "mlx/primitives.h" + +#define DEFAULT(primitive) \ + void primitive::eval_cpu(const std::vector& inputs, array& out) { \ + primitive::eval(inputs, out); \ + } + +namespace mlx::core { + +DEFAULT(Abs) +DEFAULT(Add) +DEFAULT(Arange) +DEFAULT(ArcCos) +DEFAULT(ArcCosh) +DEFAULT(ArcSin) +DEFAULT(ArcSinh) +DEFAULT(ArcTan) +DEFAULT(ArcTanh) +DEFAULT(ArgPartition) +DEFAULT(ArgReduce) +DEFAULT(ArgSort) +DEFAULT(AsType) +DEFAULT(AsStrided) +DEFAULT(Broadcast) +DEFAULT(Concatenate) +DEFAULT(Convolution) +DEFAULT(Copy) +DEFAULT(Cos) +DEFAULT(Cosh) +DEFAULT(Divide) +DEFAULT(Equal) +DEFAULT(Erf) +DEFAULT(ErfInv) +DEFAULT(Exp) +DEFAULT(FFT) +DEFAULT(Full) +DEFAULT(Gather) +DEFAULT(Greater) +DEFAULT(GreaterEqual) +DEFAULT(Less) +DEFAULT(LessEqual) +DEFAULT(Load) +DEFAULT(Log) +DEFAULT(Log1p) +DEFAULT(LogicalNot) +DEFAULT(LogAddExp) +DEFAULT(Maximum) +DEFAULT(Minimum) +DEFAULT(Multiply) +DEFAULT(Negative) +DEFAULT(NotEqual) +DEFAULT(Pad) +DEFAULT(Partition) +DEFAULT(Power) +DEFAULT(RandomBits) +DEFAULT(Reduce) +DEFAULT(Reshape) +DEFAULT(Scan) +DEFAULT(Scatter) +DEFAULT(Sigmoid) +DEFAULT(Sign) +DEFAULT(Sin) +DEFAULT(Sinh) +DEFAULT(Slice) +DEFAULT(Softmax) +DEFAULT(Sort) +DEFAULT(Square) +DEFAULT(Sqrt) +DEFAULT(StopGradient) +DEFAULT(Subtract) +DEFAULT(Tan) +DEFAULT(Tanh) +DEFAULT(Transpose) + +void Matmul::eval_cpu(const std::vector& inputs, array& out) { + if (out.dtype() != float32) { + throw std::runtime_error( + "[Matmul::eval_cpu] Currently only supports float32."); + } + out.set_data(allocator::malloc_or_wait(out.nbytes())); + + auto& a_pre = inputs[0]; + auto& b_pre = inputs[1]; + + auto check_transpose = [](const array& arr) { + auto stx = arr.strides()[arr.ndim() - 2]; + auto sty = arr.strides()[arr.ndim() - 1]; + if (stx == arr.shape(-1) && sty == 1) { + return std::make_tuple(false, stx, arr); + } else if (stx == 1 && sty == arr.shape(-2)) { + return std::make_tuple(true, sty, arr); + } else { + array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); + copy(arr, arr_copy, CopyType::General); + size_t stx = arr.shape(-1); + return std::make_tuple(false, stx, arr_copy); + } + }; + + auto [a_transposed, lda, a] = check_transpose(a_pre); + auto [b_transposed, ldb, b] = check_transpose(b_pre); + int M = a.shape(-2); + int N = b.shape(-1); + int K = a.shape(-1); + for (int i = 0; i < (a.size() / (M * K)); ++i) { + cblas_sgemm( + CblasRowMajor, + a_transposed ? CblasTrans : CblasNoTrans, // transA + b_transposed ? CblasTrans : CblasNoTrans, // transB + M, + N, + K, + 1.0f, // alpha + a.data() + elem_to_loc(M * K * i, a.shape(), a.strides()), + lda, + b.data() + elem_to_loc(K * N * i, b.shape(), b.strides()), + ldb, + 0.0f, // beta + out.data() + M * N * i, + out.shape(-1) // ldc + ); + } +} + +} // namespace mlx::core diff --git a/mlx/backend/common/erf.cpp b/mlx/backend/common/erf.cpp new file mode 100644 index 0000000000..0b876a1b0e --- /dev/null +++ b/mlx/backend/common/erf.cpp @@ -0,0 +1,38 @@ +#include + +namespace mlx::core { + +/* Approximation to the inverse error function. + * Based on code from: + * https://stackoverflow.com/questions/27229371/inverse-error-function-in-c#answer-49743348 + */ +float erfinv(float a) { + auto t = std::fma(a, 0.0f - a, 1.0f); + t = std::log(t); + float p; + if (std::abs(t) > 6.125f) { // maximum ulp error = 2.35793 + p = 3.03697567e-10f; // 0x1.4deb44p-32 + p = std::fma(p, t, 2.93243101e-8f); // 0x1.f7c9aep-26 + p = std::fma(p, t, 1.22150334e-6f); // 0x1.47e512p-20 + p = std::fma(p, t, 2.84108955e-5f); // 0x1.dca7dep-16 + p = std::fma(p, t, 3.93552968e-4f); // 0x1.9cab92p-12 + p = std::fma(p, t, 3.02698812e-3f); // 0x1.8cc0dep-9 + p = std::fma(p, t, 4.83185798e-3f); // 0x1.3ca920p-8 + p = std::fma(p, t, -2.64646143e-1f); // -0x1.0eff66p-2 + p = std::fma(p, t, 8.40016484e-1f); // 0x1.ae16a4p-1 + } else { // maximum ulp error = 2.35002 + p = 5.43877832e-9f; // 0x1.75c000p-28 + p = std::fma(p, t, 1.43285448e-7f); // 0x1.33b402p-23 + p = std::fma(p, t, 1.22774793e-6f); // 0x1.499232p-20 + p = std::fma(p, t, 1.12963626e-7f); // 0x1.e52cd2p-24 + p = std::fma(p, t, -5.61530760e-5f); // -0x1.d70bd0p-15 + p = std::fma(p, t, -1.47697632e-4f); // -0x1.35be90p-13 + p = std::fma(p, t, 2.31468678e-3f); // 0x1.2f6400p-9 + p = std::fma(p, t, 1.15392581e-2f); // 0x1.7a1e50p-7 + p = std::fma(p, t, -2.32015476e-1f); // -0x1.db2aeep-3 + p = std::fma(p, t, 8.86226892e-1f); // 0x1.c5bf88p-1 + } + return a * p; +} + +} // namespace mlx::core diff --git a/mlx/backend/common/erf.h b/mlx/backend/common/erf.h new file mode 100644 index 0000000000..d7698754dd --- /dev/null +++ b/mlx/backend/common/erf.h @@ -0,0 +1,10 @@ + +namespace mlx::core { + +/* Approximation to the inverse error function. + * Based on code from: + * https://stackoverflow.com/questions/27229371/inverse-error-function-in-c#answer-49743348 + */ +float erfinv(float a); + +} // namespace mlx::core diff --git a/mlx/backend/common/indexing.cpp b/mlx/backend/common/indexing.cpp new file mode 100644 index 0000000000..9572c1caf8 --- /dev/null +++ b/mlx/backend/common/indexing.cpp @@ -0,0 +1,377 @@ +#include +#include +#include + +#include "mlx/allocator.h" +#include "mlx/primitives.h" + +#include "mlx/backend/common/copy.h" +#include "mlx/backend/common/utils.h" + +namespace mlx::core { + +template +inline size_t offset_neg_idx(IdxT idx, size_t size) { + return (idx < 0) ? idx + size : idx; +} + +template <> +inline size_t offset_neg_idx(bool idx, size_t) { + return idx; +} + +template <> +inline size_t offset_neg_idx(uint32_t idx, size_t) { + return idx; +} + +template +void gather( + const array& src, + const std::vector& inds, + array& out, + const std::vector& axes, + const std::vector& slice_sizes) { + // If the array is row contiguous then we can do a contiguous copy given + // two conditions on the slice size: + // - Any number of leading ones in the slice sizes are allowed + // - All other slice sizes match the corresponding dimension except the + // first non-singleton slice size + // If the array is col contiguous then the reverse is the case: + // - Any number of trailing ones in the slice sizes are allowed + // - All other slice sizes match the corresponding dimension except the + // first non-singleton slice size from the end + + bool can_copy = false; + if (src.flags().row_contiguous) { + can_copy = true; + + // Ignore leading 1s + int i = 0; + for (; i < slice_sizes.size() && slice_sizes[i] == 1; ++i) + ; + + // Check the remaining + i++; + for (; i < src.ndim() && can_copy; ++i) { + can_copy = (src.shape(i) == slice_sizes[i]); + } + } else if (src.flags().col_contiguous) { + can_copy = true; + + // Ignore trailing 1s + int i = slice_sizes.size() - 1; + for (; i >= 0 && slice_sizes[i] == 1; --i) + ; + + // Skip the next slice size and check the remaining + i--; + for (; i >= 0 && can_copy; --i) { + can_copy = (src.shape(i) == slice_sizes[i]); + } + } + size_t slice_size = 1; + for (auto s : slice_sizes) { + slice_size *= s; + } + size_t ind_size = slice_size == 0 ? 0 : out.size() / slice_size; + const T* src_ptr = src.data(); + T* dst_ptr = out.data(); + size_t out_idx = 0; + + for (int idx = 0; idx < ind_size; idx++) { + size_t src_idx = 0; + for (int ii = 0; ii < inds.size(); ++ii) { + auto ax = axes[ii]; + auto idx_loc = elem_to_loc(idx, inds[ii]); + auto idx_val = + offset_neg_idx(inds[ii].data()[idx_loc], src.shape(ax)); + src_idx += (idx_val * src.strides()[ax]); + } + + if (slice_size == 1) { + dst_ptr[out_idx++] = src_ptr[src_idx]; + } else if (can_copy) { + std::copy( + src_ptr + src_idx, src_ptr + src_idx + slice_size, dst_ptr + out_idx); + out_idx += slice_size; + } else { + for (int jj = 0; jj < slice_size; jj++) { + auto src_offset = elem_to_loc(jj, slice_sizes, src.strides()); + dst_ptr[out_idx++] = src_ptr[src_idx + src_offset]; + } + } + } +} + +template +void dispatch_gather( + const array& src, + const std::vector& inds, + array& out, + const std::vector& axes, + const std::vector& size) { + switch (out.dtype()) { + case bool_: + gather(src, inds, out, axes, size); + break; + case uint8: + gather(src, inds, out, axes, size); + break; + case uint16: + gather(src, inds, out, axes, size); + break; + case uint32: + gather(src, inds, out, axes, size); + break; + case uint64: + gather(src, inds, out, axes, size); + break; + case int8: + gather(src, inds, out, axes, size); + break; + case int16: + gather(src, inds, out, axes, size); + break; + case int32: + gather(src, inds, out, axes, size); + break; + case int64: + gather(src, inds, out, axes, size); + break; + case float16: + gather(src, inds, out, axes, size); + break; + case float32: + gather(src, inds, out, axes, size); + break; + case bfloat16: + gather(src, inds, out, axes, size); + break; + case complex64: + gather(src, inds, out, axes, size); + break; + } +} + +void Gather::eval(const std::vector& inputs, array& out) { + out.set_data(allocator::malloc_or_wait(out.nbytes())); + + auto& src = inputs[0]; + std::vector inds(inputs.begin() + 1, inputs.end()); + + if (inds.empty()) { + dispatch_gather(src, inds, out, axes_, slice_sizes_); + return; + } + + switch (inds[0].dtype()) { + case bool_: + dispatch_gather(src, inds, out, axes_, slice_sizes_); + break; + case uint8: + dispatch_gather(src, inds, out, axes_, slice_sizes_); + break; + case uint16: + dispatch_gather(src, inds, out, axes_, slice_sizes_); + break; + case uint32: + dispatch_gather(src, inds, out, axes_, slice_sizes_); + break; + case uint64: + dispatch_gather(src, inds, out, axes_, slice_sizes_); + break; + case int8: + dispatch_gather(src, inds, out, axes_, slice_sizes_); + break; + case int16: + dispatch_gather(src, inds, out, axes_, slice_sizes_); + break; + case int32: + dispatch_gather(src, inds, out, axes_, slice_sizes_); + break; + case int64: + dispatch_gather(src, inds, out, axes_, slice_sizes_); + break; + case float16: + case float32: + case bfloat16: + case complex64: + throw std::runtime_error( + "[Gather::eval] Cannot gather with floating point indices."); + break; + } +} + +template +void scatter( + const array& updates, + array& out, + const std::vector& inds, + const std::vector& axes, + const OpT& op) { + int nind = inds.size(); + auto inds_ndim = updates.ndim() - out.ndim(); + size_t n_updates = nind ? inds[0].size() : 1; + + std::vector update_shape( + updates.shape().begin() + inds_ndim, updates.shape().end()); + size_t update_size = 1; + for (auto us : update_shape) { + update_size *= us; + } + + for (int i = 0; i < n_updates; ++i) { + size_t out_offset = 0; + for (int j = 0; j < nind; ++j) { + auto ax = axes[j]; + auto idx_loc = elem_to_loc(i, inds[j]); + auto idx_val = + offset_neg_idx(inds[j].data()[idx_loc], out.shape(ax)); + out_offset += (idx_val * out.strides()[ax]); + } + for (int j = 0; j < update_size; ++j) { + auto update_loc = elem_to_loc(i * update_size + j, updates); + auto out_loc = elem_to_loc(j, update_shape, out.strides()); + op(updates.data()[update_loc], + out.data() + out_offset + out_loc); + } + } +} + +template +void dispatch_scatter_inds( + array& out, + const std::vector& indices, + const array& updates, + const std::vector& axes, + Scatter::ReduceType rtype) { + switch (rtype) { + case Scatter::None: + scatter( + updates, out, indices, axes, [](auto x, auto* y) { (*y) = x; }); + break; + case Scatter::Sum: + scatter( + updates, out, indices, axes, [](auto x, auto* y) { (*y) += x; }); + break; + case Scatter::Prod: + scatter( + updates, out, indices, axes, [](auto x, auto* y) { (*y) *= x; }); + break; + case Scatter::Max: + scatter(updates, out, indices, axes, [](auto x, auto* y) { + (*y) = (*y > x) ? *y : x; + }); + break; + case Scatter::Min: + scatter(updates, out, indices, axes, [](auto x, auto* y) { + (*y) = (*y < x) ? *y : x; + }); + break; + } +} + +template +void dispatch_scatter( + array& out, + const std::vector& inds, + const array& updates, + const std::vector& axes, + Scatter::ReduceType rtype) { + if (inds.empty()) { + dispatch_scatter_inds(out, inds, updates, axes, rtype); + return; + } + + switch (inds[0].dtype()) { + case bool_: + dispatch_scatter_inds(out, inds, updates, axes, rtype); + break; + case uint8: + dispatch_scatter_inds(out, inds, updates, axes, rtype); + break; + case uint16: + dispatch_scatter_inds(out, inds, updates, axes, rtype); + break; + case uint32: + dispatch_scatter_inds(out, inds, updates, axes, rtype); + break; + case uint64: + dispatch_scatter_inds(out, inds, updates, axes, rtype); + break; + case int8: + dispatch_scatter_inds(out, inds, updates, axes, rtype); + break; + case int16: + dispatch_scatter_inds(out, inds, updates, axes, rtype); + break; + case int32: + dispatch_scatter_inds(out, inds, updates, axes, rtype); + break; + case int64: + dispatch_scatter_inds(out, inds, updates, axes, rtype); + break; + case float16: + case float32: + case bfloat16: + case complex64: + throw std::runtime_error( + "[Scatter::eval_cpu] Cannot scatter with floating point indices."); + } +} + +void Scatter::eval(const std::vector& inputs, array& out) { + assert(inputs.size() >= 2); + + auto& src = inputs[0]; + std::vector inds(inputs.begin() + 1, inputs.end() - 1); + auto& updates = inputs.back(); + + // Copy src into out (copy allocates memory for out) + copy(src, out, CopyType::General); + + switch (src.dtype()) { + case bool_: + dispatch_scatter(out, inds, updates, axes_, reduce_type_); + break; + case uint8: + dispatch_scatter(out, inds, updates, axes_, reduce_type_); + break; + case uint16: + dispatch_scatter(out, inds, updates, axes_, reduce_type_); + break; + case uint32: + dispatch_scatter(out, inds, updates, axes_, reduce_type_); + break; + case uint64: + dispatch_scatter(out, inds, updates, axes_, reduce_type_); + break; + case int8: + dispatch_scatter(out, inds, updates, axes_, reduce_type_); + break; + case int16: + dispatch_scatter(out, inds, updates, axes_, reduce_type_); + break; + case int32: + dispatch_scatter(out, inds, updates, axes_, reduce_type_); + break; + case int64: + dispatch_scatter(out, inds, updates, axes_, reduce_type_); + break; + case float16: + dispatch_scatter(out, inds, updates, axes_, reduce_type_); + break; + case float32: + dispatch_scatter(out, inds, updates, axes_, reduce_type_); + break; + case bfloat16: + dispatch_scatter(out, inds, updates, axes_, reduce_type_); + break; + case complex64: + dispatch_scatter(out, inds, updates, axes_, reduce_type_); + break; + } +} + +} // namespace mlx::core diff --git a/mlx/backend/common/load.cpp b/mlx/backend/common/load.cpp new file mode 100644 index 0000000000..5f5f173299 --- /dev/null +++ b/mlx/backend/common/load.cpp @@ -0,0 +1,52 @@ +#include +#include +#include + +#include "mlx/allocator.h" +#include "mlx/load.h" +#include "mlx/primitives.h" + +namespace mlx::core { + +namespace { + +template +void swap_endianess(uint8_t* data_bytes, size_t N) { + struct Elem { + uint8_t bytes[scalar_size]; + }; + + Elem* data = reinterpret_cast(data_bytes); + + for (size_t i = 0; i < N; i++) { + for (size_t j = 0; j < (scalar_size / 2); j++) { + std::swap(data[i].bytes[j], data[i].bytes[scalar_size - j - 1]); + } + } +} + +} // namespace + +void Load::eval(const std::vector& inputs, array& out) { + assert(inputs.size() == 0); + out.set_data(allocator::malloc_or_wait(out.nbytes())); + + reader_->seek(offset_, std::ios_base::beg); + reader_->read(out.data(), out.nbytes()); + + if (swap_endianness_) { + switch (out.itemsize()) { + case 2: + swap_endianess<2>(out.data(), out.data_size()); + break; + case 4: + swap_endianess<4>(out.data(), out.data_size()); + break; + case 8: + swap_endianess<8>(out.data(), out.data_size()); + break; + } + } +} + +} // namespace mlx::core diff --git a/mlx/backend/common/primitives.cpp b/mlx/backend/common/primitives.cpp new file mode 100644 index 0000000000..86bc746c68 --- /dev/null +++ b/mlx/backend/common/primitives.cpp @@ -0,0 +1,622 @@ +#include +#include +#include +#include +#include + +#include "mlx/allocator.h" +#include "mlx/backend/common/arange.h" +#include "mlx/backend/common/copy.h" +#include "mlx/backend/common/erf.h" +#include "mlx/backend/common/threefry.h" +#include "mlx/backend/common/unary.h" +#include "mlx/backend/common/utils.h" +#include "mlx/primitives.h" +#include "mlx/utils.h" + +namespace mlx::core { + +void Abs::eval(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + auto& in = inputs[0]; + if (is_unsigned(in.dtype())) { + // No-op for unsigned types + out.copy_shared_buffer(in); + } else { + unary(in, out, AbsOp()); + } +} + +void Arange::eval(const std::vector& inputs, array& out) { + arange(inputs, out, start_, step_); +} + +void ArcCos::eval(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + const auto& in = inputs[0]; + if (is_floating_point(out.dtype())) { + unary_fp(in, out, [](auto x) { return std::acos(x); }); + } else { + throw std::invalid_argument( + "[arccos] Cannot compute inverse cosine of elements in array" + " with non floating point type."); + } +} + +void ArcCosh::eval(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + const auto& in = inputs[0]; + if (is_floating_point(out.dtype())) { + unary_fp(in, out, [](auto x) { return std::acosh(x); }); + } else { + throw std::invalid_argument( + "[arccosh] Cannot compute inverse hyperbolic cosine of elements in" + " array with non floating point type."); + } +} + +void ArcSin::eval(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + const auto& in = inputs[0]; + if (is_floating_point(out.dtype())) { + unary_fp(in, out, [](auto x) { return std::asin(x); }); + } else { + throw std::invalid_argument( + "[arcsin] Cannot compute inverse sine of elements in array" + " with non floating point type."); + } +} + +void ArcSinh::eval(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + const auto& in = inputs[0]; + if (is_floating_point(out.dtype())) { + unary_fp(in, out, [](auto x) { return std::asinh(x); }); + } else { + throw std::invalid_argument( + "[arcsinh] Cannot compute inverse hyperbolic sine of elements in" + " array with non floating point type."); + } +} + +void ArcTan::eval(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + const auto& in = inputs[0]; + if (is_floating_point(out.dtype())) { + unary_fp(in, out, [](auto x) { return std::atan(x); }); + } else { + throw std::invalid_argument( + "[arctan] Cannot compute inverse tangent of elements in array" + " with non floating point type."); + } +} + +void ArcTanh::eval(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + const auto& in = inputs[0]; + if (is_floating_point(out.dtype())) { + unary_fp(in, out, [](auto x) { return std::atanh(x); }); + } else { + throw std::invalid_argument( + "[arctanh] Cannot compute inverse hyperbolic tangent of elements in" + " array with non floating point type."); + } +} + +void AsType::eval(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + auto& in = inputs[0]; + CopyType ctype = in.flags().contiguous ? CopyType::Vector : CopyType::General; + copy(in, out, ctype); +} + +void AsStrided::eval(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + + auto& in = inputs[0]; + + if (!in.flags().row_contiguous) { + // Just ensuring that inputs[0] came from the ops which would ensure the + // input is row contiguous. + throw std::runtime_error( + "AsStrided must be used with row contiguous arrays only."); + } + + // Compute the flags given the shape and strides + bool row_contiguous = true, col_contiguous = true; + size_t r = 1, c = 1; + for (int i = strides_.size() - 1, j = 0; i >= 0; i--, j++) { + row_contiguous &= (r == strides_[i]) || (shape_[i] == 1); + col_contiguous &= (c == strides_[j]) || (shape_[j] == 1); + r *= shape_[i]; + c *= shape_[j]; + } + auto flags = in.flags(); + // TODO: Compute the contiguous flag in a better way cause now we are + // unnecessarily strict. + flags.contiguous = row_contiguous || col_contiguous; + flags.row_contiguous = row_contiguous; + flags.col_contiguous = col_contiguous; + + // There is no easy way to compute the actual data size so we use out.size(). + // The contiguous flag will almost certainly not be set so no code should + // rely on data_size anyway. + size_t data_size = out.size(); + + return out.copy_shared_buffer(in, strides_, flags, data_size, offset_); +} + +void Broadcast::eval(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + const auto& in = inputs[0]; + if (out.size() == 0) { + out.set_data(nullptr); + return; + } + std::vector strides(out.ndim(), 0); + int diff = out.ndim() - in.ndim(); + for (int i = in.ndim() - 1; i >= 0; --i) { + strides[i + diff] = (in.shape()[i] == 1) ? 0 : in.strides()[i]; + } + auto flags = in.flags(); + if (out.size() > in.size()) { + flags.row_contiguous = flags.col_contiguous = false; + } + out.copy_shared_buffer(in, strides, flags, in.data_size()); +} + +void Concatenate::eval(const std::vector& inputs, array& out) { + std::vector sizes; + sizes.push_back(0); + for (auto& p : inputs) { + sizes.push_back(p.shape(axis_)); + } + std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin()); + + out.set_data(allocator::malloc_or_wait(out.nbytes())); + + auto strides = out.strides(); + auto flags = out.flags(); + flags.row_contiguous = false; + flags.col_contiguous = false; + flags.contiguous = false; + for (int i = 0; i < inputs.size(); i++) { + array out_slice(inputs[i].shape(), out.dtype(), nullptr, {}); + size_t data_offset = strides[axis_] * sizes[i]; + out_slice.copy_shared_buffer( + out, strides, flags, out_slice.size(), data_offset); + copy_inplace(inputs[i], out_slice, CopyType::GeneralGeneral); + } +} + +void Copy::eval(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + out.copy_shared_buffer(inputs[0]); +} + +void Cos::eval(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + const auto& in = inputs[0]; + if (is_floating_point(out.dtype())) { + unary_fp(in, out, [](auto x) { return std::cos(x); }); + } else { + throw std::invalid_argument( + "[cos] Cannot compute cosine of elements in array" + " with non floating point type."); + } +} + +void Cosh::eval(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + const auto& in = inputs[0]; + if (is_floating_point(out.dtype())) { + unary_fp(in, out, [](auto x) { return std::cosh(x); }); + } else { + throw std::invalid_argument( + "[cosh] Cannot compute hyperbolic cosine of elements in array" + " with non floating point type."); + } +} + +void Erf::eval(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + const auto& in = inputs[0]; + switch (out.dtype()) { + case float32: + out.set_data(allocator::malloc_or_wait(out.nbytes())); + unary_op(in, out, [](auto x) { return std::erf(x); }); + break; + case float16: + out.set_data(allocator::malloc_or_wait(out.nbytes())); + unary_op(in, out, [](auto x) { + return static_cast(std::erf(static_cast(x))); + }); + break; + case bfloat16: + out.set_data(allocator::malloc_or_wait(out.nbytes())); + unary_op(in, out, [](auto x) { + return static_cast(std::erf(static_cast(x))); + }); + break; + default: + throw std::invalid_argument( + "[erf] Error function only defined for arrays" + " with real floating point type."); + } +} + +void ErfInv::eval(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + const auto& in = inputs[0]; + switch (out.dtype()) { + case float32: + out.set_data(allocator::malloc_or_wait(out.nbytes())); + unary_op(in, out, [](auto x) { return erfinv(x); }); + break; + case float16: + out.set_data(allocator::malloc_or_wait(out.nbytes())); + unary_op(in, out, [](auto x) { + return static_cast(erfinv(static_cast(x))); + }); + break; + case bfloat16: + out.set_data(allocator::malloc_or_wait(out.nbytes())); + unary_op(in, out, [](auto x) { + return static_cast(erfinv(static_cast(x))); + }); + break; + default: + throw std::invalid_argument( + "[erf_inv] Inverse error function only defined for arrays" + " with real floating point type."); + } +} + +void Exp::eval(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + const auto& in = inputs[0]; + + if (is_floating_point(out.dtype())) { + unary_fp(in, out, [](auto x) { return std::exp(x); }); + } else { + throw std::invalid_argument( + "[exp] Cannot exponentiate elements in array" + " with non floating point type."); + } +} + +void Full::eval(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + auto& in = inputs[0]; + assert(in.dtype() == out.dtype()); + CopyType ctype; + if (in.data_size() == 1) { + ctype = CopyType::Scalar; + } else if (in.flags().contiguous) { + ctype = CopyType::Vector; + } else { + ctype = CopyType::General; + } + copy(in, out, ctype); +} + +void Log::eval(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + const auto& in = inputs[0]; + if (is_floating_point(out.dtype())) { + switch (base_) { + case Base::e: + unary_fp(in, out, [](auto x) { return std::log(x); }); + break; + case Base::two: + unary_fp(in, out, [](auto x) { return std::log2(x); }); + break; + case Base::ten: + unary_fp(in, out, [](auto x) { return std::log10(x); }); + break; + } + } else { + throw std::invalid_argument( + "[log] Cannot compute log of elements in array with" + " non floating point type."); + } +} + +void Log1p::eval(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + const auto& in = inputs[0]; + if (is_floating_point(out.dtype())) { + unary_fp(in, out, [](auto x) { return std::log1p(x); }); + } else { + throw std::invalid_argument( + "[log1p] Cannot compute log of elements in array with" + " non floating point type."); + } +} + +void LogicalNot::eval(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + auto& in = inputs[0]; + unary(in, out, [](auto x) { return !x; }); +} + +void Negative::eval(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + auto& in = inputs[0]; + unary(in, out, [](auto x) { return -x; }); +} + +void Pad::eval(const std::vector& inputs, array& out) { + // Inputs must be base input array and scalar val array + assert(inputs.size() == 2); + auto& in = inputs[0]; + auto& val = inputs[1]; + + // Padding value must be a scalar + assert(val.size() == 1); + + // Padding value, input and output must be of the same type + assert(val.dtype() == in.dtype() && in.dtype() == out.dtype()); + + // Fill output with val + copy(val, out, CopyType::Scalar); + + // Find offset for start of input values + size_t data_offset = 0; + for (int i = 0; i < axes_.size(); i++) { + auto ax = axes_[i] < 0 ? out.ndim() + axes_[i] : axes_[i]; + data_offset += out.strides()[ax] * low_pad_size_[i]; + } + + // Extract slice from output where input will be pasted + array out_slice(in.shape(), out.dtype(), nullptr, {}); + out_slice.copy_shared_buffer( + out, out.strides(), out.flags(), out_slice.size(), data_offset); + + // Copy input values into the slice + copy_inplace(in, out_slice, CopyType::GeneralGeneral); +} + +void RandomBits::eval(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + // keys has shape (N1, ..., NK, 2) + // out has shape (N1, ..., NK, M1, M2, ...) + auto& keys = inputs[0]; + size_t num_keys = keys.size() / 2; + + size_t elems_per_key = out.size() / num_keys; + size_t bytes_per_key = out.itemsize() * elems_per_key; + out.set_data(allocator::malloc_or_wait(out.nbytes())); + + auto kptr = inputs[0].data(); + auto cptr = out.data(); + size_t out_skip = (bytes_per_key + 4 - 1) / 4; + auto half_size = out_skip / 2; + bool even = out_skip % 2 == 0; + for (int i = 0; i < num_keys; ++i, cptr += bytes_per_key) { + auto ptr = reinterpret_cast(cptr); + // Get ith key + auto kidx = 2 * i; + auto k1_elem = elem_to_loc(kidx, keys.shape(), keys.strides()); + auto k2_elem = elem_to_loc(kidx + 1, keys.shape(), keys.strides()); + auto key = std::make_pair(kptr[k1_elem], kptr[k2_elem]); + + std::pair count{0, half_size + !even}; + for (; count.first + 1 < half_size; count.first++, count.second++) { + std::tie(ptr[count.first], ptr[count.second]) = + random::threefry2x32_hash(key, count); + } + if (count.first < half_size) { + auto rb = random::threefry2x32_hash(key, count); + ptr[count.first++] = rb.first; + if (bytes_per_key % 4 > 0) { + std::copy( + reinterpret_cast(&rb.second), + reinterpret_cast(&rb.second) + bytes_per_key % 4, + cptr + 4 * count.second); + } else { + ptr[count.second] = rb.second; + } + } + if (!even) { + count.second = 0; + ptr[half_size] = random::threefry2x32_hash(key, count).first; + } + } +} + +void Reshape::eval(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + const auto& in = inputs[0]; + if (in.flags().row_contiguous) { + // For row contiguous reshapes: + // - Shallow copy the buffer + // - If reshaping into a vector (all singleton dimensions except one) it + // becomes col contiguous again. + auto flags = in.flags(); + auto max_dim = std::max_element(out.shape().begin(), out.shape().end()); + flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim; + out.copy_shared_buffer(in, out.strides(), flags, in.data_size()); + } else { + copy(in, out, in.data_size() == 1 ? CopyType::Scalar : CopyType::General); + } +} + +void Sigmoid::eval(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + const auto& in = inputs[0]; + if (is_floating_point(out.dtype())) { + auto sigmoid_op = [](auto x) { + auto one = static_cast(1.0); + return one / (one + std::exp(-x)); + }; + unary_fp(in, out, sigmoid_op); + } else { + throw std::invalid_argument( + "[sigmoid] Cannot sigmoid of elements in array with" + " non floating point type."); + } +} + +void Sign::eval(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + auto& in = inputs[0]; + if (in.dtype() == bool_) { + out.copy_shared_buffer(in); + } else { + unary(in, out, SignOp()); + } +} + +void Sin::eval(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + const auto& in = inputs[0]; + if (is_floating_point(out.dtype())) { + unary_fp(in, out, [](auto x) { return std::sin(x); }); + } else { + throw std::invalid_argument( + "[sin] Cannot compute sine of elements in array" + " with non floating point type."); + } +} + +void Sinh::eval(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + const auto& in = inputs[0]; + if (is_floating_point(out.dtype())) { + unary_fp(in, out, [](auto x) { return std::sinh(x); }); + } else { + throw std::invalid_argument( + "[sinh] Cannot compute hyperbolic sine of elements in array" + " with non floating point type."); + } +} + +void Slice::eval(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + if (out.size() == 0) { + out.set_data(nullptr); + return; + } + auto& in = inputs[0]; + auto strides = in.strides(); + auto flags = in.flags(); + size_t data_offset = 0; + for (int i = 0; i < in.ndim(); ++i) { + data_offset += start_indices_[i] * in.strides()[i]; + strides[i] *= strides_[i]; + } + + // Compute row/col contiguity + size_t data_size = 1; + size_t f_stride = 1; + size_t b_stride = 1; + flags.row_contiguous = true; + flags.col_contiguous = true; + for (int i = 0, ri = out.ndim() - 1; ri >= 0; i++, ri--) { + flags.col_contiguous &= strides[i] == f_stride || out.shape(i) == 1; + flags.row_contiguous &= strides[ri] == b_stride || out.shape(ri) == 1; + f_stride *= out.shape(i); + b_stride *= out.shape(ri); + if (strides[i] > 0) { + data_size *= out.shape(i); + } + } + + if (data_size == 1) { + // Broadcasted scalar array is contiguous. + flags.contiguous = true; + } else if (data_size == in.data_size()) { + // Means we sliced a broadcasted dimension so leave the "no holes" flag + // alone. + } else { + // We sliced something. So either we are row or col contiguous or we + // punched a hole. + flags.contiguous &= flags.row_contiguous || flags.col_contiguous; + } + + out.copy_shared_buffer(in, strides, flags, data_size, data_offset); +} + +void Square::eval(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + auto& in = inputs[0]; + unary(in, out, [](auto x) { return x * x; }); +} + +void Sqrt::eval(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + auto& in = inputs[0]; + if (recip_) { + unary_fp(in, out, [](auto x) { + return static_cast(1.0) / sqrt(x); + }); + } else { + unary_fp(in, out, [](auto x) { return sqrt(x); }); + } +} + +void StopGradient::eval(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + out.copy_shared_buffer(inputs[0]); +} + +void Tan::eval(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + const auto& in = inputs[0]; + if (is_floating_point(out.dtype())) { + unary_fp(in, out, [](auto x) { return std::tan(x); }); + } else { + throw std::invalid_argument( + "[tan] Cannot compute tangent of elements in array" + " with non floating point type."); + } +} + +void Tanh::eval(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + const auto& in = inputs[0]; + if (is_floating_point(out.dtype())) { + unary_fp(in, out, [](auto x) { return std::tanh(x); }); + } else { + throw std::invalid_argument( + "[tanh] Cannot compute hyperbolic tangent of elements in array" + " with non floating point type."); + } +} + +void Transpose::eval(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + std::vector out_strides(out.ndim()); + auto& in = inputs[0]; + for (int ax = 0; ax < axes_.size(); ++ax) { + out_strides[ax] = in.strides()[axes_[ax]]; + } + + // Conditions for {row/col}_contiguous + // - array must be contiguous (no gaps) + // - underlying buffer size should have the same size as the array + // - cumulative product of shapes is equal to the strides (we can ignore axes + // with size == 1) + // - in the forward direction (column contiguous) + // - in the reverse direction (row contiguous) + // - vectors are both row and col contiguous (hence if both row/col are + // true, they stay true) + auto flags = in.flags(); + if (flags.contiguous && in.data_size() == in.size()) { + size_t f_stride = 1; + size_t b_stride = 1; + flags.col_contiguous = true; + flags.row_contiguous = true; + for (int i = 0, ri = out.ndim() - 1; i < out.ndim(); ++i, --ri) { + flags.col_contiguous &= (out_strides[i] == f_stride || out.shape(i) == 1); + f_stride *= out.shape(i); + flags.row_contiguous &= + (out_strides[ri] == b_stride || out.shape(ri) == 1); + b_stride *= out.shape(ri); + } + } + out.copy_shared_buffer(in, out_strides, flags, in.data_size()); +} + +} // namespace mlx::core diff --git a/mlx/backend/common/reduce.cpp b/mlx/backend/common/reduce.cpp new file mode 100644 index 0000000000..90a466f7d2 --- /dev/null +++ b/mlx/backend/common/reduce.cpp @@ -0,0 +1,215 @@ +#include +#include +#include + +#include "mlx/backend/common/reduce.h" +#include "mlx/primitives.h" + +namespace mlx::core { + +namespace { + +template +struct Limits { + static const U max; + static const U min; +}; + +#define instantiate_default_limit(type) \ + template <> \ + struct Limits { \ + static constexpr type max = std::numeric_limits::max(); \ + static constexpr type min = std::numeric_limits::min(); \ + }; + +instantiate_default_limit(uint8_t); +instantiate_default_limit(uint16_t); +instantiate_default_limit(uint32_t); +instantiate_default_limit(uint64_t); +instantiate_default_limit(int8_t); +instantiate_default_limit(int16_t); +instantiate_default_limit(int32_t); +instantiate_default_limit(int64_t); + +#define instantiate_float_limit(type) \ + template <> \ + struct Limits { \ + static const type max; \ + static const type min; \ + }; + +instantiate_float_limit(float16_t); +instantiate_float_limit(bfloat16_t); +instantiate_float_limit(float); +instantiate_float_limit(complex64_t); + +template <> +struct Limits { + static constexpr bool max = true; + static constexpr bool min = false; +}; + +const float Limits::max = std::numeric_limits::infinity(); +const float Limits::min = -std::numeric_limits::infinity(); +const bfloat16_t Limits::max = + std::numeric_limits::infinity(); +const bfloat16_t Limits::min = + -std::numeric_limits::infinity(); +const float16_t Limits::max = std::numeric_limits::infinity(); +const float16_t Limits::min = + -std::numeric_limits::infinity(); +const complex64_t Limits::max = + std::numeric_limits::infinity(); +const complex64_t Limits::min = + -std::numeric_limits::infinity(); + +struct AndReduce { + template + void operator()(bool* a, T b) { + (*a) &= (b != 0); + } + + void operator()(bool* y, bool x) { + (*y) &= x; + } +}; + +struct OrReduce { + template + void operator()(bool* a, T b) { + (*a) |= (b != 0); + } + + void operator()(bool* y, bool x) { + (*y) |= x; + } +}; + +template +void reduce_dispatch_out( + const array& in, + array& out, + Reduce::ReduceType rtype, + const std::vector& axes) { + switch (rtype) { + case Reduce::And: { + reduction_op(in, out, axes, true, AndReduce()); + break; + } + case Reduce::Or: { + reduction_op(in, out, axes, false, OrReduce()); + break; + } + case Reduce::Sum: { + auto op = [](auto y, auto x) { (*y) = (*y) + x; }; + switch (out.dtype()) { + case bool_: + reduction_op(in, out, axes, false, op); + break; + case uint8: + reduction_op(in, out, axes, 0, op); + break; + case uint16: + reduction_op(in, out, axes, 0, op); + break; + case uint32: + reduction_op(in, out, axes, 0, op); + break; + case uint64: + reduction_op(in, out, axes, 0, op); + break; + case int8: + reduction_op(in, out, axes, 0, op); + break; + case int16: + reduction_op(in, out, axes, 0, op); + break; + case int32: + reduction_op(in, out, axes, 0, op); + break; + case int64: + reduction_op(in, out, axes, 0, op); + break; + case float16: + reduction_op(in, out, axes, 0.0f, op); + break; + case float32: + reduction_op(in, out, axes, 0.0f, op); + break; + case bfloat16: + reduction_op(in, out, axes, 0.0f, op); + break; + case complex64: + reduction_op(in, out, axes, complex64_t{0.0f}, op); + break; + } + } break; + case Reduce::Prod: { + auto op = [](auto y, auto x) { (*y) *= x; }; + reduction_op(in, out, axes, 1, op); + break; + } + case Reduce::Max: { + auto op = [](auto y, auto x) { (*y) = (*y > x) ? *y : x; }; + auto init = Limits::min; + reduction_op(in, out, axes, init, op); + break; + } + case Reduce::Min: { + auto op = [](auto y, auto x) { (*y) = (*y < x) ? *y : x; }; + auto init = Limits::max; + reduction_op(in, out, axes, init, op); + break; + } + } +} + +} // namespace + +void Reduce::eval(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + auto& in = inputs[0]; + switch (in.dtype()) { + case bool_: + reduce_dispatch_out(in, out, reduce_type_, axes_); + break; + case uint8: + reduce_dispatch_out(in, out, reduce_type_, axes_); + break; + case uint16: + reduce_dispatch_out(in, out, reduce_type_, axes_); + break; + case uint32: + reduce_dispatch_out(in, out, reduce_type_, axes_); + break; + case uint64: + reduce_dispatch_out(in, out, reduce_type_, axes_); + break; + case int8: + reduce_dispatch_out(in, out, reduce_type_, axes_); + break; + case int16: + reduce_dispatch_out(in, out, reduce_type_, axes_); + break; + case int32: + reduce_dispatch_out(in, out, reduce_type_, axes_); + break; + case int64: + reduce_dispatch_out(in, out, reduce_type_, axes_); + break; + case float16: + reduce_dispatch_out(in, out, reduce_type_, axes_); + break; + case float32: + reduce_dispatch_out(in, out, reduce_type_, axes_); + break; + case bfloat16: + reduce_dispatch_out(in, out, reduce_type_, axes_); + break; + case complex64: + reduce_dispatch_out(in, out, reduce_type_, axes_); + break; + } +} + +} // namespace mlx::core diff --git a/mlx/backend/common/reduce.h b/mlx/backend/common/reduce.h new file mode 100644 index 0000000000..31b1716c33 --- /dev/null +++ b/mlx/backend/common/reduce.h @@ -0,0 +1,364 @@ +#pragma once + +#include "mlx/backend/common/utils.h" + +namespace mlx::core { + +namespace { + +enum ReductionOpType { + // Self-explanatory. Read everything and produce 1 output. + ContiguousAllReduce, + + // The input is contiguous and the last axis is reduced + // N1xR1xN2xR2x...xNnxRn + ContiguousReduce, + + // The input is contiguous and the last axis is not reduced + // R1xN1xR2xN2x...xRnxNn + ContiguousStridedReduce, + + // The input is not contiguous but the last axis is and it is reduced so we + // need to figure out the offsets but we can call the contiguous reduce after + // that. + // N3xR1xN1xR4x...xRn + GeneralContiguousReduce, + + // The input is not contiguous but the last reduction axis and the last axis + // are so we need to figure out the offset but we can call the strided reduce + // after that. + GeneralStridedReduce, + + // The input is not contiguous after the reduction axis and it may contain + // 0-stride axes or transpositions. We could copy the strides and produce a + // transposed outcome or we can read the input out of order and write the + // output in order. + GeneralReduce +}; + +// Helper for the ndimensional strided loop +// Should this be in utils? +inline void nd_loop( + std::function callback, + const std::vector& shape, + const std::vector& strides) { + std::function loop_inner; + loop_inner = [&](int dim, int offset) { + if (dim < shape.size() - 1) { + int size = shape[dim]; + size_t stride = strides[dim]; + for (int i = 0; i < size; i++) { + loop_inner(dim + 1, offset + i * stride); + } + } else { + int size = shape[dim]; + size_t stride = strides[dim]; + for (int i = 0; i < size; i++) { + callback(offset + i * stride); + } + } + }; + loop_inner(0, 0); +} + +std::pair, std::vector> shapes_without_reduction_axes( + const array& x, + const std::vector& axes) { + std::vector shape = x.shape(); + std::vector strides = x.strides(); + + for (int i = axes.size() - 1; i >= 0; i--) { + int a = axes[i]; + shape.erase(shape.begin() + a); + strides.erase(strides.begin() + a); + } + + return std::make_pair(shape, strides); +} + +template +struct DefaultStridedReduce { + Op op; + + DefaultStridedReduce(Op op_) : op(op_) {} + + void operator()(const T* x, U* accumulator, int size, size_t stride) { + for (int i = 0; i < size; i++) { + U* moving_accumulator = accumulator; + for (int j = 0; j < stride; j++) { + op(moving_accumulator, *x); + moving_accumulator++; + x++; + } + } + } +}; + +template +struct DefaultContiguousReduce { + Op op; + + DefaultContiguousReduce(Op op_) : op(op_) {} + + void operator()(const T* x, U* accumulator, int size) { + while (size-- > 0) { + op(accumulator, *x); + x++; + } + } +}; + +struct ReductionPlan { + ReductionOpType type; + std::vector shape; + std::vector strides; + + ReductionPlan( + ReductionOpType type_, + std::vector shape_, + std::vector strides_) + : type(type_), shape(std::move(shape_)), strides(std::move(strides_)) {} + ReductionPlan(ReductionOpType type_) : type(type_) {} +}; + +ReductionPlan get_reduction_plan(const array& x, const std::vector axes) { + // The data is all there and we are reducing over everything + if (x.size() == x.data_size() && axes.size() == x.ndim() && + (x.flags().row_contiguous || x.flags().col_contiguous)) { + return ContiguousAllReduce; + } + + // Row contiguous input so the output is row contiguous + if (x.flags().row_contiguous) { + // Merge consecutive axes + std::vector shape = {x.shape(axes[0])}; + std::vector strides = {x.strides()[axes[0]]}; + for (int i = 1; i < axes.size(); i++) { + if (axes[i] - 1 == axes[i - 1]) { + shape.back() *= x.shape(axes[i]); + strides.back() = x.strides()[axes[i]]; + } else { + shape.push_back(x.shape(axes[i])); + strides.push_back(x.strides()[axes[i]]); + } + } + + if (strides.back() == 1) { + return ReductionPlan(ContiguousReduce, shape, strides); + } else if (strides.back() > 1) { + return ReductionPlan(ContiguousStridedReduce, shape, strides); + } + } + + // Let's check if we can optimize our access patterns + // + // 1. We have a reduction axis with stride 1. Simply call + // GeneralContiguousReduce and be done with it. + // 2. We have transpositions and we are not reducing over the axis with + // stride 1. However, we are reducing over an axis where everything is + // contiguous in memory to the right of that axis. We can call strided + // reduce and be done with it. + // 2. We have weird transpositions and expands. Copy the strides to the + // output, then call strided reduce. + + // Sort reduction axes by stride in order to merge them and figure out if we + // have a contiguous reduction. + std::vector> reductions; + for (auto a : axes) { + reductions.push_back(std::make_pair(x.shape(a), x.strides()[a])); + } + std::sort(reductions.begin(), reductions.end(), [](auto a, auto b) { + return a.second > b.second; + }); + // Extract the two smallest and try to merge them in case the contiguous + // reduction can be bigger than just the last axis. + for (int i = reductions.size() - 1; i >= 1; i--) { + auto a = reductions[i]; + auto b = reductions[i - 1]; + + // b.stride = a.shape * a.stride then a and b are contiguous + if (b.second == a.first * a.second) { + reductions.erase(reductions.begin() + i); + reductions[i - 1] = std::make_pair(a.first * b.first, a.second); + } + } + + std::vector shape; + std::vector strides; + for (auto r : reductions) { + shape.push_back(r.first); + strides.push_back(r.second); + } + + // We can call the contiguous reduction op for every weird way the input is + // structured in the rest of the axes. + if (strides.back() == 1) { + return ReductionPlan(GeneralContiguousReduce, shape, strides); + } + + // Delegate to the general strided reduction op if the axes after + // strides.back() are contiguous. + if (strides.back() > 1) { + int size = 1; + for (int i = x.ndim() - 1; i >= 0; i--) { + if (axes.back() == i) { + continue; + } + if (x.strides()[i] != size) { + break; + } + size *= x.shape(i); + } + if (size >= strides.back()) { + return ReductionPlan(GeneralStridedReduce, shape, strides); + } + } + + return ReductionPlan(GeneralReduce, shape, strides); +} + +template +void reduction_op( + const array& x, + array& out, + const std::vector& axes, + U init, + OpS ops, + OpC opc, + Op op) { + out.set_data(allocator::malloc_or_wait(out.nbytes())); + ReductionPlan plan = get_reduction_plan(x, axes); + + if (plan.type == ContiguousAllReduce) { + U* out_ptr = out.data(); + *out_ptr = init; + opc(x.data(), out_ptr, x.size()); + return; + } + + std::vector shape; + std::vector strides; + + if (plan.type == ContiguousReduce && plan.shape.size() == 1) { + int reduction_size = plan.shape[0]; + const T* x_ptr = x.data(); + U* out_ptr = out.data(); + for (int i = 0; i < out.size(); i++, out_ptr++, x_ptr += reduction_size) { + *out_ptr = init; + opc(x_ptr, out_ptr, reduction_size); + } + return; + } + + if (plan.type == GeneralContiguousReduce || plan.type == ContiguousReduce) { + int reduction_size = plan.shape.back(); + plan.shape.pop_back(); + plan.strides.pop_back(); + const T* x_ptr = x.data(); + U* out_ptr = out.data(); + // Unrolling the following loop (and implementing it in order for + // ContiguousReduce) should hold extra performance boost. + std::tie(shape, strides) = shapes_without_reduction_axes(x, axes); + if (plan.shape.size() == 0) { + for (int i = 0; i < out.size(); i++, out_ptr++) { + int offset = elem_to_loc(i, shape, strides); + *out_ptr = init; + opc(x_ptr + offset, out_ptr, reduction_size); + } + } else { + for (int i = 0; i < out.size(); i++, out_ptr++) { + int offset = elem_to_loc(i, shape, strides); + *out_ptr = init; + nd_loop( + [&](int extra_offset) { + opc(x_ptr + offset + extra_offset, out_ptr, reduction_size); + }, + plan.shape, + plan.strides); + } + } + return; + } + + if (plan.type == ContiguousStridedReduce && plan.shape.size() == 1) { + int reduction_size = plan.shape.back(); + size_t reduction_stride = plan.strides.back(); + plan.shape.pop_back(); + plan.strides.pop_back(); + const T* x_ptr = x.data(); + U* out_ptr = out.data(); + for (int i = 0; i < out.size(); i += reduction_stride) { + std::fill_n(out_ptr, reduction_stride, init); + ops(x_ptr, out_ptr, reduction_size, reduction_stride); + x_ptr += reduction_stride * reduction_size; + out_ptr += reduction_stride; + } + return; + } + + if (plan.type == GeneralStridedReduce || + plan.type == ContiguousStridedReduce) { + int reduction_size = plan.shape.back(); + size_t reduction_stride = plan.strides.back(); + plan.shape.pop_back(); + plan.strides.pop_back(); + const T* x_ptr = x.data(); + U* out_ptr = out.data(); + std::tie(shape, strides) = shapes_without_reduction_axes(x, axes); + if (plan.shape.size() == 0) { + for (int i = 0; i < out.size(); i += reduction_stride) { + int offset = elem_to_loc(i, shape, strides); + std::fill_n(out_ptr, reduction_stride, init); + ops(x_ptr + offset, out_ptr, reduction_size, reduction_stride); + out_ptr += reduction_stride; + } + } else { + for (int i = 0; i < out.size(); i += reduction_stride) { + int offset = elem_to_loc(i, shape, strides); + std::fill_n(out_ptr, reduction_stride, init); + nd_loop( + [&](int extra_offset) { + ops(x_ptr + offset + extra_offset, + out_ptr, + reduction_size, + reduction_stride); + }, + plan.shape, + plan.strides); + out_ptr += reduction_stride; + } + } + return; + } + + if (plan.type == GeneralReduce) { + const T* x_ptr = x.data(); + U* out_ptr = out.data(); + std::tie(shape, strides) = shapes_without_reduction_axes(x, axes); + for (int i = 0; i < out.size(); i++, out_ptr++) { + int offset = elem_to_loc(i, shape, strides); + U val = init; + nd_loop( + [&](int extra_offset) { op(&val, *(x_ptr + offset + extra_offset)); }, + plan.shape, + plan.strides); + *out_ptr = val; + } + } +} + +template +void reduction_op( + const array& x, + array& out, + const std::vector& axes, + U init, + Op op) { + DefaultStridedReduce ops(op); + DefaultContiguousReduce opc(op); + reduction_op(x, out, axes, init, ops, opc, op); +} + +} // namespace + +} // namespace mlx::core diff --git a/mlx/backend/common/scan.cpp b/mlx/backend/common/scan.cpp new file mode 100644 index 0000000000..89e09e373c --- /dev/null +++ b/mlx/backend/common/scan.cpp @@ -0,0 +1,323 @@ +#include + +#include "mlx/backend/common/copy.h" +#include "mlx/backend/common/utils.h" +#include "mlx/primitives.h" + +namespace mlx::core { + +namespace { + +template +struct DefaultContiguousScan { + Op op; + U init; + + DefaultContiguousScan(Op op_, U init_) : op(op_), init(init_) {} + + void operator()( + const T* input, + U* output, + int count, + int stride, + bool reverse, + bool inclusive) { + if (!reverse) { + if (inclusive) { + for (int i = 0; i < count; i++) { + *output = *input; + for (int j = 1; j < stride; j++) { + input++; + output++; + op(output, output - 1, input); + } + output++; + input++; + } + } else { + for (int i = 0; i < count; i++) { + *output = init; + for (int j = 1; j < stride; j++) { + op(output + 1, output, input); + input++; + output++; + } + output++; + input++; + } + } + } else { + if (inclusive) { + for (int i = 0; i < count; i++) { + output += stride - 1; + input += stride - 1; + *output = *input; + for (int j = 1; j < stride; j++) { + input--; + output--; + op(output, output + 1, input); + } + output += stride; + input += stride; + } + } else { + for (int i = 0; i < count; i++) { + output += stride - 1; + input += stride - 1; + *output = init; + for (int j = 1; j < stride; j++) { + op(output - 1, output, input); + input--; + output--; + } + output += stride; + input += stride; + } + } + } + } +}; + +template +struct DefaultStridedScan { + Op op; + U init; + + DefaultStridedScan(Op op_, U init_) : op(op_), init(init_) {} + + void operator()( + const T* input, + U* output, + int count, + int size, + int stride, + bool reverse, + bool inclusive) { + // TODO: Vectorize the following naive implementation + if (!reverse) { + if (inclusive) { + for (int i = 0; i < count; i++) { + std::copy(input, input + stride, output); + output += stride; + input += stride; + for (int j = 1; j < size; j++) { + for (int k = 0; k < stride; k++) { + op(output, output - stride, input); + output++; + input++; + } + } + } + } else { + for (int i = 0; i < count; i++) { + std::fill(output, output + stride, init); + output += stride; + input += stride; + for (int j = 1; j < size; j++) { + for (int k = 0; k < stride; k++) { + op(output, output - stride, input - stride); + output++; + input++; + } + } + } + } + } else { + if (inclusive) { + for (int i = 0; i < count; i++) { + output += (size - 1) * stride; + input += (size - 1) * stride; + std::copy(input, input + stride, output); + for (int j = 1; j < size; j++) { + for (int k = 0; k < stride; k++) { + output--; + input--; + op(output, output + stride, input); + } + } + output += size * stride; + input += size * stride; + } + } else { + for (int i = 0; i < count; i++) { + output += (size - 1) * stride; + input += (size - 1) * stride; + std::fill(output, output + stride, init); + for (int j = 1; j < size; j++) { + for (int k = 0; k < stride; k++) { + output--; + input--; + op(output, output + stride, input + stride); + } + } + output += size * stride; + input += size * stride; + } + } + } + } +}; + +template +void scan_op( + OpCS opcs, + OpSS opss, + const array& input, + array& output, + int axis, + bool reverse, + bool inclusive) { + output.set_data(allocator::malloc_or_wait(output.nbytes())); + + if (input.flags().row_contiguous) { + if (input.strides()[axis] == 1) { + opcs( + input.data(), + output.data(), + input.size() / input.shape(axis), + input.shape(axis), + reverse, + inclusive); + } else { + opss( + input.data(), + output.data(), + input.size() / input.shape(axis) / input.strides()[axis], + input.shape(axis), + input.strides()[axis], + reverse, + inclusive); + } + } else { + throw std::runtime_error("Scan op supports only contiguous inputs"); + } +} + +template +void scan_dispatch( + Scan::ReduceType rtype, + const array& input, + array& output, + int axis, + bool reverse, + bool inclusive) { + switch (rtype) { + case Scan::Sum: { + auto op = [](U* o, const U* y, const T* x) { *o = *y + *x; }; + auto init = static_cast(0); + auto opcs = DefaultContiguousScan(op, init); + auto opss = DefaultStridedScan(op, init); + scan_op(opcs, opss, input, output, axis, reverse, inclusive); + break; + } + case Scan::Prod: { + auto op = [](U* o, const U* y, const T* x) { *o = *y * (*x); }; + auto init = static_cast(1); + auto opcs = DefaultContiguousScan(op, init); + auto opss = DefaultStridedScan(op, init); + scan_op(opcs, opss, input, output, axis, reverse, inclusive); + break; + } + case Scan::Min: { + auto op = [](U* o, const U* y, const T* x) { *o = (*x < *y) ? *x : *y; }; + auto init = (is_floating_point(input.dtype())) + ? static_cast(std::numeric_limits::infinity()) + : std::numeric_limits::max(); + auto opcs = DefaultContiguousScan(op, init); + auto opss = DefaultStridedScan(op, init); + scan_op(opcs, opss, input, output, axis, reverse, inclusive); + break; + } + case Scan::Max: { + auto op = [](U* o, const U* y, const T* x) { *o = (*x < *y) ? *y : *x; }; + auto init = (is_floating_point(input.dtype())) + ? static_cast(-std::numeric_limits::infinity()) + : std::numeric_limits::max(); + auto opcs = DefaultContiguousScan(op, init); + auto opss = DefaultStridedScan(op, init); + scan_op(opcs, opss, input, output, axis, reverse, inclusive); + break; + } + } +} + +} // namespace + +void Scan::eval(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + + // Ensure contiguity + auto in = inputs[0]; + if (!in.flags().row_contiguous) { + array arr_copy(in.shape(), in.dtype(), nullptr, {}); + copy(in, arr_copy, CopyType::General); + in = arr_copy; + } + + switch (in.dtype()) { + case bool_: { + // We could do a full dtype x dtype switch but this is the only case + // where we accumulate in a different type, for now. + // + // TODO: If we add the option to accumulate floats in higher precision + // floats perhaps we should add the full all-to-all dispatch. + if (reduce_type_ == Scan::Sum && out.dtype() == int32) { + scan_dispatch( + reduce_type_, in, out, axis_, reverse_, inclusive_); + } else { + scan_dispatch( + reduce_type_, in, out, axis_, reverse_, inclusive_); + } + break; + } + case uint8: + scan_dispatch( + reduce_type_, in, out, axis_, reverse_, inclusive_); + break; + case uint16: + scan_dispatch( + reduce_type_, in, out, axis_, reverse_, inclusive_); + break; + case uint32: + scan_dispatch( + reduce_type_, in, out, axis_, reverse_, inclusive_); + break; + case uint64: + scan_dispatch( + reduce_type_, in, out, axis_, reverse_, inclusive_); + break; + case int8: + scan_dispatch( + reduce_type_, in, out, axis_, reverse_, inclusive_); + break; + case int16: + scan_dispatch( + reduce_type_, in, out, axis_, reverse_, inclusive_); + break; + case int32: + scan_dispatch( + reduce_type_, in, out, axis_, reverse_, inclusive_); + break; + case int64: + scan_dispatch( + reduce_type_, in, out, axis_, reverse_, inclusive_); + break; + case float16: + scan_dispatch( + reduce_type_, in, out, axis_, reverse_, inclusive_); + break; + case float32: + scan_dispatch( + reduce_type_, in, out, axis_, reverse_, inclusive_); + break; + case bfloat16: + scan_dispatch( + reduce_type_, in, out, axis_, reverse_, inclusive_); + break; + case complex64: + throw std::runtime_error("Scan ops do not support complex types yet"); + break; + } +} + +} // namespace mlx::core diff --git a/mlx/backend/common/threefry.cpp b/mlx/backend/common/threefry.cpp new file mode 100644 index 0000000000..391648ce0b --- /dev/null +++ b/mlx/backend/common/threefry.cpp @@ -0,0 +1,29 @@ +#include "mlx/backend/common/threefry.h" + +namespace mlx::core::random { + +std::pair threefry2x32_hash( + const std::pair& key, + std::pair count) { + constexpr static uint32_t rotations[2][4] = { + {13, 15, 26, 6}, {17, 29, 16, 24}}; + + uint32_t ks[3] = {key.first, key.second, key.first ^ key.second ^ 0x1BD11BDA}; + + count.first += ks[0]; + count.second += ks[1]; + + for (int i = 0; i < 5; ++i) { + for (auto r : rotations[i % 2]) { + count.first += count.second; + count.second = (count.second << r) | (count.second >> (32 - r)); + count.second ^= count.first; + } + count.first += ks[(i + 1) % 3]; + count.second += ks[(i + 2) % 3] + i + 1; + } + + return count; +} + +} // namespace mlx::core::random diff --git a/mlx/backend/common/threefry.h b/mlx/backend/common/threefry.h new file mode 100644 index 0000000000..efcb5dcfd7 --- /dev/null +++ b/mlx/backend/common/threefry.h @@ -0,0 +1,19 @@ +#pragma once + +#include +#include + +namespace mlx::core::random { + +/** Applies the Threefry 2x32 hash function. + * This code is based on the Jax counter-based and splittable PRNG + * https://github.com/google/jax/blob/main/docs/jep/263-prng.md + * + * Original Threefry reference: + * http://www.thesalmons.org/john/random123/papers/random123sc11.pdf + */ +std::pair threefry2x32_hash( + const std::pair& key, + std::pair count); + +} // namespace mlx::core::random diff --git a/mlx/backend/common/unary.h b/mlx/backend/common/unary.h new file mode 100644 index 0000000000..271215b530 --- /dev/null +++ b/mlx/backend/common/unary.h @@ -0,0 +1,147 @@ +#pragma once + +#include "mlx/allocator.h" +#include "mlx/array.h" +#include "mlx/backend/common/utils.h" +#include "mlx/utils.h" + +namespace mlx::core { + +namespace { + +struct AbsOp { + template + T operator()(T x) { + return std::abs(x); + } + uint8_t operator()(uint8_t x) { + return x; + } + uint16_t operator()(uint16_t x) { + return x; + } + uint32_t operator()(uint32_t x) { + return x; + } + uint64_t operator()(uint64_t x) { + return x; + } + bool operator()(bool x) { + return x; + } +}; + +struct SignOp { + template + T operator()(T x) { + return (x > T(0)) - (x < T(0)); + } + + uint8_t operator()(uint8_t x) { + return x != 0; + } + uint16_t operator()(uint16_t x) { + return x != 0; + } + uint32_t operator()(uint32_t x) { + return x != 0; + } + uint64_t operator()(uint64_t x) { + return x != 0; + } +}; + +template +void unary_op(const array& a, array& out, Op op) { + const T* a_ptr = a.data(); + if (a.flags().contiguous) { + out.set_data( + allocator::malloc_or_wait(a.data_size() * out.itemsize()), + a.data_size(), + a.strides(), + a.flags()); + T* dst = out.data(); + for (size_t i = 0; i < a.data_size(); ++i) { + dst[i] = op(a_ptr[i]); + } + } else { + out.set_data(allocator::malloc_or_wait(out.nbytes())); + T* dst = out.data(); + for (size_t i = 0; i < out.size(); ++i) { + // TODO this is super inefficient, need to fix. + int a_idx = elem_to_loc(i, a.shape(), a.strides()); + dst[i] = op(a_ptr[a_idx]); + } + } +} + +template +void unary(const array& a, array& out, Op op) { + switch (out.dtype()) { + case bool_: + unary_op(a, out, op); + break; + case uint8: + unary_op(a, out, op); + break; + case uint16: + unary_op(a, out, op); + break; + case uint32: + unary_op(a, out, op); + break; + case uint64: + unary_op(a, out, op); + break; + case int8: + unary_op(a, out, op); + break; + case int16: + unary_op(a, out, op); + break; + case int32: + unary_op(a, out, op); + break; + case int64: + unary_op(a, out, op); + break; + case float16: + unary_op(a, out, op); + break; + case float32: + unary_op(a, out, op); + break; + case bfloat16: + unary_op(a, out, op); + break; + case complex64: + unary_op(a, out, op); + break; + } +} + +template +void unary_fp(const array& a, array& out, Op op) { + switch (out.dtype()) { + case bfloat16: + unary_op(a, out, op); + break; + case float16: + unary_op(a, out, op); + break; + case float32: + unary_op(a, out, op); + break; + case complex64: + unary_op(a, out, op); + break; + default: + std::ostringstream err; + err << "[unary_fp] Does not support " << out.dtype(); + throw std::runtime_error(err.str()); + } +} + +} // namespace + +} // namespace mlx::core diff --git a/mlx/backend/common/utils.h b/mlx/backend/common/utils.h new file mode 100644 index 0000000000..1cc7e9f562 --- /dev/null +++ b/mlx/backend/common/utils.h @@ -0,0 +1,29 @@ +#pragma once + +#include + +#include "mlx/array.h" + +namespace mlx::core { + +inline size_t elem_to_loc( + int elem, + const std::vector& shape, + const std::vector& strides) { + size_t loc = 0; + for (int i = shape.size() - 1; i >= 0; --i) { + auto q_and_r = ldiv(elem, shape[i]); + loc += q_and_r.rem * strides[i]; + elem = q_and_r.quot; + } + return loc; +} + +inline size_t elem_to_loc(int elem, const array& a) { + if (a.flags().row_contiguous) { + return elem; + } + return elem_to_loc(elem, a.shape(), a.strides()); +} + +} // namespace mlx::core diff --git a/mlx/backend/metal/CMakeLists.txt b/mlx/backend/metal/CMakeLists.txt new file mode 100644 index 0000000000..d5fd2e07f7 --- /dev/null +++ b/mlx/backend/metal/CMakeLists.txt @@ -0,0 +1,26 @@ +target_sources( + mlx + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/scan.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/sort.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cpp +) + +if (NOT MLX_METAL_PATH) + set(MLX_METAL_PATH ${CMAKE_CURRENT_BINARY_DIR}/kernels/) +endif() + +add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/kernels) + +target_compile_definitions( + mlx PRIVATE METAL_PATH="${MLX_METAL_PATH}/mlx.metallib") diff --git a/mlx/backend/metal/copy.cpp b/mlx/backend/metal/copy.cpp new file mode 100644 index 0000000000..aadb98e9a5 --- /dev/null +++ b/mlx/backend/metal/copy.cpp @@ -0,0 +1,113 @@ +#include + +#include "mlx/backend/metal/copy.h" +#include "mlx/backend/metal/device.h" +#include "mlx/backend/metal/kernels/defines.h" +#include "mlx/backend/metal/utils.h" +#include "mlx/primitives.h" + +namespace mlx::core { + +void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) { + if (ctype == CopyType::Vector) { + out.set_data( + allocator::malloc_or_wait(in.data_size() * out.itemsize()), + in.data_size(), + in.strides(), + in.flags()); + } else { + out.set_data(allocator::malloc_or_wait(out.nbytes())); + } + if (ctype == CopyType::GeneralGeneral) { + ctype = CopyType::General; + } + copy_gpu_inplace(in, out, ctype, s); +} + +void copy_gpu(const array& in, array& out, CopyType ctype) { + copy_gpu(in, out, ctype, out.primitive().stream()); +} + +void copy_gpu_inplace( + const array& in, + array& out, + CopyType ctype, + const Stream& s) { + // Try to collapse contiguous dims + auto [shape, strides] = collapse_contiguous_dims(in, out); + auto& strides_in = strides[0]; + auto& strides_out = strides[1]; + + auto& d = metal::device(s.device); + std::ostringstream kname; + switch (ctype) { + case CopyType::Scalar: + kname << "scopy"; + break; + case CopyType::Vector: + kname << "vcopy"; + break; + case CopyType::General: + kname << "gcopy"; + break; + case CopyType::GeneralGeneral: + kname << "ggcopy"; + break; + } + kname << type_to_name(in) << type_to_name(out); + if ((ctype == CopyType::General || ctype == CopyType::GeneralGeneral) && + shape.size() <= MAX_COPY_SPECIALIZED_DIMS) { + kname << "_" << shape.size(); + } + auto kernel = d.get_kernel(kname.str()); + auto compute_encoder = d.get_command_encoder(s.index); + compute_encoder->setComputePipelineState(kernel); + set_array_buffer(compute_encoder, in, 0); + set_array_buffer(compute_encoder, out, 1); + + if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) { + size_t ndim = shape.size(); + if (ndim > 3) { + compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 2); + compute_encoder->setBytes(strides_in.data(), ndim * sizeof(size_t), 3); + if (ctype == CopyType::GeneralGeneral) { + compute_encoder->setBytes(strides_out.data(), ndim * sizeof(size_t), 4); + } + } else { + // The shape is implicit in the grid for <= 3D + compute_encoder->setBytes(strides_in.data(), ndim * sizeof(size_t), 2); + if (ctype == CopyType::GeneralGeneral) { + compute_encoder->setBytes(strides_out.data(), ndim * sizeof(size_t), 3); + } + } + + if (ndim > MAX_BINARY_SPECIALIZED_DIMS) { + compute_encoder->setBytes( + &ndim, sizeof(int), (ctype == CopyType::GeneralGeneral) ? 5 : 4); + } + + int dim0 = ndim > 0 ? shape[ndim - 1] : 1; + int dim1 = ndim > 1 ? shape[ndim - 2] : 1; + int rest = in.size() / (dim0 * dim1); + + // NB assuming thread_group_size is a power of 2 larger than 32 x 32 + NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); + if (thread_group_size != 1024) { + throw std::runtime_error("[Metal::copy] Must use 1024 sized block"); + } + auto group_dims = get_block_dims(dim0, dim1, rest); + MTL::Size grid_dims = MTL::Size(dim0, dim1, rest); + compute_encoder->dispatchThreads(grid_dims, group_dims); + } else { + size_t nthreads = out.data_size(); + MTL::Size grid_dims = MTL::Size(nthreads, 1, 1); + NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); + if (thread_group_size > nthreads) { + thread_group_size = nthreads; + } + MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); + compute_encoder->dispatchThreads(grid_dims, group_dims); + } +} + +} // namespace mlx::core diff --git a/mlx/backend/metal/device.h b/mlx/backend/metal/device.h new file mode 100644 index 0000000000..bad001573a --- /dev/null +++ b/mlx/backend/metal/device.h @@ -0,0 +1,81 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include +#include + +#include "mlx/device.h" + +namespace fs = std::filesystem; + +namespace mlx::core::metal { + +inline std::string get_colocated_mtllib_path(const std::string& lib_name) { + Dl_info info; + std::string mtllib_path; + std::string lib_ext = lib_name + ".metallib"; + + int success = dladdr((void*)get_colocated_mtllib_path, &info); + if (success) { + auto mtllib = fs::path(info.dli_fname).remove_filename() / lib_ext; + mtllib_path = mtllib.c_str(); + } + + return mtllib_path; +} + +class Device { + public: + Device(); + Device(const Device&) = delete; + Device& operator=(const Device&) = delete; + ~Device(); + + MTL::Device* mtl_device() { + return device_; + }; + + void new_queue(int index); + MTL::CommandBuffer* new_command_buffer(int index); + MTL::CommandBuffer* get_command_buffer(int index); + int get_command_buffer_ops(int index); + void increment_command_buffer_ops(int index); + void commit_command_buffer(int index); + MTL::ComputeCommandEncoder* get_command_encoder(int index); + void end_encoding(int index); + + void register_library( + const std::string& lib_name, + const std::string& lib_path); + void register_library( + const std::string& lib_name, + const std::function& lib_path_func = + get_colocated_mtllib_path); + + MTL::ComputePipelineState* get_kernel( + const std::string& name, + const std::string& lib_name = "mlx"); + + MTL::ArgumentEncoder* argument_encoder( + const std::vector& arg_descs) const; + + private: + NS::AutoreleasePool* pool_; + MTL::Device* device_; + std::unordered_map queue_map_; + std::unordered_map> buffer_map_; + std::unordered_map encoder_map_; + std::unordered_map kernel_map_; + std::unordered_map library_map_; + std::mutex mtx_; +}; + +Device& device(mlx::core::Device); +NS::AutoreleasePool*& thread_autorelease_pool(); + +} // namespace mlx::core::metal diff --git a/mlx/backend/metal/indexing.cpp b/mlx/backend/metal/indexing.cpp new file mode 100644 index 0000000000..ea5b771f29 --- /dev/null +++ b/mlx/backend/metal/indexing.cpp @@ -0,0 +1,296 @@ +#include +#include +#include +#include + +#include "mlx/backend/common/binary.h" +#include "mlx/backend/metal/copy.h" +#include "mlx/backend/metal/device.h" +#include "mlx/backend/metal/kernels/defines.h" +#include "mlx/backend/metal/utils.h" +#include "mlx/primitives.h" +#include "mlx/utils.h" + +namespace mlx::core { + +namespace { + +static constexpr int METAL_MAX_INDEX_ARRAYS = 10; + +} // namespace + +void Gather::eval_gpu(const std::vector& inputs, array& out) { + auto& src = inputs[0]; + int nidx = inputs.size() - 1; + + if (nidx > METAL_MAX_INDEX_ARRAYS) { + std::ostringstream msg; + msg << "[Gather::eval_gpu] Gathering with more than " + << METAL_MAX_INDEX_ARRAYS << " index arrays not yet supported."; + throw std::runtime_error(msg.str()); + } + + out.set_data(allocator::malloc_or_wait(out.nbytes())); + + auto& s = stream(); + auto& d = metal::device(s.device); + + std::ostringstream kname; + std::string idx_type_name = nidx ? type_to_name(inputs[1]) : ""; + kname << "gather" << type_to_name(src) << idx_type_name << "_" << nidx; + + auto compute_encoder = d.get_command_encoder(s.index); + auto kernel = d.get_kernel(kname.str()); + + size_t slice_size = 1; + for (auto s : slice_sizes_) { + slice_size *= s; + } + + size_t ndim = src.ndim(); + size_t nthreads = out.size(); + NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); + if (thread_group_size > nthreads) { + thread_group_size = nthreads; + } + + MTL::Size grid_dims = MTL::Size(nthreads, 1, 1); + MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); + + compute_encoder->setComputePipelineState(kernel); + + // Make the argument buffer to store the indices for the + // `Indices` struct in kernels/indexing.metal + std::vector arg_descs(4); + arg_descs[0] = MTL::ArgumentDescriptor::argumentDescriptor(); + arg_descs[0]->setIndex(0); + arg_descs[0]->setDataType(MTL::DataType::DataTypePointer); + arg_descs[0]->setArrayLength(nidx); + + // Shapes + arg_descs[1] = MTL::ArgumentDescriptor::argumentDescriptor(); + arg_descs[1]->setDataType(MTL::DataType::DataTypePointer); + arg_descs[1]->setIndex(nidx + 1); + + // Strides + arg_descs[2] = MTL::ArgumentDescriptor::argumentDescriptor(); + arg_descs[2]->setDataType(MTL::DataType::DataTypePointer); + arg_descs[2]->setIndex(nidx + 2); + + // Indices ndim + arg_descs[3] = MTL::ArgumentDescriptor::argumentDescriptor(); + arg_descs[3]->setDataType(MTL::DataType::DataTypeInt); + arg_descs[3]->setIndex(nidx + 3); + + // Get the argument encoder + auto arg_enc = d.argument_encoder(arg_descs); + + // Allocate and fill buffers for shapes and strides + int idx_ndim = nidx ? inputs[1].ndim() : 0; + auto idx_shapes_buf = allocator::malloc_or_wait(sizeof(int) * idx_ndim); + auto idx_strides_buf = allocator::malloc_or_wait(sizeof(size_t) * idx_ndim); + for (int i = 0; i < nidx; ++i) { + std::copy( + inputs[i + 1].shape().begin(), + inputs[i + 1].shape().end(), + static_cast(idx_shapes_buf.raw_ptr()) + i * idx_ndim); + std::copy( + inputs[i + 1].strides().begin(), + inputs[i + 1].strides().end(), + static_cast(idx_strides_buf.raw_ptr()) + i * idx_ndim); + } + + // Allocate the argument bufer + auto arg_buf = allocator::malloc_or_wait(arg_enc->encodedLength()); + + // Register data with the encoder + arg_enc->setArgumentBuffer(static_cast(arg_buf.ptr()), 0); + for (int i = 0; i < nidx; ++i) { + set_array_buffer(compute_encoder, arg_enc, inputs[i + 1], i); + } + arg_enc->setBuffer( + static_cast(idx_shapes_buf.ptr()), 0, nidx + 1); + compute_encoder->useResource( + static_cast(idx_shapes_buf.ptr()), MTL::ResourceUsageRead); + arg_enc->setBuffer( + static_cast(idx_strides_buf.ptr()), 0, nidx + 2); + compute_encoder->useResource( + static_cast(idx_strides_buf.ptr()), MTL::ResourceUsageRead); + *static_cast(arg_enc->constantData(nidx + 3)) = idx_ndim; + + // Set all the buffers + set_array_buffer(compute_encoder, src, 0); + compute_encoder->setBuffer(static_cast(arg_buf.ptr()), 0, 1); + set_array_buffer(compute_encoder, out, 2); + compute_encoder->setBytes(src.shape().data(), ndim * sizeof(int), 3); + compute_encoder->setBytes(src.strides().data(), ndim * sizeof(size_t), 4); + compute_encoder->setBytes(&ndim, sizeof(size_t), 5); + compute_encoder->setBytes(slice_sizes_.data(), ndim * sizeof(int), 6); + compute_encoder->setBytes(&slice_size, sizeof(size_t), 7); + compute_encoder->setBytes(axes_.data(), nidx * sizeof(int), 8); + + compute_encoder->dispatchThreads(grid_dims, group_dims); + + // Cleanup temporaries + arg_enc->release(); + d.get_command_buffer(s.index)->addCompletedHandler( + [arg_buf, idx_shapes_buf, idx_strides_buf](MTL::CommandBuffer*) { + allocator::free(arg_buf); + allocator::free(idx_shapes_buf); + allocator::free(idx_strides_buf); + }); +} + +void Scatter::eval_gpu(const std::vector& inputs, array& out) { + if (size_of(out.dtype()) == 8) { + std::ostringstream msg; + msg << "[Scatter::eval_gpu] Does not support " << out.dtype(); + throw std::invalid_argument(msg.str()); + } + + int nidx = axes_.size(); + if (nidx > METAL_MAX_INDEX_ARRAYS) { + std::ostringstream msg; + msg << "[Scatter::eval_gpu] Gathering with more than " + << METAL_MAX_INDEX_ARRAYS << " index arrays not yet supported."; + throw std::runtime_error(msg.str()); + } + + // Copy src into out + auto copy_type = + inputs[0].data_size() == 1 ? CopyType::Scalar : CopyType::General; + copy_gpu(inputs[0], out, copy_type); + + // Get stream + auto& s = stream(); + auto& d = metal::device(s.device); + + // Get kernel name + std::ostringstream kname; + std::string idx_type_name = nidx ? type_to_name(inputs[1]) : ""; + kname << "scatter" << type_to_name(out) << idx_type_name; + switch (reduce_type_) { + case Scatter::None: + kname << "_none"; + break; + case Scatter::Sum: + kname << "_sum"; + break; + case Scatter::Prod: + kname << "_prod"; + break; + case Scatter::Max: + kname << "_max"; + break; + case Scatter::Min: + kname << "_min"; + break; + } + kname << "_" << nidx; + + auto compute_encoder = d.get_command_encoder(s.index); + auto kernel = d.get_kernel(kname.str()); + + auto& upd = inputs.back(); + size_t nthreads = upd.size(); + NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); + if (thread_group_size > nthreads) { + thread_group_size = nthreads; + } + + MTL::Size grid_dims = MTL::Size(nthreads, 1, 1); + MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); + + compute_encoder->setComputePipelineState(kernel); + + // Make the argument buffer to store the indices for the + // `Indices` struct in kernels/indexing.metal + std::vector arg_descs(4); + arg_descs[0] = MTL::ArgumentDescriptor::argumentDescriptor(); + arg_descs[0]->setIndex(0); + arg_descs[0]->setDataType(MTL::DataType::DataTypePointer); + arg_descs[0]->setArrayLength(nidx); + + // Shapes + arg_descs[1] = MTL::ArgumentDescriptor::argumentDescriptor(); + arg_descs[1]->setDataType(MTL::DataType::DataTypePointer); + arg_descs[1]->setIndex(nidx + 1); + + // Strides + arg_descs[2] = MTL::ArgumentDescriptor::argumentDescriptor(); + arg_descs[2]->setDataType(MTL::DataType::DataTypePointer); + arg_descs[2]->setIndex(nidx + 2); + + // Indices ndim + arg_descs[3] = MTL::ArgumentDescriptor::argumentDescriptor(); + arg_descs[3]->setDataType(MTL::DataType::DataTypeInt); + arg_descs[3]->setIndex(nidx + 3); + + // Get the argument encoder + auto arg_enc = d.argument_encoder(arg_descs); + + // Allocate and fill buffers for shapes and strides + int idx_ndim = nidx ? inputs[1].ndim() : 0; + auto idx_shapes_buf = allocator::malloc_or_wait(sizeof(int) * idx_ndim); + auto idx_strides_buf = allocator::malloc_or_wait(sizeof(size_t) * idx_ndim); + for (int i = 0; i < nidx; ++i) { + std::copy( + inputs[i + 1].shape().begin(), + inputs[i + 1].shape().end(), + static_cast(idx_shapes_buf.raw_ptr()) + i * idx_ndim); + std::copy( + inputs[i + 1].strides().begin(), + inputs[i + 1].strides().end(), + static_cast(idx_strides_buf.raw_ptr()) + i * idx_ndim); + } + + // Allocate the argument bufer + auto arg_buf = allocator::malloc_or_wait(arg_enc->encodedLength()); + + // Register data with the encoder + arg_enc->setArgumentBuffer(static_cast(arg_buf.ptr()), 0); + for (int i = 0; i < nidx; ++i) { + set_array_buffer(compute_encoder, arg_enc, inputs[i + 1], i); + } + arg_enc->setBuffer( + static_cast(idx_shapes_buf.ptr()), 0, nidx + 1); + compute_encoder->useResource( + static_cast(idx_shapes_buf.ptr()), MTL::ResourceUsageRead); + arg_enc->setBuffer( + static_cast(idx_strides_buf.ptr()), 0, nidx + 2); + compute_encoder->useResource( + static_cast(idx_strides_buf.ptr()), MTL::ResourceUsageRead); + *static_cast(arg_enc->constantData(nidx + 3)) = idx_ndim; + + compute_encoder->setBuffer(static_cast(arg_buf.ptr()), 0, 0); + size_t upd_ndim = upd.ndim(); + size_t upd_size = 1; + for (int i = idx_ndim; i < upd.ndim(); ++i) { + upd_size *= upd.shape(i); + } + set_array_buffer(compute_encoder, upd, 1); + set_array_buffer(compute_encoder, out, 2); + compute_encoder->setBytes(upd.shape().data(), upd_ndim * sizeof(int), 3); + compute_encoder->setBytes(upd.strides().data(), upd_ndim * sizeof(size_t), 4); + compute_encoder->setBytes(&upd_ndim, sizeof(size_t), 5); + compute_encoder->setBytes(&upd_size, sizeof(size_t), 6); + + size_t out_ndim = out.ndim(); + compute_encoder->setBytes(out.shape().data(), out_ndim * sizeof(int), 7); + compute_encoder->setBytes(out.strides().data(), out_ndim * sizeof(size_t), 8); + compute_encoder->setBytes(&out_ndim, sizeof(size_t), 9); + compute_encoder->setBytes(axes_.data(), axes_.size() * sizeof(int), 10); + + compute_encoder->dispatchThreads(grid_dims, group_dims); + + // Cleanup temporaries + arg_enc->release(); + d.get_command_buffer(s.index)->addCompletedHandler( + [arg_buf, idx_shapes_buf, idx_strides_buf](MTL::CommandBuffer*) { + allocator::free(arg_buf); + allocator::free(idx_shapes_buf); + allocator::free(idx_strides_buf); + }); +} + +} // namespace mlx::core diff --git a/mlx/backend/metal/kernels/atomic.h b/mlx/backend/metal/kernels/atomic.h new file mode 100644 index 0000000000..43c8cc3ba9 --- /dev/null +++ b/mlx/backend/metal/kernels/atomic.h @@ -0,0 +1,320 @@ +#pragma once + +#include +#include +#include "mlx/backend/metal/kernels/bf16.h" + +using namespace metal; + +/////////////////////////////////////////////////////////////////////////////// +// Atomic utils +/////////////////////////////////////////////////////////////////////////////// + +#pragma METAL internals : enable +template +constexpr constant bool is_metal_atomic = _disjunction< + is_same, + is_same, + is_same, + is_same>::value; + +#pragma METAL internals : disable + +template +struct mlx_atomic { + atomic val; +}; + +template +struct mlx_atomic>> { + atomic val; +}; + +/////////////////////////////////////////////////////////////////////////////// +// Native metal atomics +/////////////////////////////////////////////////////////////////////////////// + +template , bool> = true> +METAL_FUNC T +mlx_atomic_load_explicit(device mlx_atomic* object, int offset) { + return atomic_load_explicit(&(object[offset].val), memory_order_relaxed); +} + +template , bool> = true> +METAL_FUNC void +mlx_atomic_store_explicit(device mlx_atomic* object, T val, int offset) { + atomic_store_explicit(&(object[offset].val), val, memory_order_relaxed); +} + +template , bool> = true> +METAL_FUNC void +mlx_atomic_fetch_and_explicit(device mlx_atomic* object, T val, int offset) { + atomic_fetch_and_explicit(&(object[offset].val), val, memory_order_relaxed); +} + +template , bool> = true> +METAL_FUNC void +mlx_atomic_fetch_or_explicit(device mlx_atomic* object, T val, int offset) { + atomic_fetch_or_explicit(&(object[offset].val), val, memory_order_relaxed); +} + +template , bool> = true> +METAL_FUNC void +mlx_atomic_fetch_min_explicit(device mlx_atomic* object, T val, int offset) { + atomic_fetch_min_explicit(&(object[offset].val), val, memory_order_relaxed); +} + +template , bool> = true> +METAL_FUNC void +mlx_atomic_fetch_max_explicit(device mlx_atomic* object, T val, int offset) { + atomic_fetch_max_explicit(&(object[offset].val), val, memory_order_relaxed); +} + +template , bool> = true> +METAL_FUNC void +mlx_atomic_fetch_add_explicit(device mlx_atomic* object, T val, int offset) { + atomic_fetch_add_explicit(&(object[offset].val), val, memory_order_relaxed); +} + +template , bool> = true> +METAL_FUNC void +mlx_atomic_fetch_mul_explicit(device mlx_atomic* object, T val, int offset) { + T expected = mlx_atomic_load_explicit(object, offset); + while (!mlx_atomic_compare_exchange_weak_explicit( + object, &expected, val * expected, offset)) { + } +} + +template , bool> = true> +METAL_FUNC bool mlx_atomic_compare_exchange_weak_explicit( + device mlx_atomic* object, + thread T* expected, + T val, + int offset) { + return atomic_compare_exchange_weak_explicit( + &(object[offset].val), + expected, + val, + memory_order_relaxed, + memory_order_relaxed); +} + +// Specialization for float since it does not atomic_fetch_min_explicit +template <> +METAL_FUNC void mlx_atomic_fetch_min_explicit( + device mlx_atomic* object, + float val, + int offset) { + float expected = mlx_atomic_load_explicit(object, offset); + while (val < expected) { + if (mlx_atomic_compare_exchange_weak_explicit( + object, &expected, val, offset)) { + return; + } + } +} + +// Specialization for float since it does not atomic_fetch_max_explicit +template <> +METAL_FUNC void mlx_atomic_fetch_max_explicit( + device mlx_atomic* object, + float val, + int offset) { + float expected = mlx_atomic_load_explicit(object, offset); + while (val > expected) { + if (mlx_atomic_compare_exchange_weak_explicit( + object, &expected, val, offset)) { + return; + } + } +} + +/////////////////////////////////////////////////////////////////////////////// +// Custom atomics +/////////////////////////////////////////////////////////////////////////////// + +namespace { + +template +constexpr constant uint packing_size = sizeof(uint) / sizeof(T); + +template +union uint_or_packed { + T val[packing_size]; + uint bits; +}; + +template +struct mlx_atomic_update_helper { + uint operator()(uint_or_packed init, T update, int elem_offset) { + Op op; + init.val[elem_offset] = op(update, init.val[elem_offset]); + return init.bits; + } +}; + +template +METAL_FUNC void mlx_atomic_update_and_store( + device mlx_atomic* object, + T update, + int offset) { + int pack_offset = offset / packing_size; + int elem_offset = offset % packing_size; + + mlx_atomic_update_helper helper; + uint_or_packed expected; + expected.bits = + atomic_load_explicit(&(object[pack_offset].val), memory_order_relaxed); + + while (Op::condition(update, expected.val[elem_offset]) && + !mlx_atomic_compare_exchange_weak_explicit( + object, + &(expected.bits), + helper(expected, update, elem_offset), + pack_offset)) { + } +} + +template +struct __None { + static bool condition(T a, T b) { +#pragma unused(a) +#pragma unused(b) + return true; + } + + T operator()(T a, T b) { +#pragma unused(b) + return a; + } +}; + +template +struct __Add { + static bool condition(T a, T b) { +#pragma unused(a) +#pragma unused(b) + return true; + } + + T operator()(T a, T b) { + return a + b; + } +}; + +template +struct __Mul { + static bool condition(T a, T b) { +#pragma unused(a) + return b != 0; + } + + T operator()(T a, T b) { + return a * b; + } +}; + +template +struct __Max { + static bool condition(T a, T b) { + return a > b; + } + + T operator()(T a, T b) { + return max(a, b); + } +}; + +template +struct __Min { + static bool condition(T a, T b) { + return a < b; + } + + T operator()(T a, T b) { + return min(a, b); + } +}; + +} // namespace + +template , bool> = true> +METAL_FUNC T +mlx_atomic_load_explicit(device mlx_atomic* object, int offset) { + int pack_offset = offset / sizeof(T); + int elem_offset = offset % sizeof(T); + uint_or_packed packed_val; + packed_val.bits = + atomic_load_explicit(&(object[pack_offset].val), memory_order_relaxed); + return packed_val.val[elem_offset]; +} + +template , bool> = true> +METAL_FUNC void +mlx_atomic_store_explicit(device mlx_atomic* object, T val, int offset) { + mlx_atomic_update_and_store>(object, val, offset); +} + +template , bool> = true> +METAL_FUNC void +mlx_atomic_fetch_and_explicit(device mlx_atomic* object, T val, int offset) { + int pack_offset = offset / packing_size; + int elem_offset = offset % packing_size; + uint_or_packed identity; + identity.bits = __UINT32_MAX__; + identity.val[elem_offset] = val; + + atomic_fetch_and_explicit( + &(object[pack_offset].val), identity.bits, memory_order_relaxed); +} + +template , bool> = true> +METAL_FUNC void +mlx_atomic_fetch_or_explicit(device mlx_atomic* object, T val, int offset) { + int pack_offset = offset / packing_size; + int elem_offset = offset % packing_size; + uint_or_packed identity; + identity.bits = 0; + identity.val[elem_offset] = val; + + atomic_fetch_or_explicit( + &(object[pack_offset].val), identity.bits, memory_order_relaxed); +} + +template , bool> = true> +METAL_FUNC void +mlx_atomic_fetch_min_explicit(device mlx_atomic* object, T val, int offset) { + mlx_atomic_update_and_store>(object, val, offset); +} + +template , bool> = true> +METAL_FUNC void +mlx_atomic_fetch_max_explicit(device mlx_atomic* object, T val, int offset) { + mlx_atomic_update_and_store>(object, val, offset); +} + +template , bool> = true> +METAL_FUNC void +mlx_atomic_fetch_add_explicit(device mlx_atomic* object, T val, int offset) { + mlx_atomic_update_and_store>(object, val, offset); +} + +template , bool> = true> +METAL_FUNC void +mlx_atomic_fetch_mul_explicit(device mlx_atomic* object, T val, int offset) { + mlx_atomic_update_and_store>(object, val, offset); +} + +template , bool> = true> +METAL_FUNC bool mlx_atomic_compare_exchange_weak_explicit( + device mlx_atomic* object, + thread uint* expected, + uint val, + int offset) { + return atomic_compare_exchange_weak_explicit( + &(object[offset].val), + expected, + val, + memory_order_relaxed, + memory_order_relaxed); +} \ No newline at end of file diff --git a/mlx/backend/metal/kernels/bf16.h b/mlx/backend/metal/kernels/bf16.h new file mode 100644 index 0000000000..618aefd40b --- /dev/null +++ b/mlx/backend/metal/kernels/bf16.h @@ -0,0 +1,315 @@ +#pragma once + +#include + +using namespace metal; + +#if defined(__HAVE_BFLOAT__) + +typedef bfloat bfloat16_t; + +#else + +///////////////////////////////////////////////////////////////////////////// +// Helpers +///////////////////////////////////////////////////////////////////////////// + +constexpr METAL_FUNC uint16_t float_to_bfloat_bits(float x) { + // Check for nan + if ((as_type(x) & ~_fp_encoding_traits::sign_mask) > + _fp_encoding_traits::inf_mask) { + return uint16_t(as_type(0x7FC0)); + } + // Take bits + uint32_t float_bits = as_type(x); + + // Round to nearest even + float_bits += ((float_bits >> 16) & 1) + as_type(0x7FFF); + + // Take upper 16 bits + return float_bits >> 16; +} + +constexpr METAL_FUNC float bfloat_bits_to_float(uint16_t x) { + // Upper 16 bits are the data and lower 16 bits are 0s + return as_type((uint32_t)x << 16); +} + +struct _MLX_BFloat16; + +template +static constexpr constant bool can_convert_to_bfloat = + !is_same_v && is_convertible_v; + +template +static constexpr constant bool can_convert_from_bfloat = + !is_same_v && is_convertible_v; + +///////////////////////////////////////////////////////////////////////////// +// Bfloat struct +///////////////////////////////////////////////////////////////////////////// + +struct _MLX_BFloat16 { + ///////////////////////////////////////////////////////////////////////////// + // Constructors + uint16_t bits_; + _MLX_BFloat16() thread = default; + _MLX_BFloat16() threadgroup = default; + _MLX_BFloat16() device = default; + _MLX_BFloat16() constant = default; + + struct bits_to_bfloat_struct {}; + static constexpr METAL_FUNC bits_to_bfloat_struct bits_to_bfloat() { + return bits_to_bfloat_struct(); + } + constexpr METAL_FUNC _MLX_BFloat16(uint16_t bits, bits_to_bfloat_struct) + : bits_(bits) {} + + ///////////////////////////////////////////////////////////////////////////// + // Conversions to bfloat + + template < + typename T, + typename = typename enable_if>::type> + constexpr METAL_FUNC _MLX_BFloat16(T x) thread + : bits_(float_to_bfloat_bits(static_cast(x))) {} + + template < + typename T, + typename = typename enable_if>::type> + constexpr METAL_FUNC _MLX_BFloat16(T x) threadgroup + : bits_(float_to_bfloat_bits(static_cast(x))) {} + + template < + typename T, + typename = typename enable_if>::type> + constexpr METAL_FUNC _MLX_BFloat16(T x) device + : bits_(float_to_bfloat_bits(static_cast(x))) {} + + template < + typename T, + typename = typename enable_if>::type> + constexpr METAL_FUNC _MLX_BFloat16(T x) constant + : bits_(float_to_bfloat_bits(static_cast(x))) {} + + ///////////////////////////////////////////////////////////////////////////// + // Conversions from bfloat + + template < + typename T, + typename = typename enable_if>::type> + constexpr METAL_FUNC operator T() const thread { + return static_cast(bfloat_bits_to_float(bits_)); + } + + template < + typename T, + typename = typename enable_if>::type> + constexpr METAL_FUNC operator T() const threadgroup { + return static_cast(bfloat_bits_to_float(bits_)); + } + + template < + typename T, + typename = typename enable_if>::type> + constexpr METAL_FUNC operator T() const device { + return static_cast(bfloat_bits_to_float(bits_)); + } + + template < + typename T, + typename = typename enable_if>::type> + constexpr METAL_FUNC operator T() const constant { + return static_cast(bfloat_bits_to_float(bits_)); + } +}; + +///////////////////////////////////////////////////////////////////////////// +// Bfloat operators +///////////////////////////////////////////////////////////////////////////// + +///////////////////////////////////////////////////////////////////////////// +// Unary ops +constexpr METAL_FUNC _MLX_BFloat16 operator-(_MLX_BFloat16 x) { + return -static_cast(x); +} + +///////////////////////////////////////////////////////////////////////////// +// Binary operators +#define bfloat_binop_base(__op__, __operator__, otype, atype, btype, ctype) \ + constexpr METAL_FUNC otype __operator__(atype lhs, btype rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } + +#define bfloat_binop_helper(__op__, __operator__, otype, itype, ctype) \ + constexpr METAL_FUNC otype __operator__(_MLX_BFloat16 lhs, itype rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } \ + constexpr METAL_FUNC otype __operator__(itype lhs, _MLX_BFloat16 rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } + +///////////////////////////////////////////////////////////////////////////// +// Arithmetic Operators +#define bfloat_binop(_op_, _operator_) \ + bfloat_binop_base( \ + _op_, _operator_, _MLX_BFloat16, _MLX_BFloat16, _MLX_BFloat16, float); \ + bfloat_binop_helper(_op_, _operator_, float, float, float); \ + bfloat_binop_helper(_op_, _operator_, float, half, float); \ + bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int32_t, float); \ + bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint32_t, float); \ + bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int64_t, float); \ + bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint64_t, float); + +bfloat_binop(+, operator+); +bfloat_binop(-, operator-); +bfloat_binop(*, operator*); +bfloat_binop(/, operator/); + +///////////////////////////////////////////////////////////////////////////// +// Comparison ops +#define bfloat_compop(__op__, __operator__) \ + bfloat_binop_base( \ + __op__, __operator__, bool, _MLX_BFloat16, _MLX_BFloat16, float); \ + bfloat_binop_helper(__op__, __operator__, bool, float, float); \ + bfloat_binop_helper(__op__, __operator__, bool, half, float); \ + bfloat_binop_helper(__op__, __operator__, bool, int32_t, float); \ + bfloat_binop_helper(__op__, __operator__, bool, uint32_t, float); \ + bfloat_binop_helper(__op__, __operator__, bool, int64_t, float); \ + bfloat_binop_helper(__op__, __operator__, bool, uint64_t, float); + +bfloat_compop(>, operator>); +bfloat_compop(<, operator<); +bfloat_compop(>=, operator>=); +bfloat_compop(<=, operator<=); +bfloat_compop(==, operator==); +bfloat_compop(!=, operator!=); + +#undef bfloat_compop +#undef bfloat_binop_base +#undef bfloat_binop_helper +#undef bfloat_binop + +///////////////////////////////////////////////////////////////////////////// +// Inplace Operators +#define bfloat_inplace_op_helper(__op__, __operator__, itype, addr_space) \ + constexpr METAL_FUNC addr_space _MLX_BFloat16& __operator__( \ + addr_space _MLX_BFloat16& lhs, itype rhs) { \ + lhs = static_cast(lhs) __op__ static_cast(rhs); \ + return lhs; \ + } \ + constexpr METAL_FUNC addr_space itype& __operator__( \ + addr_space itype& lhs, _MLX_BFloat16 rhs) { \ + lhs = static_cast(lhs) __op__ static_cast(rhs); \ + return lhs; \ + } + +#define bfloat_inplace_op_addr_space_helper(__op__, __operator__, itype) \ + bfloat_inplace_op_helper(__op__, __operator__, itype, device); \ + bfloat_inplace_op_helper(__op__, __operator__, itype, thread); \ + bfloat_inplace_op_helper(__op__, __operator__, itype, threadgroup); + +#define bfloat_inplace_op(itype) \ + bfloat_inplace_op_addr_space_helper(+, operator+=, itype); \ + bfloat_inplace_op_addr_space_helper(-, operator-=, itype); \ + bfloat_inplace_op_addr_space_helper(*, operator*=, itype); \ + bfloat_inplace_op_addr_space_helper(/, operator/=, itype); + +bfloat_inplace_op(float); +bfloat_inplace_op(half); +bfloat_inplace_op(int16_t); +bfloat_inplace_op(int32_t); +bfloat_inplace_op(int64_t); +bfloat_inplace_op(uint16_t); +bfloat_inplace_op(uint32_t); +bfloat_inplace_op(uint64_t); + +#undef bfloat_inplace_op_helper +#undef bfloat_inplace_op_addr_space_helper +#undef bfloat_inplace_op + +#define bfloat_inplace_op_helper(__op__, __operator__, addr_space) \ + constexpr METAL_FUNC addr_space _MLX_BFloat16& __operator__( \ + addr_space _MLX_BFloat16& lhs, _MLX_BFloat16 rhs) { \ + lhs = static_cast(lhs) __op__ static_cast(rhs); \ + return lhs; \ + } + +#define bfloat_inplace_op_addr_space_helper(__op__, __operator__) \ + bfloat_inplace_op_helper(__op__, __operator__, device); \ + bfloat_inplace_op_helper(__op__, __operator__, thread); \ + bfloat_inplace_op_helper(__op__, __operator__, threadgroup); + +bfloat_inplace_op_addr_space_helper(+, operator+=); +bfloat_inplace_op_addr_space_helper(-, operator-=); +bfloat_inplace_op_addr_space_helper(*, operator*=); +bfloat_inplace_op_addr_space_helper(/, operator/=); + +#undef bfloat_inplace_op_helper +#undef bfloat_inplace_op_addr_space_helper + +///////////////////////////////////////////////////////////////////////////// +// Bfloat typedef +///////////////////////////////////////////////////////////////////////////// + +typedef struct _MLX_BFloat16 bfloat16_t; + +///////////////////////////////////////////////////////////////////////////// +// Bfloat numeric limits +///////////////////////////////////////////////////////////////////////////// + +#pragma METAL internals : enable + +namespace metal { + +template <> +struct _numeric_limits_impl : _fp_numeric_limits_impl_base { + static constexpr constant int digits = 8; + static constexpr constant int digits10 = 2; + static constexpr constant int max_digits10 = 4; + static constexpr constant int radix = 2; + static constexpr constant int min_exponent = -125; + static constexpr constant int min_exponent10 = -37; + static constexpr constant int max_exponent = 128; + static constexpr constant int max_exponent10 = 38; + + static constexpr bfloat16_t min() { + return _MLX_BFloat16(0x0080, _MLX_BFloat16::bits_to_bfloat()); + } + static constexpr bfloat16_t lowest() { + return _MLX_BFloat16(0xFF7F, _MLX_BFloat16::bits_to_bfloat()); + } + static constexpr bfloat16_t max() { + return _MLX_BFloat16(0x7F7F, _MLX_BFloat16::bits_to_bfloat()); + } + static constexpr bfloat16_t epsilon() { + return _MLX_BFloat16(0x3C00, _MLX_BFloat16::bits_to_bfloat()); + } + static constexpr bfloat16_t round_error() { + return _MLX_BFloat16(0x3F00, _MLX_BFloat16::bits_to_bfloat()); + } + static constexpr bfloat16_t infinity() { + return _MLX_BFloat16(0x7F80, _MLX_BFloat16::bits_to_bfloat()); + } + static constexpr bfloat16_t quiet_NaN() { + return _MLX_BFloat16(0x7FC0, _MLX_BFloat16::bits_to_bfloat()); + } + static constexpr bfloat16_t signaling_NaN() { + return _MLX_BFloat16(0x7F80, _MLX_BFloat16::bits_to_bfloat()); + } + static constexpr bfloat16_t denorm_min() { + return _MLX_BFloat16(0x0001, _MLX_BFloat16::bits_to_bfloat()); + } +}; + +METAL_FUNC bool isnan(_MLX_BFloat16 x) { + return x != x; +} + +} // namespace metal + +#pragma METAL internals : disable + +#endif // defined(__HAVE_BFLOAT__) + +#include "mlx/backend/metal/kernels/bf16_math.h" diff --git a/mlx/backend/metal/kernels/binary.metal b/mlx/backend/metal/kernels/binary.metal new file mode 100644 index 0000000000..5baab0e966 --- /dev/null +++ b/mlx/backend/metal/kernels/binary.metal @@ -0,0 +1,369 @@ +#include +#include + +#include "mlx/backend/metal/kernels/utils.h" +#include "mlx/backend/metal/kernels/bf16.h" + +struct Add { + template T operator()(T x, T y) { return x + y; } +}; + +struct Divide { + template T operator()(T x, T y) { return x / y; } +}; + +struct Equal { + template bool operator()(T x, T y) { return x == y; } +}; + +struct NaNEqual { + template bool operator()(T x, T y) { + return x == y || (metal::isnan(x) && metal::isnan(y)); + } + template <> + bool operator()(complex64_t x, complex64_t y) { + return x == y || + (metal::isnan(x.real) && metal::isnan(y.real) + && metal::isnan(x.imag) && metal::isnan(y.imag)) || + (x.real == y.real && metal::isnan(x.imag) && metal::isnan(y.imag)) || + (metal::isnan(x.real) && metal::isnan(y.real) && x.imag == y.imag); + } +}; + +struct Greater { + template bool operator()(T x, T y) { return x > y; } +}; + +struct GreaterEqual { + template bool operator()(T x, T y) { return x >= y; } +}; + +struct Less { + template bool operator()(T x, T y) { return x < y; } +}; + +struct LessEqual { + template bool operator()(T x, T y) { return x <= y; } +}; + +struct LogAddExp { + template + T operator()(T x, T y) { + constexpr T inf = metal::numeric_limits::infinity(); + T maxval = metal::max(x, y); + T minval = metal::min(x, y); + return (minval == -inf || maxval == inf) ? maxval : + (maxval + log1p(metal::exp(minval - maxval))); + }; +}; + +struct Maximum { + template T operator()(T x, T y) { return metal::max(x, y); } + + template <> + complex64_t operator()(complex64_t x, complex64_t y) { + return x >= y ? x : y; + } +}; + +struct Minimum { + template T operator()(T x, T y) { return metal::min(x, y); } + + template <> + complex64_t operator()(complex64_t x, complex64_t y) { + return x <= y ? x : y; + } +}; + +struct Multiply { + template T operator()(T x, T y) { return x * y; } +}; + +struct NotEqual { + template bool operator()(T x, T y) { return x != y; } + template <> + bool operator()(complex64_t x, complex64_t y) { + return x.real != y.real || x.imag != y.imag; + } +}; + +struct Power { + + template + metal::enable_if_t, T> operator()(T base, T exp) { + return metal::pow(base, exp); + } + + template + metal::enable_if_t, T> operator()(T base, T exp) { + T res = 1; + while (exp) { + if (exp & 1) { + res *= base; + } + exp >>= 1; + base *= base; + } + return res; + } + + template <> + complex64_t operator()(complex64_t x, complex64_t y) { + auto x_theta = metal::atan(x.imag / x.real); + auto x_ln_r = 0.5 * metal::log(x.real * x.real + x.imag * x.imag); + auto mag = metal::exp(y.real * x_ln_r - y.imag * x_theta); + auto phase = y.imag * x_ln_r + y.real * x_theta; + return {mag * metal::cos(phase), mag * metal::sin(phase)}; + } +}; + + +struct Subtract { + template T operator()(T x, T y) { return x - y; } +}; + +template +[[kernel]] void binary_op_s2s( + device const T* a, + device const T* b, + device U* c, + uint index [[thread_position_in_grid]]) { + c[index] = Op()(a[0], b[0]); +} + + +template +[[kernel]] void binary_op_ss( + device const T* a, + device const T* b, + device U* c, + uint index [[thread_position_in_grid]]) { + c[index] = Op()(a[0], b[0]); +} + +template +[[kernel]] void binary_op_sv( + device const T* a, + device const T* b, + device U* c, + uint index [[thread_position_in_grid]]) { + c[index] = Op()(a[0], b[index]); +} + +template +[[kernel]] void binary_op_vs( + device const T* a, + device const T* b, + device U* c, + uint index [[thread_position_in_grid]]) { + c[index] = Op()(a[index], b[0]); +} + +template +[[kernel]] void binary_op_vv( + device const T* a, + device const T* b, + device U* c, + uint index [[thread_position_in_grid]]) { + c[index] = Op()(a[index], b[index]); +} + +template +[[kernel]] void binary_op_g_nd1( + device const T* a, + device const T* b, + device U* c, + constant const size_t& a_stride, + constant const size_t& b_stride, + uint index [[thread_position_in_grid]]) { + auto a_idx = elem_to_loc_1(index, a_stride); + auto b_idx = elem_to_loc_1(index, b_stride); + c[index] = Op()(a[a_idx], b[b_idx]); +} + +template +[[kernel]] void binary_op_g_nd2( + device const T* a, + device const T* b, + device U* c, + constant const size_t a_strides[2], + constant const size_t b_strides[2], + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + auto a_idx = elem_to_loc_2(index, a_strides); + auto b_idx = elem_to_loc_2(index, b_strides); + size_t out_idx = index.x + (size_t)grid_dim.x * index.y; + c[out_idx] = Op()(a[a_idx], b[b_idx]); +} + +template +[[kernel]] void binary_op_g_nd3( + device const T* a, + device const T* b, + device U* c, + constant const size_t a_strides[3], + constant const size_t b_strides[3], + uint3 index [[thread_position_in_grid]], + uint3 grid_dim [[threads_per_grid]]) { + auto a_idx = elem_to_loc_3(index, a_strides); + auto b_idx = elem_to_loc_3(index, b_strides); + size_t out_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z); + c[out_idx] = Op()(a[a_idx], b[b_idx]); +} + +template +[[kernel]] void binary_op_g_nd( + device const T* a, + device const T* b, + device U* c, + constant const int shape[DIM], + constant const size_t a_strides[DIM], + constant const size_t b_strides[DIM], + uint3 index [[thread_position_in_grid]], + uint3 grid_dim [[threads_per_grid]]) { + auto idx = elem_to_loc_2_nd(index, shape, a_strides, b_strides); + size_t out_idx = index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z); + c[out_idx] = Op()(a[idx.x], b[idx.y]); +} + +template +[[kernel]] void binary_op_g( + device const T* a, + device const T* b, + device U* c, + constant const int* shape, + constant const size_t* a_strides, + constant const size_t* b_strides, + constant const int& ndim, + uint3 index [[thread_position_in_grid]], + uint3 grid_dim [[threads_per_grid]]) { + auto idx = elem_to_loc_2_nd(index, shape, a_strides, b_strides, ndim); + size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z); + c[out_idx] = Op()(a[idx.x], b[idx.y]); +} + +#define instantiate_binary(name, itype, otype, op, bopt) \ + template [[host_name(name)]] \ + [[kernel]] void binary_op_##bopt( \ + device const itype* a, \ + device const itype* b, \ + device otype* c, \ + uint index [[thread_position_in_grid]]); + +#define instantiate_binary_g_dim(name, itype, otype, op, dims) \ + template [[host_name(name "_" #dims)]] \ + [[kernel]] void binary_op_g_nd( \ + device const itype* a, \ + device const itype* b, \ + device otype* c, \ + constant const int shape[dims], \ + constant const size_t a_strides[dims], \ + constant const size_t b_strides[dims], \ + uint3 index [[thread_position_in_grid]], \ + uint3 grid_dim [[threads_per_grid]]); + +#define instantiate_binary_g_nd(name, itype, otype, op) \ + template [[host_name(name "_1")]] \ + [[kernel]] void binary_op_g_nd1( \ + device const itype* a, \ + device const itype* b, \ + device otype* c, \ + constant const size_t& a_stride, \ + constant const size_t& b_stride, \ + uint index [[thread_position_in_grid]]); \ + template [[host_name(name "_2")]] \ + [[kernel]] void binary_op_g_nd2( \ + device const itype* a, \ + device const itype* b, \ + device otype* c, \ + constant const size_t a_strides[2], \ + constant const size_t b_strides[2], \ + uint2 index [[thread_position_in_grid]], \ + uint2 grid_dim [[threads_per_grid]]); \ + template [[host_name(name "_3")]] \ + [[kernel]] void binary_op_g_nd3( \ + device const itype* a, \ + device const itype* b, \ + device otype* c, \ + constant const size_t a_strides[3], \ + constant const size_t b_strides[3], \ + uint3 index [[thread_position_in_grid]], \ + uint3 grid_dim [[threads_per_grid]]); \ + instantiate_binary_g_dim(name, itype, otype, op, 4) \ + instantiate_binary_g_dim(name, itype, otype, op, 5) + + +#define instantiate_binary_g(name, itype, otype, op) \ + template [[host_name(name)]] \ + [[kernel]] void binary_op_g( \ + device const itype* a, \ + device const itype* b, \ + device otype* c, \ + constant const int* shape, \ + constant const size_t* a_strides, \ + constant const size_t* b_strides, \ + constant const int& ndim, \ + uint3 index [[thread_position_in_grid]], \ + uint3 grid_dim [[threads_per_grid]]); + +#define instantiate_binary_all(name, tname, itype, otype, op) \ + instantiate_binary("ss" #name #tname, itype, otype, op, ss) \ + instantiate_binary("sv" #name #tname, itype, otype, op, sv) \ + instantiate_binary("vs" #name #tname, itype, otype, op, vs) \ + instantiate_binary("vv" #name #tname, itype, otype, op, vv) \ + instantiate_binary_g("g" #name #tname, itype, otype, op) \ + instantiate_binary_g_nd("g" #name #tname, itype, otype, op) + +#define instantiate_binary_float(name, op) \ + instantiate_binary_all(name, float16, half, half, op) \ + instantiate_binary_all(name, float32, float, float, op) \ + instantiate_binary_all(name, bfloat16, bfloat16_t, bfloat16_t, op) + +#define instantiate_binary_types(name, op) \ + instantiate_binary_all(name, bool_, bool, bool, op) \ + instantiate_binary_all(name, uint8, uint8_t, uint8_t, op) \ + instantiate_binary_all(name, uint16, uint16_t, uint16_t, op) \ + instantiate_binary_all(name, uint32, uint32_t, uint32_t, op) \ + instantiate_binary_all(name, uint64, uint64_t, uint64_t, op) \ + instantiate_binary_all(name, int8, int8_t, int8_t, op) \ + instantiate_binary_all(name, int16, int16_t, int16_t, op) \ + instantiate_binary_all(name, int32, int32_t, int32_t, op) \ + instantiate_binary_all(name, int64, int64_t, int64_t, op) \ + instantiate_binary_all(name, complex64, complex64_t, complex64_t, op) \ + instantiate_binary_float(name, op) + +#define instantiate_binary_types_bool(name, op) \ + instantiate_binary_all(name, bool_, bool, bool, op) \ + instantiate_binary_all(name, uint8, uint8_t, bool, op) \ + instantiate_binary_all(name, uint16, uint16_t, bool, op) \ + instantiate_binary_all(name, uint32, uint32_t, bool, op) \ + instantiate_binary_all(name, uint64, uint64_t, bool, op) \ + instantiate_binary_all(name, int8, int8_t, bool, op) \ + instantiate_binary_all(name, int16, int16_t, bool, op) \ + instantiate_binary_all(name, int32, int32_t, bool, op) \ + instantiate_binary_all(name, int64, int64_t, bool, op) \ + instantiate_binary_all(name, float16, half, bool, op) \ + instantiate_binary_all(name, float32, float, bool, op) \ + instantiate_binary_all(name, bfloat16, bfloat16_t, bool, op) \ + instantiate_binary_all(name, complex64, complex64_t, bool, op) + +instantiate_binary_types(add, Add) +instantiate_binary_float(div, Divide) +instantiate_binary_types_bool(eq, Equal) +instantiate_binary_types_bool(ge, Greater) +instantiate_binary_types_bool(geq, GreaterEqual) +instantiate_binary_types_bool(le, Less) +instantiate_binary_types_bool(leq, LessEqual) +instantiate_binary_types_bool(neq, NotEqual) +instantiate_binary_float(lae, LogAddExp) +instantiate_binary_types(max, Maximum) +instantiate_binary_types(min, Minimum) +instantiate_binary_types(mul, Multiply) +instantiate_binary_types(sub, Subtract) +instantiate_binary_types(pow, Power) + +// NaNEqual only needed for floating point types with boolean output +instantiate_binary_all(naneq, float16, half, bool, NaNEqual) +instantiate_binary_all(naneq, float32, float, bool, NaNEqual) +instantiate_binary_all(naneq, bfloat16, bfloat16_t, bool, NaNEqual) +instantiate_binary_all(naneq, complex64, complex64_t, bool, NaNEqual) diff --git a/mlx/backend/metal/kernels/complex.h b/mlx/backend/metal/kernels/complex.h new file mode 100644 index 0000000000..b4be6d51e8 --- /dev/null +++ b/mlx/backend/metal/kernels/complex.h @@ -0,0 +1,110 @@ +#pragma once + +#include + +using namespace metal; + +struct complex64_t; + +template +static constexpr constant bool can_convert_to_complex64 = + !is_same_v && is_convertible_v; + +template +static constexpr constant bool can_convert_from_complex64 = + !is_same_v && + (is_convertible_v || is_convertible_v); + +struct complex64_t { + float real; + float imag; + + // Constructors + constexpr complex64_t(float real, float imag) : real(real), imag(imag){}; + + // Conversions to complex64_t + template < + typename T, + typename = typename enable_if>::type> + constexpr complex64_t(T x) thread : real(x), imag(0) {} + + template < + typename T, + typename = typename enable_if>::type> + constexpr complex64_t(T x) threadgroup : real(x), imag(0) {} + + template < + typename T, + typename = typename enable_if>::type> + constexpr complex64_t(T x) device : real(x), imag(0) {} + + template < + typename T, + typename = typename enable_if>::type> + constexpr complex64_t(T x) constant : real(x), imag(0) {} + + // Converstions from complex64_t + template < + typename T, + typename = typename enable_if>::type> + constexpr operator T() const thread { + return static_cast(real); + } + + template < + typename T, + typename = typename enable_if>::type> + constexpr operator T() const threadgroup { + return static_cast(real); + } + + template < + typename T, + typename = typename enable_if>::type> + constexpr operator T() const device { + return static_cast(real); + } + + template < + typename T, + typename = typename enable_if>::type> + constexpr operator T() const constant { + return static_cast(real); + } +}; + +constexpr complex64_t operator-(complex64_t x) { + return {-x.real, -x.imag}; +} + +constexpr bool operator>=(complex64_t a, complex64_t b) { + return (a.real > b.real) || (a.real == b.real && a.imag >= b.imag); +} + +constexpr bool operator>(complex64_t a, complex64_t b) { + return (a.real > b.real) || (a.real == b.real && a.imag > b.imag); +} + +constexpr bool operator<=(complex64_t a, complex64_t b) { + return operator>=(b, a); +} + +constexpr bool operator<(complex64_t a, complex64_t b) { + return operator>(b, a); +} + +constexpr bool operator==(complex64_t a, complex64_t b) { + return a.real == b.real && a.imag == b.imag; +} + +constexpr complex64_t operator+(complex64_t a, complex64_t b) { + return {a.real + b.real, a.imag + b.imag}; +} + +constexpr complex64_t operator-(complex64_t a, complex64_t b) { + return {a.real - b.real, a.imag - b.imag}; +} + +constexpr complex64_t operator*(complex64_t a, complex64_t b) { + return {a.real * b.real - a.imag * b.imag, a.real * b.imag + a.imag * b.real}; +} diff --git a/mlx/backend/metal/kernels/defines.h b/mlx/backend/metal/kernels/defines.h new file mode 100644 index 0000000000..38102e987f --- /dev/null +++ b/mlx/backend/metal/kernels/defines.h @@ -0,0 +1,14 @@ +#pragma once + +#ifdef __METAL__ +#define MTL_CONST constant +#else +#define MTL_CONST +#endif + +static MTL_CONST constexpr int MAX_BINARY_SPECIALIZED_DIMS = 5; +static MTL_CONST constexpr int MAX_COPY_SPECIALIZED_DIMS = 5; +static MTL_CONST constexpr int MAX_REDUCE_SPECIALIZED_DIMS = 4; +static MTL_CONST constexpr int REDUCE_N_READS = 16; +static MTL_CONST constexpr int SOFTMAX_N_READS = 4; +static MTL_CONST constexpr int SOFTMAX_LOOPED_LIMIT = 4096; diff --git a/mlx/backend/metal/kernels/gemm/conv.h b/mlx/backend/metal/kernels/gemm/conv.h new file mode 100644 index 0000000000..faf550e34e --- /dev/null +++ b/mlx/backend/metal/kernels/gemm/conv.h @@ -0,0 +1,479 @@ +#pragma once + +#include +#include +#include + +#include "mlx/backend/metal/kernels/bf16.h" +#include "mlx/backend/metal/kernels/conv_params.h" + +#define MLX_MTL_CONST static constant constexpr const + +using namespace metal; + +/////////////////////////////////////////////////////////////////////////////// +// Loading helper +/////////////////////////////////////////////////////////////////////////////// + +template < + typename T, + int BM, + int BN, + int BK, + int vec_size, + int tgp_size, + int tgp_padding = 0> +struct Conv2DInputBlockLoader { + // Destination dimensions + MLX_MTL_CONST int dst_fd = BM; + MLX_MTL_CONST int dst_ld = BK + tgp_padding; + MLX_MTL_CONST int n_vecs = BK / vec_size; + + // Stride along block row within the block + MLX_MTL_CONST int bstride = tgp_size / n_vecs; + MLX_MTL_CONST int n_rows = dst_fd / bstride; + + // Thread location indices + const short thread_idx; + const short bi; + const short bj; + + // threadgroup and device memory + threadgroup T* dst; + const device T* src; + + const constant MLXConvParams<2>& params; + + int weight_h; + int weight_w; + + int offsets_n[n_rows]; + int offsets_oh[n_rows]; + int offsets_ow[n_rows]; + + /* Constructor */ + METAL_FUNC Conv2DInputBlockLoader( + const device T* src_, + threadgroup T* dst_, + const constant MLXConvParams<2>& params_, + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]]) + : thread_idx(simd_group_id * 32 + simd_lane_id), + bi(thread_idx / n_vecs), + bj(vec_size * (thread_idx % n_vecs)), + dst(dst_ + bi * dst_ld + bj), + src(src_ + bj), + params(params_), + weight_h(0), + weight_w(0) { + int out_n_pixels = params.oS[0] * params.oS[1]; + + for (int i = 0; i < n_rows; ++i) { + int offset_nhw = tid.y * BM + bi + i * bstride; + offsets_n[i] = offset_nhw / out_n_pixels; + int hw = offset_nhw % out_n_pixels; + offsets_oh[i] = hw / params.oS[1]; + offsets_ow[i] = hw % params.oS[1]; + } + + (void)lid; + } + + /* Load from device memory into threadgroup memory - without bound checking */ + METAL_FUNC void load_unsafe() const { +#pragma clang loop unroll(full) + for (short i = 0, is = 0; i < n_rows; ++i, is += bstride) { + int n = offsets_n[i]; + int oh = offsets_oh[i]; + int ow = offsets_ow[i]; + + int ih = oh * params.str[0] - params.pad[0] + weight_h * params.dil[0]; + int iw = ow * params.str[1] - params.pad[1] + weight_w * params.dil[1]; + + // Read from input if in bounds + if (ih >= 0 && ih < params.iS[0] && iw >= 0 && iw < params.iS[1]) { + const device T* curr_src = src + n * params.in_strides[0] + + ih * params.in_strides[1] + iw * params.in_strides[2]; + +#pragma clang loop unroll(full) + for (short j = 0; j < vec_size; ++j) { + dst[is * dst_ld + j] = curr_src[j]; + } + } + + // Zero pad otherwize + else { +#pragma clang loop unroll(full) + for (short j = 0; j < vec_size; ++j) { + dst[is * dst_ld + j] = T(0); + } + } + } + } + + /* Iteration helper */ + METAL_FUNC void next() { + if (++weight_w < params.wS[1]) { + return; + } + + weight_w = 0; + + if (++weight_h < params.wS[0]) { + return; + } + + weight_h = 0; + + src += BK; + } +}; + +template < + typename T, + int BM, + int BN, + int BK, + int vec_size, + int tgp_size, + int tgp_padding = 0> +struct Conv2DWeightBlockLoader { + // Destination dimensions + MLX_MTL_CONST int dst_fd = BN; + MLX_MTL_CONST int dst_ld = BK + tgp_padding; + MLX_MTL_CONST int n_vecs = BK / vec_size; + + // Stride along block row within the block + MLX_MTL_CONST int bstride = tgp_size / n_vecs; + MLX_MTL_CONST int n_rows = dst_fd / bstride; + + // Leading dimension for src + const int src_ld; + + // Thread location indices + const short thread_idx; + const short bi; + const short bj; + + // threadgroup and device memory + threadgroup T* dst; + const device T* src; + + const constant MLXConvParams<2>& params; + + int weight_h; + int weight_w; + + /* Constructor */ + METAL_FUNC Conv2DWeightBlockLoader( + const device T* src_, + threadgroup T* dst_, + const constant MLXConvParams<2>& params_, + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]]) + : src_ld(params_.wt_strides[0]), + thread_idx(simd_group_id * 32 + simd_lane_id), + bi(thread_idx / n_vecs), + bj(vec_size * (thread_idx % n_vecs)), + dst(dst_ + bi * dst_ld + bj), + src(src_ + bi * src_ld + bj), + params(params_), + weight_h(0), + weight_w(0) { + (void)lid; + (void)tid; + } + + /* Load from device memory into threadgroup memory - without bound checking */ + METAL_FUNC void load_unsafe() const { + const device T* curr_src = + src + weight_h * params.wt_strides[1] + weight_w * params.wt_strides[2]; +#pragma clang loop unroll(full) + for (short i = 0; i < dst_fd; i += bstride) { +#pragma clang loop unroll(full) + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = curr_src[i * src_ld + j]; + } + } + } + + /* Iteration helper */ + METAL_FUNC void next() { + if (++weight_w < params.wS[1]) { + return; + } + + weight_w = 0; + + if (++weight_h < params.wS[0]) { + return; + } + + weight_h = 0; + + src += BK; + } +}; + +/////////////////////////////////////////////////////////////////////////////// +// Transforms +/////////////////////////////////////////////////////////////////////////////// + +template +struct TransformNone { + static METAL_FUNC OutT apply(InT x) { + return static_cast(x); + } +}; + +template +struct AccumHelper { + typedef float accum_type; +}; + +/////////////////////////////////////////////////////////////////////////////// +// MMA helper +/////////////////////////////////////////////////////////////////////////////// + +template < + typename T, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + int tgp_padding_a = 0, + int tgp_padding_b = 0, + typename AccumType = typename AccumHelper::accum_type, + typename Epilogue = TransformNone> +struct Conv2DBlockMMA { + // Warp tile size along M + MLX_MTL_CONST int TM = BM / (WM * 8); + // Warp tile size along N + MLX_MTL_CONST int TN = BN / (WN * 8); + + // Warp tile simdgroup matrix strides along M + MLX_MTL_CONST int TM_stride = 8 * WM; + // Warp tile simdgroup matrix strides along M + MLX_MTL_CONST int TN_stride = 8 * WN; + + // Leading dimensions of threadgroup A, B blocks + MLX_MTL_CONST int lda_tgp = (transpose_a ? BM : BK) + tgp_padding_a; + MLX_MTL_CONST int ldb_tgp = (transpose_b ? BK : BN) + tgp_padding_b; + + // Strides of A, B along reduction axis + MLX_MTL_CONST short simd_stride_a = + transpose_a ? TM_stride : TM_stride * lda_tgp; + MLX_MTL_CONST short simd_stride_b = + transpose_b ? TN_stride * ldb_tgp : TN_stride; + + // Jump between elements + MLX_MTL_CONST short jump_a = transpose_a ? lda_tgp : 1; + MLX_MTL_CONST short jump_b = transpose_b ? ldb_tgp : 1; + + // Offsets within threadgroup + const int tm; + const int tn; + + // Simdgroup matrices + simdgroup_matrix Asimd[TM]; + simdgroup_matrix Bsimd[TN]; + simdgroup_matrix results[TM * TN] = { + simdgroup_matrix(0)}; + + short sm; + short sn; + + /* Constructor */ + METAL_FUNC Conv2DBlockMMA( + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]]) + : tm(8 * (simd_group_id / WN)), tn(8 * (simd_group_id % WN)) { + short qid = simd_lane_id / 4; + sm = (qid & 4) + (simd_lane_id / 2) % 4; + sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2; + } + + /* (BM, BK) X (BK, BN) multiply accumulate function */ + METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) { +// Iterate over BK in blocks of 8 +#pragma clang loop unroll(full) + for (short kk = 0; kk < BK; kk += 8) { + short2 offset_a = + transpose_a ? short2(tm + sm, kk + sn) : short2(kk + sn, tm + sm); + short2 offset_b = + transpose_b ? short2(kk + sm, tn + sn) : short2(tn + sn, kk + sm); + + const threadgroup T* As__ = As + offset_a.y * lda_tgp + offset_a.x; + const threadgroup T* Bs__ = Bs + offset_b.y * ldb_tgp + offset_b.x; + + simdgroup_barrier(mem_flags::mem_none); +// Load elements from threadgroup A as simdgroup matrices +#pragma clang loop unroll(full) + for (short i = 0; i < TM; i++) { + Asimd[i].thread_elements()[0] = static_cast(As__[0]); + Asimd[i].thread_elements()[1] = static_cast(As__[jump_a]); + As__ += simd_stride_a; + } + + simdgroup_barrier(mem_flags::mem_none); +// Load elements from threadgroup B as simdgroup matrices +#pragma clang loop unroll(full) + for (short j = 0; j < TN; j++) { + Bsimd[j].thread_elements()[0] = static_cast(Bs__[0]); + Bsimd[j].thread_elements()[1] = static_cast(Bs__[jump_b]); + Bs__ += simd_stride_b; + } + + simdgroup_barrier(mem_flags::mem_none); +// Multiply and accumulate into resulr simdgroup matrices +#pragma clang loop unroll(full) + for (short i = 0; i < TM; i++) { +#pragma clang loop unroll(full) + for (short j = 0; j < TN; j++) { + simdgroup_multiply_accumulate( + results[i * TN + j], Asimd[i], Bsimd[j], results[i * TN + j]); + } + } + } + } + + /* Store results from simdgroup_matrix results into device memory */ + METAL_FUNC void store_result(device T* C, const int ldc) const { +#pragma clang loop unroll(full) + for (int i = 0; i < TM; i++) { +#pragma clang loop unroll(full) + for (int j = 0; j < TN; j++) { + C[(i * TM_stride + sm + tm) * ldc + j * TN_stride + tn + sn] = + Epilogue::apply(results[i * TN + j].thread_elements()[0]); + C[(i * TM_stride + sm + tm) * ldc + j * TN_stride + tn + sn + 1] = + Epilogue::apply(results[i * TN + j].thread_elements()[1]); + } + } + } + + METAL_FUNC void + store_result_safe(device T* C, const int ldc, short2 dst_tile_dims) const { +#pragma clang loop unroll(full) + for (int i = 0; i < TM; i++) { + if (tm + i * TM_stride + sm < dst_tile_dims.y) { +#pragma clang loop unroll(full) + for (int j = 0; j < TN; j++) { + if (tn + j * TN_stride + sn < dst_tile_dims.x) { + C[(tm + i * TM_stride + sm) * ldc + tn + j * TN_stride + sn] = + Epilogue::apply(results[i * TN + j].thread_elements()[0]); + } + + if (tn + j * TN_stride + sn + 1 < dst_tile_dims.x) { + C[(tm + i * TM_stride + sm) * ldc + tn + j * TN_stride + sn + 1] = + Epilogue::apply(results[i * TN + j].thread_elements()[1]); + } + } + } + } + } +}; + +/////////////////////////////////////////////////////////////////////////////// +// GEMM kernels +/////////////////////////////////////////////////////////////////////////////// + +template < + typename T, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + typename AccumType = typename AccumHelper::accum_type, + typename Epilogue = TransformNone> +struct Conv2DImplicitGEMMKernel { + MLX_MTL_CONST short tgp_padding_a = 16 / sizeof(T); + MLX_MTL_CONST short tgp_padding_b = 16 / sizeof(T); + MLX_MTL_CONST short tgp_mem_size_a = + transpose_a ? BK * (BM + tgp_padding_a) : BM * (BK + tgp_padding_a); + MLX_MTL_CONST short tgp_mem_size_b = + transpose_b ? BN * (BK + tgp_padding_b) : BK * (BN + tgp_padding_b); + MLX_MTL_CONST short tgp_mem_size = tgp_mem_size_a + tgp_mem_size_b; + + MLX_MTL_CONST short tgp_size = WM * WN * 32; + MLX_MTL_CONST short vec_size = (BM == 64 && BN == 64) ? 8 : 4; + + using loader_a_t = + Conv2DInputBlockLoader; + using loader_b_t = + Conv2DWeightBlockLoader; + using mma_t = Conv2DBlockMMA< + T, + BM, + BN, + BK, + WM, + WN, + transpose_a, + transpose_b, + tgp_padding_a, + tgp_padding_b, + AccumType, + Epilogue>; + + /* Main kernel function */ + static METAL_FUNC void run( + const device T* A [[buffer(0)]], + const device T* B [[buffer(1)]], + device T* C [[buffer(2)]], + const constant MLXConvParams<2>& params [[buffer(3)]], + threadgroup T* tgp_memory [[threadgroup(0)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + const int c_row = tid.y * BM; + const int c_col = tid.x * BN; + const int K = params.wt_strides[0]; + const int N = params.O; + + B += c_col * K; + C += c_row * N + c_col; + + // Prepare threadgroup memory for loading + threadgroup T* As = tgp_memory; + threadgroup T* Bs = tgp_memory + tgp_mem_size_a; + + // Prepare threadgroup loading operations + loader_a_t loader_a(A, As, params, tid, lid, simd_gid, simd_lid); + loader_b_t loader_b(B, Bs, params, tid, lid, simd_gid, simd_lid); + + // Prepare threadgroup mma operation + mma_t mma_op(simd_gid, simd_lid); + + for (int k = 0; k < K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + // Load elements into threadgroup + loader_a.load_unsafe(); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + + threadgroup_barrier(mem_flags::mem_none); + + // Store results to device memory + mma_op.store_result(C, N); + } +}; \ No newline at end of file diff --git a/mlx/backend/metal/kernels/gemm/gemm.h b/mlx/backend/metal/kernels/gemm/gemm.h new file mode 100644 index 0000000000..f6c7b58327 --- /dev/null +++ b/mlx/backend/metal/kernels/gemm/gemm.h @@ -0,0 +1,536 @@ +#pragma once + +#include +#include +#include + +#define MLX_MTL_CONST static constant constexpr const + +using namespace metal; + +/////////////////////////////////////////////////////////////////////////////// +// Loading helper +/////////////////////////////////////////////////////////////////////////////// + +template < + typename T, + int BROWS, + int BCOLS, + int BK, + int vec_size, + int tgp_size, + bool transpose, + bool ldK, + int tgp_padding = 0> +struct BlockLoader { + // Destination dimensions + MLX_MTL_CONST int dst_fd = transpose ? BCOLS : BROWS; + MLX_MTL_CONST int dst_ld = (transpose ? BROWS : BCOLS) + tgp_padding; + MLX_MTL_CONST int n_vecs = (transpose ? BROWS : BCOLS) / vec_size; + + // Stride along block row within the block + MLX_MTL_CONST int bstride = tgp_size / n_vecs; + + // Leading dimension for src + const int src_ld; + // Stride along reduction axis between blocks + const int tstride; + + // Thread location indices + const short thread_idx; + const short bi; + const short bj; + + // threadgroup and device memory + threadgroup T* dst; + const device T* src; + + /* Constructor */ + METAL_FUNC BlockLoader( + const device T* src_, + const int src_ld_, + threadgroup T* dst_, + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]]) + : src_ld(src_ld_), + tstride( + BK * ((int)(transpose ^ !ldK) * src_ld + (int)(transpose ^ ldK))), + thread_idx(simd_group_id * 32 + simd_lane_id), + bi(thread_idx / n_vecs), + bj(vec_size * (thread_idx % n_vecs)), + dst(dst_ + bi * dst_ld + bj), + src(src_ + bi * src_ld + bj) {} + + /* Load from device memory into threadgroup memory - without bound checking */ + METAL_FUNC void load_unsafe() const { +#pragma clang loop unroll(full) + for (short i = 0; i < dst_fd; i += bstride) { +#pragma clang loop unroll(full) + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = src[i * src_ld + j]; + } + } + } + + /* Load from device memory into threadgroup memory - with bound checking */ + METAL_FUNC void load_safe(short2 src_tile_dim) const { + src_tile_dim = transpose ? src_tile_dim.yx : src_tile_dim.xy; + + // Iterate over rows of block +#pragma clang loop unroll(full) + for (short i = 0; i < dst_fd; i += bstride) { + // Row is in bounds, we check against column + if ((bi + i) < src_tile_dim.y) { + // Use fast thread memory for bound checks + short tmp_idx[vec_size]; + T tmp_val[vec_size]; + + // Make sure tmp_idx only contains valid indices +#pragma clang loop unroll(full) + for (short j = 0; j < vec_size; j++) { + tmp_idx[j] = bj + j < src_tile_dim.x ? j : 0; + } + + // Read all valid indcies into tmp_val +#pragma clang loop unroll(full) + for (short j = 0; j < vec_size; j++) { + tmp_val[j] = src[i * src_ld + tmp_idx[j]]; + } + + // Zero out uneeded values +#pragma clang loop unroll(full) + for (short j = 0; j < vec_size; j++) { + tmp_val[j] = bj + j < src_tile_dim.x ? tmp_val[j] : T(0); + } + + // Copy values to threadgroup memory +#pragma clang loop unroll(full) + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = tmp_val[j]; + } + } + + // Row is out of bounds, we just fill tgp memory with zeros + else { +#pragma clang loop unroll(full) + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = T(0); + } + } + } + } + + /* Iteration helper */ + METAL_FUNC void next() { + src += tstride; + } +}; + +/////////////////////////////////////////////////////////////////////////////// +// Transforms +/////////////////////////////////////////////////////////////////////////////// + +template +struct TransformNone { + static METAL_FUNC OutT apply(InT x) { + return static_cast(x); + } +}; + +template +struct AccumHelper { + typedef float accum_type; +}; + +/////////////////////////////////////////////////////////////////////////////// +// MMA helper +/////////////////////////////////////////////////////////////////////////////// + +template < + typename T, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + int tgp_padding_a = 0, + int tgp_padding_b = 0, + typename AccumType = typename AccumHelper::accum_type, + typename Epilogue = TransformNone> +struct BlockMMA { + // Warp tile size along M + MLX_MTL_CONST int TM = BM / (WM * 8); + // Warp tile size along N + MLX_MTL_CONST int TN = BN / (WN * 8); + + // Warp tile simdgroup matrix strides along M + MLX_MTL_CONST int TM_stride = 8 * WM; + // Warp tile simdgroup matrix strides along M + MLX_MTL_CONST int TN_stride = 8 * WN; + + // Leading dimensions of threadgroup A, B blocks + MLX_MTL_CONST int lda_tgp = (transpose_a ? BM : BK) + tgp_padding_a; + MLX_MTL_CONST int ldb_tgp = (transpose_b ? BK : BN) + tgp_padding_b; + + // Strides of A, B along reduction axis + MLX_MTL_CONST short simd_stride_a = + transpose_a ? TM_stride : TM_stride * lda_tgp; + MLX_MTL_CONST short simd_stride_b = + transpose_b ? TN_stride * ldb_tgp : TN_stride; + + // Jump between elements + MLX_MTL_CONST short jump_a = transpose_a ? lda_tgp : 1; + MLX_MTL_CONST short jump_b = transpose_b ? ldb_tgp : 1; + + // Offsets within threadgroup + const int tm; + const int tn; + + // Simdgroup matrices + simdgroup_matrix Asimd[TM]; + simdgroup_matrix Bsimd[TN]; + simdgroup_matrix results[TM * TN] = { + simdgroup_matrix(0)}; + + short sm; + short sn; + + /* Constructor */ + METAL_FUNC BlockMMA( + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]]) + : tm(8 * (simd_group_id / WN)), tn(8 * (simd_group_id % WN)) { + short qid = simd_lane_id / 4; + sm = (qid & 4) + (simd_lane_id / 2) % 4; + sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2; + } + + /* (BM, BK) X (BK, BN) multiply accumulate function */ + METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) { +// Iterate over BK in blocks of 8 +#pragma clang loop unroll(full) + for (short kk = 0; kk < BK; kk += 8) { + short2 offset_a = + transpose_a ? short2(tm + sm, kk + sn) : short2(kk + sn, tm + sm); + short2 offset_b = + transpose_b ? short2(kk + sm, tn + sn) : short2(tn + sn, kk + sm); + + const threadgroup T* As__ = As + offset_a.y * lda_tgp + offset_a.x; + const threadgroup T* Bs__ = Bs + offset_b.y * ldb_tgp + offset_b.x; + + simdgroup_barrier(mem_flags::mem_none); +// Load elements from threadgroup A as simdgroup matrices +#pragma clang loop unroll(full) + for (short i = 0; i < TM; i++) { + Asimd[i].thread_elements()[0] = static_cast(As__[0]); + Asimd[i].thread_elements()[1] = static_cast(As__[jump_a]); + As__ += simd_stride_a; + } + + simdgroup_barrier(mem_flags::mem_none); +// Load elements from threadgroup B as simdgroup matrices +#pragma clang loop unroll(full) + for (short j = 0; j < TN; j++) { + Bsimd[j].thread_elements()[0] = static_cast(Bs__[0]); + Bsimd[j].thread_elements()[1] = static_cast(Bs__[jump_b]); + Bs__ += simd_stride_b; + } + + simdgroup_barrier(mem_flags::mem_none); +// Multiply and accumulate into resulr simdgroup matrices +#pragma clang loop unroll(full) + for (short i = 0; i < TM; i++) { +#pragma clang loop unroll(full) + for (short j = 0; j < TN; j++) { + simdgroup_multiply_accumulate( + results[i * TN + j], Asimd[i], Bsimd[j], results[i * TN + j]); + } + } + } + } + + /* Store results from simdgroup_matrix results into device memory */ + METAL_FUNC void store_result(device T* C, const int ldc) const { +#pragma clang loop unroll(full) + for (int i = 0; i < TM; i++) { +#pragma clang loop unroll(full) + for (int j = 0; j < TN; j++) { + C[(i * TM_stride + sm + tm) * ldc + j * TN_stride + tn + sn] = + Epilogue::apply(results[i * TN + j].thread_elements()[0]); + C[(i * TM_stride + sm + tm) * ldc + j * TN_stride + tn + sn + 1] = + Epilogue::apply(results[i * TN + j].thread_elements()[1]); + } + } + } + + METAL_FUNC void + store_result_safe(device T* C, const int ldc, short2 dst_tile_dims) const { +#pragma clang loop unroll(full) + for (int i = 0; i < TM; i++) { + if (tm + i * TM_stride + sm < dst_tile_dims.y) { +#pragma clang loop unroll(full) + for (int j = 0; j < TN; j++) { + if (tn + j * TN_stride + sn < dst_tile_dims.x) { + C[(tm + i * TM_stride + sm) * ldc + tn + j * TN_stride + sn] = + Epilogue::apply(results[i * TN + j].thread_elements()[0]); + } + + if (tn + j * TN_stride + sn + 1 < dst_tile_dims.x) { + C[(tm + i * TM_stride + sm) * ldc + tn + j * TN_stride + sn + 1] = + Epilogue::apply(results[i * TN + j].thread_elements()[1]); + } + } + } + } + } +}; + +/////////////////////////////////////////////////////////////////////////////// +// GEMM kernels +/////////////////////////////////////////////////////////////////////////////// + +template < + typename T, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + bool MN_aligned, + bool K_aligned, + typename AccumType = typename AccumHelper::accum_type, + typename Epilogue = TransformNone> +struct GEMMKernel { + MLX_MTL_CONST short tgp_padding_a = 16 / sizeof(T); + MLX_MTL_CONST short tgp_padding_b = 16 / sizeof(T); + MLX_MTL_CONST short tgp_mem_size_a = + transpose_a ? BK * (BM + tgp_padding_a) : BM * (BK + tgp_padding_a); + MLX_MTL_CONST short tgp_mem_size_b = + transpose_b ? BN * (BK + tgp_padding_b) : BK * (BN + tgp_padding_b); + MLX_MTL_CONST short tgp_mem_size = tgp_mem_size_a + tgp_mem_size_b; + + MLX_MTL_CONST short tgp_size = WM * WN * 32; + MLX_MTL_CONST short vec_size = (BM == 64 && BN == 64) ? 8 : 4; + + using loader_a_t = BlockLoader< + T, + BM, + BK, + BK, + vec_size, + tgp_size, + transpose_a, + true, + tgp_padding_a>; + using loader_b_t = BlockLoader< + T, + BK, + BN, + BK, + vec_size, + tgp_size, + transpose_b, + false, + tgp_padding_b>; + using mma_t = BlockMMA< + T, + BM, + BN, + BK, + WM, + WN, + transpose_a, + transpose_b, + tgp_padding_a, + tgp_padding_b, + AccumType, + Epilogue>; + + /* Main kernel function */ + static METAL_FUNC void run( + const device T* A [[buffer(0)]], + const device T* B [[buffer(1)]], + device T* C [[buffer(2)]], + const constant int& M [[buffer(3)]], + const constant int& N [[buffer(4)]], + const constant int& K [[buffer(5)]], + const constant int& batch_stride_a [[buffer(6)]], + const constant int& batch_stride_b [[buffer(7)]], + const constant int& batch_stride_c [[buffer(8)]], + threadgroup T* tgp_memory [[threadgroup(0)]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { + // Pacifying compiler + (void)lid; + + // Adjust for batch + A += batch_stride_a * tid.z; + B += batch_stride_b * tid.z; + C += batch_stride_c * tid.z; + + // Adjust for transpose + const int lda_dev = transpose_a ? M : K; + const int ldb_dev = transpose_b ? K : N; + + // Find block in A, B, C + const int c_row = tid.y * BM; + const int c_col = tid.x * BN; + + A += transpose_a ? c_row : c_row * K; + B += transpose_b ? c_col * K : c_col; + C += c_row * N + c_col; + + // Prepare threadgroup memory for loading + threadgroup T* As = tgp_memory; + threadgroup T* Bs = tgp_memory + tgp_mem_size_a; + + // Prepare threadgroup loading operations + loader_a_t loader_a(A, lda_dev, As, simd_group_id, simd_lane_id); + loader_b_t loader_b(B, ldb_dev, Bs, simd_group_id, simd_lane_id); + + // Prepare threadgroup mma operation + mma_t mma_op(simd_group_id, simd_lane_id); + + /////////////////////////////////////////////////////////////////////////////// + // MNK aligned loop + if (MN_aligned && K_aligned) { + for (int k = 0; k < K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + // Load elements into threadgroup + loader_a.load_unsafe(); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + + threadgroup_barrier(mem_flags::mem_none); + + // Store results to device memory + mma_op.store_result(C, N); + return; + + } + /////////////////////////////////////////////////////////////////////////////// + // MN aligned, K unaligned loop + else if (MN_aligned && !K_aligned) { + // Main loop + int k = 0; + for (; k + BK <= K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + // Load elements into threadgroup + loader_a.load_unsafe(); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + + // Loop tail + threadgroup_barrier(mem_flags::mem_threadgroup); + + loader_a.load_safe(short2(K - k, BM)); + loader_b.load_safe(short2(BN, K - k)); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + mma_op.mma(As, Bs); + + // Store results to device memory + mma_op.store_result(C, N); + return; + + } + /////////////////////////////////////////////////////////////////////////////// + // MNK unaligned loop + else { // Loop over K - unaligned case + + short2 src_tile_dims(min(BN, N - c_col), min(BM, M - c_row)); + + if (src_tile_dims.y == BM && src_tile_dims.x == BN) { + int k = 0; + for (; k + BK <= K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + // Load elements into threadgroup + loader_a.load_unsafe(); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + + threadgroup_barrier(mem_flags::mem_none); + + if (k < K) { + loader_a.load_safe(short2(K - k, BM)); + loader_b.load_safe(short2(BN, K - k)); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + mma_op.mma(As, Bs); + } + + mma_op.store_result(C, N); + return; + + } else { + int k = 0; + for (; k + BK <= K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + // Load elements into threadgroup + loader_a.load_safe(short2(BK, src_tile_dims.y)); + loader_b.load_safe(short2(src_tile_dims.x, BK)); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + + threadgroup_barrier(mem_flags::mem_none); + + if (k < K) { + loader_a.load_safe(short2(K - k, src_tile_dims.y)); + loader_b.load_safe(short2(src_tile_dims.x, K - k)); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + mma_op.mma(As, Bs); + } + + threadgroup_barrier(mem_flags::mem_none); + mma_op.store_result_safe(C, N, src_tile_dims); + + return; + } + } + } +}; \ No newline at end of file diff --git a/mlx/backend/metal/kernels/gemv.metal b/mlx/backend/metal/kernels/gemv.metal new file mode 100644 index 0000000000..9b484871bc --- /dev/null +++ b/mlx/backend/metal/kernels/gemv.metal @@ -0,0 +1,302 @@ +#include +#include + +#include "mlx/backend/metal/kernels/defines.h" +#include "mlx/backend/metal/kernels/bf16.h" + +using namespace metal; + +/////////////////////////////////////////////////////////////////////////////// +/// Matrix vector multiplication +/////////////////////////////////////////////////////////////////////////////// + +static constant constexpr int SIMD_SIZE = 32; + +template /* Thread cols (in elements) */ +[[kernel]] void gemv( + const device T* mat [[buffer(0)]], + const device T* in_vec [[buffer(1)]], + device T* out_vec [[buffer(2)]], + const constant int& in_vec_size [[buffer(3)]], + const constant int& out_vec_size [[buffer(4)]], + const constant int& vector_batch_stride [[buffer(5)]], + const constant int& matrix_batch_stride [[buffer(6)]], + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + + static_assert(BN == SIMD_SIZE, "gemv block must have a width of SIMD_SIZE"); + + // - The matrix of size (M = out_vec_size, N = in_vec_size) is divided up + // into blocks of (BM * TM, BN * TN) divided amoung threadgroups + // - Every thread works on a block of (TM, TN) + // - We assume each thead group is launched with (BN, BM, 1) threads + // + // 1. A thread loads TN elements each from mat along TM contiguous rows + // and the corresponding scalar from the vector + // 2. The thread then multiplies and adds to accumulate its local result for the block + // 3. At the end, each thread has accumulated results over all blocks across the rows + // These are then summed up across the threadgroup + // 4. Each threadgroup writes its accumulated BN * TN outputs + // + // Edge case handling: + // - The threadgroup with the largest tid will have blocks that exceed the matrix + // * The blocks that start outside the matrix are never read (thread results remain zero) + // * The last thread that partialy overlaps with the matrix is shifted inwards + // such that the thread block fits exactly in the matrix + + // Update batch offsets + in_vec += tid.z * vector_batch_stride; + mat += tid.z * matrix_batch_stride; + out_vec += tid.z * out_vec_size; + + // Threadgroup in_vec cache + threadgroup T in_vec_block[BN][TN * 2]; + + // Thread local accumulation results + thread T result[TM] = {0}; + thread T inter[TN]; + thread T v_coeff[TN]; + + // Block position + int out_row = (tid.x * BM + simd_gid) * TM; + + // Exit simdgroup if rows out of bound + if(out_row >= out_vec_size) + return; + + // Adjust tail simdgroup to ensure in bound reads + out_row = out_row + TM <= out_vec_size ? out_row : out_vec_size - TM; + + // Advance matrix + mat += out_row * in_vec_size; + + // Loop over in_vec in blocks of BN * TN + for(int bn = simd_lid * TN; bn < in_vec_size; bn += BN * TN) { + threadgroup_barrier(mem_flags::mem_threadgroup); + // Prefetch in_vector for threadgroup use + if(simd_gid == 0) { + // Main load loop + if(bn + TN <= in_vec_size) { + #pragma clang loop unroll(full) + for(int tn = 0; tn < TN; tn++) { + in_vec_block[simd_lid][tn] = in_vec[bn + tn]; + } + } else { // Edgecase + #pragma clang loop unroll(full) + for(int tn = 0; tn < TN; tn++) { + in_vec_block[simd_lid][tn] = bn + tn < in_vec_size ? in_vec[bn + tn] : 0; + } + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Load for all rows + #pragma clang loop unroll(full) + for(int tn = 0; tn < TN; tn++) { + v_coeff[tn] = in_vec_block[simd_lid][tn]; + } + + // Per thread work loop + #pragma clang loop unroll(full) + for(int tm = 0; tm < TM; tm++) { + // Load for the row + for(int tn = 0; tn < TN; tn++) { + inter[tn] = mat[tm * in_vec_size + bn + tn]; + } + + // Accumulate results + for(int tn = 0; tn < TN; tn++) { + result[tm] += inter[tn] * v_coeff[tn]; + } + } + } + + // Simdgroup accumulations + #pragma clang loop unroll(full) + for(int tm = 0; tm < TM; tm++) { + result[tm] = simd_sum(result[tm]); + } + + // Write outputs + if(simd_lid == 0) { + #pragma clang loop unroll(full) + for(int tm = 0; tm < TM; tm++) { + out_vec[out_row + tm] = result[tm]; + } + } + +} + +#define instantiate_gemv(name, itype, bm, bn, tm, tn) \ + template [[host_name("gemv_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn)]] \ + [[kernel]] void gemv( \ + const device itype* mat [[buffer(0)]], \ + const device itype* vec [[buffer(1)]], \ + device itype* out [[buffer(2)]], \ + const constant int& in_vec_size [[buffer(3)]], \ + const constant int& out_vec_size [[buffer(4)]], \ + const constant int& vector_batch_stride [[buffer(5)]], \ + const constant int& matrix_batch_stride [[buffer(6)]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint simd_gid [[simdgroup_index_in_threadgroup]], \ + uint simd_lid [[thread_index_in_simdgroup]]); + +#define instantiate_gemv_blocks(name, itype) \ + instantiate_gemv(name, itype, 4, 32, 1, 4) \ + instantiate_gemv(name, itype, 4, 32, 4, 4) \ + instantiate_gemv(name, itype, 8, 32, 4, 4) + +instantiate_gemv_blocks(float32, float) +instantiate_gemv_blocks(float16, half) +instantiate_gemv_blocks(bfloat16, bfloat16_t) + +/////////////////////////////////////////////////////////////////////////////// +/// Vector matrix multiplication +/////////////////////////////////////////////////////////////////////////////// + +template /* Thread cols (in elements) */ +[[kernel]] void gemv_t( + const device T* mat [[buffer(0)]], + const device T* in_vec [[buffer(1)]], + device T* out_vec [[buffer(2)]], + const constant int& in_vec_size [[buffer(3)]], + const constant int& out_vec_size [[buffer(4)]], + const constant int& vector_batch_stride [[buffer(5)]], + const constant int& matrix_batch_stride [[buffer(6)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { + + // - The matrix of size (M = in_vec_size, N = out_vec_size) is divided up + // into blocks of (BM * TM, BN * TN) divided amoung threadgroups + // - Every thread works on a block of (TM, TN) + // - We assume each thead group is launched with (BN, BM, 1) threads + // + // 1. A thread loads TN elements each from mat along TM contiguous rows + // and the corresponding scalar from the vector + // 2. The thread then multiplies and adds to accumulate its local result for the block + // 3. At the end, each thread has accumulated results over all blocks across the rows + // These are then summed up across the threadgroup + // 4. Each threadgroup writes its accumulated BN * TN outputs + // + // Edge case handling: + // - The threadgroup with the largest tid will have blocks that exceed the matrix + // * The blocks that start outside the matrix are never read (thread results remain zero) + // * The last thread that partialy overlaps with the matrix is shifted inwards + // such that the thread block fits exactly in the matrix + + // Update batch offsets + in_vec += tid.z * vector_batch_stride; + mat += tid.z * matrix_batch_stride; + out_vec += tid.z * out_vec_size; + + // Thread local accumulation results + T result[TN] = {0}; + T inter[TN]; + T v_coeff[TM]; + + // Threadgroup accumulation results + threadgroup T tgp_results[BN][BM][TM]; + + int out_col = (tid.x * BN + lid.x) * TN; + int in_row = lid.y * TM; + + // Edgecase handling + if (out_col < out_vec_size) { + // Edgecase handling + out_col = out_col + TN < out_vec_size ? out_col : out_vec_size - TN; + + // Per thread accumulation main loop + int bm = in_row; + for(; bm < in_vec_size; bm += BM * TM) { + // Adding a threadgroup_barrier improves performance slightly + // This is possibly it may help exploit cache better + threadgroup_barrier(mem_flags::mem_none); + + if(bm + TM <= in_vec_size) { + + #pragma clang loop unroll(full) + for(int tm = 0; tm < TM; tm++) { + v_coeff[tm] = in_vec[bm + tm]; + } + + #pragma clang loop unroll(full) + for(int tm = 0; tm < TM; tm++) { + for(int tn = 0; tn < TN; tn++) { + inter[tn] = mat[(bm + tm) * out_vec_size + out_col + tn]; + } + for(int tn = 0; tn < TN; tn++) { + result[tn] += v_coeff[tm] * inter[tn]; + } + } + + } else { // Edgecase handling + for(int tm = 0; bm + tm < in_vec_size; tm++) { + v_coeff[tm] = in_vec[bm + tm]; + + for(int tn = 0; tn < TN; tn++) { + inter[tn] = mat[(bm + tm) * out_vec_size + out_col + tn]; + } + for(int tn = 0; tn < TN; tn++) { + result[tn] += v_coeff[tm] * inter[tn]; + } + } + } + } + + } + + // Threadgroup collection + for(int i = 0; i < TN; i++) { + tgp_results[lid.x][lid.y][i] = result[i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + if(lid.y == 0 && out_col < out_vec_size) { + // Threadgroup accumulation + #pragma clang loop unroll(full) + for(int i = 1; i < BM; i++) { + for(int j = 0; j < TN; j++) { + result[j] += tgp_results[lid.x][i][j]; + } + } + + for(int j = 0; j < TN; j++) { + out_vec[out_col + j] = result[j]; + } + } + +} + +#define instantiate_gemv_t(name, itype, bm, bn, tm, tn) \ + template [[host_name("gemv_t_" #name "_bm" #bm "_bn" #bn "_tm" #tm "_tn" #tn)]] \ + [[kernel]] void gemv_t( \ + const device itype* mat [[buffer(0)]], \ + const device itype* vec [[buffer(1)]], \ + device itype* out [[buffer(2)]], \ + const constant int& in_vec_size [[buffer(3)]], \ + const constant int& out_vec_size [[buffer(4)]], \ + const constant int& vector_batch_stride [[buffer(5)]], \ + const constant int& matrix_batch_stride [[buffer(6)]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint3 lid [[thread_position_in_threadgroup]]); + +#define instantiate_gemv_t_blocks(name, itype) \ + instantiate_gemv_t(name, itype, 8, 8, 4, 1) \ + instantiate_gemv_t(name, itype, 8, 8, 4, 4) \ + instantiate_gemv_t(name, itype, 8, 16, 4, 4) \ + instantiate_gemv_t(name, itype, 8, 32, 4, 4) \ + instantiate_gemv_t(name, itype, 8, 64, 4, 4) \ + instantiate_gemv_t(name, itype, 8, 128, 4, 4) + +instantiate_gemv_t_blocks(float32, float) +instantiate_gemv_t_blocks(float16, half) +instantiate_gemv_t_blocks(bfloat16, bfloat16_t) diff --git a/mlx/backend/metal/kernels/softmax.metal b/mlx/backend/metal/kernels/softmax.metal new file mode 100644 index 0000000000..07a49bc07a --- /dev/null +++ b/mlx/backend/metal/kernels/softmax.metal @@ -0,0 +1,226 @@ +#include +#include +#include + +#include "mlx/backend/metal/kernels/bf16.h" +#include "mlx/backend/metal/kernels/defines.h" +#include "mlx/backend/metal/kernels/utils.h" + +using namespace metal; + +template +inline T softmax_exp(T x) { + // Softmax doesn't need high precision exponential cause it is gonna be x + // will be in (-oo, 0] anyway and subsequently it will be divided by + // sum(exp(x_i)). + return fast::exp(x); +} + +template +[[kernel]] void softmax_single_row( + const device T* in, + device T* out, + constant int& axis_size, + threadgroup T* local_max [[threadgroup(0)]], + threadgroup T* local_normalizer [[threadgroup(1)]], + uint gid [[threadgroup_position_in_grid]], + uint _lid [[thread_position_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + int lid = _lid; + + T ld[N_READS]; + + in += gid * axis_size + lid * N_READS; + if (lid * N_READS + N_READS <= axis_size) { + for (int i=0; i::finite_min); + } + } + if (simd_group_id == 0) { + local_max[simd_lane_id] = Limits::finite_min; + local_normalizer[simd_lane_id] = 0; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Get the max + T maxval = Limits::finite_min; + for (int i = 0; i < N_READS; i++) { + maxval = (maxval < ld[i]) ? ld[i] : maxval; + } + maxval = simd_max(maxval); + if (simd_lane_id == 0) { + local_max[simd_group_id] = maxval; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + if (simd_group_id == 0) { + maxval = simd_max(local_max[simd_lane_id]); + if (simd_lane_id == 0) { + local_max[0] = maxval; + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + maxval = local_max[0]; + + // Compute exp(x_i - maxval) and store the partial sums in local_normalizer + T normalizer = 0; + for (int i = 0; i < N_READS; i++) { + T exp_x = softmax_exp(ld[i] - maxval); + ld[i] = exp_x; + normalizer += exp_x; + } + normalizer = simd_sum(normalizer); + if (simd_lane_id == 0) { + local_normalizer[simd_group_id] = normalizer; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + if (simd_group_id == 0) { + normalizer = simd_sum(local_normalizer[simd_lane_id]); + if (simd_lane_id == 0) { + local_normalizer[0] = normalizer; + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + normalizer = 1 / local_normalizer[0]; + + // Normalize and write to the output + out += gid * axis_size + lid * N_READS; + if (lid * N_READS + N_READS <= axis_size) { + for (int i=0; i +[[kernel]] void softmax_looped( + const device T* in, + device T* out, + constant int& axis_size, + threadgroup T* local_max [[threadgroup(0)]], + threadgroup T* local_normalizer [[threadgroup(1)]], + uint gid [[threadgroup_position_in_grid]], + uint lid [[thread_position_in_threadgroup]], + uint lsize [[threads_per_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + in += gid * axis_size; + + // Get the max and the normalizer in one go + T prevmax; + T maxval = Limits::finite_min; + T normalizer = 0; + for (int r = 0; r < static_cast(ceildiv(axis_size, N_READS * lsize)); + r++) { + int offset = r * lsize * N_READS + lid * N_READS; + T vals[N_READS]; + if (offset + N_READS <= axis_size) { + for (int i = 0; i < N_READS; i++) { + vals[i] = in[offset + i]; + } + } else { + for (int i = 0; i < N_READS; i++) { + vals[i] = + (offset + i < axis_size) ? in[offset + i] : T(Limits::finite_min); + } + } + prevmax = maxval; + for (int i = 0; i < N_READS; i++) { + maxval = (maxval < vals[i]) ? vals[i] : maxval; + } + normalizer *= softmax_exp(prevmax - maxval); + for (int i = 0; i < N_READS; i++) { + normalizer += softmax_exp(vals[i] - maxval); + } + } + // Now we got partial normalizer of N_READS * ceildiv(axis_size, N_READS * + // lsize) parts. We need to combine them. + // 1. We start by finding the max across simd groups + // 2. We then change the partial normalizers to account for a possible + // change in max + // 3. We sum all normalizers + prevmax = maxval; + maxval = simd_max(maxval); + normalizer *= softmax_exp(prevmax - maxval); + normalizer = simd_sum(normalizer); + + // Now the normalizer and max value is correct for each simdgroup. We write + // them shared memory and combine them. + prevmax = maxval; + if (simd_lane_id == 0) { + local_max[simd_group_id] = maxval; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + maxval = simd_max(local_max[simd_lane_id]); + normalizer *= softmax_exp(prevmax - maxval); + if (simd_lane_id == 0) { + local_normalizer[simd_group_id] = normalizer; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + normalizer = simd_sum(local_normalizer[simd_lane_id]); + normalizer = 1 / normalizer; + + // Finally given the normalizer and max value we can directly write the + // softmax output + out += gid * axis_size; + for (int r = 0; r < static_cast(ceildiv(axis_size, N_READS * lsize)); + r++) { + int offset = r * lsize * N_READS + lid * N_READS; + if (offset + N_READS <= axis_size) { + for (int i=0; i( \ + const device itype* in, \ + device itype* out, \ + constant int& axis_size, \ + threadgroup itype* local_max [[threadgroup(0)]], \ + threadgroup itype* local_normalizer [[threadgroup(1)]], \ + uint gid [[thread_position_in_grid]], \ + uint _lid [[thread_position_in_threadgroup]], \ + uint simd_lane_id [[thread_index_in_simdgroup]], \ + uint simd_group_id [[simdgroup_index_in_threadgroup]]); + +#define instantiate_softmax_looped(name, itype) \ + template [[host_name("softmax_looped_" #name)]] [[kernel]] void \ + softmax_looped( \ + const device itype* in, \ + device itype* out, \ + constant int& axis_size, \ + threadgroup itype* local_max [[threadgroup(0)]], \ + threadgroup itype* local_normalizer [[threadgroup(1)]], \ + uint gid [[threadgroup_position_in_grid]], \ + uint lid [[thread_position_in_threadgroup]], \ + uint lsize [[threads_per_threadgroup]], \ + uint simd_lane_id [[thread_index_in_simdgroup]], \ + uint simd_group_id [[simdgroup_index_in_threadgroup]]); + +#define instantiate_softmax(name, itype) \ + instantiate_softmax_single_row(name, itype) \ + instantiate_softmax_looped(name, itype) + +instantiate_softmax(float32, float) instantiate_softmax(float16, half) + instantiate_softmax(bfloat16, bfloat16_t) diff --git a/mlx/backend/metal/kernels/sort.metal b/mlx/backend/metal/kernels/sort.metal new file mode 100644 index 0000000000..adafe6e13e --- /dev/null +++ b/mlx/backend/metal/kernels/sort.metal @@ -0,0 +1,818 @@ +#include + +#include "mlx/backend/metal/kernels/bf16.h" +#include "mlx/backend/metal/kernels/defines.h" +#include "mlx/backend/metal/kernels/utils.h" + +#define MLX_MTL_CONST static constant constexpr const +#define MLX_MTL_LOOP_UNROLL _Pragma("clang loop unroll(full)") + +using namespace metal;\ + +// Based on GPU merge sort algorithm at https://github.com/NVIDIA/cccl/tree/main/cub/cub + +/////////////////////////////////////////////////////////////////////////////// +// Thread-level sort +/////////////////////////////////////////////////////////////////////////////// + +template +METAL_FUNC void thread_swap(thread T& a, thread T& b) { + T w = a; + a = b; + b = w; +} + +template +struct LessThan { + static constexpr constant T init = Limits::max; + + METAL_FUNC bool operator()(T a, T b) { + return a < b; + } +}; + +template < + typename val_t, + typename idx_t, + bool ARG_SORT, + short N_PER_THREAD, + typename CompareOp> +struct ThreadSort { + static METAL_FUNC void sort( + thread val_t (&vals)[N_PER_THREAD], + thread idx_t (&idxs)[N_PER_THREAD]) { + + CompareOp op; + + MLX_MTL_LOOP_UNROLL + for(short i = 0; i < N_PER_THREAD; ++i) { + MLX_MTL_LOOP_UNROLL + for(short j = i & 1; j < N_PER_THREAD - 1; j += 2) { + if(op(vals[j + 1], vals[j])) { + thread_swap(vals[j + 1], vals[j]); + thread_swap(idxs[j + 1], idxs[j]); + } + } + } + + } +}; + +/////////////////////////////////////////////////////////////////////////////// +// Threadgroup-level sort +/////////////////////////////////////////////////////////////////////////////// + +template < + typename val_t, + typename idx_t, + bool ARG_SORT, + short BLOCK_THREADS, + short N_PER_THREAD, + typename CompareOp> +struct BlockMergeSort { + using thread_sort_t = ThreadSort; + static METAL_FUNC int merge_partition( + const threadgroup val_t* As, + const threadgroup val_t* Bs, + short A_sz, + short B_sz, + short sort_md) { + + CompareOp op; + + short A_st = max(0, sort_md - B_sz); + short A_ed = min(sort_md, A_sz); + + while(A_st < A_ed) { + short md = A_st + (A_ed - A_st) / 2; + auto a = As[md]; + auto b = Bs[sort_md - 1 - md]; + + if(op(b, a)) { + A_ed = md; + } else { + A_st = md + 1; + } + } + + return A_ed; + + } + + static METAL_FUNC void merge_step( + const threadgroup val_t* As, + const threadgroup val_t* Bs, + const threadgroup idx_t* As_idx, + const threadgroup idx_t* Bs_idx, + short A_sz, + short B_sz, + thread val_t (&vals)[N_PER_THREAD], + thread idx_t (&idxs)[N_PER_THREAD]) { + + CompareOp op; + short a_idx = 0; + short b_idx = 0; + + for(int i = 0; i < N_PER_THREAD; ++i) { + auto a = As[a_idx]; + auto b = Bs[b_idx]; + bool pred = (b_idx < B_sz) && (a_idx >= A_sz || op(b, a)); + + vals[i] = pred ? b : a; + idxs[i] = pred ? Bs_idx[b_idx] : As_idx[a_idx]; + + b_idx += short(pred); + a_idx += short(!pred); + } + + } + + static METAL_FUNC void sort( + threadgroup val_t* tgp_vals [[threadgroup(0)]], + threadgroup idx_t* tgp_idxs [[threadgroup(1)]], + int size_sorted_axis, + uint3 lid [[thread_position_in_threadgroup]]) { + + // Get thread location + int idx = lid.x * N_PER_THREAD; + + // Load from shared memory + thread val_t thread_vals[N_PER_THREAD]; + thread idx_t thread_idxs[N_PER_THREAD]; + for(int i = 0; i < N_PER_THREAD; ++i) { + thread_vals[i] = tgp_vals[idx + i]; + if(ARG_SORT) { + thread_idxs[i] = tgp_idxs[idx + i]; + } + } + + // Per thread sort + if(idx < size_sorted_axis) { + thread_sort_t::sort(thread_vals, thread_idxs); + } + + // Do merges using threadgroup memory + for (int merge_threads = 2; merge_threads <= BLOCK_THREADS; merge_threads *= 2) { + // Update threadgroup memory + threadgroup_barrier(mem_flags::mem_threadgroup); + for(int i = 0; i < N_PER_THREAD; ++i) { + tgp_vals[idx + i] = thread_vals[i]; + if(ARG_SORT) { + tgp_idxs[idx + i] = thread_idxs[i]; + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Find location in merge step + int merge_group = lid.x / merge_threads; + int merge_lane = lid.x % merge_threads; + + int sort_sz = N_PER_THREAD * merge_threads; + int sort_st = N_PER_THREAD * merge_threads * merge_group; + + // As = tgp_vals[A_st:A_ed] is sorted + // Bs = tgp_vals[B_st:B_ed] is sorted + int A_st = sort_st; + int A_ed = sort_st + sort_sz / 2; + int B_st = sort_st + sort_sz / 2; + int B_ed = sort_st + sort_sz; + + const threadgroup val_t* As = tgp_vals + A_st; + const threadgroup val_t* Bs = tgp_vals + B_st; + int A_sz = A_ed - A_st; + int B_sz = B_ed - B_st; + + // Find a partition of merge elements + // Ci = merge(As[partition:], Bs[sort_md - partition:]) + // of size N_PER_THREAD for each merge lane i + // C = [Ci] is sorted + int sort_md = N_PER_THREAD * merge_lane; + int partition = merge_partition( + As, + Bs, + A_sz, + B_sz, + sort_md); + + As += partition; + Bs += sort_md - partition; + + A_sz -= partition; + B_sz -= sort_md - partition; + + const threadgroup idx_t* As_idx = ARG_SORT ? tgp_idxs + A_st + partition : nullptr; + const threadgroup idx_t* Bs_idx = ARG_SORT ? tgp_idxs + B_st + sort_md - partition : nullptr; + + // Merge starting at the partition and store results in thread registers + merge_step( + As, + Bs, + As_idx, + Bs_idx, + A_sz, + B_sz, + thread_vals, + thread_idxs); + + } + + // Write out to shared memory + threadgroup_barrier(mem_flags::mem_threadgroup); + for(int i = 0; i < N_PER_THREAD; ++i) { + tgp_vals[idx + i] = thread_vals[i]; + if(ARG_SORT) { + tgp_idxs[idx + i] = thread_idxs[i]; + } + } + } +}; + +/////////////////////////////////////////////////////////////////////////////// +// Kernel sort +/////////////////////////////////////////////////////////////////////////////// + +template < + typename T, + typename U, + bool ARG_SORT, + short BLOCK_THREADS, + short N_PER_THREAD, + typename CompareOp = LessThan> +struct KernelMergeSort { + using val_t = T; + using idx_t = uint; + using block_merge_sort_t = BlockMergeSort< + val_t, + idx_t, + ARG_SORT, + BLOCK_THREADS, + N_PER_THREAD, + CompareOp>; + + MLX_MTL_CONST short N_PER_BLOCK = BLOCK_THREADS * N_PER_THREAD; + + static METAL_FUNC void block_sort( + const device T* inp, + device U* out, + const constant int& size_sorted_axis, + const constant int& stride_sorted_axis, + const constant int& stride_segment_axis, + threadgroup val_t* tgp_vals, + threadgroup idx_t* tgp_idxs, + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { + + // tid.y tells us the segment index + inp += tid.y * stride_segment_axis; + out += tid.y * stride_segment_axis; + + // Copy into threadgroup memory + for(short i = lid.x; i < N_PER_BLOCK; i+= BLOCK_THREADS) { + tgp_vals[i] = i < size_sorted_axis ? inp[i * stride_sorted_axis] : val_t(CompareOp::init); + if(ARG_SORT) { + tgp_idxs[i] = i; + } + } + + // Sort elements within the block + threadgroup_barrier(mem_flags::mem_threadgroup); + + block_merge_sort_t::sort(tgp_vals, tgp_idxs, size_sorted_axis, lid); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Write output + for(int i = lid.x; i < size_sorted_axis; i+= BLOCK_THREADS) { + if(ARG_SORT) { + out[i * stride_sorted_axis] = tgp_idxs[i]; + } else { + out[i * stride_sorted_axis] = tgp_vals[i]; + } + } + } +}; + +template < + typename T, + typename U, + bool ARG_SORT, + short BLOCK_THREADS, + short N_PER_THREAD> +[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void block_sort( + const device T* inp [[buffer(0)]], + device U* out [[buffer(1)]], + const constant int& size_sorted_axis [[buffer(2)]], + const constant int& stride_sorted_axis [[buffer(3)]], + const constant int& stride_segment_axis [[buffer(4)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { + + using sort_kernel = KernelMergeSort; + using val_t = typename sort_kernel::val_t; + using idx_t = typename sort_kernel::idx_t; + + if(ARG_SORT) { + threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK]; + threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK]; + sort_kernel::block_sort( + inp, + out, + size_sorted_axis, + stride_sorted_axis, + stride_segment_axis, + tgp_vals, + tgp_idxs, + tid, + lid); + } else { + threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK]; + sort_kernel::block_sort( + inp, + out, + size_sorted_axis, + stride_sorted_axis, + stride_segment_axis, + tgp_vals, + nullptr, + tid, + lid); + } + +} + +constant constexpr const int zero_helper = 0; + +template < + typename T, + typename U, + bool ARG_SORT, + short BLOCK_THREADS, + short N_PER_THREAD> +[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void block_sort_nc( + const device T* inp [[buffer(0)]], + device U* out [[buffer(1)]], + const constant int& size_sorted_axis [[buffer(2)]], + const constant int& stride_sorted_axis [[buffer(3)]], + const constant int& nc_dim [[buffer(4)]], + const device int* nc_shape [[buffer(5)]], + const device size_t* nc_strides [[buffer(6)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { + + using sort_kernel = KernelMergeSort; + using val_t = typename sort_kernel::val_t; + using idx_t = typename sort_kernel::idx_t; + + auto block_idx = elem_to_loc(tid.y, nc_shape, nc_strides, nc_dim); + inp += block_idx; + out += block_idx; + + if(ARG_SORT) { + threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK]; + threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK]; + sort_kernel::block_sort( + inp, + out, + size_sorted_axis, + stride_sorted_axis, + zero_helper, + tgp_vals, + tgp_idxs, + tid, + lid); + } else { + threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK]; + sort_kernel::block_sort( + inp, + out, + size_sorted_axis, + stride_sorted_axis, + zero_helper, + tgp_vals, + nullptr, + tid, + lid); + } + +} + +/////////////////////////////////////////////////////////////////////////////// +// Instantiations +/////////////////////////////////////////////////////////////////////////////// + + +#define instantiate_block_sort(name, itname, itype, otname, otype, arg_sort, bn, tn) \ + template [[host_name(#name "_" #itname "_" #otname "_bn" #bn "_tn" #tn)]] \ + [[kernel]] void block_sort( \ + const device itype* inp [[buffer(0)]], \ + device otype* out [[buffer(1)]], \ + const constant int& size_sorted_axis [[buffer(2)]], \ + const constant int& stride_sorted_axis [[buffer(3)]], \ + const constant int& stride_segment_axis [[buffer(4)]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint3 lid [[thread_position_in_threadgroup]]); \ + template [[host_name(#name "_" #itname "_" #otname "_bn" #bn "_tn" #tn "_nc")]] \ + [[kernel]] void block_sort_nc( \ + const device itype* inp [[buffer(0)]], \ + device otype* out [[buffer(1)]], \ + const constant int& size_sorted_axis [[buffer(2)]], \ + const constant int& stride_sorted_axis [[buffer(3)]], \ + const constant int& nc_dim [[buffer(4)]], \ + const device int* nc_shape [[buffer(5)]], \ + const device size_t* nc_strides [[buffer(6)]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint3 lid [[thread_position_in_threadgroup]]); + +#define instantiate_arg_block_sort_base(itname, itype, bn, tn) \ + instantiate_block_sort(arg_block_merge_sort, itname, itype, uint32, uint32_t, true, bn, tn) + +#define instantiate_block_sort_base(itname, itype, bn, tn) \ + instantiate_block_sort(block_merge_sort, itname, itype, itname, itype, false, bn, tn) + +#define instantiate_block_sort_tn(itname, itype, bn) \ + instantiate_block_sort_base(itname, itype, bn, 8) \ + instantiate_arg_block_sort_base(itname, itype, bn, 8) + +#define instantiate_block_sort_bn(itname, itype) \ + instantiate_block_sort_tn(itname, itype, 128) \ + instantiate_block_sort_tn(itname, itype, 256) \ + instantiate_block_sort_tn(itname, itype, 512) + +instantiate_block_sort_bn(uint8, uint8_t) +instantiate_block_sort_bn(uint16, uint16_t) +instantiate_block_sort_bn(uint32, uint32_t) +instantiate_block_sort_bn(int8, int8_t) +instantiate_block_sort_bn(int16, int16_t) +instantiate_block_sort_bn(int32, int32_t) +instantiate_block_sort_bn(float16, half) +instantiate_block_sort_bn(float32, float) +instantiate_block_sort_bn(bfloat16, bfloat16_t) + +#define instantiate_block_sort_long(itname, itype) \ + instantiate_block_sort_tn(itname, itype, 128) \ + instantiate_block_sort_tn(itname, itype, 256) + +instantiate_block_sort_long(uint64, uint64_t) +instantiate_block_sort_long(int64, int64_t) + +/////////////////////////////////////////////////////////////////////////////// +// Multi block merge sort +/////////////////////////////////////////////////////////////////////////////// + +template < + typename val_t, + typename idx_t, + bool ARG_SORT, + short BLOCK_THREADS, + short N_PER_THREAD, + typename CompareOp = LessThan> +struct KernelMultiBlockMergeSort { + using block_merge_sort_t = BlockMergeSort< + val_t, + idx_t, + ARG_SORT, + BLOCK_THREADS, + N_PER_THREAD, + CompareOp>; + + MLX_MTL_CONST short N_PER_BLOCK = BLOCK_THREADS * N_PER_THREAD; + + static METAL_FUNC void block_sort( + const device val_t* inp, + device val_t* out_vals, + device idx_t* out_idxs, + const constant int& size_sorted_axis, + const constant int& stride_sorted_axis, + threadgroup val_t* tgp_vals, + threadgroup idx_t* tgp_idxs, + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { + + // tid.y tells us the segment index + int base_idx = tid.x * N_PER_BLOCK; + + // Copy into threadgroup memory + for(short i = lid.x; i < N_PER_BLOCK; i+= BLOCK_THREADS) { + int idx = base_idx + i; + tgp_vals[i] = idx < size_sorted_axis ? inp[idx * stride_sorted_axis] : val_t(CompareOp::init); + tgp_idxs[i] = idx; + } + + // Sort elements within the block + threadgroup_barrier(mem_flags::mem_threadgroup); + + block_merge_sort_t::sort(tgp_vals, tgp_idxs, size_sorted_axis, lid); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Write output + for(int i = lid.x; i < N_PER_BLOCK; i+= BLOCK_THREADS) { + int idx = base_idx + i; + if(idx < size_sorted_axis) { + out_vals[idx] = tgp_vals[i]; + out_idxs[idx] = tgp_idxs[i]; + } + } + } + + static METAL_FUNC int merge_partition( + const device val_t* As, + const device val_t* Bs, + int A_sz, + int B_sz, + int sort_md) { + + CompareOp op; + + int A_st = max(0, sort_md - B_sz); + int A_ed = min(sort_md, A_sz); + + while(A_st < A_ed) { + int md = A_st + (A_ed - A_st) / 2; + auto a = As[md]; + auto b = Bs[sort_md - 1 - md]; + + if(op(b, a)) { + A_ed = md; + } else { + A_st = md + 1; + } + } + + return A_ed; + + } +}; + +template < + typename val_t, + typename idx_t, + bool ARG_SORT, + short BLOCK_THREADS, + short N_PER_THREAD> +[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void mb_block_sort( + const device val_t* inp [[buffer(0)]], + device val_t* out_vals [[buffer(1)]], + device idx_t* out_idxs [[buffer(2)]], + const constant int& size_sorted_axis [[buffer(3)]], + const constant int& stride_sorted_axis [[buffer(4)]], + const constant int& nc_dim [[buffer(5)]], + const device int* nc_shape [[buffer(6)]], + const device size_t* nc_strides [[buffer(7)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { + + using sort_kernel = KernelMultiBlockMergeSort; + + auto block_idx = elem_to_loc(tid.y, nc_shape, nc_strides, nc_dim); + inp += block_idx; + out_vals += tid.y * size_sorted_axis; + out_idxs += tid.y * size_sorted_axis; + + threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK]; + threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK]; + + sort_kernel::block_sort( + inp, + out_vals, + out_idxs, + size_sorted_axis, + stride_sorted_axis, + tgp_vals, + tgp_idxs, + tid, + lid); +} + +template < + typename val_t, + typename idx_t, + bool ARG_SORT, + short BLOCK_THREADS, + short N_PER_THREAD> +[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void mb_block_partiton( + device idx_t* block_partitions [[buffer(0)]], + const device val_t* dev_vals [[buffer(1)]], + const device idx_t* dev_idxs [[buffer(2)]], + const constant int& size_sorted_axis [[buffer(3)]], + const constant int& merge_tiles [[buffer(4)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint3 tgp_dims [[threads_per_threadgroup]]) { + + using sort_kernel = KernelMultiBlockMergeSort< + val_t, + idx_t, + ARG_SORT, + BLOCK_THREADS, + N_PER_THREAD>; + + block_partitions += tid.y * tgp_dims.x; + dev_vals += tid.y * size_sorted_axis; + dev_idxs += tid.y * size_sorted_axis; + + // Find location in merge step + int merge_group = lid.x / merge_tiles; + int merge_lane = lid.x % merge_tiles; + + int sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles; + int sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group; + + int A_st = min(size_sorted_axis, sort_st); + int A_ed = min(size_sorted_axis, sort_st + sort_sz / 2); + int B_st = A_ed; + int B_ed = min(size_sorted_axis, B_st + sort_sz / 2); + + int partition_at = min(B_ed - A_st, sort_kernel::N_PER_BLOCK * merge_lane); + int partition = sort_kernel::merge_partition( + dev_vals + A_st, + dev_vals + B_st, + A_ed - A_st, + B_ed - B_st, + partition_at); + + block_partitions[lid.x] = A_st + partition; + +} + +template < + typename val_t, + typename idx_t, + bool ARG_SORT, + short BLOCK_THREADS, + short N_PER_THREAD, + typename CompareOp = LessThan> +[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void mb_block_merge( + const device idx_t* block_partitions [[buffer(0)]], + const device val_t* dev_vals_in [[buffer(1)]], + const device idx_t* dev_idxs_in [[buffer(2)]], + device val_t* dev_vals_out [[buffer(3)]], + device idx_t* dev_idxs_out [[buffer(4)]], + const constant int& size_sorted_axis [[buffer(5)]], + const constant int& merge_tiles [[buffer(6)]], + const constant int& num_tiles [[buffer(7)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { + + using sort_kernel = KernelMultiBlockMergeSort< + val_t, + idx_t, + ARG_SORT, + BLOCK_THREADS, + N_PER_THREAD, + CompareOp>; + + using block_sort_t = typename sort_kernel::block_merge_sort_t; + + block_partitions += tid.y * (num_tiles + 1); + dev_vals_in += tid.y * size_sorted_axis; + dev_idxs_in += tid.y * size_sorted_axis; + dev_vals_out += tid.y * size_sorted_axis; + dev_idxs_out += tid.y * size_sorted_axis; + + int block_idx = tid.x; + int merge_group = block_idx / merge_tiles; + int sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group; + int sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles; + int sort_md = sort_kernel::N_PER_BLOCK * block_idx - sort_st; + + int A_st = block_partitions[block_idx + 0]; + int A_ed = block_partitions[block_idx + 1]; + int B_st = min(size_sorted_axis, 2 * sort_st + sort_sz/2 + sort_md - A_st); + int B_ed = min(size_sorted_axis, 2 * sort_st + sort_sz/2 + sort_md + sort_kernel::N_PER_BLOCK - A_ed); + + if((block_idx % merge_tiles) == merge_tiles - 1) { + A_ed = min(size_sorted_axis, sort_st + sort_sz/2); + B_ed = min(size_sorted_axis, sort_st + sort_sz); + } + + int A_sz = A_ed - A_st; + int B_sz = B_ed - B_st; + + // Load from global memory + thread val_t thread_vals[N_PER_THREAD]; + thread idx_t thread_idxs[N_PER_THREAD]; + for(int i = 0; i < N_PER_THREAD; i++) { + int idx = BLOCK_THREADS * i + lid.x; + if(idx < (A_sz + B_sz)) { + thread_vals[i] = (idx < A_sz) ? dev_vals_in[A_st + idx] : dev_vals_in[B_st + idx - A_sz]; + thread_idxs[i] = (idx < A_sz) ? dev_idxs_in[A_st + idx] : dev_idxs_in[B_st + idx - A_sz]; + } else { + thread_vals[i] = CompareOp::init; + thread_idxs[i] = 0; + } + } + + // Write to shared memory + threadgroup val_t tgp_vals[sort_kernel::N_PER_BLOCK]; + threadgroup idx_t tgp_idxs[sort_kernel::N_PER_BLOCK]; + threadgroup_barrier(mem_flags::mem_threadgroup); + for(int i = 0; i < N_PER_THREAD; i++) { + int idx = BLOCK_THREADS * i + lid.x; + tgp_vals[idx] = thread_vals[i]; + tgp_idxs[idx] = thread_idxs[i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Merge + int sort_md_local = min(A_sz + B_sz, N_PER_THREAD * int(lid.x)); + + int A_st_local = block_sort_t::merge_partition( + tgp_vals, + tgp_vals + A_sz, + A_sz, + B_sz, + sort_md_local); + int A_ed_local = A_sz; + + int B_st_local = sort_md_local - A_st_local; + int B_ed_local = B_sz; + + int A_sz_local = A_ed_local - A_st_local; + int B_sz_local = B_ed_local - B_st_local; + + // Do merge + block_sort_t::merge_step( + tgp_vals + A_st_local, + tgp_vals + A_ed_local + B_st_local, + tgp_idxs + A_st_local, + tgp_idxs + A_ed_local + B_st_local, + A_sz_local, + B_sz_local, + thread_vals, + thread_idxs); + + threadgroup_barrier(mem_flags::mem_threadgroup); + for(int i = 0; i < N_PER_THREAD; ++i) { + int idx = lid.x * N_PER_THREAD; + tgp_vals[idx + i] = thread_vals[i]; + tgp_idxs[idx + i] = thread_idxs[i]; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + // Write output + int base_idx = tid.x * sort_kernel::N_PER_BLOCK; + for(int i = lid.x; i < sort_kernel::N_PER_BLOCK; i+= BLOCK_THREADS) { + int idx = base_idx + i; + if(idx < size_sorted_axis) { + dev_vals_out[idx] = tgp_vals[i]; + dev_idxs_out[idx] = tgp_idxs[i]; + } + } + +} + +#define instantiate_multi_block_sort(vtname, vtype, itname, itype, arg_sort, bn, tn) \ + template [[host_name("mb_block_sort_" #vtname "_" #itname "_bn" #bn "_tn" #tn)]] \ + [[kernel]] void mb_block_sort( \ + const device vtype* inp [[buffer(0)]], \ + device vtype* out_vals [[buffer(1)]], \ + device itype* out_idxs [[buffer(2)]], \ + const constant int& size_sorted_axis [[buffer(3)]], \ + const constant int& stride_sorted_axis [[buffer(4)]], \ + const constant int& nc_dim [[buffer(5)]], \ + const device int* nc_shape [[buffer(6)]], \ + const device size_t* nc_strides [[buffer(7)]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint3 lid [[thread_position_in_threadgroup]]); \ + template [[host_name("mb_block_partiton_" #vtname "_" #itname "_bn" #bn "_tn" #tn)]] \ + [[kernel]] void mb_block_partiton( \ + device itype* block_partitions [[buffer(0)]], \ + const device vtype* dev_vals [[buffer(1)]], \ + const device itype* dev_idxs [[buffer(2)]], \ + const constant int& size_sorted_axis [[buffer(3)]], \ + const constant int& merge_tiles [[buffer(4)]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint3 lid [[thread_position_in_threadgroup]], \ + uint3 tgp_dims [[threads_per_threadgroup]]); \ + template [[host_name("mb_block_merge_" #vtname "_" #itname "_bn" #bn "_tn" #tn)]] \ + [[kernel]] void mb_block_merge( \ + const device itype* block_partitions [[buffer(0)]], \ + const device vtype* dev_vals_in [[buffer(1)]], \ + const device itype* dev_idxs_in [[buffer(2)]], \ + device vtype* dev_vals_out [[buffer(3)]], \ + device itype* dev_idxs_out [[buffer(4)]], \ + const constant int& size_sorted_axis [[buffer(5)]], \ + const constant int& merge_tiles [[buffer(6)]], \ + const constant int& num_tiles [[buffer(7)]], \ + uint3 tid [[threadgroup_position_in_grid]], \ + uint3 lid [[thread_position_in_threadgroup]]); + +#define instantiate_multi_block_sort_base(vtname, vtype) \ + instantiate_multi_block_sort(vtname, vtype, uint32, uint32_t, true, 512, 8) + +instantiate_multi_block_sort_base(uint8, uint8_t) +instantiate_multi_block_sort_base(uint16, uint16_t) +instantiate_multi_block_sort_base(uint32, uint32_t) +instantiate_multi_block_sort_base(int8, int8_t) +instantiate_multi_block_sort_base(int16, int16_t) +instantiate_multi_block_sort_base(int32, int32_t) +instantiate_multi_block_sort_base(float16, half) +instantiate_multi_block_sort_base(float32, float) +instantiate_multi_block_sort_base(bfloat16, bfloat16_t) + +#define instantiate_multi_block_sort_long(vtname, vtype) \ + instantiate_multi_block_sort(vtname, vtype, uint32, uint32_t, true, 256, 8) + +instantiate_multi_block_sort_long(uint64, uint64_t) +instantiate_multi_block_sort_long(int64, int64_t) \ No newline at end of file diff --git a/mlx/backend/metal/kernels/utils.h b/mlx/backend/metal/kernels/utils.h new file mode 100644 index 0000000000..22b4a271b9 --- /dev/null +++ b/mlx/backend/metal/kernels/utils.h @@ -0,0 +1,244 @@ +#pragma once + +#include +#include "mlx/backend/metal/kernels/bf16.h" +#include "mlx/backend/metal/kernels/complex.h" + +/////////////////////////////////////////////////////////////////////////////// +// Type limits utils +/////////////////////////////////////////////////////////////////////////////// + +template +struct Limits { + static const constant U max; + static const constant U min; + static const constant U finite_max; + static const constant U finite_min; +}; + +#define instantiate_default_limit(type) \ + template <> \ + struct Limits { \ + static constexpr constant type max = metal::numeric_limits::max(); \ + static constexpr constant type min = metal::numeric_limits::min(); \ + static constexpr constant type finite_max = \ + metal::numeric_limits::max(); \ + static constexpr constant type finite_min = \ + metal::numeric_limits::min(); \ + }; + +instantiate_default_limit(uint8_t); +instantiate_default_limit(uint16_t); +instantiate_default_limit(uint32_t); +instantiate_default_limit(uint64_t); +instantiate_default_limit(int8_t); +instantiate_default_limit(int16_t); +instantiate_default_limit(int32_t); +instantiate_default_limit(int64_t); + +#define instantiate_float_limit(type) \ + template <> \ + struct Limits { \ + static constexpr constant type max = \ + metal::numeric_limits::infinity(); \ + static constexpr constant type min = \ + -metal::numeric_limits::infinity(); \ + static constexpr constant type finite_max = \ + metal::numeric_limits::max(); \ + static constexpr constant type finite_min = \ + -metal::numeric_limits::max(); \ + }; + +instantiate_float_limit(half); +instantiate_float_limit(float); +instantiate_float_limit(bfloat16_t); + +template <> +struct Limits { + static constexpr constant bool max = true; + static constexpr constant bool min = false; +}; + +/////////////////////////////////////////////////////////////////////////////// +// Indexing utils +/////////////////////////////////////////////////////////////////////////////// + +inline size_t elem_to_loc( + uint elem, + device const int* shape, + device const size_t* strides, + int ndim) { + size_t loc = 0; + for (int i = ndim - 1; i >= 0; --i) { + loc += (elem % shape[i]) * strides[i]; + elem /= shape[i]; + } + return loc; +} + +inline size_t elem_to_loc( + uint elem, + constant const int* shape, + constant const size_t* strides, + int ndim) { + size_t loc = 0; + for (int i = ndim - 1; i >= 0; --i) { + loc += (elem % shape[i]) * strides[i]; + elem /= shape[i]; + } + return loc; +} + +template +inline uint2 elem_to_loc_2_nd( + uint3 elem, + constant const int shape[NDIM], + constant const size_t a_strides[NDIM], + constant const size_t b_strides[NDIM]) { + uint2 loc = { + static_cast( + elem.x * a_strides[NDIM - 1] + elem.y * a_strides[NDIM - 2]), + static_cast( + elem.x * b_strides[NDIM - 1] + elem.y * b_strides[NDIM - 2])}; + for (int d = NDIM - 3; d >= 0; --d) { + uint l = elem.z % shape[d]; + loc.x += l * a_strides[d]; + loc.y += l * b_strides[d]; + elem.z /= shape[d]; + } + return loc; +} + +template +inline size_t elem_to_loc_nd( + uint3 elem, + constant const int shape[NDIM], + constant const size_t strides[NDIM]) { + size_t loc = elem.x * strides[NDIM - 1] + elem.y * strides[NDIM - 2]; + for (int d = NDIM - 3; d >= 0; --d) { + loc += (elem.z % shape[d]) * strides[d]; + elem.z /= shape[d]; + } + return loc; +} + +inline size_t elem_to_loc_1(uint elem, constant const size_t& stride) { + return elem * stride; +} + +inline size_t elem_to_loc_2(uint2 elem, constant const size_t strides[2]) { + return elem.x * strides[1] + elem.y * strides[0]; +} + +inline size_t elem_to_loc_3(uint3 elem, constant const size_t strides[3]) { + return elem.x * strides[2] + elem.y * strides[1] + elem.z * strides[0]; +} + +// Non templated version to handle arbitrary dims +inline size_t elem_to_loc( + uint3 elem, + constant const int* shape, + constant const size_t* strides, + int ndim) { + size_t loc = elem.x * strides[ndim - 1] + elem.y * strides[ndim - 2]; + for (int d = ndim - 3; d >= 0; --d) { + loc += (elem.z % shape[d]) * strides[d]; + elem.z /= shape[d]; + } + return loc; +} + +inline uint2 elem_to_loc_2_nd( + uint3 elem, + constant const int* shape, + constant const size_t* a_strides, + constant const size_t* b_strides, + int ndim) { + uint2 loc = { + static_cast( + elem.x * a_strides[ndim - 1] + elem.y * a_strides[ndim - 2]), + static_cast( + elem.x * b_strides[ndim - 1] + elem.y * b_strides[ndim - 2])}; + for (int d = ndim - 3; d >= 0; --d) { + uint l = elem.z % shape[d]; + loc.x += l * a_strides[d]; + loc.y += l * b_strides[d]; + elem.z /= shape[d]; + } + return loc; +} + +template +inline uint elem_to_loc_nd( + uint elem, + device const int* shape, + device const size_t* strides); + +template <> +inline uint elem_to_loc_nd<1>( + uint elem, + device const int* shape, + device const size_t* strides) { + return (elem % shape[0]) * strides[0]; +} + +template <> +inline uint elem_to_loc_nd<2>( + uint elem, + device const int* shape, + device const size_t* strides) { + uint loc = (elem % shape[1]) * strides[1]; + elem /= shape[1]; + loc += (elem % shape[0]) * strides[0]; + return loc; +} + +template <> +inline uint elem_to_loc_nd<3>( + uint elem, + device const int* shape, + device const size_t* strides) { + uint loc = (elem % shape[2]) * strides[2]; + elem /= shape[2]; + loc += (elem % shape[1]) * strides[1]; + elem /= shape[1]; + loc += (elem % shape[0]) * strides[0]; + return loc; +} + +template <> +inline uint elem_to_loc_nd<4>( + uint elem, + device const int* shape, + device const size_t* strides) { + uint loc = (elem % shape[3]) * strides[3]; + elem /= shape[3]; + loc += (elem % shape[2]) * strides[2]; + elem /= shape[2]; + loc += (elem % shape[1]) * strides[1]; + elem /= shape[1]; + loc += (elem % shape[0]) * strides[0]; + return loc; +} + +/////////////////////////////////////////////////////////////////////////////// +// Calculation utils +/////////////////////////////////////////////////////////////////////////////// + +/** Compute ceil((float)N/(float)M) */ +inline size_t ceildiv(size_t N, size_t M) { + return (N + M - 1) / M; +} + +// https://docs.oracle.com/cd/E19957-01/806-3568/ncg_goldberg.html#1202 +inline float log1p(float x) { + float xp1 = 1.0f + x; + return (xp1 == 1.0f) ? x : x * (metal::log(xp1) / (xp1 - 1.0f)); +} + +inline bfloat16_t log1p(bfloat16_t x) { + float xp1 = 1.0f + static_cast(x); + bfloat16_t ret = + (xp1 == 1.0f) ? x : bfloat16_t(x * (metal::log(xp1) / (xp1 - 1.0f))); + return ret; +} diff --git a/mlx/backend/metal/matmul.cpp b/mlx/backend/metal/matmul.cpp new file mode 100644 index 0000000000..d6fdc34ece --- /dev/null +++ b/mlx/backend/metal/matmul.cpp @@ -0,0 +1,446 @@ +#include +#include +#include +#include + +#include "mlx/backend/metal/copy.h" +#include "mlx/backend/metal/device.h" +#include "mlx/backend/metal/kernels/defines.h" +#include "mlx/backend/metal/matmul.h" +#include "mlx/backend/metal/mps/gemm.h" +#include "mlx/backend/metal/utils.h" +#include "mlx/primitives.h" +#include "mlx/utils.h" + +namespace mlx::core { + +namespace { + +bool use_mps() { + auto get_val = []() { + if (const char* buff_str = std::getenv("MLX_USE_MPS")) { + return std::string(buff_str) != "OFF"; + } else { + return false; + } + }; + static bool use_mps_ = get_val(); + return use_mps_; +} + +#define MAX_OPS_PER_BUFFER max_ops_per_buffer() + +inline void mps_matmul( + const Stream& s, + metal::Device& d, + const array& a, + const array& b, + array& out, + int M, + int N, + int K, + int batch_size_out, + int lda, + int ldb, + bool transpose_a, + bool transpose_b, + std::vector& copies) { + MPS::DataType mps_dtype = MPS::DataTypeFloat32; + + if (out.dtype() == float16) { + mps_dtype = MPS::DataTypeFloat16; + } else if (out.dtype() == bfloat16) { + mps_dtype = MPS::DataTypeBFloat16; + } + + // Used batched MPSMatrixMultiplication if batch_size_out > 1 + // We only accept the following cases: + // 1. Both a, b have batch_size_out matrices worth of data + // 2. Only one of a or b has batch_size_out matrices worth of data and + // the other has matrix worth of data + + // The matrix dimsenisons of a and b are sure to be regularly strided + if (batch_size_out > 1) { + // No broadcasting defaults + auto batch_size_a = a.data_size() / (M * K); + auto batch_size_b = b.data_size() / (K * N); + + auto matrix_stride_a = M * K; + auto matrix_stride_b = K * N; + auto matrix_stride_out = M * N; + + // At this point, batch_size_a, batch_size_b show the number of matrices + // in data, no broadcasted strides considered + if (batch_size_out == std::max(batch_size_a, batch_size_b)) { + // Handle simple broadcasting + if (std::min(batch_size_a, batch_size_b) == 1) { + matrix_stride_a = (batch_size_a == 1) ? 0 : matrix_stride_a; + matrix_stride_b = (batch_size_b == 1) ? 0 : matrix_stride_b; + + batch_size_a = batch_size_out; + batch_size_b = batch_size_out; + } + + // Only proceed if broadcasting between a and b is simple + // At this point, batch_size_a, batch_size_b show the number of matrices + // after broadcasting + if (batch_size_a == batch_size_b) { + auto a_desc = MPS::MatrixDescriptor::matrixDescriptor( + (M * K) / lda, + lda, + batch_size_a, + lda * a.itemsize(), + (matrix_stride_a * a.itemsize()), + mps_dtype); + + auto b_desc = MPS::MatrixDescriptor::matrixDescriptor( + (K * N) / ldb, + ldb, + batch_size_b, + ldb * b.itemsize(), + (matrix_stride_b * b.itemsize()), + mps_dtype); + + auto out_desc = MPS::MatrixDescriptor::matrixDescriptor( + M, + N, + batch_size_out, + N * out.itemsize(), + matrix_stride_out * out.itemsize(), + mps_dtype); + + auto a_buf = static_cast(a.buffer().ptr()); + auto a_mat = MPS::Matrix::alloc()->init(a_buf, a_desc); + + auto b_buf = static_cast(b.buffer().ptr()); + auto b_mat = MPS::Matrix::alloc()->init(b_buf, b_desc); + + auto out_buf = static_cast(out.buffer().ptr()); + auto out_mat = MPS::Matrix::alloc()->init(out_buf, out_desc); + + auto kernel = MPS::MatrixMultiplication::alloc()->init( + d.mtl_device(), transpose_a, transpose_b, M, N, K, 1.0, 0.0); + + auto command_buffer = d.get_command_buffer(s.index); + kernel->setBatchSize(batch_size_out); + kernel->setBatchStart(0); + kernel->encodeToCommandBuffer(command_buffer, a_mat, b_mat, out_mat); + command_buffer->addCompletedHandler( + [a_mat, b_mat, out_mat, kernel, copies]( + MTL::CommandBuffer*) mutable { + a_mat->release(); + b_mat->release(); + out_mat->release(); + kernel->release(); + copies.clear(); + }); + + return; + } + } + } + + // Schedule as many calls to MPSMatrixMultiplication as needed otherwise + auto a_desc = MPS::MatrixDescriptor::matrixDescriptor( + a.data_size() / lda, lda, lda * a.itemsize(), mps_dtype); + + auto b_desc = MPS::MatrixDescriptor::matrixDescriptor( + b.data_size() / ldb, ldb, ldb * b.itemsize(), mps_dtype); + + auto out_desc = MPS::MatrixDescriptor::matrixDescriptor( + batch_size_out * M, N, N * out.itemsize(), mps_dtype); + + auto a_buf = static_cast(a.buffer().ptr()); + auto a_mat = MPS::Matrix::alloc()->init(a_buf, a_desc); + + auto b_buf = static_cast(b.buffer().ptr()); + auto b_mat = MPS::Matrix::alloc()->init(b_buf, b_desc); + + auto out_buf = static_cast(out.buffer().ptr()); + auto out_mat = MPS::Matrix::alloc()->init(out_buf, out_desc); + + auto kernel = MPS::MatrixMultiplication::alloc()->init( + d.mtl_device(), transpose_a, transpose_b, M, N, K, 1.0, 0.0); + + auto command_buffer = d.get_command_buffer(s.index); + for (int i = 0; i < batch_size_out; ++i) { + auto a_row = elem_to_loc(M * K * i, a.shape(), a.strides()) / lda; + auto b_row = elem_to_loc(K * N * i, b.shape(), b.strides()) / ldb; + kernel->setLeftMatrixOrigin({a_row, 0, 0}); + kernel->setRightMatrixOrigin({b_row, 0, 0}); + kernel->setResultMatrixOrigin({i * static_cast(M), 0, 0}); + kernel->encodeToCommandBuffer(command_buffer, a_mat, b_mat, out_mat); + } + + command_buffer->addCompletedHandler( + [a_mat, b_mat, out_mat, kernel, copies](MTL::CommandBuffer*) mutable { + a_mat->release(); + b_mat->release(); + out_mat->release(); + kernel->release(); + copies.clear(); + }); +} + +} // namespace + +void mlx_matmul( + const Stream& s, + metal::Device& d, + const array& a, + const array& b, + array& out, + int M, + int N, + int K, + int batch_size_out, + int lda, + int ldb, + bool transpose_a, + bool transpose_b, + std::vector& copies) { + // Account for batch sizes and basic broadcasting + int batch_size_a = a.data_size() / (M * K); + int batch_size_b = b.data_size() / (K * N); + + int matrix_stride_a = (batch_size_a == 1) ? 0 : M * K; + int matrix_stride_b = (batch_size_b == 1) ? 0 : K * N; + int matrix_stride_out = M * N; + + // Determine dispatch kernel + int bm = 32, bn = 32, bk = 16; + int wm = 2, wn = 2; + + if ((size_t)batch_size_out * M * N >= 2ul << 20) { + if (!transpose_a && transpose_b) { + bm = 64; + bn = (out.dtype() == float32) ? 64 : 32; + bk = (out.dtype() == float32) ? 16 : 32; + } else { + bm = 64; + bn = 64; + } + } + + std::ostringstream kname; + kname << "gemm_" << (transpose_a ? 't' : 'n') << (transpose_b ? 't' : 'n') + << "_" << type_to_name(a) << "_" << type_to_name(out) << "_bm" << bm + << "_bn" << bn << "_bk" << bk << "_wm" << wm << "_wn" << wn << "_MN_" + << ((M % bm == 0 && N % bn == 0) ? "t" : "n") << "aligned" + << "_K_" << ((K % bk == 0) ? "t" : "n") << "aligned"; + + // Encode and dispatch kernel + auto compute_encoder = d.get_command_encoder(s.index); + auto kernel = d.get_kernel(kname.str()); + compute_encoder->setComputePipelineState(kernel); + + // Launch only 1 kernel in the case of simple batching / broadcasting + if (batch_size_out == std::max(batch_size_a, batch_size_b) && + (batch_size_a == batch_size_b || + std::min(batch_size_a, batch_size_b) == 1)) { + MTL::Size group_dims = MTL::Size(32, wn, wm); + MTL::Size grid_dims = + MTL::Size((N + bn - 1) / bn, (M + bm - 1) / bm, batch_size_out); + + set_array_buffer(compute_encoder, a, 0); + set_array_buffer(compute_encoder, b, 1); + set_array_buffer(compute_encoder, out, 2); + + compute_encoder->setBytes(&M, sizeof(int), 3); + compute_encoder->setBytes(&N, sizeof(int), 4); + compute_encoder->setBytes(&K, sizeof(int), 5); + compute_encoder->setBytes(&matrix_stride_a, sizeof(int), 6); + compute_encoder->setBytes(&matrix_stride_b, sizeof(int), 7); + compute_encoder->setBytes(&matrix_stride_out, sizeof(int), 8); + compute_encoder->dispatchThreadgroups(grid_dims, group_dims); + } else { // Other launch kernels with set offsets + + for (int i = 0; i < batch_size_out; ++i) { + auto a_off = elem_to_loc(M * K * i, a.shape(), a.strides()); + auto b_off = elem_to_loc(K * N * i, b.shape(), b.strides()); + + MTL::Size group_dims = MTL::Size(32, wn, wm); + MTL::Size grid_dims = MTL::Size((N + bn - 1) / bn, (M + bm - 1) / bm, 1); + + auto a_buf = static_cast(a.buffer().ptr()); + auto b_buf = static_cast(b.buffer().ptr()); + auto out_buf = static_cast(out.buffer().ptr()); + + compute_encoder->setBuffer(a_buf, a_off * a.itemsize(), 0); + compute_encoder->setBuffer(b_buf, b_off * b.itemsize(), 1); + compute_encoder->setBuffer(out_buf, i * M * N * out.itemsize(), 2); + + compute_encoder->setBytes(&M, sizeof(int), 3); + compute_encoder->setBytes(&N, sizeof(int), 4); + compute_encoder->setBytes(&K, sizeof(int), 5); + compute_encoder->setBytes(&matrix_stride_a, sizeof(int), 6); + compute_encoder->setBytes(&matrix_stride_b, sizeof(int), 7); + compute_encoder->setBytes(&matrix_stride_out, sizeof(int), 8); + compute_encoder->dispatchThreadgroups(grid_dims, group_dims); + } + } + + d.get_command_buffer(s.index)->addCompletedHandler( + [copies](MTL::CommandBuffer*) mutable { copies.clear(); }); + return; +} + +void Matmul::eval_gpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 2); + if (!is_floating_point(out.dtype())) { + throw std::runtime_error( + "[matmul] Does not yet support non-floating point types."); + } + out.set_data(allocator::malloc_or_wait(out.nbytes())); + auto& s = stream(); + auto& d = metal::device(s.device); + + auto& a_pre = inputs[0]; + auto& b_pre = inputs[1]; + + // Keep a vector with copies to be cleared in the completed buffer to release + // the arrays + std::vector copies; + auto check_transpose = [&copies, &s](const array& arr) { + auto stx = arr.strides()[arr.ndim() - 2]; + auto sty = arr.strides()[arr.ndim() - 1]; + if (stx == arr.shape(-1) && sty == 1) { + return std::make_tuple(false, stx, arr); + } else if (stx == 1 && sty == arr.shape(-2)) { + return std::make_tuple(true, sty, arr); + } else { + array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); + copy_gpu(arr, arr_copy, CopyType::General, s); + copies.push_back(arr_copy); + size_t stx = arr.shape(-1); + return std::make_tuple(false, stx, arr_copy); + } + }; + + auto [a_transposed, a_cols, a] = check_transpose(a_pre); + auto [b_transposed, b_cols, b] = check_transpose(b_pre); + + int M = a.shape(-2); + int N = b.shape(-1); + int K = a.shape(-1); + + auto batch_size_out = out.size() / (M * N); + + // Route to gemv if needed + if (std::min(M, N) == 1) { + // Collect problem info + bool is_b_matrix = N != 1; + + auto& mat = is_b_matrix ? b : a; + auto& vec = is_b_matrix ? a : b; + bool transpose_mat = is_b_matrix ? !b_transposed : a_transposed; + int in_vector_len = K; + int out_vector_len = is_b_matrix ? N : M; + + int mat_cols = transpose_mat ? out_vector_len : in_vector_len; + int mat_rows = transpose_mat ? in_vector_len : out_vector_len; + + int batch_size_mat = mat.data_size() / (mat_cols * mat_rows); + int stride_mat = batch_size_mat == batch_size_out ? mat_cols * mat_rows : 0; + + int batch_size_vec = vec.data_size() / in_vector_len; + int stride_vec = batch_size_vec == batch_size_out ? in_vector_len : 0; + + // Determine dispatch kernel + int tm = 4, tn = 4; + int bm, bn, n_out_per_tgp; + std::ostringstream kname; + + if (transpose_mat) { + bm = 8; + bn = 8; + if (out_vector_len >= 24576) { + bn = 128; + } else if (out_vector_len >= 16384) { + bn = 64; + } else if (out_vector_len >= 8192) { + bn = 16; + } + + // Specialized kernel for very small outputs + tn = out_vector_len < tn ? 1 : tn; + + n_out_per_tgp = bn * tn; + kname << "gemv_t_" << type_to_name(out); + + } else { + bm = out_vector_len >= 4096 ? 8 : 4; + bn = 32; + + // Specialized kernel for very small outputs + tm = out_vector_len < tm ? 1 : tm; + + n_out_per_tgp = bm * tm; + kname << "gemv_" << type_to_name(out); + } + + kname << "_bm" << bm << "_bn" << bn << "_tm" << tm << "_tn" << tn; + + // Encode and dispatch kernel + auto compute_encoder = d.get_command_encoder(s.index); + auto kernel = d.get_kernel(kname.str()); + compute_encoder->setComputePipelineState(kernel); + + int n_tgp = (out_vector_len + n_out_per_tgp - 1) / n_out_per_tgp; + MTL::Size group_dims = MTL::Size(bn, bm, 1); + MTL::Size grid_dims = MTL::Size(n_tgp, 1, batch_size_out); + + set_array_buffer(compute_encoder, mat, 0); + set_array_buffer(compute_encoder, vec, 1); + set_array_buffer(compute_encoder, out, 2); + + compute_encoder->setBytes(&in_vector_len, sizeof(int), 3); + compute_encoder->setBytes(&out_vector_len, sizeof(int), 4); + compute_encoder->setBytes(&stride_vec, sizeof(int), 5); + compute_encoder->setBytes(&stride_mat, sizeof(int), 6); + compute_encoder->dispatchThreadgroups(grid_dims, group_dims); + + d.get_command_buffer(s.index)->addCompletedHandler( + [copies](MTL::CommandBuffer*) mutable { copies.clear(); }); + return; + } + + d.end_encoding(s.index); + + if (use_mps()) { + mps_matmul( + s, + d, + a, + b, + out, + M, + N, + K, + batch_size_out, + a_cols, + b_cols, + a_transposed, + b_transposed, + copies); + return; + } + + mlx_matmul( + s, + d, + a, + b, + out, + M, + N, + K, + batch_size_out, + a_cols, + b_cols, + a_transposed, + b_transposed, + copies); +} + +} // namespace mlx::core diff --git a/mlx/backend/metal/matmul.h b/mlx/backend/metal/matmul.h new file mode 100644 index 0000000000..34e86733b4 --- /dev/null +++ b/mlx/backend/metal/matmul.h @@ -0,0 +1,29 @@ +#include +#include +#include + +#include "mlx/backend/metal/copy.h" +#include "mlx/backend/metal/device.h" +#include "mlx/backend/metal/mps/gemm.h" +#include "mlx/backend/metal/utils.h" +#include "mlx/utils.h" + +namespace mlx::core { + +void mlx_matmul( + const Stream& s, + metal::Device& d, + const array& a, + const array& b, + array& out, + int M, + int N, + int K, + int batch_size_out, + int lda, + int ldb, + bool transpose_a, + bool transpose_b, + std::vector& copies); + +} // namespace mlx::core \ No newline at end of file diff --git a/mlx/backend/metal/mps/gemm.h b/mlx/backend/metal/mps/gemm.h new file mode 100644 index 0000000000..6201644697 --- /dev/null +++ b/mlx/backend/metal/mps/gemm.h @@ -0,0 +1,368 @@ +#pragma once + +#include + +#define _MPS_PRIVATE_CLS(symbol) (MTL::Private::Class::s_k##symbol) +#define _MPS_PRIVATE_SEL(accessor) (MTL::Private::Selector::s_k##accessor) + +namespace MTL::Private::Class { +_MTL_PRIVATE_DEF_CLS(MPSMatrixDescriptor); +_MTL_PRIVATE_DEF_CLS(MPSMatrix); +_MTL_PRIVATE_DEF_CLS(MPSVectorDescriptor); +_MTL_PRIVATE_DEF_CLS(MPSVector); +_MTL_PRIVATE_DEF_CLS(MPSKernel); +_MTL_PRIVATE_DEF_CLS(MPSMatrixMultiplication); +_MTL_PRIVATE_DEF_CLS(MPSMatrixVectorMultiplication); +} // namespace MTL::Private::Class + +namespace MTL::Private::Selector { +_MTL_PRIVATE_DEF_SEL( + matrixDescriptorWithRows_columns_rowBytes_dataType, + "matrixDescriptorWithRows:columns:rowBytes:dataType:"); +_MTL_PRIVATE_DEF_SEL( + matrixDescriptorWithRows_columns_matrices_rowBytes_matrixBytes_dataType, + "matrixDescriptorWithRows:columns:matrices:rowBytes:matrixBytes:dataType:"); +_MTL_PRIVATE_DEF_SEL(rows, "rows"); +_MTL_PRIVATE_DEF_SEL(initWithBuffer_descriptor, "initWithBuffer:descriptor:"); +_MTL_PRIVATE_DEF_SEL( + initWithDevice_, + "initWithDevice:transposeLeft:transposeRight:" + "resultRows:resultColumns:interiorColumns:alpha:beta:"); +_MTL_PRIVATE_DEF_SEL( + encodeToCommandBuffer_leftMatrix_rightMatrix_resultMatrix, + "encodeToCommandBuffer:leftMatrix:rightMatrix:resultMatrix:"); +_MTL_PRIVATE_DEF_SEL(setLeftMatrixOrigin_, "setLeftMatrixOrigin:"); +_MTL_PRIVATE_DEF_SEL(setRightMatrixOrigin_, "setRightMatrixOrigin:"); +_MTL_PRIVATE_DEF_SEL(setResultMatrixOrigin_, "setResultMatrixOrigin:"); +_MTL_PRIVATE_DEF_SEL(setBatchStart_, "setBatchStart:"); +_MTL_PRIVATE_DEF_SEL(setBatchSize_, "setBatchSize:"); +_MTL_PRIVATE_DEF_SEL( + vectorDescriptorWithLength_dataType, + "vectorDescriptorWithLength:dataType:"); +_MTL_PRIVATE_DEF_SEL( + vectorDescriptorWithLength_vectors_vectorBytes_dataType, + "vectorDescriptorWithLength:vectors:vectorBytes:dataType:"); +_MTL_PRIVATE_DEF_SEL( + initWithDevice_transpose_rows_columns_alpha_beta, + "initWithDevice:transpose:rows:columns:alpha:beta:"); +_MTL_PRIVATE_DEF_SEL( + encodeToCommandBuffer_inputMatrix_inputVector_resultVector, + "encodeToCommandBuffer:inputMatrix:inputVector:resultVector:"); +} // namespace MTL::Private::Selector + +namespace MPS { + +typedef enum DataType : uint32_t { + DataTypeFloatBit = 0x10000000, + DataTypeAlternateEncodingBit = 0x80000000, + DataTypeFloat16 = DataTypeFloatBit | 16, + DataTypeFloat32 = DataTypeFloatBit | 32, + DataTypeBFloat16 = DataTypeAlternateEncodingBit | DataTypeFloat16 +} DataType; + +class MatrixDescriptor : public NS::Copying { + public: + static class MatrixDescriptor* matrixDescriptor( + NS::UInteger rows, + NS::UInteger columns, + NS::UInteger rowBytes, + NS::UInteger dataType); + static class MatrixDescriptor* matrixDescriptor( + NS::UInteger rows, + NS::UInteger columns, + NS::UInteger matrices, + NS::UInteger rowBytes, + NS::UInteger matrixBytes, + NS::UInteger dataType); + NS::UInteger rows() const; +}; + +class Matrix : public NS::Referencing { + public: + static class Matrix* alloc(); + Matrix* init(MTL::Buffer* buffer, MatrixDescriptor* descriptor); + Matrix* init(const MTL::Buffer* buffer, MatrixDescriptor* descriptor); +}; + +class Kernel : public NS::Referencing { + public: + NS::String* label() const; + MTL::Device* device() const; +}; + +class MatrixMultiplication + : public NS::Referencing { + public: + static class MatrixMultiplication* alloc(); + + MatrixMultiplication* init( + MTL::Device* device, + bool transposeLeft, + bool transposeRight, + NS::UInteger resultRows, + NS::UInteger resultColumns, + NS::UInteger interiorColumns, + double alpha, + double beta); + + void encodeToCommandBuffer( + MTL::CommandBuffer* commandBuffer, + Matrix* leftMatrix, + Matrix* rightMatrix, + Matrix* resultMatrix); + + void setLeftMatrixOrigin(MTL::Origin origin); + void setRightMatrixOrigin(MTL::Origin origin); + void setResultMatrixOrigin(MTL::Origin origin); + void setBatchStart(NS::UInteger batchStart); + void setBatchSize(NS::UInteger batchSize); +}; + +class VectorDescriptor : public NS::Copying { + public: + static class VectorDescriptor* vectorDescriptor( + NS::UInteger length, + NS::UInteger dataType); + static class VectorDescriptor* vectorDescriptor( + NS::UInteger length, + NS::UInteger vectors, + NS::UInteger vectorBytes, + NS::UInteger dataType); +}; + +class Vector : public NS::Referencing { + public: + static class Vector* alloc(); + Vector* init(MTL::Buffer* buffer, VectorDescriptor* descriptor); + Vector* init(const MTL::Buffer* buffer, VectorDescriptor* descriptor); +}; + +class MatrixVectorMultiplication + : public NS::Referencing { + public: + static class MatrixVectorMultiplication* alloc(); + + MatrixVectorMultiplication* init( + MTL::Device* device, + bool transpose, + NS::UInteger rows, + NS::UInteger columns, + double alpha, + double beta); + + void encodeToCommandBuffer( + MTL::CommandBuffer* commandBuffer, + Matrix* inputMatrix, + Vector* inputVector, + Vector* resultVector); +}; + +_MTL_INLINE MatrixDescriptor* MatrixDescriptor::matrixDescriptor( + NS::UInteger rows, + NS::UInteger columns, + NS::UInteger rowBytes, + NS::UInteger dataType) { + return Object::sendMessage( + _MPS_PRIVATE_CLS(MPSMatrixDescriptor), + _MPS_PRIVATE_SEL(matrixDescriptorWithRows_columns_rowBytes_dataType), + rows, + columns, + rowBytes, + dataType); +} + +_MTL_INLINE MatrixDescriptor* MatrixDescriptor::matrixDescriptor( + NS::UInteger rows, + NS::UInteger columns, + NS::UInteger matrices, + NS::UInteger rowBytes, + NS::UInteger matrixBytes, + NS::UInteger dataType) { + return Object::sendMessage( + _MPS_PRIVATE_CLS(MPSMatrixDescriptor), + _MPS_PRIVATE_SEL( + matrixDescriptorWithRows_columns_matrices_rowBytes_matrixBytes_dataType), + rows, + columns, + matrices, + rowBytes, + matrixBytes, + dataType); +} + +_MTL_INLINE NS::UInteger MatrixDescriptor::rows() const { + return Object::sendMessage(this, _MPS_PRIVATE_SEL(rows)); +} + +_MTL_INLINE Matrix* Matrix::alloc() { + return NS::Object::alloc(_MPS_PRIVATE_CLS(MPSMatrix)); +} + +_MTL_INLINE Matrix* Matrix::init( + MTL::Buffer* buffer, + MatrixDescriptor* descriptor) { + return Object::sendMessage( + this, _MPS_PRIVATE_SEL(initWithBuffer_descriptor), buffer, descriptor); +} + +_MTL_INLINE Matrix* Matrix::init( + const MTL::Buffer* buffer, + MatrixDescriptor* descriptor) { + return init(const_cast(buffer), descriptor); +} + +_MTL_INLINE NS::String* Kernel::label() const { + return Object::sendMessage(this, _MPS_PRIVATE_SEL(label)); +} + +_MTL_INLINE MTL::Device* Kernel::device() const { + return Object::sendMessage(this, _MPS_PRIVATE_SEL(device)); +} + +_MTL_INLINE MatrixMultiplication* MatrixMultiplication::alloc() { + return NS::Object::alloc( + _MPS_PRIVATE_CLS(MPSMatrixMultiplication)); +} + +_MTL_INLINE MatrixMultiplication* MatrixMultiplication::init( + MTL::Device* device, + bool transposeLeft, + bool transposeRight, + NS::UInteger resultRows, + NS::UInteger resultColumns, + NS::UInteger interiorColumns, + double alpha, + double beta) { + return Object::sendMessage( + this, + _MPS_PRIVATE_SEL(initWithDevice_), + device, + transposeLeft, + transposeRight, + resultRows, + resultColumns, + interiorColumns, + alpha, + beta); +} + +_MTL_INLINE void MatrixMultiplication::encodeToCommandBuffer( + MTL::CommandBuffer* commandBuffer, + Matrix* leftMatrix, + Matrix* rightMatrix, + Matrix* resultMatrix) { + return Object::sendMessage( + this, + _MPS_PRIVATE_SEL( + encodeToCommandBuffer_leftMatrix_rightMatrix_resultMatrix), + commandBuffer, + leftMatrix, + rightMatrix, + resultMatrix); +} + +_MTL_INLINE void MatrixMultiplication::setLeftMatrixOrigin(MTL::Origin origin) { + Object::sendMessage( + this, _MPS_PRIVATE_SEL(setLeftMatrixOrigin_), origin); +} + +_MTL_INLINE void MatrixMultiplication::setRightMatrixOrigin( + MTL::Origin origin) { + Object::sendMessage( + this, _MPS_PRIVATE_SEL(setRightMatrixOrigin_), origin); +} + +_MTL_INLINE void MatrixMultiplication::setResultMatrixOrigin( + MTL::Origin origin) { + Object::sendMessage( + this, _MPS_PRIVATE_SEL(setResultMatrixOrigin_), origin); +} + +_MTL_INLINE void MatrixMultiplication::setBatchStart(NS::UInteger batchStart) { + Object::sendMessage(this, _MPS_PRIVATE_SEL(setBatchStart_), batchStart); +} + +_MTL_INLINE void MatrixMultiplication::setBatchSize(NS::UInteger batchSize) { + Object::sendMessage(this, _MPS_PRIVATE_SEL(setBatchSize_), batchSize); +} + +_MTL_INLINE VectorDescriptor* VectorDescriptor::vectorDescriptor( + NS::UInteger length, + NS::UInteger dataType) { + return Object::sendMessage( + _MPS_PRIVATE_CLS(MPSVectorDescriptor), + _MPS_PRIVATE_SEL(vectorDescriptorWithLength_dataType), + length, + dataType); +} + +_MTL_INLINE VectorDescriptor* VectorDescriptor::vectorDescriptor( + NS::UInteger length, + NS::UInteger vectors, + NS::UInteger vectorBytes, + NS::UInteger dataType) { + return Object::sendMessage( + _MPS_PRIVATE_CLS(MPSVectorDescriptor), + _MPS_PRIVATE_SEL(vectorDescriptorWithLength_vectors_vectorBytes_dataType), + length, + vectors, + vectorBytes, + dataType); +} + +_MTL_INLINE Vector* Vector::alloc() { + return NS::Object::alloc(_MPS_PRIVATE_CLS(MPSVector)); +} + +_MTL_INLINE Vector* Vector::init( + MTL::Buffer* buffer, + VectorDescriptor* descriptor) { + return Object::sendMessage( + this, _MPS_PRIVATE_SEL(initWithBuffer_descriptor), buffer, descriptor); +} + +_MTL_INLINE Vector* Vector::init( + const MTL::Buffer* buffer, + VectorDescriptor* descriptor) { + return init(const_cast(buffer), descriptor); +} + +_MTL_INLINE MatrixVectorMultiplication* MatrixVectorMultiplication::alloc() { + return NS::Object::alloc( + _MPS_PRIVATE_CLS(MPSMatrixVectorMultiplication)); +} + +_MTL_INLINE MatrixVectorMultiplication* MatrixVectorMultiplication::init( + MTL::Device* device, + bool transpose, + NS::UInteger rows, + NS::UInteger columns, + double alpha, + double beta) { + return Object::sendMessage( + this, + _MPS_PRIVATE_SEL(initWithDevice_transpose_rows_columns_alpha_beta), + device, + transpose, + rows, + columns, + alpha, + beta); +} + +_MTL_INLINE void MatrixVectorMultiplication::encodeToCommandBuffer( + MTL::CommandBuffer* commandBuffer, + Matrix* inputMatrix, + Vector* inputVector, + Vector* resultVector) { + return Object::sendMessage( + this, + _MPS_PRIVATE_SEL( + encodeToCommandBuffer_inputMatrix_inputVector_resultVector), + commandBuffer, + inputMatrix, + inputVector, + resultVector); +} + +} // namespace MPS \ No newline at end of file diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp new file mode 100644 index 0000000000..48ceb0548f --- /dev/null +++ b/mlx/backend/metal/primitives.cpp @@ -0,0 +1,604 @@ +#include +#include +#include +#include + +#include "mlx/backend/common/binary.h" +#include "mlx/backend/metal/copy.h" +#include "mlx/backend/metal/device.h" +#include "mlx/backend/metal/kernels/defines.h" +#include "mlx/backend/metal/utils.h" +#include "mlx/primitives.h" +#include "mlx/utils.h" + +namespace mlx::core { + +namespace { + +static constexpr int METAL_MAX_INDEX_ARRAYS = 10; + +void binary_op( + const std::vector& inputs, + array& out, + const std::string op) { + assert(inputs.size() == 2); + auto& a = inputs[0]; + auto& b = inputs[1]; + auto bopt = get_binary_op_type(a, b); + set_binary_op_output_data(a, b, out, bopt); + + // Try to collapse contiguous dims + auto [shape, strides] = collapse_contiguous_dims(a, b, out); + auto& strides_a = strides[0]; + auto& strides_b = strides[1]; + auto& strides_out = strides[2]; + + std::ostringstream kname; + switch (bopt) { + case ScalarScalar: + kname << "ss"; + break; + case ScalarVector: + kname << "sv"; + break; + case VectorScalar: + kname << "vs"; + break; + case VectorVector: + kname << "vv"; + break; + case General: + kname << "g"; + break; + } + kname << op << type_to_name(a); + if (bopt == General && out.ndim() <= MAX_BINARY_SPECIALIZED_DIMS) { + kname << "_" << shape.size(); + } + + auto& s = out.primitive().stream(); + auto& d = metal::device(s.device); + auto kernel = d.get_kernel(kname.str()); + auto compute_encoder = d.get_command_encoder(s.index); + compute_encoder->setComputePipelineState(kernel); + set_array_buffer(compute_encoder, a, 0); + set_array_buffer(compute_encoder, b, 1); + set_array_buffer(compute_encoder, out, 2); + + if (bopt == General) { + auto ndim = shape.size(); + if (ndim > 3) { + compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 3); + compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 4); + compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 5); + } else { + // The shape is implicit in the grid for <= 3D + compute_encoder->setBytes(strides_a.data(), ndim * sizeof(size_t), 3); + compute_encoder->setBytes(strides_b.data(), ndim * sizeof(size_t), 4); + } + + if (ndim > MAX_BINARY_SPECIALIZED_DIMS) { + compute_encoder->setBytes(&ndim, sizeof(int), 6); + } + + // Launch up to 3D grid of threads + int dim0 = ndim > 0 ? shape[ndim - 1] : 1; + int dim1 = ndim > 1 ? shape[ndim - 2] : 1; + int rest = out.size() / (dim0 * dim1); + NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); + if (thread_group_size != 1024) { + throw std::runtime_error("[Metal::binary] Must use 1024 sized block"); + } + auto group_dims = get_block_dims(dim0, dim1, rest); + MTL::Size grid_dims = MTL::Size(dim0, dim1, rest); + compute_encoder->dispatchThreads(grid_dims, group_dims); + } else { + // Launch a 1D grid of threads + size_t nthreads = bopt == General ? out.size() : out.data_size(); + MTL::Size grid_dims = MTL::Size(nthreads, 1, 1); + NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); + if (thread_group_size > nthreads) { + thread_group_size = nthreads; + } + MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); + compute_encoder->dispatchThreads(grid_dims, group_dims); + } +} + +void unary_op( + const std::vector& inputs, + array& out, + const std::string op) { + auto& in = inputs[0]; + bool contig = in.flags().contiguous; + if (contig) { + out.set_data( + allocator::malloc_or_wait(in.data_size() * out.itemsize()), + in.data_size(), + in.strides(), + in.flags()); + } else { + out.set_data(allocator::malloc_or_wait(out.nbytes())); + } + + auto& s = out.primitive().stream(); + auto& d = metal::device(s.device); + std::string tname = type_to_name(in); + std::string opt_name = contig ? "v" : "g"; + auto kernel = d.get_kernel(opt_name + op + tname); + + size_t nthreads = contig ? in.data_size() : in.size(); + MTL::Size grid_dims = MTL::Size(nthreads, 1, 1); + NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); + if (thread_group_size > nthreads) { + thread_group_size = nthreads; + } + MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); + + auto compute_encoder = d.get_command_encoder(s.index); + compute_encoder->setComputePipelineState(kernel); + set_array_buffer(compute_encoder, in, 0); + set_array_buffer(compute_encoder, out, 1); + if (!contig) { + compute_encoder->setBytes(in.shape().data(), in.ndim() * sizeof(int), 2); + compute_encoder->setBytes( + in.strides().data(), in.ndim() * sizeof(size_t), 3); + int ndim = in.ndim(); + compute_encoder->setBytes(&ndim, sizeof(int), 4); + } + compute_encoder->dispatchThreads(grid_dims, group_dims); +} + +} // namespace + +void Abs::eval_gpu(const std::vector& inputs, array& out) { + unary_op(inputs, out, "abs"); +} + +void Add::eval_gpu(const std::vector& inputs, array& out) { + binary_op(inputs, out, "add"); +} + +template +void arange_set_scalars(T start, T next, MTL::ComputeCommandEncoder* enc) { + enc->setBytes(&start, sizeof(T), 0); + T step = next - start; + enc->setBytes(&step, sizeof(T), 1); +} + +void Arange::eval_gpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 0); + out.set_data(allocator::malloc_or_wait(out.nbytes())); + auto& s = stream(); + auto& d = metal::device(s.device); + auto kernel = d.get_kernel("arange" + type_to_name(out)); + size_t nthreads = out.size(); + MTL::Size grid_dims = MTL::Size(nthreads, 1, 1); + MTL::Size group_dims = MTL::Size( + std::min(nthreads, kernel->maxTotalThreadsPerThreadgroup()), 1, 1); + auto compute_encoder = d.get_command_encoder(s.index); + compute_encoder->setComputePipelineState(kernel); + + switch (out.dtype()) { + case bool_: // unsupported + throw std::runtime_error("[Arange::eval_gpu] Does not support bool"); + case uint8: + arange_set_scalars(start_, start_ + step_, compute_encoder); + break; + case uint16: + arange_set_scalars(start_, start_ + step_, compute_encoder); + break; + case uint32: + arange_set_scalars(start_, start_ + step_, compute_encoder); + break; + case uint64: + arange_set_scalars(start_, start_ + step_, compute_encoder); + break; + case int8: + arange_set_scalars(start_, start_ + step_, compute_encoder); + break; + case int16: + arange_set_scalars(start_, start_ + step_, compute_encoder); + break; + case int32: + arange_set_scalars(start_, start_ + step_, compute_encoder); + break; + case int64: + arange_set_scalars(start_, start_ + step_, compute_encoder); + break; + case float16: + arange_set_scalars(start_, start_ + step_, compute_encoder); + break; + case float32: + arange_set_scalars(start_, start_ + step_, compute_encoder); + break; + case bfloat16: + throw std::runtime_error("[Arange::eval_gpu] Does not support bfloat16"); + case complex64: + throw std::runtime_error("[Arange::eval_gpu] Does not support complex64"); + } + + set_array_buffer(compute_encoder, out, 2); + compute_encoder->dispatchThreads(grid_dims, group_dims); +} + +void ArcCos::eval_gpu(const std::vector& inputs, array& out) { + unary_op(inputs, out, "arccos"); +} + +void ArcCosh::eval_gpu(const std::vector& inputs, array& out) { + unary_op(inputs, out, "arccosh"); +} + +void ArcSin::eval_gpu(const std::vector& inputs, array& out) { + unary_op(inputs, out, "arcsin"); +} + +void ArcSinh::eval_gpu(const std::vector& inputs, array& out) { + unary_op(inputs, out, "arcsinh"); +} + +void ArcTan::eval_gpu(const std::vector& inputs, array& out) { + unary_op(inputs, out, "arctan"); +} + +void ArcTanh::eval_gpu(const std::vector& inputs, array& out) { + unary_op(inputs, out, "arctanh"); +} + +void ArgReduce::eval_gpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + auto& in = inputs[0]; + out.set_data(allocator::malloc_or_wait(out.nbytes())); + auto& s = stream(); + auto& d = metal::device(s.device); + std::string op_name; + switch (reduce_type_) { + case ArgReduce::ArgMin: + op_name = "argmin_"; + break; + case ArgReduce::ArgMax: + op_name = "argmax_"; + break; + } + + // Prepare the shapes, strides and axis arguments. + std::vector in_strides = in.strides(); + std::vector shape = in.shape(); + std::vector out_strides = out.strides(); + size_t axis_stride = in_strides[axis_]; + size_t axis_size = shape[axis_]; + if (out_strides.size() == in_strides.size()) { + out_strides.erase(out_strides.begin() + axis_); + } + in_strides.erase(in_strides.begin() + axis_); + shape.erase(shape.begin() + axis_); + size_t ndim = shape.size(); + + // ArgReduce + int simd_size = 32; + int n_reads = 4; + auto compute_encoder = d.get_command_encoder(s.index); + { + auto kernel = d.get_kernel(op_name + type_to_name(in)); + NS::UInteger thread_group_size = std::min( + (axis_size + n_reads - 1) / n_reads, + kernel->maxTotalThreadsPerThreadgroup()); + // round up to the closest number divisible by simd_size + thread_group_size = + (thread_group_size + simd_size - 1) / simd_size * simd_size; + assert(thread_group_size <= kernel->maxTotalThreadsPerThreadgroup()); + + size_t n_threads = out.size() * thread_group_size; + MTL::Size grid_dims = MTL::Size(n_threads, 1, 1); + MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); + compute_encoder->setComputePipelineState(kernel); + set_array_buffer(compute_encoder, in, 0); + set_array_buffer(compute_encoder, out, 1); + compute_encoder->setBytes(shape.data(), ndim * sizeof(int), 2); + compute_encoder->setBytes(in_strides.data(), ndim * sizeof(size_t), 3); + compute_encoder->setBytes(out_strides.data(), ndim * sizeof(size_t), 4); + compute_encoder->setBytes(&ndim, sizeof(size_t), 5); + compute_encoder->setBytes(&axis_stride, sizeof(size_t), 6); + compute_encoder->setBytes(&axis_size, sizeof(size_t), 7); + compute_encoder->setThreadgroupMemoryLength( + simd_size * (sizeof(uint32_t) + in.itemsize()), 0); + compute_encoder->dispatchThreads(grid_dims, group_dims); + } +} + +void AsType::eval_gpu(const std::vector& inputs, array& out) { + CopyType ctype = + inputs[0].flags().contiguous ? CopyType::Vector : CopyType::General; + copy_gpu(inputs[0], out, ctype); +} + +void AsStrided::eval_gpu(const std::vector& inputs, array& out) { + eval(inputs, out); +} + +void Broadcast::eval_gpu(const std::vector& inputs, array& out) { + eval(inputs, out); +} + +void Concatenate::eval_gpu(const std::vector& inputs, array& out) { + std::vector sizes; + sizes.push_back(0); + for (auto& p : inputs) { + sizes.push_back(p.shape(axis_)); + } + std::partial_sum(sizes.cbegin(), sizes.cend(), sizes.begin()); + + out.set_data(allocator::malloc_or_wait(out.nbytes())); + + auto strides = out.strides(); + auto flags = out.flags(); + flags.row_contiguous = false; + flags.col_contiguous = false; + flags.contiguous = false; + for (int i = 0; i < inputs.size(); i++) { + array out_slice(inputs[i].shape(), out.dtype(), nullptr, {}); + size_t data_offset = strides[axis_] * sizes[i]; + out_slice.copy_shared_buffer( + out, strides, flags, out_slice.size(), data_offset); + copy_gpu_inplace(inputs[i], out_slice, CopyType::GeneralGeneral, stream()); + } +} + +void Copy::eval_gpu(const std::vector& inputs, array& out) { + eval(inputs, out); +} + +void Cos::eval_gpu(const std::vector& inputs, array& out) { + unary_op(inputs, out, "cos"); +} + +void Cosh::eval_gpu(const std::vector& inputs, array& out) { + unary_op(inputs, out, "cosh"); +} + +void Divide::eval_gpu(const std::vector& inputs, array& out) { + binary_op(inputs, out, "div"); +} + +void Equal::eval_gpu(const std::vector& inputs, array& out) { + binary_op(inputs, out, equal_nan_ ? "naneq" : "eq"); +} + +void Erf::eval_gpu(const std::vector& inputs, array& out) { + unary_op(inputs, out, "erf"); +} + +void ErfInv::eval_gpu(const std::vector& inputs, array& out) { + unary_op(inputs, out, "erfinv"); +} + +void Exp::eval_gpu(const std::vector& inputs, array& out) { + unary_op(inputs, out, "exp"); +} + +void Full::eval_gpu(const std::vector& inputs, array& out) { + auto in = inputs[0]; + CopyType ctype; + if (in.data_size() == 1) { + ctype = CopyType::Scalar; + } else if (in.flags().contiguous) { + ctype = CopyType::Vector; + } else { + ctype = CopyType::General; + } + copy_gpu(in, out, ctype); +} + +void Greater::eval_gpu(const std::vector& inputs, array& out) { + binary_op(inputs, out, "ge"); +} + +void GreaterEqual::eval_gpu(const std::vector& inputs, array& out) { + binary_op(inputs, out, "geq"); +} + +void Less::eval_gpu(const std::vector& inputs, array& out) { + binary_op(inputs, out, "le"); +} + +void LessEqual::eval_gpu(const std::vector& inputs, array& out) { + binary_op(inputs, out, "leq"); +} + +void Load::eval_gpu(const std::vector& inputs, array& out) { + eval(inputs, out); +} + +void Log::eval_gpu(const std::vector& inputs, array& out) { + switch (base_) { + case Base::e: + unary_op(inputs, out, "log"); + break; + case Base::two: + unary_op(inputs, out, "log2"); + break; + case Base::ten: + unary_op(inputs, out, "log10"); + break; + } +} + +void Log1p::eval_gpu(const std::vector& inputs, array& out) { + unary_op(inputs, out, "log1p"); +} + +void LogicalNot::eval_gpu(const std::vector& inputs, array& out) { + unary_op(inputs, out, "lnot"); +} + +void LogAddExp::eval_gpu(const std::vector& inputs, array& out) { + binary_op(inputs, out, "lae"); +} + +void Maximum::eval_gpu(const std::vector& inputs, array& out) { + binary_op(inputs, out, "max"); +} + +void Minimum::eval_gpu(const std::vector& inputs, array& out) { + binary_op(inputs, out, "min"); +} + +void Multiply::eval_gpu(const std::vector& inputs, array& out) { + binary_op(inputs, out, "mul"); +} + +void Negative::eval_gpu(const std::vector& inputs, array& out) { + unary_op(inputs, out, "neg"); +} + +void NotEqual::eval_gpu(const std::vector& inputs, array& out) { + binary_op(inputs, out, "neq"); +} + +void Pad::eval_gpu(const std::vector& inputs, array& out) { + // Inputs must be base input array and scalar val array + assert(inputs.size() == 2); + auto& in = inputs[0]; + auto& val = inputs[1]; + + // Padding value must be a scalar + assert(val.size() == 1); + + // Padding value, input and output must be of the same type + assert(val.dtype() == in.dtype() && in.dtype() == out.dtype()); + + // Fill output with val + copy_gpu(val, out, CopyType::Scalar, stream()); + + // Find offset for start of input values + size_t data_offset = 0; + for (int i = 0; i < axes_.size(); i++) { + auto ax = axes_[i] < 0 ? out.ndim() + axes_[i] : axes_[i]; + data_offset += out.strides()[ax] * low_pad_size_[i]; + } + + // Extract slice from output where input will be pasted + array out_slice(in.shape(), out.dtype(), nullptr, {}); + out_slice.copy_shared_buffer( + out, out.strides(), out.flags(), out_slice.size(), data_offset); + + // Copy input values into the slice + copy_gpu_inplace(in, out_slice, CopyType::GeneralGeneral, stream()); +} + +void Power::eval_gpu(const std::vector& inputs, array& out) { + binary_op(inputs, out, "pow"); +} + +void RandomBits::eval_gpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + + // keys has shape (N1, ..., NK, 2) + // out has shape (N1, ..., NK, M1, M2, ...) + auto& keys = inputs[0]; + size_t num_keys = keys.size() / 2; + + size_t elems_per_key = out.size() / num_keys; + size_t bytes_per_key = out.itemsize() * elems_per_key; + out.set_data(allocator::malloc_or_wait(out.nbytes())); + + size_t out_per_key = (bytes_per_key + 4 - 1) / 4; + size_t half_size = out_per_key / 2; + bool odd = out_per_key % 2; + + auto& s = stream(); + auto& d = metal::device(s.device); + std::string kname = keys.flags().row_contiguous ? "rbitsc" : "rbits"; + auto kernel = d.get_kernel(kname); + + // organize into grid nkeys x elem_per_key + MTL::Size grid_dims = MTL::Size(num_keys, half_size + odd, 1); + NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); + auto nthreads = std::min(num_keys * (half_size + odd), thread_group_size); + MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); + auto compute_encoder = d.get_command_encoder(s.index); + compute_encoder->setComputePipelineState(kernel); + set_array_buffer(compute_encoder, keys, 0); + set_array_buffer(compute_encoder, out, 1); + compute_encoder->setBytes(&odd, sizeof(bool), 2); + compute_encoder->setBytes(&bytes_per_key, sizeof(size_t), 3); + + if (!keys.flags().row_contiguous) { + int ndim = keys.ndim(); + compute_encoder->setBytes(&ndim, sizeof(int), 4); + compute_encoder->setBytes( + keys.shape().data(), keys.ndim() * sizeof(int), 5); + compute_encoder->setBytes( + keys.strides().data(), keys.ndim() * sizeof(size_t), 6); + } + + compute_encoder->dispatchThreads(grid_dims, group_dims); +} + +void Reshape::eval_gpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + const auto& in = inputs[0]; + if (in.flags().row_contiguous) { + auto flags = in.flags(); + auto max_dim = std::max_element(out.shape().begin(), out.shape().end()); + flags.col_contiguous = out.size() <= 1 || out.size() == *max_dim; + out.copy_shared_buffer(in, out.strides(), flags, in.data_size()); + } else { + copy_gpu(in, out, CopyType::General); + } +} + +void Sigmoid::eval_gpu(const std::vector& inputs, array& out) { + unary_op(inputs, out, "sigmoid"); +} + +void Sign::eval_gpu(const std::vector& inputs, array& out) { + unary_op(inputs, out, "sign"); +} + +void Sin::eval_gpu(const std::vector& inputs, array& out) { + unary_op(inputs, out, "sin"); +} + +void Sinh::eval_gpu(const std::vector& inputs, array& out) { + unary_op(inputs, out, "sinh"); +} + +void Square::eval_gpu(const std::vector& inputs, array& out) { + unary_op(inputs, out, "square"); +} + +void Sqrt::eval_gpu(const std::vector& inputs, array& out) { + if (recip_) { + unary_op(inputs, out, "rsqrt"); + } else { + unary_op(inputs, out, "sqrt"); + } +} + +void Slice::eval_gpu(const std::vector& inputs, array& out) { + eval(inputs, out); +} + +void StopGradient::eval_gpu(const std::vector& inputs, array& out) { + eval(inputs, out); +} + +void Subtract::eval_gpu(const std::vector& inputs, array& out) { + binary_op(inputs, out, "sub"); +} + +void Tan::eval_gpu(const std::vector& inputs, array& out) { + unary_op(inputs, out, "tan"); +} + +void Tanh::eval_gpu(const std::vector& inputs, array& out) { + unary_op(inputs, out, "tanh"); +} + +void Transpose::eval_gpu(const std::vector& inputs, array& out) { + eval(inputs, out); +} + +} // namespace mlx::core diff --git a/mlx/backend/metal/scan.cpp b/mlx/backend/metal/scan.cpp new file mode 100644 index 0000000000..9c91f89e9c --- /dev/null +++ b/mlx/backend/metal/scan.cpp @@ -0,0 +1,130 @@ +#include +#include + +#include "mlx/backend/metal/copy.h" +#include "mlx/backend/metal/device.h" +#include "mlx/backend/metal/utils.h" +#include "mlx/primitives.h" + +namespace mlx::core { + +void Scan::eval_gpu(const std::vector& inputs, array& out) { + assert(inputs.size() == 1); + + out.set_data(allocator::malloc_or_wait(out.nbytes())); + + auto& s = stream(); + auto& d = metal::device(s.device); + + // Ensure contiguity + std::vector copies; + auto in = inputs[0]; + if (!in.flags().row_contiguous) { + array arr_copy(in.shape(), in.dtype(), nullptr, {}); + copy_gpu(in, arr_copy, CopyType::General, s); + copies.push_back(arr_copy); + in = arr_copy; + } + + std::ostringstream kname; + if (in.strides()[axis_] == 1) { + kname << "contiguous_scan_"; + if (reverse_) { + kname << "reverse_"; + } + kname << ((inclusive_) ? "inclusive_" : "exclusive_"); + switch (reduce_type_) { + case Scan::Sum: + kname << "sum_"; + break; + case Scan::Prod: + kname << "prod_"; + break; + case Scan::Max: + kname << "max_"; + break; + case Scan::Min: + kname << "min_"; + break; + } + kname << type_to_name(in) << "_" << type_to_name(out); + + auto kernel = d.get_kernel(kname.str()); + auto compute_encoder = d.get_command_encoder(s.index); + compute_encoder->setComputePipelineState(kernel); + set_array_buffer(compute_encoder, in, 0); + set_array_buffer(compute_encoder, out, 1); + size_t size = in.shape(axis_); + compute_encoder->setBytes(&size, sizeof(size_t), 2); + + // Compute the thread grid + int n_reads = (in.itemsize() <= 4) ? 4 : 2; + int elements_per_simd = n_reads * 32; + int thread_groups = in.size() / size; + int thread_group_size = kernel->maxTotalThreadsPerThreadgroup(); + if (size < n_reads * 1024) { + thread_group_size = ((size + elements_per_simd - 1) / elements_per_simd) * + elements_per_simd; + } else if (size < n_reads * 2048) { + thread_group_size = + ((size / 2 + elements_per_simd - 1) / elements_per_simd) * + elements_per_simd; + } + thread_group_size = std::min( + thread_group_size, + static_cast(kernel->maxTotalThreadsPerThreadgroup())); + MTL::Size grid_dims = MTL::Size(thread_groups * thread_group_size, 1, 1); + MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1); + compute_encoder->dispatchThreads(grid_dims, group_dims); + } else { + kname << "strided_scan_"; + if (reverse_) { + kname << "reverse_"; + } + kname << ((inclusive_) ? "inclusive_" : "exclusive_"); + switch (reduce_type_) { + case Scan::Sum: + kname << "sum_"; + break; + case Scan::Prod: + kname << "prod_"; + break; + case Scan::Max: + kname << "max_"; + break; + case Scan::Min: + kname << "min_"; + break; + } + kname << type_to_name(in) << "_" << type_to_name(out); + + auto kernel = d.get_kernel(kname.str()); + auto compute_encoder = d.get_command_encoder(s.index); + compute_encoder->setComputePipelineState(kernel); + set_array_buffer(compute_encoder, in, 0); + set_array_buffer(compute_encoder, out, 1); + size_t size = in.shape(axis_); + size_t stride = in.strides()[axis_]; + compute_encoder->setBytes(&size, sizeof(size_t), 2); + compute_encoder->setBytes(&stride, sizeof(size_t), 3); + + // Compute the thread grid + int n_reads = (in.itemsize() <= 4) ? 4 : 2; + int tile_x = 32; + int tile_y = 32; + int elements_per_tile_x = tile_x * n_reads; + int grid_y = in.size() / size / stride; + int grid_x = (stride + elements_per_tile_x - 1) / elements_per_tile_x; + MTL::Size grid_dims = MTL::Size(grid_x * tile_x, grid_y * tile_y, 1); + MTL::Size group_dims = MTL::Size(tile_x, tile_y, 1); + compute_encoder->dispatchThreads(grid_dims, group_dims); + } + + if (copies.size() > 0) { + auto command_buffer = d.get_command_buffer(s.index); + command_buffer->addCompletedHandler( + [copies](MTL::CommandBuffer*) mutable { copies.clear(); }); + } +} + +} // namespace mlx::core diff --git a/mlx/backend/metal/utils.h b/mlx/backend/metal/utils.h new file mode 100644 index 0000000000..0165761e85 --- /dev/null +++ b/mlx/backend/metal/utils.h @@ -0,0 +1,167 @@ +#pragma once + +#include "mlx/array.h" +#include "mlx/backend/metal/device.h" + +namespace mlx::core { + +namespace { + +void set_array_buffer( + MTL::ComputeCommandEncoder* compute_encoder, + MTL::ArgumentEncoder* enc, + const array& a, + int idx) { + auto a_buf = static_cast(a.buffer().ptr()); + auto offset = a.data() - + static_cast(const_cast(a_buf)->contents()); + enc->setBuffer(a_buf, offset, idx); + // MTL::Resource usage through argument buffer needs to be explicity + // flagged to enable hazard tracking + compute_encoder->useResource(a_buf, MTL::ResourceUsageRead); +} + +void set_array_buffer( + MTL::ComputeCommandEncoder* enc, + const array& a, + int idx) { + auto a_buf = static_cast(a.buffer().ptr()); + auto offset = a.data() - + static_cast(const_cast(a_buf)->contents()); + enc->setBuffer(a_buf, offset, idx); +} + +std::string type_to_name(const array& a) { + std::string tname; + switch (a.dtype()) { + case bool_: + tname = "bool_"; + break; + case uint8: + tname = "uint8"; + break; + case uint16: + tname = "uint16"; + break; + case uint32: + tname = "uint32"; + break; + case uint64: + tname = "uint64"; + break; + case int8: + tname = "int8"; + break; + case int16: + tname = "int16"; + break; + case int32: + tname = "int32"; + break; + case int64: + tname = "int64"; + break; + case float16: + tname = "float16"; + break; + case float32: + tname = "float32"; + break; + case bfloat16: + tname = "bfloat16"; + break; + case complex64: + tname = "complex64"; + break; + } + return tname; +} + +MTL::Size get_block_dims(int dim0, int dim1, int dim2) { + int pows[3] = {0, 0, 0}; + int sum = 0; + while (true) { + int presum = sum; + // Check all the pows + if (dim0 >= (1 << (pows[0] + 1))) { + pows[0]++; + sum++; + } + if (sum == 10) { + break; + } + if (dim1 >= (1 << (pows[1] + 1))) { + pows[1]++; + sum++; + } + if (sum == 10) { + break; + } + if (dim2 >= (1 << (pows[2] + 1))) { + pows[2]++; + sum++; + } + if (sum == presum || sum == 10) { + break; + } + } + return MTL::Size{1ul << pows[0], 1ul << pows[1], 1ul << pows[2]}; +} + +// Collapse dims that are contiguous to possibly route to a better kernel +// e.g. for x = transpose(array({0, 1, 2, 3, 4, 5, 6, 7}, {2, 2, 2}), {2, 0, 1}) +// should return {{2, 4}, {{1, 2}}}. +// +// When multiple arrays are passed they should all have the same shape. The +// collapsed axes are also the same so one shape is returned. +std::tuple, std::vector>> +collapse_contiguous_dims(const std::vector& xs) { + // Make a vector that has axes separated with -1. Collapse all axes between + // -1. + std::vector to_collapse; + if (xs[0].ndim() > 0) { + to_collapse.push_back(0); + for (int i = 1; i < xs[0].ndim(); i++) { + bool contiguous = true; + for (auto& x : xs) { + if (x.strides()[i] * x.shape()[i] != x.strides()[i - 1]) { + contiguous = false; + } + if (!contiguous) { + break; + } + } + if (!contiguous) { + to_collapse.push_back(-1); + } + to_collapse.push_back(i); + } + to_collapse.push_back(-1); + } + + std::vector out_shape; + std::vector> out_strides(xs.size()); + for (int i = 0; i < to_collapse.size(); i++) { + int current_shape = xs[0].shape()[to_collapse[i]]; + while (to_collapse[++i] != -1) { + current_shape *= xs[0].shape()[to_collapse[i]]; + } + out_shape.push_back(current_shape); + for (int j = 0; j < xs.size(); j++) { + out_strides[j].push_back(xs[j].strides()[to_collapse[i - 1]]); + } + } + + return std::make_tuple(out_shape, out_strides); +} + +template +std::tuple, std::vector>> +collapse_contiguous_dims(Arrays... xs) { + return collapse_contiguous_dims( + std::vector{std::forward(xs)...}); +} + +} // namespace + +} // namespace mlx::core diff --git a/mlx/backend/no_metal/metal.cpp b/mlx/backend/no_metal/metal.cpp new file mode 100644 index 0000000000..ec45188874 --- /dev/null +++ b/mlx/backend/no_metal/metal.cpp @@ -0,0 +1,18 @@ +#include + +#include "mlx/backend/metal/metal.h" + +namespace mlx::core::metal { + +void new_stream(Stream) {} + +std::function make_task( + array& arr, + std::vector> deps, + std::shared_ptr> p, + bool retain_graph) { + throw std::runtime_error( + "[metal::make_task] Cannot make GPU task without metal backend"); +} + +} // namespace mlx::core::metal diff --git a/mlx/device.h b/mlx/device.h new file mode 100644 index 0000000000..66035ddca9 --- /dev/null +++ b/mlx/device.h @@ -0,0 +1,27 @@ +#pragma once + +namespace mlx::core { + +struct Device { + enum class DeviceType { + cpu, + gpu, + }; + + static constexpr DeviceType cpu = DeviceType::cpu; + static constexpr DeviceType gpu = DeviceType::gpu; + + Device(DeviceType type, int index = 0) : type(type), index(index){}; + + DeviceType type; + int index; +}; + +const Device& default_device(); + +void set_default_device(const Device& d); + +bool operator==(const Device& lhs, const Device& rhs); +bool operator!=(const Device& lhs, const Device& rhs); + +} // namespace mlx::core diff --git a/mlx/dtype.cpp b/mlx/dtype.cpp new file mode 100644 index 0000000000..23749e9a6a --- /dev/null +++ b/mlx/dtype.cpp @@ -0,0 +1,205 @@ +#include +#include +#include + +#include "mlx/dtype.h" +#include "mlx/utils.h" + +namespace mlx::core { + +namespace { + +static constexpr int num_types = 13; + +static constexpr Dtype::Kind type_kinds[num_types] = { + Dtype::Kind::b, // bool_, + Dtype::Kind::u, // uint8, + Dtype::Kind::u, // uint16, + Dtype::Kind::u, // uint32, + Dtype::Kind::u, // uint64, + Dtype::Kind::i, // int8, + Dtype::Kind::i, // int16, + Dtype::Kind::i, // int32, + Dtype::Kind::i, // int64, + Dtype::Kind::f, // float16, + Dtype::Kind::f, // float32, + Dtype::Kind::V, // bfloat16, + Dtype::Kind::c // complex64, +}; + +// Following Jax type promotion rules: +// https://jax.readthedocs.io/en/latest/type_promotion.html +// clang-format off +static constexpr Dtype type_rules[num_types][num_types] = { +// bool uint8 uint16 uint32 uint64 int8 int16 int32 int64 float16 float32 bfloat16 complex64 + {bool_, uint8, uint16, uint32, uint64, int8, int16, int32, int64, float16, float32, bfloat16, complex64}, // bool + {uint8, uint8, uint16, uint32, uint64, int16, int16, int32, int64, float16, float32, bfloat16, complex64}, // uint8 + {uint16, uint16, uint16, uint32, uint64, int32, int32, int32, int64, float16, float32, bfloat16, complex64}, // uint16 + {uint32, uint32, uint32, uint32, uint64, int64, int64, int64, int64, float16, float32, bfloat16, complex64}, // uint32 + {uint64, uint64, uint64, uint64, uint64, float32, float32, float32, float32, float16, float32, bfloat16, complex64}, // uint64 + {int8, int16, int32, int64, float32, int8, int16, int32, int64, float16, float32, bfloat16, complex64}, // int8 + {int16, int16, int32, int64, float32, int16, int16, int32, int64, float16, float32, bfloat16, complex64}, // int16 + {int32, int32, int32, int64, float32, int32, int32, int32, int64, float16, float32, bfloat16, complex64}, // int32 + {int64, int64, int64, int64, float32, int64, int64, int64, int64, float16, float32, bfloat16, complex64}, // int64 + {float16, float16, float16, float16, float16, float16, float16, float16, float16, float16, float32, float32, complex64}, // float16 + {float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, float32, complex64}, // float32 + {bfloat16, bfloat16, bfloat16, bfloat16, bfloat16, bfloat16, bfloat16, bfloat16, bfloat16, float32, float32, bfloat16, complex64}, // bfloat16 + {complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64, complex64}, // complex64 +}; + +// clang-format on + +inline bool is_big_endian() { + union ByteOrder { + int32_t i; + uint8_t c[4]; + }; + ByteOrder b = {0x01234567}; + + return b.c[0] == 0x01; +} + +} // namespace + +Dtype promote_types(const Dtype& t1, const Dtype& t2) { + return Dtype(type_rules[static_cast(t1.val)][static_cast(t2.val)]); +} + +Dtype::Kind kindof(const Dtype& t) { + return type_kinds[static_cast(t.val)]; +} + +template <> +TypeToDtype::operator Dtype() { + return bool_; +} + +template <> +TypeToDtype::operator Dtype() { + return uint8; +} + +template <> +TypeToDtype::operator Dtype() { + return uint16; +} + +template <> +TypeToDtype::operator Dtype() { + return uint32; +} + +template <> +TypeToDtype::operator Dtype() { + return uint64; +} + +template <> +TypeToDtype::operator Dtype() { + return int8; +} + +template <> +TypeToDtype::operator Dtype() { + return int16; +} + +template <> +TypeToDtype::operator Dtype() { + return int32; +} + +template <> +TypeToDtype::operator Dtype() { + return int64; +} + +template <> +TypeToDtype::operator Dtype() { + return float16; +} + +template <> +TypeToDtype::operator Dtype() { + return float32; +} + +template <> +TypeToDtype::operator Dtype() { + return float32; +} + +template <> +TypeToDtype::operator Dtype() { + return bfloat16; +} + +template <> +TypeToDtype::operator Dtype() { + return complex64; +} + +// Array protocol typestring for Dtype +std::string dtype_to_array_protocol(const Dtype& t) { + std::ostringstream r; + if (size_of(t) > 1) + r << (is_big_endian() ? ">" : "<"); + else + r << "|"; + r << kindof(t) << (int)size_of(t); + return r.str(); +} + +// Dtype from array protocol type string +Dtype dtype_from_array_protocol(const std::string& t) { + if (t.length() == 2 || t.length() == 3) { + std::string r = t.length() == 3 ? t.substr(1, 2) : t; + + if (r == "V2") { + return bfloat16; + } + + uint8_t size = r[1] - '0'; + + switch (r[0]) { + case 'b': { + if (size == 1) + return bool_; + } + case 'i': { + if (size == 1) + return int8; + else if (size == 2) + return int16; + else if (size == 4) + return int32; + else if (size == 8) + return int64; + } + case 'u': { + if (size == 1) + return uint8; + else if (size == 2) + return uint16; + else if (size == 4) + return uint32; + else if (size == 8) + return uint64; + } + case 'f': { + if (size == 2) + return float16; + else if (size == 4) + return float32; + } + case 'c': { + return complex64; + } + } + } + + throw std::invalid_argument( + "[from_str] Invalid array protocol type-string: " + t); +} + +} // namespace mlx::core diff --git a/mlx/graph_utils.cpp b/mlx/graph_utils.cpp new file mode 100644 index 0000000000..59bd37367d --- /dev/null +++ b/mlx/graph_utils.cpp @@ -0,0 +1,144 @@ +#include +#include +#include +#include +#include + +#include "mlx/graph_utils.h" +#include "mlx/primitives.h" +#include "mlx/utils.h" + +namespace mlx::core { + +using OptionalArrayRef = std::optional>; + +struct ArrayNames { + std::unordered_map names; + + std::string get_name(const array& x) { + auto it = names.find(x.id()); + if (it == names.end()) { + // Get the next name in the sequence + // [A, B, ..., Z, AA, AB, ...] + std::vector letters; + auto var_num = names.size() + 1; + while (var_num > 0) { + letters.push_back('A' + (var_num - 1) % 26); + var_num = (var_num - 1) / 26; + } + std::string name(letters.rbegin(), letters.rend()); + names.insert({x.id(), name}); + return name; + } + return it->second; + } +}; + +void depth_first_traversal( + std::function callback, + const std::vector& outputs) { + std::function recurse; + std::unordered_set cache; + recurse = [&](OptionalArrayRef parent, const array& x, int input_index) { + auto id = x.id(); + if (cache.find(id) != cache.end()) { + return; + } + cache.insert(id); + for (int i = 0; i < x.inputs().size(); i++) { + recurse(x, x.inputs()[i], i); + } + callback(parent, x, input_index); + }; + + for (auto x : outputs) { + recurse(std::nullopt, x, 0); + } +} + +void depth_first_traversal( + std::function callback, + const std::vector& outputs) { + depth_first_traversal( + [&callback](OptionalArrayRef p, const array& x, int input_index) { + callback(x); + }, + outputs); +} + +void print_graph(std::ostream& os, const std::vector& outputs) { + std::vector tape; + std::vector inputs; + + depth_first_traversal( + [&](const array& x) { + if (x.has_primitive()) { + tape.push_back(x); + } else { + inputs.push_back(x); + } + }, + outputs); + + ArrayNames namer; + auto print_arr = [&namer, &os](const array& a) { + os << namer.get_name(a); + os << " [" << a.shape() << ", " << a.dtype() << "]"; + }; + + auto print_arrs = [&](const std::vector& arrs) { + for (auto& arr : arrs) { + print_arr(arr); + if (&arr != &arrs.back()) { + os << ", "; + } + } + }; + + os << "Inputs: "; + print_arrs(inputs); + os << "\nOutputs: "; + print_arrs(outputs); + os << "\n"; + + for (auto& arr : tape) { + arr.primitive().print(os); + os << " "; + print_arrs(arr.inputs()); + os << " -> "; + print_arr(arr); + os << "\n"; + } +} + +void export_to_dot(std::ostream& os, const std::vector& outputs) { + os << "digraph {" << std::endl; + + ArrayNames namer; + depth_first_traversal( + [&namer, &os](auto parent, const array& x, int input_index) { + os << "{ "; + if (!x.has_primitive()) { + os << "rank=source; "; + } + if (!parent) { + os << "rank=sink; "; + } + os << namer.get_name(x); + if (x.has_primitive()) { + os << " [label =\""; + x.primitive().print(os); + os << "\"]"; + } + os << "; }" << std::endl; + + for (auto c : x.inputs()) { + os << namer.get_name(c) << " -> " << namer.get_name(x) << std::endl; + } + }, + outputs); + + os << "}"; +} + +} // namespace mlx::core diff --git a/mlx/load.h b/mlx/load.h new file mode 100644 index 0000000000..8b433ac224 --- /dev/null +++ b/mlx/load.h @@ -0,0 +1,112 @@ +#pragma once + +#include +#include +#include + +namespace mlx::core { + +namespace io { + +class Reader { + public: + virtual bool is_open() const = 0; + virtual bool good() const = 0; + virtual size_t tell() const = 0; + virtual void seek( + int64_t off, + std::ios_base::seekdir way = std::ios_base::beg) = 0; + virtual void read(char* data, size_t n) = 0; + virtual std::string label() const = 0; +}; + +class Writer { + public: + virtual bool is_open() const = 0; + virtual bool good() const = 0; + virtual size_t tell() const = 0; + virtual void seek( + int64_t off, + std::ios_base::seekdir way = std::ios_base::beg) = 0; + virtual void write(const char* data, size_t n) = 0; + virtual std::string label() const = 0; +}; + +class FileReader : public Reader { + public: + explicit FileReader(const std::shared_ptr& is) + : is_(is), label_("stream") {} + explicit FileReader(const std::string& file_path) + : is_(std::make_shared(file_path, std::ios::binary)), + label_(file_path) {} + + bool is_open() const override { + return is_->is_open(); + } + + bool good() const override { + return is_->good(); + } + + size_t tell() const override { + return is_->tellg(); + } + + void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg) + override { + is_->seekg(off, way); + } + + void read(char* data, size_t n) override { + is_->read(data, n); + } + + std::string label() const override { + return "file " + label_; + } + + private: + std::shared_ptr is_; + std::string label_; +}; + +class FileWriter : public Writer { + public: + explicit FileWriter(const std::shared_ptr& is) + : os_(is), label_("stream") {} + explicit FileWriter(const std::string& file_path) + : os_(std::make_shared(file_path, std::ios::binary)), + label_(file_path) {} + + bool is_open() const override { + return os_->is_open(); + } + + bool good() const override { + return os_->good(); + } + + size_t tell() const override { + return os_->tellp(); + } + + void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg) + override { + os_->seekp(off, way); + } + + void write(const char* data, size_t n) override { + os_->write(data, n); + } + + std::string label() const override { + return "file " + label_; + } + + private: + std::shared_ptr os_; + std::string label_; +}; + +} // namespace io +} // namespace mlx::core \ No newline at end of file diff --git a/mlx/mlx.h b/mlx/mlx.h new file mode 100644 index 0000000000..68b6c182c3 --- /dev/null +++ b/mlx/mlx.h @@ -0,0 +1,11 @@ +#pragma once + +#include "mlx/array.h" +#include "mlx/backend/metal/metal.h" +#include "mlx/device.h" +#include "mlx/fft.h" +#include "mlx/ops.h" +#include "mlx/random.h" +#include "mlx/stream.h" +#include "mlx/transforms.h" +#include "mlx/utils.h" diff --git a/mlx/ops.h b/mlx/ops.h new file mode 100644 index 0000000000..50666fa25c --- /dev/null +++ b/mlx/ops.h @@ -0,0 +1,932 @@ +#pragma once + +#include + +#include "array.h" +#include "device.h" +#include "load.h" +#include "stream.h" + +namespace mlx::core { + +using StreamOrDevice = std::variant; + +Stream to_stream(StreamOrDevice s); + +/** Creation operations */ + +/** + * A 1D array of numbers starting at `start` (optional), + * stopping at stop, stepping by `step` (optional). **/ +array arange( + double start, + double stop, + double step, + Dtype dtype, + StreamOrDevice s = {}); +array arange(double start, double stop, double step, StreamOrDevice s = {}); +array arange(double start, double stop, Dtype dtype, StreamOrDevice s = {}); +array arange(double start, double stop, StreamOrDevice s = {}); +array arange(double stop, Dtype dtype, StreamOrDevice s = {}); +array arange(double stop, StreamOrDevice s = {}); + +array arange(int start, int stop, int step, StreamOrDevice s = {}); +array arange(int start, int stop, StreamOrDevice s = {}); +array arange(int stop, StreamOrDevice s = {}); + +/** Convert an array to the given data type. */ +array astype(const array& a, Dtype dtype, StreamOrDevice s = {}); + +/** Create a view of an array with the given shape and strides. */ +array as_strided( + const array& a, + std::vector shape, + std::vector strides, + size_t offset, + StreamOrDevice s = {}); + +/** Copy another array. */ +array copy(const array& a, StreamOrDevice s = {}); + +/** Fill an array of the given shape with the given value(s). */ +array full( + const std::vector& shape, + const array& vals, + Dtype dtype, + StreamOrDevice s = {}); +array full( + const std::vector& shape, + const array& vals, + StreamOrDevice s = {}); +template +array full( + const std::vector& shape, + T val, + Dtype dtype, + StreamOrDevice s = {}) { + return full(shape, array(val, dtype), to_stream(s)); +} +template +array full(const std::vector& shape, T val, StreamOrDevice s = {}) { + return full(shape, array(val), to_stream(s)); +} + +/** Fill an array of the given shape with zeros. */ +array zeros(const std::vector& shape, Dtype dtype, StreamOrDevice s = {}); +inline array zeros(const std::vector& shape, StreamOrDevice s = {}) { + return zeros(shape, float32, s); +} +array zeros_like(const array& a, StreamOrDevice s = {}); + +/** Fill an array of the given shape with ones. */ +array ones(const std::vector& shape, Dtype dtype, StreamOrDevice s = {}); +inline array ones(const std::vector& shape, StreamOrDevice s = {}) { + return ones(shape, float32, s); +} +array ones_like(const array& a, StreamOrDevice s = {}); + +/** array manipulation */ + +/** Reshape an array to the given shape. */ +array reshape(const array& a, std::vector shape, StreamOrDevice s = {}); + +/** Remove singleton dimensions at the given axes. */ +array squeeze( + const array& a, + const std::vector& axes, + StreamOrDevice s = {}); + +/** Remove singleton dimensions at the given axis. */ +inline array squeeze(const array& a, int axis, StreamOrDevice s = {}) { + return squeeze(a, std::vector{axis}, s); +} + +/** Remove all singleton dimensions. */ +array squeeze(const array& a, StreamOrDevice s = {}); + +/** Add a singleton dimension at the given axes. */ +array expand_dims( + const array& a, + const std::vector& axes, + StreamOrDevice s = {}); + +/** Add a singleton dimension at the given axis. */ +inline array expand_dims(const array& a, int axis, StreamOrDevice s = {}) { + return expand_dims(a, std::vector{axis}, s); +} + +/** Slice an array. */ +array slice( + const array& a, + std::vector start, + std::vector stop, + std::vector strides, + StreamOrDevice s = {}); + +/** Slice an array with a stride of 1 in each dimension. */ +array slice( + const array& a, + const std::vector& start, + const std::vector& stop, + StreamOrDevice s = {}); + +/** Split an array into sub-arrays along a given axis. */ +std::vector +split(const array& a, int num_splits, int axis, StreamOrDevice s = {}); +std::vector split(const array& a, int num_splits, StreamOrDevice s = {}); +std::vector split( + const array& a, + const std::vector& indices, + int axis, + StreamOrDevice s = {}); +std::vector +split(const array& a, const std::vector& indices, StreamOrDevice s = {}); + +/** Concatenate arrays along a given axis. */ +array concatenate( + const std::vector& arrays, + int axis, + StreamOrDevice s = {}); +array concatenate(const std::vector& arrays, StreamOrDevice s = {}); + +/** Permutes the dimensions according to the given axes. */ +array transpose(const array& a, std::vector axes, StreamOrDevice s = {}); +inline array transpose( + const array& a, + std::initializer_list axes, + StreamOrDevice s = {}) { + return transpose(a, std::vector(axes), s); +} + +/** Pad an array with a constant value */ +array pad( + const array& a, + const std::vector& axes, + const std::vector& low_pad_size, + const std::vector& high_pad_size, + const array& pad_value = array(0), + StreamOrDevice s = {}); + +/** Pad an array with a constant value along all axes */ +array pad( + const array& a, + const std::vector>& pad_width, + const array& pad_value = array(0), + StreamOrDevice s = {}); +array pad( + const array& a, + const std::pair& pad_width, + const array& pad_value = array(0), + StreamOrDevice s = {}); +array pad( + const array& a, + int pad_width, + const array& pad_value = array(0), + StreamOrDevice s = {}); + +/** Permutes the dimensions in reverse order. */ +array transpose(const array& a, StreamOrDevice s = {}); + +/** Broadcast an array to a given shape. */ +array broadcast_to( + const array& a, + const std::vector& shape, + StreamOrDevice s = {}); + +/** Broadcast a vector of arrays against one another. */ +std::vector broadcast_arrays( + const std::vector& inputs, + StreamOrDevice s = {}); + +/** Comparison operations */ + +/** Returns the bool array with (a == b) element-wise. */ +array equal(const array& a, const array& b, StreamOrDevice s = {}); +inline array operator==(const array& a, const array& b) { + return equal(a, b); +} +template +array operator==(T a, const array& b) { + return equal(array(a), b); +} +template +array operator==(const array& a, T b) { + return equal(a, array(b)); +} + +/** Returns the bool array with (a != b) element-wise. */ +array not_equal(const array& a, const array& b, StreamOrDevice s = {}); +inline array operator!=(const array& a, const array& b) { + return not_equal(a, b); +} +template +array operator!=(T a, const array& b) { + return not_equal(array(a), b); +} +template +array operator!=(const array& a, T b) { + return not_equal(a, array(b)); +} + +/** Returns bool array with (a > b) element-wise. */ +array greater(const array& a, const array& b, StreamOrDevice s = {}); +inline array operator>(const array& a, const array& b) { + return greater(a, b); +} +template +array operator>(T a, const array& b) { + return greater(array(a), b); +} +template +array operator>(const array& a, T b) { + return greater(a, array(b)); +} + +/** Returns bool array with (a >= b) element-wise. */ +array greater_equal(const array& a, const array& b, StreamOrDevice s = {}); +inline array operator>=(const array& a, const array& b) { + return greater_equal(a, b); +} +template +array operator>=(T a, const array& b) { + return greater_equal(array(a), b); +} +template +array operator>=(const array& a, T b) { + return greater_equal(a, array(b)); +} + +/** Returns bool array with (a < b) element-wise. */ +array less(const array& a, const array& b, StreamOrDevice s = {}); +inline array operator<(const array& a, const array& b) { + return less(a, b); +} +template +array operator<(T a, const array& b) { + return less(array(a), b); +} +template +array operator<(const array& a, T b) { + return less(a, array(b)); +} + +/** Returns bool array with (a <= b) element-wise. */ +array less_equal(const array& a, const array& b, StreamOrDevice s = {}); +inline array operator<=(const array& a, const array& b) { + return less_equal(a, b); +} +template +array operator<=(T a, const array& b) { + return less_equal(array(a), b); +} +template +array operator<=(const array& a, T b) { + return less_equal(a, array(b)); +} + +/** True if two arrays have the same shape and elements. */ +array array_equal( + const array& a, + const array& b, + bool equal_nan, + StreamOrDevice s = {}); +inline array +array_equal(const array& a, const array& b, StreamOrDevice s = {}) { + return array_equal(a, b, false, s); +} + +/** Select from x or y depending on condition. */ +array where( + const array& condition, + const array& x, + const array& y, + StreamOrDevice s = {}); + +/** Reduction operations */ + +/** True if all elements in the array are true (or non-zero). **/ +array all(const array& a, bool keepdims, StreamOrDevice s = {}); +inline array all(const array& a, StreamOrDevice s = {}) { + return all(a, false, to_stream(s)); +} + +/** True if the two arrays are equal within the specified tolerance. */ +array allclose( + const array& a, + const array& b, + double rtol = 1e-5, + double atol = 1e-8, + StreamOrDevice s = {}); + +/** + * Reduces the input along the given axes. An output value is true + * if all the corresponding inputs are true. + **/ +array all( + const array& a, + const std::vector& axes, + bool keepdims = false, + StreamOrDevice s = {}); + +/** + * Reduces the input along the given axis. An output value is true + * if all the corresponding inputs are true. + **/ +array all( + const array& a, + int axis, + bool keepdims = false, + StreamOrDevice s = {}); + +/** True if any elements in the array are true (or non-zero). **/ +array any(const array& a, bool keepdims, StreamOrDevice s = {}); +inline array any(const array& a, StreamOrDevice s = {}) { + return any(a, false, to_stream(s)); +} + +/** + * Reduces the input along the given axes. An output value is true + * if any of the corresponding inputs are true. + **/ +array any( + const array& a, + const std::vector& axes, + bool keepdims = false, + StreamOrDevice s = {}); + +/** + * Reduces the input along the given axis. An output value is true + * if any of the corresponding inputs are true. + **/ +array any( + const array& a, + int axis, + bool keepdims = false, + StreamOrDevice s = {}); + +/** Sums the elements of an array. */ +array sum(const array& a, bool keepdims, StreamOrDevice s = {}); +inline array sum(const array& a, StreamOrDevice s = {}) { + return sum(a, false, to_stream(s)); +} + +/** Sums the elements of an array along the given axes. */ +array sum( + const array& a, + const std::vector& axes, + bool keepdims = false, + StreamOrDevice s = {}); + +/** Sums the elements of an array along the given axis. */ +array sum( + const array& a, + int axis, + bool keepdims = false, + StreamOrDevice s = {}); + +/** Computes the mean of the elements of an array. */ +array mean(const array& a, bool keepdims, StreamOrDevice s = {}); +inline array mean(const array& a, StreamOrDevice s = {}) { + return mean(a, false, to_stream(s)); +} + +/** Computes the mean of the elements of an array along the given axes */ +array mean( + const array& a, + const std::vector& axes, + bool keepdims = false, + StreamOrDevice s = {}); + +/** Computes the mean of the elements of an array along the given axis */ +array mean( + const array& a, + int axis, + bool keepdims = false, + StreamOrDevice s = {}); + +/** Computes the mean of the elements of an array. */ +array var(const array& a, bool keepdims, int ddof = 0, StreamOrDevice s = {}); +inline array var(const array& a, StreamOrDevice s = {}) { + return var(a, false, 0, to_stream(s)); +} + +/** Computes the var of the elements of an array along the given axes */ +array var( + const array& a, + const std::vector& axes, + bool keepdims = false, + int ddof = 0, + StreamOrDevice s = {}); + +/** Computes the var of the elements of an array along the given axis */ +array var( + const array& a, + int axis, + bool keepdims = false, + int ddof = 0, + StreamOrDevice s = {}); + +/** The product of all elements of the array. */ +array prod(const array& a, bool keepdims, StreamOrDevice s = {}); +inline array prod(const array& a, StreamOrDevice s = {}) { + return prod(a, false, to_stream(s)); +} + +/** The product of the elements of an array along the given axes. */ +array prod( + const array& a, + const std::vector& axes, + bool keepdims = false, + StreamOrDevice s = {}); + +/** The product of the elements of an array along the given axis. */ +array prod( + const array& a, + int axis, + bool keepdims = false, + StreamOrDevice s = {}); + +/** The maximum of all elements of the array. */ +array max(const array& a, bool keepdims, StreamOrDevice s = {}); +inline array max(const array& a, StreamOrDevice s = {}) { + return max(a, false, to_stream(s)); +} + +/** The maximum of the elements of an array along the given axes. */ +array max( + const array& a, + const std::vector& axes, + bool keepdims = false, + StreamOrDevice s = {}); + +/** The maximum of the elements of an array along the given axis. */ +array max( + const array& a, + int axis, + bool keepdims = false, + StreamOrDevice s = {}); + +/** The minimum of all elements of the array. */ +array min(const array& a, bool keepdims, StreamOrDevice s = {}); +inline array min(const array& a, StreamOrDevice s = {}) { + return min(a, false, to_stream(s)); +} + +/** The minimum of the elements of an array along the given axes. */ +array min( + const array& a, + const std::vector& axes, + bool keepdims = false, + StreamOrDevice s = {}); + +/** The minimum of the elements of an array along the given axis. */ +array min( + const array& a, + int axis, + bool keepdims = false, + StreamOrDevice s = {}); + +/** Returns the index of the minimum value in the array. */ +array argmin(const array& a, bool keepdims, StreamOrDevice s = {}); +inline array argmin(const array& a, StreamOrDevice s = {}) { + return argmin(a, false, s); +} + +/** Returns the indices of the minimum values along a given axis. */ +array argmin( + const array& a, + int axis, + bool keepdims = false, + StreamOrDevice s = {}); + +/** Returns the index of the maximum value in the array. */ +array argmax(const array& a, bool keepdims, StreamOrDevice s = {}); +inline array argmax(const array& a, StreamOrDevice s = {}) { + return argmax(a, false, s); +} + +/** Returns the indices of the maximum values along a given axis. */ +array argmax( + const array& a, + int axis, + bool keepdims = false, + StreamOrDevice s = {}); + +/** Returns a sorted copy of the flattened array. */ +array sort(const array& a, StreamOrDevice s = {}); + +/** Returns a sorted copy of the array along a given axis. */ +array sort(const array& a, int axis, StreamOrDevice s = {}); + +/** Returns indices that sort the flattened array. */ +array argsort(const array& a, StreamOrDevice s = {}); + +/** Returns indices that sort the array along a given axis. */ +array argsort(const array& a, int axis, StreamOrDevice s = {}); + +/** + * Returns a partitioned copy of the flattened array + * such that the smaller kth elements are first. + **/ +array partition(const array& a, int kth, StreamOrDevice s = {}); + +/** + * Returns a partitioned copy of the array along a given axis + * such that the smaller kth elements are first. + **/ +array partition(const array& a, int kth, int axis, StreamOrDevice s = {}); + +/** + * Returns indices that partition the flattened array + * such that the smaller kth elements are first. + **/ +array argpartition(const array& a, int kth, StreamOrDevice s = {}); + +/** + * Returns indices that partition the array along a given axis + * such that the smaller kth elements are first. + **/ +array argpartition(const array& a, int kth, int axis, StreamOrDevice s = {}); + +/** Returns topk elements of the flattened array. */ +array topk(const array& a, int k, StreamOrDevice s = {}); + +/** Returns topk elements of the array along a given axis. */ +array topk(const array& a, int k, int axis, StreamOrDevice s = {}); + +/** The logsumexp of all elements of the array. */ +array logsumexp(const array& a, bool keepdims, StreamOrDevice s = {}); +inline array logsumexp(const array& a, StreamOrDevice s = {}) { + return logsumexp(a, false, to_stream(s)); +} + +/** The logsumexp of the elements of an array along the given axes. */ +array logsumexp( + const array& a, + const std::vector& axes, + bool keepdims = false, + StreamOrDevice s = {}); + +/** The logsumexp of the elements of an array along the given axis. */ +array logsumexp( + const array& a, + int axis, + bool keepdims = false, + StreamOrDevice s = {}); + +/** Simple arithmetic operations */ + +/** Absolute value of elements in an array. */ +array abs(const array& a, StreamOrDevice s = {}); + +/** Negate an array. */ +array negative(const array& a, StreamOrDevice s = {}); +array operator-(const array& a); + +/** The sign of the elements in an array. */ +array sign(const array& a, StreamOrDevice s = {}); + +/** Logical not of an array */ +array logical_not(const array& a, StreamOrDevice s = {}); + +/** The reciprocal (1/x) of the elements in an array. */ +array reciprocal(const array& a, StreamOrDevice s = {}); + +/** Add two arrays. */ +array add(const array& a, const array& b, StreamOrDevice s = {}); +array operator+(const array& a, const array& b); +template +array operator+(T a, const array& b) { + return add(array(a), b); +} +template +array operator+(const array& a, T b) { + return add(a, array(b)); +} + +/** Subtract two arrays. */ +array subtract(const array& a, const array& b, StreamOrDevice s = {}); +array operator-(const array& a, const array& b); +template +array operator-(T a, const array& b) { + return subtract(array(a), b); +} +template +array operator-(const array& a, T b) { + return subtract(a, array(b)); +} + +/** Multiply two arrays. */ +array multiply(const array& a, const array& b, StreamOrDevice s = {}); +array operator*(const array& a, const array& b); +template +array operator*(T a, const array& b) { + return multiply(array(a), b); +} +template +array operator*(const array& a, T b) { + return multiply(a, array(b)); +} + +/** Divide two arrays. */ +array divide(const array& a, const array& b, StreamOrDevice s = {}); +array operator/(const array& a, const array& b); +array operator/(double a, const array& b); +array operator/(const array& a, double b); + +/** Element-wise maximum between two arrays. */ +array maximum(const array& a, const array& b, StreamOrDevice s = {}); + +/** Element-wise minimum between two arrays. */ +array minimum(const array& a, const array& b, StreamOrDevice s = {}); + +/** Square the elements of an array. */ +array square(const array& a, StreamOrDevice s = {}); + +/** Exponential of the elements of an array. */ +array exp(const array& a, StreamOrDevice s = {}); + +/** Sine of the elements of an array */ +array sin(const array& a, StreamOrDevice s = {}); + +/** Cosine of the elements of an array */ +array cos(const array& a, StreamOrDevice s = {}); + +/** Tangent of the elements of an array */ +array tan(const array& a, StreamOrDevice s = {}); + +/** Arc Sine of the elements of an array */ +array arcsin(const array& a, StreamOrDevice s = {}); + +/** Arc Cosine of the elements of an array */ +array arccos(const array& a, StreamOrDevice s = {}); + +/** Arc Tangent of the elements of an array */ +array arctan(const array& a, StreamOrDevice s = {}); + +/** Hyperbolic Sine of the elements of an array */ +array sinh(const array& a, StreamOrDevice s = {}); + +/** Hyperbolic Cosine of the elements of an array */ +array cosh(const array& a, StreamOrDevice s = {}); + +/** Hyperbolic Tangent of the elements of an array */ +array tanh(const array& a, StreamOrDevice s = {}); + +/** Inverse Hyperbolic Sine of the elements of an array */ +array arcsinh(const array& a, StreamOrDevice s = {}); + +/** Inverse Hyperbolic Cosine of the elements of an array */ +array arccosh(const array& a, StreamOrDevice s = {}); + +/** Inverse Hyperbolic Tangent of the elements of an array */ +array arctanh(const array& a, StreamOrDevice s = {}); + +/** Natural logarithm of the elements of an array. */ +array log(const array& a, StreamOrDevice s = {}); + +/** Log base 2 of the elements of an array. */ +array log2(const array& a, StreamOrDevice s = {}); + +/** Log base 10 of the elements of an array. */ +array log10(const array& a, StreamOrDevice s = {}); + +/** Natural logarithm of one plus elements in the array: `log(1 + a)`. */ +array log1p(const array& a, StreamOrDevice s = {}); + +/** Log-add-exp of one elements in the array: `log(exp(a) + exp(b))`. */ +array logaddexp(const array& a, const array& b, StreamOrDevice s = {}); + +/** Element-wise logistic sigmoid of the array: `1 / (1 + exp(-x)`. */ +array sigmoid(const array& a, StreamOrDevice s = {}); + +/** Computes the error function of the elements of an array. */ +array erf(const array& a, StreamOrDevice s = {}); + +/** Computes the inverse error function of the elements of an array. */ +array erfinv(const array& a, StreamOrDevice s = {}); + +/** Stop the flow of gradients. */ +array stop_gradient(const array& a, StreamOrDevice s = {}); + +/** Matrix-matrix multiplication. */ +array matmul(const array& a, const array& b, StreamOrDevice s = {}); + +/** Gather array entries given indices and slices */ +array gather( + const array& a, + const std::vector& indices, + const std::vector& axes, + const std::vector& slice_sizes, + StreamOrDevice s = {}); +inline array gather( + const array& a, + const array& indices, + int axis, + const std::vector& slice_sizes, + StreamOrDevice s = {}) { + return gather(a, {indices}, std::vector{axis}, slice_sizes, s); +} + +/** Take array slices at the given indices of the specified axis. */ +array take( + const array& a, + const array& indices, + int axis, + StreamOrDevice s = {}); + +/** Take array entries at the given indices treating the array as flattened. */ +array take(const array& a, const array& indices, StreamOrDevice s = {}); + +/** Take array entries given indices along the axis */ +array take_along_axis( + const array& a, + const array& indices, + int axis, + StreamOrDevice s = {}); + +/** Scatter updates to given linear indices */ +array scatter( + const array& a, + const std::vector& indices, + const array& updates, + const std::vector& axes, + StreamOrDevice s = {}); +inline array scatter( + const array& a, + const array& indices, + const array& updates, + int axis, + StreamOrDevice s = {}) { + return scatter(a, {indices}, updates, std::vector{axis}, s); +} + +/** Scatter and add updates to given indices */ +array scatter_add( + const array& a, + const std::vector& indices, + const array& updates, + const std::vector& axes, + StreamOrDevice s = {}); +inline array scatter_add( + const array& a, + const array& indices, + const array& updates, + int axis, + StreamOrDevice s = {}) { + return scatter_add(a, {indices}, updates, std::vector{axis}, s); +} + +/** Scatter and prod updates to given indices */ +array scatter_prod( + const array& a, + const std::vector& indices, + const array& updates, + const std::vector& axes, + StreamOrDevice s = {}); +inline array scatter_prod( + const array& a, + const array& indices, + const array& updates, + int axis, + StreamOrDevice s = {}) { + return scatter_prod(a, {indices}, updates, std::vector{axis}, s); +} + +/** Scatter and max updates to given linear indices */ +array scatter_max( + const array& a, + const std::vector& indices, + const array& updates, + const std::vector& axes, + StreamOrDevice s = {}); +inline array scatter_max( + const array& a, + const array& indices, + const array& updates, + int axis, + StreamOrDevice s = {}) { + return scatter_max(a, {indices}, updates, std::vector{axis}, s); +} +/** Scatter and min updates to given linear indices */ +array scatter_min( + const array& a, + const std::vector& indices, + const array& updates, + const std::vector& axes, + StreamOrDevice s = {}); +inline array scatter_min( + const array& a, + const array& indices, + const array& updates, + int axis, + StreamOrDevice s = {}) { + return scatter_min(a, {indices}, updates, std::vector{axis}, s); +} + +/** Square root the elements of an array. */ +array sqrt(const array& a, StreamOrDevice s = {}); + +/** Square root and reciprocal the elements of an array. */ +array rsqrt(const array& a, StreamOrDevice s = {}); + +/** Softmax of an array. */ +array softmax( + const array& a, + const std::vector& axes, + StreamOrDevice s = {}); + +/** Softmax of an array. */ +array softmax(const array& a, StreamOrDevice s = {}); + +/** Softmax of an array. */ +inline array softmax(const array& a, int axis, StreamOrDevice s = {}) { + return softmax(a, std::vector{axis}, s); +} + +/** Raise elements of a to the power of b element-wise */ +array power(const array& a, const array& b, StreamOrDevice s = {}); +inline array operator^(const array& a, const array& b) { + return power(a, b); +} +template +array operator^(T a, const array& b) { + return power(array(a), b); +} +template +array operator^(const array& a, T b) { + return power(a, array(b)); +} + +/** Cumulative sum of an array. */ +array cumsum( + const array& a, + int axis, + bool reverse = false, + bool inclusive = true, + StreamOrDevice s = {}); + +/** Cumulative product of an array. */ +array cumprod( + const array& a, + int axis, + bool reverse = false, + bool inclusive = true, + StreamOrDevice s = {}); + +/** Cumulative max of an array. */ +array cummax( + const array& a, + int axis, + bool reverse = false, + bool inclusive = true, + StreamOrDevice s = {}); + +/** Cumulative min of an array. */ +array cummin( + const array& a, + int axis, + bool reverse = false, + bool inclusive = true, + StreamOrDevice s = {}); + +/** Convolution operations */ + +/** 1D convolution with a filter */ +array conv1d( + const array& input, + const array& weight, + int stride = 1, + int padding = 0, + int dilation = 1, + int groups = 1, + StreamOrDevice s = {}); + +/** 2D convolution with a filter */ +array conv2d( + const array& input, + const array& weight, + const std::pair& stride = {1, 1}, + const std::pair& padding = {0, 0}, + const std::pair& dilation = {1, 1}, + int groups = 1, + StreamOrDevice s = {}); + +/** Serialization operations */ + +/** Save array to out stream in .npy format */ +void save( + std::shared_ptr out_stream, + array a, + bool retain_graph = true); + +/** Save array to file in .npy format */ +void save(const std::string& file, array a, bool retain_graph = true); + +/** Load array from reader in .npy format */ +array load(std::shared_ptr in_stream, StreamOrDevice s = {}); + +/** Load array from file in .npy format */ +array load(const std::string& file, StreamOrDevice s = {}); + +} // namespace mlx::core diff --git a/mlx/primitives.h b/mlx/primitives.h new file mode 100644 index 0000000000..44159f86d1 --- /dev/null +++ b/mlx/primitives.h @@ -0,0 +1,1527 @@ +#pragma once + +#include "array.h" +#include "device.h" +#include "load.h" +#include "stream.h" + +#define DEFINE_GRADS() \ + array jvp( \ + const std::vector& primals, \ + const std::vector& tangents, \ + const std::vector& argnums) override; \ + \ + std::vector vjp( \ + const std::vector& primals, \ + const array& cotan, \ + const std::vector& argnums) override; + +#define DEFINE_PRINT(PRIMITIVE) \ + void print(std::ostream& os) override { \ + os << #PRIMITIVE; \ + } + +#define DEFINE_DEFAULT_IS_EQUIVALENT() \ + bool is_equivalent(const Primitive& other) const override { \ + return true; \ + } + +namespace mlx::core { + +// Abstract base class +class Primitive { + public: + explicit Primitive(Stream stream) : stream_(stream) {} + + /** The device the primitive will run on. */ + const Device& device() { + return stream().device; + } + + /** The stream the primitive will run on. */ + const Stream& stream() { + return stream_; + } + + /** + * A primitive must know how to evaluate itself on + * the CPU/GPU for the given inputs and populate the output array. + * + * To avoid unecessary allocations, the evaluation function + * is responsible for allocating space for the array. + */ + virtual void eval_cpu(const std::vector& inputs, array& out) = 0; + virtual void eval_gpu(const std::vector& inputs, array& out) = 0; + + /** + * The Jacobian-vector product. + */ + virtual array jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums); + + /** + * The vector-Jacobian product. + */ + virtual std::vector vjp( + const std::vector& primals, + const array& cotan, + const std::vector& argnums); + + /** + * The primitive must know how to vectorize itself accross + * the given axes. The output is a pair containing the array + * representing the vectorized computation and the axis which + * corresponds to the output vectorized dimension. + */ + virtual std::pair vmap( + const std::vector& inputs, + const std::vector& axes); + + /** Print the primitive. */ + virtual void print(std::ostream& os) = 0; + + /** Equivalence check defaults to false unless overriden by the primitive */ + virtual bool is_equivalent(const Primitive& other) const { + return false; + } + + virtual ~Primitive() = default; + Primitive(const Primitive& other) = delete; + Primitive(Primitive&& other) = delete; + Primitive& operator=(const Primitive& other) = delete; + Primitive& operator=(Primitive&& other) = delete; + + private: + // Every primitive stores the stream it should run in + Stream stream_; +}; + +class Abs : public Primitive { + public: + explicit Abs(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(Abs) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, array& out); +}; + +class Add : public Primitive { + public: + explicit Add(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(Add) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, array& out); +}; + +class Arange : public Primitive { + public: + explicit Arange(Stream stream, double start, double stop, double step) + : Primitive(stream), start_(start), stop_(stop), step_(step){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_PRINT(Arange) + bool is_equivalent(const Primitive& other) const override; + + private: + double start_; + double stop_; + double step_; + + void eval(const std::vector& inputs, array& out); +}; + +class ArcCos : public Primitive { + public: + explicit ArcCos(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(ArcCos) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, array& out); +}; + +class ArcCosh : public Primitive { + public: + explicit ArcCosh(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(ArcCosh) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, array& out); +}; + +class ArcSin : public Primitive { + public: + explicit ArcSin(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(ArcSin) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, array& out); +}; + +class ArcSinh : public Primitive { + public: + explicit ArcSinh(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(ArcSinh) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, array& out); +}; + +class ArcTan : public Primitive { + public: + explicit ArcTan(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(ArcTan) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, array& out); +}; + +class ArcTanh : public Primitive { + public: + explicit ArcTanh(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(ArcTanh) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, array& out); +}; + +class ArgPartition : public Primitive { + public: + explicit ArgPartition(Stream stream, int kth, int axis) + : Primitive(stream), kth_(kth), axis_(axis){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_PRINT(ArgPartition) + bool is_equivalent(const Primitive& other) const override; + + private: + int kth_; + int axis_; + + void eval(const std::vector& inputs, array& out); +}; + +class ArgReduce : public Primitive { + public: + enum ReduceType { + ArgMin, + ArgMax, + }; + + explicit ArgReduce(Stream stream, ReduceType reduce_type, int axis) + : Primitive(stream), reduce_type_(reduce_type), axis_(axis){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_PRINT(ArgReduce) + bool is_equivalent(const Primitive& other) const override; + + private: + ReduceType reduce_type_; + int axis_; + + void eval(const std::vector& inputs, array& out); +}; + +class ArgSort : public Primitive { + public: + explicit ArgSort(Stream stream, int axis) : Primitive(stream), axis_(axis){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_PRINT(ArgSort) + bool is_equivalent(const Primitive& other) const override; + + private: + int axis_; + + void eval(const std::vector& inputs, array& out); +}; + +class AsType : public Primitive { + public: + explicit AsType(Stream stream, Dtype dtype) + : Primitive(stream), dtype_(dtype){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(AsType) + bool is_equivalent(const Primitive& other) const override; + + private: + Dtype dtype_; + + void eval(const std::vector& inputs, array& out); +}; + +class AsStrided : public Primitive { + public: + explicit AsStrided( + Stream stream, + const std::vector& shape, + const std::vector& strides, + size_t offset) + : Primitive(stream), shape_(shape), strides_(strides), offset_(offset){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_GRADS() + DEFINE_PRINT(AsStrided) + bool is_equivalent(const Primitive& other) const override; + + private: + std::vector shape_; + std::vector strides_; + size_t offset_; + + void eval(const std::vector& inputs, array& out); +}; + +class Broadcast : public Primitive { + public: + explicit Broadcast(Stream stream, const std::vector& shape) + : Primitive(stream), shape_(shape){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(Broadcast) + bool is_equivalent(const Primitive& other) const override; + + private: + std::vector shape_; + + void eval(const std::vector& inputs, array& out); +}; + +class Concatenate : public Primitive { + public: + explicit Concatenate(Stream stream, int axis) + : Primitive(stream), axis_(axis){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(Concatenate) + bool is_equivalent(const Primitive& other) const override; + + private: + int axis_; + + void eval(const std::vector& inputs, array& out); +}; + +class Convolution : public Primitive { + public: + explicit Convolution( + Stream stream, + const std::vector& padding, + const std::vector& kernel_strides, + const std::vector& kernel_dilation, + const std::vector& input_dilation) + : Primitive(stream), + padding_(padding), + kernel_strides_(kernel_strides), + kernel_dilation_(kernel_dilation), + input_dilation_(input_dilation){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::vector vjp( + const std::vector& primals, + const array& cotan, + const std::vector& argnums) override; + + DEFINE_PRINT(Convolution) + bool is_equivalent(const Primitive& other) const override; + + private: + std::vector padding_; + std::vector kernel_strides_; + std::vector kernel_dilation_; + std::vector input_dilation_; + + void eval(const std::vector& inputs, array& out); +}; + +class Copy : public Primitive { + public: + explicit Copy(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(Copy) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, array& out); +}; + +class Cos : public Primitive { + public: + explicit Cos(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(Cos) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, array& out); +}; + +class Cosh : public Primitive { + public: + explicit Cosh(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(Cosh) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, array& out); +}; + +class Divide : public Primitive { + public: + explicit Divide(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(Divide) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, array& out); +}; + +class Equal : public Primitive { + public: + explicit Equal(Stream stream, bool equal_nan = false) + : Primitive(stream), equal_nan_(equal_nan){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(Equal) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, array& out); + bool equal_nan_; +}; + +class Erf : public Primitive { + public: + explicit Erf(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(Erf) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, array& out); +}; + +class ErfInv : public Primitive { + public: + explicit ErfInv(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(ErfInv) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, array& out); +}; + +class Exp : public Primitive { + public: + explicit Exp(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(Exp) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, array& out); +}; + +class FFT : public Primitive { + public: + explicit FFT( + Stream stream, + const std::vector& axes, + bool inverse, + bool real) + : Primitive(stream), axes_(axes), inverse_(inverse), real_(real){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(FFT) + + bool is_equivalent(const Primitive& other) const override; + + private: + std::vector axes_; + bool inverse_; + bool real_; + + void eval(const std::vector& inputs, array& out); +}; + +class Full : public Primitive { + public: + explicit Full(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(Full) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, array& out); +}; + +class Gather : public Primitive { + public: + explicit Gather( + Stream stream, + const std::vector& axes, + const std::vector& slice_sizes) + : Primitive(stream), axes_(axes), slice_sizes_(slice_sizes){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(Gather) + bool is_equivalent(const Primitive& other) const override; + + private: + void eval(const std::vector& inputs, array& out); + std::vector axes_; + std::vector slice_sizes_; +}; + +class Greater : public Primitive { + public: + explicit Greater(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(Greater) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, array& out); +}; + +class GreaterEqual : public Primitive { + public: + explicit GreaterEqual(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(GreaterEqual) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, array& out); +}; + +class Less : public Primitive { + public: + explicit Less(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(Less) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, array& out); +}; + +class LessEqual : public Primitive { + public: + explicit LessEqual(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(LessEqual) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, array& out); +}; + +class Load : public Primitive { + public: + explicit Load( + Stream stream, + std::shared_ptr reader, + size_t offset, + bool swap_endianness = false) + : Primitive(stream), + reader_(reader), + offset_(offset), + swap_endianness_(swap_endianness){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_PRINT(Load) + + private: + void eval(const std::vector& inputs, array& out); + std::shared_ptr reader_; + size_t offset_; + bool swap_endianness_; +}; + +class Log : public Primitive { + public: + enum Base { two, ten, e }; + + explicit Log(Stream stream, Base base) : Primitive(stream), base_(base){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(Log) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + Base base_; + void eval(const std::vector& inputs, array& out); +}; + +class Log1p : public Primitive { + public: + explicit Log1p(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(Log1p) + + private: + void eval(const std::vector& inputs, array& out); +}; + +class LogicalNot : public Primitive { + public: + explicit LogicalNot(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(LogicalNot) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, array& out); +}; + +class LogAddExp : public Primitive { + public: + explicit LogAddExp(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(LogAddExp) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, array& out); +}; + +class Matmul : public Primitive { + public: + explicit Matmul(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::vector vjp( + const std::vector& primals, + const array& cotan, + const std::vector& argnums) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_PRINT(Matmul) + DEFINE_DEFAULT_IS_EQUIVALENT() +}; + +class Maximum : public Primitive { + public: + explicit Maximum(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(Maximum) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, array& out); +}; + +class Minimum : public Primitive { + public: + explicit Minimum(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(Minimum) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, array& out); +}; + +class Multiply : public Primitive { + public: + explicit Multiply(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(Multiply) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, array& out); +}; + +class Negative : public Primitive { + public: + explicit Negative(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(Negative) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, array& out); +}; + +class NotEqual : public Primitive { + public: + explicit NotEqual(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(NotEqual) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, array& out); +}; + +class Pad : public Primitive { + public: + explicit Pad( + Stream stream, + const std::vector& axes, + const std::vector& low_pad_size, + const std::vector& high_pad_size) + : Primitive(stream), + axes_(axes), + low_pad_size_(low_pad_size), + high_pad_size_(high_pad_size){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(Pad) + bool is_equivalent(const Primitive& other) const override; + + private: + std::vector axes_; + std::vector low_pad_size_; + std::vector high_pad_size_; + + void eval(const std::vector& inputs, array& out); +}; + +class Partition : public Primitive { + public: + explicit Partition(Stream stream, int kth, int axis) + : Primitive(stream), kth_(kth), axis_(axis){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(Partition) + bool is_equivalent(const Primitive& other) const override; + + private: + int kth_; + int axis_; + + void eval(const std::vector& inputs, array& out); +}; + +class Power : public Primitive { + public: + explicit Power(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(Power) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, array& out); +}; + +class RandomBits : public Primitive { + public: + explicit RandomBits(Stream stream, const std::vector& shape, int width) + : Primitive(stream), shape_(shape), width_(width){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_PRINT(RandomBits) + bool is_equivalent(const Primitive& other) const override; + + private: + std::vector shape_; + int width_; + + void eval(const std::vector& inputs, array& out); +}; + +class Reshape : public Primitive { + public: + explicit Reshape(Stream stream, const std::vector& shape) + : Primitive(stream), shape_(shape){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(Reshape) + bool is_equivalent(const Primitive& other) const override; + + private: + std::vector shape_; + + void eval(const std::vector& inputs, array& out); +}; + +class Reduce : public Primitive { + public: + enum ReduceType { And, Or, Sum, Prod, Min, Max }; + + explicit Reduce( + Stream stream, + ReduceType reduce_type, + const std::vector& axes) + : Primitive(stream), reduce_type_(reduce_type), axes_(axes){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + std::vector vjp( + const std::vector& primals, + const array& cotan, + const std::vector& argnums) override; + + void print(std::ostream& os) override { + switch (reduce_type_) { + case And: + os << "And"; + case Or: + os << "And"; + break; + case Sum: + os << "Sum"; + break; + case Prod: + os << "Prod"; + break; + case Min: + os << "Min"; + break; + case Max: + os << "Max"; + break; + } + os << " Reduce"; + } + bool is_equivalent(const Primitive& other) const override; + + private: + ReduceType reduce_type_; + std::vector axes_; + + void eval(const std::vector& inputs, array& out); +}; + +class Scan : public Primitive { + public: + enum ReduceType { Max, Min, Sum, Prod }; + + explicit Scan( + Stream stream, + ReduceType reduce_type, + int axis, + bool reverse, + bool inclusive) + : Primitive(stream), + reduce_type_(reduce_type), + axis_(axis), + reverse_(reverse), + inclusive_(inclusive){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS(); + void print(std::ostream& os) override { + os << "Cum"; + switch (reduce_type_) { + case Sum: + os << "Sum"; + break; + case Prod: + os << "Prod"; + break; + case Min: + os << "Min"; + break; + case Max: + os << "Max"; + break; + } + os << " Reduce"; + } + bool is_equivalent(const Primitive& other) const override; + + private: + ReduceType reduce_type_; + int axis_; + bool reverse_; + bool inclusive_; + + void eval(const std::vector& inputs, array& out); +}; + +class Scatter : public Primitive { + public: + enum ReduceType { Max, Min, Sum, Prod, None }; + + explicit Scatter( + Stream stream, + ReduceType reduce_type, + const std::vector& axes) + : Primitive(stream), reduce_type_(reduce_type), axes_(axes){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_PRINT(Scatter) + bool is_equivalent(const Primitive& other) const override; + + private: + void eval(const std::vector& inputs, array& out); + ReduceType reduce_type_; + std::vector axes_; +}; + +class Sigmoid : public Primitive { + public: + explicit Sigmoid(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(Sigmoid) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, array& out); +}; + +class Sign : public Primitive { + public: + explicit Sign(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(Sign) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, array& out); +}; + +class Sin : public Primitive { + public: + explicit Sin(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(Sin) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, array& out); +}; + +class Sinh : public Primitive { + public: + explicit Sinh(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(Sinh) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, array& out); +}; + +class Slice : public Primitive { + public: + explicit Slice( + Stream stream, + const std::vector& start_indices, + const std::vector& end_indices, + const std::vector& strides) + : Primitive(stream), + start_indices_(start_indices), + end_indices_(end_indices), + strides_(strides){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(Slice) + bool is_equivalent(const Primitive& other) const override; + + private: + std::vector start_indices_; + std::vector end_indices_; + std::vector strides_; + + void eval(const std::vector& inputs, array& out); +}; + +class Softmax : public Primitive { + public: + explicit Softmax(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(Softmax) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, array& out); +}; + +class Sort : public Primitive { + public: + explicit Sort(Stream stream, int axis) : Primitive(stream), axis_(axis){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(Sort) + bool is_equivalent(const Primitive& other) const override; + + private: + int axis_; + + void eval(const std::vector& inputs, array& out); +}; + +class Square : public Primitive { + public: + explicit Square(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(Square) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, array& out); +}; + +class Sqrt : public Primitive { + public: + explicit Sqrt(Stream stream, bool recip = false) + : Primitive(stream), recip_(recip){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(Sqrt) + bool is_equivalent(const Primitive& other) const override; + + private: + void eval(const std::vector& inputs, array& out); + bool recip_; +}; + +class StopGradient : public Primitive { + public: + explicit StopGradient(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_PRINT(StopGradient) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, array& out); +}; + +class Subtract : public Primitive { + public: + explicit Subtract(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(Subtract) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, array& out); +}; + +class Tan : public Primitive { + public: + explicit Tan(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(Tan) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, array& out); +}; + +class Tanh : public Primitive { + public: + explicit Tanh(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(Tanh) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, array& out); +}; + +class Uniform : public Primitive { + public: + explicit Uniform(Stream stream) : Primitive(stream){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_PRINT(Uniform) + DEFINE_DEFAULT_IS_EQUIVALENT() + + private: + void eval(const std::vector& inputs, array& out); +}; + +class Transpose : public Primitive { + public: + explicit Transpose(Stream stream, const std::vector& axes) + : Primitive(stream), axes_(axes){}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::pair vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_GRADS() + DEFINE_PRINT(Transpose) + bool is_equivalent(const Primitive& other) const override; + + private: + std::vector axes_; + + void eval(const std::vector& inputs, array& out); +}; + +} // namespace mlx::core diff --git a/mlx/random.cpp b/mlx/random.cpp new file mode 100644 index 0000000000..41ce2b5a66 --- /dev/null +++ b/mlx/random.cpp @@ -0,0 +1,300 @@ +#include +#include + +#include "mlx/ops.h" +#include "mlx/primitives.h" +#include "mlx/random.h" +#include "mlx/utils.h" + +namespace mlx::core::random { + +KeySequence::KeySequence(uint64_t seed) : key_(key(seed)) {} + +void KeySequence::seed(uint64_t seed) { + key_ = key((seed)); +} + +array KeySequence::next() { + auto out = split(key_); + key_ = out.first; + return out.second; +} + +void seed(uint64_t seed) { + KeySequence::default_().seed(seed); +} + +array key(uint64_t seed) { + uint32_t k1 = static_cast(seed >> 32); + uint32_t k2 = static_cast(seed); + return array({k1, k2}); +} + +array bits( + const std::vector& shape, + int width /* 4 */, + const std::optional& key_ /*= nullopt */, + StreamOrDevice s /* = {} */) { + auto key = key_ ? *key_ : KeySequence::default_().next(); + if (key.dtype() != uint32) { + std::ostringstream msg; + msg << "Expected key type uint32 but received " << key.dtype() << "."; + throw std::invalid_argument(msg.str()); + } + if (key.shape() != std::vector{2}) { + std::ostringstream msg; + msg << "Expected key shape (2) but received " << key.shape() << "."; + throw std::invalid_argument(msg.str()); + } + + auto get_dtype = [width]() { + switch (width) { + case 4: + return uint32; + case 2: + return uint16; + case 1: + return uint8; + default: + std::ostringstream msg; + msg << "[bits] Bit width must be in {1, 2, 4} but got " << width << "."; + throw std::invalid_argument(msg.str()); + } + }; + return array( + shape, + get_dtype(), + std::make_unique(to_stream(s), shape, width), + {key}); +} + +std::pair split(const array& key, StreamOrDevice s /* = {} */) { + auto stream = to_stream(s); + auto out = mlx::core::split(random::split(key, 2, stream), 2, stream); + return {reshape(out[0], {2}, stream), reshape(out[1], {2}, stream)}; +} + +array split(const array& key, int num, StreamOrDevice s /* = {} */) { + return bits({num, 2}, 4, key, s); +} + +array uniform( + const array& low, + const array& high, + const std::vector& shape, + Dtype dtype /* = float32 */, + const std::optional& key /*= nullopt */, + StreamOrDevice s /* = {} */) { + if (!is_floating_point(dtype)) { + throw std::invalid_argument( + "Can only generate uniform numbers with floating point type."); + } + + auto stream = to_stream(s); + auto range = subtract(high, low, stream); + auto out_shape = broadcast_shapes(shape, range.shape()); + if (out_shape != shape) { + std::ostringstream msg; + msg << "Cannot generate random values of shape " << shape + << " from broadcasted shape " << out_shape << "."; + throw std::invalid_argument(msg.str()); + } + // Get random values between [0, nextafter(maxval, 0.0f)] since samples must + // be in [low, high) + // TODO replace minimum with modulo uint32_t(nextafter(maxval, 0.0f)) to avoid + // clipping effects + float maxval = std::numeric_limits::max(); + auto upper = array(std::nextafter(maxval, 0.0f), dtype); + auto out = minimum(bits(shape, size_of(dtype), key, stream), upper, stream); + out = divide(out, array(maxval, dtype), stream); + return add(multiply(range, out, stream), low, stream); +} + +array uniform( + const std::vector& shape, + Dtype dtype, + const std::optional& key /*= nullopt */, + StreamOrDevice s /* = {} */) { + return uniform( + array(0.0, dtype), array(1.0, dtype), shape, dtype, key, to_stream(s)); +} + +array normal( + const std::vector& shape, + Dtype dtype, + const std::optional& key /*= nullopt */, + StreamOrDevice s /* = {} */) { + auto stream = to_stream(s); + auto low = array(std::nextafter(-1.0f, 0.0f), dtype); + auto high = array(1.0f, dtype); + auto samples = uniform(low, high, shape, dtype, key, stream); + return multiply( + array(std::sqrt(2.0), dtype), erfinv(samples, stream), stream); +} + +array randint( + const array& low, + const array& high, + const std::vector& shape, + Dtype dtype /* = int32 */, + const std::optional& key /*= nullopt */, + StreamOrDevice s /* = {} */) { + if (!is_integral(dtype)) { + throw std::invalid_argument( + "[randint] randint only accepts integer dtypes and bool."); + } + auto u = uniform(low, high, shape, float32, key, s); + return astype(maximum(u, low, s), dtype, s); +} + +array bernoulli( + const array& p, + const std::vector& shape, + const std::optional& key /*= nullopt */, + StreamOrDevice s /* = {} */) { + if (!is_floating_point(p.dtype())) { + throw std::invalid_argument( + "[bernoulli] bernoulli probability `p` must be a float type."); + } + auto res = uniform(shape, p.dtype(), key, s); + res = less(res, p, s); + if (res.shape() != shape) { + throw std::invalid_argument( + "[bernoulli] shape of `p` is incompatible with argument `shape`."); + } + return res; +} + +array bernoulli( + const array& p, + const std::optional& key /*= nullopt */, + StreamOrDevice s /* = {} */) { + return bernoulli(p, p.shape(), key, s); +} + +array bernoulli( + const std::optional& key /*= nullopt */, + StreamOrDevice s /* = {} */) { + return bernoulli(array(0.5f), key, s); +} + +array truncated_normal( + const array& lower, + const array& upper, + const std::vector& shape, + Dtype dtype /* = float32 */, + const std::optional& key /*= nullopt */, + StreamOrDevice s /* = {} */) { + // Same as + // https://jax.readthedocs.io/en/latest/_modules/jax/_src/random.html#truncated_normal + + if (!is_floating_point(dtype)) { + throw std::invalid_argument( + "[trunc_normal] trunc_normal only accepts floating point dtypes."); + } + + auto sqrt2 = array(std::sqrt(2.0), dtype); + auto lower_t = astype(lower, dtype, s); + auto upper_t = astype(upper, dtype, s); + auto a = erf(divide(lower_t, sqrt2, s), s); + auto b = erf(divide(upper_t, sqrt2, s), s); + auto u = uniform(a, b, shape, dtype, key, s); + auto out = multiply(sqrt2, erfinv(u, s), s); + + // Clip in bouds + return maximum(minimum(upper_t, out, s), lower_t, s); +} + +array truncated_normal( + const array& lower, + const array& upper, + Dtype dtype /* = float32 */, + const std::optional& key /*= nullopt */, + StreamOrDevice s /* = {} */) { + auto shape = broadcast_shapes(lower.shape(), upper.shape()); + return truncated_normal(lower, upper, shape, dtype, key, s); +} + +array gumbel( + const std::vector& shape, + Dtype dtype /* = float32 */, + const std::optional& key /*= nullopt */, + StreamOrDevice s /* = {} */) { + // -log(-log(uniform(shape))) + return negative( + log(negative(log(uniform(shape, dtype, key, s), s), s), s), s); +} + +int get_valid_axis(int axis, int ndim) { + int ax = axis < 0 ? axis + ndim : axis; + if (ax < 0 || ax >= ndim) { + std::ostringstream msg; + msg << "[categorical] Invalid axis " << axis << " for logits with " << ndim + << " dimensions."; + throw std::invalid_argument(msg.str()); + } + return ax; +} + +array categorical_impl( + const array& logits, + int axis, + const std::vector& shape, + const std::optional& key /*= nullopt */, + StreamOrDevice s) { + auto gumbel_shape = shape; + auto offset = axis + shape.size() - logits.ndim() + 1; + gumbel_shape.insert(gumbel_shape.begin() + offset, logits.shape(axis)); + auto g = gumbel(gumbel_shape, float32, key, s); + return argmax(add(g, logits, s), offset, false, s); +} + +array categorical( + const array& logits, + int axis, + const std::vector& shape, + const std::optional& key /*= nullopt */, + StreamOrDevice s /* = {} */) { + // Validate and normalize axis + axis = get_valid_axis(axis, logits.ndim()); + + // Check that shape broadcasts with reduce(logits, axis) + auto reduced_shape = logits.shape(); + reduced_shape.erase(reduced_shape.begin() + axis); + if (broadcast_shapes(shape, reduced_shape) != shape) { + std::ostringstream msg; + msg << "[categorical] Requested shape " << shape + << " is not broadcast compatable with reduced logits shape" + << reduced_shape << "."; + throw std::invalid_argument(msg.str()); + } + + return categorical_impl(logits, axis, shape, key, s); +} + +array categorical( + const array& logits_, + int axis, + int num_samples, + const std::optional& key /*= nullopt */, + StreamOrDevice s /* = {} */) { + axis = get_valid_axis(axis, logits_.ndim()); + auto logits = expand_dims(logits_, -1); + auto shape = logits.shape(); + shape.erase(shape.begin() + axis); + shape.back() = num_samples; + return categorical_impl(logits, axis, shape, key, s); +} + +array categorical( + const array& logits, + int axis /* = -1 */, + const std::optional& key /*= nullopt */, + StreamOrDevice s /* = {} */) { + axis = get_valid_axis(axis, logits.ndim()); + auto shape = logits.shape(); + shape.erase(shape.begin() + axis); + return categorical_impl(logits, axis, shape, key, s); +} + +} // namespace mlx::core::random diff --git a/mlx/scheduler.cpp b/mlx/scheduler.cpp new file mode 100644 index 0000000000..cb2842a593 --- /dev/null +++ b/mlx/scheduler.cpp @@ -0,0 +1,43 @@ +#include "mlx/scheduler.h" +#include "mlx/backend/metal/metal.h" + +namespace mlx::core { + +Stream default_stream(Device d) { + if (!metal::is_available() && d == Device::gpu) { + throw std::invalid_argument( + "[default_stream] Cannot get gpu stream without gpu backend."); + } + return scheduler::scheduler().get_default_stream(d); +} + +void set_default_stream(Stream s) { + if (!metal::is_available() && s.device == Device::gpu) { + throw std::invalid_argument( + "[set_default_stream] Cannot set gpu stream without gpu backend."); + } + return scheduler::scheduler().set_default_stream(s); +} + +Stream new_stream(Device d) { + if (!metal::is_available() && d == Device::gpu) { + throw std::invalid_argument( + "[new_stream] Cannot make gpu stream without gpu backend."); + } + return scheduler::scheduler().new_stream(d); +} + +Stream new_stream() { + return scheduler::scheduler().new_stream(default_device()); +} + +namespace scheduler { + +/** A singleton scheduler to manage devices, streams, and task execution. */ +Scheduler& scheduler() { + static Scheduler scheduler; + return scheduler; +} + +} // namespace scheduler +} // namespace mlx::core diff --git a/mlx/scheduler.h b/mlx/scheduler.h new file mode 100644 index 0000000000..ba615d37e0 --- /dev/null +++ b/mlx/scheduler.h @@ -0,0 +1,170 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include "mlx/backend/metal/metal.h" +#include "mlx/device.h" +#include "mlx/stream.h" + +namespace mlx::core::scheduler { + +struct StreamThread { + std::mutex mtx; + std::queue> q; + std::condition_variable cond; + bool stop; + Stream stream; + std::thread thread; + + StreamThread(Stream stream) + : stop(false), stream(stream), thread(&StreamThread::thread_fn, this) {} + + ~StreamThread() { + { + std::unique_lock lk(mtx); + stop = true; + } + cond.notify_one(); + thread.join(); + } + + void thread_fn() { + metal::new_stream(stream); + while (true) { + std::function task; + { + std::unique_lock lk(mtx); + cond.wait(lk, [this] { return !this->q.empty() || this->stop; }); + if (q.empty() && stop) { + return; + } + task = std::move(q.front()); + q.pop(); + } + task(); + } + } + + template + void enqueue(F&& f) { + { + std::unique_lock lk(mtx); + if (stop) { + throw std::runtime_error( + "Cannot enqueue work after stream is stopped."); + } + q.emplace(std::forward(f)); + } + cond.notify_one(); + } +}; + +class Scheduler { + public: + Scheduler() : n_active_tasks_(0) { + if (metal::is_available()) { + default_streams_.insert({Device::gpu, new_stream(Device::gpu)}); + } + default_streams_.insert({Device::cpu, new_stream(Device::cpu)}); + } + + // Not copyable or moveable + Scheduler(const Scheduler&) = delete; + Scheduler(Scheduler&&) = delete; + Scheduler& operator=(const Scheduler&) = delete; + Scheduler& operator=(Scheduler&&) = delete; + + Stream new_stream(const Device& d) { + auto stream = Stream(streams_.size(), d); + streams_.push_back(new StreamThread{stream}); + return stream; + } + + template + void enqueue(const Stream& stream, F&& f); + + Stream get_default_stream(const Device& d) { + return default_streams_.at(d.type); + } + + void set_default_stream(const Stream& s) { + default_streams_.at(s.device.type) = s; + } + + void notify_new_task(const Stream& stream) { + { + std::unique_lock lk(mtx); + n_active_tasks_++; + } + completion_cv.notify_all(); + } + + void notify_task_completion(const Stream& stream) { + { + std::unique_lock lk(mtx); + n_active_tasks_--; + } + completion_cv.notify_all(); + } + + int n_active_tasks() const { + return n_active_tasks_; + } + + void wait_for_one() { + std::unique_lock lk(mtx); + int n_tasks_old = n_active_tasks(); + if (n_tasks_old > 1) { + completion_cv.wait(lk, [this, n_tasks_old] { + return this->n_active_tasks() != n_tasks_old; + }); + } + } + + ~Scheduler() { + for (auto s : streams_) { + delete s; + } + } + + private: + int n_active_tasks_; + std::vector streams_; + std::unordered_map default_streams_; + std::condition_variable completion_cv; + std::mutex mtx; +}; + +template +void Scheduler::enqueue(const Stream& stream, F&& f) { + streams_[stream.index]->enqueue(std::forward(f)); +} + +Scheduler& scheduler(); + +template +void enqueue(const Stream& stream, F&& f) { + scheduler().enqueue(stream, std::forward(f)); +} + +inline int n_active_tasks() { + return scheduler().n_active_tasks(); +} + +inline void notify_new_task(const Stream& stream) { + scheduler().notify_new_task(stream); +} + +inline void notify_task_completion(const Stream& stream) { + scheduler().notify_task_completion(stream); +} + +inline void wait_for_one() { + scheduler().wait_for_one(); +} + +} // namespace mlx::core::scheduler diff --git a/mlx/stream.h b/mlx/stream.h new file mode 100644 index 0000000000..000185f0f5 --- /dev/null +++ b/mlx/stream.h @@ -0,0 +1,30 @@ +#pragma once + +#include "mlx/device.h" + +namespace mlx::core { + +struct Stream { + int index; + Device device; + explicit Stream(int index, Device device) : index(index), device(device) {} +}; + +/** Get the default stream for the given device. */ +Stream default_stream(Device d); + +/** Make the stream the default for its device. */ +void set_default_stream(Stream s); + +/** Make a new stream on the given device. */ +Stream new_stream(Device d); + +inline bool operator==(const Stream& lhs, const Stream& rhs) { + return lhs.index == rhs.index; +} + +inline bool operator!=(const Stream& lhs, const Stream& rhs) { + return !(lhs == rhs); +} + +} // namespace mlx::core diff --git a/mlx/types/half_types.h b/mlx/types/half_types.h new file mode 100644 index 0000000000..47ea3d08eb --- /dev/null +++ b/mlx/types/half_types.h @@ -0,0 +1,54 @@ +#pragma once +#ifdef __ARM_FEATURE_FP16_SCALAR_ARITHMETIC + +#include +namespace mlx::core { +typedef __fp16 float16_t; +} // namespace mlx::core + +#else + +#define ADD_HALF_BINOPS +#include "mlx/types/fp16.h" +namespace mlx::core { +typedef struct _MLX_Float16 float16_t; +} // namespace mlx::core + +#endif // __ARM_FEATURE_FP16_SCALAR_ARITHMETIC +#ifdef __ARM_FEATURE_BF16 + +#include +namespace mlx::core { +typedef __bf16 bfloat16_t; +} // namespace mlx::core + +#else + +#define ADD_HALF_BINOPS +#include "mlx/types/bf16.h" +namespace mlx::core { +typedef struct _MLX_BFloat16 bfloat16_t; +} // namespace mlx::core + +#endif // __ARM_FEATURE_BF16 + +#ifdef ADD_HALF_BINOPS +namespace mlx::core { + +// clang-format off +#define fp16_bf16_binop_helper(__op__, __operator__) \ + inline float __operator__(float16_t lhs, bfloat16_t rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } \ + inline float __operator__(bfloat16_t lhs, float16_t rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } + +fp16_bf16_binop_helper(+, operator+) +fp16_bf16_binop_helper(-, operator-) +fp16_bf16_binop_helper(*, operator*) +fp16_bf16_binop_helper(/, operator/) +// clang-format on + +} // namespace mlx::core +#endif diff --git a/mlx/utils.cpp b/mlx/utils.cpp new file mode 100644 index 0000000000..027e104e92 --- /dev/null +++ b/mlx/utils.cpp @@ -0,0 +1,255 @@ +#include +#include + +#include "utils.h" + +namespace mlx::core { + +Dtype result_type(const std::vector& arrays) { + std::vector dtypes(1, bool_); + for (auto& arr : arrays) { + dtypes.push_back(promote_types(dtypes.back(), arr.dtype())); + } + return dtypes.back(); +} + +std::vector broadcast_shapes( + const std::vector& s1, + const std::vector& s2) { + // Use the same broadcasting rules as numpy + // https://numpy.org/doc/1.20/user/theory.broadcasting.html + // "The size of the trailing axes for both arrays in an operation must + // either be the same size or one of them must be one." + int ndim1 = s1.size(); + int ndim2 = s2.size(); + int ndim = std::max(ndim1, ndim2); + int diff = std::abs(ndim1 - ndim2); + const auto& big = ndim1 > ndim2 ? s1 : s2; + const auto& small = ndim1 > ndim2 ? s2 : s1; + std::vector out_shape(ndim); + for (int i = ndim - 1; i >= diff; --i) { + int a = big[i]; + int b = small[i - diff]; + if (b == a) { + out_shape[i] = a; + } else if (a == 1 || b == 1) { + // 0 if a or b is 0 otherwise max(a, b) + out_shape[i] = a * b; + } else { + std::ostringstream msg; + msg << "Shapes " << s1 << " and " << s2 << " cannot be broadcast."; + throw std::invalid_argument(msg.str()); + } + } + for (int i = diff - 1; i >= 0; --i) { + out_shape[i] = big[i]; + } + return out_shape; +} + +std::ostream& operator<<(std::ostream& os, const Device& d) { + os << "Device("; + switch (d.type) { + case Device::cpu: + os << "cpu"; + break; + case Device::gpu: + os << "gpu"; + break; + } + os << ", " << d.index << ")"; + return os; +} + +std::ostream& operator<<(std::ostream& os, const Stream& s) { + os << "Stream("; + os << s.device; + os << ", " << s.index << ")"; + return os; +} + +std::ostream& operator<<(std::ostream& os, int8_t x) { + os << static_cast(x); + return os; +} + +std::ostream& operator<<(std::ostream& os, uint8_t x) { + os << static_cast(x); + return os; +} + +namespace { + +inline size_t elem_to_loc( + int elem, + const std::vector& shape, + const std::vector& strides) { + size_t loc = 0; + for (int i = shape.size() - 1; i >= 0; --i) { + auto q_and_r = ldiv(elem, shape[i]); + loc += q_and_r.rem * strides[i]; + elem = q_and_r.quot; + } + return loc; +} + +template +void print_subarray(std::ostream& os, const array& a, size_t index, int dim) { + int num_print = 3; + int n = a.shape(dim); + size_t s = a.strides()[dim]; + bool is_last = dim == a.ndim() - 1; + auto prefix = is_last ? "" : std::string(7 + dim, ' '); + auto postfix = is_last ? ", " : ",\n"; + os << "["; + for (int i = 0; i < n; ++i) { + os << (i == 0 ? "" : prefix); + if (i == num_print && n > 2 * num_print) { + os << "..."; + i = n - num_print - 1; + index += s * (n - 2 * num_print - 1); + } else if (is_last) { + os << a.data()[index]; + } else { + print_subarray(os, a, index, dim + 1); + } + os << (i == n - 1 ? "" : postfix); + index += s; + } + os << "]"; +} + +template +void print_array(std::ostream& os, const array& a) { + std::vector indices(a.ndim(), 0); + os << std::boolalpha; + os << "array("; + if (a.ndim() == 0) { + auto data = a.data(); + os << data[0]; + } else { + print_subarray(os, a, 0, 0); + } + os << ", dtype=" << a.dtype() << ")"; + os << std::noboolalpha; +} + +} // namespace + +std::ostream& operator<<(std::ostream& os, const Dtype& dtype) { + switch (dtype) { + case bool_: + return os << "bool"; + case uint8: + return os << "uint8"; + case uint16: + return os << "uint16"; + case uint32: + return os << "uint32"; + case uint64: + return os << "uint64"; + case int8: + return os << "int8"; + case int16: + return os << "int16"; + case int32: + return os << "int32"; + case int64: + return os << "int64"; + case float16: + return os << "float16"; + case float32: + return os << "float32"; + case bfloat16: + return os << "bfloat16"; + case complex64: + return os << "complex64"; + } + return os; +} + +std::ostream& operator<<(std::ostream& os, const Dtype::Kind& k) { + switch (k) { + case Dtype::Kind::b: + return os << "b"; + case Dtype::Kind::i: + return os << "i"; + case Dtype::Kind::u: + return os << "u"; + case Dtype::Kind::f: + return os << "f"; + case Dtype::Kind::c: + return os << "c"; + case Dtype::Kind::V: + return os << "V"; + } + return os; +} + +std::ostream& operator<<(std::ostream& os, array a) { + if (!a.is_evaled()) { + a.eval(); + } + switch (a.dtype()) { + case bool_: + print_array(os, a); + break; + case uint8: + print_array(os, a); + break; + case uint16: + print_array(os, a); + break; + case uint32: + print_array(os, a); + break; + case uint64: + print_array(os, a); + break; + case int8: + print_array(os, a); + break; + case int16: + print_array(os, a); + break; + case int32: + print_array(os, a); + break; + case int64: + print_array(os, a); + break; + case float16: + print_array(os, a); + break; + case bfloat16: + print_array(os, a); + break; + case float32: + print_array(os, a); + break; + case complex64: + print_array(os, a); + break; + } + return os; +} + +std::ostream& operator<<(std::ostream& os, const std::vector& v) { + os << "("; + for (int i = 0; i < v.size(); ++i) { + os << v[i] << ((i == v.size() - 1) ? "" : ","); + } + os << ")"; + return os; +} + +std::ostream& operator<<(std::ostream& os, const std::vector& v) { + os << "("; + for (int i = 0; i < v.size(); ++i) { + os << v[i] << ((i == v.size() - 1) ? "" : ","); + } + os << ")"; + return os; +} + +} // namespace mlx::core diff --git a/mlx/utils.h b/mlx/utils.h new file mode 100644 index 0000000000..b004ae8ba4 --- /dev/null +++ b/mlx/utils.h @@ -0,0 +1,33 @@ +#pragma once + +#include "array.h" +#include "device.h" +#include "dtype.h" +#include "stream.h" + +namespace mlx::core { + +/** The type from promoting the arrays' types with one another. */ +Dtype result_type(const std::vector& arrays); + +std::vector broadcast_shapes( + const std::vector& s1, + const std::vector& s2); + +std::ostream& operator<<(std::ostream& os, const Device& d); +std::ostream& operator<<(std::ostream& os, const Stream& s); +std::ostream& operator<<(std::ostream& os, const Dtype& d); +std::ostream& operator<<(std::ostream& os, const Dtype::Kind& k); +std::ostream& operator<<(std::ostream& os, array a); +std::ostream& operator<<(std::ostream& os, const std::vector& v); +std::ostream& operator<<(std::ostream& os, const std::vector& v); +inline std::ostream& operator<<(std::ostream& os, const complex64_t& v) { + return os << v.real() << (v.imag() > 0 ? "+" : "") << v.imag() << "j"; +} +inline std::ostream& operator<<(std::ostream& os, const float16_t& v) { + return os << static_cast(v); +} +inline std::ostream& operator<<(std::ostream& os, const bfloat16_t& v) { + return os << static_cast(v); +} +} // namespace mlx::core diff --git a/python/README.md b/python/README.md new file mode 100644 index 0000000000..4d6979c19f --- /dev/null +++ b/python/README.md @@ -0,0 +1,37 @@ +### Packaging for PyPI + +Install `build` and `twine`: + +``` +pip install --user --upgrade build +pip install --user --upgrade twine +``` + +Generate the source distribution and wheel: + +``` +python -m build +``` + +*Warning* use a test server first + +#### Test Upload + +Upload to test server: + +``` +python -m twine upload --repository testpypi dist/* +``` + +Install from test server and check that it works: + +``` +python -m pip install --index-url https://test.pypi.org/simple/ --no-deps mlx +``` + +#### Upload + +``` +python -m twine upload dist/* +``` + diff --git a/python/mlx/_reprlib_fix.py b/python/mlx/_reprlib_fix.py new file mode 100644 index 0000000000..51111e5dd2 --- /dev/null +++ b/python/mlx/_reprlib_fix.py @@ -0,0 +1,18 @@ +import array +import reprlib + + +class FixedRepr(reprlib.Repr): + """Only route python array instances to repr_array.""" + + def repr_array(self, x, maxlevel): + if isinstance(x, array.array): + return super().repr_array(x, maxlevel) + else: + return self.repr_instance(x, maxlevel) + + +# We need to monkey-patch reprlib so that we can use the debugger without +# renaming the array to something else +fixed_repr = FixedRepr() +reprlib.repr = fixed_repr.repr diff --git a/python/mlx/extension.py b/python/mlx/extension.py new file mode 100644 index 0000000000..8dd7a6a723 --- /dev/null +++ b/python/mlx/extension.py @@ -0,0 +1,94 @@ +import os +import re +import subprocess +import sys +from pathlib import Path + +from setuptools import Extension, setup, find_namespace_packages +from setuptools.command.build_ext import build_ext + +import mlx + +_MLX_PATH = str(mlx.__path__[0]) + + +# A CMakeExtension needs a sourcedir instead of a file list. +class CMakeExtension(Extension): + def __init__(self, name: str, sourcedir: str = "") -> None: + super().__init__(name, sources=[]) + self.sourcedir = os.fspath(Path(sourcedir).resolve()) + + +class CMakeBuild(build_ext): + def build_extension(self, ext: CMakeExtension) -> None: + # Must be in this form due to bug in .resolve() only fixed in Python 3.10+ + ext_fullpath = Path.cwd() / self.get_ext_fullpath(ext.name) # type: ignore[no-untyped-call] + extdir = ext_fullpath.parent.resolve() + + debug = int(os.environ.get("DEBUG", 0)) if self.debug is None else self.debug + cfg = "Debug" if debug else "Release" + + # CMake lets you override the generator - we need to check this. + # Can be set with Conda-Build, for example. + cmake_generator = os.environ.get("CMAKE_GENERATOR", "") + + # Set Python_EXECUTABLE instead if you use PYBIND11_FINDPYTHON + # EXAMPLE_VERSION_INFO shows you how to pass a value into the C++ code + # from Python. + cmake_args = [ + f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={extdir}{os.sep}", + f"-DCMAKE_BUILD_TYPE={cfg}", + "-DBUILD_SHARED_LIBS=ON", + ] + build_args = [] + # Adding CMake arguments set as environment variable + # (needed e.g. to build for ARM OSx on conda-forge) + if "CMAKE_ARGS" in os.environ: + cmake_args += [item for item in os.environ["CMAKE_ARGS"].split(" ") if item] + + if sys.platform.startswith("darwin"): + # Cross-compile support for macOS - respect ARCHFLAGS if set + archs = re.findall(r"-arch (\S+)", os.environ.get("ARCHFLAGS", "")) + if archs: + cmake_args += ["-DCMAKE_OSX_ARCHITECTURES={}".format(";".join(archs))] + + # Set CMAKE_BUILD_PARALLEL_LEVEL to control the parallel build level + # across all generators. + if "CMAKE_BUILD_PARALLEL_LEVEL" not in os.environ: + # self.parallel is a Python 3 only way to set parallel jobs by hand + # using -j in the build_ext call, not supported by pip or PyPA-build. + if hasattr(self, "parallel") and self.parallel: + # CMake 3.12+ only. + build_args += [f"-j{self.parallel}"] + + build_temp = Path(self.build_temp) / ext.name + if not build_temp.exists(): + build_temp.mkdir(parents=True) + + # Make sure cmake can find MLX + os.environ["MLX_DIR"] = _MLX_PATH + + subprocess.run( + ["cmake", ext.sourcedir, *cmake_args], cwd=build_temp, check=True + ) + subprocess.run( + ["cmake", "--build", ".", *build_args], cwd=build_temp, check=True + ) + + def run(self): + super().run() + + # Based on https://github.com/pypa/setuptools/blob/main/setuptools/command/build_ext.py#L102 + if self.inplace: + for ext in self.extensions: + if isinstance(ext, CMakeExtension): + # Resolve inplace package dir + build_py = self.get_finalized_command("build_py") + inplace_file, regular_file = self._get_inplace_equivalent( + build_py, ext + ) + + inplace_dir = str(Path(inplace_file).parent.resolve()) + regular_dir = str(Path(regular_file).parent.resolve()) + + self.copy_tree(regular_dir, inplace_dir) diff --git a/python/mlx/nn/layers/base.py b/python/mlx/nn/layers/base.py new file mode 100644 index 0000000000..3a3cd3d26b --- /dev/null +++ b/python/mlx/nn/layers/base.py @@ -0,0 +1,401 @@ +import textwrap +from typing import Any, Callable, List, Union, Optional + +import mlx.core as mx +from mlx.utils import tree_flatten, tree_unflatten + + +class Module(dict): + """Base class for building neural networks with MLX. + + All the layers provided in :mod:`mlx.nn.layers` subclass this class and + your models should do the same. + + A ``Module`` can contain other ``Module`` instances or :class:`mlx.core.array` + instances in arbitrary nesting of python lists or dicts. The ``Module`` + then allows recursively extracting all the :class:`mlx.core.array` instances + using :meth:`mlx.nn.Module.parameters`. + + In addition, the ``Module`` has the concept of trainable and non trainable + parameters (called "frozen"). When using :func:`mlx.nn.value_and_grad` + the gradients are returned only with respect to the trainable parameters. + All arrays in a module are trainable unless they are added in the "frozen" + set by calling :meth:`freeze`. + + .. code-block:: python + + import mlx.core as mx + import mlx.nn as nn + + class MyMLP(nn.Module): + def __init__(self, in_dims: int, out_dims: int, hidden_dims: int = 16): + super().__init__() + + self.in_proj = nn.Linear(in_dims, hidden_dims) + self.out_proj = nn.Linear(hidden_dims, out_dims) + + def __call__(self, x): + x = self.in_proj(x) + x = mx.maximum(x, 0) + return self.out_proj(x) + + model = MyMLP(2, 1) + + # All the model parameters are created but since MLX is lazy by + # default, they are not evaluated yet. Calling `mx.eval` actually + # allocates memory and initializes the parameters. + mx.eval(model.parameters()) + + # Setting a parameter to a new value is as simply as accessing that + # parameter and assigning a new array to it. + model.in_proj.weight = model.in_proj.weight * 2 + mx.eval(model.parameters()) + """ + + def __init__(self): + """Should be called by the subclasses of ``Module``.""" + self._no_grad = set() + self._training = True + + @property + def training(self): + return self._training + + def _extra_repr(self): + return "" + + def __repr__(self): + children = tree_flatten(self.children(), is_leaf=self.is_module) + value = f"{type(self).__name__}({self._extra_repr()}" + for k, v in children: + value += "\n" + value += textwrap.indent(f"({k}): {repr(v)}", prefix=" ") + if children: + value += "\n" + value += ")" + + return value + + def __getattr__(self, key: str): + if key in self: + return self[key] + else: + raise AttributeError(f"{type(self)!r} has no attribute {key!r}") + + def __setattr__(self, key: str, val: Any): + self[key] = val + + def load_weights(self, file: str): + """ + Load and update the model's weights from a `.npz` file. + """ + self.update(tree_unflatten(list(mx.load(file).items()))) + + def save_weights(self, file: str): + """ + Save the model's weights to a `.npz` file. + """ + mx.savez(file, **dict(tree_flatten(self.parameters()))) + + @staticmethod + def is_module(value): + return isinstance(value, Module) + + @staticmethod + def valid_child_filter(module, key, value): + return isinstance(value, (dict, list)) + + @staticmethod + def valid_parameter_filter(module, key, value): + return isinstance(value, (dict, list, mx.array)) and not key.startswith("_") + + @staticmethod + def trainable_parameter_filter(module, key, value): + return ( + Module.valid_parameter_filter(module, key, value) + and key not in module._no_grad + ) + + def filter_and_map( + self, + filter_fn: Callable[["mlx.nn.Module", str, Any], bool], + map_fn: Optional[Callable] = None, + is_leaf_fn: Optional[Callable[["mlx.nn.Module", str, Any], bool]] = None, + ): + """Recursively filter the contents of the module using ``filter_fn``, + namely only select keys and values where ``filter_fn`` returns true. + + This is used to implement :meth:`parameters` and :meth:`trainable_parameters` + but it can also be used to extract any subset of the module's parameters. + + Args: + filter_fn (Callable): Given a value, the key in which it is found + and the containing module, decide whether to keep the value or + drop it. + map_fn (Callable, optional): Optionally transform the value before + returning it. + is_leaf_fn (Callable, optional): Given a value, the key in which it + is found and the containing module decide if it is a leaf. + + Returns: + A dictionary containing the contents of the module recursively filtered + """ + + map_fn = map_fn or (lambda x: x) + is_leaf_fn = is_leaf_fn or ( + lambda m, k, v: not isinstance(v, (Module, dict, list)) + ) + + def unwrap(vk, v): + if is_leaf_fn(self, vk, v): + return map_fn(v) + + if isinstance(v, Module): + return v.filter_and_map(filter_fn, map_fn, is_leaf_fn) + + if isinstance(v, dict): + nd = {} + for k, v in v.items(): + tk = f"{vk}.{k}" + nd[k] = unwrap(tk, v) if filter_fn(self, tk, v) else {} + return nd + + if isinstance(v, list): + nl = [] + for i, vi in enumerate(v): + tk = f"{vk}.{i}" + nl.append(unwrap(tk, vi) if filter_fn(self, tk, vi) else {}) + return nl + + raise RuntimeError("Unexpected leaf found while traversing the module") + + return {k: unwrap(k, v) for k, v in self.items() if filter_fn(self, k, v)} + + def parameters(self): + """Recursively return all the :class:`mlx.core.array` members of this Module + as a dict of dicts and lists.""" + return self.filter_and_map(self.valid_parameter_filter) + + def trainable_parameters(self): + """Recursively return all the non frozen :class:`mlx.core.array` members of + this Module as a dict of dicts and lists.""" + return self.filter_and_map(self.trainable_parameter_filter) + + def children(self): + """Return the direct descendants of this Module instance.""" + return self.filter_and_map( + self.valid_child_filter, is_leaf_fn=lambda m, k, v: isinstance(v, Module) + ) + + def leaf_modules(self): + """Return the submodules that do not contain other modules.""" + + def _is_leaf_module(m, k, v): + return isinstance(v, Module) and len(tree_flatten(v.children())) == 0 + + return self.filter_and_map(self.valid_child_filter, is_leaf_fn=_is_leaf_module) + + def update(self, parameters: dict): + """Replace the parameters of this Module with the provided ones in the + dict of dicts and lists. + + Commonly used by the optimizer to change the model to the updated + (optimized) parameters. Also used by the :meth:`mlx.nn.value_and_grad` to set the + tracers in the model in order to compute gradients. + + The passed in parameters dictionary need not be a full dictionary + similar to :meth:`parameters`. Only the provided locations will be + updated. + + Args: + parameters (dict): A complete or partial dictionary of the modules + parameters. + """ + + def apply(dst, parameters): + if isinstance(parameters, dict): + for k in parameters: + if k in dst: + current_value = dst[k] + new_value = parameters[k] + if isinstance(current_value, mx.array): + dst[k] = new_value + elif isinstance(current_value, Module): + current_value.update(new_value) + elif isinstance(current_value, (dict, list)): + apply(current_value, new_value) + elif isinstance(parameters, list): + for i in range(len(dst)): + current_value = dst[i] + new_value = parameters[i] + if isinstance(current_value, mx.array): + dst[i] = new_value + elif isinstance(current_value, Module): + current_value.update(new_value) + elif isinstance(current_value, (dict, list)): + apply(current_value, new_value) + + apply(self, parameters) + + def apply( + self, + map_fn: Callable[[mx.array], mx.array], + filter_fn: Optional[Callable[["mlx.nn.Module", str, Any], bool]] = None, + ): + """Map all the parameters using the provided ``map_fn`` and immediately + update the module with the mapped parameters. + + For instance running ``model.apply(lambda x: x.astype(mx.float16))`` + casts all parameters to 16 bit floats. + + Args: + map_fn (Callable): Maps an array to another array + filter_fn (Callable, optional): Filter to select which arrays to + map (default: :meth:`Module.valid_parameter_filter`). + """ + filter_fn = filter_fn or Module.valid_parameter_filter + self.update(self.filter_and_map(filter_fn, map_fn)) + + def apply_to_modules(self, apply_fn: Callable[[str, "mlx.nn.Module"], Any]): + """Apply a function to all the modules in this instance (including this + instance). + + Args: + apply_fn (Callable): The function to apply to the modules. + """ + module_stack = [("", self)] + while module_stack: + prefix, mod = module_stack.pop() + apply_fn(prefix, mod) + prefix = "." + prefix if prefix else "" + module_stack.extend( + tree_flatten(mod.children(), prefix=prefix, is_leaf=self.is_module) + ) + + def modules(self): + """Return a list with all the modules in this instance. + + Returns: + A list of :class:`mlx.nn.Module` instances. + """ + modulelist = [] + self.apply_to_modules(lambda k, m: modulelist.append(m)) + return modulelist + + def named_modules(self): + """Return a list with all the modules in this instance and their name + with dot notation. + + Returns: + A list of tuples (str, :class:`mlx.nn.Module`). + """ + modulelist = [] + self.apply_to_modules(lambda k, m: modulelist.append((k, m))) + return modulelist + + def _validate_keys(self, keys, strict): + keys = keys if isinstance(keys, list) else [keys] + if strict: + for k in keys: + if k not in self: + raise KeyError(f"Module doesn't contain member {k}.") + return keys + + def freeze( + self, + *, + recurse: bool = True, + keys: Optional[Union[str, List[str]]] = None, + strict: bool = False, + ): + """Freeze the Module's parameters or some of them. Freezing a parameter means not + computing gradients for it. + + This function is idempotent ie freezing a frozen model is a noop. + + For instance to only train the attention parameters from a transformer: + + model = ... + model.freeze() + model.apply_to_modules(lambda k, v: v.unfreeze() if k.endswith("attention") else None) + + Args: + recurse (bool, optional): If True then freeze the parameters of the + submodules as well (default: True). + keys (str or list[str], optional): If provided then only these + parameters will be frozen otherwise all the parameters of a + module. For instance freeze all biases by calling + ``module.freeze(keys="bias")``. + strict (bool, optional): If set to True validate that the passed keys exist + (default: False). + """ + + def _freeze_impl(_, m): + local_keys = keys + if local_keys is None: + local_keys = tree_flatten( + m.filter_and_map( + lambda m, k, v: (not isinstance(v, Module)) + and m.valid_parameter_filter(m, k, v) + ) + ) + local_keys = [k for (k, v) in local_keys] + + local_keys = m._validate_keys(local_keys, strict) + m._no_grad.update(local_keys) + + if recurse: + self.apply_to_modules(_freeze_impl) + else: + _freeze_impl("", self) + + def unfreeze( + self, + *, + recurse: bool = True, + keys: Optional[Union[str, List[str]]] = None, + strict: bool = False, + ): + """Unfreeze the Module's parameters or some of them. + + This function is idempotent ie unfreezing a model that is not frozen is + a noop. + + For instance to only train the biases one can do: + + model = ... + model.freeze() + model.unfreeze(keys="bias") + + Args: + recurse (bool, optional): If True then unfreeze the parameters of the + submodules as well (default: True). + keys (str or list[str], optional): If provided then only these + parameters will be unfrozen otherwise all the parameters of a + module. For instance unfreeze all biases by calling + ``module.unfreeze(keys="bias")``. + strict (bool, optional): If set to True validate that the passed keys exist + (default: False). + """ + + def _unfreeze_impl(_, m): + if keys is None: + m._no_grad.clear() + + else: + local_keys = m._validate_keys(keys, strict) + m._no_grad.difference_update(local_keys) + + if recurse: + self.apply_to_modules(_unfreeze_impl) + else: + _unfreeze_impl("", self) + + def train(self, mode: bool = True): + def _set_train(_, m): + m._training = mode + + self.apply_to_modules(_set_train) + + def eval(self): + self.train(False) diff --git a/python/mlx/nn/layers/containers.py b/python/mlx/nn/layers/containers.py new file mode 100644 index 0000000000..8403be069c --- /dev/null +++ b/python/mlx/nn/layers/containers.py @@ -0,0 +1,22 @@ +from mlx.nn.layers.base import Module + + +class Sequential(Module): + """A layer that calls the passed callables in order. + + We can pass either modules or plain callables to the Sequential module. If + our functions have learnable parameters they should be implemented as + ``nn.Module`` instances. + + Args: + modules (tuple of Callables): The modules to call in order + """ + + def __init__(self, *modules): + super().__init__() + self.layers = list(modules) + + def __call__(self, x): + for m in self.layers: + x = m(x) + return x diff --git a/python/mlx/nn/layers/dropout.py b/python/mlx/nn/layers/dropout.py new file mode 100644 index 0000000000..e6e889387a --- /dev/null +++ b/python/mlx/nn/layers/dropout.py @@ -0,0 +1,33 @@ +import mlx.core as mx +from mlx.nn.layers.base import Module + + +class Dropout(Module): + """Randomly zero a portion of the elements during training. + + The remaining elements are multiplied with :math:`\frac{1}{1-p}` where + :math:`p` is the probability of zeroing an element. This is done so the + expected value of a given element will remain the same. + + Args: + p (float): The probability to zero an element + """ + + def __init__(self, p: float = 0.5): + super().__init__() + + if p < 0 or p >= 1: + raise ValueError("The dropout probability should be in [0, 1)") + + self._p_1 = 1 - p + + def _extra_repr(self): + return f"p={1-self._p_1}" + + def __call__(self, x): + if self._p_1 == 1 or not self.training: + return x + + mask = mx.random.bernoulli(self._p_1, x.shape) + + return (1 / self._p_1) * mask.astype(x.dtype) * x diff --git a/python/mlx/nn/layers/normalization.py b/python/mlx/nn/layers/normalization.py new file mode 100644 index 0000000000..6d0dd4b3a5 --- /dev/null +++ b/python/mlx/nn/layers/normalization.py @@ -0,0 +1,178 @@ +import mlx.core as mx +from mlx.nn.layers.base import Module + + +class LayerNorm(Module): + r"""Applies layer normalization [1] on the inputs. + + Computes + + .. math:: + + y = \frac{x - E[x]}{\sqrt{Var[x]} + \epsilon} \gamma + \beta, + + where :math:`\gamma` and :math:`\beta` are learned per feature dimension + parameters initialized at 1 and 0 respectively. + + [1]: https://arxiv.org/abs/1607.06450 + + Args: + dims (int): The feature dimension of the input to normalize over + eps (float): A small additive constant for numerical stability + affine (bool): If True learn an affine transform to apply after the + normalization + """ + + def __init__(self, dims: int, eps: float = 1e-5, affine: bool = True): + super().__init__() + if affine: + self.bias = mx.zeros((dims,)) + self.weight = mx.ones((dims,)) + self.eps = eps + self.dims = dims + + def _extra_repr(self): + return f"{self.dims}, eps={self.eps}, affine={'weight' in self}" + + def __call__(self, x): + means = mx.mean(x, axis=-1, keepdims=True) + var = mx.var(x, axis=-1, keepdims=True) + x = (x - means) * mx.rsqrt(var + self.eps) + return (self.weight * x + self.bias) if "weight" in self else x + + +class RMSNorm(Module): + r"""Applies Root Mean Square normalization [1] to the inputs. + + Computes + + .. math:: + + y = \frac{x}{\sqrt{E[x^2] + \epsilon}} \gamma + + where :math:`\gamma` is a learned per feature dimension parameter initialized at + 1. + + [1]: https://arxiv.org/abs/1910.07467 + + Args: + dims (int): The feature dimension of the input to normalize over + eps (float): A small additive constant for numerical stability + """ + + def __init__(self, dims: int, eps: float = 1e-5): + super().__init__() + self.weight = mx.ones((dims,)) + self.eps = eps + + def _extra_repr(self): + return f"{self.weight.shape[0]}, eps={self.eps}" + + def __call__(self, x): + # S is 1/sqrt(N) where N is the size of the features of x and is used + # to compute a numerically more stable RMS of x by multiplying with S + # first and summing. + # + # This way we prefer underflow over overflow which is controlled with + # the parameter epsilon anyway. + S = 1 / x.shape[-1] ** 0.5 + + n = (x * S).square().sum(axis=-1, keepdims=True) + n = mx.rsqrt(n + self.eps) + + return self.weight * x * n + + +class GroupNorm(Module): + r"""Applies Group Normalization [1] to the inputs. + + Computes the same normalization as layer norm, namely + + .. math:: + + y = \frac{x - E[x]}{\sqrt{Var[x]} + \epsilon} \gamma + \beta, + + where :math:`\gamma` and :math:`\beta` are learned per feature dimension + parameters initialized at 1 and 0 respectively. However, the mean and + variance are computed over the spatial dimensions and each group of + features. In particular, the input is split into num_groups accross the + feature dimension. + + The feature dimension is assumed to be the last dimension and the dimensions + that precede it (except the first) are considered the spatial dimensions. + + [1]: https://arxiv.org/abs/1803.08494 + + Args: + num_groups (int): Number of groups to separate the features into + dims (int): The feature dimensions of the input to normalize over + eps (float): A small additive constant for numerical stability + affine (bool): If True learn an affine transform to apply after the + normalization. + pytorch_compatible (bool): If True perform the group normalization in + the same order/grouping as PyTorch. + """ + + def __init__( + self, + num_groups: int, + dims: int, + eps: float = 1e-5, + affine: bool = True, + pytorch_compatible: bool = False, + ): + super().__init__() + if affine: + self.bias = mx.zeros((dims,)) + self.weight = mx.ones((dims,)) + self.num_groups = num_groups + self.dims = dims + self.eps = eps + self.pytorch_compatible = pytorch_compatible + + def _extra_repr(self): + return ( + f"{self.num_groups}, {self.dims}, eps={self.eps}, " + f"affine={'weight' in self}, pytorch_compatible={self.pytorch_compatible}" + ) + + def _pytorch_compatible_group_norm(self, x): + num_groups = self.num_groups + batch, *rest, dims = x.shape + + # Split into groups + x = x.reshape(batch, -1, num_groups, dims // num_groups) + x = x.transpose(0, 1, 3, 2).reshape(batch, -1, num_groups) + + # Normalize + means = mx.mean(x, axis=1, keepdims=True) + var = mx.var(x, axis=1, keepdims=True) + x = (x - means) * mx.rsqrt(var + self.eps) + x = x.reshape(batch, -1, dims // num_groups, num_groups) + x = x.transpose(0, 1, 3, 2).reshape(batch, *rest, dims) + + return x + + def _group_norm(self, x): + num_groups = self.num_groups + batch, *rest, dims = x.shape + + # Split into groups + x = x.reshape(batch, -1, num_groups) + + # Normalize + means = mx.mean(x, axis=1, keepdims=True) + var = mx.var(x, axis=1, keepdims=True) + x = (x - means) * mx.rsqrt(var + self.eps) + x = x.reshape(batch, *rest, dims) + + return x + + def __call__(self, x): + group_norm = ( + self._pytorch_compatible_group_norm + if self.pytorch_compatible + else self._group_norm + ) + x = group_norm(x) + return (self.weight * x + self.bias) if "weight" in self else x diff --git a/python/mlx/nn/layers/positional_encoding.py b/python/mlx/nn/layers/positional_encoding.py new file mode 100644 index 0000000000..9ccc2b6928 --- /dev/null +++ b/python/mlx/nn/layers/positional_encoding.py @@ -0,0 +1,142 @@ +import math +from typing import Optional + +import mlx.core as mx +from mlx.nn.layers.base import Module + + +class RoPE(Module): + """Implements the rotary positional encoding [1]. + + The traditional implementation rotates consecutive pairs of elements in the + feature dimension while the default implementation rotates pairs with + stride half the feature dimensions for efficiency. + + [1]: https://arxiv.org/abs/2104.09864 + + Args: + dims (int): The feature dimensions to be rotated. If the input feature + is larger than dims then the rest is left unchanged. + traditional (bool): If set to True choose the traditional + implementation which is slightly less efficient. + """ + + def __init__(self, dims: int, traditional: bool = False): + super().__init__() + self.dims = dims + self.traditional = traditional + + def _extra_repr(self): + return f"{self.dims}, traditional={self.traditional}" + + def _compute_rope(self, costheta, sintheta, x): + x1 = x[..., : self.dims // 2] + x2 = x[..., self.dims // 2 : self.dims] + rx1 = x1 * costheta - x2 * sintheta + rx2 = x1 * sintheta + x2 * costheta + + if self.dims < x.shape[-1]: + rx = mx.concatenate([rx1, rx2, x[..., self.dims :]], axis=-1) + else: + rx = mx.concatenate([rx1, rx2], axis=-1) + + return rx + + def _compute_traditional_rope(self, costheta, sintheta, x): + x1 = x[..., ::2] + x2 = x[..., 1::2] + rx1 = x1 * costheta - x2 * sintheta + rx2 = x1 * sintheta + x2 * costheta + + if self.dims < x.shape[-1]: + raise NotImplementedError( + "RoPE doesn't implement partial traditional application" + ) + + rx = mx.concatenate([rx1[..., None], rx2[..., None]], axis=-1) + + return rx + + def __call__(self, x, offset: int = 0): + shape = x.shape + x = mx.reshape(x, (-1, shape[-2], shape[-1])) + N = x.shape[1] + offset + costheta, sintheta = RoPE.create_cos_sin_theta( + N, self.dims, offset=offset, dtype=x.dtype + ) + + rope = ( + self._compute_traditional_rope if self.traditional else self._compute_rope + ) + rx = rope(costheta, sintheta, x) + + return mx.reshape(rx, shape) + + @staticmethod + def create_cos_sin_theta( + N: int, D: int, offset: int = 0, base: float = 10000, dtype=mx.float32 + ): + D = D // 2 + positions = mx.arange(offset, N, dtype=dtype) + freqs = mx.exp(-mx.arange(0.0, D, dtype=dtype) * (math.log(base) / D)) + theta = mx.reshape(positions, (-1, 1)) * mx.reshape(freqs, (1, -1)) + costheta = mx.cos(theta) + sintheta = mx.sin(theta) + + return costheta, sintheta + + +class SinusoidalPositionalEncoding(Module): + """Implements sinusoidal positional encoding similar to [1]. + + [1]: https://arxiv.org/abs/1706.03762 + + Args: + dims (int): The dimensionality of the resulting positional embeddings. + min_freq (float): The minimum frequency expected (default: 0.0001) + max_freq (float): The maximum frequency expected (default: 1) + scale (float): Scale the embeddings by that number (default: sqrt(dims//2)) + cos_first (bool): If set to True embed using ``[cos(x); sin(x)]`` + instead of the other way around (default: False) + full_turns (bool): If set to True multiply the frequencies + with ``2 pi`` (default: False) + """ + + def __init__( + self, + dims: int, + min_freq: float = 0.0001, + max_freq: float = 1, + scale: Optional[float] = None, + cos_first: bool = False, + full_turns: bool = False, + ): + super().__init__() + + one_zero = 1 - mx.arange(0, dims // 2) / (dims // 2 - 1) + min_freq = math.log(min_freq) + max_freq = math.log(max_freq) + + # Start with underscore so it is not included in the parameters + self._sigmas = mx.exp(one_zero * (max_freq - min_freq) + min_freq) + if full_turns: + self._sigmas = self._sigmas * (2 * math.pi) + + # Save some constants that define the implementation + self.scale = scale or (2 / dims) ** 0.5 + self.cos_first = cos_first + + def __call__(self, x): + y = x[..., None] * self._sigmas + cosy = mx.cos(y) + siny = mx.sin(y) + + if self.cos_first: + y = mx.concatenate([cosy, siny], axis=-1) + else: + y = mx.concatenate([siny, cosy], axis=-1) + + if self.scale != 1: + y = y * self.scale + + return y diff --git a/python/mlx/nn/layers/transformer.py b/python/mlx/nn/layers/transformer.py new file mode 100644 index 0000000000..93359a4c73 --- /dev/null +++ b/python/mlx/nn/layers/transformer.py @@ -0,0 +1,136 @@ +import math +from typing import Optional + +import mlx.core as mx +from mlx.nn.layers.base import Module +from mlx.nn.layers.linear import Linear +from mlx.nn.layers.normalization import LayerNorm + + +class MultiHeadAttention(Module): + """Implements the scaled dot product attention with multiple heads. + + Given inputs for queries, keys and values the ``MultiHeadAttention`` produces + new values by aggregating information from the input values according to + the similarities of the input queries and keys. + + All inputs as well as the output are lineary projected without biases. + + MultiHeadAttention also expects an additive attention mask that should be + broadcastable with (batch, num_heads, # queries, # keys). The mask should + have ``-inf`` or very negative numbers to the positions that should *not* be + attended to. + + Args: + dims (int): The model dimensions. If no other dims are provided then + dims is used for queries, keys, values and the output. + num_heads (int): How many attention heads to use + query_input_dims (int, optional): The input dimensions of the queries (default: dims). + key_input_dims (int, optional): The input dimensions of the keys (default: dims). + value_input_dims (int, optional): The input dimensions of the values (default: key_input_dims). + value_dims (int, optional): The dimensions of the values after the projection (default: dims). + value_output_dims (int, optional): The dimensions the new values will be projected to (default: dims). + """ + + def __init__( + self, + dims: int, + num_heads: int, + query_input_dims: Optional[int] = None, + key_input_dims: Optional[int] = None, + value_input_dims: Optional[int] = None, + value_dims: Optional[int] = None, + value_output_dims: Optional[int] = None, + ): + super().__init__() + + if (dims % num_heads) != 0: + raise ValueError( + f"The input feature dimensions should be divisble by the number of heads ({dims} % {num_heads}) != 0" + ) + + query_input_dims = query_input_dims or dims + key_input_dims = key_input_dims or dims + value_input_dims = value_input_dims or key_input_dims + value_dims = value_dims or dims + value_output_dims = value_output_dims or dims + + self.num_heads = num_heads + self.query_proj = Linear(query_input_dims, dims, False) + self.key_proj = Linear(key_input_dims, dims, False) + self.value_proj = Linear(value_input_dims, value_dims, False) + self.out_proj = Linear(value_dims, value_output_dims, False) + + def __call__(self, queries, keys, values, mask=None): + queries = self.query_proj(queries) + keys = self.key_proj(keys) + values = self.value_proj(values) + + num_heads = self.num_heads + B, L, D = queries.shape + _, S, _ = keys.shape + queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 3, 1) + values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3) + + # Dimensions are [batch x num heads x sequence x hidden dim] + scale = math.sqrt(1 / queries.shape[-1]) + scores = (queries * scale) @ keys + if mask is not None: + scores = scores + mask.astype(scores.dtype) + scores = mx.softmax(scores, axis=-1) + values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) + + return self.out_proj(values_hat) + + @staticmethod + def create_additive_causal_mask(N: int, dtype: mx.Dtype = mx.float32): + indices = mx.arange(N) + mask = indices[:, None] < indices[None] + # usually inf but 1e9 is as good and softmax(full(1e9)) != nan + # TODO: Should replace this with finfo(dtype).min + mask = mask.astype(dtype) * -1e9 + return mask + + +class TransformerEncoderLayer(Module): + def __init__(self, dims: int, num_heads: int, mlp_dims: Optional[int] = None): + super().__init__() + mlp_dims = mlp_dims or dims * 4 + self.attention = MultiHeadAttention(dims, num_heads) + self.ln1 = LayerNorm(dims) + self.ln2 = LayerNorm(dims) + self.linear1 = Linear(dims, mlp_dims) + self.linear2 = Linear(mlp_dims, dims) + + def __call__(self, x, mask): + y = self.ln1(x) + y = self.attention(y, y, y, mask) + x = x + y + + y = self.ln2(x) + y = self.linear1(y) + y = mx.maximum(y, 0) + y = self.linear2(y) + x = x + y + + return x + + +class TransformerEncoder(Module): + def __init__( + self, num_layers: int, dims: int, num_heads: int, mlp_dims: Optional[int] = None + ): + super().__init__() + self.layers = [ + TransformerEncoderLayer(dims, num_heads, mlp_dims) + for i in range(num_layers) + ] + self.ln = LayerNorm(dims) + + def __call__(self, x, mask): + for l in self.layers: + x = l(x, mask) + x = self.ln(x) + + return x diff --git a/python/mlx/nn/utils.py b/python/mlx/nn/utils.py new file mode 100644 index 0000000000..c1023619f6 --- /dev/null +++ b/python/mlx/nn/utils.py @@ -0,0 +1,31 @@ +from typing import Callable + +import mlx.core as mx + + +def value_and_grad(model: "mlx.nn.Module", fn: Callable): + """Transform the passed function ``fn`` to a function that computes the + gradients of ``fn`` wrt the model's trainable parameters and also its + value. + + Args: + model (mlx.nn.Module): The model whose trainable parameters to compute + gradients for + fn (Callable): The scalar function to compute gradients for + + Returns: + A callable that returns the value of ``fn`` and the gradients wrt the + trainable parameters of ``model`` + """ + + def inner_fn(params, *args, **kwargs): + model.update(params) + return fn(*args, **kwargs) + + value_grad_fn = mx.value_and_grad(inner_fn) + + def wrapped_value_grad_fn(*args, **kwargs): + value, grad = value_grad_fn(model.trainable_parameters(), *args, **kwargs) + return value, grad + + return wrapped_value_grad_fn diff --git a/python/mlx/optimizers.py b/python/mlx/optimizers.py new file mode 100644 index 0000000000..63ccf03d84 --- /dev/null +++ b/python/mlx/optimizers.py @@ -0,0 +1,152 @@ +import math +from typing import List + +import mlx.core as mx +from mlx.utils import tree_map + + +class OptimizerState(dict): + """The optimizer state implements a recursively defined + :class:`collections.defaultdict`, namely a missing key in an optimizer + state is an :class:`OptimizerState`. + + .. note:: + :meth:`OptimizerState.get` in contrast to a normal dictionary also sets + the key to the ``default`` value if the ``key`` was not present in the + dictionary. + """ + + def __getitem__(self, key): + if key not in self: + self[key] = OptimizerState() + return super().__getitem__(key) + + def get(self, key, default): + """If ``key`` doesn't exist set its value to ``default`` and then return it.""" + if key not in self: + self[key] = default + return super().__getitem__(key) + + +class Optimizer: + """The base class for all optimizers. It allows us to implement an + optimizer on a per-parameter basis and apply it to a parameter tree. + + Attributes: + state (OptimizerState): It holds the optimizer's state dictionary. + """ + + def __init__(self): + self.state = OptimizerState() + + def update(self, model: "mlx.nn.Module", gradients: dict): + """Apply the gradients to the parameters of the model and update the + model with the new parameters. + + Args: + model (mlx.nn.Module): An mlx module to be updated. + gradients (dict): A Python tree of gradients, most likely computed + via :func:`mlx.nn.value_and_grad`. + """ + model.update(self.apply_gradients(gradients, model)) + + def apply_gradients(self, gradients: dict, model: dict): + """Apply the gradients to the parameters and return the updated parameters. + + Can be used to update a model via + ``model.update(opt.apply_gradients(grads, model))`` which is precisely + how :meth:`Optimizer.update` is implemented. + + Args: + gradients (dict): A Python tree of gradients. + model (dict): A Python tree of parameters. It can be a superset of + the gradients. In that case the returned python tree + will be of the same structure as the gradients. + """ + return tree_map(self.apply_single, gradients, model, self.state) + + def apply_single( + self, gradient: mx.array, parameter: mx.array, state: OptimizerState + ): + """To be extended by the children classes to implement each optimizer's + update.""" + raise NotImplementedError() + + +class SGD(Optimizer): + r"""Stochastic gradient descent optimizer. + + Updates a parameter :math:`w` with a gradient :math:`g` as follows + + .. math:: + + v_{t+1} &= \mu v_t + (1 - \mu) g_t \\ + w_{t+1} &= w_t - \lambda v_{t+1} + + Args: + learning_rate (float): The learning :math:`\lambda` for the update + momentum (float): The momentum strength :math:`\mu` + """ + + def __init__(self, learning_rate: float, momentum: float = 0.0): + super().__init__() + + self.learning_rate = learning_rate + self.momentum = momentum + + def apply_single( + self, gradient: mx.array, parameter: mx.array, state: OptimizerState + ): + """Performs the SGD parameter update and stores :math:`v` in the + optimizer state.""" + if self.momentum <= 0: + return parameter - self.learning_rate * gradient + + v = state.get("v", mx.zeros_like(gradient)) + v = self.momentum * v + (1 - self.momentum) * gradient + state["v"] = v + return parameter - self.learning_rate * v + + +class Adam(Optimizer): + r"""Implementation of the Adam optimizer [1]. + + Our Adam implementation follows the original paper and omits the bias + correction in the first and second moment estimates. In detail, + + .. math:: + + m_{t+1} &= \beta_1 m_t + (1 - \beta_1) g_t \\ + v_{t+1} &= \beta_2 v_t + (1 - \beta_2) g_t^2 \\ + w_{t+1} &= w_t - \lambda \frac{m_{t+1}}{\sqrt{v_{t+1} + \epsilon}} + + [1]: Kingma, D.P. and Ba, J., 2015. Adam: A method for stochastic + optimization. ICLR 2015. + """ + + def __init__( + self, learning_rate: float, betas: List[float] = [0.9, 0.999], eps: float = 1e-8 + ): + super().__init__() + + self.learning_rate = learning_rate + self.betas = betas + self.eps = eps + + def apply_single( + self, gradient: mx.array, parameter: mx.array, state: OptimizerState + ): + """Performs the Adam parameter update and stores :math:`v` and + :math:`m` in the optimizer state.""" + lr = self.learning_rate + b1, b2 = self.betas + eps = self.eps + + m = state.get("m", gradient) + v = state.get("v", mx.square(gradient)) + m = b1 * m + (1 - b1) * gradient + v = b2 * v + (1 - b2) * mx.square(gradient) + state["m"] = m + state["v"] = v + + return parameter - lr * m / (mx.sqrt(v) + eps) diff --git a/python/src/CMakeLists.txt b/python/src/CMakeLists.txt new file mode 100644 index 0000000000..5ab8a50bf1 --- /dev/null +++ b/python/src/CMakeLists.txt @@ -0,0 +1,32 @@ +pybind11_add_module( + core + ${CMAKE_CURRENT_SOURCE_DIR}/mlx.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/array.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/metal.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/stream.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/transforms.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/random.cpp +) + +if (NOT MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY) + set(MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}) +endif() + +set_target_properties( + core + PROPERTIES + LIBRARY_OUTPUT_DIRECTORY + ${MLX_PYTHON_BINDINGS_OUTPUT_DIRECTORY} +) + +target_link_libraries(core PRIVATE mlx) +target_compile_definitions(core PRIVATE _VERSION_=${MLX_VERSION}) + +if(BUILD_SHARED_LIBS) + target_link_options(core PRIVATE -Wl,-rpath,@loader_path/lib) +endif() diff --git a/python/src/fft.cpp b/python/src/fft.cpp new file mode 100644 index 0000000000..d20e2b2fe6 --- /dev/null +++ b/python/src/fft.cpp @@ -0,0 +1,468 @@ +#include +#include + +#include "python/src/utils.h" + +#include "mlx/fft.h" +#include "mlx/ops.h" + +namespace py = pybind11; +using namespace py::literals; + +using namespace mlx::core; + +void init_fft(py::module_& parent_module) { + auto m = parent_module.def_submodule( + "fft", "mlx.core.fft: Fast Fourier Transforms."); + m.def( + "fft", + [](const array& a, + const std::optional& n, + int axis, + StreamOrDevice s) { + if (n.has_value()) { + return fft::fft(a, n.value(), axis, s); + } else { + return fft::fft(a, axis, s); + } + }, + "a"_a, + "n"_a = none, + "axis"_a = -1, + "stream"_a = none, + R"pbdoc( + One dimensional discrete Fourier Transform. + + Args: + a (array): The input array. + n (int, optional): Size of the transformed axis. The + corresponding axis in the input is truncated or padded with + zeros to match ``n``. The default value is ``a.shape[axis]``. + axis (int, optional): Axis along which to perform the FFT. The + default is ``-1``. + + Returns: + array: The DFT of the input along the given axis. + )pbdoc"); + m.def( + "ifft", + [](const array& a, + const std::optional& n, + int axis, + StreamOrDevice s) { + if (n.has_value()) { + return fft::ifft(a, n.value(), axis, s); + } else { + return fft::ifft(a, axis, s); + } + }, + "a"_a, + "n"_a = none, + "axis"_a = -1, + "stream"_a = none, + R"pbdoc( + One dimensional inverse discrete Fourier Transform. + + Args: + a (array): The input array. + n (int, optional): Size of the transformed axis. The + corresponding axis in the input is truncated or padded with + zeros to match ``n``. The default value is ``a.shape[axis]``. + axis (int, optional): Axis along which to perform the FFT. The + default is ``-1``. + + Returns: + array: The inverse DFT of the input along the given axis. + )pbdoc"); + m.def( + "fft2", + [](const array& a, + const std::optional>& n, + const std::optional>& axes, + StreamOrDevice s) { + if (axes.has_value() && n.has_value()) { + return fft::fftn(a, n.value(), axes.value(), s); + } else if (axes.has_value()) { + return fft::fftn(a, axes.value(), s); + } else if (n.has_value()) { + std::vector axes_(n.value().size()); + std::iota(axes_.begin(), axes_.end(), -n.value().size()); + return fft::fftn(a, n.value(), axes_, s); + } else { + return fft::fftn(a, s); + } + }, + "a"_a, + "s"_a = none, + "axes"_a = std::vector{-2, -1}, + "stream"_a = none, + R"pbdoc( + Two dimensional discrete Fourier Transform. + + Args: + a (array): The input array. + s (list(int), optional): Sizes of the transformed axes. The + corresponding axes in the input are truncated or padded with + zeros to match the sizes in ``s``. The default value is the + sizes of ``a`` along ``axes``. + axes (list(int), optional): Axes along which to perform the FFT. + The default is ``[-2, -1]``. + + Returns: + array: The DFT of the input along the given axes. + )pbdoc"); + m.def( + "ifft2", + [](const array& a, + const std::optional>& n, + const std::optional>& axes, + StreamOrDevice s) { + if (axes.has_value() && n.has_value()) { + return fft::ifftn(a, n.value(), axes.value(), s); + } else if (axes.has_value()) { + return fft::ifftn(a, axes.value(), s); + } else if (n.has_value()) { + std::vector axes_(n.value().size()); + std::iota(axes_.begin(), axes_.end(), -n.value().size()); + return fft::ifftn(a, n.value(), axes_, s); + } else { + return fft::ifftn(a, s); + } + }, + "a"_a, + "s"_a = none, + "axes"_a = std::vector{-2, -1}, + "stream"_a = none, + R"pbdoc( + Two dimensional inverse discrete Fourier Transform. + + Args: + a (array): The input array. + s (list(int), optional): Sizes of the transformed axes. The + corresponding axes in the input are truncated or padded with + zeros to match the sizes in ``s``. The default value is the + sizes of ``a`` along ``axes``. + axes (list(int), optional): Axes along which to perform the FFT. + The default is ``[-2, -1]``. + + Returns: + array: The inverse DFT of the input along the given axes. + )pbdoc"); + m.def( + "fftn", + [](const array& a, + const std::optional>& n, + const std::optional>& axes, + StreamOrDevice s) { + if (axes.has_value() && n.has_value()) { + return fft::fftn(a, n.value(), axes.value(), s); + } else if (axes.has_value()) { + return fft::fftn(a, axes.value(), s); + } else if (n.has_value()) { + std::vector axes_(n.value().size()); + std::iota(axes_.begin(), axes_.end(), -n.value().size()); + return fft::fftn(a, n.value(), axes_, s); + } else { + return fft::fftn(a, s); + } + }, + "a"_a, + "s"_a = none, + "axes"_a = none, + "stream"_a = none, + R"pbdoc( + n-dimensional discrete Fourier Transform. + + Args: + a (array): The input array. + s (list(int), optional): Sizes of the transformed axes. The + corresponding axes in the input are truncated or padded with + zeros to match the sizes in ``s``. The default value is the + sizes of ``a`` along ``axes``. + axes (list(int), optional): Axes along which to perform the FFT. + The default is ``None`` in which case the FFT is over the last + ``len(s)`` axes are or all axes if ``s`` is also ``None``. + + Returns: + array: The DFT of the input along the given axes. + )pbdoc"); + m.def( + "ifftn", + [](const array& a, + const std::optional>& n, + const std::optional>& axes, + StreamOrDevice s) { + if (axes.has_value() && n.has_value()) { + return fft::ifftn(a, n.value(), axes.value(), s); + } else if (axes.has_value()) { + return fft::ifftn(a, axes.value(), s); + } else if (n.has_value()) { + std::vector axes_(n.value().size()); + std::iota(axes_.begin(), axes_.end(), -n.value().size()); + return fft::ifftn(a, n.value(), axes_, s); + } else { + return fft::ifftn(a, s); + } + }, + "a"_a, + "s"_a = none, + "axes"_a = none, + "stream"_a = none, + R"pbdoc( + n-dimensional inverse discrete Fourier Transform. + + Args: + a (array): The input array. + s (list(int), optional): Sizes of the transformed axes. The + corresponding axes in the input are truncated or padded with + zeros to match the sizes in ``s``. The default value is the + sizes of ``a`` along ``axes``. + axes (list(int), optional): Axes along which to perform the FFT. + The default is ``None`` in which case the FFT is over the last + ``len(s)`` axes or all axes if ``s`` is also ``None``. + + Returns: + array: The inverse DFT of the input along the given axes. + )pbdoc"); + m.def( + "rfft", + [](const array& a, + const std::optional& n, + int axis, + StreamOrDevice s) { + if (n.has_value()) { + return fft::rfft(a, n.value(), axis, s); + } else { + return fft::rfft(a, axis, s); + } + }, + "a"_a, + "n"_a = none, + "axis"_a = -1, + "stream"_a = none, + R"pbdoc( + One dimensional discrete Fourier Transform on a real input. + + The output has the same shape as the input except along ``axis`` in + which case it has size ``n // 2 + 1``. + + Args: + a (array): The input array. If the array is complex it will be silently + cast to a real type. + n (int, optional): Size of the transformed axis. The + corresponding axis in the input is truncated or padded with + zeros to match ``n``. The default value is ``a.shape[axis]``. + axis (int, optional): Axis along which to perform the FFT. The + default is ``-1``. + + Returns: + array: The DFT of the input along the given axis. The output + data type will be complex. + )pbdoc"); + m.def( + "irfft", + [](const array& a, + const std::optional& n, + int axis, + StreamOrDevice s) { + if (n.has_value()) { + return fft::irfft(a, n.value(), axis, s); + } else { + return fft::irfft(a, axis, s); + } + }, + "a"_a, + "n"_a = none, + "axis"_a = -1, + "stream"_a = none, + R"pbdoc( + The inverse of :func:`rfft`. + + The output has the same shape as the input except along ``axis`` in + which case it has size ``n``. + + Args: + a (array): The input array. + n (int, optional): Size of the transformed axis. The + corresponding axis in the input is truncated or padded with + zeros to match ``n // 2 + 1``. The default value is + ``a.shape[axis] // 2 + 1``. + axis (int, optional): Axis along which to perform the FFT. The + default is ``-1``. + + Returns: + array: The real array containing the inverse of :func:`rfft`. + )pbdoc"); + m.def( + "rfft2", + [](const array& a, + const std::optional>& n, + const std::optional>& axes, + StreamOrDevice s) { + if (axes.has_value() && n.has_value()) { + return fft::rfftn(a, n.value(), axes.value(), s); + } else if (axes.has_value()) { + return fft::rfftn(a, axes.value(), s); + } else if (n.has_value()) { + std::vector axes_(n.value().size()); + std::iota(axes_.begin(), axes_.end(), -n.value().size()); + return fft::rfftn(a, n.value(), axes_, s); + } else { + return fft::rfftn(a, s); + } + }, + "a"_a, + "s"_a = none, + "axes"_a = std::vector{-2, -1}, + "stream"_a = none, + R"pbdoc( + Two dimensional real discrete Fourier Transform. + + The output has the same shape as the input except along the dimensions in + ``axes`` in which case it has sizes from ``s``. The last axis in ``axes`` is + treated as the real axis and will have size ``s[-1] // 2 + 1``. + + Args: + a (array): The input array. If the array is complex it will be silently + cast to a real type. + s (list(int), optional): Sizes of the transformed axes. The + corresponding axes in the input are truncated or padded with + zeros to match the sizes in ``s``. The default value is the + sizes of ``a`` along ``axes``. + axes (list(int), optional): Axes along which to perform the FFT. + The default is ``[-2, -1]``. + + Returns: + array: The real DFT of the input along the given axes. The output + data type will be complex. + )pbdoc"); + m.def( + "irfft2", + [](const array& a, + const std::optional>& n, + const std::optional>& axes, + StreamOrDevice s) { + if (axes.has_value() && n.has_value()) { + return fft::irfftn(a, n.value(), axes.value(), s); + } else if (axes.has_value()) { + return fft::irfftn(a, axes.value(), s); + } else if (n.has_value()) { + std::vector axes_(n.value().size()); + std::iota(axes_.begin(), axes_.end(), -n.value().size()); + return fft::irfftn(a, n.value(), axes_, s); + } else { + return fft::irfftn(a, s); + } + }, + "a"_a, + "s"_a = none, + "axes"_a = std::vector{-2, -1}, + "stream"_a = none, + R"pbdoc( + The inverse of :func:`rfft2`. + + Note the input is generally complex. The dimensions of the input + specified in ``axes`` are padded or truncated to match the sizes + from ``s``. The last axis in ``axes`` is treated as the real axis + and will have size ``s[-1] // 2 + 1``. + + Args: + a (array): The input array. + s (list(int), optional): Sizes of the transformed axes. The + corresponding axes in the input are truncated or padded with + zeros to match the sizes in ``s`` except for the last axis + which has size ``s[-1] // 2 + 1``. The default value is the + sizes of ``a`` along ``axes``. + axes (list(int), optional): Axes along which to perform the FFT. + The default is ``[-2, -1]``. + + Returns: + array: The real array containing the inverse of :func:`rfft2`. + )pbdoc"); + m.def( + "rfftn", + [](const array& a, + const std::optional>& n, + const std::optional>& axes, + StreamOrDevice s) { + if (axes.has_value() && n.has_value()) { + return fft::rfftn(a, n.value(), axes.value(), s); + } else if (axes.has_value()) { + return fft::rfftn(a, axes.value(), s); + } else if (n.has_value()) { + std::vector axes_(n.value().size()); + std::iota(axes_.begin(), axes_.end(), -n.value().size()); + return fft::rfftn(a, n.value(), axes_, s); + } else { + return fft::rfftn(a, s); + } + }, + "a"_a, + "s"_a = none, + "axes"_a = none, + "stream"_a = none, + R"pbdoc( + n-dimensional real discrete Fourier Transform. + + The output has the same shape as the input except along the dimensions in + ``axes`` in which case it has sizes from ``s``. The last axis in ``axes`` is + treated as the real axis and will have size ``s[-1] // 2 + 1``. + + Args: + a (array): The input array. If the array is complex it will be silently + cast to a real type. + s (list(int), optional): Sizes of the transformed axes. The + corresponding axes in the input are truncated or padded with + zeros to match the sizes in ``s``. The default value is the + sizes of ``a`` along ``axes``. + axes (list(int), optional): Axes along which to perform the FFT. + The default is ``None`` in which case the FFT is over the last + ``len(s)`` axes or all axes if ``s`` is also ``None``. + + Returns: + array: The real DFT of the input along the given axes. The output + )pbdoc"); + m.def( + "irfftn", + [](const array& a, + const std::optional>& n, + const std::optional>& axes, + StreamOrDevice s) { + if (axes.has_value() && n.has_value()) { + return fft::irfftn(a, n.value(), axes.value(), s); + } else if (axes.has_value()) { + return fft::irfftn(a, axes.value(), s); + } else if (n.has_value()) { + std::vector axes_(n.value().size()); + std::iota(axes_.begin(), axes_.end(), -n.value().size()); + return fft::irfftn(a, n.value(), axes_, s); + } else { + return fft::irfftn(a, s); + } + }, + "a"_a, + "s"_a = none, + "axes"_a = none, + "stream"_a = none, + R"pbdoc( + The inverse of :func:`rfftn`. + + Note the input is generally complex. The dimensions of the input + specified in ``axes`` are padded or truncated to match the sizes + from ``s``. The last axis in ``axes`` is treated as the real axis + and will have size ``s[-1] // 2 + 1``. + + Args: + a (array): The input array. + s (list(int), optional): Sizes of the transformed axes. The + corresponding axes in the input are truncated or padded with + zeros to match the sizes in ``s``. The default value is the + sizes of ``a`` along ``axes``. + axes (list(int), optional): Axes along which to perform the FFT. + The default is ``None`` in which case the FFT is over the last + ``len(s)`` axes or all axes if ``s`` is also ``None``. + + Returns: + array: The real array containing the inverse of :func:`rfftn`. + )pbdoc"); +} diff --git a/python/src/indexing.cpp b/python/src/indexing.cpp new file mode 100644 index 0000000000..968a537473 --- /dev/null +++ b/python/src/indexing.cpp @@ -0,0 +1,635 @@ +#include +#include + +#include "python/src/indexing.h" + +#include "mlx/ops.h" + +bool is_none_slice(const py::slice& in_slice) { + return ( + py::getattr(in_slice, "start").is_none() && + py::getattr(in_slice, "stop").is_none() && + py::getattr(in_slice, "step").is_none()); +} + +int get_slice_int(py::object obj, int default_val) { + if (!obj.is_none()) { + if (!py::isinstance(obj)) { + throw std::invalid_argument("Slice indices must be integers or None."); + } + return py::cast(py::cast(obj)); + } + return default_val; +} + +void get_slice_params( + int& starts, + int& ends, + int& strides, + const py::slice& in_slice, + int axis_size) { + // Following numpy's convention + // Assume n is the number of elements in the dimension being sliced. + // Then, if i is not given it defaults to 0 for k > 0 and n - 1 for + // k < 0 . If j is not given it defaults to n for k > 0 and -n-1 for + // k < 0 . If k is not given it defaults to 1 + + strides = get_slice_int(py::getattr(in_slice, "step"), 1); + starts = get_slice_int( + py::getattr(in_slice, "start"), strides < 0 ? axis_size - 1 : 0); + ends = get_slice_int( + py::getattr(in_slice, "stop"), strides < 0 ? -axis_size - 1 : axis_size); + + // starts = (starts < 0) ? starts + axis_size : starts; + // ends = (ends < 0) ? ends + axis_size : ends; +} + +array get_int_index(py::object idx, int axis_size) { + int idx_ = py::cast(idx); + idx_ = (idx_ < 0) ? idx_ + axis_size : idx_; + + return array(idx_, uint32); +} + +bool is_valid_index_type(const py::object& obj) { + return py::isinstance(obj) || py::isinstance(obj) || + py::isinstance(obj) || obj.is_none() || py::ellipsis().is(obj); +} + +array mlx_get_item_slice(const array& src, const py::slice& in_slice) { + // Check input and raise error if 0 dim for parity with np + if (src.ndim() == 0) { + throw std::invalid_argument( + "too many indices for array: array is 0-dimensional"); + } + + // Return a copy of the array if none slice is request + if (is_none_slice(in_slice)) { + return src; + } + + std::vector starts(src.ndim(), 0); + std::vector ends = src.shape(); + std::vector strides(src.ndim(), 1); + + // Check and update slice params + get_slice_params(starts[0], ends[0], strides[0], in_slice, ends[0]); + return slice(src, starts, ends, strides); +} + +array mlx_get_item_array(const array& src, const array& indices) { + // Check input and raise error if 0 dim for parity with np + if (src.ndim() == 0) { + throw std::invalid_argument( + "too many indices for array: array is 0-dimensional"); + } + + if (indices.dtype() == bool_) { + throw std::invalid_argument("boolean indices are not yet supported"); + } + + // If only one input array is mentioned, we set axis=0 in take + // for parity with np + return take(src, indices, 0); +} + +array mlx_get_item_int(const array& src, const py::int_& idx) { + // Check input and raise error if 0 dim for parity with np + if (src.ndim() == 0) { + throw std::invalid_argument( + "too many indices for array: array is 0-dimensional"); + } + + // If only one input idx is mentioned, we set axis=0 in take + // for parity with np + return take(src, get_int_index(idx, src.shape(0)), 0); +} + +array mlx_gather_nd( + array src, + const std::vector& indices, + bool gather_first, + int& max_dims) { + max_dims = 0; + std::vector gather_indices; + std::vector is_slice(indices.size(), false); + int num_slices = 0; + // gather all the arrays + for (int i = 0; i < indices.size(); i++) { + auto& idx = indices[i]; + + if (py::isinstance(idx)) { + int start, end, stride; + get_slice_params(start, end, stride, idx, src.shape(i)); + gather_indices.push_back(arange(start, end, stride, uint32)); + num_slices++; + is_slice[i] = true; + } else if (py::isinstance(idx)) { + gather_indices.push_back(get_int_index(idx, src.shape(i))); + } else if (py::isinstance(idx)) { + auto arr = py::cast(idx); + max_dims = std::max(static_cast(arr.ndim()), max_dims); + gather_indices.push_back(arr); + } + } + + // reshape them so that the int/array indices are first + if (gather_first) { + int slice_index = 0; + for (int i = 0; i < gather_indices.size(); i++) { + if (is_slice[i]) { + std::vector index_shape(max_dims + num_slices, 1); + index_shape[max_dims + slice_index] = gather_indices[i].shape(0); + gather_indices[i] = reshape(gather_indices[i], index_shape); + slice_index++; + } else { + std::vector index_shape = gather_indices[i].shape(); + index_shape.insert(index_shape.end(), num_slices, 1); + gather_indices[i] = reshape(gather_indices[i], index_shape); + } + } + } else { + // reshape them so that the int/array indices are last + for (int i = 0; i < gather_indices.size(); i++) { + if (i < num_slices) { + std::vector index_shape(max_dims + num_slices, 1); + index_shape[i] = gather_indices[i].shape(0); + gather_indices[i] = reshape(gather_indices[i], index_shape); + } + } + } + + // Do the gather + std::vector axes(indices.size()); + std::iota(axes.begin(), axes.end(), 0); + std::vector slice_sizes = src.shape(); + std::fill(slice_sizes.begin(), slice_sizes.begin() + indices.size(), 1); + src = gather(src, gather_indices, axes, slice_sizes); + + // Squeeze the dims + std::vector out_shape; + out_shape.insert( + out_shape.end(), + src.shape().begin(), + src.shape().begin() + max_dims + num_slices); + out_shape.insert( + out_shape.end(), + src.shape().begin() + max_dims + num_slices + indices.size(), + src.shape().end()); + src = reshape(src, out_shape); + + return src; +} + +array mlx_get_item_nd(array src, const py::tuple& entries) { + // No indices make this a noop + if (entries.size() == 0) { + return src; + } + + // The plan is as follows: + // 1. Replace the ellipsis with a series of slice(None) + // 2. Loop over the indices and calculate the gather indices + // 3. Calculate the remaining slices and reshapes + + // Ellipsis handling + std::vector indices; + { + int non_none_indices_before = 0; + int non_none_indices_after = 0; + std::vector r_indices; + int i = 0; + for (; i < entries.size(); i++) { + auto idx = entries[i]; + if (!is_valid_index_type(idx)) { + throw std::invalid_argument( + "Cannot index mlx array using the given type yet"); + } + if (!py::ellipsis().is(idx)) { + indices.push_back(idx); + non_none_indices_before += !idx.is_none(); + } else { + break; + } + } + for (int j = entries.size() - 1; j > i; j--) { + auto idx = entries[j]; + if (!is_valid_index_type(idx)) { + throw std::invalid_argument( + "Cannot index mlx array using the given type yet"); + } + if (py::ellipsis().is(idx)) { + throw std::invalid_argument( + "An index can only have a single ellipsis (...)"); + } + r_indices.push_back(idx); + non_none_indices_after += !idx.is_none(); + } + for (int axis = non_none_indices_before; + axis < src.ndim() - non_none_indices_after; + axis++) { + indices.push_back(py::slice(0, src.shape(axis), 1)); + } + indices.insert(indices.end(), r_indices.rbegin(), r_indices.rend()); + } + + // Check for the number of indices passed + { + int cnt = src.ndim(); + for (auto& idx : indices) { + if (!idx.is_none()) { + cnt--; + } + } + if (cnt < 0) { + std::ostringstream msg; + msg << "Too many indices for array with " << src.ndim() << "dimensions."; + throw std::invalid_argument(msg.str()); + } + } + + // Gather handling + // + // Check whether we have arrays or integer indices and delegate to gather_nd + // after removing the slices at the end and all Nones. + std::vector remaining_indices; + bool have_array = false; + { + // First check whether the results of gather are going to be 1st or + // normally in between. + bool have_non_array = false; + bool gather_first = false; + for (auto& idx : indices) { + if (py::isinstance(idx) || py::isinstance(idx)) { + if (have_array && have_non_array) { + gather_first = true; + break; + } + have_array = true; + } else { + have_non_array |= have_array; + } + } + + if (have_array) { + int last_array; + // Then find the last array + for (last_array = indices.size() - 1; last_array >= 0; last_array--) { + auto& idx = indices[last_array]; + if (py::isinstance(idx) || py::isinstance(idx)) { + break; + } + } + + std::vector gather_indices; + for (int i = 0; i <= last_array; i++) { + auto& idx = indices[i]; + if (!idx.is_none()) { + gather_indices.push_back(idx); + } + } + int max_dims; + src = mlx_gather_nd(src, gather_indices, gather_first, max_dims); + + // Reassemble the indices for the slicing or reshaping if there are any + if (gather_first) { + for (int i = 0; i < max_dims; i++) { + remaining_indices.push_back( + py::slice(py::none(), py::none(), py::none())); + } + for (int i = 0; i < last_array; i++) { + auto& idx = indices[i]; + if (idx.is_none()) { + remaining_indices.push_back(indices[i]); + } else if (py::isinstance(idx)) { + remaining_indices.push_back( + py::slice(py::none(), py::none(), py::none())); + } + } + for (int i = last_array + 1; i < indices.size(); i++) { + remaining_indices.push_back(indices[i]); + } + } else { + for (int i = 0; i < indices.size(); i++) { + auto& idx = indices[i]; + if (py::isinstance(idx) || py::isinstance(idx)) { + break; + } else if (idx.is_none()) { + remaining_indices.push_back(idx); + } else { + remaining_indices.push_back( + py::slice(py::none(), py::none(), py::none())); + } + } + for (int i = 0; i < max_dims; i++) { + remaining_indices.push_back( + py::slice(py::none(), py::none(), py::none())); + } + for (int i = last_array + 1; i < indices.size(); i++) { + remaining_indices.push_back(indices[i]); + } + } + } + } + if (have_array && remaining_indices.empty()) { + return src; + } + if (remaining_indices.empty()) { + remaining_indices = indices; + } + + // Slice handling + { + std::vector starts(src.ndim(), 0); + std::vector ends = src.shape(); + std::vector strides(src.ndim(), 1); + int axis = 0; + for (auto& idx : remaining_indices) { + if (!idx.is_none()) { + get_slice_params( + starts[axis], ends[axis], strides[axis], idx, ends[axis]); + axis++; + } + } + src = slice(src, starts, ends, strides); + } + + // Unsqueeze handling + if (remaining_indices.size() > src.ndim()) { + std::vector out_shape; + int axis = 0; + for (auto& idx : remaining_indices) { + if (idx.is_none()) { + out_shape.push_back(1); + } else { + out_shape.push_back(src.shape(axis++)); + } + } + src = reshape(src, out_shape); + } + + return src; +} + +array mlx_get_item(const array& src, const py::object& obj) { + if (py::isinstance(obj)) { + return mlx_get_item_slice(src, obj); + } else if (py::isinstance(obj)) { + return mlx_get_item_array(src, py::cast(obj)); + } else if (py::isinstance(obj)) { + return mlx_get_item_int(src, obj); + } else if (py::isinstance(obj)) { + return mlx_get_item_nd(src, obj); + } else if (obj.is_none()) { + std::vector s(1, 1); + s.insert(s.end(), src.shape().begin(), src.shape().end()); + return reshape(src, s); + } + throw std::invalid_argument("Cannot index mlx array using the given type."); +} + +array mlx_set_item_int( + const array& src, + const py::int_& idx, + const array& update) { + if (src.ndim() == 0) { + throw std::invalid_argument( + "too many indices for array: array is 0-dimensional"); + } + + // Remove any leading singleton dimensions from the update + // and then broadcast update to shape of src[0, ...] + int s = 0; + for (; s < update.ndim() && update.shape(s) == 1; s++) + ; + auto up_shape = + std::vector(update.shape().begin() + s, update.shape().end()); + auto shape = src.shape(); + shape[0] = 1; + return scatter( + src, + get_int_index(idx, src.shape(0)), + broadcast_to(reshape(update, up_shape), shape), + 0); +} + +array mlx_set_item_array( + const array& src, + const array& indices, + const array& update) { + if (src.ndim() == 0) { + throw std::invalid_argument( + "too many indices for array: array is 0-dimensional"); + } + + // Remove any leading singleton dimensions from the update + int s = 0; + for (; s < update.ndim() && update.shape(s) == 1; s++) + ; + auto up_shape = + std::vector(update.shape().begin() + s, update.shape().end()); + auto up = reshape(update, up_shape); + + // The update shape must broadcast with indices.shape + [1] + src.shape[1:] + up_shape = indices.shape(); + up_shape.insert(up_shape.end(), src.shape().begin() + 1, src.shape().end()); + up = broadcast_to(up, up_shape); + up_shape.insert(up_shape.begin() + indices.ndim(), 1); + up = reshape(up, up_shape); + + return scatter(src, indices, up, 0); +} + +array mlx_set_item_slice( + const array& src, + const py::slice& in_slice, + const array& update) { + // Check input and raise error if 0 dim for parity with np + if (src.ndim() == 0) { + throw std::invalid_argument( + "too many indices for array: array is 0-dimensional"); + } + + // If none slice is requested broadcast the update + // to the src size and return it. + if (is_none_slice(in_slice)) { + int s = 0; + for (; s < update.ndim() && update.shape(s) == 1; s++) + ; + auto up_shape = + std::vector(update.shape().begin() + s, update.shape().end()); + return broadcast_to(reshape(update, up_shape), src.shape()); + } + + int start = 0; + int end = src.shape(0); + int stride = 1; + + // Check and update slice params + get_slice_params(start, end, stride, in_slice, end); + + return mlx_set_item_array(src, arange(start, end, stride, uint32), update); +} + +array mlx_set_item_nd( + const array& src, + const py::tuple& entries, + const array& update) { + std::vector indices; + int non_none_indices = 0; + + // Expand ellipses into a series of ':' slices + { + int non_none_indices_before = 0; + int non_none_indices_after = 0; + bool has_ellipsis = false; + int indices_before = 0; + for (int i = 0; i < entries.size(); ++i) { + auto idx = entries[i]; + if (!is_valid_index_type(idx)) { + throw std::invalid_argument( + "Cannot index mlx array using the given type yet"); + } else if (!py::ellipsis().is(idx)) { + if (!has_ellipsis) { + indices_before++; + non_none_indices_before += !idx.is_none(); + } else { + non_none_indices_after += !idx.is_none(); + } + indices.push_back(idx); + } else if (has_ellipsis) { + throw std::invalid_argument( + "An index can only have a single ellipsis (...)"); + } else { + has_ellipsis = true; + } + } + if (has_ellipsis) { + for (int axis = non_none_indices_before; + axis < src.ndim() - non_none_indices_after; + axis++) { + indices.insert( + indices.begin() + indices_before, py::slice(0, src.shape(axis), 1)); + } + non_none_indices = src.ndim(); + } else { + non_none_indices = non_none_indices_before + non_none_indices_after; + } + } + + if (non_none_indices > src.ndim()) { + std::ostringstream msg; + msg << "Too many indices for array with " << src.ndim() << "dimensions."; + throw std::invalid_argument(msg.str()); + } + + // Remove leading singletons dimensions from the update + int s = 0; + for (; s < update.ndim() && update.shape(s) == 1; s++) { + }; + auto up_shape = + std::vector(update.shape().begin() + s, update.shape().end()); + auto up = reshape(update, up_shape); + + // If no non-None indices return the broadcasted update + if (non_none_indices == 0) { + return broadcast_to(up, src.shape()); + } + + unsigned long max_dim = 0; + bool arrays_first = false; + int num_slices = 0; + int num_arrays = 0; + { + bool have_array = false; + bool have_non_array = false; + for (auto& idx : indices) { + if (py::isinstance(idx) || idx.is_none()) { + have_non_array = have_array; + num_slices++; + } else if (py::isinstance(idx)) { + have_array = true; + if (have_array && have_non_array) { + arrays_first = true; + } + max_dim = std::max(py::cast(idx).ndim(), max_dim); + num_arrays++; + } + } + } + + std::vector arr_indices; + int slice_num = 0; + int array_num = 0; + int ax = 0; + for (int i = 0; i < indices.size(); ++i) { + auto& pyidx = indices[i]; + if (py::isinstance(pyidx)) { + int start, end, stride; + get_slice_params(start, end, stride, pyidx, src.shape(ax++)); + auto idx = arange(start, end, stride, uint32); + std::vector idx_shape(max_dim + num_slices, 1); + auto loc = slice_num + (arrays_first ? max_dim : 0); + slice_num++; + idx_shape[loc] = idx.size(); + arr_indices.push_back(reshape(idx, idx_shape)); + } else if (py::isinstance(pyidx)) { + arr_indices.push_back(get_int_index(pyidx, src.shape(ax++))); + } else if (pyidx.is_none()) { + slice_num++; + } else if (py::isinstance(pyidx)) { + ax++; + auto idx = py::cast(pyidx); + std::vector idx_shape; + if (!arrays_first) { + idx_shape.insert(idx_shape.end(), slice_num, 1); + } + idx_shape.insert(idx_shape.end(), max_dim - idx.ndim(), 1); + idx_shape.insert(idx_shape.end(), idx.shape().begin(), idx.shape().end()); + idx_shape.insert( + idx_shape.end(), num_slices - (arrays_first ? 0 : slice_num), 1); + arr_indices.push_back(reshape(idx, idx_shape)); + if (!arrays_first && ++array_num == num_arrays) { + slice_num += max_dim; + } + } else { + throw std::invalid_argument( + "Cannot index mlx array using the given type yet"); + } + } + + arr_indices = broadcast_arrays(arr_indices); + up_shape = arr_indices[0].shape(); + up_shape.insert( + up_shape.end(), + src.shape().begin() + non_none_indices, + src.shape().end()); + up = broadcast_to(up, up_shape); + up_shape.insert( + up_shape.begin() + arr_indices[0].ndim(), non_none_indices, 1); + up = reshape(up, up_shape); + + std::vector axes(arr_indices.size(), 0); + std::iota(axes.begin(), axes.end(), 0); + return scatter(src, arr_indices, up, axes); +} + +void mlx_set_item(array& src, const py::object& obj, const ScalarOrArray& v) { + auto vals = to_array(v, src.dtype()); + auto impl = [&src, &obj, &vals]() { + if (py::isinstance(obj)) { + return mlx_set_item_slice(src, obj, vals); + } else if (py::isinstance(obj)) { + return mlx_set_item_array(src, py::cast(obj), vals); + } else if (py::isinstance(obj)) { + return mlx_set_item_int(src, obj, vals); + } else if (py::isinstance(obj)) { + return mlx_set_item_nd(src, obj, vals); + } else if (obj.is_none()) { + return broadcast_to(vals, src.shape()); + } + throw std::invalid_argument("Cannot index mlx array using the given type."); + }; + auto out = impl(); + src.overwrite_descriptor(out); +} diff --git a/python/src/indexing.h b/python/src/indexing.h new file mode 100644 index 0000000000..cb5d3106c0 --- /dev/null +++ b/python/src/indexing.h @@ -0,0 +1,12 @@ +#pragma once + +#include + +#include "mlx/array.h" +#include "python/src/utils.h" + +namespace py = pybind11; +using namespace mlx::core; + +array mlx_get_item(const array& src, const py::object& obj); +void mlx_set_item(array& src, const py::object& obj, const ScalarOrArray& v); diff --git a/python/src/load.cpp b/python/src/load.cpp new file mode 100644 index 0000000000..d3f1f92502 --- /dev/null +++ b/python/src/load.cpp @@ -0,0 +1,290 @@ +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "mlx/load.h" +#include "mlx/ops.h" +#include "mlx/utils.h" +#include "python/src/load.h" +#include "python/src/utils.h" + +namespace py = pybind11; +using namespace py::literals; +using namespace mlx::core; + +/////////////////////////////////////////////////////////////////////////////// +// Helpers +/////////////////////////////////////////////////////////////////////////////// + +bool is_istream_object(const py::object& file) { + return py::hasattr(file, "read") && py::hasattr(file, "seek") && + py::hasattr(file, "tell") && py::hasattr(file, "closed"); +} + +bool is_ostream_object(const py::object& file) { + return py::hasattr(file, "write") && py::hasattr(file, "seek") && + py::hasattr(file, "tell") && py::hasattr(file, "closed"); +} + +bool is_zip_file(const py::module_& zipfile, const py::object& file) { + if (is_istream_object(file)) { + auto st_pos = file.attr("tell")(); + bool r = (zipfile.attr("is_zipfile")(file)).cast(); + file.attr("seek")(st_pos, 0); + return r; + } + return zipfile.attr("is_zipfile")(file).cast(); +} + +class ZipFileWrapper { + public: + ZipFileWrapper( + const py::module_& zipfile, + const py::object& file, + char mode = 'r', + int compression = 0) + : zipfile_module_(zipfile), + zipfile_object_(zipfile.attr("ZipFile")( + file, + "mode"_a = mode, + "compression"_a = compression, + "allowZip64"_a = true)), + files_list_(zipfile_object_.attr("namelist")()), + open_func_(zipfile_object_.attr("open")), + read_func_(zipfile_object_.attr("read")), + close_func_(zipfile_object_.attr("close")) {} + + std::vector namelist() const { + return files_list_.cast>(); + } + + py::object open(const std::string& key, char mode = 'r') { + // Following numpy : + // https://github.com/numpy/numpy/blob/db4f43983cb938f12c311e1f5b7165e270c393b4/numpy/lib/npyio.py#L742C36-L742C47 + if (mode == 'w') { + return open_func_(key, "mode"_a = mode, "force_zip64"_a = true); + } + return open_func_(key, "mode"_a = mode); + } + + private: + py::module_ zipfile_module_; + py::object zipfile_object_; + py::list files_list_; + py::object open_func_; + py::object read_func_; + py::object close_func_; +}; + +/////////////////////////////////////////////////////////////////////////////// +// Loading +/////////////////////////////////////////////////////////////////////////////// + +class PyFileReader : public io::Reader { + public: + PyFileReader(py::object file) + : pyistream_(file), + readinto_func_(file.attr("readinto")), + seek_func_(file.attr("seek")), + tell_func_(file.attr("tell")) {} + + bool is_open() const override { + return !pyistream_.attr("closed").cast(); + } + + bool good() const override { + return !pyistream_.is_none(); + } + + size_t tell() const override { + return tell_func_().cast(); + } + + void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg) + override { + seek_func_(off, (int)way); + } + + void read(char* data, size_t n) override { + py::object bytes_read = + readinto_func_(py::memoryview::from_buffer(data, {n}, {sizeof(char)})); + if (bytes_read.is_none() || py::cast(bytes_read) < n) { + throw std::runtime_error("[load] Failed to read from python stream"); + } + } + + std::string label() const override { + return "python file object"; + } + + private: + py::object pyistream_; + py::object readinto_func_; + py::object seek_func_; + py::object tell_func_; +}; + +DictOrArray mlx_load_helper(py::object file, StreamOrDevice s) { + py::module_ zipfile = py::module_::import("zipfile"); + + // Assume .npz file if it is zipped + if (is_zip_file(zipfile, file)) { + // Output dictionary filename in zip -> loaded array + std::unordered_map array_dict; + + // Create python ZipFile object + ZipFileWrapper zipfile_object(zipfile, file); + for (const std::string& st : zipfile_object.namelist()) { + // Open zip file as a python file stream + py::object sub_file = zipfile_object.open(st); + + // Create array from python fille stream + auto arr = load(std::make_shared(sub_file), s); + + // Remove .npy from file if it is there + auto key = st; + if (st.length() > 4 && st.substr(st.length() - 4, 4) == ".npy") + key = st.substr(0, st.length() - 4); + + // Add array to dict + array_dict.insert({key, arr}); + } + + // If we don't own the stream and it was passed to us, eval immediately + for (auto& [key, arr] : array_dict) { + arr.eval(); + } + + return {array_dict}; + } else if (py::isinstance(file)) { // Assume .npy file path string + return {load(py::cast(file), s)}; + } else if (is_istream_object(file)) { + // If we don't own the stream and it was passed to us, eval immediately + auto arr = load(std::make_shared(file), s); + arr.eval(); + return {arr}; + } + + throw std::invalid_argument( + "[load] Input must be a file-like object, string, or pathlib.Path"); +} + +/////////////////////////////////////////////////////////////////////////////// +// Saving +/////////////////////////////////////////////////////////////////////////////// + +class PyFileWriter : public io::Writer { + public: + PyFileWriter(py::object file) + : pyostream_(file), + write_func_(file.attr("write")), + seek_func_(file.attr("seek")), + tell_func_(file.attr("tell")) {} + + bool is_open() const override { + return !pyostream_.attr("closed").cast(); + } + + bool good() const override { + return !pyostream_.is_none(); + } + + size_t tell() const override { + return tell_func_().cast(); + } + + void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg) + override { + seek_func_(off, (int)way); + } + + void write(const char* data, size_t n) override { + py::object bytes_written = + write_func_(py::memoryview::from_buffer(data, {n}, {sizeof(char)})); + if (bytes_written.is_none() || py::cast(bytes_written) < n) { + throw std::runtime_error("[load] Failed to write to python stream"); + } + } + + std::string label() const override { + return "python file object"; + } + + private: + py::object pyostream_; + py::object write_func_; + py::object seek_func_; + py::object tell_func_; +}; + +void mlx_save_helper(py::object file, array a, bool retain_graph) { + if (py::isinstance(file)) { + save(py::cast(file), a, retain_graph); + return; + } else if (is_ostream_object(file)) { + save(std::make_shared(file), a, retain_graph); + return; + } + + throw std::invalid_argument( + "[save] Input must be a file-like object, string, or pathlib.Path"); +} + +void mlx_savez_helper( + py::object file_, + py::args args, + const py::kwargs& kwargs, + bool compressed) { + // Add .npz to the end of the filename if not already there + py::object file = file_; + + if (py::isinstance(file_)) { + std::string fname = file_.cast(); + + // Add .npz to file name if it is not there + if (fname.length() < 4 || fname.substr(fname.length() - 4, 4) != ".npz") + fname += ".npz"; + + file = py::str(fname); + } + + // Collect args and kwargs + auto arrays_dict = kwargs.cast>(); + auto arrays_list = args.cast>(); + + for (int i = 0; i < arrays_list.size(); i++) { + std::string arr_name = "arr_" + std::to_string(i); + + if (arrays_dict.count(arr_name) > 0) { + throw std::invalid_argument( + "[savez] Cannot use un-named variables and keyword " + arr_name); + } + + arrays_dict.insert({arr_name, arrays_list[i]}); + } + + // Create python ZipFile object depending on compression + py::module_ zipfile = py::module_::import("zipfile"); + int compression = compressed ? zipfile.attr("ZIP_DEFLATED").cast() + : zipfile.attr("ZIP_STORED").cast(); + char mode = 'w'; + ZipFileWrapper zipfile_object(zipfile, file, mode, compression); + + // Save each array + for (auto [k, a] : arrays_dict) { + std::string fname = k + ".npy"; + auto py_ostream = zipfile_object.open(fname, 'w'); + save(std::make_shared(py_ostream), a); + } + + return; +} diff --git a/python/src/ops.cpp b/python/src/ops.cpp new file mode 100644 index 0000000000..ca33cabd2a --- /dev/null +++ b/python/src/ops.cpp @@ -0,0 +1,2422 @@ +#include +#include +#include + +#include +#include +#include + +#include "mlx/ops.h" +#include "mlx/utils.h" +#include "python/src/load.h" +#include "python/src/utils.h" + +namespace py = pybind11; +using namespace py::literals; +using namespace mlx::core; + +using Scalar = std::variant; + +Dtype scalar_to_dtype(Scalar scalar) { + if (std::holds_alternative(scalar)) { + return int32; + } else { + return float32; + } +} + +double scalar_to_double(Scalar s) { + if (std::holds_alternative(s)) { + return std::get(s); + } else { + return static_cast(std::get(s)); + } +} + +void init_ops(py::module_& m) { + m.def( + "reshape", + &reshape, + "a"_a, + py::pos_only(), + "shape"_a, + py::kw_only(), + "stream"_a = none, + R"pbdoc( + Reshape an array while preserving the size. + + Args: + a (array): Input array. + shape (tuple(int)): New shape. + stream (Stream, optional): Stream or device. Defaults to ```None``` + in which case the default stream of the default device is used. + + Returns: + array: The reshaped array. + )pbdoc"); + m.def( + "squeeze", + [](const array& a, const IntOrVec& v, const StreamOrDevice& s) { + if (std::holds_alternative(v)) { + return squeeze(a, s); + } else if (auto pv = std::get_if(&v); pv) { + return squeeze(a, *pv, s); + } else { + return squeeze(a, std::get>(v), s); + } + }, + "a"_a, + py::pos_only(), + "axis"_a = none, + py::kw_only(), + "stream"_a = none, + R"pbdoc( + Remove length one axes from an array. + + Args: + a (array): Input array. + axis (int or tuple(int), optional): Axes to remove. Defaults + to ```None``` in which case all size one axes are removed. + + Returns: + array: The output array with size one axes removed. + )pbdoc"); + m.def( + "expand_dims", + [](const array& a, + const std::variant>& v, + StreamOrDevice s) { + if (auto pv = std::get_if(&v); pv) { + return expand_dims(a, *pv, s); + } else { + return expand_dims(a, std::get>(v), s); + } + }, + "a"_a, + py::pos_only(), + "axis"_a, + py::kw_only(), + "stream"_a = none, + R"pbdoc( + Add a size one dimension at the given axis. + + Args: + a (array): Input array. + axes (int or tuple(int)): The index of the inserted dimensions. + + Returns: + array: The array with inserted dimensions. + )pbdoc"); + m.def( + "abs", + &mlx::core::abs, + "a"_a, + py::pos_only(), + py::kw_only(), + "stream"_a = none, + R"pbdoc( + Element-wise absolute value. + + Args: + a (array): Input array. + + Returns: + array: The absolute value of ``a``. + )pbdoc"); + m.def( + "sign", + &sign, + "a"_a, + py::pos_only(), + py::kw_only(), + "stream"_a = none, + R"pbdoc( + Element-wise sign. + + Args: + a (array): Input array. + + Returns: + array: The sign of ``a``. + )pbdoc"); + m.def( + "negative", + &negative, + "a"_a, + py::pos_only(), + py::kw_only(), + "stream"_a = none, + R"pbdoc( + Element-wise negation. + + Args: + a (array): Input array. + + Returns: + array: The negative of ``a``. + )pbdoc"); + m.def( + "add", + [](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) { + auto [a, b] = to_arrays(a_, b_); + return add(a, b, s); + }, + "a"_a, + "b"_a, + py::pos_only(), + py::kw_only(), + "stream"_a = none, + R"pbdoc( + Element-wise addition. + + Add two arrays with numpy-style broadcasting semantics. Either or both input arrays + can also be scalars. + + Args: + a (array): Input array or scalar. + b (array): Input array or scalar. + + Returns: + array: The sum of ``a`` and ``b``. + )pbdoc"); + m.def( + "subtract", + [](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) { + auto [a, b] = to_arrays(a_, b_); + return subtract(a, b, s); + }, + "a"_a, + "b"_a, + py::pos_only(), + py::kw_only(), + "stream"_a = none, + R"pbdoc( + Element-wise subtraction. + + Subtract one array from another with numpy-style broadcasting semantics. Either or both + input arrays can also be scalars. + + Args: + a (array): Input array or scalar. + b (array): Input array or scalar. + + Returns: + array: The difference ``a - b``. + )pbdoc"); + m.def( + "multiply", + [](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) { + auto [a, b] = to_arrays(a_, b_); + return multiply(a, b, s); + }, + "a"_a, + "b"_a, + py::pos_only(), + py::kw_only(), + "stream"_a = none, + R"pbdoc( + Element-wise multiplication. + + Multiply two arrays with numpy-style broadcasting semantics. Either or both + input arrays can also be scalars. + + Args: + a (array): Input array or scalar. + b (array): Input array or scalar. + + Returns: + array: The multiplication ``a * b``. + )pbdoc"); + m.def( + "divide", + [](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) { + auto [a, b] = to_arrays(a_, b_); + return divide(a, b, s); + }, + "a"_a, + "b"_a, + py::pos_only(), + py::kw_only(), + "stream"_a = none, + R"pbdoc( + Element-wise division. + + Divide two arrays with numpy-style broadcasting semantics. Either or both + input arrays can also be scalars. + + Args: + a (array): Input array or scalar. + b (array): Input array or scalar. + + Returns: + array: The quotient ``a / b``. + )pbdoc"); + m.def( + "equal", + [](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) { + auto [a, b] = to_arrays(a_, b_); + return equal(a, b, s); + }, + "a"_a, + "b"_a, + py::pos_only(), + py::kw_only(), + "stream"_a = none, + R"pbdoc( + Element-wise equality. + + Equality comparison on two arrays with numpy-style broadcasting semantics. + Either or both input arrays can also be scalars. + + Args: + a (array): Input array or scalar. + b (array): Input array or scalar. + + Returns: + array: The element-wise comparison ``a == b``. + )pbdoc"); + m.def( + "not_equal", + [](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) { + auto [a, b] = to_arrays(a_, b_); + return not_equal(a, b, s); + }, + "a"_a, + "b"_a, + py::pos_only(), + py::kw_only(), + "stream"_a = none, + R"pbdoc( + Element-wise not equal. + + Not equal comparison on two arrays with numpy-style broadcasting semantics. + Either or both input arrays can also be scalars. + + Args: + a (array): Input array or scalar. + b (array): Input array or scalar. + + Returns: + array: The element-wise comparison ``a != b``. + )pbdoc"); + m.def( + "less", + [](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) { + auto [a, b] = to_arrays(a_, b_); + return less(a, b, s); + }, + "a"_a, + "b"_a, + py::pos_only(), + py::kw_only(), + "stream"_a = none, + R"pbdoc( + Element-wise less than. + + Strict less than on two arrays with numpy-style broadcasting semantics. + Either or both input arrays can also be scalars. + + Args: + a (array): Input array or scalar. + b (array): Input array or scalar. + + Returns: + array: The element-wise comparison ``a < b``. + )pbdoc"); + m.def( + "less_equal", + [](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) { + auto [a, b] = to_arrays(a_, b_); + return less_equal(a, b, s); + }, + "a"_a, + "b"_a, + py::pos_only(), + py::kw_only(), + "stream"_a = none, + R"pbdoc( + Element-wise less than or equal. + + Less than or equal on two arrays with numpy-style broadcasting semantics. + Either or both input arrays can also be scalars. + + Args: + a (array): Input array or scalar. + b (array): Input array or scalar. + + Returns: + array: The element-wise comparison ``a <= b``. + )pbdoc"); + m.def( + "greater", + [](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) { + auto [a, b] = to_arrays(a_, b_); + return greater(a, b, s); + }, + "a"_a, + "b"_a, + py::pos_only(), + py::kw_only(), + "stream"_a = none, + R"pbdoc( + Element-wise greater than. + + Strict greater than on two arrays with numpy-style broadcasting semantics. + Either or both input arrays can also be scalars. + + Args: + a (array): Input array or scalar. + b (array): Input array or scalar. + + Returns: + array: The element-wise comparison ``a > b``. + )pbdoc"); + m.def( + "greater_equal", + [](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) { + auto [a, b] = to_arrays(a_, b_); + return greater_equal(a, b, s); + }, + "a"_a, + "b"_a, + py::pos_only(), + py::kw_only(), + "stream"_a = none, + R"pbdoc( + Element-wise greater or equal. + + Greater than or equal on two arrays with numpy-style broadcasting semantics. + Either or both input arrays can also be scalars. + + Args: + a (array): Input array or scalar. + b (array): Input array or scalar. + + Returns: + array: The element-wise comparison ``a >= b``. + )pbdoc"); + m.def( + "array_equal", + [](const ScalarOrArray& a_, + const ScalarOrArray& b_, + bool equal_nan, + StreamOrDevice s) { + auto [a, b] = to_arrays(a_, b_); + return array_equal(a, b, equal_nan, s); + }, + "a"_a, + "b"_a, + py::pos_only(), + py::kw_only(), + "equal_nan"_a = false, + "stream"_a = none, + R"pbdoc( + Array equality check. + + Compare two arrays for equality. Returns ``True`` if and only if the arrays + have the same shape and their values are equal. The arrays need not have + the same type to be considered equal. + + Args: + a (array): Input array or scalar. + b (array): Input array or scalar. + equal_nan (bool): If ``True``, NaNs are treated as equal. + Defaults to ``False``. + + Returns: + array: A scalar boolean array. + )pbdoc"); + m.def( + "matmul", + &matmul, + "a"_a, + "b"_a, + py::pos_only(), + py::kw_only(), + "stream"_a = none, + R"pbdoc( + Matrix multiplication. + + Perform the (possibly batched) matrix multiplication of two arrays. This function supports + broadcasting for arrays with more than two dimensions. + + - If the first array is 1-D then a 1 is prepended to its shape to make it + a matrix. Similarly if the second array is 1-D then a 1 is appended to its + shape to make it a matrix. In either case the singleton dimension is removed + from the result. + - A batched matrix multiplication is performed if the arrays have more than + 2 dimensions. The matrix dimensions for the matrix product are the last + two dimensions of each input. + - All but the last two dimensions of each input are broadcast with one another using + standard numpy-style broadcasting semantics. + + Args: + a (array): Input array or scalar. + b (array): Input array or scalar. + + Returns: + array: The matrix product of ``a`` and ``b``. + )pbdoc"); + m.def( + "square", + &square, + "a"_a, + py::pos_only(), + py::kw_only(), + "stream"_a = none, + R"pbdoc( + Element-wise square. + + Args: + a (array): Input array. + + Returns: + array: The square of ``a``. + )pbdoc"); + m.def( + "sqrt", + &mlx::core::sqrt, + "a"_a, + py::pos_only(), + py::kw_only(), + "stream"_a = none, + R"pbdoc( + Element-wise square root. + + Args: + a (array): Input array. + + Returns: + array: The square root of ``a``. + )pbdoc"); + m.def( + "rsqrt", + &rsqrt, + "a"_a, + py::pos_only(), + py::kw_only(), + "stream"_a = none, + R"pbdoc( + Element-wise reciprocal and square root. + + Args: + a (array): Input array. + + Returns: + array: One over the square root of ``a``. + )pbdoc"); + m.def( + "reciprocal", + &reciprocal, + "a"_a, + py::pos_only(), + py::kw_only(), + "stream"_a = none, + R"pbdoc( + Element-wise reciprocal. + + Args: + a (array): Input array. + + Returns: + array: The reciprocal of ``a``. + )pbdoc"); + m.def( + "logical_not", + [](const ScalarOrArray& a, StreamOrDevice s) { + return logical_not(to_array(a), s); + }, + "a"_a, + py::pos_only(), + py::kw_only(), + "stream"_a = none, + R"pbdoc( + Element-wise logical not. + + Args: + a (array): Input array or scalar. + + Returns: + array: The boolean array containing the logical not of ``a``. + )pbdoc"); + m.def( + "logaddexp", + [](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) { + auto [a, b] = to_arrays(a_, b_); + return logaddexp(a, b, s); + }, + "a"_a, + "b"_a, + py::pos_only(), + py::kw_only(), + "stream"_a = none, + R"pbdoc( + Element-wise log-add-exp. + + This is a numerically stable log-add-exp of two arrays with numpy-style + broadcasting semantics. Either or both input arrays can also be scalars. + + The computation is is a numerically stable version of ``log(exp(a) + exp(b))``. + + Args: + a (array): Input array or scalar. + b (array): Input array or scalar. + + Returns: + array: The log-add-exp of ``a`` and ``b``. + )pbdoc"); + m.def( + "exp", + &mlx::core::exp, + "a"_a, + py::pos_only(), + py::kw_only(), + "stream"_a = none, + R"pbdoc( + Element-wise exponential. + + Args: + a (array): Input array. + + Returns: + array: The exponential of ``a``. + )pbdoc"); + m.def( + "erf", + &mlx::core::erf, + "a"_a, + py::pos_only(), + py::kw_only(), + "stream"_a = none, + R"pbdoc( + Element-wise error function. + + .. math:: + \mathrm{erf}(x) = \frac{2}{\sqrt{\pi}} \int_0^t e^{-t^2} \, dx + + Args: + a (array): Input array. + + Returns: + array: The error function of ``a``. + )pbdoc"); + m.def( + "erfinv", + &mlx::core::erfinv, + "a"_a, + py::pos_only(), + py::kw_only(), + "stream"_a = none, + R"pbdoc( + Element-wise inverse of :func:`erf`. + + Args: + a (array): Input array. + + Returns: + array: The inverse error function of ``a``. + )pbdoc"); + m.def( + "sin", + &mlx::core::sin, + "a"_a, + py::pos_only(), + py::kw_only(), + "stream"_a = none, + R"pbdoc( + Element-wise sine. + + Args: + a (array): Input array. + + Returns: + array: The sine of ``a``. + )pbdoc"); + m.def( + "cos", + &mlx::core::cos, + "a"_a, + py::pos_only(), + py::kw_only(), + "stream"_a = none, + R"pbdoc( + Element-wise cosine. + + Args: + a (array): Input array. + + Returns: + array: The cosine of ``a``. + )pbdoc"); + m.def( + "tan", + &mlx::core::tan, + "a"_a, + py::pos_only(), + py::kw_only(), + "stream"_a = none, + R"pbdoc( + Element-wise tangent. + + Args: + a (array): Input array. + + Returns: + array: The tangent of ``a``. + )pbdoc"); + m.def( + "arcsin", + &mlx::core::arcsin, + "a"_a, + py::pos_only(), + py::kw_only(), + "stream"_a = none, + R"pbdoc( + Element-wise inverse sine. + + Args: + a (array): Input array. + + Returns: + array: The inverse sine of ``a``. + )pbdoc"); + m.def( + "arccos", + &mlx::core::arccos, + "a"_a, + py::pos_only(), + py::kw_only(), + "stream"_a = none, + R"pbdoc( + Element-wise inverse cosine. + + Args: + a (array): Input array. + + Returns: + array: The inverse cosine of ``a``. + )pbdoc"); + m.def( + "arctan", + &mlx::core::arctan, + "a"_a, + py::pos_only(), + py::kw_only(), + "stream"_a = none, + R"pbdoc( + Element-wise inverse tangent. + + Args: + a (array): Input array. + + Returns: + array: The inverse tangent of ``a``. + )pbdoc"); + m.def( + "sinh", + &mlx::core::sinh, + "a"_a, + py::pos_only(), + py::kw_only(), + "stream"_a = none, + R"pbdoc( + Element-wise hyperbolic sine. + + Args: + a (array): Input array. + + Returns: + array: The hyperbolic sine of ``a``. + )pbdoc"); + m.def( + "cosh", + &mlx::core::cosh, + "a"_a, + py::pos_only(), + py::kw_only(), + "stream"_a = none, + R"pbdoc( + Element-wise hyperbolic cosine. + + Args: + a (array): Input array. + + Returns: + array: The hyperbolic cosine of ``a``. + )pbdoc"); + m.def( + "tanh", + &mlx::core::tanh, + "a"_a, + py::pos_only(), + py::kw_only(), + "stream"_a = none, + R"pbdoc( + Element-wise hyperbolic tangent. + + Args: + a (array): Input array. + + Returns: + array: The hyperbolic tangent of ``a``. + )pbdoc"); + m.def( + "arcsinh", + &mlx::core::arcsinh, + "a"_a, + py::pos_only(), + py::kw_only(), + "stream"_a = none, + R"pbdoc( + Element-wise inverse hyperbolic sine. + + Args: + a (array): Input array. + + Returns: + array: The inverse hyperbolic sine of ``a``. + )pbdoc"); + m.def( + "arccosh", + &mlx::core::arccosh, + "a"_a, + py::pos_only(), + py::kw_only(), + "stream"_a = none, + R"pbdoc( + Element-wise inverse hyperbolic cosine. + + Args: + a (array): Input array. + + Returns: + array: The inverse hyperbolic cosine of ``a``. + )pbdoc"); + m.def( + "arctanh", + &mlx::core::arctanh, + "a"_a, + py::pos_only(), + py::kw_only(), + "stream"_a = none, + R"pbdoc( + Element-wise inverse hyperbolic tangent. + + Args: + a (array): Input array. + + Returns: + array: The inverse hyperbolic tangent of ``a``. + )pbdoc"); + m.def( + "log", + &mlx::core::log, + "a"_a, + py::pos_only(), + py::kw_only(), + "stream"_a = none, + R"pbdoc( + Element-wise natural logarithm. + + Args: + a (array): Input array. + + Returns: + array: The natural logarithm of ``a``. + )pbdoc"); + m.def( + "log2", + &mlx::core::log2, + "a"_a, + py::pos_only(), + py::kw_only(), + "stream"_a = none, + R"pbdoc( + Element-wise base-2 logarithm. + + Args: + a (array): Input array. + + Returns: + array: The base-2 logarithm of ``a``. + )pbdoc"); + m.def( + "log10", + &mlx::core::log10, + "a"_a, + py::pos_only(), + py::kw_only(), + "stream"_a = none, + R"pbdoc( + Element-wise base-10 logarithm. + + Args: + a (array): Input array. + + Returns: + array: The base-10 logarithm of ``a``. + )pbdoc"); + m.def( + "log1p", + &mlx::core::log1p, + "a"_a, + py::pos_only(), + py::kw_only(), + "stream"_a = none, + R"pbdoc( + Element-wise natural log of one plus the array. + + Args: + a (array): Input array. + + Returns: + array: The natural logarithm of one plus ``a``. + )pbdoc"); + m.def( + "stop_gradient", + &stop_gradient, + "a"_a, + py::pos_only(), + py::kw_only(), + "stream"_a = none, + R"pbdoc( + Stop gradients from being computed. + + The operation is the identity but it prevents gradients from flowing + through the array. + + Args: + a (array): Input array. + + Returns: + array: The unchanged input ``a`` but without gradient flowing + through it. + )pbdoc"); + m.def( + "sigmoid", + &sigmoid, + "a"_a, + py::pos_only(), + py::kw_only(), + "stream"_a = none, + R"pbdoc( + Element-wise logistic sigmoid. + + The logistic sigmoid function is: + + .. math:: + \mathrm{sigmoid}(x) = \frac{1}{1 + e^{-x}} + + Args: + a (array): Input array. + + Returns: + array: The logistic sigmoid of ``a``. + )pbdoc"); + m.def( + "power", + [](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) { + auto [a, b] = to_arrays(a_, b_); + return power(a, b, s); + }, + "a"_a, + "b"_a, + py::pos_only(), + py::kw_only(), + "stream"_a = none, + R"pbdoc( + Element-wise power operation. + + Raise the elements of a to the powers in elements of b with numpy-style + broadcasting semantics. Either or both input arrays can also be scalars. + + Args: + a (array): Input array or scalar. + b (array): Input array or scalar. + + Returns: + array: Bases of ``a`` raised to powers in ``b``. + )pbdoc"); + { + // Disable function signature just for arange which we write manually + py::options options; + options.disable_function_signatures(); + m.def( + "arange", + [](Scalar stop, std::optional dtype_, StreamOrDevice s) { + Dtype dtype = + dtype_.has_value() ? dtype_.value() : scalar_to_dtype(stop); + + return arange(0.0, scalar_to_double(stop), 1.0, dtype, s); + }, + "stop"_a, + "dtype"_a = none, + "stream"_a = none); + m.def( + "arange", + [](Scalar start, + Scalar stop, + std::optional dtype_, + StreamOrDevice s) { + Dtype dtype = dtype_.has_value() + ? dtype_.value() + : promote_types(scalar_to_dtype(start), scalar_to_dtype(stop)); + return arange( + scalar_to_double(start), scalar_to_double(stop), dtype, s); + }, + "start"_a, + "stop"_a, + "dtype"_a = none, + "stream"_a = none); + m.def( + "arange", + [](Scalar stop, + Scalar step, + std::optional dtype_, + StreamOrDevice s) { + Dtype dtype = dtype_.has_value() + ? dtype_.value() + : promote_types(scalar_to_dtype(stop), scalar_to_dtype(step)); + + return arange( + 0.0, scalar_to_double(stop), scalar_to_double(step), dtype, s); + }, + "stop"_a, + "step"_a, + "dtype"_a = none, + "stream"_a = none); + m.def( + "arange", + [](Scalar start, + Scalar stop, + Scalar step, + std::optional dtype_, + StreamOrDevice s) { + // Determine the final dtype based on input types + Dtype dtype = dtype_.has_value() + ? dtype_.value() + : promote_types( + scalar_to_dtype(start), + promote_types( + scalar_to_dtype(stop), scalar_to_dtype(step))); + + return arange( + scalar_to_double(start), + scalar_to_double(stop), + scalar_to_double(step), + dtype, + s); + }, + "start"_a, + "stop"_a, + "step"_a, + "dtype"_a = none, + "stream"_a = none, + R"pbdoc( + arange(start, stop, step, dtype: Optional[Dtype] = None, *, stream: Union[None, Stream, Device] = None) -> array + + Generates ranges of numbers. + + Generate numbers in the half-open interval ``[start, stop)`` in + increments of ``step``. + + Args: + start (float or int, optional): Starting value which defaults to ``0``. + stop (float or int): Stopping value. + step (float or int, optional): Increment which defaults to ``1``. + dtype (Dtype, optional): Specifies the data type of the output. + If unspecified will default to ``float32`` if any of ``start``, + ``stop``, or ``step`` are ``float``. Otherwise will default to + ``int32``. + + Returns: + array: The range of values. + + Note: + Following the Numpy convention the actual increment used to + generate numbers is ``dtype(start + step) - dtype(start)``. + This can lead to unexpected results for example if `start + step` + is a fractional value and the `dtype` is integral. + )pbdoc"); + } + m.def( + "take", + [](const array& a, + const array& indices, + const std::optional& axis, + StreamOrDevice s) { + if (axis.has_value()) { + return take(a, indices, axis.value(), s); + } else { + return take(a, indices, s); + } + }, + "a"_a, + py::pos_only(), + "indices"_a, + "axis"_a = std::nullopt, + py::kw_only(), + "stream"_a = none, + R"pbdoc( + Take elements along an axis. + + The elements are taken from ``indices`` along the specified axis. + If the axis is not specified the array is treated as a flattened + 1-D array prior to performing the take. + + As an example, if the ``axis=1`` this is equialent to ``a[:, indices, ...]``. + + Args: + a (array): Input array. + indices (array): Input array with integral type. + axis (int, optional): Axis along which to perform the take. If unspecified + the array is treated as a flattened 1-D vector. + + Returns: + array: The indexed values of ``a``. + )pbdoc"); + m.def( + "take_along_axis", + [](const array& a, + const array& indices, + const std::optional& axis, + StreamOrDevice s) { + if (axis.has_value()) { + return take_along_axis(a, indices, axis.value(), s); + } else { + return take_along_axis(reshape(a, {-1}, s), indices, 0, s); + } + }, + "a"_a, + py::pos_only(), + "indices"_a, + "axis"_a, + py::kw_only(), + "stream"_a = none, + R"pbdoc( + Take values along an axis at the specified indices. + + Args: + a (array): Input array. + indices (array): Indices array. These should be broadcastable with + the input array excluding the `axis` dimension. + axis (int or None): Axis in the input to take the values from. If + ``axis == None`` the array is flattened to 1D prior to the indexing + operation. + + Returns: + array: The output array with the specified shape and values. + )pbdoc"); + m.def( + "full", + [](const std::variant>& shape, + const ScalarOrArray& vals, + std::optional dtype, + StreamOrDevice s) { + if (auto pv = std::get_if(&shape); pv) { + return full({*pv}, to_array(vals, dtype), s); + } else { + return full( + std::get>(shape), to_array(vals, dtype), s); + } + }, + "shape"_a, + "vals"_a, + "dtype"_a = std::nullopt, + py::kw_only(), + "stream"_a = none, + R"pbdoc( + Construct an array with the given value. + + Constructs an array of size ``shape`` filled with ``vals``. If ``vals`` + is an :obj:`array` it must be broadcastable to the given ``shape``. + + Args: + shape (int or list(int)): The shape of the output array. + vals (float or int or array): Values to fill the array with. + dtype (Dtype, optional): Data type of the output array. If + unspecified the output type is inferred from ``vals``. + + Returns: + array: The output array with the specified shape and values. + )pbdoc"); + m.def( + "zeros", + [](const std::variant>& shape, + std::optional dtype, + StreamOrDevice s) { + auto t = dtype.value_or(float32); + if (auto pv = std::get_if(&shape); pv) { + return zeros({*pv}, t, s); + } else { + return zeros(std::get>(shape), t, s); + } + }, + "shape"_a, + "dtype"_a = std::nullopt, + py::kw_only(), + "stream"_a = none, + R"pbdoc( + Construct an array of zeros. + + Args: + shape (int or list(int)): The shape of the output array. + dtype (Dtype, optional): Data type of the output array. If + unspecified the output type defaults to ``float32``. + + Returns: + array: The array of zeros with the specified shape. + )pbdoc"); + m.def( + "zeros_like", + &zeros_like, + "a"_a, + py::pos_only(), + py::kw_only(), + "stream"_a = none, + R"pbdoc( + An array of zeros like the input. + + Args: + a (array): The input to take the shape and type from. + + Returns: + array: The output array filled with zeros. + )pbdoc"); + m.def( + "ones", + [](const std::variant>& shape, + std::optional dtype, + StreamOrDevice s) { + auto t = dtype.value_or(float32); + if (auto pv = std::get_if(&shape); pv) { + return ones({*pv}, t, s); + } else { + return ones(std::get>(shape), t, s); + } + }, + "shape"_a, + "dtype"_a = std::nullopt, + py::kw_only(), + "stream"_a = none, + R"pbdoc( + Construct an array of ones. + + Args: + shape (int or list(int)): The shape of the output array. + dtype (Dtype, optional): Data type of the output array. If + unspecified the output type defaults to ``float32``. + + Returns: + array: The array of ones with the specified shape. + )pbdoc"); + m.def( + "ones_like", + &ones_like, + "a"_a, + py::pos_only(), + py::kw_only(), + "stream"_a = none, + R"pbdoc( + An array of ones like the input. + + Args: + a (array): The input to take the shape and type from. + + Returns: + array: The output array filled with ones. + )pbdoc"); + m.def( + "allclose", + &allclose, + "a"_a, + "b"_a, + py::pos_only(), + "rtol"_a = 1e-5, + "atol"_a = 1e-8, + py::kw_only(), + "stream"_a = none, + R"pbdoc( + Approximate comparison of two arrays. + + The arrays are considered equal if: + + .. code-block:: + + all(abs(a - b) <= (atol + rtol * abs(b))) + + Note unlike :func:`array_equal`, this function supports numpy-style + broadcasting. + + Args: + a (array): Input array. + b (array): Input array. + rtol (float): Relative tolerance. + atol (float): Absolute tolerance. + + Returns: + array: The boolean output scalar indicating if the arrays are close. + )pbdoc"); + m.def( + "all", + [](const array& a, + const IntOrVec& axis, + bool keepdims, + StreamOrDevice s) { + return all(a, get_reduce_axes(axis, a.ndim()), keepdims, s); + }, + "a"_a, + py::pos_only(), + "axis"_a = none, + "keepdims"_a = false, + py::kw_only(), + "stream"_a = none, + R"pbdoc( + An `and` reduction over the given axes. + + Args: + a (array): Input array. + axis (int or list(int), optional): Optional axis or + axes to reduce over. If unspecified this defaults + to reducing over the entire array. + keepdims (bool, optional): Keep reduced axes as + singleton dimensions, defaults to `False`. + + Returns: + array: The output array with the corresponding axes reduced. + )pbdoc"); + m.def( + "any", + [](const array& a, + const IntOrVec& axis, + bool keepdims, + StreamOrDevice s) { + return any(a, get_reduce_axes(axis, a.ndim()), keepdims, s); + }, + "a"_a, + py::pos_only(), + "axis"_a = none, + "keepdims"_a = false, + py::kw_only(), + "stream"_a = none, + R"pbdoc( + An `or` reduction over the given axes. + + Args: + a (array): Input array. + axis (int or list(int), optional): Optional axis or + axes to reduce over. If unspecified this defaults + to reducing over the entire array. + keepdims (bool, optional): Keep reduced axes as + singleton dimensions, defaults to `False`. + + Returns: + array: The output array with the corresponding axes reduced. + )pbdoc"); + m.def( + "minimum", + [](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) { + auto [a, b] = to_arrays(a_, b_); + return minimum(a, b, s); + }, + "a"_a, + "b"_a, + py::pos_only(), + "stream"_a = none, + R"pbdoc( + Element-wise minimum. + + Take the element-wise min of two arrays with numpy-style broadcasting + semantics. Either or both input arrays can also be scalars. + + Args: + a (array): Input array or scalar. + b (array): Input array or scalar. + + Returns: + array: The min of ``a`` and ``b``. + )pbdoc"); + m.def( + "maximum", + [](const ScalarOrArray& a_, const ScalarOrArray& b_, StreamOrDevice s) { + auto [a, b] = to_arrays(a_, b_); + return maximum(a, b, s); + }, + "a"_a, + "b"_a, + py::pos_only(), + py::kw_only(), + "stream"_a = none, + R"pbdoc( + Element-wise maximum. + + Take the element-wise max of two arrays with numpy-style broadcasting + semantics. Either or both input arrays can also be scalars. + + Args: + a (array): Input array or scalar. + b (array): Input array or scalar. + + Returns: + array: The max of ``a`` and ``b``. + )pbdoc"); + m.def( + "transpose", + [](const array& a, + const std::optional>& axes, + StreamOrDevice s) { + if (axes.has_value()) { + return transpose(a, get_reduce_axes(axes.value(), a.ndim()), s); + } else { + return transpose(a, s); + } + }, + "a"_a, + py::pos_only(), + "axes"_a = std::nullopt, + py::kw_only(), + "stream"_a = none, + R"pbdoc( + Transpose the dimensions of the array. + + Args: + a (array): Input array. + axes (list(int), optional): Specifies the source axis for each axis + in the new array. The default is to reverse the axes. + + Returns: + array: The transposed array. + )pbdoc"); + m.def( + "sum", + [](const array& a, + const IntOrVec& axis, + bool keepdims, + StreamOrDevice s) { + return sum(a, get_reduce_axes(axis, a.ndim()), keepdims, s); + }, + "array"_a, + py::pos_only(), + "axis"_a = none, + "keepdims"_a = false, + py::kw_only(), + "stream"_a = none, + R"pbdoc( + Sum reduce the array over the given axes. + + Args: + a (array): Input array. + axis (int or list(int), optional): Optional axis or + axes to reduce over. If unspecified this defaults + to reducing over the entire array. + keepdims (bool, optional): Keep reduced axes as + singleton dimensions, defaults to `False`. + + Returns: + array: The output array with the corresponding axes reduced. + )pbdoc"); + m.def( + "prod", + [](const array& a, + const IntOrVec& axis, + bool keepdims, + StreamOrDevice s) { + return prod(a, get_reduce_axes(axis, a.ndim()), keepdims, s); + }, + "a"_a, + py::pos_only(), + "axis"_a = none, + "keepdims"_a = false, + py::kw_only(), + "stream"_a = none, + R"pbdoc( + An product reduction over the given axes. + + Args: + a (array): Input array. + axis (int or list(int), optional): Optional axis or + axes to reduce over. If unspecified this defaults + to reducing over the entire array. + keepdims (bool, optional): Keep reduced axes as + singleton dimensions, defaults to `False`. + + Returns: + array: The output array with the corresponding axes reduced. + )pbdoc"); + m.def( + "min", + [](const array& a, + const IntOrVec& axis, + bool keepdims, + StreamOrDevice s) { + return min(a, get_reduce_axes(axis, a.ndim()), keepdims, s); + }, + "a"_a, + py::pos_only(), + "axis"_a = none, + "keepdims"_a = false, + py::kw_only(), + "stream"_a = none, + R"pbdoc( + An `min` reduction over the given axes. + + Args: + a (array): Input array. + axis (int or list(int), optional): Optional axis or + axes to reduce over. If unspecified this defaults + to reducing over the entire array. + keepdims (bool, optional): Keep reduced axes as + singleton dimensions, defaults to `False`. + + Returns: + array: The output array with the corresponding axes reduced. + )pbdoc"); + m.def( + "max", + [](const array& a, + const IntOrVec& axis, + bool keepdims, + StreamOrDevice s) { + return max(a, get_reduce_axes(axis, a.ndim()), keepdims, s); + }, + "a"_a, + py::pos_only(), + "axis"_a = none, + "keepdims"_a = false, + py::kw_only(), + "stream"_a = none, + R"pbdoc( + An `max` reduction over the given axes. + + Args: + a (array): Input array. + axis (int or list(int), optional): Optional axis or + axes to reduce over. If unspecified this defaults + to reducing over the entire array. + keepdims (bool, optional): Keep reduced axes as + singleton dimensions, defaults to `False`. + + Returns: + array: The output array with the corresponding axes reduced. + )pbdoc"); + m.def( + "logsumexp", + [](const array& a, + const IntOrVec& axis, + bool keepdims, + StreamOrDevice s) { + return logsumexp(a, get_reduce_axes(axis, a.ndim()), keepdims, s); + }, + "a"_a, + py::pos_only(), + "axis"_a = none, + "keepdims"_a = false, + py::kw_only(), + "stream"_a = none, + R"pbdoc( + A `log-sum-exp` reduction over the given axes. + + The log-sum-exp reduction is a numerically stable version of: + + .. code-block:: + + log(sum(exp(a), axis)) + + Args: + a (array): Input array. + axis (int or list(int), optional): Optional axis or + axes to reduce over. If unspecified this defaults + to reducing over the entire array. + keepdims (bool, optional): Keep reduced axes as + singleton dimensions, defaults to `False`. + + Returns: + array: The output array with the corresponding axes reduced. + )pbdoc"); + m.def( + "mean", + [](const array& a, + const IntOrVec& axis, + bool keepdims, + StreamOrDevice s) { + return mean(a, get_reduce_axes(axis, a.ndim()), keepdims, s); + }, + "a"_a, + py::pos_only(), + "axis"_a = none, + "keepdims"_a = false, + py::kw_only(), + "stream"_a = none, + R"pbdoc( + Compute the mean(s) over the given axes. + + Args: + a (array): Input array. + axis (int or list(int), optional): Optional axis or + axes to reduce over. If unspecified this defaults + to reducing over the entire array. + keepdims (bool, optional): Keep reduced axes as + singleton dimensions, defaults to `False`. + + Returns: + array: The output array of means. + )pbdoc"); + m.def( + "var", + [](const array& a, + const IntOrVec& axis, + bool keepdims, + int ddof, + StreamOrDevice s) { + return var(a, get_reduce_axes(axis, a.ndim()), keepdims, ddof, s); + }, + "a"_a, + py::pos_only(), + "axis"_a = none, + "keepdims"_a = false, + "ddof"_a = 0, + py::kw_only(), + "stream"_a = none, + R"pbdoc( + Compute the variance(s) over the given axes. + + Args: + a (array): Input array. + axis (int or list(int), optional): Optional axis or + axes to reduce over. If unspecified this defaults + to reducing over the entire array. + keepdims (bool, optional): Keep reduced axes as + singleton dimensions, defaults to `False`. + ddof (int, optional): The divisor to compute the variance + is ``N - ddof``, defaults to 0. + + Returns: + array: The output array of variances. + )pbdoc"); + m.def( + "split", + [](const array& a, + const std::variant>& indices_or_sections, + int axis, + StreamOrDevice s) { + if (auto pv = std::get_if(&indices_or_sections); pv) { + return split(a, *pv, axis, s); + } else { + return split( + a, std::get>(indices_or_sections), axis, s); + } + }, + "a"_a, + py::pos_only(), + "indices_or_sections"_a, + "axis"_a = 0, + py::kw_only(), + "stream"_a = none, + R"pbdoc( + Split an array along a given axis. + + Args: + a (array): Input array. + indices_or_sections (int or list(int)): If ``indices_or_sections`` + is an integer the array is split into that many sections of equal + size. An error is raised if this is not possible. If ``indices_or_sections`` + is a list, the list contains the indices of the start of each subarray + along the given axis. + axis (int, optional): Axis to split along, defaults to `0`. + + Returns: + list(array): A list of split arrays. + )pbdoc"); + m.def( + "argmin", + [](const array& a, + std::optional axis, + bool keepdims, + StreamOrDevice s) { + if (axis) { + return argmin(a, *axis, keepdims, s); + } else { + return argmin(a, keepdims, s); + } + }, + "a"_a, + py::pos_only(), + "axis"_a = std::nullopt, + "keepdims"_a = false, + py::kw_only(), + "stream"_a = none, + R"pbdoc( + Indices of the minimum values along the axis. + + Args: + a (array): Input array. + axis (int, optional): Optional axis to reduce over. If unspecified + this defaults to reducing over the entire array. + keepdims (bool, optional): Keep reduced axes as + singleton dimensions, defaults to `False`. + + Returns: + array: The output array with the indices of the minimum values. + )pbdoc"); + m.def( + "argmax", + [](const array& a, + std::optional axis, + bool keepdims, + StreamOrDevice s) { + if (axis) { + return argmax(a, *axis, keepdims, s); + } else { + return argmax(a, keepdims, s); + } + }, + "a"_a, + py::pos_only(), + "axis"_a = std::nullopt, + "keepdims"_a = false, + py::kw_only(), + "stream"_a = none, + R"pbdoc( + Indices of the maximum values along the axis. + + Args: + a (array): Input array. + axis (int, optional): Optional axis to reduce over. If unspecified + this defaults to reducing over the entire array. + keepdims (bool, optional): Keep reduced axes as + singleton dimensions, defaults to `False`. + + Returns: + array: The output array with the indices of the minimum values. + )pbdoc"); + m.def( + "sort", + [](const array& a, std::optional axis, StreamOrDevice s) { + if (axis) { + return sort(a, *axis, s); + } else { + return sort(a, s); + } + }, + "a"_a, + py::pos_only(), + "axis"_a = -1, + py::kw_only(), + "stream"_a = none, + R"pbdoc( + Returns a sorted copy of the array. + + Args: + a (array): Input array. + axis (int or None, optional): Optional axis to sort over. + If ``None``, this sorts over the flattened array. + If unspecified, it defaults to -1 (sorting over the last axis). + + Returns: + array: The sorted array. + )pbdoc"); + m.def( + "argsort", + [](const array& a, std::optional axis, StreamOrDevice s) { + if (axis) { + return argsort(a, *axis, s); + } else { + return argsort(a, s); + } + }, + "a"_a, + py::pos_only(), + "axis"_a = -1, + py::kw_only(), + "stream"_a = none, + R"pbdoc( + Returns the indices that sort the array. + + Args: + a (array): Input array. + axis (int or None, optional): Optional axis to sort over. + If ``None``, this sorts over the flattened array. + If unspecified, it defaults to -1 (sorting over the last axis). + + Returns: + array: The indices that sort the input array. + )pbdoc"); + m.def( + "partition", + [](const array& a, int kth, std::optional axis, StreamOrDevice s) { + if (axis) { + return partition(a, kth, *axis, s); + } else { + return partition(a, kth, s); + } + }, + "a"_a, + py::pos_only(), + "kth"_a, + "axis"_a = -1, + py::kw_only(), + "stream"_a = none, + R"pbdoc( + Returns a partitioned copy of the array such that the smaller ``kth`` + elements are first. + + The ordering of the elements in partitions is undefined. + + Args: + a (array): Input array. + kth (int): Element at the ``kth`` index will be in its sorted + position in the output. All elements before the kth index will + be less or equal to the ``kth`` element and all elements after + will be greater or equal to the ``kth`` element in the output. + axis (int or None, optional): Optional axis to partition over. + If ``None``, this partitions over the flattened array. + If unspecified, it defaults to ``-1``. + + Returns: + array: The partitioned array. + )pbdoc"); + m.def( + "argpartition", + [](const array& a, int kth, std::optional axis, StreamOrDevice s) { + if (axis) { + return argpartition(a, kth, *axis, s); + } else { + return argpartition(a, kth, s); + } + }, + "a"_a, + py::pos_only(), + "kth"_a, + "axis"_a = -1, + py::kw_only(), + "stream"_a = none, + R"pbdoc( + Returns the indices that partition the array. + + The ordering of the elements within a partition in given by the indices + is undefined. + + Args: + a (array): Input array. + kth (int): Element index at the ``kth`` position in the output will + give the sorted position. All indices before the ``kth`` position + will be of elements less or equal to the element at the ``kth`` + index and all indices after will be of elements greater or equal + to the element at the ``kth`` index. + axis (int or None, optional): Optional axis to partiton over. + If ``None``, this partitions over the flattened array. + If unspecified, it defaults to ``-1``. + + Returns: + array: The indices that partition the input array. + )pbdoc"); + m.def( + "topk", + [](const array& a, int k, std::optional axis, StreamOrDevice s) { + if (axis) { + return topk(a, k, *axis, s); + } else { + return topk(a, k, s); + } + }, + "a"_a, + py::pos_only(), + "k"_a, + "axis"_a = -1, + py::kw_only(), + "stream"_a = none, + R"pbdoc( + Returns the ``k`` largest elements from the input along a given axis. + + The elements will not necessarily be in sorted order. + + Args: + a (array): Input array. + k (int): ``k`` top elements to be returned + axis (int or None, optional): Optional axis to select over. + If ``None``, this selects the top ``k`` elements over the + flattened array. If unspecified, it defaults to ``-1``. + + Returns: + array: The top ``k`` elements from the input. + )pbdoc"); + m.def( + "broadcast_to", + [](const ScalarOrArray& a, + const std::vector& shape, + StreamOrDevice s) { return broadcast_to(to_array(a), shape, s); }, + "a"_a, + py::pos_only(), + "shape"_a, + py::kw_only(), + "stream"_a = none, + R"pbdoc( + Broadcast an array to the given shape. + + The broadcasting semantics are the same as Numpy. + + Args: + a (array): Input array. + shape (list(int)): The shape to broadcast to. + + Returns: + array: The output array with the new shape. + )pbdoc"); + m.def( + "softmax", + [](const array& a, const IntOrVec& axis, StreamOrDevice s) { + return softmax(a, get_reduce_axes(axis, a.ndim()), s); + }, + "a"_a, + py::pos_only(), + "axis"_a = none, + py::kw_only(), + "stream"_a = none, + R"pbdoc( + Perform the softmax along the given axis. + + This operation is a numerically stable version of: + + .. code-block:: + + exp(a) / sum(exp(a), axis, keepdims=True) + + Args: + a (array): Input array. + axis (int or list(int), optional): Optional axis or axes to compute + the softmax over. If unspecified this performs the softmax over + the full array. + + Returns: + array: The output of the softmax. + )pbdoc"); + m.def( + "concatenate", + [](const std::vector& arrays, + std::optional axis, + StreamOrDevice s) { + if (axis) { + return concatenate(arrays, *axis, s); + } else { + return concatenate(arrays, s); + } + }, + "arrays"_a, + py::pos_only(), + "axis"_a = 0, + py::kw_only(), + "stream"_a = none, + R"pbdoc( + Concatenate the arrays along the given axis. + + Args: + arrays (list(array)): Input :obj:`list` or :obj:`tuple` of arrays. + axis (int, optional): Optional axis to concatenate along. If + unspecified defaults to ``0``. + + Returns: + array: The concatenated array. + )pbdoc"); + m.def( + "pad", + [](const array& a, + const std::variant< + int, + std::tuple, + std::pair, + std::vector>>& pad_width, + const ScalarOrArray& constant_value, + StreamOrDevice s) { + if (auto pv = std::get_if(&pad_width); pv) { + return pad(a, *pv, to_array(constant_value), s); + } else if (auto pv = std::get_if>(&pad_width); pv) { + return pad(a, std::get<0>(*pv), to_array(constant_value), s); + } else if (auto pv = std::get_if>(&pad_width); pv) { + return pad(a, *pv, to_array(constant_value), s); + } else { + auto v = std::get>>(pad_width); + if (v.size() == 1) { + return pad(a, v[0], to_array(constant_value), s); + } else { + return pad(a, v, to_array(constant_value), s); + } + } + }, + "a"_a, + py::pos_only(), + "pad_width"_a, + "constant_values"_a = 0, + py::kw_only(), + "stream"_a = none, + R"pbdoc( + Pad an array with a constant value + + Args: + a (array): Input array. + pad_width (int, tuple(int), tuple(int, int) or list(tuple(int, int))): Number of padded + values to add to the edges of each axis:``((before_1, after_1), + (before_2, after_2), ..., (before_N, after_N))``. If a single pair + of integers is passed then ``(before_i, after_i)`` are all the same. + If a single integer or tuple with a single integer is passed then + all axes are extended by the same number on each side. + constant_value (array or scalar, optional): Optional constant value + to pad the edges of the array with. + + Returns: + array: The padded array. + )pbdoc"); + m.def( + "as_strided", + [](const array& a, + std::optional> shape, + std::optional> strides, + size_t offset, + StreamOrDevice s) { + std::vector a_shape = (shape) ? *shape : a.shape(); + std::vector a_strides; + if (strides) { + a_strides = *strides; + } else { + std::fill_n(std::back_inserter(a_strides), a_shape.size(), 1); + for (int i = a_shape.size() - 1; i > 0; i--) { + a_strides[i - 1] = a_shape[i] * a_strides[i]; + } + } + return as_strided(a, a_shape, a_strides, offset, s); + }, + "a"_a, + py::pos_only(), + "shape"_a = none, + "strides"_a = none, + "offset"_a = 0, + py::kw_only(), + "stream"_a = none, + R"pbdoc( + Create a view into the array with the given shape and strides. + + The resulting array will always be as if the provided array was row + contiguous regardless of the provided arrays storage order and current + strides. + + .. note:: + Note that this function should be used with caution as it changes + the shape and strides of the array directly. This can lead to the + resulting array pointing to invalid memory locations which can + result into crashes. + + Args: + a (array): Input array + shape (list(int), optional): The shape of the resulting array. If + None it defaults to ``a.shape()``. + strides (list(int), optional): The strides of the resulting array. If + None it defaults to the reverse exclusive cumulative product of + ``a.shape()``. + offset (int): Skip that many elements from the beginning of the input + array. + + Returns: + array: The output array which is the strided view of the input. + )pbdoc"); + m.def( + "cumsum", + [](const array& a, + std::optional axis, + bool reverse, + bool inclusive, + StreamOrDevice s) { + if (axis) { + return cumsum(a, *axis, reverse, inclusive, s); + } else { + return cumsum(reshape(a, {-1}, s), 0, reverse, inclusive, s); + } + }, + "a"_a, + py::pos_only(), + "axis"_a = std::nullopt, + py::kw_only(), + "reverse"_a = false, + "inclusive"_a = true, + "stream"_a = none, + R"pbdoc( + Return the cumulative sum of the elements along the given axis. + + Args: + a (array): Input array + axis (int, optional): Optional axis to compute the cumulative sum + over. If unspecified the cumulative sum of the flattened array is + returned. + reverse (bool): Perform the cumulative sum in reverse. + inclusive (bool): The i-th element of the output includes the i-th + element of the input. + )pbdoc"); + m.def( + "cumprod", + [](const array& a, + std::optional axis, + bool reverse, + bool inclusive, + StreamOrDevice s) { + if (axis) { + return cumprod(a, *axis, reverse, inclusive, s); + } else { + return cumprod(reshape(a, {-1}, s), 0, reverse, inclusive, s); + } + }, + "a"_a, + py::pos_only(), + "axis"_a = std::nullopt, + py::kw_only(), + "reverse"_a = false, + "inclusive"_a = true, + "stream"_a = none, + R"pbdoc( + Return the cumulative product of the elements along the given axis. + + Args: + a (array): Input array + axis (int, optional): Optional axis to compute the cumulative product + over. If unspecified the cumulative product of the flattened array is + returned. + reverse (bool): Perform the cumulative product in reverse. + inclusive (bool): The i-th element of the output includes the i-th + element of the input. + )pbdoc"); + m.def( + "cummax", + [](const array& a, + std::optional axis, + bool reverse, + bool inclusive, + StreamOrDevice s) { + if (axis) { + return cummax(a, *axis, reverse, inclusive, s); + } else { + return cummax(reshape(a, {-1}, s), 0, reverse, inclusive, s); + } + }, + "a"_a, + py::pos_only(), + "axis"_a = std::nullopt, + py::kw_only(), + "reverse"_a = false, + "inclusive"_a = true, + "stream"_a = none, + R"pbdoc( + Return the cumulative maximum of the elements along the given axis. + + Args: + a (array): Input array + axis (int, optional): Optional axis to compute the cumulative maximum + over. If unspecified the cumulative maximum of the flattened array is + returned. + reverse (bool): Perform the cumulative maximum in reverse. + inclusive (bool): The i-th element of the output includes the i-th + element of the input. + )pbdoc"); + m.def( + "cummin", + [](const array& a, + std::optional axis, + bool reverse, + bool inclusive, + StreamOrDevice s) { + if (axis) { + return cummin(a, *axis, reverse, inclusive, s); + } else { + return cummin(reshape(a, {-1}, s), 0, reverse, inclusive, s); + } + }, + "a"_a, + py::pos_only(), + "axis"_a = std::nullopt, + py::kw_only(), + "reverse"_a = false, + "inclusive"_a = true, + "stream"_a = none, + R"pbdoc( + Return the cumulative minimum of the elements along the given axis. + + Args: + a (array): Input array + axis (int, optional): Optional axis to compute the cumulative minimum + over. If unspecified the cumulative minimum of the flattened array is + returned. + reverse (bool): Perform the cumulative minimum in reverse. + inclusive (bool): The i-th element of the output includes the i-th + element of the input. + )pbdoc"); + m.def( + "convolve", + [](const array& a, + const array& v, + const std::string& mode, + StreamOrDevice s) { + if (a.ndim() != 1 || v.ndim() != 1) { + throw std::invalid_argument("[convolve] Inputs must be 1D."); + } + + array in = a.size() < v.size() ? v : a; + array wt = a.size() < v.size() ? a : v; + wt = slice(wt, {wt.shape(0) - 1}, {-wt.shape(0) - 1}, {-1}, s); + + in = reshape(in, {1, -1, 1}, s); + wt = reshape(wt, {1, -1, 1}, s); + + int padding = 0; + + if (mode == "full") { + padding = wt.size() - 1; + } else if (mode == "valid") { + padding = 0; + } else if (mode == "same") { + // Odd sizes use symmetric padding + if (wt.size() % 2) { + padding = wt.size() / 2; + } else { // Even sizes use asymmetric padding + int pad_l = wt.size() / 2; + int pad_r = std::max(0, pad_l - 1); + in = pad(in, {{0, 0}, {pad_l, pad_r}, {0, 0}}, array(0), s); + } + + } else { + throw std::invalid_argument("[convolve] Invalid mode."); + } + + array out = conv1d( + in, + wt, + /*stride = */ 1, + /*padding = */ padding, + /*dilation = */ 1, + /*groups = */ 1, + s); + + return reshape(out, {-1}, s); + }, + "a"_a, + "v"_a, + py::pos_only(), + "mode"_a = "full", + py::kw_only(), + "stream"_a = none, + R"pbdoc( + The discrete convolution of 1D arrays. + + If ``v`` is longer than ``a``, then they are swapped. + The conv filter is flipped following signal processing convention. + + Args: + a (array): 1D Input array. + v (array): 1D Input array. + mode (str, optional): {'full', 'valid', 'same'} + + Returns: + array: The convolved array. + )pbdoc"); + m.def( + "conv1d", + &conv1d, + "input"_a, + "weight"_a, + py::pos_only(), + "stride"_a = 1, + "padding"_a = 0, + "dilation"_a = 1, + "groups"_a = 1, + py::kw_only(), + "stream"_a = none, + R"pbdoc( + 1D convolution over an input with several channels + + Note: Only the default ``groups=1`` is currently supported. + + Args: + input (array): input array of shape (``N``, ``H``, ``C_in``) + weight (array): weight array of shape (``C_out``, ``H``, ``C_in``) + stride (int, optional): kernel stride. Default: ``1``. + padding (int, optional): input padding. Default: ``0``. + dilation (int, optional): kernel dilation. Default: ``1``. + groups (int, optional): input feature groups. Default: ``1``. + + Returns: + array: The convolved array. + )pbdoc"); + m.def( + "conv2d", + [](const array& input, + const array& weight, + const std::variant>& stride, + const std::variant>& padding, + const std::variant>& dilation, + int groups, + StreamOrDevice s) { + std::pair stride_pair{1, 1}; + std::pair padding_pair{0, 0}; + std::pair dilation_pair{1, 1}; + + if (auto pv = std::get_if(&stride); pv) { + stride_pair = std::pair{*pv, *pv}; + } else { + stride_pair = std::get>(stride); + } + + if (auto pv = std::get_if(&padding); pv) { + padding_pair = std::pair{*pv, *pv}; + } else { + padding_pair = std::get>(padding); + } + + if (auto pv = std::get_if(&dilation); pv) { + dilation_pair = std::pair{*pv, *pv}; + } else { + dilation_pair = std::get>(dilation); + } + + return conv2d( + input, weight, stride_pair, padding_pair, dilation_pair, groups, s); + }, + "input"_a, + "weight"_a, + py::pos_only(), + "stride"_a = 1, + "padding"_a = 0, + "dilation"_a = 1, + "groups"_a = 1, + py::kw_only(), + "stream"_a = none, + R"pbdoc( + 2D convolution over an input with several channels + + Note: Only the default ``groups=1`` is currently supported. + + Args: + input (array): input array of shape ``(N, H, W, C_in)`` + weight (array): weight array of shape ``(C_out, H, W, C_in)`` + stride (int or tuple(int), optional): :obj:`tuple` of size 2 with + kernel strides. All spatial dimensions get the same stride if + only one number is specified. Default: ``1``. + padding (int or tuple(int), optional): :obj:`tuple` of size 2 with + symmetric input padding. All spatial dimensions get the same + padding if only one number is specified. Default: ``0``. + dilation (int or tuple(int), optional): :obj:`tuple` of size 2 with + kernel dilation. All spatial dimensions get the same dilation + if only one number is specified. Default: ``1`` + groups (int, optional): input feature groups. Default: ``1``. + + Returns: + array: The convolved array. + )pbdoc"); + m.def( + "save", + &mlx_save_helper, + "file"_a, + "arr"_a, + py::pos_only(), + "retain_graph"_a = true, + py::kw_only(), + R"pbdoc( + Save the array to a binary file in ``.npy`` format. + + Args: + file (str): File to which the array is saved + arr (array): Array to be saved. + retain_graph(bool): Optional argument to retain graph + during array evaluation before saving. Default: True + + )pbdoc"); + m.def( + "savez", + [](py::object file, py::args args, const py::kwargs& kwargs) { + mlx_savez_helper(file, args, kwargs, /*compressed=*/false); + }, + "file"_a, + py::pos_only(), + py::kw_only(), + R"pbdoc( + Save several arrays to a binary file in uncompressed ``.npz`` format. + + .. code-block:: python + + import mlx.core as mx + + x = mx.ones((10, 10)) + mx.savez("my_path.npz", x=x) + + import mlx.nn as nn + from mlx.utils import tree_flatten + + model = nn.TransformerEncoder(6, 128, 4) + flat_params = tree_flatten(model.parameters()) + mx.savez("model.npz", **dict(flat_params)) + + Args: + file (file, str): Path to file to which the arrays are saved. + args (arrays): Arrays to be saved. + kwargs (arrays): Arrays to be saved. Each array will be saved + with the associated keyword as the output file name. + + )pbdoc"); + m.def( + "savez_compressed", + [](py::object file, py::args args, const py::kwargs& kwargs) { + mlx_savez_helper(file, args, kwargs, /*compressed=*/true); + }, + "file"_a, + py::pos_only(), + py::kw_only(), + R"pbdoc( + Save several arrays to a binary file in compressed ``.npz`` format. + + Args: + file (file, str): Path to file to which the arrays are saved. + args (arrays): Arrays to be saved. + kwargs (arrays): Arrays to be saved. Each array will be saved + with the associated keyword as the output file name. + + )pbdoc"); + m.def( + "load", + &mlx_load_helper, + "file"_a, + py::pos_only(), + py::kw_only(), + "stream"_a = none, + R"pbdoc( + Load array(s) from a binary file in ``.npy`` or ``.npz`` format. + + Args: + file (file, str): File in which the array is saved + + Returns: + result (array, dict): The loaded array if ``.npy`` file or a dict mapping name to array if ``.npz`` file + )pbdoc"); + m.def( + "where", + [](const ScalarOrArray& condition, + const ScalarOrArray& x_, + const ScalarOrArray& y_, + StreamOrDevice s) { + auto [x, y] = to_arrays(x_, y_); + return where(to_array(condition), x, y, s); + }, + "condition"_a, + "x"_a, + "y"_a, + py::pos_only(), + py::kw_only(), + "stream"_a = none, + R"pbdoc( + Select from ``x`` or ``y`` according to ``condition``. + + The condition and input arrays must be the same shape or broadcastable + with each another. + + Args: + condition (array): The condition array. + x (array): The input selected from where condition is ``True``. + y (array): The input selected from where condition is ``False``. + + Returns: + result (array): The output containing elements selected from ``x`` and ``y``. + )pbdoc"); +} diff --git a/python/src/random.cpp b/python/src/random.cpp new file mode 100644 index 0000000000..f38fc43d91 --- /dev/null +++ b/python/src/random.cpp @@ -0,0 +1,289 @@ +#include +#include + +#include "python/src/utils.h" + +#include "mlx/ops.h" +#include "mlx/random.h" + +namespace py = pybind11; +using namespace py::literals; +using namespace mlx::core; +using namespace mlx::core::random; + +void init_random(py::module_& parent_module) { + auto m = parent_module.def_submodule( + "random", + "mlx.core.random: functionality related to random number generation"); + m.def( + "seed", + &seed, + "seed"_a, + R"pbdoc( + Seed the global PRNG. + + Args: + seed (int): Seed for the global PRNG. + )pbdoc"); + m.def( + "key", + &key, + "seed"_a, + R"pbdoc( + Get a PRNG key from a seed. + + Args: + seed (int): Seed for the PRNG. + + Returns: + array: The PRNG key array. + )pbdoc"); + m.def( + "split", + py::overload_cast(&random::split), + "key"_a, + "num"_a = 2, + "stream"_a = none, + R"pbdoc( + Split a PRNG key into sub keys. + + Args: + key (array): Input key to split. + num (int, optional): Number of sub keys. Default is 2. + + Returns: + array: The array of sub keys with ``num`` as its first dimension. + )pbdoc"); + m.def( + "uniform", + [](const ScalarOrArray& low, + const ScalarOrArray& high, + const std::vector& shape, + Dtype type, + const std::optional& key, + StreamOrDevice s) { + return uniform(to_array(low), to_array(high), shape, type, key, s); + }, + "low"_a = 0, + "high"_a = 1, + "shape"_a = std::vector{}, + "dtype"_a = float32, + "key"_a = none, + "stream"_a = none, + R"pbdoc( + Generate uniformly distributed random numbers. + + The values are sampled uniformly in the half-open interval ``[low, high)``. + The lower and upper bound can be scalars or arrays and must be + broadcastable to ``shape``. + + Args: + low (scalar or array, optional): Lower bound of the distribution. Default is ``0``. + high (scalar or array, optional): Upper bound of the distribution. Default is ``1``. + shape (list(int), optional): Shape of the output. Default is ``()``. + key (array, optional): A PRNG key. Default: None. + dtype (Dtype, optional): Type of the output. Default is ``float32``. + + Returns: + array: The output array random values. + )pbdoc"); + m.def( + "normal", + [](const std::vector& shape, + Dtype type, + const std::optional& key, + StreamOrDevice s) { return normal(shape, type, key, s); }, + + "shape"_a = std::vector{}, + "dtype"_a = float32, + "key"_a = none, + "stream"_a = none, + R"pbdoc( + Generate normally distributed random numbers. + + Args: + shape (list(int), optional): Shape of the output. Default is ``()``. + dtype (Dtype, optional): Type of the output. Default is ``float32``. + key (array, optional): A PRNG key. Default: None. + + Returns: + array: The output array of random values. + )pbdoc"); + m.def( + "randint", + [](const ScalarOrArray& low, + const ScalarOrArray& high, + const std::vector& shape, + Dtype type, + const std::optional& key, + StreamOrDevice s) { + return randint(to_array(low), to_array(high), shape, type, key, s); + }, + "low"_a, + "high"_a, + "shape"_a = std::vector{}, + "dtype"_a = int32, + "key"_a = none, + "stream"_a = none, + R"pbdoc( + Generate random integers from the given interval. + + The values are sampled with equal probability from the integers in + half-open interval ``[low, high)``. The lower and upper bound can be + scalars or arrays and must be roadcastable to ``shape``. + + Args: + low (scalar or array): Lower bound of the interval. + high (scalar or array): Upper bound of the interval. + shape (list(int), optional): Shape of the output. Defaults to ``()``. + dtype (Dtype, optional): Type of the output. Defaults to ``int32``. + key (array, optional): A PRNG key. Default: None. + + Returns: + array: The array of random integers. + )pbdoc"); + m.def( + "bernoulli", + [](const ScalarOrArray& p_, + const std::optional> shape, + const std::optional& key, + StreamOrDevice s) { + auto p = to_array(p_); + if (shape.has_value()) { + return bernoulli(p, shape.value(), key, s); + } else { + return bernoulli(p, key, s); + } + }, + "p"_a = 0.5, + "shape"_a = none, + "key"_a = none, + "stream"_a = none, + R"pbdoc( + Generate Bernoulli random values. + + The values are sampled from the bernoulli distribution with parameter + ``p``. The parameter ``p`` can be a :obj:`float` or :obj:`array` and + must be broadcastable to ``shape``. + + Args: + p (float or array, optional): Parameter of the Bernoulli + distribution. Default is 0.5. + shape (list(int), optional): Shape of the output. The default + shape is ``p.shape``. + key (array, optional): A PRNG key. Default: None. + + Returns: + array: The array of random integers. + )pbdoc"); + m.def( + "truncated_normal", + [](const ScalarOrArray& lower_, + const ScalarOrArray& upper_, + const std::optional> shape_, + Dtype dtype, + const std::optional& key, + StreamOrDevice s) { + auto lower = to_array(lower_); + auto upper = to_array(upper_); + if (shape_.has_value()) { + return truncated_normal(lower, upper, shape_.value(), dtype, key, s); + } else { + return truncated_normal(lower, upper, dtype, key, s); + } + }, + "lower"_a, + "upper"_a, + "shape"_a = none, + "dtype"_a = float32, + "key"_a = none, + "stream"_a = none, + R"pbdoc( + Generate values from a truncated normal distribution. + + The values are sampled from the truncated normal distribution + on the domain ``(lower, upper)``. The bounds ``lower`` and ``upper`` + can be scalars or arrays and must be broadcastable to ``shape``. + + Args: + lower (scalar or array): Lower bound of the domain. + upper (scalar or array): Upper bound of the domain. + shape (list(int), optional): The shape of the output. + Default is ``()``. + dtype (Dtype, optinoal): The data type of the output. + Default is ``float32``. + key (array, optional): A PRNG key. Default: None. + + Returns: + array: The output array of random values. + )pbdoc"); + m.def( + "gumbel", + &gumbel, + "shape"_a = std::vector{}, + "dtype"_a = float32, + "stream"_a = none, + "key"_a = none, + R"pbdoc( + Sample from the standard Gumbel distribution. + + The values are sampled from a standard Gumbel distribution + which CDF ``exp(-exp(-x))``. + + Args: + shape (list(int)): The shape of the output. + key (array, optional): A PRNG key. Default: None. + + Returns: + array: The :class:`array` with shape ``shape`` and + distributed according to the Gumbel distribution + )pbdoc"); + m.def( + "categorical", + [](const array& logits, + int axis, + const std::optional> shape, + const std::optional num_samples, + const std::optional& key, + StreamOrDevice s) { + if (shape.has_value() && num_samples.has_value()) { + throw std::invalid_argument( + "[categorical] At most one of shape or num_samples can be specified."); + } else if (shape.has_value()) { + return categorical(logits, axis, shape.value(), key, s); + } else if (num_samples.has_value()) { + return categorical(logits, axis, num_samples.value(), key, s); + } else { + return categorical(logits, axis, key, s); + } + }, + "logits"_a, + "axis"_a = -1, + "shape"_a = none, + "num_samples"_a = none, + "key"_a = none, + "stream"_a = none, + R"pbdoc( + Sample from a categorical distribution. + + The values are sampled from the categorical distribution specified by + the unnormalized values in ``logits``. Note, at most one of ``shape`` + or ``num_samples`` can be specified. If both are ``None``, the output + has the same shape as ``logits`` with the ``axis`` dimension removed. + + Args: + logits (array): The *unnormalized* categorical distribution(s). + axis (int, optional): The axis which specifies the distribution. + Default is ``-1``. + shape (list(int), optional): The shape of the output. This must + be broadcast compatable with ``logits.shape`` with the ``axis`` + dimension removed. Default: ``None`` + num_samples (int, optional): The number of samples to draw from each + of the categorical distributions in ``logits``. The output will have + ``num_samples`` in the last dimension. Default: ``None``. + key (array, optional): A PRNG key. Default: None. + + Returns: + array: The ``shape``-sized output array with type ``uint32``. + )pbdoc"); +} diff --git a/python/src/utils.h b/python/src/utils.h new file mode 100644 index 0000000000..bc06fb5046 --- /dev/null +++ b/python/src/utils.h @@ -0,0 +1,71 @@ +#pragma once +#include +#include + +#include +#include +#include + +#include "mlx/array.h" + +namespace py = pybind11; + +using namespace mlx::core; + +using IntOrVec = std::variant>; +using ScalarOrArray = + std::variant, array>; +static constexpr std::monostate none{}; + +inline std::vector get_reduce_axes(const IntOrVec& v, int dims) { + std::vector axes; + if (std::holds_alternative(v)) { + axes.resize(dims); + std::iota(axes.begin(), axes.end(), 0); + } else if (auto pv = std::get_if(&v); pv) { + axes.push_back(*pv); + } else { + axes = std::get>(v); + } + return axes; +} + +inline array to_array( + const ScalarOrArray& v, + std::optional dtype = std::nullopt) { + if (auto pv = std::get_if(&v); pv) { + return array(py::cast(*pv), dtype.value_or(bool_)); + } else if (auto pv = std::get_if(&v); pv) { + auto out_t = dtype.value_or(int32); + // bool_ is an exception and is always promoted + return array(py::cast(*pv), (out_t == bool_) ? int32 : out_t); + } else if (auto pv = std::get_if(&v); pv) { + auto out_t = dtype.value_or(float32); + return array( + py::cast(*pv), is_floating_point(out_t) ? out_t : float32); + } else if (auto pv = std::get_if>(&v); pv) { + return array(static_cast(*pv), complex64); + } else { + return std::get(v); + } +} + +inline std::pair to_arrays( + const ScalarOrArray& a, + const ScalarOrArray& b) { + // Four cases: + // - If both a and b are arrays leave their types alone + // - If a is an array but b is not, treat b as a weak python type + // - If b is an array but a is not, treat a as a weak python type + // - If neither is an array convert to arrays but leave their types alone + if (auto pa = std::get_if(&a); pa) { + if (auto pb = std::get_if(&b); pb) { + return {*pa, *pb}; + } + return {*pa, to_array(b, pa->dtype())}; + } else if (auto pb = std::get_if(&b); pb) { + return {to_array(a, pb->dtype()), *pb}; + } else { + return {to_array(a), to_array(b)}; + } +} diff --git a/python/tests/test_array.py b/python/tests/test_array.py new file mode 100644 index 0000000000..e7edb6c205 --- /dev/null +++ b/python/tests/test_array.py @@ -0,0 +1,1041 @@ +import operator +import unittest +from itertools import permutations + +import mlx.core as mx +import numpy as np + +import mlx_tests + + +class TestVersion(mlx_tests.MLXTestCase): + def test_version(self): + v = mx.__version__ + vnums = v.split(".") + self.assertEqual(len(vnums), 3) + v = ".".join(str(int(vn)) for vn in vnums) + self.assertEqual(v, mx.__version__) + + +class TestDtypes(mlx_tests.MLXTestCase): + def test_dtypes(self): + self.assertEqual(mx.bool_.size, 1) + self.assertEqual(mx.uint8.size, 1) + self.assertEqual(mx.uint16.size, 2) + self.assertEqual(mx.uint32.size, 4) + self.assertEqual(mx.uint64.size, 8) + self.assertEqual(mx.int8.size, 1) + self.assertEqual(mx.int16.size, 2) + self.assertEqual(mx.int32.size, 4) + self.assertEqual(mx.int64.size, 8) + self.assertEqual(mx.float16.size, 2) + self.assertEqual(mx.float32.size, 4) + self.assertEqual(mx.bfloat16.size, 2) + self.assertEqual(mx.complex64.size, 8) + + self.assertEqual(str(mx.bool_), "bool") + self.assertEqual(str(mx.uint8), "uint8") + self.assertEqual(str(mx.uint16), "uint16") + self.assertEqual(str(mx.uint32), "uint32") + self.assertEqual(str(mx.uint64), "uint64") + self.assertEqual(str(mx.int8), "int8") + self.assertEqual(str(mx.int16), "int16") + self.assertEqual(str(mx.int32), "int32") + self.assertEqual(str(mx.int64), "int64") + self.assertEqual(str(mx.float16), "float16") + self.assertEqual(str(mx.float32), "float32") + self.assertEqual(str(mx.bfloat16), "bfloat16") + self.assertEqual(str(mx.complex64), "complex64") + + def test_scalar_conversion(self): + dtypes = [ + "uint8", + "uint16", + "uint32", + "uint64", + "int8", + "int16", + "int32", + "int64", + "float16", + "float32", + "complex64", + ] + + for dtype in dtypes: + with self.subTest(dtype=dtype): + x = np.array(2, dtype=getattr(np, dtype)) + y = np.min(x) + + self.assertEqual(x.dtype, y.dtype) + self.assertTupleEqual(x.shape, y.shape) + + z = mx.array(y) + self.assertEqual(np.array(z), x) + self.assertEqual(np.array(z), y) + self.assertEqual(z.dtype, getattr(mx, dtype)) + self.assertListEqual(list(z.shape), list(x.shape)) + self.assertListEqual(list(z.shape), list(y.shape)) + + +class TestArray(mlx_tests.MLXTestCase): + def test_array_basics(self): + x = mx.array(1) + self.assertEqual(x.size, 1) + self.assertEqual(x.ndim, 0) + self.assertEqual(x.shape, []) + self.assertEqual(x.dtype, mx.int32) + self.assertEqual(x.item(), 1) + self.assertTrue(isinstance(x.item(), int)) + + with self.assertRaises(TypeError): + len(x) + + x = mx.array(1, mx.uint32) + self.assertEqual(x.item(), 1) + self.assertTrue(isinstance(x.item(), int)) + + x = mx.array(1, mx.int64) + self.assertEqual(x.item(), 1) + self.assertTrue(isinstance(x.item(), int)) + + x = mx.array(1.0) + self.assertEqual(x.size, 1) + self.assertEqual(x.ndim, 0) + self.assertEqual(x.shape, []) + self.assertEqual(x.dtype, mx.float32) + self.assertEqual(x.item(), 1.0) + self.assertTrue(isinstance(x.item(), float)) + + x = mx.array(False) + self.assertEqual(x.size, 1) + self.assertEqual(x.ndim, 0) + self.assertEqual(x.shape, []) + self.assertEqual(x.dtype, mx.bool_) + self.assertEqual(x.item(), False) + self.assertTrue(isinstance(x.item(), bool)) + + x = mx.array(complex(1, 1)) + self.assertEqual(x.ndim, 0) + self.assertEqual(x.shape, []) + self.assertEqual(x.dtype, mx.complex64) + self.assertEqual(x.item(), complex(1, 1)) + self.assertTrue(isinstance(x.item(), complex)) + + x = mx.array([True, False, True]) + self.assertEqual(x.dtype, mx.bool_) + self.assertEqual(x.ndim, 1) + self.assertEqual(x.shape, [3]) + self.assertEqual(len(x), 3) + + x = mx.array([True, False, True], mx.float32) + self.assertEqual(x.dtype, mx.float32) + + x = mx.array([0, 1, 2]) + self.assertEqual(x.dtype, mx.int32) + self.assertEqual(x.ndim, 1) + self.assertEqual(x.shape, [3]) + + x = mx.array([0, 1, 2], mx.float32) + self.assertEqual(x.dtype, mx.float32) + + x = mx.array([0.0, 1.0, 2.0]) + self.assertEqual(x.dtype, mx.float32) + self.assertEqual(x.ndim, 1) + self.assertEqual(x.shape, [3]) + + x = mx.array([1j, 1 + 0j]) + self.assertEqual(x.dtype, mx.complex64) + self.assertEqual(x.ndim, 1) + self.assertEqual(x.shape, [2]) + + # From tuple + x = mx.array((1, 2, 3), mx.int32) + self.assertEqual(x.dtype, mx.int32) + self.assertEqual(x.tolist(), [1, 2, 3]) + + def test_bool_conversion(self): + x = mx.array(True) + self.assertTrue(x) + x = mx.array(False) + self.assertFalse(x) + x = mx.array(1.0) + self.assertTrue(x) + x = mx.array(0.0) + self.assertFalse(x) + + def test_construction_from_lists(self): + x = mx.array([]) + self.assertEqual(x.size, 0) + self.assertEqual(x.shape, [0]) + self.assertEqual(x.dtype, mx.float32) + + x = mx.array([[], [], []]) + self.assertEqual(x.size, 0) + self.assertEqual(x.shape, [3, 0]) + self.assertEqual(x.dtype, mx.float32) + + x = mx.array([[[], []], [[], []], [[], []]]) + self.assertEqual(x.size, 0) + self.assertEqual(x.shape, [3, 2, 0]) + self.assertEqual(x.dtype, mx.float32) + + # Check failure cases + with self.assertRaises(ValueError): + x = mx.array([[[], []], [[]], [[], []]]) + + with self.assertRaises(ValueError): + x = mx.array([[[], []], [[1.0, 2.0], []], [[], []]]) + + with self.assertRaises(ValueError): + x = mx.array([[0, 1], [[0, 1], 1]]) + + with self.assertRaises(ValueError): + x = mx.array([[0, 1], ["hello", 1]]) + + x = mx.array([True, False, 3]) + self.assertEqual(x.dtype, mx.int32) + + x = mx.array([True, False, 3, 4.0]) + self.assertEqual(x.dtype, mx.float32) + + x = mx.array([[True, False], [1, 3], [2, 4.0]]) + self.assertEqual(x.dtype, mx.float32) + + x = mx.array([[1.0, 2.0], [0.0, 3.9]], mx.bool_) + self.assertEqual(x.dtype, mx.bool_) + self.assertTrue(mx.array_equal(x, mx.array([[True, True], [False, True]]))) + + x = mx.array([[1.0, 2.0], [0.0, 3.9]], mx.int32) + self.assertTrue(mx.array_equal(x, mx.array([[1, 2], [0, 3]]))) + + x = mx.array([1 + 0j, 2j, True, 0], mx.complex64) + self.assertEqual(x.tolist(), [1 + 0j, 2j, 1 + 0j, 0j]) + + def test_init_from_array(self): + x = mx.array(3.0) + y = mx.array(x) + + self.assertTrue(mx.array_equal(x, y)) + + y = mx.array(x, mx.int32) + self.assertEqual(y.dtype, mx.int32) + self.assertEqual(y.item(), 3) + + y = mx.array(x, mx.bool_) + self.assertEqual(y.dtype, mx.bool_) + self.assertEqual(y.item(), True) + + # y = mx.array(x, mx.complex64) + # self.assertEqual(y.dtype, mx.complex64) + # self.assertEqual(y.item(), 3.0+0j) + + def test_array_repr(self): + x = mx.array(True) + self.assertEqual(str(x), "array(true, dtype=bool)") + x = mx.array(1) + self.assertEqual(str(x), "array(1, dtype=int32)") + x = mx.array(1.0) + self.assertEqual(str(x), "array(1, dtype=float32)") + + x = mx.array([1, 0, 1]) + self.assertEqual(str(x), "array([1, 0, 1], dtype=int32)") + + x = mx.array([1] * 6) + expected = "array([1, 1, 1, 1, 1, 1], dtype=int32)" + self.assertEqual(str(x), expected) + + x = mx.array([1] * 7) + expected = "array([1, 1, 1, ..., 1, 1, 1], dtype=int32)" + self.assertEqual(str(x), expected) + + x = mx.array([[1, 2], [1, 2], [1, 2]]) + expected = "array([[1, 2],\n" " [1, 2],\n" " [1, 2]], dtype=int32)" + self.assertEqual(str(x), expected) + + x = mx.array([[[1, 2], [1, 2]], [[1, 2], [1, 2]]]) + expected = ( + "array([[[1, 2],\n" + " [1, 2]],\n" + " [[1, 2],\n" + " [1, 2]]], dtype=int32)" + ) + self.assertEqual(str(x), expected) + + x = mx.array([[1, 2]] * 6) + expected = ( + "array([[1, 2],\n" + " [1, 2],\n" + " [1, 2],\n" + " [1, 2],\n" + " [1, 2],\n" + " [1, 2]], dtype=int32)" + ) + self.assertEqual(str(x), expected) + x = mx.array([[1, 2]] * 7) + expected = ( + "array([[1, 2],\n" + " [1, 2],\n" + " [1, 2],\n" + " ...,\n" + " [1, 2],\n" + " [1, 2],\n" + " [1, 2]], dtype=int32)" + ) + self.assertEqual(str(x), expected) + + x = mx.array([1], dtype=mx.int8) + expected = "array([1], dtype=int8)" + self.assertEqual(str(x), expected) + x = mx.array([1], dtype=mx.int16) + expected = "array([1], dtype=int16)" + self.assertEqual(str(x), expected) + x = mx.array([1], dtype=mx.uint8) + expected = "array([1], dtype=uint8)" + self.assertEqual(str(x), expected) + + # Fp16 is not supported in all platforms + x = mx.array([1.2], dtype=mx.float16) + expected = "array([1.2002], dtype=float16)" + self.assertEqual(str(x), expected) + + x = mx.array([1 + 1j], dtype=mx.complex64) + expected = "array([1+1j], dtype=complex64)" + self.assertEqual(str(x), expected) + x = mx.array([1 - 1j], dtype=mx.complex64) + expected = "array([1-1j], dtype=complex64)" + + x = mx.array([1 + 1j], dtype=mx.complex64) + expected = "array([1+1j], dtype=complex64)" + self.assertEqual(str(x), expected) + x = mx.array([1 - 1j], dtype=mx.complex64) + expected = "array([1-1j], dtype=complex64)" + + def test_array_to_list(self): + types = [mx.bool_, mx.uint32, mx.int32, mx.int64, mx.float32] + for t in types: + x = mx.array(1, t) + self.assertEqual(x.tolist(), 1) + + vals = [1, 2, 3, 4] + x = mx.array(vals) + self.assertEqual(x.tolist(), vals) + + vals = [[1, 2], [3, 4]] + x = mx.array(vals) + self.assertEqual(x.tolist(), vals) + + vals = [[1, 0], [0, 1]] + x = mx.array(vals, mx.bool_) + self.assertEqual(x.tolist(), vals) + + vals = [[1.5, 2.5], [3.5, 4.5]] + x = mx.array(vals) + self.assertEqual(x.tolist(), vals) + + vals = [[[0.5, 1.5], [2.5, 3.5]], [[4.5, 5.5], [6.5, 7.5]]] + x = mx.array(vals) + self.assertEqual(x.tolist(), vals) + + # Empty arrays + vals = [] + x = mx.array(vals) + self.assertEqual(x.tolist(), vals) + + vals = [[], []] + x = mx.array(vals) + self.assertEqual(x.tolist(), vals) + + # Complex arrays + vals = [0.5 + 0j, 1.5 + 1j, 2.5 + 0j, 3.5 + 1j] + x = mx.array(vals) + self.assertEqual(x.tolist(), vals) + + def test_array_np_conversion(self): + # Shape test + a = np.array([]) + x = mx.array(a) + self.assertEqual(x.size, 0) + self.assertEqual(x.shape, [0]) + self.assertEqual(x.dtype, mx.float32) + + a = np.array([[], [], []]) + x = mx.array(a) + self.assertEqual(x.size, 0) + self.assertEqual(x.shape, [3, 0]) + self.assertEqual(x.dtype, mx.float32) + + a = np.array([[[], []], [[], []], [[], []]]) + x = mx.array(a) + self.assertEqual(x.size, 0) + self.assertEqual(x.shape, [3, 2, 0]) + self.assertEqual(x.dtype, mx.float32) + + # Content test + a = 2.0 * np.ones((3, 5, 4)) + x = mx.array(a) + self.assertEqual(x.dtype, mx.float32) + self.assertEqual(x.ndim, 3) + self.assertEqual(x.shape, [3, 5, 4]) + + y = np.asarray(x) + self.assertTrue(np.allclose(a, y)) + + a = np.array(3, dtype=np.int32) + x = mx.array(a) + self.assertEqual(x.dtype, mx.int32) + self.assertEqual(x.ndim, 0) + self.assertEqual(x.shape, []) + self.assertEqual(x.item(), 3) + + # mlx to numpy test + x = mx.array([True, False, True]) + y = np.asarray(x) + self.assertEqual(y.dtype, np.bool_) + self.assertEqual(y.ndim, 1) + self.assertEqual(y.shape, (3,)) + self.assertEqual(y[0], True) + self.assertEqual(y[1], False) + self.assertEqual(y[2], True) + + # complex64 mx <-> np + cvals = [0j, 1, 1 + 1j] + x = np.array(cvals) + y = mx.array(x) + self.assertEqual(y.dtype, mx.complex64) + self.assertEqual(y.shape, [3]) + self.assertEqual(y.tolist(), cvals) + + y = mx.array([0j, 1, 1 + 1j]) + x = np.asarray(y) + self.assertEqual(x.dtype, np.complex64) + self.assertEqual(x.shape, (3,)) + self.assertEqual(x.tolist(), cvals) + + def test_array_np_dtype_conversion(self): + dtypes_list = [ + (mx.bool_, np.bool_), + (mx.uint8, np.uint8), + (mx.uint16, np.uint16), + (mx.uint32, np.uint32), + (mx.uint64, np.uint64), + (mx.int8, np.int8), + (mx.int16, np.int16), + (mx.int32, np.int32), + (mx.int64, np.int64), + (mx.float16, np.float16), + (mx.float32, np.float32), + (mx.complex64, np.complex64), + ] + + for mlx_dtype, np_dtype in dtypes_list: + a_npy = np.random.uniform(low=0, high=100, size=(32,)).astype(np_dtype) + a_mlx = mx.array(a_npy) + + self.assertEqual(a_mlx.dtype, mlx_dtype) + self.assertTrue(np.allclose(a_mlx, a_npy)) + + b_mlx = mx.random.uniform( + low=0, + high=10, + shape=(32,), + ).astype(mlx_dtype) + b_npy = np.array(b_mlx) + + self.assertEqual(b_npy.dtype, np_dtype) + + def test_dtype_promotion(self): + dtypes_list = [ + (mx.bool_, np.bool_), + (mx.uint8, np.uint8), + (mx.uint16, np.uint16), + (mx.uint32, np.uint32), + (mx.uint64, np.uint64), + (mx.int8, np.int8), + (mx.int16, np.int16), + (mx.int32, np.int32), + (mx.int64, np.int64), + (mx.float32, np.float32), + ] + + promotion_pairs = permutations(dtypes_list, 2) + + for (mlx_dt_1, np_dt_1), (mlx_dt_2, np_dt_2) in promotion_pairs: + with self.subTest(dtype1=np_dt_1, dtype2=np_dt_2): + a_npy = np.ones((3,), dtype=np_dt_1) + b_npy = np.ones((3,), dtype=np_dt_2) + + c_npy = a_npy + b_npy + + a_mlx = mx.ones((3,), dtype=mlx_dt_1) + b_mlx = mx.ones((3,), dtype=mlx_dt_2) + + c_mlx = a_mlx + b_mlx + + self.assertEqual(c_mlx.dtype, mx.array(c_npy).dtype) + + a_mlx = mx.ones((3,), dtype=mx.float16) + b_mlx = mx.ones((3,), dtype=mx.float32) + c_mlx = a_mlx + b_mlx + + self.assertEqual(c_mlx.dtype, mx.float32) + + b_mlx = mx.ones((3,), dtype=mx.int32) + c_mlx = a_mlx + b_mlx + + self.assertEqual(c_mlx.dtype, mx.float16) + + def test_dtype_python_scalar_promotion(self): + tests = [ + (mx.bool_, operator.mul, False, mx.bool_), + (mx.bool_, operator.mul, 0, mx.int32), + (mx.bool_, operator.mul, 1.0, mx.float32), + (mx.int8, operator.mul, False, mx.int8), + (mx.int8, operator.mul, 0, mx.int8), + (mx.int8, operator.mul, 1.0, mx.float32), + (mx.int16, operator.mul, False, mx.int16), + (mx.int16, operator.mul, 0, mx.int16), + (mx.int16, operator.mul, 1.0, mx.float32), + (mx.int32, operator.mul, False, mx.int32), + (mx.int32, operator.mul, 0, mx.int32), + (mx.int32, operator.mul, 1.0, mx.float32), + (mx.int64, operator.mul, False, mx.int64), + (mx.int64, operator.mul, 0, mx.int64), + (mx.int64, operator.mul, 1.0, mx.float32), + (mx.uint8, operator.mul, False, mx.uint8), + (mx.uint8, operator.mul, 0, mx.uint8), + (mx.uint8, operator.mul, 1.0, mx.float32), + (mx.uint16, operator.mul, False, mx.uint16), + (mx.uint16, operator.mul, 0, mx.uint16), + (mx.uint16, operator.mul, 1.0, mx.float32), + (mx.uint32, operator.mul, False, mx.uint32), + (mx.uint32, operator.mul, 0, mx.uint32), + (mx.uint32, operator.mul, 1.0, mx.float32), + (mx.uint64, operator.mul, False, mx.uint64), + (mx.uint64, operator.mul, 0, mx.uint64), + (mx.uint64, operator.mul, 1.0, mx.float32), + (mx.float32, operator.mul, False, mx.float32), + (mx.float32, operator.mul, 0, mx.float32), + (mx.float32, operator.mul, 1.0, mx.float32), + (mx.float16, operator.mul, False, mx.float16), + (mx.float16, operator.mul, 0, mx.float16), + (mx.float16, operator.mul, 1.0, mx.float16), + ] + + for dtype_in, f, v, dtype_out in tests: + x = mx.array(0, dtype_in) + y = f(x, v) + self.assertEqual(y.dtype, dtype_out) + + def test_array_comparison(self): + a = mx.array([0.0, 1.0, 5.0]) + b = mx.array([-1.0, 2.0, 5.0]) + + self.assertEqual((a < b).tolist(), [False, True, False]) + self.assertEqual((a <= b).tolist(), [False, True, True]) + self.assertEqual((a > b).tolist(), [True, False, False]) + self.assertEqual((a >= b).tolist(), [True, False, True]) + + self.assertEqual((a < 5).tolist(), [True, True, False]) + self.assertEqual((5 < a).tolist(), [False, False, False]) + self.assertEqual((5 <= a).tolist(), [False, False, True]) + self.assertEqual((a > 1).tolist(), [False, False, True]) + self.assertEqual((a >= 1).tolist(), [False, True, True]) + + def test_array_neg(self): + a = mx.array([-1.0, 4.0, 0.0]) + + self.assertEqual((-a).tolist(), [1.0, -4.0, 0.0]) + + def test_array_type_cast(self): + a = mx.array([0.1, 2.3, -1.3]) + b = [0, 2, -1] + + self.assertEqual(a.astype(mx.int32).tolist(), b) + self.assertEqual(a.astype(mx.int32).dtype, mx.int32) + + b = mx.array(b).astype(mx.float32) + self.assertEqual(b.dtype, mx.float32) + + def test_array_iteration(self): + a = mx.array([0, 1, 2]) + + for i, x in enumerate(a): + self.assertEqual(x.item(), i) + + a = mx.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) + x, y, z = a + self.assertEqual(x.tolist(), [1.0, 2.0]) + self.assertEqual(y.tolist(), [3.0, 4.0]) + self.assertEqual(z.tolist(), [5.0, 6.0]) + + def test_indexing(self): + # Basic content check, slice indexing + a_npy = np.arange(64, dtype=np.float32) + a_mlx = mx.array(a_npy) + a_sliced_mlx = a_mlx[2:50:4] + a_sliced_npy = np.asarray(a_sliced_mlx) + self.assertTrue(np.array_equal(a_sliced_npy, a_npy[2:50:4])) + + # Basic content check, mlx array indexing + a_npy = np.arange(64, dtype=np.int32) + a_npy = a_npy.reshape((8, 8)) + a_mlx = mx.array(a_npy) + idx_npy = np.array([0, 1, 2, 7, 5], dtype=np.uint32) + idx_mlx = mx.array(idx_npy) + a_sliced_mlx = a_mlx[idx_mlx] + a_sliced_npy = np.asarray(a_sliced_mlx) + self.assertTrue(np.array_equal(a_sliced_npy, a_npy[idx_npy])) + + # Basic content check, int indexing + a_sliced_mlx = a_mlx[5] + a_sliced_npy = np.asarray(a_sliced_mlx) + self.assertTrue(np.array_equal(a_sliced_npy, a_npy[5])) + self.assertEqual(len(a_sliced_npy.shape), len(a_npy[5].shape)) + self.assertEqual(len(a_sliced_npy.shape), 1) + self.assertEqual(a_sliced_npy.shape[0], a_npy[5].shape[0]) + + # Basic content check, negative indexing + a_sliced_mlx = a_mlx[-1] + self.assertTrue(np.array_equal(a_sliced_mlx, a_npy[-1])) + + # Basic content check, empty index + a_sliced_mlx = a_mlx[()] + a_sliced_npy = np.asarray(a_sliced_mlx) + self.assertTrue(np.array_equal(a_sliced_npy, a_npy[()])) + + # Basic content check, new axis + a_sliced_mlx = a_mlx[None] + a_sliced_npy = np.asarray(a_sliced_mlx) + self.assertTrue(np.array_equal(a_sliced_npy, a_npy[None])) + + # Multi dim indexing, all ints + self.assertEqual(a_mlx[0, 0].item(), 0) + self.assertEqual(a_mlx[0, 0].ndim, 0) + + # Multi dim indexing, all slices + a_sliced_mlx = a_mlx[2:4, 5:] + a_sliced_npy = np.asarray(a_sliced_mlx) + self.assertTrue(np.array_equal(a_sliced_npy, a_npy[2:4, 5:])) + + a_sliced_mlx = a_mlx[:, 0:5] + a_sliced_npy = np.asarray(a_sliced_mlx) + self.assertTrue(np.array_equal(a_sliced_npy, a_npy[:, 0:5])) + + # Slicing, strides + a_sliced_mlx = a_mlx[:, ::2] + a_sliced_npy = np.asarray(a_sliced_mlx) + self.assertTrue(np.array_equal(a_sliced_npy, a_npy[:, ::2])) + + # Slicing, -ve index + a_sliced_mlx = a_mlx[-2:, :-1] + a_sliced_npy = np.asarray(a_sliced_mlx) + self.assertTrue(np.array_equal(a_sliced_npy, a_npy[-2:, :-1])) + + # Slicing, start > end + a_sliced_mlx = a_mlx[8:3] + self.assertEqual(a_sliced_mlx.size, 0) + + # Slicing, Clipping past the end + a_sliced_mlx = a_mlx[7:10] + a_sliced_npy = np.asarray(a_sliced_mlx) + self.assertTrue(np.array_equal(a_sliced_npy, a_npy[7:10])) + + # Multi dim indexing, int and slices + a_sliced_mlx = a_mlx[0, :5] + a_sliced_npy = np.asarray(a_sliced_mlx) + self.assertTrue(np.array_equal(a_sliced_npy, a_npy[0, :5])) + + a_sliced_mlx = a_mlx[:, -1] + a_sliced_npy = np.asarray(a_sliced_mlx) + self.assertTrue(np.array_equal(a_sliced_npy, a_npy[:, -1])) + + # Multi dim indexing, int and array + a_sliced_mlx = a_mlx[idx_mlx, 0] + a_sliced_npy = np.asarray(a_sliced_mlx) + self.assertTrue(np.array_equal(a_sliced_npy, a_npy[idx_npy, 0])) + + # Multi dim indexing, array and slices + a_sliced_mlx = a_mlx[idx_mlx, :5] + a_sliced_npy = np.asarray(a_sliced_mlx) + self.assertTrue(np.array_equal(a_sliced_npy, a_npy[idx_npy, :5])) + + a_sliced_mlx = a_mlx[:, idx_mlx] + a_sliced_npy = np.asarray(a_sliced_mlx) + self.assertTrue(np.array_equal(a_sliced_npy, a_npy[:, idx_npy])) + + # Multi dim indexing with multiple arrays + def check_slices(arr_np, *idx_np): + arr_mlx = mx.array(arr_np) + idx_mlx = [ + mx.array(idx) if isinstance(idx, np.ndarray) else idx for idx in idx_np + ] + slice_mlx = arr_mlx[tuple(idx_mlx)] + self.assertTrue( + np.array_equal(arr_np[tuple(idx_np)], arr_mlx[tuple(idx_mlx)]) + ) + + a_np = np.arange(16).reshape(4, 4) + check_slices(a_np, np.array([0, 1, 2, 3]), np.array([0, 1, 2, 3])) + check_slices(a_np, np.array([0, 1, 2, 3]), np.array([1, 0, 3, 3])) + check_slices(a_np, np.array([[0, 1]]), np.array([[0], [1], [3]])) + + a_np = np.arange(64).reshape(2, 4, 2, 4) + check_slices(a_np, 0, np.array([0, 1, 2])) + check_slices(a_np, slice(0, 1), np.array([0, 1, 2])) + check_slices( + a_np, slice(0, 1), np.array([0, 1, 2]), slice(None), slice(0, 4, 2) + ) + check_slices( + a_np, slice(0, 1), np.array([0, 1, 2]), slice(None), np.array([1, 2, 0]) + ) + check_slices(a_np, slice(0, 1), np.array([0, 1, 2]), 1, np.array([1, 2, 0])) + check_slices( + a_np, slice(0, 1), np.array([0, 1, 2]), np.array([1, 0, 0]), slice(0, 1) + ) + check_slices( + a_np, + slice(0, 1), + np.array([[0], [1], [2]]), + np.array([[1, 0, 0]]), + slice(0, 1), + ) + check_slices( + a_np, + slice(0, 2), + np.array([[0], [1], [2]]), + slice(0, 2), + np.array([[1, 0, 0]]), + ) + for p in permutations([slice(None), slice(None), 0, np.array([1, 0])]): + check_slices(a_np, *p) + for p in permutations( + [slice(None), slice(None), 0, np.array([1, 0]), None, None] + ): + check_slices(a_np, *p) + for p in permutations([0, np.array([1, 0]), None, Ellipsis, slice(None)]): + check_slices(a_np, *p) + + # Non-contiguous arrays in slicing + a_mlx = mx.reshape(mx.arange(128), (16, 8)) + a_mlx = a_mlx[::2, :] + a_np = np.array(a_mlx) + idx_np = np.arange(8)[::2] + idx_mlx = mx.arange(8)[::2] + self.assertTrue( + np.array_equal(a_np[idx_np, idx_np], np.array(a_mlx[idx_mlx, idx_mlx])) + ) + + def test_setitem(self): + a = mx.array(0) + a[None] = 1 + self.assertEqual(a.item(), 1) + + a = mx.array([1, 2, 3]) + a[0] = 2 + self.assertEqual(a.tolist(), [2, 2, 3]) + + a[-1] = 2 + self.assertEqual(a.tolist(), [2, 2, 2]) + + a[0] = mx.array([[[1]]]) + self.assertEqual(a.tolist(), [1, 2, 2]) + + a[:] = 0 + self.assertEqual(a.tolist(), [0, 0, 0]) + + a[None] = 1 + self.assertEqual(a.tolist(), [1, 1, 1]) + + a[0:1] = 2 + self.assertEqual(a.tolist(), [2, 1, 1]) + + a[0:2] = 3 + self.assertEqual(a.tolist(), [3, 3, 1]) + + a[0:3] = 4 + self.assertEqual(a.tolist(), [4, 4, 4]) + + a[0:1] = mx.array(0) + self.assertEqual(a.tolist(), [0, 4, 4]) + + a[0:1] = mx.array([1]) + self.assertEqual(a.tolist(), [1, 4, 4]) + + with self.assertRaises(ValueError): + a[0:1] = mx.array([2, 3]) + + a[0:2] = mx.array([2, 2]) + self.assertEqual(a.tolist(), [2, 2, 4]) + + a[:] = mx.array([[[[1, 1, 1]]]]) + self.assertEqual(a.tolist(), [1, 1, 1]) + + # Array slices + def check_slices(arr_np, update_np, *idx_np): + arr_mlx = mx.array(arr_np) + update_mlx = mx.array(update_np) + idx_mlx = [ + mx.array(idx) if isinstance(idx, np.ndarray) else idx for idx in idx_np + ] + if len(idx_np) > 1: + idx_np = tuple(idx_np) + idx_mlx = tuple(idx_mlx) + else: + idx_np = idx_np[0] + idx_mlx = idx_mlx[0] + arr_np[idx_np] = update_np + arr_mlx[idx_mlx] = update_mlx + self.assertTrue(np.array_equal(arr_np, arr_mlx)) + + check_slices(np.zeros((3, 3)), 1, 0) + check_slices(np.zeros((3, 3)), 1, -1) + check_slices(np.zeros((3, 3)), 1, slice(0, 2)) + check_slices(np.zeros((3, 3)), np.array([[0, 1, 2], [3, 4, 5]]), slice(0, 2)) + + with self.assertRaises(ValueError): + a = mx.array(0) + a[0] = mx.array(1) + + check_slices(np.zeros((3, 3)), 1, np.array([0, 1, 2])) + check_slices(np.zeros((3, 3)), np.array(3), np.array([0, 1, 2])) + check_slices(np.zeros((3, 3)), np.array([3]), np.array([0, 1, 2])) + check_slices(np.zeros((3, 3)), np.array([3]), np.array([0, 1])) + check_slices(np.zeros((3, 2)), np.array([[3, 3], [4, 4]]), np.array([0, 1])) + check_slices(np.zeros((3, 2)), np.array([[3, 3], [4, 4]]), np.array([0, 1])) + check_slices( + np.zeros((3, 2)), np.array([[3, 3], [4, 4], [5, 5]]), np.array([0, 0, 1]) + ) + + # Multiple slices + a = mx.array(0) + a[None, None] = 1 + self.assertEqual(a.item(), 1) + + a[None, None] = mx.array(2) + self.assertEqual(a.item(), 2) + + a[None, None] = mx.array([[[3]]]) + self.assertEqual(a.item(), 3) + + a[()] = 4 + self.assertEqual(a.item(), 4) + + a_np = np.zeros((2, 3, 4, 5)) + check_slices(a_np, 1, np.array([0, 0]), slice(0, 2), slice(0, 3), 4) + check_slices( + a_np, + np.arange(10).reshape(2, 5), + np.array([0, 0]), + np.array([0, 1]), + np.array([2, 3]), + ) + check_slices( + a_np, + np.array([[3], [4]]), + np.array([0, 0]), + np.array([0, 1]), + np.array([2, 3]), + ) + check_slices( + a_np, np.arange(5), np.array([0, 0]), np.array([0, 1]), np.array([2, 3]) + ) + check_slices(np.zeros(5), np.arange(2), None, None, np.array([2, 3])) + check_slices( + np.zeros((4, 3, 4)), + np.arange(3), + np.array([2, 3]), + slice(0, 3), + np.array([2, 3]), + ) + + with self.assertRaises(ValueError): + a = mx.zeros((4, 3, 4)) + a[mx.array([2, 3]), None, mx.array([2, 3])] = mx.arange(2) + + with self.assertRaises(ValueError): + a = mx.zeros((4, 3, 4)) + a[mx.array([2, 3]), None, mx.array([2, 3])] = mx.arange(3) + + check_slices(np.zeros((4, 3, 4)), 1, np.array([2, 3]), None, np.array([2, 1])) + check_slices( + np.zeros((4, 3, 4)), np.arange(4), np.array([2, 3]), None, np.array([2, 1]) + ) + check_slices( + np.zeros((4, 3, 4)), + np.arange(2 * 4).reshape(2, 1, 4), + np.array([2, 3]), + None, + np.array([2, 1]), + ) + + check_slices(np.zeros((4, 4)), 1, slice(0, 2), slice(0, 2)) + check_slices(np.zeros((4, 4)), np.arange(2), slice(0, 2), slice(0, 2)) + check_slices( + np.zeros((4, 4)), np.arange(2).reshape(2, 1), slice(0, 2), slice(0, 2) + ) + check_slices( + np.zeros((4, 4)), np.arange(4).reshape(2, 2), slice(0, 2), slice(0, 2) + ) + + with self.assertRaises(ValueError): + a = mx.zeros((2, 2, 2)) + a[..., ...] = 1 + + with self.assertRaises(ValueError): + a = mx.zeros((2, 2, 2, 2, 2)) + a[0, ..., 0, ..., 0] = 1 + + with self.assertRaises(ValueError): + a = mx.zeros((2, 2)) + a[0, 0, 0] = 1 + + check_slices(np.zeros((2, 2, 2, 2)), 1, None, Ellipsis, None) + check_slices( + np.zeros((2, 2, 2, 2)), 1, np.array([0, 1]), Ellipsis, np.array([0, 1]) + ) + check_slices( + np.zeros((2, 2, 2, 2)), + np.arange(2 * 2 * 2).reshape(2, 2, 2), + np.array([0, 1]), + Ellipsis, + np.array([0, 1]), + ) + + def test_slice_negative_step(self): + + a_np = np.arange(20) + a_mx = mx.array(a_np) + + # Basic negative slice + b_np = a_np[::-1] + b_mx = a_mx[::-1] + self.assertTrue(np.array_equal(b_np, b_mx)) + + # Bounds negative slice + b_np = a_np[-3:3:-1] + b_mx = a_mx[-3:3:-1] + self.assertTrue(np.array_equal(b_np, b_mx)) + + # Bounds negative slice + b_np = a_np[25:-50:-1] + b_mx = a_mx[25:-50:-1] + self.assertTrue(np.array_equal(b_np, b_mx)) + + # Jumping negative slice + b_np = a_np[::-3] + b_mx = a_mx[::-3] + self.assertTrue(np.array_equal(b_np, b_mx)) + + # Bounds and negative slice + b_np = a_np[-3:3:-3] + b_mx = a_mx[-3:3:-3] + self.assertTrue(np.array_equal(b_np, b_mx)) + + # Bounds and negative slice + b_np = a_np[25:-50:-3] + b_mx = a_mx[25:-50:-3] + self.assertTrue(np.array_equal(b_np, b_mx)) + + # Negatie slice and ascending bounds + b_np = a_np[0:20:-3] + b_mx = a_mx[0:20:-3] + self.assertTrue(np.array_equal(b_np, b_mx)) + + # Multi-dim negative slices + a_np = np.arange(3 * 6 * 4).reshape(3, 6, 4) + a_mx = mx.array(a_np) + + # Flip each dim + b_np = a_np[..., ::-1] + b_mx = a_mx[..., ::-1] + self.assertTrue(np.array_equal(b_np, b_mx)) + + b_np = a_np[:, ::-1, :] + b_mx = a_mx[:, ::-1, :] + self.assertTrue(np.array_equal(b_np, b_mx)) + + b_np = a_np[::-1, ...] + b_mx = a_mx[::-1, ...] + self.assertTrue(np.array_equal(b_np, b_mx)) + + # Flip pairs of dims + b_np = a_np[::-1, 1:5:2, ::-2] + b_mx = a_mx[::-1, 1:5:2, ::-2] + self.assertTrue(np.array_equal(b_np, b_mx)) + + b_np = a_np[::-1, ::-2, 1:5:2] + b_mx = a_mx[::-1, ::-2, 1:5:2] + self.assertTrue(np.array_equal(b_np, b_mx)) + + # Flip all dims + b_np = a_np[::-1, ::-3, ::-2] + b_mx = a_mx[::-1, ::-3, ::-2] + self.assertTrue(np.array_equal(b_np, b_mx)) + + def test_api(self): + x = mx.array(np.random.rand(10, 10, 10)) + ops = [ + ("reshape", (100, -1)), + "square", + "sqrt", + "rsqrt", + "reciprocal", + "exp", + "log", + "sin", + "cos", + "log1p", + ("all", 1), + ("any", 1), + ("transpose", (0, 2, 1)), + ("sum", 1), + ("prod", 1), + ("min", 1), + ("max", 1), + ("logsumexp", 1), + ("mean", 1), + ("var", 1), + ("argmin", 1), + ("argmax", 1), + ] + for op in ops: + if isinstance(op, tuple): + op, *args = op + else: + args = tuple() + y1 = getattr(mx, op)(x, *args) + y2 = getattr(x, op)(*args) + self.assertEqual(y1.dtype, y2.dtype) + self.assertEqual(y1.shape, y2.shape) + self.assertTrue(mx.array_equal(y1, y2)) + + y1 = mx.split(x, 2) + y2 = x.split(2) + self.assertEqual(len(y1), 2) + self.assertEqual(len(y1), len(y2)) + self.assertTrue(mx.array_equal(y1[0], y2[0])) + self.assertTrue(mx.array_equal(y1[1], y2[1])) + + def test_memoryless_copy(self): + a_mx = mx.ones((2, 2)) + b_mx = mx.broadcast_to(a_mx, (5, 2, 2)) + + # Make np arrays without copy + a_np = np.array(a_mx, copy=False) + b_np = np.array(b_mx, copy=False) + + # Check that we get read-only array that does not own the underlying data + self.assertFalse(a_np.flags.owndata) + self.assertFalse(a_np.flags.writeable) + + # Check contents + self.assertTrue(np.array_equal(np.ones((2, 2), dtype=np.float32), a_np)) + self.assertTrue(np.array_equal(np.ones((5, 2, 2), dtype=np.float32), b_np)) + + # Check strides + self.assertSequenceEqual(b_np.strides, (0, 8, 4)) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/tests/test_autograd.py b/python/tests/test_autograd.py new file mode 100644 index 0000000000..96e22dd6b6 --- /dev/null +++ b/python/tests/test_autograd.py @@ -0,0 +1,263 @@ +import unittest + +import mlx.core as mx + +import mlx_tests + + +class TestAutograd(mlx_tests.MLXTestCase): + def test_jvp(self): + fun = lambda x: 2 * x + out, dout = mx.jvp(fun, [mx.array(1.0)], [mx.array(2.0)]) + self.assertEqual(out[0].item(), 2.0) + self.assertEqual(dout[0].item(), 4.0) + + fun = lambda x, y: x * y + _, out = mx.jvp( + fun, [mx.array(4.0), mx.array(2.0)], [mx.array(3.0), mx.array(2.0)] + ) + self.assertEqual(out[0].item(), 4.0 * 2.0 + 2.0 * 3.0) + + fun = lambda x, y, z: (x * y, y * z) + _, out = mx.jvp( + fun, + [mx.array(2.0), mx.array(4.0), mx.array(6.0)], + [mx.array(1.0), mx.array(3.0), mx.array(1.0)], + ) + self.assertEqual(len(out), 2) + self.assertEqual(out[0].item(), 4.0 * 1.0 + 2.0 * 3.0) + self.assertEqual(out[1].item(), 4.0 * 1.0 + 6.0 * 3.0) + + def test_vjp(self): + fun = lambda x: 2 * x + out, dout = mx.vjp(fun, [mx.array(1.0)], [mx.array(2.0)]) + self.assertEqual(out[0].item(), 2.0) + self.assertEqual(dout[0].item(), 4.0) + + fun = lambda x, y: x * y + _, dout = mx.vjp(fun, [mx.array(4.0), mx.array(2.0)], [mx.array(3.0)]) + self.assertEqual(dout[0].item(), 6.0) + self.assertEqual(dout[1].item(), 12.0) + + fun = lambda x, y, z: (x * y, y * z) + _, out = mx.vjp( + fun, + [mx.array(2.0), mx.array(4.0), mx.array(6.0)], + [mx.array(1.0), mx.array(3.0)], + ) + self.assertEqual(len(out), 3) + self.assertEqual(out[0].item(), 4.0 * 1.0) + self.assertEqual(out[1].item(), 2.0 * 1.0 + 6.0 * 3.0) + self.assertEqual(out[2].item(), 4.0 * 3.0) + + def test_grad(self): + fun = lambda x: x * x + + value, dfdx = mx.value_and_grad(fun)(mx.array(0.5)) + self.assertEqual(value.item(), 0.25) + self.assertEqual(dfdx.item(), 1.0) + + dfdx = mx.grad(fun)(mx.array(0.5)) + self.assertEqual(dfdx.item(), 1.0) + + df2dx2 = mx.grad(mx.grad(fun))(mx.array(0.5)) + self.assertEqual(df2dx2.item(), 2.0) + df3dx3 = mx.grad(mx.grad(mx.grad(fun)))(mx.array(0.5)) + self.assertEqual(df3dx3.item(), 0.0) + + fun = lambda x, y: x * y + x = mx.array(2.0) + y = mx.array(3.0) + dfdx = mx.grad(fun, argnums=0)(x, y) + self.assertEqual(dfdx.item(), 3.0) + dfdx = mx.grad(fun, argnums=1)(x, y) + self.assertEqual(dfdx.item(), 2.0) + + # Pass non array args to functions works + fun = lambda x, y: x + value, dfdx = mx.value_and_grad(fun)(mx.array(2.0), "hello") + self.assertEqual(value.item(), 2.0) + self.assertEqual(dfdx.item(), 1.0) + + dfdx = mx.grad(fun)(mx.array(2.0), "hello") + self.assertEqual(dfdx.item(), 1.0) + + # Raises when function does not return array + fun = lambda x: "hello" + with self.assertRaises(ValueError): + mx.grad(fun)(mx.array(2.0)) + + # Raises for invalid argument number or argument type + fun = lambda x: x + with self.assertRaises(ValueError): + mx.grad(fun, argnums=2)(mx.array(2.0)) + with self.assertRaises(ValueError): + mx.grad(fun, argnums=-2)(mx.array(2.0)) + with self.assertRaises(ValueError): + mx.grad(fun)("hello") + + # Raises when output is not a scalar array + fun = lambda x: mx.sum(x, keepdims=True) + with self.assertRaises(ValueError): + mx.grad(fun)(mx.ones((2, 2))) + + def test_grad_trees(self): + fun = lambda x, y: x * y + value, dfdx = mx.value_and_grad(fun, (0, 1))(mx.array(0.5), mx.array(2.0)) + self.assertEqual(value.item(), 1.0) + self.assertTrue(isinstance(dfdx, tuple)) + self.assertEqual(dfdx[0].item(), 2.0) + self.assertEqual(dfdx[1].item(), 0.5) + + fun = lambda x, y: x * y + value, dfdx = mx.value_and_grad(fun, 1)(mx.array(0.5), mx.array(2.0)) + self.assertEqual(value.item(), 1.0) + self.assertEqual(dfdx.item(), 0.5) + + fun = lambda p: p["x"] * p["y"] + value, dfdx = mx.value_and_grad(fun)({"x": mx.array(0.5), "y": mx.array(2.0)}) + self.assertEqual(value.item(), 1.0) + self.assertEqual(dfdx["x"].item(), 2.0) + self.assertEqual(dfdx["y"].item(), 0.5) + + fun = lambda p: p["x"] * p["y"] + with self.assertRaises(ValueError): + mx.value_and_grad(fun)({"x": 0.5, "y": mx.array(2.0)}) + with self.assertRaises(ValueError): + mx.value_and_grad(fun, (0, 1))({"x": mx.array(0.5), "y": mx.array(2.0)}) + + fun = lambda p, b: mx.square(p[0]["foo"][2]) * b + value, dfdx = mx.value_and_grad(fun)( + [{"foo": [[], [], mx.array(2.0)]}], mx.array(0.5) + ) + self.assertEqual(value.item(), 2.0) + self.assertEqual(dfdx[0]["foo"][2].item(), 2.0) + + fun = lambda x: x + with self.assertRaises(TypeError): + mx.value_and_grad(fun, (None, None)) + with self.assertRaises(ValueError): + mx.value_and_grad(fun, tuple()) + + def test_auxiliary_values(self): + def fun(x, y): + l = (x * y).sum() + extra = {"loss": l, "foo": y.square() + x.square(), "bar": [1, 2, 3, y, x]} + return l, extra + + fun_value_grad = mx.value_and_grad(fun) + fun_grad = mx.grad(fun) + + (loss, a), b = fun_value_grad(mx.ones((2, 2)), mx.ones((2, 2))) + self.assertEqual(a["loss"].item(), 4) + self.assertTrue(mx.array_equal(b, mx.ones((2, 2)))) + self.assertTrue(mx.array_equal(a["foo"], 2 * mx.ones((2, 2)))) + self.assertEqual(a["bar"][:3], [1, 2, 3]) + self.assertTrue(mx.array_equal(a["bar"][3], mx.ones((2, 2)))) + self.assertTrue(mx.array_equal(a["bar"][4], mx.ones((2, 2)))) + + with self.assertRaises(ValueError): + _ = fun_grad(mx.ones((2, 2)), mx.ones((2, 2))) + + def test_grad_kwargs(self): + fun = lambda x, y: x * y + a, b = mx.array(0.5), mx.array(2.0) + dfdx = mx.grad(fun) + self.assertEqual(dfdx(a, b).item(), 2.0) + self.assertEqual(dfdx(a, y=b).item(), 2.0) + with self.assertRaises(ValueError): + dfdx(x=a, y=b).item() + + dfdy = mx.grad(fun, argnums=[], argnames=["y"]) + with self.assertRaises(ValueError): + dfdy(a, b) + grads = dfdy(a, y=b) + self.assertTrue(isinstance(grads, tuple)) + self.assertTrue(grads[0] is None) + self.assertTrue(isinstance(grads[1], dict)) + self.assertEqual(grads[1]["y"].item(), 0.5) + grads = dfdy(x=a, y=b) + self.assertEqual(grads[1]["y"].item(), 0.5) + self.assertEqual(len(grads[1]), 1) + + dfdxy = mx.grad(fun, argnums=[0], argnames=["y"]) + with self.assertRaises(ValueError): + dfdxy(a, b) + with self.assertRaises(ValueError): + dfdxy(x=a, y=b) + grads = dfdxy(a, y=b) + self.assertTrue(isinstance(grads, tuple)) + self.assertEqual(grads[0].item(), 2.0) + self.assertTrue(isinstance(grads[1], dict)) + self.assertEqual(grads[1]["y"].item(), 0.5) + + fun = lambda x, y, z: x * y * z + dfdxyz = mx.grad(fun, argnums=[0, 1], argnames=["z"]) + c = mx.array(4.0) + grads = dfdxyz(a, b, z=c) + self.assertTrue(isinstance(grads, tuple)) + self.assertTrue(isinstance(grads[0], tuple)) + self.assertEqual(grads[0][0].item(), 8.0) + self.assertEqual(grads[0][1].item(), 2.0) + self.assertTrue(isinstance(grads[1], dict)) + self.assertEqual(grads[1]["z"].item(), 1.0) + + fun = lambda x, y: x * y + dfdy = mx.grad(fun, argnames=["y"]) + grads = dfdy(a, y=b) + self.assertTrue(isinstance(grads, tuple)) + self.assertTrue(grads[0] is None) + self.assertTrue(isinstance(grads[1], dict)) + self.assertEqual(grads[1]["y"].item(), 0.5) + + def test_captured(self): + a = mx.array(5.0) + f = lambda x: a + x + g = lambda x: a + a + h = lambda x: x + x + + dfdx = mx.grad(f) + self.assertEqual(dfdx(a).item(), 1.0) + + dgdx = mx.grad(g) + self.assertEqual(dgdx(a).item(), 0.0) + + dhdx = mx.grad(h) + self.assertEqual(dhdx(a).item(), 2.0) + + d2fdx2 = mx.grad(dfdx) + self.assertEqual(d2fdx2(a).item(), 0.0) + + d2gdx2 = mx.grad(dgdx) + self.assertEqual(d2gdx2(a).item(), 0.0) + + d2hdx2 = mx.grad(dhdx) + self.assertEqual(d2hdx2(a).item(), 0.0) + + def test_stop_gradient(self): + shape_in = (4, 4) + w_in = mx.ones(shape_in) + x_in = mx.ones(shape_in) + cotan = mx.ones(shape_in) + + def h(w, x): + x1 = 2 * x + y = mx.stop_gradient(x1) + y1 = 3 * y + return w @ y1 + + vals, vjps = mx.vjp(h, [w_in, x_in], [cotan]) + mx.eval(vjps) + + self.assertTrue(mx.allclose(vjps[0], 24.0 * mx.ones(shape_in))) + self.assertTrue(mx.allclose(vjps[1], mx.zeros(shape_in))) + + g = lambda x: h(w_in, x) + vals, vjps = mx.vjp(g, [x_in], [cotan]) + mx.eval(vjps) + + self.assertTrue(mx.allclose(vjps[0], mx.zeros(shape_in))) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/tests/test_device.py b/python/tests/test_device.py new file mode 100644 index 0000000000..b3abec01fb --- /dev/null +++ b/python/tests/test_device.py @@ -0,0 +1,105 @@ +import unittest + +import mlx.core as mx + +import mlx_tests + + +# Don't inherit from MLXTestCase to avoid call to setUp +class TestDefaultDevice(unittest.TestCase): + def test_mlx_default_device(self): + device = mx.default_device() + if mx.metal.is_available(): + self.assertEqual(device, mx.Device(mx.gpu)) + self.assertEqual(str(device), "Device(gpu, 0)") + self.assertEqual(device, mx.gpu) + self.assertEqual(mx.gpu, device) + else: + self.assertEqual(device.type, mx.Device(mx.cpu)) + with self.assertRaises(ValueError): + mx.set_default_device(mx.gpu) + + +class TestDevice(mlx_tests.MLXTestCase): + def test_device(self): + device = mx.default_device() + + cpu = mx.Device(mx.cpu) + mx.set_default_device(cpu) + self.assertEqual(mx.default_device(), cpu) + self.assertEqual(str(cpu), "Device(cpu, 0)") + + mx.set_default_device(mx.cpu) + self.assertEqual(mx.default_device(), mx.cpu) + self.assertEqual(cpu, mx.cpu) + self.assertEqual(mx.cpu, cpu) + + # Restore device + mx.set_default_device(device) + + def test_op_on_device(self): + x = mx.array(1.0) + y = mx.array(1.0) + + a = mx.add(x, y, stream=None) + b = mx.add(x, y, stream=mx.default_device()) + self.assertEqual(a.item(), b.item()) + b = mx.add(x, y, stream=mx.cpu) + self.assertEqual(a.item(), b.item()) + + if mx.metal.is_available(): + b = mx.add(x, y, stream=mx.gpu) + self.assertEqual(a.item(), b.item()) + + +class TestStream(mlx_tests.MLXTestCase): + def test_stream(self): + s1 = mx.default_stream(mx.default_device()) + self.assertEqual(s1.device, mx.default_device()) + + s2 = mx.new_stream(mx.default_device()) + self.assertEqual(s2.device, mx.default_device()) + self.assertNotEqual(s1, s2) + + if mx.metal.is_available(): + s_gpu = mx.default_stream(mx.gpu) + self.assertEqual(s_gpu.device, mx.gpu) + else: + with self.assertRaises(ValueError): + mx.default_stream(mx.gpu) + + s_cpu = mx.default_stream(mx.cpu) + self.assertEqual(s_cpu.device, mx.cpu) + + s_cpu = mx.new_stream(mx.cpu) + self.assertEqual(s_cpu.device, mx.cpu) + + if mx.metal.is_available(): + s_gpu = mx.new_stream(mx.gpu) + self.assertEqual(s_gpu.device, mx.gpu) + else: + with self.assertRaises(ValueError): + mx.new_stream(mx.gpu) + + def test_op_on_stream(self): + x = mx.array(1.0) + y = mx.array(1.0) + + a = mx.add(x, y, stream=mx.default_stream(mx.default_device())) + + if mx.metal.is_available(): + b = mx.add(x, y, stream=mx.default_stream(mx.gpu)) + self.assertEqual(a.item(), b.item()) + s_gpu = mx.new_stream(mx.gpu) + b = mx.add(x, y, stream=s_gpu) + self.assertEqual(a.item(), b.item()) + + b = mx.add(x, y, stream=mx.default_stream(mx.cpu)) + self.assertEqual(a.item(), b.item()) + s_cpu = mx.new_stream(mx.cpu) + b = mx.add(x, y, stream=s_cpu) + self.assertEqual(a.item(), b.item()) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/tests/test_eval.py b/python/tests/test_eval.py new file mode 100644 index 0000000000..e4d48a8779 --- /dev/null +++ b/python/tests/test_eval.py @@ -0,0 +1,34 @@ +from functools import partial + +import unittest + +import mlx.core as mx + +import mlx_tests + + +class TestEval(mlx_tests.MLXTestCase): + def test_eval(self): + arrs = [mx.ones((2, 2)) for _ in range(4)] + mx.eval(*arrs) + for x in arrs: + self.assertEqual(x.tolist(), [[1, 1], [1, 1]]) + + def test_retain_graph(self): + def fun(x, retain_graph): + y = 3 * x + mx.eval(y, retain_graph=retain_graph) + return 2 * y + + dfun_dx_1 = mx.grad(partial(fun, retain_graph=False)) + dfun_dx_2 = mx.grad(partial(fun, retain_graph=True)) + + with self.assertRaises(ValueError): + dfun_dx_1(mx.array(1.0)) + + y = dfun_dx_2(mx.array(1.0)) + self.assertEqual(y.item(), 6.0) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/tests/test_fft.py b/python/tests/test_fft.py new file mode 100644 index 0000000000..6af90e1c58 --- /dev/null +++ b/python/tests/test_fft.py @@ -0,0 +1,90 @@ +import unittest + +import itertools +import mlx.core as mx +import numpy as np + +import mlx_tests + + +class TestFFT(mlx_tests.MLXTestCase): + def check_mx_np(self, op, a_np, axes, s): + with self.subTest(op=op, axes=axes, s=s): + op_np = getattr(np.fft, op) + op_mx = getattr(mx.fft, op) + out_np = op_np(a_np, s=s, axes=axes) + a_mx = mx.array(a_np) + out_mx = op_mx(a_mx, s=s, axes=axes) + self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)) + + def test_fft(self): + default = mx.default_device() + mx.set_default_device(mx.cpu) + + def check_mx_np(op_mx, op_np, a_np, **kwargs): + out_np = op_np(a_np, **kwargs) + a_mx = mx.array(a_np) + out_mx = op_mx(a_mx, **kwargs) + self.assertTrue(np.allclose(out_np, out_mx, atol=1e-5, rtol=1e-6)) + + r = np.random.rand(100).astype(np.float32) + i = np.random.rand(100).astype(np.float32) + a_np = r + 1j * i + check_mx_np(mx.fft.fft, np.fft.fft, a_np) + + # Check with slicing and padding + r = np.random.rand(100).astype(np.float32) + i = np.random.rand(100).astype(np.float32) + a_np = r + 1j * i + check_mx_np(mx.fft.fft, np.fft.fft, a_np, n=80) + check_mx_np(mx.fft.fft, np.fft.fft, a_np, n=120) + + # Check different axes + r = np.random.rand(100, 100).astype(np.float32) + i = np.random.rand(100, 100).astype(np.float32) + a_np = r + 1j * i + check_mx_np(mx.fft.fft, np.fft.fft, a_np, axis=0) + check_mx_np(mx.fft.fft, np.fft.fft, a_np, axis=1) + + # Check real fft + a_np = np.random.rand(100).astype(np.float32) + check_mx_np(mx.fft.rfft, np.fft.rfft, a_np) + check_mx_np(mx.fft.rfft, np.fft.rfft, a_np, n=80) + check_mx_np(mx.fft.rfft, np.fft.rfft, a_np, n=120) + + # Check real inverse + r = np.random.rand(100, 100).astype(np.float32) + i = np.random.rand(100, 100).astype(np.float32) + a_np = r + 1j * i + check_mx_np(mx.fft.ifft, np.fft.ifft, a_np) + check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=80) + check_mx_np(mx.fft.ifft, np.fft.ifft, a_np, n=120) + check_mx_np(mx.fft.irfft, np.fft.irfft, a_np) + check_mx_np(mx.fft.irfft, np.fft.irfft, a_np, n=80) + check_mx_np(mx.fft.irfft, np.fft.irfft, a_np, n=120) + + mx.set_default_device(default) + + def test_fftn(self): + default = mx.default_device() + mx.set_default_device(mx.cpu) + + r = np.random.randn(8, 8, 8).astype(np.float32) + i = np.random.randn(8, 8, 8).astype(np.float32) + a = r + 1j * i + + axes = [None, (1, 2), (2, 1), (0, 2)] + shapes = [None, (10, 5), (5, 10)] + ops = ["fft2", "ifft2", "rfft2", "irfft2", "fftn", "ifftn", "rfftn", "irfftn"] + + for op, ax, s in itertools.product(ops, axes, shapes): + x = a + if op in ["rfft2", "rfftn"]: + x = r + self.check_mx_np(op, x, axes=ax, s=s) + + mx.set_default_device(default) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/tests/test_ops.py b/python/tests/test_ops.py new file mode 100644 index 0000000000..9885a7e708 --- /dev/null +++ b/python/tests/test_ops.py @@ -0,0 +1,1283 @@ +import unittest +from itertools import permutations + +import math +import mlx.core as mx +import numpy as np + +import mlx_tests + + +class TestOps(mlx_tests.MLXTestCase): + def test_full_ones_zeros(self): + x = mx.full(2, 3.0) + self.assertEqual(x.shape, [2]) + self.assertEqual(x.tolist(), [3.0, 3.0]) + + x = mx.full((2, 3), 2.0) + self.assertEqual(x.dtype, mx.float32) + self.assertEqual(x.shape, [2, 3]) + self.assertEqual(x.tolist(), [[2, 2, 2], [2, 2, 2]]) + + x = mx.full([3, 2], mx.array([False, True])) + self.assertEqual(x.dtype, mx.bool_) + self.assertEqual(x.tolist(), [[False, True], [False, True], [False, True]]) + + x = mx.full([3, 2], mx.array([2.0, 3.0])) + self.assertEqual(x.tolist(), [[2, 3], [2, 3], [2, 3]]) + + x = mx.zeros(2) + self.assertEqual(x.shape, [2]) + self.assertEqual(x.tolist(), [0.0, 0.0]) + + x = mx.ones(2) + self.assertEqual(x.shape, [2]) + self.assertEqual(x.tolist(), [1.0, 1.0]) + + for t in [mx.bool_, mx.int32, mx.float32]: + x = mx.zeros([2, 2], t) + self.assertEqual(x.dtype, t) + self.assertTrue(mx.array_equal(x, mx.array([[0, 0], [0, 0]]))) + y = mx.zeros_like(x) + self.assertEqual(y.dtype, t) + self.assertTrue(mx.array_equal(y, x)) + + x = mx.ones([2, 2], t) + self.assertEqual(x.dtype, t) + self.assertTrue(mx.array_equal(x, mx.array([[1, 1], [1, 1]]))) + y = mx.ones_like(x) + self.assertEqual(y.dtype, t) + self.assertTrue(mx.array_equal(y, x)) + + def test_scalar_inputs(self): + # Check combinations of python types + a = mx.add(False, True) + self.assertEqual(a.dtype, mx.bool_) + self.assertEqual(a.item(), True) + + a = mx.add(1, 2) + self.assertEqual(a.dtype, mx.int32) + self.assertEqual(a.item(), 3) + + a = mx.add(1.0, 2.0) + self.assertEqual(a.dtype, mx.float32) + self.assertEqual(a.item(), 3.0) + + a = mx.add(True, 2) + self.assertEqual(a.dtype, mx.int32) + self.assertEqual(a.item(), 3) + + a = mx.add(True, 2.0) + self.assertEqual(a.dtype, mx.float32) + self.assertEqual(a.item(), 3.0) + + a = mx.add(1, 2.0) + self.assertEqual(a.dtype, mx.float32) + self.assertEqual(a.item(), 3.0) + + a = mx.add(2, True) + self.assertEqual(a.dtype, mx.int32) + self.assertEqual(a.item(), 3) + + a = mx.add(2.0, True) + self.assertEqual(a.dtype, mx.float32) + self.assertEqual(a.item(), 3.0) + + a = mx.add(2.0, 1) + self.assertEqual(a.dtype, mx.float32) + self.assertEqual(a.item(), 3.0) + + # Check comibinations with mlx arrays + a = mx.add(mx.array(True), False) + self.assertEqual(a.dtype, mx.bool_) + self.assertEqual(a.item(), True) + + a = mx.add(mx.array(1), False) + self.assertEqual(a.dtype, mx.int32) + self.assertEqual(a.item(), 1.0) + + # Edge case: take the type of the scalar + a = mx.add(mx.array(True), 1) + self.assertEqual(a.dtype, mx.int32) + self.assertEqual(a.item(), 2) + + a = mx.add(mx.array(1.0), 1) + self.assertEqual(a.dtype, mx.float32) + self.assertEqual(a.item(), 2.0) + + a = mx.add(1, mx.array(1.0)) + self.assertEqual(a.dtype, mx.float32) + self.assertEqual(a.item(), 2.0) + + binary_ops = [ + "add", + "subtract", + "multiply", + "divide", + "equal", + "not_equal", + "less", + "greater", + "less_equal", + "greater_equal", + "maximum", + "minimum", + ] + + for op in binary_ops: + npop = getattr(np, op) + mlxop = getattr(mx, op) + + # Avoid subtract from bool and divide by 0 + for x in [-1, 0, 1, -1.0, 1.0]: + for y in [True, -1, 1, -1.0, 1.0]: + self.assertEqual(npop(x, y).item(), mlxop(x, y).item()) + + def test_add(self): + x = mx.array(1) + y = mx.array(1) + z = mx.add(x, y) + self.assertEqual(z.item(), 2) + + x = mx.array(False, mx.bool_) + z = x + 1 + self.assertEqual(z.dtype, mx.int32) + self.assertEqual(z.item(), 1) + z = 2 + x + self.assertEqual(z.dtype, mx.int32) + self.assertEqual(z.item(), 2) + + x = mx.array(1, mx.uint32) + z = x + 3 + self.assertEqual(z.dtype, mx.uint32) + self.assertEqual(z.item(), 4) + + z = 3 + x + self.assertEqual(z.dtype, mx.uint32) + self.assertEqual(z.item(), 4) + + z = x + 3.0 + self.assertEqual(z.dtype, mx.float32) + self.assertEqual(z.item(), 4.0) + + z = 3.0 + x + self.assertEqual(z.dtype, mx.float32) + self.assertEqual(z.item(), 4.0) + + x = mx.array(1, mx.int64) + z = x + 3 + self.assertEqual(z.dtype, mx.int64) + self.assertEqual(z.item(), 4) + z = 3 + x + self.assertEqual(z.dtype, mx.int64) + self.assertEqual(z.item(), 4) + z = x + 3.0 + self.assertEqual(z.dtype, mx.float32) + self.assertEqual(z.item(), 4.0) + z = 3.0 + x + self.assertEqual(z.dtype, mx.float32) + self.assertEqual(z.item(), 4.0) + + x = mx.array(1, mx.float32) + z = x + 3 + self.assertEqual(z.dtype, mx.float32) + self.assertEqual(z.item(), 4) + z = 3 + x + self.assertEqual(z.dtype, mx.float32) + self.assertEqual(z.item(), 4) + + def test_subtract(self): + x = mx.array(4.0) + y = mx.array(3.0) + + z = mx.subtract(x, y) + self.assertEqual(z.dtype, mx.float32) + self.assertEqual(z.item(), 1.0) + + z = x - 3.0 + self.assertEqual(z.dtype, mx.float32) + self.assertEqual(z.item(), 1.0) + + z = 5.0 - x + self.assertEqual(z.dtype, mx.float32) + self.assertEqual(z.item(), 1.0) + + def test_multiply(self): + x = mx.array(2.0) + y = mx.array(3.0) + + z = mx.multiply(x, y) + self.assertEqual(z.dtype, mx.float32) + self.assertEqual(z.item(), 6.0) + + z = x * 3.0 + self.assertEqual(z.dtype, mx.float32) + self.assertEqual(z.item(), 6.0) + + z = 3.0 * x + self.assertEqual(z.dtype, mx.float32) + self.assertEqual(z.item(), 6.0) + + def test_divide(self): + x = mx.array(2.0) + y = mx.array(4.0) + + z = mx.divide(x, y) + self.assertEqual(z.dtype, mx.float32) + self.assertEqual(z.item(), 0.5) + + z = x / 4.0 + self.assertEqual(z.dtype, mx.float32) + self.assertEqual(z.item(), 0.5) + + z = 1.0 / x + self.assertEqual(z.dtype, mx.float32) + self.assertEqual(z.item(), 0.5) + + def test_comparisons(self): + a = mx.array([0.0, 1.0, 5.0]) + b = mx.array([-1.0, 2.0, 5.0]) + + self.assertEqual(mx.less(a, b).tolist(), [False, True, False]) + self.assertEqual(mx.less_equal(a, b).tolist(), [False, True, True]) + self.assertEqual(mx.greater(a, b).tolist(), [True, False, False]) + self.assertEqual(mx.greater_equal(a, b).tolist(), [True, False, True]) + + self.assertEqual(mx.less(a, 5).tolist(), [True, True, False]) + self.assertEqual(mx.less(5, a).tolist(), [False, False, False]) + self.assertEqual(mx.less_equal(5, a).tolist(), [False, False, True]) + self.assertEqual(mx.greater(a, 1).tolist(), [False, False, True]) + self.assertEqual(mx.greater_equal(a, 1).tolist(), [False, True, True]) + + a = mx.array([0.0, 1.0, 5.0, -1.0]) + b = mx.array([0.0, 2.0, 5.0, 3.0]) + self.assertEqual(mx.equal(a, b).tolist(), [True, False, True, False]) + self.assertEqual(mx.not_equal(a, b).tolist(), [False, True, False, True]) + + def test_array_equal(self): + x = mx.array([1, 2, 3, 4]) + y = mx.array([1, 2, 3, 4]) + self.assertTrue(mx.array_equal(x, y)) + + y = mx.array([1, 2, 4, 5]) + self.assertFalse(mx.array_equal(x, y)) + + y = mx.array([1, 2, 3]) + self.assertFalse(mx.array_equal(x, y)) + + # Can still be equal with different types + y = mx.array([1.0, 2.0, 3.0, 4.0]) + self.assertTrue(mx.array_equal(x, y)) + + x = mx.array([0.0, float("nan")]) + y = mx.array([0.0, float("nan")]) + self.assertFalse(mx.array_equal(x, y)) + self.assertTrue(mx.array_equal(x, y, equal_nan=True)) + + for t in [mx.float32, mx.float16, mx.bfloat16, mx.complex64]: + with self.subTest(type=t): + x = mx.array([0.0, float("nan")]).astype(t) + y = mx.array([0.0, float("nan")]).astype(t) + self.assertFalse(mx.array_equal(x, y)) + self.assertTrue(mx.array_equal(x, y, equal_nan=True)) + + def test_minimum(self): + x = mx.array([0.0, -5, 10.0]) + y = mx.array([1.0, -7.0, 3.0]) + + expected = [0, -7, 3] + self.assertListEqual(mx.minimum(x, y).tolist(), expected) + + def test_maximum(self): + x = mx.array([0.0, -5, 10.0]) + y = mx.array([1.0, -7.0, 3.0]) + + expected = [1, -5, 10] + self.assertListEqual(mx.maximum(x, y).tolist(), expected) + + def test_transpose_noargs(self): + x = mx.array([[0, 1, 1], [1, 0, 0]]) + + expected = [ + [0, 1], + [1, 0], + [1, 0], + ] + + self.assertListEqual(mx.transpose(x).tolist(), expected) + + def test_transpose_axis(self): + x = mx.array( + [ + [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], + [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]], + ] + ) + expected = [ + [[0, 4, 8], [1, 5, 9], [2, 6, 10], [3, 7, 11]], + [[12, 16, 20], [13, 17, 21], [14, 18, 22], [15, 19, 23]], + ] + + self.assertListEqual(mx.transpose(x, axes=(0, 2, 1)).tolist(), expected) + + def test_sum(self): + x = mx.array( + [ + [1, 2], + [3, 3], + ] + ) + self.assertEqual(mx.sum(x).item(), 9) + y = mx.sum(x, keepdims=True) + self.assertEqual(y, mx.array(9)) + self.assertEqual(y.shape, [1, 1]) + + self.assertEqual(mx.sum(x, axis=0).tolist(), [4, 5]) + self.assertEqual(mx.sum(x, axis=1).tolist(), [3, 6]) + + x_npy = np.arange(3 * 5 * 4 * 7).astype(np.float32) + x_npy = np.reshape(x_npy, (3, 5, 4, 7)) + x_mlx = mx.array(x_npy) + + for axis in (None, 0, 1, 2, 3, (0, 1), (2, 3), (1, 2, 3)): + sum_npy = np.sum(x_npy, axis=axis) + sum_mlx = np.asarray(mx.sum(x_mlx, axis=axis)) + self.assertListEqual(list(sum_npy.shape), list(sum_mlx.shape)) + self.assertTrue(np.all(sum_npy == sum_mlx)) + + x_npy = np.array([1.0, 2.0, 3.0, 4.0]).astype(np.float32) + x_mlx = mx.array(x_npy) + + y_npy = x_npy[0:4:2] + y_npy = np.broadcast_to(y_npy, (2, 2)) + + y_mlx = x_mlx[0:4:2] + y_mlx = mx.broadcast_to(y_mlx, (2, 2)) + + for axis in (None, 0, 1, (0, 1)): + sum_npy = np.sum(y_npy, axis=axis) + sum_mlx = np.asarray(mx.sum(y_mlx, axis=axis)) + self.assertListEqual(list(sum_npy.shape), list(sum_mlx.shape)) + self.assertTrue(np.all(sum_npy == sum_mlx)) + + def test_prod(self): + x = mx.array( + [ + [1, 2], + [3, 3], + ] + ) + self.assertEqual(mx.prod(x).item(), 18) + y = mx.prod(x, keepdims=True) + self.assertEqual(y, mx.array(18)) + self.assertEqual(y.shape, [1, 1]) + + self.assertEqual(mx.prod(x, axis=0).tolist(), [3, 6]) + self.assertEqual(mx.prod(x, axis=1).tolist(), [2, 9]) + + def test_min_and_max(self): + x = mx.array( + [ + [1, 2], + [3, 4], + ] + ) + self.assertEqual(mx.min(x).item(), 1) + self.assertEqual(mx.max(x).item(), 4) + y = mx.min(x, keepdims=True) + self.assertEqual(y.shape, [1, 1]) + self.assertEqual(y, mx.array(1)) + + y = mx.max(x, keepdims=True) + self.assertEqual(y.shape, [1, 1]) + self.assertEqual(y, mx.array(4)) + + self.assertEqual(mx.min(x, axis=0).tolist(), [1, 2]) + self.assertEqual(mx.min(x, axis=1).tolist(), [1, 3]) + self.assertEqual(mx.max(x, axis=0).tolist(), [3, 4]) + self.assertEqual(mx.max(x, axis=1).tolist(), [2, 4]) + + def test_argmin_argmax(self): + data = np.random.rand(10, 12, 13) + x = mx.array(data) + for op in ["argmin", "argmax"]: + for axis in range(3): + for kd in [True, False]: + a = getattr(mx, op)(x, axis, kd) + b = getattr(np, op)(data, axis, keepdims=kd) + self.assertEqual(a.tolist(), b.tolist()) + + for op in ["argmin", "argmax"]: + a = getattr(mx, op)(x, keepdims=True) + b = getattr(np, op)(data, keepdims=True) + self.assertEqual(a.tolist(), b.tolist()) + a = getattr(mx, op)(x) + b = getattr(np, op)(data) + self.assertEqual(a.item(), b) + + def test_broadcast(self): + a_npy = np.reshape(np.arange(200), (10, 20)) + a_mlx = mx.array(a_npy) + + b_npy = np.broadcast_to(a_npy, (30, 10, 20)) + b_mlx = mx.broadcast_to(a_mlx, (30, 10, 20)) + self.assertListEqual(list(b_npy.shape), list(b_mlx.shape)) + self.assertTrue(np.array_equal(b_npy, b_mlx)) + + b_npy = np.broadcast_to(a_npy, (1, 10, 20)) + b_mlx = mx.broadcast_to(a_mlx, (1, 10, 20)) + self.assertListEqual(list(b_npy.shape), list(b_mlx.shape)) + self.assertTrue(np.array_equal(b_npy, b_mlx)) + + b_npy = np.broadcast_to(1, (10, 20)) + b_mlx = mx.broadcast_to(1, (10, 20)) + self.assertListEqual(list(b_npy.shape), list(b_mlx.shape)) + self.assertTrue(np.array_equal(b_npy, b_mlx)) + + def test_logsumexp(self): + x = mx.array( + [ + [1.0, 2.0], + [3.0, 4.0], + ] + ) + xnp = np.array(x.tolist(), dtype=np.float32) + expected = np.log(np.sum(np.exp(xnp))) + self.assertTrue(math.isclose(mx.logsumexp(x).item(), expected.item())) + + def test_mean(self): + x = mx.array( + [ + [1, 2], + [3, 4], + ] + ) + self.assertEqual(mx.mean(x).item(), 2.5) + y = mx.mean(x, keepdims=True) + self.assertEqual(y, mx.array(2.5)) + self.assertEqual(y.shape, [1, 1]) + + self.assertEqual(mx.mean(x, axis=0).tolist(), [2, 3]) + self.assertEqual(mx.mean(x, axis=1).tolist(), [1.5, 3.5]) + + def test_var(self): + x = mx.array( + [ + [1, 2], + [3, 4], + ] + ) + self.assertEqual(mx.var(x).item(), 1.25) + y = mx.var(x, keepdims=True) + self.assertEqual(y, mx.array(1.25)) + self.assertEqual(y.shape, [1, 1]) + + self.assertEqual(mx.var(x, axis=0).tolist(), [1.0, 1.0]) + self.assertEqual(mx.var(x, axis=1).tolist(), [0.25, 0.25]) + + def test_abs(self): + a = mx.array([-1.0, 1.0, -2.0, 3.0]) + result = mx.abs(a) + expected = np.abs(a, dtype=np.float32) + self.assertTrue(np.allclose(result, expected)) + + def test_negative(self): + a = mx.array([-1.0, 1.0, -2.0, 3.0]) + result = mx.negative(a) + expected = np.negative(a, dtype=np.float32) + self.assertTrue(np.allclose(result, expected)) + + def test_sign(self): + a = mx.array([-1.0, 1.0, 0.0, -2.0, 3.0]) + result = mx.sign(a) + expected = np.sign(a, dtype=np.float32) + self.assertTrue(np.allclose(result, expected)) + + def test_logical_not(self): + a = mx.array([-1.0, 1.0, 0.0, 1.0, -2.0, 3.0]) + result = mx.logical_not(a) + expected = np.logical_not(a) + self.assertTrue(np.array_equal(result, expected)) + + def test_square(self): + a = mx.array([0.1, 0.5, 1.0, 10.0]) + result = mx.square(a) + expected = np.square(a, dtype=np.float32) + + self.assertTrue(np.allclose(result, expected)) + + def test_sqrt(self): + a = mx.array([0.1, 0.5, 1.0, 10.0]) + result = mx.sqrt(a) + expected = np.sqrt(a, dtype=np.float32) + self.assertTrue(np.allclose(result, expected)) + + def test_rsqrt(self): + a = mx.array([0.1, 0.5, 1.0, 10.0]) + result = mx.rsqrt(a) + expected = 1.0 / np.sqrt(a, dtype=np.float32) + self.assertTrue(np.allclose(result, expected)) + + def test_reciprocal(self): + a = mx.array([0.1, 0.5, 1.0, 2.0]) + result = mx.reciprocal(a) + expected = np.reciprocal(a, dtype=np.float32) + self.assertTrue(np.allclose(result, expected)) + + def test_logaddexp(self): + a = mx.array([0, 1, 2, 9.0]) + b = mx.array([1, 0, 4, 2.5]) + + result = mx.logaddexp(a, b) + expected = np.logaddexp(a, b, dtype=np.float32) + + self.assertTrue(np.allclose(result, expected)) + + def test_log(self): + a = mx.array([1, 0.5, 10, 100]) + result = mx.log(a) + expected = np.log(a, dtype=np.float32) + + self.assertTrue(np.allclose(result, expected)) + + def test_log2(self): + a = mx.array([0.5, 1, 2, 10, 16]) + result = mx.log2(a) + expected = np.log2(a, dtype=np.float32) + + self.assertTrue(np.allclose(result, expected)) + + def test_log10(self): + a = mx.array([0.1, 1, 10, 20, 100]) + result = mx.log10(a) + expected = np.log10(a, dtype=np.float32) + + self.assertTrue(np.allclose(result, expected)) + + def test_exp(self): + a = mx.array([0, 0.5, -0.5, 5]) + result = mx.exp(a) + expected = np.exp(a, dtype=np.float32) + + self.assertTrue(np.allclose(result, expected)) + + def test_erf(self): + inputs = [-5, 0.0, 0.5, 1.0, 2.0, 10.0] + x = mx.array(inputs) + expected = np.array([math.erf(i) for i in inputs]) + self.assertTrue(np.allclose(mx.erf(x), expected)) + + def test_erfinv(self): + inputs = [-5.0, -1.0, 0.5, 0.0, 0.5, 1.0, 5.0] + x = mx.array(inputs) + # Output of: + # scipy.special.erfinv([-5.0, -1.0, 0.5, 0.0, 0.5, 1.0, 5.0]) + expected = np.array( + [ + float("nan"), + -float("inf"), + 0.47693628, + 0.0, + 0.47693628, + float("inf"), + float("nan"), + ] + ).astype(np.float32) + self.assertTrue(np.allclose(mx.erfinv(x), expected, equal_nan=True)) + + def test_sin(self): + a = mx.array( + [0, math.pi / 4, math.pi / 2, math.pi, 3 * math.pi / 4, 2 * math.pi] + ) + result = mx.sin(a) + expected = np.sin(a, dtype=np.float32) + + self.assertTrue(np.allclose(result, expected)) + + def test_cos(self): + a = mx.array( + [0, math.pi / 4, math.pi / 2, math.pi, 3 * math.pi / 4, 2 * math.pi] + ) + result = mx.cos(a) + expected = np.cos(a, dtype=np.float32) + + self.assertTrue(np.allclose(result, expected)) + + def test_log1p(self): + a = mx.array([1, 0.5, 10, 100]) + result = mx.log1p(a) + expected = np.log1p(a, dtype=np.float32) + + self.assertTrue(np.allclose(result, expected)) + + def test_sigmoid(self): + a = mx.array([0.0, 1.0, -1.0, 5.0, -5.0]) + result = mx.sigmoid(a) + expected = 1 / (1 + np.exp(-a, dtype=np.float32)) + self.assertTrue(np.allclose(result, expected)) + + def test_allclose(self): + a = mx.array(1.0) + b = mx.array(1.0) + + self.assertTrue(mx.allclose(a, b).item()) + + b = mx.array(1.1) + self.assertFalse(mx.allclose(a, b).item()) + self.assertTrue(mx.allclose(a, b, 0.1).item()) + self.assertFalse(mx.allclose(a, b, 0.01).item()) + self.assertTrue(mx.allclose(a, b, 0.01, 0.1).item()) + + def test_all(self): + a = mx.array([[True, False], [True, True]]) + + self.assertFalse(mx.all(a).item()) + self.assertEqual(mx.all(a, keepdims=True).shape, [1, 1]) + self.assertFalse(mx.all(a, axis=[0, 1]).item()) + self.assertEqual(mx.all(a, axis=[0]).tolist(), [True, False]) + self.assertEqual(mx.all(a, axis=[1]).tolist(), [False, True]) + self.assertEqual(mx.all(a, axis=0).tolist(), [True, False]) + self.assertEqual(mx.all(a, axis=1).tolist(), [False, True]) + + def test_any(self): + a = mx.array([[True, False], [False, False]]) + + self.assertTrue(mx.any(a).item()) + self.assertEqual(mx.any(a, keepdims=True).shape, [1, 1]) + self.assertTrue(mx.any(a, axis=[0, 1]).item()) + self.assertEqual(mx.any(a, axis=[0]).tolist(), [True, False]) + self.assertEqual(mx.any(a, axis=[1]).tolist(), [True, False]) + self.assertEqual(mx.any(a, axis=0).tolist(), [True, False]) + self.assertEqual(mx.any(a, axis=1).tolist(), [True, False]) + + def test_stop_gradient(self): + def func(x): + return mx.sum(2 * x + mx.stop_gradient(3 * x)) + + x = mx.array([0.0, 0.1, -3]) + expected = [2, 2, 2] + + self.assertListEqual(mx.grad(func)(x).tolist(), expected) + + def test_take(self): + # Shape: 4 x 3 x 2 + l = [ + [[1, 3], [-2, -2], [-3, -2]], + [[2, 4], [-3, 2], [-4, -2]], + [[2, 3], [2, 4], [2, 1]], + [[1, -5], [3, -1], [2, 3]], + ] + + a = mx.array(l) + a_npy = np.array(l) + + indices = [0, -1] + flatten_take = mx.take(a, mx.array(indices)).tolist() + flatten_take_expected = np.take(a_npy, np.array(indices)).tolist() + self.assertListEqual(flatten_take, flatten_take_expected) + + indices = [-1, 2, 0] + axis_take = mx.take(a, mx.array(indices), axis=0).tolist() + axis_take_expected = np.take(a_npy, np.array(indices), axis=0).tolist() + self.assertListEqual(axis_take, axis_take_expected) + + indices = [0, 0, -2] + axis_take = mx.take(a, mx.array(indices), axis=1).tolist() + axis_take_expected = np.take(a_npy, np.array(indices), axis=1).tolist() + self.assertListEqual(axis_take, axis_take_expected) + + indices = [0, -1, -1] + axis_take = mx.take(a, mx.array(indices), axis=-1).tolist() + axis_take_expected = np.take(a_npy, np.array(indices), axis=-1).tolist() + self.assertListEqual(axis_take, axis_take_expected) + + a_npy = np.arange(8 * 8 * 8, dtype=np.int32) + a_npy = a_npy.reshape((8, 8, 8)) + idx_npy = np.arange(6, dtype=np.uint32) + idx_npy = idx_npy.reshape((2, 3)) + a_mlx = mx.array(a_npy) + idx_mlx = mx.array(idx_npy) + + a_npy_taken = np.take(a_npy, idx_npy) + a_mlx_taken = mx.take(a_mlx, idx_mlx) + self.assertListEqual(list(a_npy_taken.shape), a_mlx_taken.shape) + self.assertListEqual(a_npy_taken.tolist(), a_mlx_taken.tolist()) + + a_npy_taken = np.take(a_npy, idx_npy, axis=0) + a_mlx_taken = mx.take(a_mlx, idx_mlx, axis=0) + self.assertListEqual(list(a_npy_taken.shape), a_mlx_taken.shape) + self.assertListEqual(a_npy_taken.tolist(), a_mlx_taken.tolist()) + + a_npy_taken = np.take(a_npy, idx_npy, axis=1) + a_mlx_taken = mx.take(a_mlx, idx_mlx, axis=1) + self.assertListEqual(list(a_npy_taken.shape), a_mlx_taken.shape) + self.assertListEqual(a_npy_taken.tolist(), a_mlx_taken.tolist()) + + a_npy_taken = np.take(a_npy, idx_npy, axis=2) + a_mlx_taken = mx.take(a_mlx, idx_mlx, axis=2) + self.assertListEqual(list(a_npy_taken.shape), a_mlx_taken.shape) + self.assertListEqual(a_npy_taken.tolist(), a_mlx_taken.tolist()) + + def test_take_along_axis(self): + a_np = np.arange(8).reshape(2, 2, 2) + a_mlx = mx.array(a_np) + idx_np = np.array([1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0]) + idx_mlx = mx.array(idx_np) + + for ax in [None, 0, 1, 2]: + if ax == None: + shape = [-1] + else: + shape = [2] * 3 + shape[ax] = 3 + out_np = np.take_along_axis(a_np, idx_np.reshape(shape), axis=ax) + out_mlx = mx.take_along_axis(a_mlx, mx.reshape(idx_mlx, shape), axis=ax) + self.assertTrue(np.array_equal(out_np, np.array(out_mlx))) + + def test_split(self): + a = mx.array([1, 2, 3]) + splits = mx.split(a, 3) + for e, x in enumerate(splits): + self.assertEqual(x.item(), e + 1) + + a = mx.array([[1, 2], [3, 4], [5, 6]]) + x, y, z = mx.split(a, 3, axis=0) + self.assertEqual(x.tolist(), [[1, 2]]) + self.assertEqual(y.tolist(), [[3, 4]]) + self.assertEqual(z.tolist(), [[5, 6]]) + + a = mx.arange(8) + x, y, z = mx.split(a, [1, 5]) + self.assertEqual(x.tolist(), [0]) + self.assertEqual(y.tolist(), [1, 2, 3, 4]) + self.assertEqual(z.tolist(), [5, 6, 7]) + + def test_arange_overload_dispatch(self): + a = mx.arange(5) + expected = [0, 1, 2, 3, 4] + self.assertListEqual(a.tolist(), expected) + + a = mx.arange(1, 5) + expected = [1, 2, 3, 4] + self.assertListEqual(a.tolist(), expected) + + a = mx.arange(-3, step=-1) + expected = [0, -1, -2] + self.assertListEqual(a.tolist(), expected) + + a = mx.arange(stop=2, step=0.5) + expected = [0, 0.5, 1.0, 1.5] + self.assertListEqual(a.tolist(), expected) + + with self.assertRaises(TypeError): + mx.arange(start=1, step=2) + + a = mx.arange(stop=3) + expected = [0, 1, 2] + self.assertListEqual(a.tolist(), expected) + + def test_arange_inferred_dtype(self): + a = mx.arange(5) + self.assertEqual(a.dtype, mx.int32) + + a = mx.arange(5.0) + self.assertEqual(a.dtype, mx.float32) + + a = mx.arange(1, 3.0) + self.assertEqual(a.dtype, mx.float32) + + a = mx.arange(1, 3, dtype=mx.float32) + self.assertEqual(a.dtype, mx.float32) + + a = mx.arange(1, 5, 1) + self.assertEqual(a.dtype, mx.int32) + + a = mx.arange(1.0, 5, 1) + self.assertEqual(a.dtype, mx.float32) + + a = mx.arange(1, 5.0, 1) + self.assertEqual(a.dtype, mx.float32) + + a = mx.arange(1, 5, 1.0) + self.assertEqual(a.dtype, mx.float32) + + a = mx.arange(1.0, 3.0, 0.2, dtype=mx.int32) + self.assertEqual(a.dtype, mx.int32) + + def test_arange_corner_cases_cast(self): + a = mx.arange(0, 3, 0.2, dtype=mx.int32) + expected = [0] * 15 + self.assertListEqual(a.tolist(), expected) + self.assertEqual(a.dtype, mx.int32) + + a = mx.arange(-1, -4, -0.9, dtype=mx.int32) + expected = [-1] * 4 + self.assertListEqual(a.tolist(), expected) + self.assertEqual(a.dtype, mx.int32) + + a = mx.arange(-1, -20, -1.2, dtype=mx.int32) + expected = [ + -1, + -2, + -3, + -4, + -5, + -6, + -7, + -8, + -9, + -10, + -11, + -12, + -13, + -14, + -15, + -16, + ] + self.assertListEqual(a.tolist(), expected) + self.assertEqual(a.dtype, mx.int32) + + def test_unary_ops(self): + def test_ops(npop, mlxop, x, y, atol): + r_np = npop(x) + r_mlx = mlxop(y) + mx.eval(r_mlx) + + self.assertTrue(np.allclose(r_np, r_mlx, atol=atol)) + + x = np.random.rand(18, 28, 38) + for op in ["abs", "exp", "log", "square", "sqrt"]: + with self.subTest(op=op): + float_dtypes = [("float16", 1e-3), ("float32", 1e-6)] + + for dtype, atol in float_dtypes: + with self.subTest(dtype=dtype): + x_ = x.astype(getattr(np, dtype)) + y_ = mx.array(x_) + test_ops(getattr(np, op), getattr(mx, op), x_, y_, atol) + + def test_trig_ops(self): + def test_ops(npop, mlxop, x, y, atol): + r_np = npop(x) + r_mlx = mlxop(y) + mx.eval(r_mlx) + + self.assertTrue(np.allclose(r_np, r_mlx, atol=atol)) + + x = np.random.rand(9, 12, 18) + xi = np.random.rand(9, 12, 18) + base_ops = ["sin", "cos", "tan"] + hyperbolic_ops = ["sinh", "cosh", "tanh"] + all_fwd_ops = base_ops + hyperbolic_ops + + for op in all_fwd_ops: + with self.subTest(op=op): + float_dtypes = [("float16", 1e-3), ("float32", 1e-6)] + + for dtype, atol in float_dtypes: + with self.subTest(dtype=dtype): + x_ = x.astype(getattr(np, dtype)) + y_ = mx.array(x_) + test_ops(getattr(np, op), getattr(mx, op), x_, y_, atol) + + with self.subTest(op=op): + float_dtypes = [("complex64", 1e-5)] + + for dtype, atol in float_dtypes: + with self.subTest(dtype=dtype): + x_ = x + 1.0j * xi + x_ = x_.astype(getattr(np, dtype)) + y_ = mx.array(x_) + test_ops(getattr(np, op), getattr(mx, op), x_, y_, atol) + + with self.subTest(op="arc" + op): + float_dtypes = [("float16", 1e-3), ("float32", 1e-6)] + op_inv = "arc" + op + + for dtype, atol in float_dtypes: + with self.subTest(dtype=dtype): + np_op_fwd = getattr(np, op) + x_ = np_op_fwd(x).astype(getattr(np, dtype)) + y_ = mx.array(x_) + test_ops(getattr(np, op_inv), getattr(mx, op_inv), x_, y_, atol) + + # Test grads + np_vjp_funcs = { + "sin": lambda primal, cotan: cotan * np.cos(primal), + "cos": lambda primal, cotan: -cotan * np.sin(primal), + "tan": lambda primal, cotan: cotan / (np.cos(primal) ** 2), + "sinh": lambda primal, cotan: cotan * np.cosh(primal), + "cosh": lambda primal, cotan: cotan * np.sinh(primal), + "tanh": lambda primal, cotan: cotan / (np.cosh(primal) ** 2), + "arcsin": lambda primal, cotan: cotan / np.sqrt(1.0 - primal**2), + "arccos": lambda primal, cotan: -cotan / np.sqrt(1.0 - primal**2), + "arctan": lambda primal, cotan: cotan / (1.0 + primal**2), + "arcsinh": lambda primal, cotan: cotan / np.sqrt(primal**2 + 1), + "arccosh": lambda primal, cotan: cotan / np.sqrt(primal**2 - 1), + "arctanh": lambda primal, cotan: cotan / (1.0 - primal**2), + } + with self.subTest(name="grads"): + for op in all_fwd_ops: + with self.subTest(op=op): + primal_np = xi.astype(np.float32) + primal_mx = mx.array(primal_np) + x_ = x.astype(np.float32) + y_ = mx.array(x_) + op_ = op + atol_ = 1e-5 + + np_vjp = lambda x: np_vjp_funcs[op_](primal_np, x) + mx_vjp = lambda x: mx.vjp(getattr(mx, op_), [primal_mx], [x])[1][0] + test_ops(np_vjp, mx_vjp, x_, y_, atol_) + + with self.subTest(op="arc" + op): + np_op_fwd = getattr(np, op) + primal_np = np_op_fwd(xi).astype(np.float32) + + # To avoid divide by zero error + if op == "cosh": + primal_np[np.isclose(primal_np, 1.0)] += 1e-3 + elif op == "cos": + primal_np[np.isclose(primal_np, 1.0)] -= 1e-3 + + primal_mx = mx.array(primal_np) + x_ = x.astype(np.float32) + y_ = mx.array(x_) + op_ = "arc" + op + atol_ = 1e-5 + + np_vjp = lambda x: np_vjp_funcs[op_](primal_np, x) + mx_vjp = lambda x: mx.vjp(getattr(mx, op_), [primal_mx], [x])[1][0] + test_ops(np_vjp, mx_vjp, x_, y_, atol_) + + def test_binary_ops(self): + def test_ops(npop, mlxop, x, y, atol): + r_np = npop(x, x) + r_mlx = mlxop(y, y) + mx.eval(r_mlx) + self.assertTrue(np.allclose(r_np, r_mlx, atol=atol)) + + r_np = npop(x[:1], x) + r_mlx = mlxop(y[:1], y) + mx.eval(r_mlx) + self.assertTrue(np.allclose(r_np, r_mlx, atol=atol)) + + r_np = npop(x[:, :1], x) + r_mlx = mlxop(y[:, :1], y) + mx.eval(r_mlx) + self.assertTrue(np.allclose(r_np, r_mlx, atol=atol)) + + r_np = npop(x[:, :, :1], x) + r_mlx = mlxop(y[:, :, :1], y) + mx.eval(r_mlx) + self.assertTrue(np.allclose(r_np, r_mlx, atol=atol)) + + x = np.maximum(np.random.rand(18, 28, 38), 0.1) + y = mx.array(x) + mx.eval(y) + for op in [ + "add", + "subtract", + "multiply", + "divide", + "maximum", + "minimum", + "power", + ]: + with self.subTest(op=op): + int_dtypes = [ + "int8", + "int16", + "int32", + "int64", + "uint8", + "uint16", + "uint32", + "uint64", + ] + + float_dtypes = ["float16", "float32"] + + dtypes = ( + float_dtypes + if op in ("divide", "power") + else (int_dtypes + float_dtypes) + ) + for dtype in dtypes: + atol = 1e-3 if dtype == "float16" else 1e-6 + with self.subTest(dtype=dtype): + x_ = x.astype(getattr(np, dtype)) + y_ = mx.array(x_) + test_ops(getattr(np, op), getattr(mx, op), x_, y_, atol) + + def test_irregular_binary_ops(self): + # Check transposed binary ops + dims = [2, 3, 4, 5] + size = 3 + trial_mul = 2 + np.random.seed(0) + for d in dims: + anp = np.random.randint(-20, 20, (size**d,)).reshape([size] * d) + bnp = np.random.randint(-20, 20, (size**d,)).reshape([size] * d) + for _ in range(trial_mul * d): + amlx = mx.array(anp) + bmlx = mx.array(bnp) + a_t = np.random.permutation(d).tolist() + b_t = np.random.permutation(d).tolist() + outnp = np.add(anp.transpose(a_t), bnp.transpose(b_t)) + outmlx = mx.add(mx.transpose(amlx, a_t), mx.transpose(bmlx, b_t)) + self.assertTrue(np.array_equal(outnp, outmlx)) + + # Check broadcast binary ops + for d in dims: + anp = np.random.randint(-20, 20, (size**d,)).reshape([size] * d) + for n_bsx in range(d): + bnp = np.random.randint(-20, 20, (size**n_bsx,)).reshape( + [size] * n_bsx + ) + for _ in range(trial_mul * d): + amlx = mx.array(anp) + bmlx = mx.array(bnp) + b_shape = [1] * (d - n_bsx) + [size] * n_bsx + np.random.shuffle(b_shape) + outnp = np.add(anp, bnp.reshape(b_shape)) + outmlx = mx.add(amlx, mx.reshape(bmlx, b_shape)) + self.assertTrue(np.array_equal(outnp, outmlx)) + + # Check strided binary ops + for d in dims: + a = np.random.randint(-20, 20, (10,) * d) + b = np.random.randint(-20, 20, (10,) * d) + a_ = mx.array(a) + b_ = mx.array(b) + for t in permutations(range(d)): + for s in range(d): + idx = tuple( + [slice(None)] * s + + [slice(None, None, 2)] + + [slice(None)] * (d - s - 1) + ) + c = a.transpose(t)[idx] + b[idx] + c_ = mx.transpose(a_, t)[idx] + b_[idx] + self.assertTrue(np.array_equal(c, c_)) + + def test_softmax(self): + cases = [(np.float32, 1e-6), (np.float16, 1e-3)] + + for dtype, atol in cases: + a_npy = np.random.randn(16, 8, 32).astype(dtype) + a_mlx = mx.array(a_npy) + + def np_softmax(x, axis): + ex = np.exp(x - np.max(x, axis=axis, keepdims=True)) + return ex / np.sum(ex, axis=axis, keepdims=True) + + for axes in (None, 0, 1, 2, (0, 1), (1, 2), (0, 2), (0, 1, 2)): + b_npy = np_softmax(a_npy, axes) + b_mlx = mx.softmax(a_mlx, axes) + self.assertTrue(np.allclose(b_npy, b_mlx, atol=atol)) + + for s in [100, 2049, 4097, 8193]: + a = np.full(s, -np.inf) + a[-1] = 0.0 + a = mx.softmax(mx.array(a)) + self.assertFalse(np.any(np.isnan(a))) + self.assertTrue((a[:-1] == 0).all()) + self.assertEqual(a[-1], 1) + + def test_concatenate(self): + a_npy = np.random.randn(32, 32, 32) + b_npy = np.random.randn(32, 32, 32) + a_mlx = mx.array(a_npy) + b_mlx = mx.array(b_npy) + + for axis in (None, 0, 1, 2): + for p in permutations([0, 1, 2]): + c_npy = np.concatenate([a_npy, np.transpose(b_npy, p)], axis=axis) + c_mlx = mx.concatenate([a_mlx, mx.transpose(b_mlx, p)], axis=axis) + self.assertEqual(list(c_npy.shape), list(c_mlx.shape)) + self.assertTrue(np.allclose(c_npy, c_mlx, atol=1e-6)) + + def test_pad(self): + pad_width_and_values = [ + ([(1, 1), (1, 1), (1, 1)], 0), + ([(1, 1), (1, 1), (1, 1)], 5), + ([(3, 0), (0, 2), (5, 7)], 0), + ([(3, 0), (0, 2), (5, 7)], -7), + ([(0, 0), (0, 0), (0, 0)], 0), + ] + + for pw, v in pad_width_and_values: + with self.subTest(pad_width=pw, value=v): + a_npy = np.random.randn(16, 16, 16).astype(np.float32) + a_mlx = mx.array(a_npy) + + b_npy = np.pad(a_npy, pw, constant_values=v) + b_mlx = mx.pad(a_mlx, pw, constant_values=v) + + self.assertEqual(list(b_npy.shape), list(b_mlx.shape)) + self.assertTrue(np.allclose(b_npy, b_mlx, atol=1e-6)) + + a = mx.zeros((1, 1, 1)) + self.assertEqual(mx.pad(a, 1).shape, [3, 3, 3]) + self.assertEqual(mx.pad(a, (1,)).shape, [3, 3, 3]) + self.assertEqual(mx.pad(a, [1]).shape, [3, 3, 3]) + self.assertEqual(mx.pad(a, (1, 2)).shape, [4, 4, 4]) + self.assertEqual(mx.pad(a, [(1, 2)]).shape, [4, 4, 4]) + self.assertEqual(mx.pad(a, ((1, 2),)).shape, [4, 4, 4]) + self.assertEqual(mx.pad(a, ((1, 2), (2, 1), (2, 2))).shape, [4, 4, 5]) + + # Test grads + a_fwd = mx.array(np.random.rand(16, 16).astype(np.float32)) + a_bwd = mx.ones((22, 22)) + f = lambda x: mx.pad(x, ((4, 2), (2, 4))) + + _, df = mx.vjp(f, [a_fwd], [a_bwd]) + self.assertTrue(mx.allclose(a_bwd[4:-2, 2:-4], df[0]).item()) + + def test_where(self): + a = mx.array([[1, 2], [3, 4]]) + out = mx.where(True, a, 1) + out_np = np.where(True, a, 1) + self.assertTrue(np.array_equal(out, out_np)) + + out = mx.where(True, 1, a) + out_np = np.where(True, 1, a) + self.assertTrue(np.array_equal(out, out_np)) + + condition = mx.array([[True, False], [False, True]]) + b = mx.array([5, 6]) + out = mx.where(condition, a, b) + out_np = np.where(condition, a, b) + self.assertTrue(np.array_equal(out, out_np)) + + def test_as_strided(self): + x_npy = np.random.randn(128).astype(np.float32) + x_mlx = mx.array(x_npy) + + shapes = [(10, 10), (5, 5), (2, 20), (10,)] + strides = [(3, 3), (7, 1), (1, 5), (4,)] + for shape, stride in zip(shapes, strides): + for offset in [0, 1, 3]: + y_npy = np.lib.stride_tricks.as_strided( + x_npy[offset:], shape, np.multiply(stride, 4) + ) + y_mlx = mx.as_strided(x_mlx, shape, stride, offset) + self.assertTrue(np.array_equal(y_npy, y_mlx)) + + def test_scans(self): + a_npy = np.random.randn(32, 32, 32).astype(np.float32) + a_mlx = mx.array(a_npy) + + for op in ["cumsum", "cumprod"]: + npop = getattr(np, op) + mxop = getattr(mx, op) + for axis in (None, 0, 1, 2): + c_npy = npop(a_npy, axis=axis) + c_mlx = mxop(a_mlx, axis=axis) + self.assertTrue(np.allclose(c_npy, c_mlx, rtol=1e-4, atol=1e-4)) + + for op in ["cumsum", "cumprod", "cummax", "cummin"]: + c1 = mxop(a_mlx, axis=2) + c2 = mxop(a_mlx, axis=2, inclusive=False, reverse=False) + self.assertTrue(mx.array_equal(c1[:, :, :-1], c2[:, :, 1:])) + c1 = mxop(a_mlx, axis=1) + c2 = mxop(a_mlx, axis=1, inclusive=False, reverse=False) + self.assertTrue(mx.array_equal(c1[:, :-1, :], c2[:, 1:, :])) + c1 = mxop(a_mlx, axis=0) + c2 = mxop(a_mlx, axis=0, inclusive=False, reverse=False) + self.assertTrue(mx.array_equal(c1[:-1, :, :], c2[1:, :, :])) + + rev_idx = mx.arange(31, -1, -1) + c1 = mxop(a_mlx[:, :, rev_idx], axis=2)[:, :, rev_idx] + c2 = mxop(a_mlx, axis=2, inclusive=True, reverse=True) + self.assertTrue(mx.array_equal(c1, c2)) + c1 = mxop(a_mlx[:, rev_idx, :], axis=1)[:, rev_idx, :] + c2 = mxop(a_mlx, axis=1, inclusive=True, reverse=True) + self.assertTrue(mx.array_equal(c1, c2)) + c1 = mxop(a_mlx[rev_idx, :, :], axis=0)[rev_idx, :, :] + c2 = mxop(a_mlx, axis=0, inclusive=True, reverse=True) + self.assertTrue(mx.array_equal(c1, c2)) + + rev_idx = mx.arange(31, -1, -1) + c1 = mxop(a_mlx[:, :, rev_idx], axis=2)[:, :, rev_idx][:, :, 1:] + c2 = mxop(a_mlx, axis=2, inclusive=False, reverse=True)[:, :, :-1] + self.assertTrue(mx.array_equal(c1, c2)) + c1 = mxop(a_mlx[:, rev_idx, :], axis=1)[:, rev_idx, :][:, 1:, :] + c2 = mxop(a_mlx, axis=1, inclusive=False, reverse=True)[:, :-1, :] + self.assertTrue(mx.array_equal(c1, c2)) + c1 = mxop(a_mlx[rev_idx, :, :], axis=0)[rev_idx, :, :][1:, :, :] + c2 = mxop(a_mlx, axis=0, inclusive=False, reverse=True)[:-1, :, :] + self.assertTrue(mx.array_equal(c1, c2)) + + def test_squeeze_expand(self): + a = mx.zeros((2, 1, 2, 1)) + self.assertEqual(mx.squeeze(a).shape, [2, 2]) + self.assertEqual(mx.squeeze(a, 1).shape, [2, 2, 1]) + self.assertEqual(mx.squeeze(a, [1, 3]).shape, [2, 2]) + self.assertEqual(a.squeeze().shape, [2, 2]) + self.assertEqual(a.squeeze(1).shape, [2, 2, 1]) + self.assertEqual(a.squeeze([1, 3]).shape, [2, 2]) + + a = mx.zeros((2, 2)) + self.assertEqual(mx.squeeze(a).shape, [2, 2]) + + self.assertEqual(mx.expand_dims(a, 0).shape, [1, 2, 2]) + self.assertEqual(mx.expand_dims(a, (0, 1)).shape, [1, 1, 2, 2]) + self.assertEqual(mx.expand_dims(a, [0, -1]).shape, [1, 2, 2, 1]) + + def test_sort(self): + shape = (3, 4, 5) + for dtype in ("int32", "float32"): + for axis in (None, 0, 1, 2): + with self.subTest(dtype=dtype, axis=axis): + np.random.seed(0) + np_dtype = getattr(np, dtype) + a_np = np.random.uniform(0, 100, size=shape).astype(np_dtype) + a_mx = mx.array(a_np) + + b_np = np.sort(a_np, axis=axis) + b_mx = mx.sort(a_mx, axis=axis) + + self.assertTrue(np.array_equal(b_np, b_mx)) + self.assertEqual(b_mx.dtype, a_mx.dtype) + + c_np = np.argsort(a_np, axis=axis) + c_mx = mx.argsort(a_mx, axis=axis) + d_np = np.take_along_axis(a_np, c_np, axis=axis) + d_mx = mx.take_along_axis(a_mx, c_mx, axis=axis) + + self.assertTrue(np.array_equal(d_np, d_mx)) + self.assertEqual(c_mx.dtype, mx.uint32) + + def test_partition(self): + shape = (3, 4, 5) + for dtype in ("int32", "float32"): + for axis in (None, 0, 1, 2): + for kth in (-2, 2): + with self.subTest(dtype=dtype, axis=axis, kth=kth): + np.random.seed(0) + np_dtype = getattr(np, dtype) + a_np = np.random.uniform(0, 100, size=shape).astype(np_dtype) + a_mx = mx.array(a_np) + + b_np = np.partition(a_np, kth, axis=axis) + b_mx = mx.partition(a_mx, kth, axis=axis) + + c_np = np.take(b_np, (kth,), axis=axis) + c_mx = np.take(np.array(b_mx), (kth,), axis=axis) + + self.assertTrue(np.array_equal(c_np, c_mx)) + self.assertEqual(b_mx.dtype, a_mx.dtype) + + top_k_mx = mx.topk(a_mx, kth, axis=axis) + self.assertTrue(np.all(c_np <= top_k_mx)) + self.assertEqual(top_k_mx.dtype, a_mx.dtype) + + if kth >= 0: + d_np = np.take(b_mx, np.arange(kth), axis=axis) + self.assertTrue(np.all(d_np <= c_mx)) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/tests/test_reduce.py b/python/tests/test_reduce.py new file mode 100644 index 0000000000..b75f8a0ba5 --- /dev/null +++ b/python/tests/test_reduce.py @@ -0,0 +1,118 @@ +import unittest +from itertools import permutations, combinations + +import mlx.core as mx +import numpy as np + +import mlx_tests + + +class TestReduce(mlx_tests.MLXTestCase): + def test_axis_permutation_sums(self): + x_npy = np.random.randn(5, 5, 5, 5, 5).astype(np.float32) + x_mlx = mx.array(x_npy) + for t in permutations(range(5)): + with self.subTest(t=t): + y_npy = np.transpose(x_npy, t) + y_mlx = mx.transpose(x_mlx, t) + for n in range(1, 6): + for a in combinations(range(5), n): + with self.subTest(a=a): + z_npy = np.sum(y_npy, axis=a) + z_mlx = mx.sum(y_mlx, axis=a) + mx.eval(z_mlx) + self.assertTrue( + np.allclose(z_npy, np.array(z_mlx), atol=1e-4) + ) + + def test_expand_sums(self): + x_npy = np.random.randn(5, 1, 5, 1, 5, 1).astype(np.float32) + x_mlx = mx.array(x_npy) + for m in range(1, 4): + for ax in combinations([1, 3, 5], m): + shape = np.array([5, 1, 5, 1, 5, 1]) + shape[list(ax)] = 5 + shape = shape.tolist() + with self.subTest(shape=shape): + y_npy = np.broadcast_to(x_npy, shape) + y_mlx = mx.broadcast_to(x_mlx, shape) + for n in range(1, 7): + for a in combinations(range(6), n): + with self.subTest(a=a): + z_npy = np.sum(y_npy, axis=a) / 1000 + z_mlx = mx.sum(y_mlx, axis=a) / 1000 + mx.eval(z_mlx) + self.assertTrue( + np.allclose(z_npy, np.array(z_mlx), atol=1e-4) + ) + + def test_dtypes(self): + int_dtypes = [ + "int8", + "int16", + "int32", + "uint8", + "uint16", + "uint32", + ] + float_dtypes = ["float32"] + + for dtype in int_dtypes + float_dtypes: + with self.subTest(dtype=dtype): + x = np.random.uniform(0, 2, size=(3, 3, 3)).astype(getattr(np, dtype)) + y = mx.array(x) + + for op in ("sum", "prod", "min", "max"): + with self.subTest(op=op): + + np_op = getattr(np, op) + mlx_op = getattr(mx, op) + + for axes in (None, 0, 1, 2, (0, 1), (0, 2), (1, 2), (0, 1, 2)): + with self.subTest(axes=axes): + if op in ("sum", "prod"): + r_np = np_op( + x, axis=axes, dtype=(getattr(np, dtype)) + ) + else: + r_np = np_op(x, axis=axes) + r_mlx = mlx_op(y, axis=axes) + mx.eval(r_mlx) + self.assertTrue(np.allclose(r_np, r_mlx, atol=1e-4)) + + def test_arg_reduce(self): + dtypes = [ + "uint8", + "uint16", + "uint32", + "uint64", + "int8", + "int16", + "int32", + "int64", + "float16", + "float32", + ] + for dtype in dtypes: + with self.subTest(dtype=dtype): + + data = np.random.rand(10, 12, 13).astype(getattr(np, dtype)) + x = mx.array(data) + for op in ["argmin", "argmax"]: + for axis in range(3): + for kd in [True, False]: + a = getattr(mx, op)(x, axis, kd) + b = getattr(np, op)(data, axis, keepdims=kd) + self.assertEqual(a.tolist(), b.tolist()) + + for op in ["argmin", "argmax"]: + a = getattr(mx, op)(x, keepdims=True) + b = getattr(np, op)(data, keepdims=True) + self.assertEqual(a.tolist(), b.tolist()) + a = getattr(mx, op)(x) + b = getattr(np, op)(data) + self.assertEqual(a.item(), b) + + +if __name__ == "__main__": + unittest.main(failfast=True) diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt new file mode 100644 index 0000000000..0879aa0f6f --- /dev/null +++ b/tests/CMakeLists.txt @@ -0,0 +1,38 @@ +FetchContent_Declare( + doctest + GIT_REPOSITORY "https://github.com/onqtam/doctest" + GIT_TAG "b7c21ec5ceeadb4951b00396fc1e4642dd347e5f" +) +FetchContent_MakeAvailable(doctest) + +add_executable(tests ${PROJECT_SOURCE_DIR}/tests/tests.cpp) + +if (MLX_BUILD_METAL) + set( + METAL_TEST_SOURCES + metal_tests.cpp + ) +endif() + +target_sources(tests PRIVATE + allocator_tests.cpp + array_tests.cpp + arg_reduce_tests.cpp + autograd_tests.cpp + blas_tests.cpp + creations_tests.cpp + device_tests.cpp + eval_tests.cpp + fft_tests.cpp + graph_optimize_tests.cpp + load_tests.cpp + ops_tests.cpp + random_tests.cpp + scheduler_tests.cpp + utils_tests.cpp + vmap_tests.cpp + ${METAL_TEST_SOURCES} +) + +target_link_libraries(tests PRIVATE mlx doctest) +add_test(NAME tests COMMAND tests) diff --git a/tests/arg_reduce_tests.cpp b/tests/arg_reduce_tests.cpp new file mode 100644 index 0000000000..2bfd8968f5 --- /dev/null +++ b/tests/arg_reduce_tests.cpp @@ -0,0 +1,205 @@ +#include + +#include "doctest/doctest.h" + +#include "mlx/mlx.h" +#include "mlx/primitives.h" + +using namespace mlx::core; + +void test_arg_reduce_small( + Device d, + const array& x, + ArgReduce::ReduceType r, + std::vector out_shape, + int axis, + std::vector expected_output) { + auto s = default_stream(d); + auto y = + array(out_shape, uint32, std::make_unique(s, r, axis), {x}); + y.eval(); + const uint32_t* ydata = y.data(); + for (int i = 0; i < y.size(); i++) { + CHECK_EQ(expected_output[i], ydata[i]); + } +} + +void test_arg_reduce_against_cpu( + const array& x, + ArgReduce::ReduceType r, + std::vector out_shape, + int axis) { + auto y1 = array( + out_shape, + uint32, + std::make_unique(default_stream(Device::cpu), r, axis), + {x}); + auto y2 = array( + out_shape, + uint32, + std::make_unique(default_stream(Device::gpu), r, axis), + {x}); + y1.eval(); + y2.eval(); + CHECK(array_equal(y1, y2).item()); +} + +TEST_CASE("test arg reduce small") { + auto x = array( + {0, 2, 1, 7, 5, -5, 0, 2, 1, 7, 5, -5, + 0, 2, 1, 7, 5, -5, 0, 2, 1, 7, 5, -5}, + {2, 3, 4}); + x.eval(); + test_arg_reduce_small( + Device::cpu, x, ArgReduce::ArgMin, {2, 3}, 2, {0, 1, 3, 0, 1, 3}); + test_arg_reduce_small( + Device::cpu, x, ArgReduce::ArgMin, {2, 4}, 1, {0, 1, 1, 2, 0, 1, 1, 2}); + test_arg_reduce_small( + Device::cpu, + x, + ArgReduce::ArgMin, + {3, 4}, + 0, + {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); + test_arg_reduce_small( + Device::cpu, x, ArgReduce::ArgMax, {2, 3}, 2, {3, 0, 1, 3, 0, 1}); + test_arg_reduce_small( + Device::cpu, x, ArgReduce::ArgMax, {2, 4}, 1, {1, 2, 2, 0, 1, 2, 2, 0}); + test_arg_reduce_small( + Device::cpu, + x, + ArgReduce::ArgMax, + {3, 4}, + 0, + {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); + + if (!metal::is_available()) { + INFO("Skiping arg reduction gpu tests"); + return; + } + + test_arg_reduce_small( + Device::gpu, x, ArgReduce::ArgMin, {2, 3}, 2, {0, 1, 3, 0, 1, 3}); + test_arg_reduce_small( + Device::gpu, x, ArgReduce::ArgMin, {2, 4}, 1, {0, 1, 1, 2, 0, 1, 1, 2}); + test_arg_reduce_small( + Device::gpu, + x, + ArgReduce::ArgMin, + {3, 4}, + 0, + {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); + test_arg_reduce_small( + Device::gpu, x, ArgReduce::ArgMax, {2, 3}, 2, {3, 0, 1, 3, 0, 1}); + test_arg_reduce_small( + Device::gpu, x, ArgReduce::ArgMax, {2, 4}, 1, {1, 2, 2, 0, 1, 2, 2, 0}); + test_arg_reduce_small( + Device::gpu, + x, + ArgReduce::ArgMax, + {3, 4}, + 0, + {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); +} + +TEST_CASE("test arg reduce against cpu") { + if (!metal::is_available()) { + INFO("Skiping arg reduction gpu tests"); + return; + } + + auto x = random::uniform(array(0.0), array(1.0), {127, 92, 55}); + x.eval(); + test_arg_reduce_against_cpu(x, ArgReduce::ArgMin, {127, 92}, 2); + test_arg_reduce_against_cpu(x, ArgReduce::ArgMin, {127, 55}, 1); + test_arg_reduce_against_cpu(x, ArgReduce::ArgMin, {92, 55}, 0); + test_arg_reduce_against_cpu(x, ArgReduce::ArgMax, {127, 92}, 2); + test_arg_reduce_against_cpu(x, ArgReduce::ArgMax, {127, 55}, 1); + test_arg_reduce_against_cpu(x, ArgReduce::ArgMax, {92, 55}, 0); + + auto y = random::uniform(array(0.0), array(1.0), {1234}); + y.eval(); + test_arg_reduce_against_cpu(y, ArgReduce::ArgMin, {}, 0); + test_arg_reduce_against_cpu(y, ArgReduce::ArgMax, {}, 0); +} + +void test_arg_reduce_small_bool( + Device d, + ArgReduce::ReduceType r, + std::vector out_shape, + int axis, + std::vector expected_output) { + auto s = default_stream(d); + auto x = array( + {0, 2, 1, 7, 5, -5, 0, 2, 1, 7, 5, -5, + 0, 2, 1, 7, 5, -5, 0, 2, 1, 7, 5, -5}, + {2, 3, 4}); + x.eval(); + auto y = + array(out_shape, uint32, std::make_unique(s, r, axis), {x}); + y.eval(); + const uint32_t* ydata = y.data(); + for (int i = 0; i < y.size(); i++) { + CHECK_EQ(expected_output[i], ydata[i]); + } +} + +TEST_CASE("test arg reduce bool") { + if (!metal::is_available()) { + INFO("Skiping arg reduction gpu tests"); + return; + } + auto x = array( + {false, true, true, false, false, false, false, true, + true, false, true, true, false, true, true, false, + false, false, false, true, true, false, true, true}, + {2, 3, 4}); + x.eval(); + test_arg_reduce_small( + Device::gpu, x, ArgReduce::ArgMin, {2, 3}, 2, {0, 0, 1, 0, 0, 1}); + test_arg_reduce_small( + Device::gpu, x, ArgReduce::ArgMin, {2, 4}, 1, {0, 1, 1, 0, 0, 1, 1, 0}); + test_arg_reduce_small( + Device::gpu, + x, + ArgReduce::ArgMin, + {3, 4}, + 0, + {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); + test_arg_reduce_small( + Device::gpu, x, ArgReduce::ArgMax, {2, 3}, 2, {1, 3, 0, 1, 3, 0}); + test_arg_reduce_small( + Device::gpu, x, ArgReduce::ArgMax, {2, 4}, 1, {2, 0, 0, 1, 2, 0, 0, 1}); + test_arg_reduce_small( + Device::gpu, + x, + ArgReduce::ArgMax, + {3, 4}, + 0, + {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); +} + +TEST_CASE("test arg reduce edge cases") { + auto a = argmin(array(1.0)); + CHECK_EQ(a.item(), 0); + auto b = argmax(array(1.0)); + CHECK_EQ(b.item(), 0); + CHECK_THROWS(argmin(array({}))); + CHECK_THROWS(argmax(array({}))); +} + +TEST_CASE("test arg reduce irregular strides") { + auto x = array( + {0, 2, 1, 7, 5, -5, 0, 2, 1, 7, 5, -5, + 0, 2, 1, 7, 5, -5, 0, 2, 1, 7, 5, -5}, + {2, 3, 4}); + x = transpose(x, {2, 0, 1}); + x.eval(); + test_arg_reduce_small( + Device::cpu, x, ArgReduce::ArgMin, {4, 2}, 2, {0, 0, 1, 1, 1, 1, 2, 2}); + + if (!metal::is_available()) { + INFO("Skiping arg reduction gpu tests"); + return; + } +} diff --git a/tests/blas_tests.cpp b/tests/blas_tests.cpp new file mode 100644 index 0000000000..256a876a37 --- /dev/null +++ b/tests/blas_tests.cpp @@ -0,0 +1,108 @@ +#include + +#include "doctest/doctest.h" + +#include "mlx/mlx.h" + +using namespace mlx::core; + +TEST_CASE("test matmul") { + auto a = array(1); + auto b = array({1.0}); + CHECK_THROWS_AS(matmul(a, b), std::invalid_argument); + + a = array({1.0}); + b = array({1.0}); + auto out = matmul(a, b); + CHECK_EQ(out.shape(), std::vector{}); + CHECK_EQ(out.size(), 1); + CHECK_EQ(out.dtype(), float32); + CHECK_EQ(out.item(), 1.0f); + + a = ones({2, 4}); + b = ones({2}); + CHECK_THROWS_AS(matmul(a, b), std::invalid_argument); + + a = ones({2, 4}); + b = ones({3, 2}); + CHECK_THROWS_AS(matmul(a, b), std::invalid_argument); + + a = ones({2, 4}); + b = ones({4, 3, 2}); + CHECK_THROWS_AS(matmul(a, b), std::invalid_argument); + + a = ones({2}); + b = ones({4, 2}); + CHECK_THROWS_AS(matmul(a, b), std::invalid_argument); + + a = ones({2, 3}); + b = ones({4, 2}); + CHECK_THROWS_AS(matmul(a, b), std::invalid_argument); + + a = ones({2, 4, 3}); + b = ones({4, 2}); + CHECK_THROWS_AS(matmul(a, b), std::invalid_argument); + + a = ones({2, 4}); + b = ones({4, 2}); + out = matmul(a, b); + CHECK(array_equal(out, full({2, 2}, 4.0f)).item()); + + a = ones({2, 4}, int32); + b = ones({4, 2}, float32); + out = matmul(a, b); + CHECK(array_equal(out, full({2, 2}, 4.0f)).item()); + + // Check single dimensions + a = ones({4}); + b = ones({4, 2}); + out = matmul(a, b); + CHECK(array_equal(out, full({2}, 4.0f)).item()); + + a = ones({2, 4}); + b = ones({4}); + out = matmul(a, b); + CHECK(array_equal(out, full({2}, 4.0f)).item()); + + a = ones({4}); + b = ones({4}); + out = matmul(a, b); + CHECK(array_equal(out, full({}, 4.0f)).item()); + + // Test transposed arrays + a = array({1.0f, 1.0f, 1.0f, 1.0f}, {1, 4}); + b = array({1.0f, 1.0f, 1.0f, 1.0f}, {4, 1}); + out = matmul(transpose(a), transpose(b)); + CHECK(array_equal(out, ones({4, 4})).item()); + + a = array({1.0f, 2.0f, 3.0f, 4.0f}, {2, 2}); + b = array({1.0f, 2.0f, 1.0f, 2.0f}, {2, 2}); + out = matmul(transpose(a), b); + CHECK( + array_equal(out, array({4.0f, 8.0f, 6.0f, 12.0f}, {2, 2})).item()); + + out = matmul(a, transpose(b)); + CHECK( + array_equal(out, array({5.0f, 5.0f, 11.0f, 11.0f}, {2, 2})).item()); + + out = matmul(transpose(a), transpose(b)); + CHECK( + array_equal(out, array({7.0f, 7.0f, 10.0f, 10.0f}, {2, 2})).item()); + + // Test broadcasting for both arrays + a = ones({5, 4, 2}); + b = ones({2, 3}); + out = matmul(a, b); + CHECK(array_equal(out, full({5, 4, 3}, 2.0f)).item()); + + a = ones({5, 1, 4, 2}); + b = ones({1, 7, 2, 3}); + out = matmul(a, b); + CHECK(array_equal(out, full({5, 7, 4, 3}, 2.0f)).item()); + + // Test batched matmul with transpose + a = ones({2, 2, 4}); + b = ones({2, 4, 2}); + out = matmul(transpose(a, {0, 2, 1}), transpose(b, {0, 2, 1})); + CHECK(array_equal(out, full({2, 4, 4}, 2.0f)).item()); +} diff --git a/tests/metal_tests.cpp b/tests/metal_tests.cpp new file mode 100644 index 0000000000..76f62dce0f --- /dev/null +++ b/tests/metal_tests.cpp @@ -0,0 +1,438 @@ +#include +#include "doctest/doctest.h" + +#include "mlx/backend/metal/device.h" +#include "mlx/backend/metal/metal.h" +#include "mlx/mlx.h" + +using namespace mlx::core; + +static const std::array types = + {bool_, uint32, int32, int64, float32}; + +TEST_CASE("test metal device") { + // Make sure the device and library can load + CHECK(metal::is_available()); + auto& device = metal::device(Device::gpu); +} + +TEST_CASE("test metal arange") { + for (auto t : types) { + if (t == bool_) { + continue; + } + auto out_cpu = arange(1, 100, 2, t, Device::cpu); + auto out_gpu = arange(1, 100, 2, t, Device::gpu); + CHECK(array_equal(out_gpu, out_cpu, Device::cpu).item()); + + out_cpu = arange(1, 5, 0.25, t, Device::cpu); + out_gpu = arange(1, 5, 0.25, t, Device::gpu); + CHECK(array_equal(out_gpu, out_cpu, Device::cpu).item()); + } +} + +TEST_CASE("test metal full") { + for (auto t : types) { + auto out_cpu = full({4, 4}, 2, t, Device::cpu); + auto out_gpu = full({4, 4}, 2, t, Device::gpu); + CHECK(array_equal(out_gpu, out_cpu, Device::cpu).item()); + } + + // Check broadcasting works + { + auto x = full({2, 2}, array({3, 4}, {2, 1}), Device::gpu); + CHECK( + array_equal(x, array({3, 3, 4, 4}, {2, 2}), Device::cpu).item()); + x = full({2, 2}, array({3, 4}, {1, 2}), Device::gpu); + CHECK( + array_equal(x, array({3, 4, 3, 4}, {2, 2}), Device::cpu).item()); + } + + // Check zeros and ones + { + auto x = zeros({2, 2}, float32, Device::gpu); + auto y = array({0.0, 0.0, 0.0, 0.0}, {2, 2}); + CHECK(array_equal(x, y, Device::cpu).item()); + + x = ones({2, 2}, float32, Device::gpu); + y = array({1.0, 1.0, 1.0, 1.0}, {2, 2}); + CHECK(array_equal(x, y, Device::cpu).item()); + } +} + +TEST_CASE("test metal astype") { + array x = array({-4, -3, -2, -1, 0, 1, 2, 3}); + // Check all types work + for (auto t : types) { + auto out_cpu = astype(x, t, Device::cpu); + auto out_gpu = astype(x, t, Device::gpu); + CHECK(array_equal(out_gpu, out_cpu, Device::cpu).item()); + } + + x = transpose(reshape(x, {2, 2, 2}), {1, 2, 0}); + for (auto t : types) { + auto out_cpu = astype(x, t, Device::cpu); + auto out_gpu = astype(x, t, Device::gpu); + CHECK(array_equal(out_gpu, out_cpu, Device::cpu).item()); + } +} + +TEST_CASE("test metal reshape") { + array x = array({0, 1, 2, 3, 4, 5, 6, 7}); + auto out_cpu = reshape(x, {2, 2, 2}); + auto out_gpu = reshape(x, {2, 2, 2}, Device::gpu); + CHECK(array_equal(out_gpu, out_cpu, Device::cpu).item()); + + x = transpose(reshape(x, {2, 2, 2}), {1, 2, 0}); + out_cpu = reshape(x, {4, 2}); + out_gpu = reshape(x, {4, 2}, Device::gpu); + CHECK(array_equal(out_gpu, out_cpu, Device::cpu).item()); + + out_cpu = reshape(x, {8}); + out_gpu = reshape(x, {8}, Device::gpu); + CHECK(array_equal(out_gpu, out_cpu, Device::cpu).item()); +} + +TEST_CASE("test metal reduce") { + { + array a(true); + CHECK_EQ(all(a, Device::gpu).item(), true); + CHECK_EQ(any(a, Device::gpu).item(), true); + + a = array(std::initializer_list{}); + CHECK_EQ(all(a, Device::gpu).item(), true); + CHECK_EQ(any(a, Device::gpu).item(), false); + } + + { + std::vector vals(33, 1); + array a(vals.data(), {33}); + CHECK_EQ(all(a, Device::gpu).item(), true); + + vals[32] = 0; + a = array(vals.data(), {33}); + CHECK_EQ(all(a, Device::gpu).item(), false); + } + + { + std::vector vals(33, 0); + array a(vals.data(), {33}); + CHECK_EQ(any(a, Device::gpu).item(), false); + + vals[32] = 1; + a = array(vals.data(), {33}); + CHECK_EQ(any(a, Device::gpu).item(), true); + } + + { + std::vector vals(1 << 14, 0); + array a(vals.data(), {1 << 14}); + CHECK_EQ(all(a, Device::gpu).item(), false); + CHECK_EQ(any(a, Device::gpu).item(), false); + + vals[4] = 1; + vals[999] = 1; + vals[2000] = 1; + a = array(vals.data(), {1 << 14}); + CHECK_EQ(all(a, Device::gpu).item(), false); + CHECK_EQ(any(a, Device::gpu).item(), true); + } + + // sum and prod + { + array a = array({true, false, true}); + CHECK_EQ(sum(a, Device::gpu).item(), 2); + CHECK_EQ(prod(a, Device::gpu).item(), false); + + a = array({true, true, true}); + CHECK_EQ(sum(a, Device::gpu).item(), 3); + CHECK_EQ(prod(a, Device::gpu).item(), true); + + a = full({2, 2, 2}, 2.0f); + CHECK_EQ(sum(a, Device::gpu).item(), 16.0f); + CHECK_EQ(prod(a, Device::gpu).item(), 256.0f); + + a = full({500, 2, 2}, 1u); + CHECK_EQ(sum(a, Device::gpu).item(), 2000); + CHECK_EQ(prod(a, Device::gpu).item(), 1u); + + a = full({500, 2, 2}, 1); + CHECK_EQ(sum(a, Device::gpu).item(), 2000); + CHECK_EQ(prod(a, Device::gpu).item(), 1); + } + + // reducing only some axes and irregular layouts + { + array a(1.0f); + a = broadcast_to(a, {2, 2, 2}); + CHECK_EQ(sum(a, Device::gpu).item(), 8.0f); + + a = ones({2, 4, 8, 16}); + for (auto ax : {0, 1, 2, 3}) { + auto out_gpu = sum(a, ax, false, Device::gpu); + auto out_cpu = sum(a, ax, false, Device::cpu); + CHECK(array_equal(out_gpu, out_cpu, Device::cpu).item()); + } + + for (auto ax : {1, 2, 3}) { + auto out_gpu = sum(a, {0, ax}, false, Device::gpu); + auto out_cpu = sum(a, {0, ax}, false, Device::cpu); + CHECK(array_equal(out_gpu, out_cpu, Device::cpu).item()); + } + for (auto ax : {2, 3}) { + auto out_gpu = sum(a, {0, 1, ax}, false, Device::gpu); + auto out_cpu = sum(a, {0, 1, ax}, false, Device::cpu); + CHECK(array_equal(out_gpu, out_cpu, Device::cpu).item()); + } + } +} + +TEST_CASE("test metal binary ops") { + // scalar-scalar + { + array a(2.0f); + array b(4.0f); + auto out = add(a, b, Device::gpu); + CHECK_EQ(out.item(), 6.0f); + } + + // scalar-vector and vector-scalar + { + array a(2.0f); + array b({2.0f, 4.0f, 6.0f}); + auto out = add(a, b, Device::gpu); + auto expected = array({4.0f, 6.0f, 8.0f}); + CHECK(array_equal(out, expected, Device::cpu).item()); + out = add(b, a, Device::gpu); + CHECK(array_equal(out, expected, Device::cpu).item()); + } + + // vector-vector + { + array a({0.0f, 1.0f, 2.0f}); + array b({3.0f, 4.0f, 5.0f}); + auto out = add(a, b, Device::gpu); + auto expected = array({3.0f, 5.0f, 7.0f}); + CHECK(array_equal(out, expected, Device::cpu).item()); + } + + // general + { + array a({0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f}, {2, 2, 2}); + array b({0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f}, {2, 2, 2}); + a = transpose(a, {0, 2, 1}); + b = transpose(b, {1, 0, 2}); + auto out_gpu = add(a, b, Device::gpu); + auto out_cpu = add(a, b, Device::cpu); + auto expected = + array({0.0f, 3.0f, 5.0f, 8.0f, 6.0f, 9.0f, 11.0f, 14.0f}, {2, 2, 2}); + CHECK(array_equal(out_gpu, out_cpu, Device::cpu).item()); + CHECK(array_equal(out_gpu, expected, Device::cpu).item()); + } + + // Check all types work + for (auto t : types) { + auto a = astype(array({0, 1, 2}), t); + auto b = astype(array({3, 4, 5}), t); + auto out_cpu = add(a, b, Device::cpu); + auto out_gpu = add(a, b, Device::gpu); + CHECK(array_equal(out_gpu, out_cpu, Device::cpu).item()); + } + + // Check subtraction + { + auto a = array({3, 2, 1}); + auto b = array({1, 1, 1}); + auto out = subtract(a, b, Device::gpu); + CHECK(array_equal(out, array({2, 1, 0}), Device::cpu).item()); + } + + // Check multiplication + { + auto a = array({1, 2, 3}); + auto b = array({2, 2, 2}); + auto out = multiply(a, b, Device::gpu); + CHECK(array_equal(out, array({2, 4, 6}), Device::cpu).item()); + } + + // Check division + { + auto x = array(1.0f); + auto y = array(1.0f); + CHECK_EQ(divide(x, y, Device::gpu).item(), 1.0f); + + x = array(1.0f); + y = array(0.5); + CHECK_EQ(divide(x, y, Device::gpu).item(), 2.0f); + + x = array(1.0f); + y = array(0.0f); + CHECK(std::isinf(divide(x, y, Device::gpu).item())); + + x = array(0.0f); + y = array(0.0f); + CHECK(std::isnan(divide(x, y, Device::gpu).item())); + } + + // Check maximum and minimum + { + auto x = array(1.0f); + auto y = array(0.0f); + CHECK_EQ(maximum(x, y, Device::gpu).item(), 1.0f); + CHECK_EQ(minimum(x, y, Device::gpu).item(), 0.0f); + y = array(2.0f); + CHECK_EQ(maximum(x, y, Device::gpu).item(), 2.0f); + CHECK_EQ(minimum(x, y, Device::gpu).item(), 1.0f); + } + + // Check equal + { + array x(1.0f); + array y(1.0f); + CHECK(equal(x, y, Device::gpu).item()); + x = array(0.0f); + CHECK(!equal(x, y, Device::gpu).item()); + } + + // Greater and less + { + array x(1.0f); + array y(0.0f); + CHECK(greater(x, y, Device::gpu).item()); + CHECK(greater_equal(x, y, Device::gpu).item()); + CHECK(!greater(y, x, Device::gpu).item()); + CHECK(!greater_equal(y, x, Device::gpu).item()); + y = array(1.0f); + CHECK(!greater(x, y, Device::gpu).item()); + CHECK(greater_equal(x, y, Device::gpu).item()); + + x = array(0.0f); + y = array(1.0f); + CHECK(less(x, y, Device::gpu).item()); + CHECK(less_equal(x, y, Device::gpu).item()); + CHECK(!less(y, x, Device::gpu).item()); + CHECK(!less_equal(y, x, Device::gpu).item()); + y = array(0.0f); + CHECK(!less(x, y, Device::gpu).item()); + CHECK(less_equal(x, y, Device::gpu).item()); + } + + // Check logaddexp + { + constexpr float inf = std::numeric_limits::infinity(); + array x(inf); + array y(2.0f); + auto out = logaddexp(x, y, Device::gpu); + CHECK_EQ(out.item(), inf); + + x = array(-inf); + out = logaddexp(x, y, Device::gpu); + CHECK_EQ(out.item(), 2.0f); + + y = array(-inf); + out = logaddexp(x, y, Device::gpu); + CHECK_EQ(out.item(), -inf); + } +} + +TEST_CASE("test metal unary ops") { + // contiguous + { + array x({-1.0f, 0.0f, 1.0f}); + auto expected = array({1.0f, 0.0f, 1.0f}); + CHECK(array_equal(abs(x, Device::gpu), expected, Device::cpu).item()); + } + + // general + { + array x({-1.0f, 0.0f, 1.0f, 1.0f, -1.0f, 1.0f, 3.0f, -3.0f}); + auto y = slice(x, {0}, {8}, {2}); + auto expected = array({1.0f, 1.0f, 1.0f, 3.0f}); + CHECK(array_equal(abs(y, Device::gpu), expected, Device::cpu).item()); + + y = slice(x, {4}, {8}); + expected = array({1.0f, 1.0f, 3.0f, 3.0f}); + CHECK(array_equal(abs(y, Device::gpu), expected, Device::cpu).item()); + } + + // Test negative + { + array x(1.0f); + CHECK_EQ(negative(x, Device::gpu).item(), -1.0f); + } + + // Check all types work + for (auto t : types) { + if (t == bool_) { + continue; + } + auto in = astype(array({1}), t); + auto out_cpu = negative(in, Device::cpu); + auto out_gpu = negative(in, Device::gpu); + CHECK(array_equal(out_gpu, out_cpu, Device::cpu).item()); + } + + // Test log1p + { + constexpr float inf = std::numeric_limits::infinity(); + array x(-1.0f); + CHECK_EQ(log1p(x, Device::gpu).item(), -inf); + + x = array(0.0f); + CHECK_EQ(log1p(x, Device::gpu).item(), 0.0f); + + x = array(1e-9f); + CHECK_EQ(log1p(x, Device::gpu).item(), 1e-9f); + + x = array(-2.0f); + CHECK(std::isnan(log1p(x, Device::gpu).item())); + } +} + +TEST_CASE("test metal random") { + { + auto key = random::key(0); + auto x = random::bits({}, 4, key, Device::gpu); + auto y = random::bits({}, 4, key, Device::gpu); + CHECK_EQ(x.item(), 1797259609u); + CHECK_EQ(x.item(), y.item()); + } + + { + auto key = random::key(1); + auto x = random::bits({}, 4, key, Device::gpu); + CHECK_EQ(x.item(), 507451445u); + } + + { + auto key = random::key(0); + auto x = random::bits({3, 1}, 4, key, Device::gpu); + auto expected = array({4146024105u, 1351547692u, 2718843009u}, {3, 1}); + CHECK(array_equal(x, expected, Device::cpu).item()); + } +} + +TEST_CASE("test metal matmul") { + { + auto a = ones({2, 2}); + auto b = ones({2, 2}); + auto out = matmul(a, b, Device::gpu); + CHECK(array_equal(out, full({2, 2}, 2.0f), Device::cpu).item()); + } + + // Batched matmul + { + auto a = ones({3, 2, 2}); + auto b = ones({3, 2, 2}); + auto out = matmul(a, b, Device::gpu); + CHECK(array_equal(out, full({3, 2, 2}, 2.0f), Device::cpu).item()); + } + + // Broadcast batched matmul + { + auto a = ones({1, 3, 2, 2}); + auto b = ones({3, 1, 2, 2}); + auto out = matmul(a, b, Device::gpu); + CHECK(array_equal(out, full({3, 3, 2, 2}, 2.0f), Device::cpu).item()); + } +} diff --git a/tests/ops_tests.cpp b/tests/ops_tests.cpp new file mode 100644 index 0000000000..a9a74b829d --- /dev/null +++ b/tests/ops_tests.cpp @@ -0,0 +1,1926 @@ +#include +#include + +#include "doctest/doctest.h" + +#include "mlx/mlx.h" + +using namespace mlx::core; + +TEST_CASE("test copy") { + array x(1.0); + auto y = copy(x); + CHECK_EQ(y.shape(), std::vector{}); + CHECK_NE(y.id(), x.id()); + CHECK_EQ(y.item(), 1.0f); + + x = array({1, 2}, {2, 1}); + y = copy(x); + CHECK_EQ(y.shape(), std::vector{2, 1}); + CHECK_EQ(y.dtype(), int32); + CHECK_NE(y.id(), x.id()); + CHECK(array_equal(y, x).item()); +} + +TEST_CASE("test reshape") { + array x(1.0); + CHECK_EQ(reshape(x, {}).shape(), std::vector{}); + CHECK_THROWS_AS(reshape(x, {2}), std::invalid_argument); + auto y = reshape(x, {1, 1, 1}); + CHECK_EQ(y.shape(), std::vector{1, 1, 1}); + y = reshape(x, {-1, 1, 1}); + CHECK_EQ(y.shape(), std::vector{1, 1, 1}); + y = reshape(x, {1, 1, -1}); + CHECK_EQ(y.shape(), std::vector{1, 1, 1}); + CHECK_THROWS_AS(reshape(x, {1, -1, -1}), std::invalid_argument); + CHECK_THROWS_AS(reshape(x, {2, -1}), std::invalid_argument); + + x = zeros({2, 2, 2}); + y = reshape(x, {8}); + CHECK_EQ(y.shape(), std::vector{8}); + CHECK_THROWS_AS(reshape(x, {7}), std::invalid_argument); + y = reshape(x, {-1}); + CHECK_EQ(y.shape(), std::vector{8}); + y = reshape(x, {-1, 2}); + CHECK_EQ(y.shape(), std::vector{4, 2}); + CHECK_THROWS_AS(reshape(x, {-1, 7}), std::invalid_argument); + + // Works with empty array + x = array({}); + y = reshape(x, {0, 0, 0}); + CHECK_EQ(y.shape(), std::vector{0, 0, 0}); + y.eval(); + CHECK_EQ(y.size(), 0); + CHECK_THROWS_AS(reshape(x, {}), std::invalid_argument); + CHECK_THROWS_AS(reshape(x, {1}), std::invalid_argument); + y = reshape(x, {1, 5, 0}); + CHECK_EQ(y.shape(), std::vector{1, 5, 0}); +} + +TEST_CASE("test squeeze and expand") { + array x = zeros({2, 1, 2, 1, 2, 1}); + CHECK_EQ(squeeze(x).shape(), std::vector{2, 2, 2}); + CHECK_EQ(squeeze(x, {1, 3, 5}).shape(), std::vector{2, 2, 2}); + CHECK_EQ(squeeze(x, {-1, -3, -5}).shape(), std::vector{2, 2, 2}); + CHECK_EQ(squeeze(x, 1).shape(), std::vector{2, 2, 1, 2, 1}); + CHECK_EQ(squeeze(x, -1).shape(), std::vector{2, 1, 2, 1, 2}); + + CHECK_THROWS(squeeze(x, 0)); + CHECK_THROWS(squeeze(x, 2)); + CHECK_THROWS(squeeze(x, {1, 3, 1})); + CHECK_THROWS(squeeze(x, {1, 3, -3})); + + x = zeros({2, 2}); + CHECK_EQ(expand_dims(x, 0).shape(), std::vector{1, 2, 2}); + CHECK_EQ(expand_dims(x, -1).shape(), std::vector{2, 2, 1}); + CHECK_EQ(expand_dims(x, 1).shape(), std::vector{2, 1, 2}); + CHECK_EQ(expand_dims(x, {0, 1, 2}).shape(), std::vector{1, 1, 1, 2, 2}); + CHECK_EQ( + expand_dims(x, {0, 1, 2, 5, 6, 7}).shape(), + std::vector{1, 1, 1, 2, 2, 1, 1, 1}); + + CHECK_THROWS(expand_dims(x, 3)); + CHECK_THROWS(expand_dims(x, -4)); + CHECK_THROWS(expand_dims(x, {0, 1, 0})); + CHECK_THROWS(expand_dims(x, {0, 1, -4})); +} + +TEST_CASE("test slice") { + array x = array(3); + auto out = slice(x, {}, {}); + CHECK_EQ(out.item(), 3); + CHECK_THROWS_AS(slice(x, {1}, {2}), std::invalid_argument); + CHECK_THROWS_AS(slice(x, {}, {2}), std::invalid_argument); + CHECK_THROWS_AS(slice(x, {0}, {}), std::invalid_argument); + + x = array({3}); + out = slice(x, {0}, {1}); + CHECK_EQ(out.item(), 3); + out = slice(x, {-1}, {1}); + CHECK_EQ(out.item(), 3); + + out = slice(x, {-3}, {10}); + CHECK_EQ(out.item(), 3); + + out = slice(x, {1}, {0}); + eval(out); + CHECK_EQ(out.shape(), std::vector{0}); + + out = slice(x, {0}, {1}, {1}); + CHECK_EQ(out.item(), 3); + + out = slice(x, {0}, {1}, {10}); + CHECK_EQ(out.item(), 3); + + x = array({0, 1, 2, 3, 4, 5, 6, 7}, {2, 4}); + out = slice(x, {0, 0}, {2, 2}); + CHECK(array_equal(out, array({0, 1, 4, 5}, {2, 2})).item()); + + out = slice(x, {0, 0}, {0, 2}); + CHECK(array_equal(out, reshape(array({}), {0, 2})).item()); + + out = slice(x, {0, 2}, {2, 3}); + CHECK(array_equal(out, array({2, 6}, {2, 1})).item()); + + out = slice(x, {0, 0}, {2, 4}, {1, 2}); + CHECK(array_equal(out, array({0, 2, 4, 6}, {2, 2})).item()); +} + +TEST_CASE("test split") { + array x = array(1); + CHECK_THROWS(split(x, 0)); + + x = array({3}); + CHECK_EQ(split(x, 1)[0].item(), 3); + + x = array({0, 1, 2}); + CHECK_THROWS(split(x, 3, 1)); + CHECK_THROWS(split(x, 3, -2)); + + auto out = split(x, 3, 0); + CHECK_EQ(out.size(), 3); + + out = split(x, 3, -1); + CHECK_EQ(out.size(), 3); + for (auto i = 0; i < 3; ++i) { + CHECK_EQ(out[i].shape(), std::vector{1}); + CHECK_EQ(out[i].dtype(), int32); + CHECK_EQ(out[i].item(), i); + } + + x = array({0, 1, 2, 3, 4, 5}, {2, 3}); + out = split(x, 2); + CHECK(array_equal(out[0], array({0, 1, 2}, {1, 3})).item()); + CHECK(array_equal(out[1], array({3, 4, 5}, {1, 3})).item()); + out = split(x, 3, 1); + CHECK(array_equal(out[0], array({0, 3}, {2, 1})).item()); + CHECK(array_equal(out[1], array({1, 4}, {2, 1})).item()); + CHECK(array_equal(out[2], array({2, 5}, {2, 1})).item()); + + x = zeros({8, 12}); + out = split(x, 2); + CHECK_EQ(out.size(), 2); + CHECK_EQ(out[0].shape(), std::vector{4, 12}); + CHECK_EQ(out[1].shape(), std::vector{4, 12}); + out = split(x, 3, 1); + CHECK_EQ(out.size(), 3); + CHECK_EQ(out[0].shape(), std::vector{8, 4}); + CHECK_EQ(out[1].shape(), std::vector{8, 4}); + CHECK_EQ(out[2].shape(), std::vector{8, 4}); + + out = split(x, std::vector{}); + CHECK_EQ(out.size(), 1); + CHECK_EQ(out[0].shape(), x.shape()); + + out = split(x, {3, 7}); + CHECK_EQ(out.size(), 3); + CHECK_EQ(out[0].shape(), std::vector{3, 12}); + CHECK_EQ(out[1].shape(), std::vector{4, 12}); + CHECK_EQ(out[2].shape(), std::vector{1, 12}); + + out = split(x, std::vector{20}); + CHECK_EQ(out.size(), 2); + CHECK_EQ(out[0].shape(), std::vector{8, 12}); + CHECK_EQ(out[1].shape(), std::vector{0, 12}); + + // Negative indices + out = split(x, std::vector{-5}); + CHECK_EQ(out[0].shape(), std::vector{3, 12}); + CHECK_EQ(out[1].shape(), std::vector{5, 12}); + + // Different axis + out = split(x, std::vector{2, 8}, 1); + CHECK_EQ(out[0].shape(), std::vector{8, 2}); + CHECK_EQ(out[1].shape(), std::vector{8, 6}); + CHECK_EQ(out[2].shape(), std::vector{8, 4}); + + // Out of order indices + x = arange(5); + out = split(x, std::vector{2, 1, 2}); + CHECK(array_equal(out[0], array({0, 1})).item()); + CHECK(array_equal(out[1], array({})).item()); + CHECK(array_equal(out[2], array({1})).item()); + CHECK(array_equal(out[3], array({2, 3, 4})).item()); +} + +TEST_CASE("test transpose") { + array x(1); + auto y = transpose(x); + CHECK_EQ(y.shape(), std::vector{}); + CHECK_EQ(y.item(), 1); + CHECK_THROWS_AS(transpose(x, {0}), std::invalid_argument); + CHECK_THROWS_AS(transpose(x, {1}), std::invalid_argument); + + x = array({1}, {1}); + y = transpose(x); + CHECK_EQ(y.shape(), std::vector{1}); + CHECK_EQ(y.item(), 1); + + // Negative indices + y = transpose(x, {-1}); + CHECK_EQ(y.shape(), std::vector{1}); + CHECK_EQ(y.item(), 1); + + CHECK_THROWS_AS(transpose(x, {1}), std::invalid_argument); + CHECK_THROWS_AS(transpose(x, {0, 0}), std::invalid_argument); + + // Works with empty array + x = array({}); + y = transpose(x); + CHECK_EQ(y.shape(), std::vector{0}); + y.eval(); + CHECK_EQ(y.size(), 0); + + x = array({1, 2, 3, 4, 5, 6}, {2, 3}); + y = transpose(x); + CHECK_EQ(y.shape(), std::vector{3, 2}); + y = transpose(x, {-1, 0}); + CHECK_EQ(y.shape(), std::vector{3, 2}); + y = transpose(x, {-1, -2}); + CHECK_EQ(y.shape(), std::vector{3, 2}); + y.eval(); + CHECK(array_equal(y, array({1, 4, 2, 5, 3, 6}, {3, 2})).item()); + y = transpose(x, {0, 1}); + CHECK_EQ(y.shape(), std::vector{2, 3}); + CHECK(array_equal(y, x).item()); + y = transpose(x, {0, -1}); + CHECK_EQ(y.shape(), std::vector{2, 3}); + CHECK(array_equal(y, x).item()); + + CHECK_THROWS_AS(transpose(x, {}), std::invalid_argument); + CHECK_THROWS_AS(transpose(x, {0}), std::invalid_argument); + CHECK_THROWS_AS(transpose(x, {0, 0}), std::invalid_argument); + CHECK_THROWS_AS(transpose(x, {0, 0, 0}), std::invalid_argument); + CHECK_THROWS_AS(transpose(x, {0, 1, 1}), std::invalid_argument); + + x = array({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, {2, 3, 2}); + y = transpose(x); + CHECK_EQ(y.shape(), std::vector{2, 3, 2}); + auto expected = array({1, 7, 3, 9, 5, 11, 2, 8, 4, 10, 6, 12}, {2, 3, 2}); + CHECK(array_equal(y, expected).item()); + + y = transpose(x, {0, 1, 2}); + CHECK_EQ(y.shape(), std::vector{2, 3, 2}); + CHECK(array_equal(y, x).item()); + y = transpose(x, {1, 0, 2}); + CHECK_EQ(y.shape(), std::vector{3, 2, 2}); + expected = array({1, 2, 7, 8, 3, 4, 9, 10, 5, 6, 11, 12}, {3, 2, 2}); + CHECK(array_equal(y, expected).item()); + y = transpose(x, {0, 2, 1}); + CHECK_EQ(y.shape(), std::vector{2, 2, 3}); + expected = array({1, 3, 5, 2, 4, 6, 7, 9, 11, 8, 10, 12}, {2, 2, 3}); + CHECK(array_equal(y, expected).item()); + + // Check reshaping a transposed array + x = array({0, 1, 2, 3, 4, 5, 6, 7}, {4, 2}); + x = reshape(transpose(x), {2, 2, 2}); + expected = array({0, 2, 4, 6, 1, 3, 5, 7}, {2, 2, 2}); + CHECK(array_equal(x, expected).item()); + + // Check maintaining contiguous status + x = array({0, 1, 2, 3, 4, 5, 6, 7}, {1, 4, 1, 2}); + CHECK(x.flags().row_contiguous); + x = transpose(x, {2, 1, 0, 3}); + eval(x); + CHECK(x.flags().row_contiguous); +} + +TEST_CASE("test comparison ops") { + // Empty array + { + array x({}); + array y({}); + auto z = x == y; + CHECK_EQ(z.dtype(), bool_); + CHECK_EQ(z.shape(), std::vector{0}); + } + + // Basic cases + { + array x(1.0); + array y(1.0); + CHECK(equal(x, y).item()); + CHECK((x == y).item()); + CHECK((x == 1.0f).item()); + CHECK((1.0f == y).item()); + + CHECK(!(x != y).item()); + CHECK(!not_equal(x, y).item()); + CHECK(!(1.0f != y).item()); + CHECK(!(x != 1.0f).item()); + + CHECK(array_equal(x, y).item()); + + x = array(0.0); + CHECK(!equal(x, y).item()); + CHECK(!array_equal(x, y).item()); + CHECK(not_equal(x, y).item()); + } + + // Greater and less + { + array x(1.0); + array y(0.0); + CHECK(greater(x, y).item()); + CHECK((x > 0.0f).item()); + CHECK((1.0f > y).item()); + CHECK(greater_equal(x, y).item()); + CHECK((1.0f >= y).item()); + CHECK(!(x > 1.0f).item()); + CHECK((x >= 1.0f).item()); + + CHECK(less(y, x).item()); + CHECK((y < 1.0).item()); + CHECK((y <= 1.0f).item()); + CHECK(!(x < 1.0).item()); + CHECK((x <= 1.0f).item()); + } + + // Check array_equal works + { + auto x = zeros({5, 5}); + auto y = zeros({5, 5}); + CHECK(array_equal(x, y).item()); + + x = zeros({1, 1}); + CHECK(!array_equal(x, y).item()); + + x = ones({5, 5}); + CHECK(!array_equal(x, y).item()); + + x = array({0.0f, 1.0f, NAN}); + y = array({0.0f, 1.0f, NAN}); + CHECK(!array_equal(x, y).item()); + CHECK(array_equal(x, y, true).item()); + } + + // Check other types + { + auto x = zeros({5, 5}, int32); + auto y = zeros({5, 5}, int32); + CHECK(array_equal(x, y).item()); + + x = ones({5, 5}, bool_); + y = ones({5, 5}, bool_); + CHECK(array_equal(x, y).item()); + } + + // Check type promotion + { + array x(1.0f); + array y(1); + CHECK_EQ(equal(x, y).item(), true); + + x = array(true, bool_); + CHECK_EQ(equal(x, y).item(), true); + } + + // Broadcasting works + { + auto x = zeros({1, 2}); + auto y = zeros({2, 1}); + auto z = equal(x, y); + CHECK_EQ(z.dtype(), bool_); + CHECK_EQ(z.shape(), std::vector{2, 2}); + auto expected = array({true, true, true, true}, {2, 2}); + CHECK(array_equal(z, expected).item()); + + x = array({1.0, 2.0}, {1, 2}); + y = array({1.0, 2.0}, {2, 1}); + z = equal(x, y); + CHECK_EQ(z.dtype(), bool_); + CHECK_EQ(z.shape(), std::vector{2, 2}); + expected = array({true, false, false, true}, {2, 2}); + CHECK(array_equal(z, expected).item()); + + expected = array({false, true, false, false}, {2, 2}); + z = greater(x, y); + CHECK(array_equal(z, expected).item()); + + expected = array({true, true, false, true}, {2, 2}); + z = greater_equal(x, y); + CHECK(array_equal(z, expected).item()); + + expected = array({false, false, true, false}, {2, 2}); + z = less(x, y); + CHECK(array_equal(z, expected).item()); + + expected = array({true, false, true, true}, {2, 2}); + z = less_equal(x, y); + CHECK(array_equal(z, expected).item()); + } +} + +TEST_CASE("test all close") { + array x(1.0f); + array y(1.0f); + CHECK(allclose(x, y).item()); + + y = array(1.1f); + CHECK_FALSE(allclose(x, y).item()); + CHECK(allclose(x, y, 0.1).item()); + CHECK_FALSE(allclose(x, y, 0.01).item()); + CHECK(allclose(x, y, 0.01, 0.1).item()); +} + +TEST_CASE("test reduction ops") { + // Check shapes and throws correctly + { + auto x = array(1); + auto out = sum(x); + CHECK_EQ(out.ndim(), 0); + CHECK_THROWS_AS(sum(x, 0), std::out_of_range); + CHECK_THROWS_AS(sum(x, -1), std::out_of_range); + out = sum(x, std::vector{}); + CHECK_EQ(out.shape(), std::vector{}); + CHECK_EQ(out.size(), 1); + + x = array({}); + out = sum(x); + CHECK_EQ(out.shape(), std::vector{}); + CHECK_EQ(out.size(), 1); + out = sum(x, true); + CHECK_EQ(out.shape(), std::vector{1}); + out = sum(x, std::vector{}); + CHECK_EQ(out.shape(), x.shape()); + + x = zeros({2}); + out = sum(x); + CHECK_EQ(out.ndim(), 0); + out = sum(x, -1); + CHECK_EQ(out.ndim(), 0); + out = sum(x, -1, true); + CHECK_EQ(out.ndim(), 1); + CHECK_EQ(out.shape(), std::vector{1}); + + CHECK_THROWS_AS(sum(x, 1), std::out_of_range); + CHECK_THROWS_AS(sum(x, -2), std::out_of_range); + CHECK_THROWS_AS(sum(x, {0, 0}), std::invalid_argument); + CHECK_THROWS_AS(sum(x, {-1, 0}), std::invalid_argument); + + x = zeros({2, 3, 4}); + out = sum(x, {0, 2}); + CHECK_EQ(out.shape(), std::vector{3}); + out = sum(x, std::vector{}); + CHECK_EQ(out.shape(), x.shape()); + + out = sum(x, {0, -1}); + CHECK_EQ(out.shape(), std::vector{3}); + + out = sum(x, {0, -1}, true); + CHECK_EQ(out.shape(), std::vector{1, 3, 1}); + + out = sum(x, true); + CHECK_EQ(out.shape(), std::vector{1, 1, 1}); + + out = sum(x); + CHECK_EQ(out.shape(), std::vector{}); + + CHECK_THROWS_AS(sum(x, 3), std::out_of_range); + CHECK_THROWS_AS(sum(x, -4), std::out_of_range); + CHECK_THROWS_AS(sum(x, {0, 1, -2}), std::invalid_argument); + } + + // Test sum + { + auto x = array({}); + CHECK_EQ(sum(x).item(), 0.0f); + + x = array({1, 2, 3}); + CHECK_EQ(sum(x).item(), 6); + CHECK(array_equal(sum(x, std::vector{}), x).item()); + + x = ones({2, 3}); + CHECK(array_equal(sum(x, 1), full({2}, 3.0f)).item()); + CHECK(array_equal(sum(x, 0), full({3}, 2.0f)).item()); + CHECK_EQ(sum(x, {0, 1}).item(), 6.0f); + + x = ones({2, 3, 4}); + CHECK(array_equal(sum(x, 0), full({3, 4}, 2.0f)).item()); + CHECK(array_equal(sum(x, 1), full({2, 4}, 3.0f)).item()); + CHECK(array_equal(sum(x, 2), full({2, 3}, 4.0f)).item()); + CHECK(array_equal(sum(x, {0, 1}), full({4}, 6.0f)).item()); + CHECK(array_equal(sum(x, {0, 2}), full({3}, 8.0f)).item()); + CHECK(array_equal(sum(x, {1, 2}), full({2}, 12.0f)).item()); + + // Output for bool gets higher precision + x = array({true, true, true}); + CHECK_EQ(sum(x).item(), 3); + + x = array(2.0f); + x = broadcast_to(x, {2, 2, 2}); + CHECK_EQ(sum(x).item(), 16.0f); + + // Tests with non-uniform results after reduction + x = array({1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f}, {2, 3}); + CHECK(array_equal(sum(x, 0), full({3}, 3.0f)).item()); + CHECK(array_equal(sum(x, 1), array({3.0f, 6.0f}, {2})).item()); + } + + // Test prod + { + auto x = array({}); + CHECK_EQ(prod(x).item(), 1.0f); + + x = array({2, 2, 2}); + CHECK_EQ(prod(x).item(), 8); + CHECK(array_equal(prod(x, std::vector{}), x).item()); + + x = full({2, 3}, 2.0f); + CHECK(array_equal(prod(x, 1), full({2}, 8.0f)).item()); + CHECK(array_equal(prod(x, 0), full({3}, 4.0f)).item()); + CHECK_EQ(prod(x, {0, 1}).item(), 64.0f); + + x = full({2, 3, 4}, 2.0f); + CHECK(array_equal(prod(x, 0), full({3, 4}, 4.0f)).item()); + CHECK(array_equal(prod(x, 1), full({2, 4}, 8.0f)).item()); + CHECK(array_equal(prod(x, 2), full({2, 3}, 16.0f)).item()); + CHECK(array_equal(prod(x, {0, 1}), full({4}, 64.0f)).item()); + CHECK(array_equal(prod(x, {0, 2}), full({3}, 256.0f)).item()); + CHECK(array_equal(prod(x, {1, 2}), full({2}, 4096.0f)).item()); + + // Tests with non-uniform results after reduction + x = array({1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f}, {2, 3}); + CHECK(array_equal(prod(x, 0), full({3}, 2.0f)).item()); + CHECK(array_equal(prod(x, 1), array({1.0f, 8.0f}, {2})).item()); + + x = array({true, true, true, false, true, false}, {2, 3}); + CHECK(array_equal(prod(x, 0), array({false, true, false})).item()); + CHECK(array_equal(prod(x, 1), array({true, false})).item()); + } + + // Test all + { + auto x = array({}); + CHECK_EQ(all(x).item(), true); + + x = array({2, 2, 2}); + CHECK_EQ(all(x).item(), true); + auto out = all(x, std::vector{}); + CHECK(array_equal(out, array({true, true, true})).item()); + + x = array({0, 2, 2}); + CHECK_EQ(all(x).item(), false); + + x = array({true, true, true, false, true, false}, {2, 3}); + CHECK(array_equal(all(x, 1), array({true, false})).item()); + CHECK(array_equal(all(x, 0), array({false, true, false})).item()); + } + + // Test any + { + auto x = array({}); + CHECK_EQ(any(x).item(), false); + + x = array({0, 0, 0}); + CHECK_EQ(any(x).item(), false); + + x = array({0, 2, 0}); + CHECK_EQ(any(x).item(), true); + auto out = any(x, std::vector{}); + CHECK(array_equal(out, array({false, true, false})).item()); + + x = array({true, false, true, false, false, false}, {2, 3}); + CHECK(array_equal(any(x, 1), array({true, false})).item()); + CHECK(array_equal(any(x, 0), array({true, false, true})).item()); + } + + // Test max and min + { + auto x = array({}); + CHECK_THROWS(max(x)); + CHECK_THROWS(min(x)); + + x = array({1.0f, 2.0f, 3.0f}); + CHECK_EQ(max(x).item(), 3.0f); + CHECK_EQ(min(x).item(), 1.0f); + + x = array({-2.0f, -1.0f}); + CHECK_EQ(max(x).item(), -1.0f); + CHECK_EQ(min(x).item(), -2.0f); + + constexpr float inf = std::numeric_limits::infinity(); + x = array({inf}); + CHECK_EQ(min(x).item(), inf); + + x = array({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}, {2, 3}); + CHECK(array_equal(max(x, 0), array({4.0f, 5.0f, 6.0f})).item()); + CHECK(array_equal(max(x, 1), array({3.0f, 6.0f})).item()); + CHECK(array_equal(min(x, 0), array({1.0f, 2.0f, 3.0f})).item()); + CHECK(array_equal(min(x, 1), array({1.0f, 4.0f})).item()); + + x = array({1u, 2u, 3u}); + CHECK_EQ(max(x).item(), 3u); + CHECK_EQ(min(x).item(), 1u); + + x = array({1u, 2u, 3u, 4u, 5u, 6u}, {2, 3}); + CHECK(array_equal(max(x, 0), array({4u, 5u, 6u})).item()); + CHECK(array_equal(max(x, 1), array({3u, 6u})).item()); + CHECK(array_equal(min(x, 0), array({1u, 2u, 3u})).item()); + CHECK(array_equal(min(x, 1), array({1u, 4u})).item()); + + x = array({true, false, true, false, false, false}, {2, 3}); + CHECK(array_equal(max(x, 1), array({true, false})).item()); + CHECK(array_equal(max(x, 0), array({true, false, true})).item()); + + x = array({true, true, true, false, true, false}, {2, 3}); + CHECK(array_equal(min(x, 1), array({true, false})).item()); + CHECK(array_equal(min(x, 0), array({false, true, false})).item()); + } + + // Test logsumexp + { + auto x = array({}); + CHECK_THROWS(logsumexp(x)); + + constexpr float inf = std::numeric_limits::infinity(); + + x = array({-inf, -inf}); + WARN_EQ(logsumexp(x).item(), -inf); + + x = array({0.0f, -inf}); + CHECK_EQ(logsumexp(x).item(), 0.0f); + + x = array({0.0f, inf}); + WARN_EQ(logsumexp(x).item(), inf); + + x = reshape(arange(6, float32), {2, 3}); + + std::vector nums = {0.0f, 1.0f, 2.0f, 3.0f}; + x = array(nums.data(), {2, 2}); + auto y = logsumexp(x, {0, 1}, true); + CHECK_EQ(y.shape(), std::vector{1, 1}); + auto result = std::log( + std::exp(nums[0]) + std::exp(nums[1]) + std::exp(nums[2]) + + std::exp(nums[3])); + CHECK(y.item() == doctest::Approx(result)); + auto expected = array( + {std::log(std::exp(nums[0]) + std::exp(nums[2])), + std::log(std::exp(nums[1]) + std::exp(nums[3]))}); + CHECK(allclose(logsumexp(x, 0), expected).item()); + + expected = array( + {std::log(std::exp(nums[0]) + std::exp(nums[1])), + std::log(std::exp(nums[2]) + std::exp(nums[3]))}); + CHECK(allclose(logsumexp(x, 1), expected).item()); + } + + // Test softmax + { + auto x = array({0., 0., 0., 0.}); + auto y = array({0.25, 0.25, 0.25, 0.25}); + CHECK(array_equal(y, softmax(x)).item()); + CHECK(array_equal(y, softmax(x, -1)).item()); + CHECK(array_equal(y, softmax(x, std::vector{-1})).item()); + CHECK(array_equal(y, softmax(x, std::vector{0})).item()); + } +} + +TEST_CASE("test irregular binary ops") { + // 1D strided + { + auto x = full({128}, 1.0f); + auto y = full({64}, 1.0f); + x = slice(x, {0}, {128}, {4}); + y = slice(y, {0}, {64}, {2}); + CHECK(array_equal(add(x, y), full({32}, 2.0f)).item()); + } + + // 2D broadcasts + { + auto x = full({32, 32}, 4.0f); + auto y = full({32}, 4.0f); + CHECK(array_equal(add(x, y), full({32, 32}, 8.0f)).item()); + y = reshape(y, {32, 1}); + CHECK(array_equal(add(x, y), full({32, 32}, 8.0f)).item()); + CHECK(array_equal(subtract(y, x), zeros({32, 32})).item()); + } +} + +TEST_CASE("test arithmetic unary ops") { + // Test negative + { + array x(1.0f); + CHECK_EQ(negative(x).item(), -1.0f); + CHECK_EQ((-x).item(), -1.0f); + + // works on empty array + CHECK(array_equal(-array({}), array({})).item()); + + // Throws on bool + CHECK_THROWS(negative(array(true))); + } + + // Test logical not + { + array x(false); + CHECK_EQ(logical_not(x).item(), true); + + x = array(1.0f); + auto y = logical_not(x); + CHECK_EQ(y.dtype(), bool_); + CHECK_EQ(y.item(), false); + + x = array(0); + y = logical_not(x); + CHECK_EQ(y.dtype(), bool_); + CHECK_EQ(y.item(), true); + } + + // Test abs + { + array x({-1.0f, 0.0f, 1.0f}); + CHECK(array_equal(abs(x), array({1.0f, 0.0f, 1.0f})).item()); + + // works on empty array + CHECK(array_equal(abs(array({})), array({})).item()); + + // int32 + x = array({-1, 0, 1}); + CHECK(array_equal(abs(x), array({1, 0, 1})).item()); + + // uint32 + x = array({1u, 0u, 1u}); + CHECK(array_equal(abs(x), array({1u, 0u, 1u})).item()); + + // bool + x = array({false, true}); + CHECK(array_equal(abs(x), array({false, true})).item()); + } + + // Test sign + { + array x({-1.0f, 0.0f, 1.0f}); + CHECK(array_equal(sign(x), x).item()); + + // works on empty array + CHECK(array_equal(sign(array({})), array({})).item()); + + // int32 + x = array({-1, 0, 1}); + CHECK(array_equal(sign(x), x).item()); + + // uint32 + x = array({1u, 0u, 1u}); + CHECK(array_equal(sign(x), x).item()); + + // bool + x = array({false, true}); + CHECK(array_equal(sign(x), x).item()); + } + + constexpr float neginf = -std::numeric_limits::infinity(); + + // Test exponential + { + array x(0.0); + CHECK_EQ(exp(x).item(), 1.0); + + x = array(2.0); + CHECK_EQ(exp(x).item(), std::exp(2.0f)); + + CHECK(array_equal(exp(array({})), array({})).item()); + + x = array(neginf); + CHECK_EQ(exp(x).item(), 0.0f); + + // Integer input type + x = array(2); + CHECK_EQ(x.dtype(), int32); + CHECK_EQ(exp(x).item(), std::exp(2.0f)); + + // Input is irregularly strided + x = broadcast_to(array(1.0f), {2, 2, 2}); + CHECK(array_equal(exp(x), full({2, 2, 2}, std::exp(1.0f))).item()); + + x = split(array({0.0f, 1.0f, 2.0f, 3.0f}, {2, 2}), 2, 1)[0]; + auto expected = array({std::exp(0.0f), std::exp(2.0f)}, {2, 1}); + CHECK(array_equal(exp(x), expected).item()); + } + + // Test sine + { + array x(0.0); + CHECK_EQ(sin(x).item(), 0.0); + + x = array(M_PI_2); + CHECK(sin(x).item() == doctest::Approx(std::sin(M_PI_2))); + + CHECK(array_equal(sin(array({})), array({})).item()); + + // Integer input type + x = array(0); + CHECK_EQ(x.dtype(), int32); + CHECK_EQ(sin(x).item(), std::sin(0.0f)); + + // Input is irregularly strided + x = broadcast_to(array(1.0f), {2, 2, 2}); + CHECK(allclose(sin(x), full({2, 2, 2}, std::sin(1.0f))).item()); + + x = split(array({0.0f, 1.0f, 2.0f, 3.0f}, {2, 2}), 2, 1)[0]; + auto expected = array({std::sin(0.0f), std::sin(2.0f)}, {2, 1}); + CHECK(allclose(sin(x), expected).item()); + } + + // Test cos + { + array x(0.0); + CHECK_EQ(cos(x).item(), doctest::Approx(1.0)); + + x = array(M_PI_2); + CHECK(cos(x).item() == doctest::Approx(std::cos(M_PI_2))); + + CHECK(array_equal(cos(array({})), array({})).item()); + + // Integer input type + x = array(0); + CHECK_EQ(x.dtype(), int32); + CHECK(cos(x).item() == doctest::Approx(std::cos(0.0f))); + + // Input is irregularly strided + x = broadcast_to(array(1.0f), {2, 2, 2}); + CHECK(allclose(cos(x), full({2, 2, 2}, std::cos(1.0f))).item()); + + x = split(array({0.0f, 1.0f, 2.0f, 3.0f}, {2, 2}), 2, 1)[0]; + auto expected = array({std::cos(0.0f), std::cos(2.0f)}, {2, 1}); + CHECK(allclose(cos(x), expected).item()); + } + + // Test log + { + array x(0.0); + CHECK_EQ(log(x).item(), neginf); + + x = array(1.0); + CHECK_EQ(log(x).item(), log(1.0f)); + + // Integer input type + x = array(1); + CHECK_EQ(log(x).dtype(), float32); + CHECK_EQ(log(x).item(), log(1.0f)); + + // Input is irregularly strided + x = broadcast_to(array(1.0f), {2, 2, 2}); + CHECK(array_equal(log(x), full({2, 2, 2}, std::log(1.0f))).item()); + + x = split(array({1.0f, 2.0f, 3.0f, 4.0f}, {2, 2}), 2, 1)[0]; + auto expected = array({std::log(1.0f), std::log(3.0f)}, {2, 1}); + CHECK(array_equal(log(x), expected).item()); + } + + // Test log2 + { + array x(0.0); + CHECK_EQ(log2(x).item(), neginf); + + x = array(1.0); + CHECK_EQ(log2(x).item(), 0.0f); + + x = array(1024.0f); + CHECK_EQ(log2(x).item(), 10.0f); + } + + // Test log10 + { + array x(0.0); + CHECK_EQ(log10(x).item(), neginf); + + x = array(1.0); + CHECK_EQ(log10(x).item(), 0.0f); + + x = array(1000.0f); + CHECK_EQ(log10(x).item(), 3.0f); + } + + // Test log1p + { + array x(-1.0f); + CHECK_EQ(log1p(x).item(), neginf); + + x = array(1.0f); + CHECK_EQ(log1p(x).item(), std::log1pf(1.0f)); + + // Integer input type + x = array(1); + CHECK_EQ(log1p(x).dtype(), float32); + CHECK_EQ(log1p(x).item(), std::log1pf(1.0f)); + + // Input is irregularly strided + x = broadcast_to(array(1.0f), {2, 2, 2}); + CHECK( + array_equal(log1p(x), full({2, 2, 2}, std::log1pf(1.0f))).item()); + + x = split(array({1.0f, 2.0f, 3.0f, 4.0f}, {2, 2}), 2, 1)[0]; + auto expected = array({std::log1pf(1.0f), std::log1pf(3.0f)}, {2, 1}); + CHECK(array_equal(log1p(x), expected).item()); + } + + // Test sigmoid + { + array x(0.0); + CHECK_EQ(sigmoid(x).item(), 0.5f); + + // Integer input type + x = array(0); + CHECK_EQ(sigmoid(x).dtype(), float32); + CHECK_EQ(sigmoid(x).item(), 0.5f); + + constexpr auto inf = std::numeric_limits::infinity(); + x = array(inf); + CHECK_EQ(sigmoid(x).item(), 1.0f); + x = array(-inf); + CHECK_EQ(sigmoid(x).item(), 0.0f); + } + + // Test square + { + array x(3.0); + CHECK_EQ(square(x).item(), 9.0); + + x = array(2); + CHECK_EQ(square(x).item(), 4); + + x = full({3, 3}, 2.0f); + CHECK(array_equal(square(x), full({3, 3}, 4.0f)).item()); + } + + // Test sqrt and rsqrt + { + array x(4.0); + CHECK_EQ(sqrt(x).item(), 2.0); + CHECK_EQ(rsqrt(x).item(), 0.5); + + x = full({3, 3}, 9.0f); + CHECK(array_equal(sqrt(x), full({3, 3}, 3.0f)).item()); + + x = array(4, int32); + CHECK_EQ(sqrt(x).item(), 2.0f); + CHECK_EQ(rsqrt(x).item(), 0.5f); + } + + // Test reciprocal + { + array x(8.0); + CHECK_EQ(reciprocal(x).item(), 0.125f); + + x = array(2); + auto out = reciprocal(x); + CHECK_EQ(out.dtype(), float32); + CHECK_EQ(out.item(), 0.5f); + + x = full({3, 3}, 2.0f); + CHECK(array_equal(reciprocal(x), full({3, 3}, 0.5f)).item()); + } +} + +TEST_CASE("test error functions") { + constexpr float inf = std::numeric_limits::infinity(); + array x(0.0f); + CHECK_EQ(erf(x).item(), 0.0f); + x = array(inf); + CHECK_EQ(erf(x).item(), 1.0f); + x = array(-inf); + CHECK_EQ(erf(x).item(), -1.0f); + + x = array(1, int32); + CHECK_EQ(erf(x).dtype(), float32); + + x = array(0.0f); + CHECK_EQ(erfinv(x).item(), 0.0f); + x = array(1.0f); + CHECK_EQ(erfinv(x).item(), inf); + x = array(-1.0f); + CHECK_EQ(erfinv(x).item(), -inf); + + x = array(1, int32); + CHECK_EQ(erfinv(x).dtype(), float32); + + x = array(2.0f); + CHECK(std::isnan(erfinv(x).item())); + x = array(-2.0f); + CHECK(std::isnan(erfinv(x).item())); + + auto vals = {0.9f, 0.5f, 0.1f, -0.1f, -0.5f, -0.9f}; + // Expected values are generated from scipy's error function: + // python -c "import scipy.special as ss; + // vals = [0.9, 0.5, 0.1, -0.1, -0.5, -0.9]; + // print([ss.erf(x) for x in vals])" + { + auto expected = { + 0.7969082124228322, + 0.5204998778130465, + 0.1124629160182849, + -0.1124629160182849, + -0.5204998778130465, + -0.7969082124228322}; + for (int i = 0; i < vals.size(); ++i) { + x = array(vals.begin()[i]); + CHECK_EQ(erf(x).item(), doctest::Approx(expected.begin()[i])); + } + } + + // Expected values are generated from scipy's inverse error function: + // python -c "import scipy.special as ss; + // vals = [0.9, 0.5, 0.1, -0.1, -0.5, -0.9]; + // print([ss.erfinv(x) for x in vals])" + { + auto expected = { + 1.1630871536766738, + 0.4769362762044699, + 0.08885599049425778, + -0.08885599049425769, + -0.4769362762044699, + -1.1630871536766743}; + for (int i = 0; i < vals.size(); ++i) { + x = array(vals.begin()[i]); + CHECK_EQ(erfinv(x).item(), doctest::Approx(expected.begin()[i])); + } + } + + // float16_t + { + array x(0.0f, float16); + auto out = erf(x); + CHECK_EQ(out.dtype(), float16); + CHECK_EQ(out.item(), 0.0f); + + out = erfinv(x); + CHECK_EQ(out.dtype(), float16); + CHECK_EQ(out.item(), 0.0f); + } + + // bfloat + { + array x(0.0f, bfloat16); + auto out = erf(x); + CHECK_EQ(out.dtype(), bfloat16); + CHECK_EQ(out.item(), 0.0f); + + out = erfinv(x); + CHECK_EQ(out.dtype(), bfloat16); + CHECK_EQ(out.item(), 0.0f); + } +} + +TEST_CASE("test arithmetic binary ops") { + array x(1.0); + array y(1.0); + auto z = add(x, y); + CHECK_EQ(z.item(), 2.0); + z = x + y; + CHECK_EQ(z.item(), 2.0); + z = add(z, x); + CHECK_EQ(z.item(), 3.0); + z.eval(); // No-op + CHECK_EQ(z.item(), 3.0); + + // Chain a few adds: + auto out = x; + for (int i = 0; i < 10; ++i) { + out = add(out, x); + } + CHECK_EQ(out.item(), 11.0); + + // Works for different shapes + x = array({1.0, 2.0, 3.0}, {1, 3}); + y = array({1.0, 2.0, 3.0}, {1, 3}); + z = add(x, y); + CHECK_EQ(z.shape(), std::vector{1, 3}); + auto eq = array_equal(z, array({2.0, 4.0, 6.0}, {1, 3})); + CHECK(eq.item()); + + // Works with scalars + x = array({1.0, 2.0, 3.0}, {1, 3}); + y = x + 2.0; + CHECK_EQ(y.dtype(), float32); + eq = array_equal(y, array({3.0, 4.0, 5.0}, {1, 3})); + CHECK(eq.item()); + y = 2.0 + x; + CHECK_EQ(y.dtype(), float32); + eq = array_equal(y, array({3.0, 4.0, 5.0}, {1, 3})); + CHECK(eq.item()); + + // Check type promotion + y = 2 + x; + CHECK_EQ(y.dtype(), float32); + + y = 2.0 + array({1, 2, 3}); + CHECK_EQ(y.dtype(), float32); + CHECK(array_equal(y, array({3.0, 4.0, 5.0})).item()); + + // Broadcasting works + x = broadcast_to(array({1.0}), {10}); + y = broadcast_to(array({2.0}), {10}); + z = add(x, y); + CHECK(array_equal(z, full({10}, 3.0)).item()); + + x = array({1.0, 2.0}, {1, 2}); + y = array({1.0, 2.0}, {2, 1}); + z = add(x, y); + CHECK_EQ(z.shape(), std::vector{2, 2}); + eq = array_equal(z, array({2.0, 3.0, 3.0, 4.0}, {2, 2})); + CHECK(eq.item()); + + x = ones({3, 2, 1}); + z = x + 2.0; + CHECK_EQ(z.shape(), std::vector{3, 2, 1}); + eq = array_equal(z, array({3.0, 3.0, 3.0, 3.0, 3.0, 3.0}, {3, 2, 1})); + CHECK(eq.item()); + + // Works for empty arrays + x = array({}); + y = array({}); + z = x + y; + z.eval(); + CHECK_EQ(z.size(), 0); + CHECK_EQ(z.shape(), std::vector{0}); + + // Check subtraction + x = array({3, 2, 1}); + y = array({1, 1, 1}); + CHECK(array_equal(x - y, array({2, 1, 0})).item()); + + // Check multiplication + x = array({1, 2, 3}); + y = array({2, 2, 2}); + CHECK(array_equal(x * y, array({2, 4, 6})).item()); + + // Check division + x = array(1); + y = array(1); + CHECK_EQ(divide(x, y).item(), 1.0f); + + x = array(1); + y = array(0.5); + CHECK_EQ(divide(x, y).item(), 2.0f); + + x = array(1); + y = array(4); + CHECK_EQ(divide(x, y).item(), 0.25f); + + x = array(true); + y = array(true); + CHECK_EQ(divide(x, y).item(), 1.0f); + + x = array(false); + y = array(true); + CHECK_EQ(divide(x, y).item(), 0.0f); + + x = array(true); + y = array(false); + CHECK(std::isinf(divide(x, y).item())); + + x = array(false); + y = array(false); + CHECK(std::isnan(divide(x, y).item())); + + // Check maximum and minimum + x = array(1.0f); + y = array(0.0f); + CHECK_EQ(maximum(x, y).item(), 1.0f); + CHECK_EQ(minimum(x, y).item(), 0.0f); + y = array(2.0f); + CHECK_EQ(maximum(x, y).item(), 2.0f); + CHECK_EQ(minimum(x, y).item(), 1.0f); + + // Check logaddexp + x = array(0.0f); + y = array(0.0f); + CHECK_EQ(logaddexp(x, y).item(), std::log(2.0f)); + + x = array(0u); + y = array(10000u); + CHECK_EQ(logaddexp(x, y).item(), 10000.0f); + + constexpr float inf = std::numeric_limits::infinity(); + x = array(inf); + y = array(3.0f); + CHECK_EQ(logaddexp(x, y).item(), inf); + + x = array(-inf); + y = array(3.0f); + CHECK_EQ(logaddexp(x, y).item(), 3.0f); + + x = array(-inf); + y = array(-inf); + CHECK_EQ(logaddexp(x, y).item(), -inf); + + x = array(inf); + y = array(inf); + CHECK_EQ(logaddexp(x, y).item(), inf); + + x = array(-inf); + y = array(inf); + CHECK_EQ(logaddexp(x, y).item(), inf); +} + +TEST_CASE("test broadcast") { + auto s = broadcast_shapes({1}, {1, 2}); + CHECK_EQ(s, std::vector{1, 2}); + + s = broadcast_shapes({1, 2}, {1}); + CHECK_EQ(s, std::vector{1, 2}); + + s = broadcast_shapes({2, 2}, {}); + CHECK_EQ(s, std::vector{2, 2}); + + s = broadcast_shapes({}, {1, 1}); + CHECK_EQ(s, std::vector{1, 1}); + + s = broadcast_shapes({1, 2, 1}, {2}); + CHECK_EQ(s, std::vector{1, 2, 2}); + + s = broadcast_shapes({2}, {1, 2, 1}); + CHECK_EQ(s, std::vector{1, 2, 2}); + + s = broadcast_shapes({2, 2, 2}, {1, 2, 1}); + CHECK_EQ(s, std::vector{2, 2, 2}); + + s = broadcast_shapes({2, 2, 2, 1}, {1, 2, 1}); + CHECK_EQ(s, std::vector{2, 2, 2, 1}); + + s = broadcast_shapes({0}, {0, 0}); + CHECK_EQ(s, std::vector{0, 0}); + + CHECK_EQ(broadcast_shapes({}, {0}), std::vector{0}); + + s = broadcast_shapes({5, 0}, {0, 5, 0}); + CHECK_EQ(s, std::vector{0, 5, 0}); + + CHECK_EQ(broadcast_shapes({}, {0}), std::vector{0}); + CHECK_EQ(broadcast_shapes({1}, {0}), std::vector{0}); + CHECK_EQ(broadcast_shapes({1}, {0}), std::vector{0}); + CHECK_EQ(broadcast_shapes({1}, {0, 0}), std::vector{0, 0}); + CHECK_EQ(broadcast_shapes({1, 1}, {0}), std::vector{1, 0}); + CHECK_EQ(broadcast_shapes({1, 1}, {0, 0}), std::vector{0, 0}); + CHECK_EQ(broadcast_shapes({2, 1}, {1, 0}), std::vector{2, 0}); + CHECK_EQ(broadcast_shapes({2, 1}, {2, 0}), std::vector{2, 0}); + CHECK_EQ(broadcast_shapes({2, 1}, {1, 2, 0}), std::vector{1, 2, 0}); + CHECK_THROWS_AS(broadcast_shapes({2}, {0}), std::invalid_argument); + CHECK_THROWS_AS(broadcast_shapes({2, 1}, {0, 0}), std::invalid_argument); + + CHECK_THROWS_AS(broadcast_shapes({3}, {2}), std::invalid_argument); + CHECK_THROWS_AS(broadcast_shapes({1, 3}, {2}), std::invalid_argument); + CHECK_THROWS_AS(broadcast_shapes({3}, {1, 2}), std::invalid_argument); + CHECK_THROWS_AS( + broadcast_shapes({1, 3, 2}, {1, 2, 2}), std::invalid_argument); + + auto x = full({1, 1}, 2.3f); + CHECK_EQ(broadcast_to(x, {1, 1}).item(), 2.3f); + + x = broadcast_to(x, {5, 1}); + CHECK_EQ(x.shape(), std::vector{5, 1}); + x.eval(); + CHECK_EQ(x.strides(), std::vector{0, 0}); + + CHECK_THROWS_AS(broadcast_to(x, {1, 5}), std::invalid_argument); + x = broadcast_to(x, {5, 5}); + CHECK_EQ(x.shape(), std::vector{5, 5}); + + x = zeros({2, 1, 2}); + x = broadcast_to(x, {4, 2, 1, 2}); + CHECK_EQ(x.shape(), std::vector{4, 2, 1, 2}); + x.eval(); + CHECK_EQ(x.strides(), std::vector{0, 2, 0, 1}); + + // Broadcast on empty arrays works as expected + x = array({}); + CHECK_THROWS_AS(broadcast_to(x, {1}), std::invalid_argument); + + // Broadcast to empty array works as expected + x = array({1}); + auto y = broadcast_to(x, {0}); + eval(y); + CHECK_EQ(y.size(), 0); + CHECK_EQ(y.shape(), std::vector{0}); + + x = array({1, 2}, {2, 1}); + y = broadcast_to(x, {2, 0}); + eval(y); + CHECK_EQ(y.size(), 0); + CHECK_EQ(y.shape(), std::vector{2, 0}); + + // Check repeat application works + x = zeros({2}); + x = broadcast_to(broadcast_to(x, {2, 2}), {2, 2}); + CHECK_EQ(x.shape(), std::vector{2, 2}); + x.eval(); + CHECK_EQ(x.strides(), std::vector{0, 1}); + x = broadcast_to(broadcast_to(x, {2, 2}), {2, 2, 2}); + CHECK_EQ(x.shape(), std::vector{2, 2, 2}); + x.eval(); + CHECK_EQ(x.strides(), std::vector{0, 0, 1}); + + // Broadcast on transposed arrray works + x = array({0, 1, 2, 3, 4, 5}, {2, 3}); + x = broadcast_to(transpose(x), {2, 3, 2}); + CHECK_EQ(x.shape(), std::vector{2, 3, 2}); + y = broadcast_to(array({0, 3, 1, 4, 2, 5}, {3, 2}), {2, 3, 2}); + CHECK(array_equal(x, y).item()); + + // Reshape on broadcasted array works + x = array(1.0); + x = broadcast_to(x, {2}); + x = reshape(x, {1, 2}); + CHECK(array_equal(x, ones({1, 2})).item()); +} + +TEST_CASE("test gather") { + // More indices than dimensions + CHECK_THROWS(gather(array(0), array({1}), 0, {1})); + + // Mismatch dimensions and indices + CHECK_THROWS(gather(array({0}), {array({0})}, {0, 1}, {1})); + CHECK_THROWS(gather(array({0}), array({0}), -1, {1})); + + // Repeat dimensions + CHECK_THROWS( + gather(array({0}, {1, 1}), {array({0}), array({0})}, {0, 0}, {1, 1})); + + // Slice sizes incorrect + CHECK_THROWS(gather(array({0}), array({0}), 0, {2})); + CHECK_THROWS(gather(array({0}), array({0}), 0, {0, 0})); + CHECK_THROWS(gather(array({0}), array({0}), 0, {-1})); + + // Wrong index type + CHECK_THROWS(gather(array({0}), array({0.0f}), 0, {0})); + CHECK_THROWS( + gather(array({0}, {1, 1}), {array({0}), array({0.0f})}, {0, 1}, {1, 1})); + + // Index arrays must be broadcastable + CHECK_THROWS(gather( + array({0}, {1, 1}), + {array({0, 0, 0}, {3}), array({0, 0}, {2})}, + {0, 1}, + {1, 1})); + + // Basic test of correctness with 1D input + auto x = arange(20); + auto y = arange(10); + auto out = gather(x, y, 0, {1}); + CHECK_EQ(out.shape(), std::vector{10, 1}); + CHECK(array_equal(reshape(out, {-1}), y).item()); + + out = gather(x, array({15}, uint32), 0, {1}); + CHECK_EQ(out.shape(), std::vector{1, 1}); + CHECK_EQ(out.item(), 15); + + // No index gather works + out = gather(x, {}, std::vector{}, {10}); + CHECK_EQ(out.shape(), std::vector{10}); + CHECK(array_equal(out, arange(10)).item()); + + // Basic test of correctness with 2D input + x = arange(128); + x = reshape(x, {4, 32}); + y = array({0, 1}, uint32); + out = gather(x, y, 0, {1, 32}); + CHECK_EQ(out.shape(), std::vector{2, 1, 32}); + CHECK(array_equal(reshape(out, {64}), arange(64)).item()); + + x = reshape(x, {64, 2}); + y = array({0}, uint32); + out = gather(x, y, 0, {64, 1}); + CHECK_EQ(out.shape(), std::vector{1, 64, 1}); + CHECK(array_equal(out, reshape(arange(0, 128, 2), {1, 64, 1})).item()); + + // Basic test of correctness with 3D input + x = arange(256); + x = reshape(x, {8, 4, 8}); + y = array({0}, uint32); + out = gather(x, y, 0, {8, 1, 1}); + CHECK_EQ(out.shape(), std::vector{1, 8, 1, 1}); + CHECK( + array_equal(out, reshape(arange(0, 256, 32), {1, 8, 1, 1})).item()); + + x = broadcast_to(array({1, 2}), {20, 2}); + out = gather(x, array({5}), 0, {1, 1}); + CHECK_EQ(out.item(), 1); + out = gather(x, {array({5}), array({1})}, {0, 1}, {1, 1}); + CHECK_EQ(out.item(), 2); +} + +TEST_CASE("test take") { + // Empty takes + auto empty = astype(array({}), int32); + auto z = take(array({1}), empty); + CHECK_EQ(z.shape(), std::vector{0}); + empty = reshape(empty, {1, 0, 1}); + z = take(array({1}), empty); + CHECK_EQ(z.shape(), std::vector{1, 0, 1}); + + CHECK_THROWS(take(array({}), array(1))); + + z = take(array({}), empty); + CHECK_EQ(z.size(), 0); + + // Take a single row + auto x = reshape(arange(256), {8, 4, 8}); + z = take(x, array({0}, uint32), 0); + CHECK_EQ(z.shape(), std::vector{1, 4, 8}); + z = reshape(z, {32}); + CHECK(array_equal(z, arange(32)).item()); + + z = take(x, array({1}, uint32), 0); + z = reshape(z, {32}); + CHECK(array_equal(z, arange(32, 64)).item()); + + // Take multiple rows + x = arange(256); + x = reshape(x, {8, 4, 8}); + z = take(x, array({0, 1}, uint32), 0); + z = reshape(z, {64}); + CHECK(array_equal(z, arange(64)).item()); + + // Take along middle axis + x = reshape(arange(8), {2, 2, 2}); + z = take(x, array({0}), 1); + CHECK(array_equal(z, array({0, 1, 4, 5}, {2, 1, 2})).item()); + + // Irregular strides test + auto a = array({1, 2, 3}, float32); + auto indices = broadcast_to(array(0), {10}); + auto b = take(a, indices); + CHECK(array_equal(b, ones({10})).item()); + + // Take with 0 dim index + z = take(array({0, 1, 2}), array(0)); + CHECK_EQ(z.item(), 0); + CHECK_EQ(z.ndim(), 0); + + // Check take with float indices crashes + CHECK_THROWS(take(array({}), array({}))); + CHECK_THROWS(take(a, array({1.0, 2.0, 3.0}))); + + // Check axis + a = array({1, 2, 3, 4}, {2, 2}); + CHECK_THROWS(take(a, array({1}), -3)); + CHECK_THROWS(take(a, array({1}), 2)); + + // Check negative indices + a = array({1, 2, 3, 4}, {2, 2}); + CHECK_EQ(take(a, array({-1})).item(), 4); + CHECK(array_equal(take(a, array({1, -1})), array({2, 4})).item()); + CHECK(array_equal(take(a, array(-1), 0), array({3, 4})).item()); + + // Check shapes + a = zeros({2, 1, 1}); + auto out = take(a, array({1}), 0); + CHECK(array_equal(out, zeros({1, 1, 1})).item()); + out = take(a, array({0}), 1); + CHECK(array_equal(out, zeros({2, 1, 1})).item()); + out = take(a, array({0}), 1); + CHECK(array_equal(out, zeros({2, 1, 1})).item()); + a = zeros({1, 2, 1}); + out = take(a, array({0}), 0); + CHECK(array_equal(out, zeros({1, 2, 1})).item()); + out = take(a, array({0}), 1); + CHECK(array_equal(out, zeros({1, 1, 1})).item()); + out = take(a, array({0, 1}), 1); + CHECK(array_equal(out, zeros({1, 2, 1})).item()); +} + +TEST_CASE("test take along axis") { + // No zero dim arrays + auto a = array(1); + CHECK_THROWS(take_along_axis(a, array(0), 0)); + + // Index and array size mismatches + a = arange(5); + CHECK_THROWS(take_along_axis(a, array({1}), 1)); + CHECK_THROWS(take_along_axis(a, array({1}, {1, 1}), 0)); + CHECK_THROWS(take_along_axis(a, array(1), -1)); + + auto out = take_along_axis(a, array({1}), 0); + CHECK_EQ(out.item(), 1); + out = take_along_axis(a, array({1}), -1); + CHECK_EQ(out.item(), 1); + + // Indices have wrong shape + a = zeros({2, 3, 4}); + CHECK_THROWS(take(a, zeros({1, 3, 4}), 1)); + CHECK_THROWS(take(a, zeros({2, 3, 7}), 1)); + CHECK_THROWS(take(a, zeros({2, 3, 2}), 0)); + + // Empty arrays + a = reshape(array({}), {1, 0}); + CHECK_THROWS(take_along_axis(a, array({1}), 0)); + + out = take_along_axis(a, reshape(array({1}), {1, 1}), 0); + eval(out); // Make sure it runs + CHECK_EQ(out.shape(), std::vector{1, 0}); + + auto inds = reshape(astype(array({}), int32), {1, 0}); + out = take_along_axis(a, inds, 0); + eval(out); // Make sure it runs + CHECK_EQ(out.shape(), std::vector{1, 0}); + + a = array({1, 2, 3, 4}, {2, 2}); + inds = array({0, 1}, {1, 2}); + out = take_along_axis(a, inds, 0); + CHECK(array_equal(out, array({1, 4}, {1, 2})).item()); + + inds = array({0, 1, 0, 1, 0, 0, 1, 0}, {4, 2}, int32); + out = take_along_axis(a, inds, 0); + CHECK(array_equal(out, array({1, 4, 1, 4, 1, 2, 3, 2}, {4, 2})).item()); + + inds = array({0, 1}, {2, 1}); + out = take_along_axis(a, inds, 1); + CHECK(array_equal(out, array({1, 4}, {2, 1})).item()); + + // Broadcasting works + inds = array({0}, {1, 1}); + out = take_along_axis(a, inds, 0); + CHECK(array_equal(out, array({1, 2}, {1, 2})).item()); + out = take_along_axis(a, inds, 1); + CHECK(array_equal(out, array({1, 3}, {2, 1})).item()); + + inds = array({0, 1, 1, 0, 0, 1}, {2, 3}, int32); + out = take_along_axis(a, inds, 1); + CHECK(array_equal(out, array({1, 2, 2, 3, 3, 4}, {2, 3})).item()); + + a = reshape(arange(8), {2, 2, 2}); + inds = array({0, 1, 0, 0, 1, 0, 0, 1}, {2, 2, 2}); + out = take_along_axis(a, inds, 0); + CHECK(array_equal(out, array({0, 5, 2, 3, 4, 1, 2, 7}, {2, 2, 2})) + .item()); + out = take_along_axis(a, inds, 1); + CHECK(array_equal(out, array({0, 3, 0, 1, 6, 5, 4, 7}, {2, 2, 2})) + .item()); + out = take_along_axis(a, inds, 2); + CHECK(array_equal(out, array({0, 1, 2, 2, 5, 4, 6, 7}, {2, 2, 2})) + .item()); +} + +TEST_CASE("test scatter") { + // More indices than dimensions + CHECK_THROWS(scatter(array(0), array({1}), array(1), 0)); + + // Mismatch dimensions and indices + CHECK_THROWS(scatter(array({0}), {array({0})}, array({1}, {1, 1}), {0, 1})); + CHECK_THROWS(scatter(array({0}), array({0}), array({1}, {1, 1}), -1)); + + // Repeat dimensions + CHECK_THROWS(scatter( + array({0}, {1, 1}), {array({0}), array({0})}, array({1}), {0, 0})); + + // Update sizes incorrect + CHECK_THROWS(scatter(array({0}), array({0}), array({0, 1}), 0)); + CHECK_THROWS(scatter(array({0}), array({0}), array({0, 1}, {2, 1}), 0)); + CHECK_THROWS(scatter(array({0}, {1}), array({0}), array({0, 1}, {1, 2}), 0)); + + // Wrong index type + CHECK_THROWS(scatter(array({0}), array({0.0f}), array({0}, {1, 1}), 0)); + CHECK_THROWS(scatter( + array({0}, {1, 1}), + {array({0}), array({0.0f})}, + array({1}, {1, 1, 1}), + {0, 1})); + + // Index arrays must be broadcastable + CHECK_THROWS(scatter( + array({0}, {1, 1}), + {array({0, 0, 0}, {3}), array({0, 0}, {2})}, + ones({3, 2, 1, 1}), + {0, 1})); + + // Single element scatter + auto in = zeros({4}, float32); + auto inds = arange(2); + auto updates = ones({2, 1}, float32); + auto out = scatter(in, inds, updates, 0); + CHECK(array_equal(out, array({1.0f, 1.0f, 0.0f, 0.0f})).item()); + + // Single element scatter add + in = ones({4}, float32); + inds = array({0, 0, 3}); + updates = ones({3, 1}, float32); + out = scatter_add(in, inds, updates, 0); + CHECK(array_equal(out, array({3.0f, 1.0f, 1.0f, 2.0f})).item()); + + // Single element scatter prod + in = ones({4}, float32); + inds = array({0, 0, 3}); + updates = full({3, 1}, 2.0f, float32); + out = scatter_prod(in, inds, updates, 0); + CHECK(array_equal(out, array({4.0f, 1.0f, 1.0f, 2.0f})).item()); + + // Single element scatter max + in = ones({4}, float32); + inds = array({0, 0, 3}); + updates = array({1.0f, 6.0f, -2.0f}, {3, 1}); + out = scatter_max(in, inds, updates, 0); + CHECK(array_equal(out, array({6.0f, 1.0f, 1.0f, 1.0f})).item()); + + // Single element scatter min + in = ones({4}, float32); + inds = array({0, 0, 3}); + updates = array({1.0f, -6.0f, 2.0f}, {3, 1}); + out = scatter_min(in, inds, updates, 0); + CHECK(array_equal(out, array({-6.0f, 1.0f, 1.0f, 1.0f})).item()); + + // Empty scatter + in = arange(4, float32); + inds = astype(array({}), uint32); + updates = reshape(array({}), {0, 1}); + out = scatter(in, inds, updates, 0); + CHECK(array_equal(out, in).item()); + + // Array scatters + in = zeros({4, 4}, float32); + inds = array({0, 1, 2, 3}); + updates = reshape(arange(16, float32), {4, 1, 4}); + out = scatter(in, inds, updates, 0); + CHECK(array_equal(out, reshape(arange(16, float32), {4, 4})).item()); + + // Irregular strided index and reduce collison test + in = zeros({10}, float32); + inds = broadcast_to(array(3), {10}); + updates = ones({10, 1}, float32); + out = scatter_add(in, inds, updates, 0); + CHECK_EQ(take(out, array(3)).item(), 10); + + // 1 element array with 0 dim index + in = array({1}, int32); + updates = array({2}, int32); + out = scatter_max(in, array(0), updates, 0); + CHECK_EQ(out.item(), 2); + + // No index arrays or axes + out = scatter_max(array(1), {}, array(2), std::vector{}); + CHECK_EQ(out.item(), 2); + + // Irregularaly strided updates test + in = ones({3, 3}); + updates = broadcast_to(array({0, 0, 0}), {1, 3, 3}); + inds = array({0}); + out = scatter(in, inds, updates, 0); + CHECK(array_equal(out, zeros({3, 3})).item()); + + // Along different axis + in = zeros({2, 3}); + updates = array({1, 2, 3, 4}, {2, 2, 1}); + inds = array({0, 2}); + out = scatter(in, inds, updates, 1); + auto expected = array({1, 0, 3, 2, 0, 4}, {2, 3}); + CHECK(array_equal(out, expected).item()); + + // Multiple index arrays + in = zeros({2, 2}); + updates = array({1, 2}, {2, 1, 1}); + inds = array({0, 1}); + out = scatter(in, {inds, inds}, updates, {0, 1}); + CHECK(array_equal(out, array({1, 0, 0, 2}, {2, 2})).item()); + + // Broadcasted indices + in = zeros({2, 2}); + updates = array({5, 2, 9, 1}, {2, 2, 1, 1}); + auto inds0 = array({0, 1}, {2, 1}); + auto inds1 = array({0, 1}, {1, 2}); + out = scatter(in, {inds0, inds1}, updates, {0, 1}); + CHECK(array_equal(out, array({5, 2, 9, 1}, {2, 2})).item()); + + // Brodacasted operand + in = broadcast_to(array({0, 0}), {2, 2}); + updates = array({1, 1}, {2, 1, 1}); + inds = array({0, 1}); + out = scatter_add(in, inds, updates, 0); + CHECK(array_equal(out, array({1, 0, 1, 0}, {2, 2})).item()); +} + +TEST_CASE("test complex ops") { + // Creation ops + { + auto x = full({2, 2}, complex64_t{1, 1}); + CHECK_EQ(x.dtype(), complex64); + std::initializer_list expected = { + {1, 1}, {1, 1}, {1, 1}, {1, 1}}; + CHECK(array_equal(x, array(expected, {2, 2})).item()); + } + + // Unary ops + { + std::initializer_list vals = {{0, 1}, {1, 0}, {1, 1}}; + auto x = array(vals); + + auto y = abs(x); + CHECK_EQ(y.dtype(), float32); + CHECK(array_equal(y, array({1.0f, 1.0f, std::sqrt(2.0f)})).item()); + + y = negative(x); + std::initializer_list expected = {{0, -1}, {-1, 0}, {-1, -1}}; + CHECK(array_equal(y, array(expected)).item()); + + y = exp(x); + { + std::initializer_list expected = { + {0.54030231, 0.84147098}, {2.71828183, 0.}, {1.46869394, 2.28735529}}; + CHECK(allclose(y, array(expected)).item()); + } + + y = sin(x); + { + std::initializer_list expected = { + {0., 1.17520119}, {0.84147098, 0.}, {1.29845758, 0.63496391}}; + CHECK(allclose(y, array(expected)).item()); + } + + y = cos(x); + { + std::initializer_list expected = { + {1.54308063, -0.}, {0.54030231, -0.}, {0.83373003, -0.98889771}}; + CHECK(allclose(y, array(expected)).item()); + } + } + + // Binary ops + { + std::initializer_list vals_x = {{0, 1}, {1, 0}, {1, 1}}; + auto x = array(vals_x); + + std::initializer_list vals_y = {{2, 0}, {1, 1}, {0, 1}}; + auto y = array(vals_y); + + auto z = add(x, y); + { + std::initializer_list expected = {{2, 1}, {2, 1}, {1, 2}}; + CHECK(array_equal(z, array(expected)).item()); + } + + z = subtract(x, y); + { + std::initializer_list expected = {{-2, 1}, {0, -1}, {1, 0}}; + CHECK(array_equal(z, array(expected)).item()); + } + + z = multiply(x, y); + { + std::initializer_list expected = {{0, 2}, {1, 1}, {-1, 1}}; + CHECK(array_equal(z, array(expected)).item()); + } + + z = maximum(x, y); + { + std::initializer_list expected = {{2, 0}, {1, 1}, {1, 1}}; + CHECK(array_equal(z, array(expected)).item()); + } + } + + // Reductions + if (default_device() == Device::cpu) { + std::initializer_list vals = {{0, 0}, {1, 0}, {0, 1}}; + auto x = array(vals); + CHECK_EQ(max(x).item(), complex64_t{1, 0}); + CHECK_EQ(min(x).item(), complex64_t{0, 0}); + CHECK_EQ(sum(x).item(), complex64_t{1, 1}); + CHECK_EQ(prod(x).item(), complex64_t{0, 0}); + } +} + +TEST_CASE("test as_strided op") { + auto x = arange(10); + auto y = as_strided(x, {3, 3}, {1, 1}, 0); + auto expected = array({0, 1, 2, 1, 2, 3, 2, 3, 4}, {3, 3}); + CHECK(array_equal(y, expected).item()); + + y = as_strided(x, {3, 3}, {0, 3}, 0); + expected = array({0, 3, 6, 0, 3, 6, 0, 3, 6}, {3, 3}); + CHECK(array_equal(y, expected).item()); + + x = reshape(x, {2, 5}); // 0 1 2 3 ... + x = transpose(x, {1, 0}); // 0 5 1 6 2 7 ... + y = as_strided(x, {3, 3}, {2, 1}, 1); + expected = array({5, 1, 6, 6, 2, 7, 7, 3, 8}, {3, 3}); + CHECK(array_equal(y, expected).item()); +} + +TEST_CASE("test scan op") { + auto x = array({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}, {2, 3}); + auto y = cumsum(x, 1, false, true); + auto expected = array({1.0f, 3.0f, 6.0f, 4.0f, 9.0f, 15.0f}, {2, 3}); + CHECK(array_equal(y, expected).item()); + + y = cumsum(x, 1, false, false); + expected = array({0.0f, 1.0f, 3.0f, 0.0f, 4.0f, 9.0f}, {2, 3}); + CHECK(array_equal(y, expected).item()); + + y = cumsum(x, 1, true, true); + expected = array({6.0f, 5.0f, 3.0f, 15.0f, 11.0f, 6.0f}, {2, 3}); + CHECK(array_equal(y, expected).item()); + + y = cumsum(x, 1, true, false); + expected = array({5.0f, 3.0f, 0.0f, 11.0f, 6.0f, 0.0f}, {2, 3}); + CHECK(array_equal(y, expected).item()); + + x = array({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f}, {2, 2, 2}); + y = cumsum(x, 0, false, true); + expected = + array({1.0f, 2.0f, 3.0f, 4.0f, 6.0f, 8.0f, 10.0f, 12.0f}, {2, 2, 2}); + CHECK(array_equal(y, expected).item()); + + y = cumsum(x, 1, false, true); + expected = + array({1.0f, 2.0f, 4.0f, 6.0f, 5.0f, 6.0f, 12.0f, 14.0f}, {2, 2, 2}); + CHECK(array_equal(y, expected).item()); + + x = array({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f}, {2, 2, 2}); + y = cumsum(x, 0, true, true); + expected = + array({6.0f, 8.0f, 10.0f, 12.0f, 5.0f, 6.0f, 7.0f, 8.0f}, {2, 2, 2}); + CHECK(array_equal(y, expected).item()); + + y = cumsum(x, 1, true, true); + expected = + array({4.0f, 6.0f, 3.0f, 4.0f, 12.0f, 14.0f, 7.0f, 8.0f}, {2, 2, 2}); + CHECK(array_equal(y, expected).item()); + + x = reshape(x, {4, 2}); + y = cumsum(x, 0, false, false); + expected = array({0.0f, 0.0f, 1.0f, 2.0f, 4.0f, 6.0f, 9.0f, 12.0f}, {4, 2}); + CHECK(array_equal(y, expected).item()); + + y = cumsum(x, 0, true, false); + expected = + array({15.0f, 18.0f, 12.0f, 14.0f, 7.0f, 8.0f, 0.0f, 0.0f}, {4, 2}); + CHECK(array_equal(y, expected).item()); + + // Check the vmap implementation + auto fun = [](array x) { return cumsum(x, 0, false, true); }; + y = vmap(fun, 0, 0)(x); + expected = array({1.0f, 3.0f, 3.0f, 7.0f, 5.0f, 11.0f, 7.0f, 15.0f}, {4, 2}); + CHECK(array_equal(y, expected).item()); + + y = vmap(fun, 1, 1)(x); + expected = array({1.0f, 2.0f, 4.0f, 6.0f, 9.0f, 12.0f, 16.0f, 20.0f}, {4, 2}); + CHECK(array_equal(y, expected).item()); +} + +TEST_CASE("test pad") { + auto x = zeros({1, 2, 3}); + CHECK_EQ(pad(x, 1).shape(), std::vector{3, 4, 5}); + CHECK_EQ(pad(x, {0, 1}).shape(), std::vector{2, 3, 4}); + CHECK_EQ(pad(x, {{1, 1}, {1, 2}, {3, 1}}).shape(), std::vector{3, 5, 7}); +} + +TEST_CASE("test power") { + CHECK_EQ(power(array(1), array(2)).item(), 1); + CHECK_EQ((array(1) ^ 2).item(), 1); + CHECK_EQ((1 ^ array(2)).item(), 1); + CHECK_EQ((array(-1) ^ 2).item(), 1); + CHECK_EQ((array(-1) ^ 3).item(), -1); + + // TODO Throws but exception not caught from calling thread + // CHECK_THROWS((x^-1).item()); + + CHECK_EQ((array(true) ^ array(false)).item(), true); + CHECK_EQ((array(false) ^ array(false)).item(), true); + CHECK_EQ((array(true) ^ array(true)).item(), true); + CHECK_EQ((array(false) ^ array(true)).item(), false); + + auto x = array(2.0f); + CHECK_EQ((x ^ 0.5).item(), std::pow(2.0f, 0.5f)); + CHECK_EQ((x ^ 2.0f).item(), 4.0f); + + CHECK(std::isnan((array(-1.0f) ^ 0.5).item())); + + auto a = complex64_t{0.5, 0.5}; + auto b = complex64_t{0.5, 0.5}; + auto expected = std::pow(a, b); + auto out = (array(a) ^ array(b)).item(); + CHECK(abs(out.real() - expected.real()) < 1e-7); + CHECK(abs(out.imag() - expected.imag()) < 1e-7); +} + +TEST_CASE("test where") { + array condition(true); + array x(1.0f); + array y(0.0f); + auto out = where(condition, x, y); + CHECK_EQ(out.dtype(), float32); + CHECK_EQ(out.item(), 1.0f); + + x = array({1, 2}, {2, 1}); + y = array({3, 4}, {1, 2}); + CHECK(array_equal(where(condition, x, y), broadcast_to(x, {2, 2})) + .item()); + + condition = array(false); + CHECK(array_equal(where(condition, x, y), broadcast_to(y, {2, 2})) + .item()); + + condition = array({true, false}); + out = where(condition, x, y); + auto expected = array({1, 4, 2, 4}, {2, 2}); + CHECK(array_equal(where(condition, x, y), expected).item()); + + condition = array({true, false, false, true}, {2, 2}); + out = where(condition, x, y); + expected = array({1, 4, 3, 2}, {2, 2}); + CHECK(array_equal(where(condition, x, y), expected).item()); + + x = array(1); + y = array(2); + out = where(condition, x, y); + expected = array({1, 2, 2, 1}, {2, 2}); + CHECK(array_equal(where(condition, x, y), expected).item()); +} diff --git a/tests/random_tests.cpp b/tests/random_tests.cpp new file mode 100644 index 0000000000..92ed802de9 --- /dev/null +++ b/tests/random_tests.cpp @@ -0,0 +1,545 @@ +#include + +#include "doctest/doctest.h" + +#include "mlx/mlx.h" + +using namespace mlx::core; + +TEST_CASE("test random key") { + auto key = random::key(0); + CHECK(array_equal(key, array({0, 0})).item()); + + key = random::key(1); + CHECK(array_equal(key, array({0, 1})).item()); + + int64_t seed = static_cast(1) << 32; + key = random::key(seed); + CHECK(array_equal(key, array({1, 0})).item()); + + key = random::key(seed + 1); + CHECK(array_equal(key, array({1, 1})).item()); +} + +TEST_CASE("test global rng") { + random::seed(4); + auto x = random::bits({}); + auto y = random::bits({}); + + random::seed(4); + auto a = random::bits({}); + auto b = random::bits({}); + + CHECK_EQ(x.item(), a.item()); + CHECK_EQ(y.item(), b.item()); +} + +TEST_CASE("test random split") { + auto [key, subkey] = random::split(random::key(0)); + CHECK(array_equal(key, array({4146024105u, 967050713u})).item()); + CHECK(array_equal(subkey, array({2718843009u, 1272950319u})).item()); + + auto keys = random::split(random::key(0), 3); + auto expected = array( + {2467461003u, + 428148500u, + 3186719485u, + 3840466878u, + 2562233961u, + 1946702221u}, + {3, 2}); + CHECK(array_equal(keys, expected).item()); +} + +TEST_CASE("test random bits") { + // Test shapes, types, and sizes + { + auto key = random::key(0); + auto x = random::bits({}, key); + CHECK_EQ(x.size(), 1); + CHECK_EQ(x.dtype(), uint32); + + x = random::bits({0}, key); + CHECK(array_equal(x, array({})).item()); + + // Check wrong key type or shape + key = array({0, 0}); + CHECK_THROWS_AS(random::uniform({}, key), std::invalid_argument); + key = array({0, 0}, {1, 2}); + CHECK_THROWS_AS(random::uniform({}, key), std::invalid_argument); + key = array({0u, 0u, 0u}, {3, 1}); + CHECK_THROWS_AS(random::uniform({}, key), std::invalid_argument); + key = array({0u, 0u}, {2, 1}); + CHECK_THROWS_AS(random::uniform({}, key), std::invalid_argument); + } + + // Expected bits in the following tests were generated from + // Jax's Threefry 2x32 implementation using the following in + // python: + // + // ``` + // import jax + // import jax.prng + // shape = (SET THIS) + // seed = (SET THIS) + // width = (SET THIS) + // key = jax.random.PRNGKey(seed) + // print(jax.prng.threefry_prng_impl.random_bits(key, width, shape)) + + { + auto key = random::key(0); + auto x = random::bits({}, key); + auto y = random::bits({}, key); + CHECK_EQ(x.item(), 1797259609u); + CHECK_EQ(x.item(), y.item()); + + x = random::bits({}, 2, key); + CHECK_EQ(x.item(), 345); + + x = random::bits({}, 1, key); + CHECK_EQ(x.item(), 89); + } + + { + auto key = random::key(1); + auto x = random::bits({}, key); + CHECK_EQ(x.item(), 507451445u); + + x = random::bits({}, 2, key); + CHECK_EQ(x.item(), 6197); + + x = random::bits({}, 1, key); + CHECK_EQ(x.item(), 53); + + CHECK_THROWS(random::bits({}, 0, key)); + CHECK_THROWS(random::bits({}, 5, key)); + CHECK_THROWS(random::bits({}, -1, key)); + } + + { + auto key = random::key(0); + auto x = random::bits({3, 1}, key); + auto expected = array({4146024105u, 1351547692u, 2718843009u}, {3, 1}); + CHECK(array_equal(x, expected).item()); + + x = random::bits({5}, 2, key); + expected = array({20137, 63263, 64300, 20622, 16513}, uint16); + CHECK(array_equal(x, expected).item()); + expected = array({20137, 63263, 64300, 20622, 16513, 41486}, uint16); + x = random::bits({6}, 2, key); + CHECK(array_equal(x, expected).item()); + expected = array({20137, 63263, 1497, 14756, 16513, 41486, 44591}, uint16); + x = random::bits({7}, 2, key); + CHECK(array_equal(x, expected).item()); + x = random::bits({8}, 2, key); + expected = + array({20137, 63263, 1497, 14756, 16513, 41486, 44591, 19423}, uint16); + CHECK(array_equal(x, expected).item()); + } + + { + auto key = array({0u, 0u, 1u, 1u}, {2, 2}); + auto shape = std::vector{3}; + auto fn = [&shape](array k) { return random::bits(shape, k); }; + + auto expected = array( + {4146024105u, + 1351547692u, + 2718843009u, + 3725146706u, + 1802982961u, + 1349634643u}, + {2, 3}); + CHECK(array_equal(vmap(fn)(key), expected).item()); + expected = array( + {2441914641u, + 1110694964u, + 3819641963u, + 2441914641u, + 1110694964u, + 3819641963u}, + {2, 3}); + CHECK(array_equal(vmap(fn, 1)(key), expected).item()); + + // Vmap twice + key = array( + {0u, + 0u, + 1u, + 1u, + 2u, + 2u, + + 3u, + 3u, + 4u, + 4u, + 5u, + 5u}, + {3, 2, 2}); + shape = {2}; + auto out = vmap(vmap(fn))(key); + expected = array( + {928981903u, + 3453687069u, + 3606183818u, + 460005496u, + + 2799733733u, + 856293553u, + 4081856343u, + 3445925136u, + + 2775548010u, + 1430281703u, + 305173070u, + 2615843348u}, + {3, 2, 2}); + CHECK(array_equal(out, expected).item()); + + out = vmap(vmap(fn, 1), 0)(key); + expected = array( + {1948878966u, + 4237131848u, + 1948878966u, + 4237131848u, + + 2531170506u, + 1858648356u, + 2531170506u, + 1858648356u, + + 740561898u, + 4234094099u, + 740561898u, + 4234094099u}, + {3, 2, 2}); + CHECK(array_equal(out, expected).item()); + } + + // Vmap smaller type + { + auto key = array({0u, 0u, 1u, 1u}, {2, 2}); + auto fn = [](array k) { return random::bits({5}, 2, k); }; + + auto expected = array( + {4146024105u, + 1351547692u, + 2718843009u, + 3725146706u, + 1802982961u, + 1349634643u}, + {2, 3}); + auto out = vmap(fn)(key); + auto x1 = random::bits({5}, 2, take(key, array(0), 0)); + auto x2 = random::bits({5}, 2, take(key, array(1), 0)); + + CHECK(array_equal(take(out, array(0), 0), x1).item()); + CHECK(array_equal(take(out, array(1), 0), x2).item()); + } +} + +TEST_CASE("test random uniform") { + // Test shapes, types, and sizes + { + auto x = random::uniform({}); + CHECK_EQ(x.size(), 1); + CHECK_EQ(x.dtype(), float32); + + if (is_available(float16)) { + x = random::uniform({}, float16); + CHECK_EQ(x.size(), 1); + CHECK_EQ(x.dtype(), float16); + } + + x = random::uniform({0}); + CHECK(array_equal(x, array({})).item()); + + // Non float type throws + CHECK_THROWS_AS(random::uniform({}, int32), std::invalid_argument); + + // Check broadcasting + x = random::uniform(zeros({3, 1}), ones({1, 3}), {3, 3}); + CHECK_EQ(x.shape(), std::vector{3, 3}); + CHECK_THROWS_AS( + random::uniform(zeros({3, 3}), 1.0, {1, 3}), std::invalid_argument); + CHECK_THROWS_AS( + random::uniform(zeros({3, 3}), 1.0, {2, 3}), std::invalid_argument); + CHECK_THROWS_AS( + random::uniform(zeros({3, 1}), ones({1, 3}), {1, 3}), + std::invalid_argument); + + // Check wrong key type or shape + auto key = array({0, 0}); + CHECK_THROWS_AS(random::uniform({}, key), std::invalid_argument); + key = array({0, 0}, {1, 2}); + CHECK_THROWS_AS(random::uniform({}, key), std::invalid_argument); + key = array({0u, 0u, 0u}, {3, 1}); + CHECK_THROWS_AS(random::uniform({}, key), std::invalid_argument); + key = array({0u, 0u}, {2, 1}); + CHECK_THROWS_AS(random::uniform({}, key), std::invalid_argument); + } + + // Expected bits in the following tests were generated from + // Jax's Threefry 2x32 implementation using the following in + // python: + // + // ``` + // import jax + // import jax.prng + // shape = (SET THIS) + // seed = (SET THIS) + // key = jax.random.PRNGKey(seed) + // print(jax.prng.threefry_prng_impl.random_bits(key, 32, shape)) + + constexpr auto to_float = [](uint32_t n) { + return static_cast(n) / UINT32_MAX; + }; + + { + auto key = random::key(0); + auto x = random::uniform({}, key); + auto y = random::uniform({}, key); + auto expected = to_float(1797259609); + CHECK_EQ(x.item(), expected); + CHECK_EQ(x.item(), y.item()); + } + + { + auto key = random::key(1); + auto x = random::uniform({}, key); + auto expected = to_float(507451445); + CHECK_EQ(x.item(), expected); + } + + { + auto key = random::key(0); + auto x = random::uniform({3, 1}, key); + auto expected = array( + {to_float(4146024105), to_float(1351547692), to_float(2718843009)}, + {3, 1}); + CHECK(array_equal(x, expected).item()); + } + + // Check vmap + { + auto key = random::key(0); + auto fun = [](array k, array low) { + return random::uniform(low, 1, {3}, float32, k); + }; + auto out = vmap(fun, -1)(key, zeros({2, 3})); + CHECK_EQ(out.shape(), std::vector{2, 3}); + + key = zeros({2, 2}, uint32); + out = vmap(fun)(key, zeros({2, 3})); + CHECK_EQ(out.shape(), std::vector{2, 3}); + } + + // Check bounds are respected + { + auto key = random::key(128291); + auto out = random::uniform(array(-1.0f), array(1.0f), {100}, float32, key); + CHECK(all(less(out, array(1.0f))).item()); + CHECK(all(greater_equal(out, array(-1.0f))).item()); + } +} + +TEST_CASE("test random normal") { + // Test shapes, types, and sizes + { + auto x = random::normal({}); + CHECK_EQ(x.size(), 1); + CHECK_EQ(x.dtype(), float32); + + x = random::uniform({0}); + CHECK(array_equal(x, array({})).item()); + + // Non float type throws + CHECK_THROWS_AS(random::normal({}, int32), std::invalid_argument); + + // Check wrong key type or shape + auto key = array({0, 0}); + CHECK_THROWS_AS(random::normal({}, key), std::invalid_argument); + key = array({0, 0}, {1, 2}); + CHECK_THROWS_AS(random::normal({}, key), std::invalid_argument); + key = array({0u, 0u, 0u}, {3, 1}); + CHECK_THROWS_AS(random::normal({}, key), std::invalid_argument); + key = array({0u, 0u}, {2, 1}); + CHECK_THROWS_AS(random::normal({}, key), std::invalid_argument); + } + + { + constexpr float inf = std::numeric_limits::infinity(); + auto key = random::key(128291); + auto out = random::normal({100}, key); + CHECK(all(less(abs(out), array(inf))).item()); + } +} + +TEST_CASE("test random randint") { + CHECK_THROWS_AS( + random::randint(array(3), array(5), {1}, float32), std::invalid_argument); + + auto x = random::randint(0, 10, {}, uint32); + CHECK_EQ(x.size(), 1); + CHECK_EQ(x.dtype(), uint32); + + x = random::randint(0, 2, {}, bool_); + CHECK_EQ(x.size(), 1); + CHECK_EQ(x.dtype(), bool_); + + x = random::randint(0, 2, {}, int32); + CHECK_EQ(x.size(), 1); + CHECK_EQ(x.dtype(), int32); + + x = random::randint(0, 2, {}, int64); + CHECK_EQ(x.size(), 1); + CHECK_EQ(x.dtype(), int64); + + // Check all in bounds + auto low = -10.0; + auto high = 20.0; + x = random::randint(low, high, {1000, 1000}); + CHECK((all(low <= x).item() && all(x < high).item())); + + // Check high < low => all equals to low + low = 20.0; + high = -10.0; + x = random::randint(low, high, {3, 3}); + CHECK(all(equal(x, array(low))).item()); + + // Check wrong key type or shape + auto key = array({0, 0}, {1, 2}); + CHECK_THROWS_AS( + random::randint(low, high, {}, float32, key), std::invalid_argument); +} + +TEST_CASE("test random bernoulli") { + auto x = random::bernoulli(); + + CHECK_EQ(x.size(), 1); + CHECK_EQ(x.dtype(), bool_); + + // Bernoulli parameter can have floating point type + if (is_available(float16)) { + x = random::bernoulli(array(0.5, float16)); + CHECK_EQ(x.size(), 1); + CHECK_EQ(x.dtype(), bool_); + } + + CHECK_THROWS(random::bernoulli(array(1, int32))); + + // Negative numbers allowed in Jax + x = random::bernoulli(array(-1.0)); + CHECK_FALSE(x.item()); + + x = random::bernoulli(array(5.0)); + CHECK(x.item()); + + // Return array with correct shape + x = random::bernoulli(0.5, {3, 3}); + CHECK_EQ(x.shape(), std::vector({3, 3})); + + // Try with p = {} + x = random::bernoulli(array({})); + CHECK_EQ(x.size(), 0); + + // Try broadcasting + auto p = array({0.1, 0.2, 0.3}); + p = reshape(p, {1, 3}); + x = random::bernoulli(p, {4, 3}); + CHECK_EQ(x.shape(), std::vector({4, 3})); + + CHECK_THROWS_AS(random::bernoulli(array({}), {3, 3}), std::invalid_argument); + + p = array({0.1, 0.2, 0.3}); + // Ask for the wrong shape => throws + CHECK_THROWS_AS(random::bernoulli(p, {2}), std::invalid_argument); + + // Check wrong key type or shape + auto key = array({0, 0}, {1, 2}); + CHECK_THROWS_AS(random::bernoulli(array(0.5), key), std::invalid_argument); +} + +TEST_CASE("Test truncated normal") { + auto x = random::truncated_normal(array(-2.0), array(2.0)); + + CHECK_EQ(x.size(), 1); + CHECK_EQ(x.dtype(), float32); + + if (is_available(float16)) { + x = random::truncated_normal(array(-2.0), array(2.0), {}, float16); + CHECK_EQ(x.size(), 1); + CHECK_EQ(x.dtype(), float16); + } + + // Requested shape + x = random::truncated_normal(array(-2.0), array(2.0), {3, 4}); + CHECK_EQ(x.shape(), std::vector({3, 4})); + + // Empty array + x = random::truncated_normal(array({}), array({})); + CHECK_EQ(x.size(), 0); + + // Broadcast + auto lower = reshape(array({-2.0, -3.0}), {1, 2}); + auto higher = reshape(array({0.0, 3.0, 1.5}), {3, 1}); + x = random::truncated_normal(lower, higher); + + // All in bounds + CHECK_EQ(x.shape(), std::vector({3, 2})); + CHECK((all(x <= higher).item() && all(lower <= x).item())); + + // high < low => all equal to low + x = random::truncated_normal(array(2.0), array(-2.0)); + CHECK(all(x == array(2.0)).item()); + + // Non broadcastable => throws + CHECK_THROWS_AS( + random::truncated_normal(lower, higher, {4, 2}), std::invalid_argument); + + auto key = array({0, 0}, {1, 2}); + CHECK_THROWS_AS( + random::truncated_normal(array(-2.0), array(2.0), {1, 1}, float32, key), + std::invalid_argument); +} + +TEST_CASE("test categorical") { + auto logits = zeros({10, 20}); + + using random::categorical; + + // Invalid axes + CHECK_THROWS(categorical(logits, 2)); + CHECK_THROWS(categorical(logits, -3)); + + // Invalid requested shapes + CHECK_THROWS(categorical(logits, 1, std::vector{1})); + CHECK_THROWS(categorical(logits, 1, std::vector{11})); + CHECK_THROWS(categorical(logits, 1, {10, 1})); + + CHECK_EQ(categorical(logits, -1).shape(), std::vector{10}); + CHECK_EQ(categorical(logits, 0).shape(), std::vector{20}); + CHECK_EQ(categorical(logits, 1).shape(), std::vector{10}); + + auto out = categorical(logits); + CHECK_EQ(out.shape(), std::vector{10}); + CHECK_EQ(out.dtype(), uint32); + CHECK(max(out).item() < 20); + + out = categorical(logits, 0, {5, 20}); + CHECK_EQ(out.shape(), std::vector{5, 20}); + CHECK(max(out).item() < 10); + + float inf = std::numeric_limits::infinity(); + logits = array({1.0f, -2.0f, inf, 4.0f, 3.0f}); + CHECK_EQ(categorical(logits).item(), 2); + + logits = array({-inf, -2.0f, -inf, -inf}); + CHECK_EQ(categorical(logits).item(), 1); + + logits = zeros({5, 4, 3}); + CHECK_EQ(categorical(logits, -1, 7).shape(), std::vector{5, 4, 7}); + CHECK_EQ(categorical(logits, -2, 7).shape(), std::vector{5, 3, 7}); + CHECK_EQ(categorical(logits, -3, 7).shape(), std::vector{4, 3, 7}); +} diff --git a/tests/tests.cpp b/tests/tests.cpp new file mode 100644 index 0000000000..0ad8b60180 --- /dev/null +++ b/tests/tests.cpp @@ -0,0 +1,22 @@ +#define DOCTEST_CONFIG_IMPLEMENT +#include "doctest/doctest.h" + +#include + +#include "mlx/mlx.h" + +using namespace mlx::core; + +int main(int argc, char** argv) { + doctest::Context context; + + const char* device = std::getenv("DEVICE"); + if (device != nullptr && std::string(device) == "cpu") { + set_default_device(Device::cpu); + } else if (metal::is_available()) { + set_default_device(Device::gpu); + } + + context.applyCommandLine(argc, argv); + return context.run(); +} diff --git a/tests/vmap_tests.cpp b/tests/vmap_tests.cpp new file mode 100644 index 0000000000..6ccc25d473 --- /dev/null +++ b/tests/vmap_tests.cpp @@ -0,0 +1,248 @@ +#include "doctest/doctest.h" + +#include "mlx/mlx.h" + +using namespace mlx::core; + +TEST_CASE("test simple vmap") { + // vmap reshape + { + auto vfun = vmap([](array input) { return reshape(input, {2, 2}); }); + auto x = zeros({3, 4}); + CHECK(array_equal(vfun(x), zeros({3, 2, 2})).item()); + + x = array({0, 1, 2, 3, 4, 5, 6, 7}, {2, 2, 2}); + vfun = vmap([](array input) { return reshape(input, {4}); }); + auto expected = array({0, 1, 2, 3, 4, 5, 6, 7}, {2, 4}); + CHECK(array_equal(vfun(x), expected).item()); + + vfun = vmap([](array input) { return reshape(input, {4}); }, 1); + expected = array({0, 1, 4, 5, 2, 3, 6, 7}, {2, 4}); + CHECK(array_equal(vfun(x), expected).item()); + + vfun = vmap([](array input) { return reshape(input, {4}); }, 1, 1); + expected = array({0, 2, 1, 3, 4, 6, 5, 7}, {4, 2}); + CHECK(array_equal(vfun(x), expected).item()); + } + + // vmap broadcast + { + auto fun = [](array input) { return broadcast_to(input, {4, 2}); }; + + CHECK_THROWS_AS(vmap(fun, 0, -1), std::invalid_argument); + CHECK_THROWS_AS(vmap(fun, -1, 0), std::invalid_argument); + + auto vfun = vmap(fun, -1, -1); + auto x = zeros({2}); + CHECK(array_equal(vfun(x), zeros({4, 2})).item()); + + vfun = vmap(fun); + x = zeros({3, 2}); + CHECK(array_equal(vfun(x), zeros({3, 4, 2})).item()); + + vfun = vmap(fun, 0, 1); + CHECK(array_equal(vfun(x), zeros({4, 3, 2})).item()); + + vfun = vmap(fun, 0, 2); + CHECK(array_equal(vfun(x), zeros({4, 2, 3})).item()); + + vfun = vmap(fun, 0, 2); + x = zeros({2, 3}); + CHECK_THROWS_AS(vfun(x), std::invalid_argument); + + x = zeros({2, 3}); + vfun = vmap(fun, 1); + CHECK(array_equal(vfun(x), zeros({3, 4, 2})).item()); + + vfun = vmap(fun, 1, 1); + CHECK(array_equal(vfun(x), zeros({4, 3, 2})).item()); + + vfun = vmap(fun, 1, 2); + CHECK(array_equal(vfun(x), zeros({4, 2, 3})).item()); + } + + // vmap transpose + { + auto fun = [](array input) { return transpose(input); }; + auto vfun = vmap(fun); + auto x = array({0, 1, 2, 3, 4, 5}, {3, 2}); + CHECK(array_equal(vfun(x), x).item()); + + vfun = vmap(fun, 0, 1); + CHECK(array_equal(vfun(x), transpose(x)).item()); + + x = array({0, 1, 2, 3, 4, 5, 6, 7}, {2, 2, 2}); + vfun = vmap(fun); + CHECK(array_equal(vfun(x), transpose(x, {0, 2, 1})).item()); + + vfun = vmap(fun, 1, 1); + CHECK(array_equal(vfun(x), transpose(x, {2, 1, 0})).item()); + + vfun = vmap(fun, 2, 2); + CHECK(array_equal(vfun(x), transpose(x, {1, 0, 2})).item()); + + // vmap twice + x = array( + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, {2, 2, 2, 2}); + vfun = vmap(vmap(fun)); + CHECK(array_equal(vfun(x), transpose(x, {0, 1, 3, 2})).item()); + } + + // vmap add + { + auto fun = [](std::vector inputs) { + auto out = add(inputs[0], inputs[1]); + return std::vector{out}; + }; + + auto vfun = vmap(fun); + array x({1.0, 2.0}, {2, 1}); + array y({2.0, 3.0}, {2, 1}); + auto out = vfun({x, y})[0]; + CHECK(array_equal(out, array({3.0, 5.0}, {2, 1})).item()); + + x = ones({2, 1, 3}); + y = ones({3, 2}); + vfun = vmap(fun, {2, 0}); + out = vfun({x, y})[0]; + CHECK(array_equal(out, full({3, 2, 2}, 2.0)).item()); + + x = array(1.); + y = ones({3, 2}); + vfun = vmap(fun, {-1, 0}); + out = vfun({x, y})[0]; + CHECK(array_equal(out, full({3, 2}, 2.0)).item()); + + x = ones({3, 2}); + y = array(1.); + vfun = vmap(fun, {0, -1}); + out = vfun({x, y})[0]; + CHECK(array_equal(out, full({3, 2}, 2.0)).item()); + + CHECK_THROWS_AS(vmap(fun, {-1, -1}, {0}), std::invalid_argument); + CHECK_THROWS_AS(vmap(fun, {-1, 0}, {-1}), std::invalid_argument); + CHECK_THROWS_AS(vmap(fun, {0, -1}, {-1}), std::invalid_argument); + + x = array(1.); + y = array(1.); + vfun = vmap(fun, {-1, -1}, {-1}); + out = vfun({x, y})[0]; + CHECK(array_equal(out, array(2.)).item()); + + x = ones({3, 2, 1}); + y = ones({3, 2, 1}); + vfun = vmap(vmap(fun)); + out = vfun({x, y})[0]; + CHECK(array_equal(out, x + y).item()); + } + + // vmap with capturing closure + { + auto x = add(add(ones({2}), zeros({2})), zeros({2})); + auto fun = [x](const array& input) { return add(input, x); }; + + auto vfun = vmap(fun); + auto y = ones({3, 2}); + CHECK(array_equal(vfun(y), full({3, 2}, 2.0f)).item()); + } + { + auto x = ones({4}); + auto z = x + x; + auto vfun = vmap( + [z](std::vector inputs) { + return std::vector{add(z, inputs[1])}; + }, + {-1, 0}); + auto y = ones({3, 4}); + CHECK(array_equal(vfun({x, y})[0], full({3, 4}, 3.0)).item()); + } +} + +TEST_CASE("test vmap with eval") { + auto fun = [](std::vector inputs) { + auto x = inputs[0] + 1; + auto y = inputs[1] + 2; + eval(x); + auto out = add(x, y); + return std::vector{out}; + }; + + auto vfun = vmap(fun); + array x({1.0, 2.0}, {2, 1}); + array y({2.0, 3.0}, {2, 1}); + CHECK_THROWS(vfun({x, y})); + + // Ok to eval functions of non-vmapped input + x = array(1.0); + vfun = vmap(fun, {-1, 0}); + CHECK(array_equal(vfun({x, y})[0], array({6.0f, 7.0f}, {2, 1})).item()); + + // Not ok to eval function of vmapped input even with retain graph + auto fun2 = [](std::vector inputs) { + auto x = inputs[0] + 1; + auto y = inputs[1] + 2; + eval({x}, true); + auto out = add(x, y); + return std::vector{out}; + }; + x = array({1.0, 2.0}, {2, 1}); + CHECK_THROWS(vmap(fun2)({x, y})); +} + +TEST_CASE("test vmap comparison ops") { + // vmap equal + { + auto fun = [](std::vector inputs) { + return std::vector{equal(inputs[0], inputs[1])}; + }; + auto vfun = vmap(fun); + auto x = zeros({2, 3}, float32); + auto y = zeros({2, 3}, float32); + auto out = vfun({x, y})[0]; + CHECK(all(out).item()); + + vfun = vmap(fun, {0, -1}); + x = zeros({2, 3}, float32); + y = zeros({3}, float32); + out = vfun({x, y})[0]; + CHECK(all(out).item()); + + vfun = vmap(fun, {0, -1}); + x = array({0, 0, 0, 1, 1, 1}, {2, 3}); + y = zeros({3}, float32); + out = vfun({x, y})[0]; + auto expected = array({true, true, true, false, false, false}, {2, 3}); + CHECK(array_equal(out, expected).item()); + } +} + +TEST_CASE("test vmap creation ops") { + // vmap astype + { + auto fun = [](array in) { return astype(in, int32); }; + auto x = zeros({2, 3}, float32); + auto out = vmap(fun)(x); + CHECK_EQ(out.dtype(), int32); + CHECK(array_equal(out, zeros({2, 3}, int32)).item()); + } + + // vmap full + { + auto fun = [](array in) { return full({2}, in); }; + auto x = array({1, 2, 3}); + auto out = vmap(fun)(x); + auto expected = array({1, 1, 2, 2, 3, 3}, {3, 2}); + CHECK(array_equal(out, expected).item()); + + x = array({1, 2, 3}, {3, 1}); + out = vmap(fun)(x); + expected = array({1, 1, 2, 2, 3, 3}, {3, 2}); + CHECK(array_equal(out, expected).item()); + + x = array({1, 2, 3}, {1, 3}); + CHECK_THROWS_AS(vmap(fun)(x), std::invalid_argument); + out = vmap(fun, 1, 1)(x); + expected = array({1, 2, 3, 1, 2, 3}, {2, 3}); + CHECK(array_equal(out, expected).item()); + } +}