diff --git a/include/pypto/codegen/distributed/distributed_codegen.h b/include/pypto/codegen/distributed/distributed_codegen.h index 35cc2556a..dff9dd3d4 100644 --- a/include/pypto/codegen/distributed/distributed_codegen.h +++ b/include/pypto/codegen/distributed/distributed_codegen.h @@ -109,6 +109,7 @@ class DistributedCodegen : public CodegenBase { void EmitAllocIntermediatesFunction(const ir::FunctionPtr& host_orch); // Helpers + void RegisterParamsAndEmitScalarBindings(const ir::FunctionPtr& func); [[nodiscard]] std::string ParamDirectionToTensorArgType(ir::ParamDirection dir) const; [[nodiscard]] std::vector SortFunctionsByRoleAndLevel() const; void ClassifyFunctions(); diff --git a/include/pypto/ir/transforms/utils/scope_outline_utils.h b/include/pypto/ir/transforms/utils/scope_outline_utils.h index a7ad27ab2..21c6fb3bf 100644 --- a/include/pypto/ir/transforms/utils/scope_outline_utils.h +++ b/include/pypto/ir/transforms/utils/scope_outline_utils.h @@ -633,7 +633,64 @@ class ScopeOutliner : public IRMutator { } // Apply pointer-based substitution after store results are materialized. - auto transformed_body = Substitute(pre_sub_body, var_substitution_map); + // + // We can't reuse `Substitute` here because IRMutator::VisitExpr_(VarPtr) + // mints a *fresh* Var when an old Var's type embeds a remapped shape Var + // (see mutator.cpp:225-239). For a tensor input whose shape references + // another input scalar, this means the body ends up referencing a Var + // that's NOT the one we just pushed into `input_params`; the codegen's + // param-binding loop then can't find a tensor view for it. We need + // visibility into the post-substitution remap state to pull out the + // freshened param Vars and update `input_params` accordingly. + class TrackingSubstituteMutator : public IRMutator { + public: + explicit TrackingSubstituteMutator(const std::unordered_map& var_map) { + for (const auto& [k, v] : var_map) { + var_remap_[k] = v; + } + } + const std::unordered_map& GetVarRemap() const { return var_remap_; } + }; + TrackingSubstituteMutator subst_mutator(var_substitution_map); + auto transformed_body = subst_mutator.VisitStmt(pre_sub_body); + + // Reconcile param/output Vars with any freshened versions created during + // substitution. ResolveVarRemapHit memoizes the resolved (final) Var back + // into var_remap_ keyed by the original Var, so the chain + // old → seed (initial param/outlined) → freshened (after type remap) + // collapses to old → freshened. Pick that out and replace the stale + // entry in input_params / outlined_output_vars / return_types. + const auto& post_remap = subst_mutator.GetVarRemap(); + auto resolve_to_freshened = [&](const VarPtr& original, const VarPtr& seeded) -> VarPtr { + auto it = post_remap.find(original.get()); + if (it == post_remap.end()) return seeded; + auto freshened = AsVarLike(it->second); + if (!freshened) return seeded; + return freshened; + }; + for (size_t i = 0; i < input_vars.size(); ++i) { + input_params[i] = resolve_to_freshened(input_vars[i], input_params[i]); + } + for (size_t i = 0; i < output_vars.size(); ++i) { + bool is_store = store_output_set.count(output_vars[i].get()) > 0; + if (is_store) { + // Store targets aren't seeded into var_substitution_map, but their types + // may still embed remapped shape vars that trigger freshening during + // substitution. Check only the outlined var key; the original key may + // alias an input param (InOut parameters) and produce a false match. + auto it = post_remap.find(outlined_output_vars[i].get()); + if (it != post_remap.end()) { + if (auto freshened = AsVarLike(it->second)) { + outlined_output_vars[i] = freshened; + return_types[i] = freshened->GetType(); + } + } + continue; + } + auto freshened = resolve_to_freshened(output_vars[i], outlined_output_vars[i]); + outlined_output_vars[i] = freshened; + return_types[i] = freshened->GetType(); + } // Build outlined function body (transformed body + return statement) StmtPtr outlined_body; diff --git a/python/pypto/backend/pto_backend.py b/python/pypto/backend/pto_backend.py index 80b2c1610..aa4c98353 100644 --- a/python/pypto/backend/pto_backend.py +++ b/python/pypto/backend/pto_backend.py @@ -607,21 +607,26 @@ def _generate_config_file( return "\n".join(lines) + "\n" +class _CallCollector(_ir_core.IRVisitor): + """Collect all GlobalVar callee names reachable from a function body.""" + + def __init__(self) -> None: + super().__init__() + self.callee_names: list[str] = [] + + def visit_call(self, op: _ir_core.Call) -> None: + if isinstance(op.op, _ir_core.GlobalVar): + self.callee_names.append(op.op.name) + super().visit_call(op) + + def _extract_group_member_names( group_func: _ir_core.Function, ) -> list[str]: """Extract function names called by a Group function from its body.""" - names: list[str] = [] - stmts = _ir_core.flatten_to_stmts(group_func.body) - for stmt in stmts: - call = None - if isinstance(stmt, _ir_core.EvalStmt): - call = stmt.expr - elif isinstance(stmt, _ir_core.AssignStmt): - call = stmt.value - if isinstance(call, _ir_core.Call) and isinstance(call.op, _ir_core.GlobalVar): - names.append(call.op.name) - return names + collector = _CallCollector() + collector.visit_stmt(group_func.body) + return collector.callee_names def _extract_peer_function_names( @@ -1046,15 +1051,9 @@ def _collect_chip_task_functions( while work: func = work.pop() - for stmt in _ir_core.flatten_to_stmts(func.body): - call = None - if isinstance(stmt, _ir_core.EvalStmt): - call = stmt.expr - elif isinstance(stmt, _ir_core.AssignStmt): - call = stmt.value - if not (isinstance(call, _ir_core.Call) and isinstance(call.op, _ir_core.GlobalVar)): - continue - callee_name = call.op.name + collector = _CallCollector() + collector.visit_stmt(func.body) + for callee_name in collector.callee_names: if callee_name in visited: continue callee = program.get_function(callee_name) diff --git a/src/codegen/distributed/distributed_codegen.cpp b/src/codegen/distributed/distributed_codegen.cpp index 63802fac3..9ebc6ef44 100644 --- a/src/codegen/distributed/distributed_codegen.cpp +++ b/src/codegen/distributed/distributed_codegen.cpp @@ -25,6 +25,7 @@ #include "pypto/core/logging.h" #include "pypto/ir/expr.h" #include "pypto/ir/function.h" +#include "pypto/ir/kind_traits.h" #include "pypto/ir/program.h" #include "pypto/ir/scalar_expr.h" #include "pypto/ir/stmt.h" @@ -184,10 +185,13 @@ void DistributedCodegen::EmitFunction(const ir::FunctionPtr& func) { emitter_.EmitLine(sig.str()); emitter_.IncreaseIndent(); - // Register parameter names - for (const auto& param : func->params_) { - declared_vars_.insert(SanitizeName(param->name_hint_)); - } + // Register parameter names and emit local bindings for scalar params. + // All orchestrator parameters live in the tensors dict; tensor params are + // referenced via tensors["name"] at call sites, but scalar params (e.g. + // pl.Scalar[pl.BOOL]) may appear in bare-name contexts such as ``if`` + // conditions. Emitting ``name = tensors["name"]`` at the top of the + // function body ensures the bare name resolves correctly. + RegisterParamsAndEmitScalarBindings(func); // Emit body if (func->body_) { @@ -210,10 +214,8 @@ void DistributedCodegen::EmitEntryFunction() { emitter_.EmitLine("def entry(orch, _args, config, *, tensors, callables, sub_ids, _keep):"); emitter_.IncreaseIndent(); - // Register parameter names - for (const auto& param : entry_func_->params_) { - declared_vars_.insert(SanitizeName(param->name_hint_)); - } + // Register parameter names and emit local bindings for scalar params. + RegisterParamsAndEmitScalarBindings(entry_func_); // Emit body if (entry_func_->body_) { @@ -224,6 +226,16 @@ void DistributedCodegen::EmitEntryFunction() { emitter_.EmitLine(""); } +void DistributedCodegen::RegisterParamsAndEmitScalarBindings(const ir::FunctionPtr& func) { + for (const auto& param : func->params_) { + std::string name = SanitizeName(param->name_hint_); + declared_vars_.insert(name); + if (ir::As(param->GetType())) { + emitter_.EmitLine(name + " = tensors[\"" + name + "\"]"); + } + } +} + // ======================================================================== // Statement visitors // ========================================================================