Skip to content

Commit

Permalink
[RELAX][PASS] Convert layout pass and ops enhanced to support sub ind…
Browse files Browse the repository at this point in the history
…exing (#17568)

Convert layout pass and ops enhanced to support sub indexing

Majority of the operations made compatible with custom layouts.
Incompatible ops will fallback to regular layout.

Conv1D, Conv3D, Pool1D, Pool3D, AdaptiveAvgPool1D, AdaptiveAvgPool3D
are left unchanged now. 2D networks are expected to work now.
  • Loading branch information
srkreddy1238 authored Jan 19, 2025
1 parent 8b59368 commit d641354
Show file tree
Hide file tree
Showing 15 changed files with 3,398 additions and 42 deletions.
3 changes: 2 additions & 1 deletion python/tvm/script/ir_builder/tir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1688,9 +1688,10 @@ def index_map(
mapping: Callable,
*,
inverse_index_map: Optional[Callable] = None,
index_dtype: str = "int64",
) -> IndexMap:
"""Create a TIR Index mapping"""
return IndexMap.from_func(mapping, inverse_index_map=inverse_index_map)
return IndexMap.from_func(mapping, inverse_index_map=inverse_index_map, index_dtype=index_dtype)


def target(
Expand Down
4 changes: 4 additions & 0 deletions src/relax/op/image/resize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,10 @@ InferLayoutOutput InferLayoutResize2d(const Call& call,
} else {
// We dont have a desired layout for resize2d, propagate from the input instead.
data_layout = GetLayoutDecision(var_layout_map, call->args[0]);
// Not handling sub indexing now.
if (data_layout->layout.ndim() != data_layout->layout.ndim_primal()) {
data_layout = LayoutDecision(InitialLayout(4));
}
new_attrs->layout = TransposeLike(attrs->layout, InitialLayout(4), data_layout->layout).name();
}
return InferLayoutOutput({data_layout, InitialNLayout(call->args[1])}, {data_layout},
Expand Down
75 changes: 52 additions & 23 deletions src/relax/op/nn/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -308,30 +308,59 @@ InferLayoutOutput InferLayoutConv2d(const Call& call,
Layout desired_data_layout = (*it).second[0];
Layout desired_weight_layout = (*it).second[1];
Layout desired_output_layout = (*it).second.size() == 3 ? (*it).second[2] : (*it).second[0];
ICHECK_EQ(desired_data_layout.ndim(), desired_data_layout.ndim_primal()) << "Axis swap only";
ICHECK_EQ(desired_weight_layout.ndim(), desired_weight_layout.ndim_primal())
<< "Axis swap only";
ICHECK_EQ(desired_output_layout.ndim(), desired_output_layout.ndim_primal())
<< "Axis swap only";
data_layout = TransposeLike(InitialLayout(4), attrs->data_layout, desired_data_layout);
weight_layout = TransposeLike(InitialLayout(4), attrs->kernel_layout, desired_weight_layout);
output_layout = TransposeLike(InitialLayout(4), attrs->out_layout, desired_output_layout);
new_attrs->data_layout = (*it).second[0];
new_attrs->kernel_layout = (*it).second[1];
new_attrs->out_layout = (*it).second.size() == 3 ? (*it).second[2] : (*it).second[0];
} else {
// We don't have a desired layout for conv2d.
// We can just propagate the layout from the input.
data_layout = GetLayoutDecision(var_layout_map, call->args[0]);
weight_layout = GetLayoutDecision(var_layout_map, call->args[1]);
output_layout = data_layout;
new_attrs->data_layout =
TransposeLike(attrs->data_layout, InitialLayout(4), data_layout->layout).name();
new_attrs->kernel_layout =
TransposeLike(attrs->kernel_layout, InitialLayout(4), weight_layout->layout).name();
new_attrs->out_layout =
TransposeLike(attrs->out_layout, InitialLayout(4), output_layout->layout).name();
tir::Layout input_layout(attrs->data_layout, DataType::Int(64));
tir::Layout kernel_layout(attrs->kernel_layout, DataType::Int(64));
tir::Layout out_layout(attrs->out_layout, DataType::Int(64));

if ((desired_data_layout.ndim() == input_layout.ndim()) &&
(desired_weight_layout.ndim() == kernel_layout.ndim()) &&
(desired_output_layout.ndim() == out_layout.ndim())) {
// Just a transpose
data_layout = TransposeLike(InitialLayout(4), attrs->data_layout, desired_data_layout);
weight_layout = TransposeLike(InitialLayout(4), attrs->kernel_layout, desired_weight_layout);
output_layout = TransposeLike(InitialLayout(4), attrs->out_layout, desired_output_layout);
new_attrs->data_layout = (*it).second[0];
new_attrs->kernel_layout = (*it).second[1];
new_attrs->out_layout = (*it).second.size() == 3 ? (*it).second[2] : (*it).second[0];
return InferLayoutOutput({data_layout, weight_layout}, {output_layout}, Attrs(new_attrs));
} else {
// Layout Transform
auto data_si = GetStructInfo(call->args[0]);
auto kernel_si = GetStructInfo(call->args[1]);
TensorStructInfo data_sinfo = data_si.as<TensorStructInfo>().value();
TensorStructInfo kernel_sinfo = kernel_si.as<TensorStructInfo>().value();
Optional<ShapeExpr> data_shape = GetRef<ShapeExpr>(data_sinfo->shape.as<ShapeExprNode>());
Optional<ShapeExpr> kernel_shape = GetRef<ShapeExpr>(kernel_sinfo->shape.as<ShapeExprNode>());

bool can_data_proved =
CanProveLayoutTransform(input_layout, desired_data_layout, data_shape.value()->values);
bool can_kernel_proved = CanProveLayoutTransform(kernel_layout, desired_weight_layout,
kernel_shape.value()->values);

if (can_data_proved && can_kernel_proved) {
data_layout = TransposeSubLayoutLike(InitialLayout(4), input_layout, desired_data_layout);
weight_layout =
TransposeSubLayoutLike(InitialLayout(4), kernel_layout, desired_weight_layout);
output_layout = TransposeSubLayoutLike(InitialLayout(4), out_layout, desired_output_layout);
new_attrs->data_layout = (*it).second[0];
new_attrs->kernel_layout = (*it).second[1];
new_attrs->out_layout = (*it).second.size() == 3 ? (*it).second[2] : (*it).second[0];
return InferLayoutOutput({data_layout, weight_layout}, {output_layout}, Attrs(new_attrs));
}
}
}

// We don't have a desired layout for conv2d or desired layouts not compatible.
// We can just propagate the layout from the input.
data_layout = GetLayoutDecision(var_layout_map, call->args[0]);
weight_layout = GetLayoutDecision(var_layout_map, call->args[1]);
output_layout = data_layout;
new_attrs->data_layout =
TransposeLike(attrs->data_layout, InitialLayout(4), data_layout->layout).name();
new_attrs->kernel_layout =
TransposeLike(attrs->kernel_layout, InitialLayout(4), weight_layout->layout).name();
new_attrs->out_layout =
TransposeLike(attrs->out_layout, InitialLayout(4), output_layout->layout).name();
return InferLayoutOutput({data_layout, weight_layout}, {output_layout}, Attrs(new_attrs));
}

Expand Down
26 changes: 24 additions & 2 deletions src/relax/op/nn/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,16 @@ InferLayoutOutput InferLayoutSoftmax(const Call& call,
ICHECK(attrs) << "Invalid Call";

LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]);

// TODO(Siva): We could handle if the axis is not the sub indexed one.
if (layout->layout.ndim() != layout->layout.ndim_primal()) {
const auto* tensor_sinfo = GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
ICHECK(tensor_sinfo != nullptr) << "Invalid Call";
ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support static ndim for now";
int ndim = tensor_sinfo->ndim;
layout = LayoutDecision(InitialLayout(ndim));
}

ObjectPtr<SoftmaxAttrs> new_attrs = make_object<SoftmaxAttrs>(*attrs);
new_attrs->axis = FindAxis(layout->layout, attrs->axis);
return InferLayoutOutput({layout}, {layout}, Attrs(new_attrs));
Expand Down Expand Up @@ -290,8 +300,18 @@ InferLayoutOutput InferLayoutBatchNorm(const Call& call,
ICHECK(attrs) << "Invalid Call";

LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]);

// While dealing with sub layouts, its adviced to deal with batchnorm
// on other ways like decomposing or fusion methods.
// This handling is fail safe fallback.
const auto* input_sinfo = GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
int ndim = input_sinfo->ndim;
if (layout->layout.ndim() != layout->layout.ndim_primal()) {
layout = LayoutDecision(InitialLayout(ndim));
}

ObjectPtr<BatchNormAttrs> new_attrs = make_object<BatchNormAttrs>(*attrs);
new_attrs->axis = FindAxis(layout->layout, attrs->axis);
new_attrs->axis = FindAxis(layout->layout, (attrs->axis + ndim) % ndim);
return InferLayoutOutput(
{layout, initial_layouts[1], initial_layouts[2], initial_layouts[3], initial_layouts[4]},
{{layout, initial_layouts[3], initial_layouts[4]}}, Attrs(new_attrs));
Expand Down Expand Up @@ -353,9 +373,11 @@ InferLayoutOutput InferLayoutLayerNorm(const Call& call,

LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]);
ObjectPtr<LayerNormAttrs> new_attrs = make_object<LayerNormAttrs>(*attrs);
const auto* input_sinfo = GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
int ndim = input_sinfo->ndim;
std::vector<Integer> new_axis;
for (const auto& axis : attrs->axes) {
new_axis.push_back(FindAxis(layout->layout, axis->value));
new_axis.push_back(FindAxis(layout->layout, (axis->value + ndim) % ndim));
}
new_attrs->axes = std::move(new_axis);
return InferLayoutOutput({layout, initial_layouts[1], initial_layouts[2]}, {layout},
Expand Down
32 changes: 32 additions & 0 deletions src/relax/op/nn/pooling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,23 @@ InferLayoutOutput InferLayoutPool2d(const Call& call,

LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]);
ObjectPtr<Pool2DAttrs> new_attrs = make_object<Pool2DAttrs>(*attrs);

if (layout->layout.ndim() != layout->layout.ndim_primal()) {
tir::Layout in_layout(attrs->layout, DataType::Int(64));
auto desired_layout = TransposeSubLayoutLike(attrs->layout, InitialLayout(4), layout->layout);
auto data_si = GetStructInfo(call->args[0]);
TensorStructInfo data_sinfo = data_si.as<TensorStructInfo>().value();
Optional<ShapeExpr> data_shape = GetRef<ShapeExpr>(data_sinfo->shape.as<ShapeExprNode>());
if (CanProveLayoutTransform(in_layout, desired_layout, data_shape.value()->values)) {
// Not handling out_layout being different from in_layout now. Any use case ?
new_attrs->layout = desired_layout.name();
new_attrs->out_layout = desired_layout.name();
return InferLayoutOutput({layout}, {layout}, Attrs(new_attrs));
} else {
layout = InitialLayout(4);
}
}

new_attrs->layout = TransposeLike(attrs->layout, InitialLayout(4), layout->layout).name();
new_attrs->out_layout = TransposeLike(attrs->out_layout, InitialLayout(4), layout->layout).name();
return InferLayoutOutput({layout}, {layout}, Attrs(new_attrs));
Expand Down Expand Up @@ -583,6 +600,21 @@ InferLayoutOutput InferLayoutAdaptiveAvgPool2D(const Call& call,

LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]);
ObjectPtr<AdaptivePool2DAttrs> new_attrs = make_object<AdaptivePool2DAttrs>(*attrs);
if (layout->layout.ndim() != layout->layout.ndim_primal()) {
tir::Layout in_layout(attrs->layout, DataType::Int(64));
auto desired_layout = TransposeSubLayoutLike(attrs->layout, InitialLayout(4), layout->layout);
auto data_si = GetStructInfo(call->args[0]);
TensorStructInfo data_sinfo = data_si.as<TensorStructInfo>().value();
Optional<ShapeExpr> data_shape = GetRef<ShapeExpr>(data_sinfo->shape.as<ShapeExprNode>());
if (CanProveLayoutTransform(in_layout, desired_layout, data_shape.value()->values)) {
// Not handling out_layout being different from in_layout now. Any use case ?
new_attrs->layout = desired_layout.name();
new_attrs->out_layout = desired_layout.name();
return InferLayoutOutput({layout}, {layout}, Attrs(new_attrs));
} else {
layout = InitialLayout(4);
}
}
new_attrs->layout = TransposeLike(attrs->layout, InitialLayout(4), layout->layout).name();
new_attrs->out_layout = TransposeLike(attrs->out_layout, InitialLayout(4), layout->layout).name();
return InferLayoutOutput({layout}, {layout}, Attrs(new_attrs));
Expand Down
22 changes: 22 additions & 0 deletions src/relax/op/op_common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -185,5 +185,27 @@ InferLayoutOutput InferLayoutUnaryEwise(const Call& call,
return InferLayoutOutput({layout}, {layout}, Attrs(call->attrs));
}

bool CanProveLayoutTransform(const Layout& input_layout, const Layout& desired_layout,
Array<PrimExpr> shape) {
bool can_prove = true;
try {
tir::BijectiveLayout todesired(input_layout, desired_layout);
Array<PrimExpr> desired_shape = todesired.ForwardShape(shape);
Array<PrimExpr> back_shape = todesired.BackwardShape(desired_shape);
arith::Analyzer analyzer;
for (size_t i = 0; i < shape.size(); ++i) {
if (tir::is_const_int(shape[i])) {
if (!analyzer.CanProveEqual(shape[i], back_shape[i])) {
can_prove = false;
break;
}
}
}
} catch (std::exception& err) {
return false;
}
return can_prove;
}

} // namespace relax
} // namespace tvm
10 changes: 10 additions & 0 deletions src/relax/op/op_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,16 @@ Expr MakeAllocTensor(Expr shape, DataTypeImm dtype, PrimValue runtime_device_ind
*/
Array<Expr> GetCallArgs(const Call& call);

/**
* \brief Checks the given shape can be proved from the source layout to dst layout
* \param input_layout is the layout of given shape
* \param desired_layout is the target layout the shape to be transformed
* \param shape array
* \return true or false depending on the compatibility
*/
bool CanProveLayoutTransform(const Layout& input_layout, const Layout& desired_layout,
Array<PrimExpr> shape);

} // namespace relax
} // namespace tvm

Expand Down
15 changes: 15 additions & 0 deletions src/relax/op/tensor/binary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,21 @@ InferLayoutOutput InferLayoutBinaryEwise(const Call& call,
ICHECK(!x1_sinfo->IsUnknownNdim() && !x2_sinfo->IsUnknownNdim())
<< "Unknown dim tensors should not be handled by this function";

Optional<ShapeExpr> shape1 = GetRef<ShapeExpr>(x1_sinfo->shape.as<ShapeExprNode>());
Optional<ShapeExpr> shape2 = GetRef<ShapeExpr>(x2_sinfo->shape.as<ShapeExprNode>());
// Lets handle sub indexing as long as primal dims are matching
if (layout1->layout.ndim_primal() == layout2->layout.ndim_primal()) {
if ((layout1->layout.ndim() >= layout2->layout.ndim()) && shape2.defined()) {
if (CanProveLayoutTransform(layout2->layout, layout1->layout, shape2.value()->values)) {
return InferLayoutOutput({layout1, layout1}, {layout1}, Attrs(call->attrs));
}
} else if (shape1.defined()) {
if (CanProveLayoutTransform(layout1->layout, layout2->layout, shape1.value()->values)) {
return InferLayoutOutput({layout2, layout2}, {layout2}, Attrs(call->attrs));
}
}
}

if (x1_sinfo->ndim <= x2_sinfo->ndim) {
if (x1_sinfo->ndim == 0) {
LayoutDecision out_layout = layout2;
Expand Down
4 changes: 4 additions & 0 deletions src/relax/op/tensor/index.cc
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,10 @@ InferLayoutOutput InferLayoutStridedSlice(const Call& call,
<< "but expression " << call << " has argument "
<< call->args[0] << " of unknown dimensionality.";
LayoutDecision existing_layout = GetLayoutDecision(var_layout_map, call->args[0]);
// Can't handle sub indexed layouts.
if (existing_layout->layout.ndim() != existing_layout->layout.ndim_primal()) {
existing_layout = LayoutDecision(InitialLayout(tensor_sinfo->ndim));
}

auto opt_axes_tuple = UnpackTupleOfPrimValue<Integer>(GetStructInfo(call->args[1]));
CHECK(opt_axes_tuple) << "Layout inference of " << call->op
Expand Down
41 changes: 39 additions & 2 deletions src/relax/op/tensor/manipulate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,10 @@ InferLayoutOutput InferLayoutExpandDims(const Call& call,

LayoutDecision existing_layout = GetLayoutDecision(var_layout_map, call->args[0]);
int ndim = tensor_sinfo->ndim;
// Can't handle sub indexed layouts.
if (existing_layout->layout.ndim() != existing_layout->layout.ndim_primal()) {
existing_layout = LayoutDecision(InitialLayout(ndim));
}
int n_new_dim = attrs->axis.size();
int output_ndim = ndim + n_new_dim;
std::vector<bool> is_new_dim(output_ndim, false);
Expand Down Expand Up @@ -622,6 +626,12 @@ InferLayoutOutput InferLayoutPermuteDims(const Call& call,
int ndim = tensor_sinfo->ndim;

LayoutDecision existing_layout = GetLayoutDecision(var_layout_map, call->args[0]);

// permute_dims can't handle sub indexed layouts.
if (existing_layout->layout.ndim() != existing_layout->layout.ndim_primal()) {
existing_layout = LayoutDecision(InitialLayout(ndim));
}

Array<Integer> order;
if (attrs->axes.defined()) {
order = attrs->axes.value();
Expand Down Expand Up @@ -942,10 +952,33 @@ InferLayoutOutput InferLayoutSplit(const Call& call,
ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support known ndim";

LayoutDecision existing_layout = GetLayoutDecision(var_layout_map, call->args[0]);
ObjectPtr<SplitAttrs> new_attrs = make_object<SplitAttrs>(*attrs);
new_attrs->axis = FindAxis(existing_layout->layout, attrs->axis);
StructInfo out_sinfo = InferStructInfoSplit(call, BlockBuilder::Create(IRModule()));
const auto* out_tuple = out_sinfo.as<TupleStructInfoNode>();

/*
* Fallback if the outputs can't be represented in input sub indexed layout
* This can happen after sub indexing, if we can't split the corresponding primal axis
*/
if (existing_layout->layout.ndim() != existing_layout->layout.ndim_primal()) {
for (const auto& si : out_tuple->fields) {
ICHECK(si->IsInstance<TensorStructInfoNode>())
<< "Fields of TupleStructInfo must be TensorStructInfo"
"output structinfo, but got "
<< si;
auto sinfo = Downcast<TensorStructInfo>(si);
Optional<ShapeExpr> shape_expr = GetRef<ShapeExpr>(sinfo->shape.as<ShapeExprNode>());
CHECK(shape_expr.defined());
auto shape_arr = shape_expr.value();
if (!CanProveLayoutTransform(InitialLayout(tensor_sinfo->ndim), existing_layout->layout,
shape_arr->values)) {
existing_layout = InitialLayout(tensor_sinfo->ndim);
break;
}
}
}

ObjectPtr<SplitAttrs> new_attrs = make_object<SplitAttrs>(*attrs);
new_attrs->axis = FindAxis(existing_layout->layout, attrs->axis);
ICHECK(out_tuple != nullptr) << "Invalid Call";
NLayout tuple_layouts(Array<NLayout>(out_tuple->fields.size(), existing_layout));
return InferLayoutOutput({existing_layout}, {tuple_layouts}, Attrs(new_attrs));
Expand Down Expand Up @@ -1092,6 +1125,10 @@ InferLayoutOutput InferLayoutSqueeze(const Call& call,
}

LayoutDecision existing_layout = GetLayoutDecision(var_layout_map, call->args[0]);
// Can't handle sub indexed layouts.
if (existing_layout->layout.ndim() != existing_layout->layout.ndim_primal()) {
existing_layout = LayoutDecision(InitialLayout(ndim));
}
String new_axis_str = TransposeStrLike(axis_str, InitialLayout(ndim), existing_layout->layout);
Array<Integer> new_axis;
for (size_t i = 0; i < new_axis_str.size(); ++i) {
Expand Down
24 changes: 17 additions & 7 deletions src/relax/op/tensor/statistical.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,25 +108,35 @@ InferLayoutOutput InferLayoutStatistical(const Call& call,

std::string axis_str(ndim, '0');
for (const auto& iter : axis) {
axis_str[(iter->value + ndim) % ndim] = '1';
axis_str[(iter->value + ndim) % ndim] = '#';
}
for (int i = 0, j = 0; i < ndim; ++i) {
if (axis_str[i] != '1') {
if (axis_str[i] != '#') {
axis_str[i] = 'A' + j++;
}
}

LayoutDecision exisiting_layout = GetLayoutDecision(var_layout_map, call->args[0]);
String new_axis_str = TransposeStrLike(axis_str, InitialLayout(ndim), exisiting_layout->layout);
auto new_axis_str = TransposeSubLayoutStrLike(axis_str, InitialLayout(ndim).name(),
exisiting_layout->layout.name());
std::string output_layout_ref = new_axis_str;
new_axis_str.erase(std::remove_if(new_axis_str.begin(), new_axis_str.end(),
[](unsigned char c) { return std::isdigit(c); }),
new_axis_str.end());

Array<Integer> new_axis;
for (size_t i = 0; i < new_axis_str.size(); ++i) {
if (new_axis_str.at(i) == '1') {
if (new_axis_str.at(i) == '#') {
new_axis.push_back(Integer(i));
}
}
std::string output_layout = new_axis_str;
output_layout.erase(std::remove(output_layout.begin(), output_layout.end(), '1'),
output_layout.end());
std::string output_layout;
for (size_t i = 0; i < output_layout_ref.length(); ++i) {
if ((isdigit(output_layout_ref[i]) && (output_layout_ref[i + 1] == '#')) ||
(output_layout_ref[i] == '#'))
continue;
output_layout.push_back(output_layout_ref[i]);
}

ObjectPtr<StatisticalAttrs> new_attrs = make_object<StatisticalAttrs>(*attrs);
new_attrs->axis = new_axis;
Expand Down
Loading

0 comments on commit d641354

Please sign in to comment.