Skip to content

Commit 522b30a

Browse files
committed
Add luacode.Value.IdenticalTo method
This makes equality vs. identity checks more explicit and handles a few weird edge cases with them. We no longer use `unique.Handle` to build up the constant table, but it's okay because we can use a hashing scheme to reduce the number of comparisons per constant table insert. This is roughly equivalent to what was happening before, but with less runtime machinery.
1 parent d6e3c48 commit 522b30a

File tree

6 files changed

+176
-23
lines changed

6 files changed

+176
-23
lines changed

internal/luacode/code.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ func (p *parser) codeIndexed(fs *funcState, t, k expressionDescriptor) (expressi
343343
isKstr := k.kind == expressionKindConstant &&
344344
!k.hasJumps() &&
345345
k.constantIndex() <= maxArgB &&
346-
fs.constantTable[k.constantIndex()].Value().isShortString()
346+
fs.Constants[k.constantIndex()].isShortString()
347347
if t.kind == expressionKindUpvalue && !isKstr {
348348
// [OpGetTabUp] can only index short strings.
349349
// Copy the table from an upvalue to a register.

internal/luacode/funcstate.go

+18-16
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,7 @@ package luacode
77
import (
88
"errors"
99
"fmt"
10-
"slices"
11-
"unique"
10+
"hash/maphash"
1211
)
1312

1413
// funcState is the mutable state associated with a [Prototype]
@@ -41,9 +40,12 @@ type funcState struct {
4140
// when returning.
4241
needClose bool
4342

44-
// constantTable will be copied into the [Prototype] Constants field,
45-
// but uses [unique.Handle] to speed up lookups during build.
46-
constantTable []unique.Handle[Value]
43+
// constantsIndex is a mapping of [Value]
44+
// to their indices in the [Prototype] Constants table.
45+
// The key for a [Value] v is determined by hashValue(v).
46+
constantsIndex map[uint64][]int
47+
// constantsIndexSeed is used to hash the values for constantsIndex.
48+
constantsIndexSeed maphash.Seed
4749

4850
lineInfoWriter lineInfoWriter
4951
}
@@ -69,11 +71,6 @@ type blockControl struct {
6971
//
7072
// Equivalent to `luaK_finish` in upstream Lua.
7173
func (fs *funcState) finish() error {
72-
fs.Constants = make([]Value, len(fs.constantTable))
73-
for i, handle := range fs.constantTable {
74-
fs.Constants[i] = handle.Value()
75-
}
76-
7774
for i, instruction := range fs.Code {
7875
if i > 0 && fs.Code[i-1].IsOutTop() != instruction.IsInTop() {
7976
return fmt.Errorf("internal error: instruction %d: %v follows %v",
@@ -410,16 +407,21 @@ func (fs *funcState) markToBeClosed() {
410407
fs.needClose = true
411408
}
412409

413-
// addConstant either inserts a constant into the function's Constants table
410+
// addConstant either inserts a constant into the [Prototype] Constants table
414411
// and returns the index of the constant
415412
// or returns the index of an existing identical constant in the table.
416413
//
417414
// Equivalent to `addk` in upstream Lua.
418415
func (fs *funcState) addConstant(k Value) int {
419-
kHandle := unique.Make(k)
420-
if i := slices.Index(fs.constantTable, kHandle); i >= 0 {
421-
return i
416+
kHash := k.hash(fs.constantsIndexSeed)
417+
entries := fs.constantsIndex[kHash]
418+
for _, i := range entries {
419+
if k.IdenticalTo(fs.Constants[i]) {
420+
return i
421+
}
422422
}
423-
fs.constantTable = append(fs.constantTable, kHandle)
424-
return len(fs.constantTable) - 1
423+
fs.Constants = append(fs.Constants, k)
424+
i := len(fs.Constants) - 1
425+
fs.constantsIndex[kHash] = append(entries, i)
426+
return i
425427
}

internal/luacode/parser.go

+6
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ package luacode
77
import (
88
"errors"
99
"fmt"
10+
"hash/maphash"
1011
"io"
1112
"slices"
1213
"strings"
@@ -175,12 +176,17 @@ func (p *parser) openFunction(prev *funcState, f *Prototype) *funcState {
175176
prev: prev,
176177
Prototype: f,
177178

179+
constantsIndex: make(map[uint64][]int),
180+
178181
previousLine: f.LineDefined,
179182
firstLocal: len(p.activeVariables),
180183
firstLabel: len(p.labels),
181184
}
182185
if prev != nil {
183186
prev.Functions = append(prev.Functions, f)
187+
fs.constantsIndexSeed = prev.constantsIndexSeed
188+
} else {
189+
fs.constantsIndexSeed = maphash.MakeSeed()
184190
}
185191
p.enterBlock(fs, false)
186192
return fs

internal/luacode/prototype_test.go

+1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
)
1414

1515
var prototypeDiffOptions = cmp.Options{
16+
cmp.Comparer(Value.IdenticalTo),
1617
lineInfoCompareOption,
1718
cmpopts.EquateEmpty(),
1819
}

internal/luacode/value.go

+49-6
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
package luacode
66

77
import (
8+
"hash/maphash"
89
"math"
910
"strconv"
1011
"strings"
@@ -36,8 +37,8 @@ func (t valueType) noVariant() valueType {
3637
// Value is a subset of Lua values that can be used as constants:
3738
// nil, booleans, floats, integers, and strings.
3839
// The zero value is nil.
39-
// Values can be compared for equality with the == operator.
4040
type Value struct {
41+
_ [0]func() // Prevent comparing with "==".
4142
bits uint64
4243
s string
4344
t valueType
@@ -159,11 +160,11 @@ func (v Value) Unquoted() (s string, isString bool) {
159160
case math.IsInf(f, -1):
160161
return "-inf", false
161162
default:
162-
s = strconv.FormatFloat(f, 'g', -1, 64)
163-
if !strings.ContainsAny(s, ".e") {
164-
s += ".0"
165-
}
166-
return s, false
163+
s = strconv.FormatFloat(f, 'g', -1, 64)
164+
if !strings.ContainsAny(s, ".e") {
165+
s += ".0"
166+
}
167+
return s, false
167168
}
168169
case valueTypeInteger:
169170
i, _ := v.Int64(OnlyIntegral)
@@ -200,6 +201,12 @@ func (v Value) Equal(v2 Value) bool {
200201
case valueTypeNil, valueTypeFalse, valueTypeTrue:
201202
return v.t == v2.t
202203
case valueTypeFloat:
204+
if v2.IsInteger() {
205+
// Float must have integer value to be equal.
206+
i1, ok := v.Int64(OnlyIntegral)
207+
i2, _ := v2.Int64(OnlyIntegral)
208+
return ok && i1 == i2
209+
}
203210
f1, _ := v.Float64()
204211
f2, ok := v2.Float64()
205212
return ok && f1 == f2
@@ -214,6 +221,42 @@ func (v Value) Equal(v2 Value) bool {
214221
}
215222
}
216223

224+
// IdenticalTo reports whether two values represent the same value.
225+
// This is mostly the same as [Value.Equal],
226+
// but will report true for two NaNs, for example.
227+
func (v Value) IdenticalTo(v2 Value) bool {
228+
if v.t != v2.t {
229+
return false
230+
}
231+
switch v.t.noVariant() {
232+
case valueTypeNil, valueTypeBoolean:
233+
return true
234+
case valueTypeString:
235+
return v.s == v2.s
236+
default:
237+
return v.bits == v2.bits
238+
}
239+
}
240+
241+
// hash returns a hash value for v
242+
// such that if v1.IdenticalTo(v2),
243+
// then v1.hash(seed) == v2.hash(seed).
244+
func (v Value) hash(seed maphash.Seed) uint64 {
245+
var h maphash.Hash
246+
h.SetSeed(seed)
247+
h.WriteByte(byte(v.t))
248+
switch v.t.noVariant() {
249+
case valueTypeNumber:
250+
for i := range 64 / 8 {
251+
h.WriteByte(byte(v.bits >> (i * 8)))
252+
}
253+
case valueTypeString:
254+
s, _ := v.Unquoted()
255+
h.WriteString(s)
256+
}
257+
return h.Sum64()
258+
}
259+
217260
// FloatToIntegerMode is an enumeration of rounding modes for [FloatToInteger].
218261
type FloatToIntegerMode int
219262

internal/luacode/value_test.go

+101
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,104 @@ func TestValueUnquoted(t *testing.T) {
3434
}
3535
}
3636
}
37+
38+
func TestValueEqual(t *testing.T) {
39+
tests := []struct {
40+
v1, v2 Value
41+
want bool
42+
}{
43+
{Value{}, Value{}, true},
44+
{BoolValue(false), Value{}, false},
45+
{BoolValue(true), Value{}, false},
46+
{IntegerValue(0), Value{}, false},
47+
{FloatValue(0), Value{}, false},
48+
{StringValue(""), Value{}, false},
49+
{BoolValue(false), BoolValue(false), true},
50+
{BoolValue(true), BoolValue(true), true},
51+
{BoolValue(true), BoolValue(false), false},
52+
{IntegerValue(42), IntegerValue(42), true},
53+
{IntegerValue(42), IntegerValue(-42), false},
54+
{IntegerValue(42), FloatValue(42), true},
55+
{IntegerValue(math.MaxInt64 - 1023), FloatValue(math.MaxInt64 - 1023), true},
56+
{IntegerValue(math.MinInt64), FloatValue(math.MinInt64), true},
57+
{IntegerValue(42), FloatValue(-42), false},
58+
{FloatValue(3.14), FloatValue(3.14), true},
59+
{FloatValue(math.NaN()), FloatValue(42), false},
60+
{FloatValue(math.NaN()), FloatValue(math.NaN()), false},
61+
{StringValue(""), StringValue(""), true},
62+
{StringValue(""), StringValue("123"), false},
63+
{StringValue("123"), StringValue("123"), true},
64+
{StringValue("123"), StringValue("456"), false},
65+
{StringValue("123"), IntegerValue(123), false},
66+
67+
// Float values that can't be represented as an integer.
68+
{IntegerValue(math.MaxInt64), FloatValue(math.MaxInt64), false},
69+
{IntegerValue(math.MinInt64), FloatValue(math.MinInt64 - 1025), false},
70+
}
71+
72+
for _, test := range tests {
73+
if got := test.v1.Equal(test.v2); got && !test.want {
74+
t.Errorf("%v == %v", test.v1, test.v2)
75+
} else if !got && test.want {
76+
t.Errorf("%v != %v", test.v1, test.v2)
77+
}
78+
79+
if got := test.v2.Equal(test.v1); got && !test.want {
80+
t.Errorf("%v == %v", test.v2, test.v1)
81+
} else if !got && test.want {
82+
t.Errorf("%v != %v", test.v2, test.v1)
83+
}
84+
}
85+
}
86+
87+
func TestValueIdenticalTo(t *testing.T) {
88+
identicalTests := []Value{
89+
{},
90+
BoolValue(false),
91+
BoolValue(true),
92+
IntegerValue(42),
93+
IntegerValue(0),
94+
IntegerValue(-42),
95+
FloatValue(3.14),
96+
FloatValue(42),
97+
FloatValue(math.Inf(1)),
98+
FloatValue(math.Inf(-1)),
99+
FloatValue(math.NaN()),
100+
FloatValue(math.MinInt64),
101+
FloatValue(math.MaxInt64),
102+
StringValue(""),
103+
StringValue("abc"),
104+
}
105+
notIdenticalTests := []struct {
106+
v1, v2 Value
107+
}{
108+
{BoolValue(false), Value{}},
109+
{BoolValue(true), Value{}},
110+
{IntegerValue(0), Value{}},
111+
{FloatValue(0), Value{}},
112+
{StringValue(""), Value{}},
113+
{BoolValue(false), BoolValue(true)},
114+
{IntegerValue(123), IntegerValue(-123)},
115+
{FloatValue(123), FloatValue(-123)},
116+
{FloatValue(3.14), FloatValue(-3.14)},
117+
{FloatValue(math.Inf(1)), FloatValue(math.Inf(-1))},
118+
{StringValue("123"), StringValue("456")},
119+
{IntegerValue(123), FloatValue(123)},
120+
{IntegerValue(123), StringValue("123")},
121+
{FloatValue(123), StringValue("123")},
122+
}
123+
124+
for _, v := range identicalTests {
125+
if !v.IdenticalTo(v) {
126+
t.Errorf("(%v).IdenticalTo(%v) = false; want true", v, v)
127+
}
128+
}
129+
for _, test := range notIdenticalTests {
130+
if test.v1.IdenticalTo(test.v2) {
131+
t.Errorf("(%v).IdenticalTo(%v) = true; want false", test.v1, test.v2)
132+
}
133+
if test.v2.IdenticalTo(test.v1) {
134+
t.Errorf("(%v).IdenticalTo(%v) = true; want false", test.v2, test.v1)
135+
}
136+
}
137+
}

0 commit comments

Comments
 (0)