fix(ir,codegen): reconcile freshened Vars and emit scalar param bindings#1329
fix(ir,codegen): reconcile freshened Vars and emit scalar param bindings#1329lyfne123 merged 3 commits intohw-native-sys:mainfrom
Conversation
Orchestrator parameters live in the tensors dict, but scalar params (e.g. pl.Scalar[pl.BOOL]) may appear in bare-name contexts such as if-conditions. Emit `name = tensors["name"]` at the top of each function body so the bare name resolves correctly. Also refactor duplicated call-collection logic in pto_backend.py into a reusable _CallCollector IRVisitor, and extract the C++ param registration loop into RegisterParamsAndEmitScalarBindings().
IRMutator::VisitExpr_(VarPtr) mints fresh Vars when a Var's type embeds a remapped shape Var. For tensor inputs whose shape references another input scalar, the outlined body ends up referencing a Var not present in input_params, causing codegen param-binding failures. Replace the opaque Substitute() call with a TrackingSubstituteMutator that exposes post-substitution var_remap_ state, then reconcile input_params / outlined_output_vars / return_types with any freshened Var instances.
📝 WalkthroughWalkthroughThis PR adds scalar parameter binding to distributed code generation, refactors call discovery using a visitor pattern in the Python backend, and improves variable tracking during scope outline substitution to handle freshened variables created during type transformation. ChangesScalar Parameter Binding and Visitor Refactoring
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 5✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request enhances IR transformations and code generation by implementing a TrackingSubstituteMutator to manage freshened variables during scope outlining and ensuring scalar parameters are correctly bound in the distributed codegen. It also refactors call collection in the Python backend using a visitor pattern. Feedback was provided to improve the variable reconciliation logic in scope_outline_utils.h, specifically to handle store targets whose types might be affected by shape remapping, ensuring consistency between function signatures and their bodies.
There was a problem hiding this comment.
🧹 Nitpick comments (2)
include/pypto/ir/transforms/utils/scope_outline_utils.h (2)
645-655: 💤 Low valueConsider promoting
TrackingSubstituteMutatorout of the method body.Defining a class with a base-class inheritance inside a 400-line member function works, but it's unusual and:
- bloats
OutlineScopefurther at a different level of abstraction,- prevents any unit test from exercising the substitution-tracking behavior in isolation,
- silently relies on
IRMutator::var_remap_being non-private, which is an internal coupling worth making explicit.A cleaner shape would be either (a) a private nested class of
ScopeOutliner(or a peer inoutline_utils), or (b) a protectedIRMutator::GetVarRemap()accessor in the base so this whole subclass disappears. Option (b) is the most localized change and keeps the encapsulation story forIRMutatorhonest.♻️ Sketch for option (b)
// in include/pypto/ir/transforms/base/mutator.h (IRMutator) +protected: + const std::unordered_map<const Expr*, ExprPtr>& GetVarRemap() const { return var_remap_; } + void SeedVarRemap(const std::unordered_map<const Var*, VarPtr>& seed) { + for (const auto& [k, v] : seed) var_remap_[k] = v; + }- class TrackingSubstituteMutator : public IRMutator { - public: - explicit TrackingSubstituteMutator(const std::unordered_map<const Var*, VarPtr>& var_map) { - for (const auto& [k, v] : var_map) { - var_remap_[k] = v; - } - } - const std::unordered_map<const Expr*, ExprPtr>& GetVarRemap() const { return var_remap_; } - }; - TrackingSubstituteMutator subst_mutator(var_substitution_map); + IRMutator subst_mutator; + subst_mutator.SeedVarRemap(var_substitution_map); auto transformed_body = subst_mutator.VisitStmt(pre_sub_body);🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@include/pypto/ir/transforms/utils/scope_outline_utils.h` around lines 645 - 655, The local subclass TrackingSubstituteMutator is defined inside a large member function and directly reads IRMutator::var_remap_, so either (preferred) add a protected accessor IRMutator::GetVarRemap() to expose the remap for subclasses and then rewrite TrackingSubstituteMutator to call that accessor (removing direct access to var_remap_), or move TrackingSubstituteMutator out of the method body into a private nested class of the ScopeOutliner (or a peer class in outline_utils) so it can be unit-tested; update the constructor and GetVarRemap() use sites (TrackingSubstituteMutator and the code that constructs subst_mutator / calls VisitStmt) to use the new location or accessor and ensure no code accesses IRMutator::var_remap_ directly.
645-680: 🏗️ Heavy liftThe memoization behavior is correctly implemented, but consider two optional improvements:
The chain collapsing mechanism (old → seed → freshened → old → freshened) works as intended:
ResolveVarRemapHit(mutator.cpp:222) explicitly memoizes the resolved value back intovar_remap_keyed by the original Expr pointer, so subsequent lookups in the reconciliation loops will find the final freshened Var.Optional refactor 1: Hoist
TrackingSubstituteMutatorto a private nested class ofScopeOutliner(rather than a local class inside the method). This improves readability and allows reuse across methods or future tests.Optional refactor 2: Consider adding a protected
IRMutator::GetVarRemap()accessor to the base class instead of relying on subclass access to thevar_remap_member. This would reduce coupling to internal member names and provide a cleaner API for future transformations that need read-only access to the remapping after substitution.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@include/pypto/ir/transforms/utils/scope_outline_utils.h` around lines 645 - 680, Move the local TrackingSubstituteMutator class out of the method and make it a private nested class of ScopeOutliner (declare it in the same header alongside ScopeOutliner, keep its constructor signature taking const std::unordered_map<const Var*, VarPtr>& var_map and its GetVarRemap() method) and update the method to instantiate ScopeOutliner::TrackingSubstituteMutator; additionally add a protected accessor on IRMutator, e.g. protected: const std::unordered_map<const Expr*, ExprPtr>& GetVarRemap() const, so subclasses (and callers) can read the remap without relying on the concrete member name var_remap_, then adjust TrackingSubstituteMutator to use the base-class accessor and ensure existing call sites (subst_mutator.GetVarRemap()) continue to work.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Nitpick comments:
In `@include/pypto/ir/transforms/utils/scope_outline_utils.h`:
- Around line 645-655: The local subclass TrackingSubstituteMutator is defined
inside a large member function and directly reads IRMutator::var_remap_, so
either (preferred) add a protected accessor IRMutator::GetVarRemap() to expose
the remap for subclasses and then rewrite TrackingSubstituteMutator to call that
accessor (removing direct access to var_remap_), or move
TrackingSubstituteMutator out of the method body into a private nested class of
the ScopeOutliner (or a peer class in outline_utils) so it can be unit-tested;
update the constructor and GetVarRemap() use sites (TrackingSubstituteMutator
and the code that constructs subst_mutator / calls VisitStmt) to use the new
location or accessor and ensure no code accesses IRMutator::var_remap_ directly.
- Around line 645-680: Move the local TrackingSubstituteMutator class out of the
method and make it a private nested class of ScopeOutliner (declare it in the
same header alongside ScopeOutliner, keep its constructor signature taking const
std::unordered_map<const Var*, VarPtr>& var_map and its GetVarRemap() method)
and update the method to instantiate ScopeOutliner::TrackingSubstituteMutator;
additionally add a protected accessor on IRMutator, e.g. protected: const
std::unordered_map<const Expr*, ExprPtr>& GetVarRemap() const, so subclasses
(and callers) can read the remap without relying on the concrete member name
var_remap_, then adjust TrackingSubstituteMutator to use the base-class accessor
and ensure existing call sites (subst_mutator.GetVarRemap()) continue to work.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 30c66793-d450-4817-afb9-0a16372f1041
📒 Files selected for processing (4)
include/pypto/codegen/distributed/distributed_codegen.hinclude/pypto/ir/transforms/utils/scope_outline_utils.hpython/pypto/backend/pto_backend.pysrc/codegen/distributed/distributed_codegen.cpp
Store-target outlined Vars are not seeded into var_substitution_map, but their types may still embed remapped shape Vars that trigger freshening during body substitution. Previously these were unconditionally skipped, leaving a potential mismatch between the return type and the body Var. Now check the outlined Var key in post_remap (not the original key, which may alias an input param for InOut parameters) and update outlined_output_vars / return_types when freshening occurred.
Summary
Vars after scope outline substitution.IRMutator::VisitExpr_(VarPtr)mints fresh Vars when a Var's type embeds a remapped shape Var. For tensor inputs whose shape references another input scalar, the outlined body ended up referencing a Var not present ininput_params, causing codegen param-binding failures. Replaced the opaqueSubstitute()call with aTrackingSubstituteMutatorthat exposes post-substitutionvar_remap_state, then reconciledinput_params/outlined_output_vars/return_typeswith any freshened Var instances.tensorsdict, but scalar params (e.g.pl.Scalar[pl.BOOL]) may appear in bare-name contexts such as if-conditions. Emitname = tensors["name"]at the top of each function body so the bare name resolves correctly. Also refactored duplicated call-collection logic inpto_backend.pyinto a reusable_CallCollectorIRVisitor, and extracted the C++ param registration loop intoRegisterParamsAndEmitScalarBindings().Testing