Skip to content

Commit 7d4d502

Browse files
authored
Support function contexts in IRBuilder (#5967)
Add a `visitFunctionStart` function to IRBuilder and make it responsible for setting the function's body when the context is closed. This will simplify outlining, will be necessary to support branches to function scope properly, and removes an extra block around function bodies in the new wat parser.
1 parent c290aad commit 7d4d502

File tree

6 files changed

+136
-118
lines changed

6 files changed

+136
-118
lines changed

src/parser/context-defs.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,6 @@ Result<> ParseDefsCtx::addFunc(Name,
5757
std::optional<LocalsT>,
5858
Index pos) {
5959
CHECK_ERR(withLoc(pos, irBuilder.visitEnd()));
60-
auto body = irBuilder.build();
61-
CHECK_ERR(withLoc(pos, body));
62-
wasm.functions[index]->body = *body;
6360
return Ok{};
6461
}
6562

src/parser/contexts.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -816,9 +816,10 @@ struct ParseDefsCtx : TypeParserCtx<ParseDefsCtx> {
816816

817817
IRBuilder irBuilder;
818818

819-
void setFunction(Function* func) {
819+
Result<> visitFunctionStart(Function* func) {
820820
this->func = func;
821-
irBuilder.setFunction(func);
821+
CHECK_ERR(irBuilder.visitFunctionStart(func));
822+
return Ok{};
822823
}
823824

824825
ParseDefsCtx(std::string_view in,

src/parser/wat-parser.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,8 +157,7 @@ Result<> parseModule(Module& wasm, std::string_view input) {
157157

158158
for (Index i = 0; i < decls.funcDefs.size(); ++i) {
159159
ctx.index = i;
160-
ctx.setFunction(wasm.functions[i].get());
161-
CHECK_ERR(ctx.irBuilder.makeBlock(Name{}, ctx.func->getResults()));
160+
CHECK_ERR(ctx.visitFunctionStart(wasm.functions[i].get()));
162161
WithPosition with(ctx, decls.funcDefs[i].pos);
163162
auto parsed = func(ctx);
164163
CHECK_ERR(parsed);

src/wasm-ir-builder.h

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ class IRBuilder : public UnifiedExpressionVisitor<IRBuilder, Result<>> {
5151
// Handle the boundaries of control flow structures. Users may choose to use
5252
// the corresponding `makeXYZ` function below instead of `visitXYZStart`, but
5353
// either way must call `visitEnd` and friends at the appropriate times.
54+
[[nodiscard]] Result<> visitFunctionStart(Function* func);
5455
[[nodiscard]] Result<> visitBlockStart(Block* block);
5556
[[nodiscard]] Result<> visitIfStart(If* iff, Name label = {});
5657
[[nodiscard]] Result<> visitElse();
@@ -170,8 +171,6 @@ class IRBuilder : public UnifiedExpressionVisitor<IRBuilder, Result<>> {
170171
// [[nodiscard]] Result<> makeStringSliceWTF();
171172
// [[nodiscard]] Result<> makeStringSliceIter();
172173

173-
void setFunction(Function* func) { this->func = func; }
174-
175174
// Private functions that must be public for technical reasons.
176175
[[nodiscard]] Result<> visitExpression(Expression*);
177176
[[nodiscard]] Result<> visitBlock(Block*);
@@ -189,6 +188,9 @@ class IRBuilder : public UnifiedExpressionVisitor<IRBuilder, Result<>> {
189188
// to have.
190189
struct ScopeCtx {
191190
struct NoScope {};
191+
struct FuncScope {
192+
Function* func;
193+
};
192194
struct BlockScope {
193195
Block* block;
194196
};
@@ -203,8 +205,8 @@ class IRBuilder : public UnifiedExpressionVisitor<IRBuilder, Result<>> {
203205
struct LoopScope {
204206
Loop* loop;
205207
};
206-
using Scope =
207-
std::variant<NoScope, BlockScope, IfScope, ElseScope, LoopScope>;
208+
using Scope = std::
209+
variant<NoScope, FuncScope, BlockScope, IfScope, ElseScope, LoopScope>;
208210

209211
// The control flow structure we are building expressions for.
210212
Scope scope;
@@ -217,6 +219,9 @@ class IRBuilder : public UnifiedExpressionVisitor<IRBuilder, Result<>> {
217219
ScopeCtx() : scope(NoScope{}) {}
218220
ScopeCtx(Scope scope) : scope(scope) {}
219221

222+
static ScopeCtx makeFunc(Function* func) {
223+
return ScopeCtx(FuncScope{func});
224+
}
220225
static ScopeCtx makeBlock(Block* block) {
221226
return ScopeCtx(BlockScope{block});
222227
}
@@ -229,6 +234,12 @@ class IRBuilder : public UnifiedExpressionVisitor<IRBuilder, Result<>> {
229234
static ScopeCtx makeLoop(Loop* loop) { return ScopeCtx(LoopScope{loop}); }
230235

231236
bool isNone() { return std::get_if<NoScope>(&scope); }
237+
Function* getFunction() {
238+
if (auto* funcScope = std::get_if<FuncScope>(&scope)) {
239+
return funcScope->func;
240+
}
241+
return nullptr;
242+
}
232243
Block* getBlock() {
233244
if (auto* blockScope = std::get_if<BlockScope>(&scope)) {
234245
return blockScope->block;
@@ -254,6 +265,9 @@ class IRBuilder : public UnifiedExpressionVisitor<IRBuilder, Result<>> {
254265
return nullptr;
255266
}
256267
Type getResultType() {
268+
if (auto* func = getFunction()) {
269+
return func->type.getSignature().results;
270+
}
257271
if (auto* block = getBlock()) {
258272
return block->type;
259273
}

src/wasm/wasm-ir-builder.cpp

Lines changed: 35 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,15 @@ Result<> IRBuilder::visitArrayNew(ArrayNew* curr) {
281281
return Ok{};
282282
}
283283

284+
Result<> IRBuilder::visitFunctionStart(Function* func) {
285+
if (!scopeStack.empty()) {
286+
return Err{"unexpected start of function"};
287+
}
288+
scopeStack.push_back(ScopeCtx::makeFunc(func));
289+
this->func = func;
290+
return Ok{};
291+
}
292+
284293
Result<> IRBuilder::visitBlockStart(Block* curr) {
285294
scopeStack.push_back(ScopeCtx::makeBlock(curr));
286295
return Ok{};
@@ -327,12 +336,12 @@ Result<Expression*> IRBuilder::finishScope(Block* block) {
327336
auto hoisted = hoistLastValue();
328337
CHECK_ERR(hoisted);
329338
auto hoistedType = scope.exprStack.back()->type;
330-
if (hoistedType.size() != block->type.size()) {
339+
if (hoistedType.size() != type.size()) {
331340
// We cannot propagate the hoisted value directly because it does not
332341
// have the correct number of elements. Break it up if necessary and
333342
// construct our returned tuple from parts.
334343
CHECK_ERR(packageHoistedValue(*hoisted));
335-
std::vector<Expression*> elems(block->type.size());
344+
std::vector<Expression*> elems(type.size());
336345
for (size_t i = 0; i < elems.size(); ++i) {
337346
auto elem = pop();
338347
CHECK_ERR(elem);
@@ -369,11 +378,11 @@ Result<Expression*> IRBuilder::finishScope(Block* block) {
369378
} else {
370379
// More than one expression, so we need a block. Allocate one if we weren't
371380
// already given one.
372-
if (!block) {
373-
block = wasm.allocator.alloc<Block>();
374-
block->type = type;
381+
if (block) {
382+
block->list.set(scope.exprStack);
383+
} else {
384+
block = builder.makeBlock(scope.exprStack, type);
375385
}
376-
block->list.set(scope.exprStack);
377386
ret = block;
378387
}
379388
scopeStack.pop_back();
@@ -395,50 +404,45 @@ Result<> IRBuilder::visitElse() {
395404
}
396405

397406
Result<> IRBuilder::visitEnd() {
398-
auto& scope = getScope();
407+
auto scope = getScope();
399408
if (scope.isNone()) {
400409
return Err{"unexpected end"};
401410
}
402-
if (auto* block = scope.getBlock()) {
403-
auto expr = finishScope(block);
404-
CHECK_ERR(expr);
411+
auto expr = finishScope(scope.getBlock());
412+
CHECK_ERR(expr);
413+
414+
// If the scope expression cannot be directly labeled, we may need to wrap it
415+
// in a block.
416+
auto maybeWrapForLabel = [&](Expression* curr) -> Expression* {
417+
if (auto label = scope.getLabel()) {
418+
return builder.makeBlock(label, {curr}, curr->type);
419+
}
420+
return curr;
421+
};
422+
423+
if (auto* func = scope.getFunction()) {
424+
func->body = *expr;
425+
} else if (auto* block = scope.getBlock()) {
405426
assert(*expr == block);
406427
// TODO: Track branches so we can know whether this block is a target and
407428
// finalize more efficiently.
408429
block->finalize(block->type);
409430
push(block);
410-
return Ok{};
411431
} else if (auto* loop = scope.getLoop()) {
412-
auto expr = finishScope();
413-
CHECK_ERR(expr);
414432
loop->body = *expr;
415433
loop->finalize(loop->type);
416434
push(loop);
417-
return Ok{};
418-
}
419-
auto label = scope.getLabel();
420-
Expression* scopeExpr = nullptr;
421-
if (auto* iff = scope.getIf()) {
422-
auto expr = finishScope();
423-
CHECK_ERR(expr);
435+
} else if (auto* iff = scope.getIf()) {
424436
iff->ifTrue = *expr;
425437
iff->ifFalse = nullptr;
426438
iff->finalize(iff->type);
427-
scopeExpr = iff;
439+
push(maybeWrapForLabel(iff));
428440
} else if (auto* iff = scope.getElse()) {
429-
auto expr = finishScope();
430-
CHECK_ERR(expr);
431441
iff->ifFalse = *expr;
432442
iff->finalize(iff->type);
433-
scopeExpr = iff;
434-
}
435-
assert(scopeExpr && "unexpected scope kind");
436-
if (label) {
437-
// We cannot directly name an If in Binaryen IR, so we need to wrap it in
438-
// a block.
439-
push(builder.makeBlock(label, {scopeExpr}, scopeExpr->type));
443+
push(maybeWrapForLabel(iff));
440444
} else {
441-
push(scopeExpr);
445+
WASM_UNREACHABLE("unexpected scope kind");
442446
}
443447
return Ok{};
444448
}

0 commit comments

Comments
 (0)