Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/pypto/codegen/distributed/distributed_codegen.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ class DistributedCodegen : public CodegenBase {
void EmitAllocIntermediatesFunction(const ir::FunctionPtr& host_orch);

// Helpers
void RegisterParamsAndEmitScalarBindings(const ir::FunctionPtr& func);
[[nodiscard]] std::string ParamDirectionToTensorArgType(ir::ParamDirection dir) const;
[[nodiscard]] std::vector<ir::FunctionPtr> SortFunctionsByRoleAndLevel() const;
void ClassifyFunctions();
Expand Down
59 changes: 58 additions & 1 deletion include/pypto/ir/transforms/utils/scope_outline_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -633,7 +633,64 @@ class ScopeOutliner : public IRMutator {
}

// Apply pointer-based substitution after store results are materialized.
auto transformed_body = Substitute(pre_sub_body, var_substitution_map);
//
// We can't reuse `Substitute` here because IRMutator::VisitExpr_(VarPtr)
// mints a *fresh* Var when an old Var's type embeds a remapped shape Var
// (see mutator.cpp:225-239). For a tensor input whose shape references
// another input scalar, this means the body ends up referencing a Var
// that's NOT the one we just pushed into `input_params`; the codegen's
// param-binding loop then can't find a tensor view for it. We need
// visibility into the post-substitution remap state to pull out the
// freshened param Vars and update `input_params` accordingly.
class TrackingSubstituteMutator : public IRMutator {
public:
explicit TrackingSubstituteMutator(const std::unordered_map<const Var*, VarPtr>& var_map) {
for (const auto& [k, v] : var_map) {
var_remap_[k] = v;
}
}
const std::unordered_map<const Expr*, ExprPtr>& GetVarRemap() const { return var_remap_; }
};
TrackingSubstituteMutator subst_mutator(var_substitution_map);
auto transformed_body = subst_mutator.VisitStmt(pre_sub_body);

// Reconcile param/output Vars with any freshened versions created during
// substitution. ResolveVarRemapHit memoizes the resolved (final) Var back
// into var_remap_ keyed by the original Var, so the chain
// old → seed (initial param/outlined) → freshened (after type remap)
// collapses to old → freshened. Pick that out and replace the stale
// entry in input_params / outlined_output_vars / return_types.
const auto& post_remap = subst_mutator.GetVarRemap();
auto resolve_to_freshened = [&](const VarPtr& original, const VarPtr& seeded) -> VarPtr {
auto it = post_remap.find(original.get());
if (it == post_remap.end()) return seeded;
auto freshened = AsVarLike(it->second);
if (!freshened) return seeded;
return freshened;
};
for (size_t i = 0; i < input_vars.size(); ++i) {
input_params[i] = resolve_to_freshened(input_vars[i], input_params[i]);
}
for (size_t i = 0; i < output_vars.size(); ++i) {
bool is_store = store_output_set.count(output_vars[i].get()) > 0;
if (is_store) {
// Store targets aren't seeded into var_substitution_map, but their types
// may still embed remapped shape vars that trigger freshening during
// substitution. Check only the outlined var key; the original key may
// alias an input param (InOut parameters) and produce a false match.
auto it = post_remap.find(outlined_output_vars[i].get());
if (it != post_remap.end()) {
if (auto freshened = AsVarLike(it->second)) {
outlined_output_vars[i] = freshened;
return_types[i] = freshened->GetType();
}
}
continue;
}
auto freshened = resolve_to_freshened(output_vars[i], outlined_output_vars[i]);
outlined_output_vars[i] = freshened;
return_types[i] = freshened->GetType();
}
Comment thread
luohuan19 marked this conversation as resolved.

// Build outlined function body (transformed body + return statement)
StmtPtr outlined_body;
Expand Down
39 changes: 19 additions & 20 deletions python/pypto/backend/pto_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,21 +607,26 @@ def _generate_config_file(
return "\n".join(lines) + "\n"


class _CallCollector(_ir_core.IRVisitor):
"""Collect all GlobalVar callee names reachable from a function body."""

def __init__(self) -> None:
super().__init__()
self.callee_names: list[str] = []

def visit_call(self, op: _ir_core.Call) -> None:
if isinstance(op.op, _ir_core.GlobalVar):
self.callee_names.append(op.op.name)
super().visit_call(op)


def _extract_group_member_names(
group_func: _ir_core.Function,
) -> list[str]:
"""Extract function names called by a Group function from its body."""
names: list[str] = []
stmts = _ir_core.flatten_to_stmts(group_func.body)
for stmt in stmts:
call = None
if isinstance(stmt, _ir_core.EvalStmt):
call = stmt.expr
elif isinstance(stmt, _ir_core.AssignStmt):
call = stmt.value
if isinstance(call, _ir_core.Call) and isinstance(call.op, _ir_core.GlobalVar):
names.append(call.op.name)
return names
collector = _CallCollector()
collector.visit_stmt(group_func.body)
return collector.callee_names


def _extract_peer_function_names(
Expand Down Expand Up @@ -1046,15 +1051,9 @@ def _collect_chip_task_functions(

while work:
func = work.pop()
for stmt in _ir_core.flatten_to_stmts(func.body):
call = None
if isinstance(stmt, _ir_core.EvalStmt):
call = stmt.expr
elif isinstance(stmt, _ir_core.AssignStmt):
call = stmt.value
if not (isinstance(call, _ir_core.Call) and isinstance(call.op, _ir_core.GlobalVar)):
continue
callee_name = call.op.name
collector = _CallCollector()
collector.visit_stmt(func.body)
for callee_name in collector.callee_names:
if callee_name in visited:
continue
callee = program.get_function(callee_name)
Expand Down
28 changes: 20 additions & 8 deletions src/codegen/distributed/distributed_codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "pypto/core/logging.h"
#include "pypto/ir/expr.h"
#include "pypto/ir/function.h"
#include "pypto/ir/kind_traits.h"
#include "pypto/ir/program.h"
#include "pypto/ir/scalar_expr.h"
#include "pypto/ir/stmt.h"
Expand Down Expand Up @@ -184,10 +185,13 @@ void DistributedCodegen::EmitFunction(const ir::FunctionPtr& func) {
emitter_.EmitLine(sig.str());
emitter_.IncreaseIndent();

// Register parameter names
for (const auto& param : func->params_) {
declared_vars_.insert(SanitizeName(param->name_hint_));
}
// Register parameter names and emit local bindings for scalar params.
// All orchestrator parameters live in the tensors dict; tensor params are
// referenced via tensors["name"] at call sites, but scalar params (e.g.
// pl.Scalar[pl.BOOL]) may appear in bare-name contexts such as ``if``
// conditions. Emitting ``name = tensors["name"]`` at the top of the
// function body ensures the bare name resolves correctly.
RegisterParamsAndEmitScalarBindings(func);

// Emit body
if (func->body_) {
Expand All @@ -210,10 +214,8 @@ void DistributedCodegen::EmitEntryFunction() {
emitter_.EmitLine("def entry(orch, _args, config, *, tensors, callables, sub_ids, _keep):");
emitter_.IncreaseIndent();

// Register parameter names
for (const auto& param : entry_func_->params_) {
declared_vars_.insert(SanitizeName(param->name_hint_));
}
// Register parameter names and emit local bindings for scalar params.
RegisterParamsAndEmitScalarBindings(entry_func_);

// Emit body
if (entry_func_->body_) {
Expand All @@ -224,6 +226,16 @@ void DistributedCodegen::EmitEntryFunction() {
emitter_.EmitLine("");
}

void DistributedCodegen::RegisterParamsAndEmitScalarBindings(const ir::FunctionPtr& func) {
for (const auto& param : func->params_) {
std::string name = SanitizeName(param->name_hint_);
declared_vars_.insert(name);
if (ir::As<ir::ScalarType>(param->GetType())) {
emitter_.EmitLine(name + " = tensors[\"" + name + "\"]");
}
}
}

// ========================================================================
// Statement visitors
// ========================================================================
Expand Down
Loading