Skip to content

Commit

Permalink
awni's commit files
Browse files Browse the repository at this point in the history
  • Loading branch information
awni committed Nov 29, 2023
1 parent e411fca commit 8ca7f9e
Show file tree
Hide file tree
Showing 130 changed files with 30,159 additions and 0 deletions.
87 changes: 87 additions & 0 deletions .clang-format
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
---
AccessModifierOffset: -1
AlignAfterOpenBracket: AlwaysBreak
AlignConsecutiveAssignments: false
AlignConsecutiveDeclarations: false
AlignEscapedNewlinesLeft: true
AlignOperands: false
AlignTrailingComments: false
AllowAllParametersOfDeclarationOnNextLine: false
AllowShortBlocksOnASingleLine: false
AllowShortCaseLabelsOnASingleLine: false
AllowShortFunctionsOnASingleLine: Empty
AllowShortIfStatementsOnASingleLine: false
AllowShortLoopsOnASingleLine: false
AlwaysBreakAfterReturnType: None
AlwaysBreakBeforeMultilineStrings: true
AlwaysBreakTemplateDeclarations: true
BinPackArguments: false
BinPackParameters: false
BraceWrapping:
AfterClass: false
AfterControlStatement: false
AfterEnum: false
AfterFunction: false
AfterNamespace: false
AfterObjCDeclaration: false
AfterStruct: false
AfterUnion: false
BeforeCatch: false
BeforeElse: false
IndentBraces: false
BreakBeforeBinaryOperators: None
BreakBeforeBraces: Attach
BreakBeforeTernaryOperators: true
BreakConstructorInitializersBeforeComma: false
BreakAfterJavaFieldAnnotations: false
BreakStringLiterals: false
ColumnLimit: 80
CommentPragmas: '^ IWYU pragma:'
ConstructorInitializerAllOnOneLineOrOnePerLine: true
ConstructorInitializerIndentWidth: 4
ContinuationIndentWidth: 4
Cpp11BracedListStyle: true
DerivePointerAlignment: false
DisableFormat: false
ForEachMacros: [ FOR_EACH, FOR_EACH_R, FOR_EACH_RANGE, ]
IncludeCategories:
- Regex: '^<.*\.h(pp)?>'
Priority: 1
- Regex: '^<.*'
Priority: 2
- Regex: '.*'
Priority: 3
IndentCaseLabels: true
IndentWidth: 2
IndentWrappedFunctionNames: false
KeepEmptyLinesAtTheStartOfBlocks: false
MacroBlockBegin: ''
MacroBlockEnd: ''
MaxEmptyLinesToKeep: 1
NamespaceIndentation: None
ObjCBlockIndentWidth: 2
ObjCSpaceAfterProperty: false
ObjCSpaceBeforeProtocolList: false
PenaltyBreakBeforeFirstCallParameter: 1
PenaltyBreakComment: 300
PenaltyBreakFirstLessLess: 120
PenaltyBreakString: 1000
PenaltyExcessCharacter: 1000000
PenaltyReturnTypeOnItsOwnLine: 200
PointerAlignment: Left
ReflowComments: true
SortIncludes: true
SpaceAfterCStyleCast: false
SpaceBeforeAssignmentOperators: true
SpaceBeforeParens: ControlStatements
SpaceInEmptyParentheses: false
SpacesBeforeTrailingComments: 1
SpacesInAngles: false
SpacesInContainerLiterals: true
SpacesInCStyleCastParentheses: false
SpacesInParentheses: false
SpacesInSquareBrackets: false
Standard: Cpp11
TabWidth: 8
UseTab: Never
...
3 changes: 3 additions & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
include CMakeLists.txt
recursive-include mlx/ *
include python/src/*
198 changes: 198 additions & 0 deletions benchmarks/cpp/irregular_strides.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
#include <iostream>
#include <sstream>

#include "mlx/mlx.h"
#include "time_utils.h"

using namespace mlx::core;

void time_irregular_binary_ops_1D() {
auto device = default_device();
int size = 1000000;
int step = 2;
auto a = random::uniform({size});
auto b = random::uniform({size});
eval(a, b);
a = slice(a, {0}, {size}, {step});
b = slice(b, {0}, {size}, {step});
TIMEM("1D strided", add, a, b, device);
}

void time_irregular_binary_ops_2D() {
auto device = default_device();
int size = 2048;
auto a = random::uniform({size, size});
auto b = random::uniform({size, size});
eval(a, b);
TIMEM("2D regular", add, a, b, device);

b = transpose(b);
eval(b);
TIMEM("2D transpose", add, a, b, device);

b = random::uniform({size});
eval(b);
TIMEM("2D broadcast dim 0", add, a, b, device);

b = reshape(b, {size, 1});
eval(b);
TIMEM("2D broadcast dim 1", add, a, b, device);
}

void time_irregular_binary_ops_3D() {
auto device = default_device();
int d0 = 32;
int d1 = 512;
int d2 = 512;
auto a = random::uniform({d0, d1, d2});
auto b = random::uniform({d0, d1, d2});
TIMEM("3D regular", add, a, b, device);

b = transpose(b, {0, 2, 1});
TIMEM("3D transpose", add, a, b, device);

b = random::uniform({d1, d2});
TIMEM("3D broadcast dim 0", add, a, b, device);

b = random::uniform({d0, 1, d2});
TIMEM("3D broadcast dim 1", add, a, b, device);

b = random::uniform({d0, d1, 1});
TIMEM("3D broadcast dim 2", add, a, b, device);

b = random::uniform({d2});
TIMEM("3D broadcast dims 0, 1", add, a, b, device);

b = random::uniform({d1, 1});
TIMEM("3D broadcast dims 0, 2", add, a, b, device);

b = random::uniform({d0, 1, 1});
TIMEM("3D broadcast dims 1, 2", add, a, b, device);
}

void time_irregular_binary_ops_4D() {
auto device = default_device();
std::vector<int> shape = {8, 8, 512, 512};
auto a = random::uniform(shape);
auto b = random::uniform(shape);

TIMEM("4D regular", add, a, b, device);

b = transpose(b, {0, 1, 3, 2});
TIMEM("4D transpose", add, a, b, device);

std::string om = "4D broadcast dims ";
for (int i = 0; i < shape.size(); ++i) {
shape[i] = 1;
b = random::uniform(shape);
std::ostringstream msg;
msg << om << i;
TIMEM(msg.str(), add, a, b, device);

for (int j = i + 1; j < shape.size(); ++j) {
shape[j] = 1;
std::ostringstream msg;
msg << om << i << ", " << j;
b = random::uniform(shape);
TIMEM(msg.str(), add, a, b, device);
shape[j] = a.shape(j);

for (int k = j + 1; k < shape.size(); ++k) {
shape[k] = 1;
std::ostringstream msg;
msg << om << i << ", " << j << ", " << k;
b = random::uniform(shape);
TIMEM(msg.str(), add, a, b, device);
shape[k] = a.shape(k);
}
}
shape[i] = a.shape(i);
}
}

void time_irregular_reshape() {
auto device = default_device();
std::vector<int> shape;
auto reshape_fn = [&shape, device](const array& a) {
return reshape(a, shape, device);
};

int size = 64;
int d = 2 * size;

auto a = random::uniform({d, d, d});

shape = {8 * size, size, size};
TIMEM("3D contiguous", reshape_fn, a);

a = transpose(a);
shape = {8 * size, size, size};
TIMEM("3D transpose", reshape_fn, a);

a = transpose(a, {1, 2, 0});
shape = {8 * size, size, size};
TIMEM("3D transpose dims 1 2", reshape_fn, a);

a = broadcast_to(random::uniform({d, d}), {d, d, d});
TIMEM("3D broadcast dim 0", reshape_fn, a);

a = broadcast_to(random::uniform({d, 1, d}), {d, d, d});
TIMEM("3D broadcast dim 1", reshape_fn, a);

a = broadcast_to(random::uniform({d, d, 1}), {d, d, d});
TIMEM("3D broadcast dim 2", reshape_fn, a);

a = broadcast_to(random::uniform({d}), {d, d, d});
TIMEM("3D broadcast dims 0, 1", reshape_fn, a);

a = broadcast_to(random::uniform({d, 1}), {d, d, d});
TIMEM("3D broadcast dims 0, 2", reshape_fn, a);

a = broadcast_to(random::uniform({d, 1, 1}), {d, d, d});
TIMEM("3D broadcast dims 1, 2", reshape_fn, a);

a = broadcast_to(random::uniform({1, 1, 1}), {d, d, d});
TIMEM("3D broadcast dims 1, 2, 3", reshape_fn, a);
}

void time_irregular_astype_1D() {
auto device = default_device();
int size = 1000000;
int step = 2;
auto a = random::uniform({size});
a = slice(a, {0}, {size}, {step});
TIMEM("1D strided", astype, a, int32, device);
}

void time_irregular_astype_2D() {
auto device = default_device();
int size = 2048;
std::vector<int> shape = {size, size};

auto a = random::uniform(shape);
TIMEM("2D regular", astype, a, int32, device);

a = transpose(a);
TIMEM("2D transpose", astype, a, int32, device);

a = broadcast_to(random::uniform({size}), shape);
TIMEM("2D broadcast dim 0", astype, a, int32, device);

a = broadcast_to(random::uniform({size, 1}), shape);
TIMEM("2D broadcast dim 1", astype, a, int32, device);
}

int main(int argc, char** argv) {
if (argc > 1) {
bool use_gpu = !strcmp(argv[1], "gpu");
set_default_device(use_gpu ? Device::gpu : Device::cpu);
}
std::cout << "Benchmarks for " << default_device() << std::endl;
time_irregular_binary_ops_1D();
time_irregular_binary_ops_2D();
time_irregular_binary_ops_3D();
time_irregular_binary_ops_4D();
time_irregular_reshape();
time_irregular_astype_1D();
time_irregular_astype_2D();
}
Loading

0 comments on commit 8ca7f9e

Please sign in to comment.