Skip to content

Commit 0d7e4f3

Browse files
authored
Feat/expr in var origin (#45)
1 parent 765dd19 commit 0d7e4f3

24 files changed

+1012
-332
lines changed

Numscript.g4

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,14 @@ valueExpr:
2121
| monetaryLit # monetaryLiteral
2222
| left = valueExpr op = DIV right = valueExpr # infixExpr
2323
| left = valueExpr op = (PLUS | MINUS) right = valueExpr # infixExpr
24-
| LPARENS valueExpr RPARENS # parenthesizedExpr;
24+
| LPARENS valueExpr RPARENS # parenthesizedExpr
25+
| functionCall # application;
2526

2627
functionCallArgs: valueExpr ( COMMA valueExpr)*;
2728
functionCall:
2829
fnName = (OVERDRAFT | IDENTIFIER) LPARENS functionCallArgs? RPARENS;
2930

30-
varOrigin: EQ functionCall;
31+
varOrigin: EQ valueExpr;
3132
varDeclaration:
3233
type_ = IDENTIFIER name = VARIABLE_NAME varOrigin?;
3334
varsDeclaration: VARS LBRACE varDeclaration* RBRACE;

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,5 +32,5 @@ require (
3232
github.com/tidwall/match v1.1.1 // indirect
3333
github.com/tidwall/pretty v1.2.1 // indirect
3434
github.com/tidwall/sjson v1.2.5 // indirect
35-
golang.org/x/exp v0.0.0-20240707233637-46b078467d37
35+
golang.org/x/exp v0.0.0-20240707233637-46b078467d37 // indirect
3636
)

internal/analysis/check.go

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ func (res *CheckResult) check() {
162162
}
163163

164164
if varDecl.Origin != nil {
165-
res.checkVarOrigin(*varDecl.Origin, varDecl)
165+
res.checkExpression(*varDecl.Origin, varDecl.Type.Name)
166166
}
167167
}
168168
}
@@ -302,18 +302,20 @@ func (res *CheckResult) checkDuplicateVars(variableName parser.Variable, decl pa
302302
}
303303
}
304304

305-
func (res *CheckResult) checkVarOrigin(fnCall parser.FnCall, decl parser.VarDeclaration) {
306-
resolution, ok := Builtins[fnCall.Caller.Name]
307-
if ok {
308-
resolution, ok := resolution.(VarOriginFnCallResolution)
309-
if ok {
310-
res.fnCallResolution[decl.Origin.Caller] = resolution
311-
res.assertHasType(decl.Name, resolution.Return, decl.Type.Name)
305+
func (res *CheckResult) checkFnCall(fnCall parser.FnCall) string {
306+
returnType := TypeAny
307+
308+
if resolution, ok := Builtins[fnCall.Caller.Name]; ok {
309+
if resolution, ok := resolution.(VarOriginFnCallResolution); ok {
310+
res.fnCallResolution[fnCall.Caller] = resolution
311+
returnType = resolution.Return
312312
}
313313
}
314314

315315
// this must come after resolution
316316
res.checkFnCallArity(&fnCall)
317+
318+
return returnType
317319
}
318320

319321
func (res *CheckResult) checkExpression(lit parser.ValueExpr, requiredType string) {
@@ -379,6 +381,9 @@ func (res *CheckResult) checkTypeOf(lit parser.ValueExpr, typeHint string) strin
379381
case *parser.StringLiteral:
380382
return TypeString
381383

384+
case *parser.FnCall:
385+
return res.checkFnCall(*lit)
386+
382387
default:
383388
return TypeAny
384389
}

internal/analysis/hover.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ func hoverOnVar(varDecl parser.VarDeclaration, position parser.Position) Hover {
7575
}
7676

7777
if varDecl.Origin != nil {
78-
hover := hoverOnFnCall(*varDecl.Origin, position)
78+
hover := hoverOnExpression(*varDecl.Origin, position)
7979
if hover != nil {
8080
return hover
8181
}
@@ -183,6 +183,8 @@ func hoverOnExpression(lit parser.ValueExpr, position parser.Position) Hover {
183183
return hover
184184
}
185185

186+
case *parser.FnCall:
187+
return hoverOnFnCall(*lit, position)
186188
}
187189

188190
return nil

internal/cmd/run.go

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ var runOutFormatOpt string
2929
var overdraftFeatureFlag bool
3030
var oneOfFeatureFlag bool
3131
var accountInterpolationFlag bool
32+
var midScriptFunctionCallFeatureFlag bool
3233

3334
type inputOpts struct {
3435
Script string `json:"script"`
@@ -129,6 +130,9 @@ func run(path string) {
129130
if accountInterpolationFlag {
130131
featureFlags[interpreter.ExperimentalAccountInterpolationFlag] = struct{}{}
131132
}
133+
if midScriptFunctionCallFeatureFlag {
134+
featureFlags[interpreter.ExperimentalMidScriptFunctionCall] = struct{}{}
135+
}
132136

133137
result, err := interpreter.RunProgram(context.Background(), parseResult.Value, opt.Variables, interpreter.StaticStore{
134138
Balances: opt.Balances,
@@ -208,9 +212,11 @@ func getRunCmd() *cobra.Command {
208212
cmd.Flags().BoolVar(&runStdinFlag, "stdin", false, "Take input from stdin (same format as the --raw option)")
209213

210214
// Feature flag
211-
cmd.Flags().BoolVar(&overdraftFeatureFlag, interpreter.ExperimentalOverdraftFunctionFeatureFlag, false, "feature flag to enable the overdraft() function")
212-
cmd.Flags().BoolVar(&oneOfFeatureFlag, interpreter.ExperimentalOneofFeatureFlag, false, "feature flag to enable the oneof combinator")
215+
216+
cmd.Flags().BoolVar(&overdraftFeatureFlag, interpreter.ExperimentalOverdraftFunctionFeatureFlag, false, "enables the experimental overdraft() function")
217+
cmd.Flags().BoolVar(&oneOfFeatureFlag, interpreter.ExperimentalOneofFeatureFlag, false, "enable the experimental oneof combinator")
213218
cmd.Flags().BoolVar(&accountInterpolationFlag, interpreter.ExperimentalAccountInterpolationFlag, false, "enables an account interpolation syntax, e.g. @users:$id:pending")
219+
cmd.Flags().BoolVar(&midScriptFunctionCallFeatureFlag, interpreter.ExperimentalMidScriptFunctionCall, false, "allows to use function call as expression, and to use any expression when definining variables")
214220

215221
// Output options
216222
cmd.Flags().StringVar(&runOutFormatOpt, "output-format", OutputFormatPretty, "Set the output format. Available options: pretty, json.")

internal/interpreter/balances.go

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
package interpreter
2+
3+
import (
4+
"math/big"
5+
6+
"github.com/formancehq/numscript/internal/utils"
7+
)
8+
9+
func (b Balances) fetchAccountBalances(account string) AccountBalance {
10+
return defaultMapGet(b, account, func() AccountBalance {
11+
return AccountBalance{}
12+
})
13+
}
14+
15+
// Get the (account, asset) tuple from the Balances
16+
// if the tuple is not present, it will write a big.NewInt(0) in it and return it
17+
func (b Balances) fetchBalance(account string, asset string) *big.Int {
18+
accountBalances := b.fetchAccountBalances(account)
19+
20+
return defaultMapGet(accountBalances, asset, func() *big.Int {
21+
return new(big.Int)
22+
})
23+
}
24+
25+
func (b Balances) has(account string, asset string) bool {
26+
accountBalances := defaultMapGet(b, account, func() AccountBalance {
27+
return AccountBalance{}
28+
})
29+
30+
_, ok := accountBalances[asset]
31+
return ok
32+
}
33+
34+
// given a BalanceQuery, return a new query which only contains needed (asset, account) pairs
35+
// (that is, the ones that aren't already cached)
36+
func (b Balances) filterQuery(q BalanceQuery) BalanceQuery {
37+
filteredQuery := BalanceQuery{}
38+
for accountName, queriedCurrencies := range q {
39+
filteredCurrencies := utils.Filter(queriedCurrencies, func(currency string) bool {
40+
return !b.has(accountName, currency)
41+
})
42+
43+
if len(filteredCurrencies) > 0 {
44+
filteredQuery[accountName] = filteredCurrencies
45+
}
46+
47+
}
48+
return filteredQuery
49+
}
50+
51+
// Merge balances by adding balances in the "update" arg
52+
func (b Balances) mergeBalance(update Balances) {
53+
// merge queried balance
54+
for acc, accBalances := range update {
55+
cachedAcc := defaultMapGet(b, acc, func() AccountBalance {
56+
return AccountBalance{}
57+
})
58+
59+
for curr, amt := range accBalances {
60+
cachedAcc[curr] = amt
61+
}
62+
}
63+
}
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
package interpreter
2+
3+
import (
4+
"math/big"
5+
"testing"
6+
7+
"github.com/stretchr/testify/require"
8+
)
9+
10+
func TestFilterQuery(t *testing.T) {
11+
fullBalance := Balances{
12+
"alice": AccountBalance{
13+
"EUR/2": big.NewInt(1),
14+
"USD/2": big.NewInt(2),
15+
},
16+
"bob": AccountBalance{
17+
"BTC": big.NewInt(3),
18+
},
19+
}
20+
21+
filteredQuery := fullBalance.filterQuery(BalanceQuery{
22+
"alice": []string{"GBP/2", "YEN", "EUR/2"},
23+
"bob": []string{"BTC"},
24+
"charlie": []string{"ETH"},
25+
})
26+
27+
require.Equal(t, BalanceQuery{
28+
"alice": []string{"GBP/2", "YEN"},
29+
"charlie": []string{"ETH"},
30+
}, filteredQuery)
31+
}

internal/interpreter/batch_balances_query.go

Lines changed: 5 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ import (
55

66
"github.com/formancehq/numscript/internal/parser"
77
"github.com/formancehq/numscript/internal/utils"
8-
"golang.org/x/exp/maps"
98
)
109

1110
// traverse the script to batch in advance required balance queries
@@ -57,41 +56,28 @@ func (st *programState) batchQuery(account string, asset string) {
5756
}
5857

5958
previousValues := st.CurrentBalanceQuery[account]
60-
if !slices.Contains[[]string, string](previousValues, asset) {
59+
if !slices.Contains(previousValues, asset) {
6160
st.CurrentBalanceQuery[account] = append(previousValues, asset)
6261
}
6362
}
6463

6564
func (st *programState) runBalancesQuery() error {
66-
filteredQuery := BalanceQuery{}
67-
for accountName, queriedCurrencies := range st.CurrentBalanceQuery {
68-
69-
cachedCurrenciesForAccount := defaultMapGet(st.CachedBalances, accountName, func() AccountBalance {
70-
return AccountBalance{}
71-
})
72-
73-
for _, queriedCurrency := range queriedCurrencies {
74-
isAlreadyCached := slices.Contains(maps.Keys(cachedCurrenciesForAccount), queriedCurrency)
75-
if !isAlreadyCached {
76-
filteredQuery[accountName] = queriedCurrencies
77-
}
78-
}
79-
80-
}
65+
filteredQuery := st.CachedBalances.filterQuery(st.CurrentBalanceQuery)
8166

8267
// avoid updating balances if we don't need to fetch new data
8368
if len(filteredQuery) == 0 {
8469
return nil
8570
}
8671

87-
balances, err := st.Store.GetBalances(st.ctx, filteredQuery)
72+
queriedBalances, err := st.Store.GetBalances(st.ctx, filteredQuery)
8873
if err != nil {
8974
return err
9075
}
9176
// reset batch query
9277
st.CurrentBalanceQuery = BalanceQuery{}
9378

94-
st.CachedBalances = balances
79+
st.CachedBalances.mergeBalance(queriedBalances)
80+
9581
return nil
9682
}
9783

internal/interpreter/evaluate_expr.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,16 @@ func (st *programState) evaluateExpr(expr parser.ValueExpr) (Value, InterpreterE
8484
return nil, nil
8585
}
8686

87+
case *parser.FnCall:
88+
if !st.varOriginPosition {
89+
err := st.checkFeatureFlag(ExperimentalMidScriptFunctionCall)
90+
if err != nil {
91+
return nil, err
92+
}
93+
}
94+
95+
return st.handleFnCall(nil, *expr)
96+
8797
default:
8898
utils.NonExhaustiveMatchPanic[any](expr)
8999
return nil, nil

0 commit comments

Comments
 (0)