Skip to content

#sdy add temporary pass in XLA SDY import to not allow meshes with different axis shapes. #415

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions shardy/round_trip_import/import_shardy_attrs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.

#include <cassert>
#include <cstdint>
#include <iterator>
#include <memory>
#include <optional>

Expand Down Expand Up @@ -143,6 +144,26 @@ void convertShardyAttrs(FuncOp funcOp, IRRewriter& rewriter) {
});
}

LogicalResult hasDifferentMeshShapes(
SmallVector<int64_t>& meshesWithAxisSizes, MeshAttr meshAttr,
ModuleOp moduleOp) {
ArrayRef<MeshAxisAttr> axes = meshAttr.getAxes();
SmallVector<int64_t> sizes;
if (!axes.empty()) {
sizes.reserve(axes.size());
llvm::transform(axes, std::back_inserter(sizes),
[](MeshAxisAttr axis) { return axis.getSize(); });
if (meshesWithAxisSizes.empty()) {
meshesWithAxisSizes = sizes;
} else if (meshesWithAxisSizes != sizes) {
return moduleOp.emitError(
"JAX does not support multiple meshes with different "
"axis sizes.");
}
}
return mlir::success();
}

class SdyRoundTripImportShardyAttrsPass
: public PassWrapper<SdyRoundTripImportShardyAttrsPass,
OperationPass<ModuleOp>> {
Expand All @@ -167,8 +188,17 @@ class SdyRoundTripImportShardyAttrsPass
// Insert the meshes before any functions.
rewriter.setInsertionPointToStart(moduleOp.getBody());
SymbolTable symbolTable(moduleOp);
// TODO(b/402371282): allow different meshes with different axis sizes
// during import. Either support propagation through different mesh shapes
// or error during propagation.
SmallVector<int64_t> meshesWithAxisSizes;
for (NamedAttribute mesh : sdyMeshes) {
auto meshAttr = cast<MeshAttr>(mesh.getValue());
if (hasDifferentMeshShapes(meshesWithAxisSizes, meshAttr, moduleOp)
.failed()) {
signalPassFailure();
return;
}
symbolTable.insert(
rewriter.create<MeshOp>(moduleOp.getLoc(), mesh.getName(), meshAttr));
}
Expand Down
38 changes: 19 additions & 19 deletions shardy/round_trip_import/test/sdy_round_trip_import_pipeline.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2,34 +2,34 @@

// CHECK-LABEL: module @multiple_func_result_shardings
module @multiple_func_result_shardings attributes {mhlo.frontend_attributes = {xla.sdy.meshes =
"{mesh = #sdy.mesh<[\\\22a\\\22=8, \\\22b\\\22=8, \\\22c\\\22=8]>, mesh2 = #sdy.mesh<[\\\22a\\\22=1, \\\22b\\\22=4, \\\22c\\\22=1]>}"}} {
// CHECK: sdy.mesh @mesh = <["a"=8, "b"=8, "c"=8]>
"{mesh = #sdy.mesh<[\\\22a\\\22=1, \\\22b\\\22=4, \\\22c\\\22=4]>, mesh2 = #sdy.mesh<[\\\22x\\\22=1, \\\22y\\\22=4, \\\22z\\\22=4]>}"}} {
// CHECK: sdy.mesh @mesh = <["a"=1, "b"=4, "c"=4]>

// CHECK: sdy.mesh @mesh2 = <["a"=1, "b"=4, "c"=1]>
// CHECK: sdy.mesh @mesh2 = <["x"=1, "y"=4, "z"=4]>

// CHECK-LABEL: func @func_results_with_sharding
// CHECK-SAME: %arg0: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"b"}p2]>},
// CHECK-SAME: %arg1: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}p1]>},
// CHECK-SAME: %arg1: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{}]>},
// CHECK-SAME: %arg2: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"c"}p0]>}
// CHECK-SAME: ) -> (
// CHECK-SAME: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}p0]>},
// CHECK-SAME: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"b"}p0]>},
// CHECK-SAME: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"b"}p2]>},
// CHECK-SAME: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}p1]>},
// CHECK-SAME: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"b"}p1]>},
// CHECK-SAME: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"c"}p0]>},
// CHECK-SAME: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"b"}p2]>},
// CHECK-SAME: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}p3]>}) {
// CHECK-SAME: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"c"}p3]>}) {
// CHECK-NEXT: return %arg0, %arg1, %arg0, %arg1, %arg1, %arg2
// CHECK-NEXT: }
func.func @func_results_with_sharding(
%arg0: tensor<32xi32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh, [{\\\22b\\\22}p2]>"}},
%arg1: tensor<32xi32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh, [{\\\22a\\\22}p1]>"}},
%arg1: tensor<32xi32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh, [{}]>"}},
%arg2: tensor<32xi32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh, [{\\\22c\\\22}p0]>"}}
) -> (tensor<32xi32>, tensor<32xi32>, tensor<32xi32>, tensor<32xi32>, tensor<32xi32>, tensor<32xi32>) {
%0 = stablehlo.custom_call @xla.sdy.FuncResultSharding(%arg0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22a\\\22}p0]>]>"}} : (tensor<32xi32>) -> tensor<32xi32>
%0 = stablehlo.custom_call @xla.sdy.FuncResultSharding(%arg0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22b\\\22}p0]>]>"}} : (tensor<32xi32>) -> tensor<32xi32>
%1 = stablehlo.custom_call @xla.sdy.FuncResultSharding(%arg1) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22b\\\22}p2]>]>"}} : (tensor<32xi32>) -> tensor<32xi32>
%2 = stablehlo.custom_call @xla.sdy.FuncResultSharding(%arg0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22a\\\22}p1]>]>"}} : (tensor<32xi32>) -> tensor<32xi32>
%2 = stablehlo.custom_call @xla.sdy.FuncResultSharding(%arg0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22b\\\22}p1]>]>"}} : (tensor<32xi32>) -> tensor<32xi32>
%3 = stablehlo.custom_call @xla.sdy.FuncResultSharding(%arg1) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22c\\\22}p0]>]>"}} : (tensor<32xi32>) -> tensor<32xi32>
%4 = stablehlo.custom_call @xla.sdy.FuncResultSharding(%arg2) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22a\\\22}p3]>]>"}} : (tensor<32xi32>) -> tensor<32xi32>
%4 = stablehlo.custom_call @xla.sdy.FuncResultSharding(%arg2) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22c\\\22}p3]>]>"}} : (tensor<32xi32>) -> tensor<32xi32>
return %0, %1, %2, %3, %1, %4 : tensor<32xi32>, tensor<32xi32>, tensor<32xi32>, tensor<32xi32>, tensor<32xi32>, tensor<32xi32>
}

Expand Down Expand Up @@ -83,19 +83,19 @@ module @multiple_func_result_shardings attributes {mhlo.frontend_attributes = {x
}

// CHECK-LABEL: func @discard_shardings_on_unknown_ops(
// CHECK-SAME: %arg0: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}p0]>})
// CHECK-SAME: -> (tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}p4]>}) {
// CHECK-SAME: %arg0: tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"b"}p0]>})
// CHECK-SAME: -> (tensor<32xi32> {sdy.sharding = #sdy.sharding<@mesh, [{"c"}p4]>}) {
func.func @discard_shardings_on_unknown_ops(
%arg0: tensor<32xi32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh, [{\\\22a\\\22}p0]>"}}
%arg0: tensor<32xi32> {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding<@mesh, [{\\\22b\\\22}p0]>"}}
) -> tensor<32xi32> {
// CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %arg0, %arg0 : tensor<32xi32>
// CHECK-NEXT: %[[SHARDING:.*]] = sdy.sharding_constraint %[[ADD]] <@mesh, [{"a"}p2]> : tensor<32xi32>
// CHECK-NEXT: %[[SHARDING:.*]] = sdy.sharding_constraint %[[ADD]] <@mesh, [{"c"}p2]> : tensor<32xi32>
// CHECK-NEXT: %[[UNKNOWN:.*]] = stablehlo.custom_call @UnknownCustomCall(%[[SHARDING]]) : (tensor<32xi32>) -> tensor<32xi32>
// CHECK-NEXT: return %[[UNKNOWN]]
%0 = stablehlo.add %arg0, %arg0 {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22a\\\22}p1]>]>"}} : tensor<32xi32>
%1 = stablehlo.custom_call @Sharding(%0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22a\\\22}p2]>]>"}} : (tensor<32xi32>) -> tensor<32xi32>
%2 = stablehlo.custom_call @UnknownCustomCall(%1) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22a\\\22}p3]>]>"}} : (tensor<32xi32>) -> tensor<32xi32>
%3 = stablehlo.custom_call @xla.sdy.FuncResultSharding(%2) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22a\\\22}p4]>]>"}} : (tensor<32xi32>) -> tensor<32xi32>
%0 = stablehlo.add %arg0, %arg0 {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22b\\\22}p1]>]>"}} : tensor<32xi32>
%1 = stablehlo.custom_call @Sharding(%0) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22c\\\22}p2]>]>"}} : (tensor<32xi32>) -> tensor<32xi32>
%2 = stablehlo.custom_call @UnknownCustomCall(%1) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22b\\\22}p3]>]>"}} : (tensor<32xi32>) -> tensor<32xi32>
%3 = stablehlo.custom_call @xla.sdy.FuncResultSharding(%2) {mhlo.frontend_attributes = {xla.sdy.sharding = "#sdy.sharding_per_value<[<@mesh, [{\\\22c\\\22}p4]>]>"}} : (tensor<32xi32>) -> tensor<32xi32>
return %3 : tensor<32xi32>
}

Expand Down