@@ -2133,10 +2133,13 @@ hloShardingFromTensorShardingAttr(mlir::sdy::TensorShardingAttr attr,
2133
2133
xla::sdy::convertToHloSharding (attr, get_mesh_attr, manual_axes));
2134
2134
}
2135
2135
2136
- extern " C" mlir::sdy::TensorShardingAttr
2137
- hloShardingToTensorShardingAttr (const xla::HloSharding *hloSharding,
2138
- mlir::sdy::MeshAttr meshAttr, int64_t rank,
2139
- const bool *isClosed, const int64_t *priority) {
2136
+ // XXX: This is incorrect for multiple meshes. We need to use the current mesh
2137
+ // to generate this instead of the global mesh Currently we are storing only a
2138
+ // single mesh, so we can just use this.
2139
+ extern " C" mlir::sdy::TensorShardingAttr hloShardingToTensorShardingAttr (
2140
+ mlir::MLIRContext *context, const xla::HloSharding *hloSharding,
2141
+ mlir::StringAttr meshName, mlir::sdy::MeshAttr meshAttr, int64_t rank,
2142
+ const bool *isClosed, const int64_t *priority) {
2140
2143
const SmallDenseMap<int64_t , StringRef> deviceIdToMaximalMeshName =
2141
2144
SmallDenseMap<int64_t , StringRef>();
2142
2145
mlir::sdy::TensorShardingAttr tensorShardingAttr =
@@ -2157,7 +2160,9 @@ hloShardingToTensorShardingAttr(const xla::HloSharding *hloSharding,
2157
2160
isClosed[i], dimPriority));
2158
2161
}
2159
2162
2160
- return tensorShardingAttr;
2163
+ return mlir::sdy::TensorShardingAttr::get (
2164
+ context, meshName, tensorShardingAttr.getDimShardings (),
2165
+ tensorShardingAttr.getReplicatedAxes ());
2161
2166
}
2162
2167
2163
2168
#pragma endregion
0 commit comments