Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/pypto/codegen/distributed/distributed_codegen.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ class DistributedCodegen : public CodegenBase {
void EmitAllocIntermediatesFunction(const ir::FunctionPtr& host_orch);

// Helpers
void RegisterParamsAndEmitScalarBindings(const ir::FunctionPtr& func);
[[nodiscard]] std::string ParamDirectionToTensorArgType(ir::ParamDirection dir) const;
[[nodiscard]] std::vector<ir::FunctionPtr> SortFunctionsByRoleAndLevel() const;
void ClassifyFunctions();
Expand Down
46 changes: 45 additions & 1 deletion include/pypto/ir/transforms/utils/scope_outline_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -633,7 +633,51 @@ class ScopeOutliner : public IRMutator {
}

// Apply pointer-based substitution after store results are materialized.
auto transformed_body = Substitute(pre_sub_body, var_substitution_map);
//
// We can't reuse `Substitute` here because IRMutator::VisitExpr_(VarPtr)
// mints a *fresh* Var when an old Var's type embeds a remapped shape Var
// (see mutator.cpp:225-239). For a tensor input whose shape references
// another input scalar, this means the body ends up referencing a Var
// that's NOT the one we just pushed into `input_params`; the codegen's
// param-binding loop then can't find a tensor view for it. We need
// visibility into the post-substitution remap state to pull out the
// freshened param Vars and update `input_params` accordingly.
class TrackingSubstituteMutator : public IRMutator {
public:
explicit TrackingSubstituteMutator(const std::unordered_map<const Var*, VarPtr>& var_map) {
for (const auto& [k, v] : var_map) {
var_remap_[k] = v;
}
}
const std::unordered_map<const Expr*, ExprPtr>& GetVarRemap() const { return var_remap_; }
};
TrackingSubstituteMutator subst_mutator(var_substitution_map);
auto transformed_body = subst_mutator.VisitStmt(pre_sub_body);

// Reconcile param/output Vars with any freshened versions created during
// substitution. ResolveVarRemapHit memoizes the resolved (final) Var back
// into var_remap_ keyed by the original Var, so the chain
// old → seed (initial param/outlined) → freshened (after type remap)
// collapses to old → freshened. Pick that out and replace the stale
// entry in input_params / outlined_output_vars / return_types.
const auto& post_remap = subst_mutator.GetVarRemap();
auto resolve_to_freshened = [&](const VarPtr& original, const VarPtr& seeded) -> VarPtr {
auto it = post_remap.find(original.get());
if (it == post_remap.end()) return seeded;
auto freshened = AsVarLike(it->second);
if (!freshened) return seeded;
return freshened;
};
for (size_t i = 0; i < input_vars.size(); ++i) {
input_params[i] = resolve_to_freshened(input_vars[i], input_params[i]);
}
for (size_t i = 0; i < output_vars.size(); ++i) {
bool is_store = store_output_set.count(output_vars[i].get()) > 0;
if (is_store) continue; // Store-target outlined Vars aren't seeded into the substitution map.
auto freshened = resolve_to_freshened(output_vars[i], outlined_output_vars[i]);
outlined_output_vars[i] = freshened;
return_types[i] = freshened->GetType();
}
Comment thread
luohuan19 marked this conversation as resolved.

// Build outlined function body (transformed body + return statement)
StmtPtr outlined_body;
Expand Down
39 changes: 19 additions & 20 deletions python/pypto/backend/pto_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,21 +607,26 @@ def _generate_config_file(
return "\n".join(lines) + "\n"


class _CallCollector(_ir_core.IRVisitor):
"""Collect all GlobalVar callee names reachable from a function body."""

def __init__(self) -> None:
super().__init__()
self.callee_names: list[str] = []

def visit_call(self, op: _ir_core.Call) -> None:
if isinstance(op.op, _ir_core.GlobalVar):
self.callee_names.append(op.op.name)
super().visit_call(op)


def _extract_group_member_names(
group_func: _ir_core.Function,
) -> list[str]:
"""Extract function names called by a Group function from its body."""
names: list[str] = []
stmts = _ir_core.flatten_to_stmts(group_func.body)
for stmt in stmts:
call = None
if isinstance(stmt, _ir_core.EvalStmt):
call = stmt.expr
elif isinstance(stmt, _ir_core.AssignStmt):
call = stmt.value
if isinstance(call, _ir_core.Call) and isinstance(call.op, _ir_core.GlobalVar):
names.append(call.op.name)
return names
collector = _CallCollector()
collector.visit_stmt(group_func.body)
return collector.callee_names


def _extract_peer_function_names(
Expand Down Expand Up @@ -1046,15 +1051,9 @@ def _collect_chip_task_functions(

while work:
func = work.pop()
for stmt in _ir_core.flatten_to_stmts(func.body):
call = None
if isinstance(stmt, _ir_core.EvalStmt):
call = stmt.expr
elif isinstance(stmt, _ir_core.AssignStmt):
call = stmt.value
if not (isinstance(call, _ir_core.Call) and isinstance(call.op, _ir_core.GlobalVar)):
continue
callee_name = call.op.name
collector = _CallCollector()
collector.visit_stmt(func.body)
for callee_name in collector.callee_names:
if callee_name in visited:
continue
callee = program.get_function(callee_name)
Expand Down
28 changes: 20 additions & 8 deletions src/codegen/distributed/distributed_codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "pypto/core/logging.h"
#include "pypto/ir/expr.h"
#include "pypto/ir/function.h"
#include "pypto/ir/kind_traits.h"
#include "pypto/ir/program.h"
#include "pypto/ir/scalar_expr.h"
#include "pypto/ir/stmt.h"
Expand Down Expand Up @@ -184,10 +185,13 @@ void DistributedCodegen::EmitFunction(const ir::FunctionPtr& func) {
emitter_.EmitLine(sig.str());
emitter_.IncreaseIndent();

// Register parameter names
for (const auto& param : func->params_) {
declared_vars_.insert(SanitizeName(param->name_hint_));
}
// Register parameter names and emit local bindings for scalar params.
// All orchestrator parameters live in the tensors dict; tensor params are
// referenced via tensors["name"] at call sites, but scalar params (e.g.
// pl.Scalar[pl.BOOL]) may appear in bare-name contexts such as ``if``
// conditions. Emitting ``name = tensors["name"]`` at the top of the
// function body ensures the bare name resolves correctly.
RegisterParamsAndEmitScalarBindings(func);

// Emit body
if (func->body_) {
Expand All @@ -210,10 +214,8 @@ void DistributedCodegen::EmitEntryFunction() {
emitter_.EmitLine("def entry(orch, _args, config, *, tensors, callables, sub_ids, _keep):");
emitter_.IncreaseIndent();

// Register parameter names
for (const auto& param : entry_func_->params_) {
declared_vars_.insert(SanitizeName(param->name_hint_));
}
// Register parameter names and emit local bindings for scalar params.
RegisterParamsAndEmitScalarBindings(entry_func_);

// Emit body
if (entry_func_->body_) {
Expand All @@ -224,6 +226,16 @@ void DistributedCodegen::EmitEntryFunction() {
emitter_.EmitLine("");
}

void DistributedCodegen::RegisterParamsAndEmitScalarBindings(const ir::FunctionPtr& func) {
for (const auto& param : func->params_) {
std::string name = SanitizeName(param->name_hint_);
declared_vars_.insert(name);
if (ir::As<ir::ScalarType>(param->GetType())) {
emitter_.EmitLine(name + " = tensors[\"" + name + "\"]");
}
}
}

// ========================================================================
// Statement visitors
// ========================================================================
Expand Down
Loading