Skip to content

Commit

Permalink
update dynamic cast
Browse files Browse the repository at this point in the history
  • Loading branch information
liuruyan committed Feb 12, 2025
1 parent 4ac3508 commit bce8028
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 10 deletions.
6 changes: 5 additions & 1 deletion paddle/cinn/ir/ir_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -628,7 +628,11 @@ void IrNode::convert_int64_to_int32() {
if (type_ == UInt(64)) type_ = UInt(32);

for (Expr &operand : operands) {
operand->convert_int64_to_int32();
if (operand->node_type() == IrNodeTy::Load) {
operand = ir::Cast::Make(Int(32), operand);
} else {
operand->convert_int64_to_int32();
}
}
}

Expand Down
32 changes: 23 additions & 9 deletions paddle/cinn/optim/longlong2int_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -119,19 +119,26 @@ class CastLonglong2IntMutator : public ir::IRMutator<> {
auto node = expr->As<ir::Load>();
std::for_each(node->indices.begin(),
node->indices.end(),
[&](cinn::ir::Expr& e) { ir::TryElevateInt64ToInt32({e}); });
[&](cinn::ir::Expr& e) { ir::IRMutator<>::Visit(&e, &e); });
ir::IRMutator<>::Visit(&node->tensor, &node->tensor);
}

void Visit(const ir::Select* op, Expr* expr) override {
auto node = expr->As<ir::Select>();
auto cond = node->condition;
if (cond.is_cmp()) {
ir::TryElevateInt64ToInt32({cond->operand(0), cond->operand(1)});
if (cond.is_cmp() && cond->operand(0).is_index() &&
cond->operand(1).is_index()) {
ir::IRMutator<>::Visit(&cond->operands[0], &cond->operands[0]);
ir::IRMutator<>::Visit(&cond->operands[1], &cond->operands[1]);
}
ir::IRMutator<>::Visit(&node->true_value, &node->true_value);
ir::IRMutator<>::Visit(&node->false_value, &node->false_value);
}
void Visit(const ir::IntImm* op, Expr* expr) override {
ir::TryElevateInt64ToInt32({*expr});
}
void Visit(const ir::_Var_* op, Expr* expr) override {
ir::TryElevateInt64ToInt32({*expr});
}
};

class LongLong2IntStmtPass : public StmtPass {
Expand All @@ -148,18 +155,25 @@ class LongLong2IntExprPass : public ExprPass {
} // namespace

LogicalResult LongLong2IntStmtPass::Run(ir::stmt::StmtRef stmt) {
auto CastStore = [](StmtRef stmt) {
CastLonglong2IntMutator narrow;
// store and if_then_else stmt may has recursive load, so we need to use
// mutator to change those type.
auto CastStore = [&](StmtRef stmt) {
Store store_stmt = stmt.as<Store>();
for (Expr index : store_stmt->indices()) {
ir::TryElevateInt64ToInt32({index});
narrow(&index);
}
ir::Expr value = store_stmt->value();
narrow(&value);
};

auto CastIfThenElse = [](StmtRef stmt) {
auto CastIfThenElse = [&](StmtRef stmt) {
IfThenElse if_stmt = stmt.as<IfThenElse>();
Expr cond = if_stmt->condition();
if (cond.is_cmp()) {
ir::TryElevateInt64ToInt32({cond->operand(0), cond->operand(1)});
if (cond.is_cmp() && cond->operand(0).is_index() &&
cond->operand(1).is_index()) {
narrow(&cond->operands[0]);
narrow(&cond->operands[1]);
}
};

Expand Down

0 comments on commit bce8028

Please sign in to comment.