Skip to content

Commit bcb0282

Browse files
authored
feat: JLL changes for sdy.sharding_constraint (#799)
1 parent 9a150f7 commit bcb0282

File tree

1 file changed

+10
-5
lines changed

1 file changed

+10
-5
lines changed

deps/ReactantExtra/API.cpp

+10-5
Original file line numberDiff line numberDiff line change
@@ -2133,10 +2133,13 @@ hloShardingFromTensorShardingAttr(mlir::sdy::TensorShardingAttr attr,
21332133
xla::sdy::convertToHloSharding(attr, get_mesh_attr, manual_axes));
21342134
}
21352135

2136-
extern "C" mlir::sdy::TensorShardingAttr
2137-
hloShardingToTensorShardingAttr(const xla::HloSharding *hloSharding,
2138-
mlir::sdy::MeshAttr meshAttr, int64_t rank,
2139-
const bool *isClosed, const int64_t *priority) {
2136+
// XXX: This is incorrect for multiple meshes. We need to use the current mesh
2137+
// to generate this instead of the global mesh Currently we are storing only a
2138+
// single mesh, so we can just use this.
2139+
extern "C" mlir::sdy::TensorShardingAttr hloShardingToTensorShardingAttr(
2140+
mlir::MLIRContext *context, const xla::HloSharding *hloSharding,
2141+
mlir::StringAttr meshName, mlir::sdy::MeshAttr meshAttr, int64_t rank,
2142+
const bool *isClosed, const int64_t *priority) {
21402143
const SmallDenseMap<int64_t, StringRef> deviceIdToMaximalMeshName =
21412144
SmallDenseMap<int64_t, StringRef>();
21422145
mlir::sdy::TensorShardingAttr tensorShardingAttr =
@@ -2157,7 +2160,9 @@ hloShardingToTensorShardingAttr(const xla::HloSharding *hloSharding,
21572160
isClosed[i], dimPriority));
21582161
}
21592162

2160-
return tensorShardingAttr;
2163+
return mlir::sdy::TensorShardingAttr::get(
2164+
context, meshName, tensorShardingAttr.getDimShardings(),
2165+
tensorShardingAttr.getReplicatedAxes());
21612166
}
21622167

21632168
#pragma endregion

0 commit comments

Comments
 (0)