Skip to content

Commit 9aa5d42

Browse files
authored
banderwagon: avoid allocations in scalar field conversions (#30)
* avoid allocations Signed-off-by: Ignacio Hagopian <[email protected]> * avoid extra allocs also in simple fr conversion Signed-off-by: Ignacio Hagopian <[email protected]> * avoid further allocs in ElementToBytes Signed-off-by: Ignacio Hagopian <[email protected]> Signed-off-by: Ignacio Hagopian <[email protected]>
1 parent a661476 commit 9aa5d42

File tree

2 files changed

+32
-41
lines changed

2 files changed

+32
-41
lines changed

banderwagon/element.go

+22-27
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ var Generator = Element{inner: bandersnatch.PointProj{
1616
Y: bandersnatch.GetEdwardsCurve().Base.Y,
1717
Z: fp.One(),
1818
}}
19+
1920
var Identity = Element{inner: bandersnatch.PointProj{
2021
X: fp.Zero(),
2122
Y: fp.One(),
@@ -34,7 +35,7 @@ func (p Element) Bytes() [sizePointCompressed]byte {
3435
affine_representation.FromProj(&p.inner)
3536

3637
// Serialisation takes the x co-ordinate and multiplies it by the sign of y
37-
var x = affine_representation.X
38+
x := affine_representation.X
3839
if !affine_representation.Y.LexicographicallyLargest() {
3940
x.Neg(&x)
4041
}
@@ -44,15 +45,15 @@ func (p Element) Bytes() [sizePointCompressed]byte {
4445
// Serialises multiple group elements using a batch multi inversion
4546
func ElementsToBytes(elements []*Element) [][sizePointCompressed]byte {
4647
// Collect all z co-ordinates
47-
var zs []fp.Element
48+
zs := make([]fp.Element, len(elements))
4849
for i := 0; i < int(len(elements)); i++ {
49-
zs = append(zs, elements[i].inner.Z)
50+
zs[i] = elements[i].inner.Z
5051
}
5152

5253
// Invert z co-ordinates
5354
zInvs := fp.BatchInvert(zs)
5455

55-
var serialised_points [][sizePointCompressed]byte
56+
serialised_points := make([][sizePointCompressed]byte, len(elements))
5657

5758
// Multiply x and y by zInv
5859
for i := 0; i < int(len(elements)); i++ {
@@ -69,11 +70,10 @@ func ElementsToBytes(elements []*Element) [][sizePointCompressed]byte {
6970
X.Neg(&X)
7071
}
7172

72-
serialised_points = append(serialised_points, X.Bytes())
73+
serialised_points[i] = X.Bytes()
7374
}
7475

7576
return serialised_points
76-
7777
}
7878

7979
func (p *Element) setBytes(buf []byte, trusted bool) error {
@@ -116,49 +116,41 @@ func (p *Element) SetBytesTrusted(buf []byte) error {
116116

117117
// computes X/Y
118118
func (p Element) mapToBaseField() fp.Element {
119-
120119
var res fp.Element
121120
res.Div(&p.inner.X, &p.inner.Y)
122121
return res
123122
}
124123

125-
func (p Element) MapToScalarField() fr.Element {
124+
func (p Element) MapToScalarField(res *fr.Element) {
126125
basefield := p.mapToBaseField()
127126
baseFieldBytes := basefield.BytesLE()
128127

129-
var res fr.Element
130128
res.SetBytesLE(baseFieldBytes[:])
131-
132-
return res
133129
}
134130

135131
// Maps each point to a field element in the scalar field
136-
func MultiMapToScalarField(elements []*Element) []fr.Element {
132+
func MultiMapToScalarField(result []*fr.Element, elements []*Element) {
133+
if len(result) != len(elements) {
134+
panic("MultiMapToScalarField expects the result slice to be the same length of elements")
135+
}
136+
137137
// Collect all y co-ordinates
138-
var ys []fp.Element
138+
ys := make([]fp.Element, len(elements))
139139
for i := 0; i < int(len(elements)); i++ {
140-
ys = append(ys, elements[i].inner.Y)
140+
ys[i] = elements[i].inner.Y
141141
}
142142

143143
// Invert y co-ordinates
144144
yInvs := fp.BatchInvert(ys)
145145

146-
var scalars []fr.Element
147-
148146
// Multiply x by yInv
149147
for i := 0; i < int(len(elements)); i++ {
150148
var mappedElement fp.Element
151149

152150
mappedElement.Mul(&elements[i].inner.X, &yInvs[i])
153151
byts := mappedElement.BytesLE()
154-
155-
var res fr.Element
156-
res.SetBytesLE(byts[:])
157-
scalars = append(scalars, res)
152+
result[i].SetBytesLE(byts[:])
158153
}
159-
160-
return scalars
161-
162154
}
163155

164156
// TODO: change this to not use pointers
@@ -191,7 +183,7 @@ func (p *Element) Equal(other *Element) bool {
191183
func subgroup_check(x fp.Element) error {
192184
var res, one, ax_sq fp.Element
193185
one.SetOne()
194-
var A = bandersnatch.GetEdwardsCurve().A
186+
A := bandersnatch.GetEdwardsCurve().A
195187

196188
// 1 - ax^2
197189
ax_sq.Square(&x)
@@ -209,24 +201,27 @@ func (p *Element) Identity() *Element {
209201
*p = Identity
210202
return p
211203
}
204+
212205
func (p *Element) Double(p1 *Element) *Element {
213206
p.inner.Double(&p1.inner)
214207
return p
215208
}
209+
216210
func (p *Element) Add(p1, p2 *Element) *Element {
217211
p.inner.Add(&p1.inner, &p2.inner)
218212
return p
219213
}
214+
220215
func (p *Element) AddMixed(p1 *Element, p2 bandersnatch.PointAffine) *Element {
221216
p.inner.MixedAdd(&p1.inner, &p2)
222217
return p
223218
}
219+
224220
func (p *Element) Sub(p1, p2 *Element) *Element {
225221
var neg_p2 Element
226222
neg_p2.Neg(p2)
227223

228224
return p.Add(p1, &neg_p2)
229-
230225
}
231226

232227
func (p *Element) IsOnCurve() bool {
@@ -244,6 +239,7 @@ func (p *Element) Normalise() {
244239
p.inner.Y.Set(&point_aff.Y)
245240
p.inner.Z.SetOne()
246241
}
242+
247243
func (p *Element) Set(p1 *Element) *Element {
248244
p.inner.X.Set(&p1.inner.X)
249245
p.inner.Y.Set(&p1.inner.Y)
@@ -255,6 +251,7 @@ func (p *Element) Neg(p1 *Element) *Element {
255251
p.inner.Neg(&p1.inner)
256252
return p
257253
}
254+
258255
func (p *Element) ScalarMul(p1 *Element, scalar_mont *fr.Element) *Element {
259256
p.inner.ScalarMul(&p1.inner, scalar_mont)
260257
return p
@@ -269,7 +266,6 @@ func (p *Element) ScalarMul(p1 *Element, scalar_mont *fr.Element) *Element {
269266
//
270267
// we could increase storage by 2x and save CPU time by serialising the projective point
271268
func UnsafeReadUncompressedPoint(r io.Reader) *Element {
272-
273269
affine_point := bandersnatch.ReadUncompressedPoint(r)
274270
var proj_repr bandersnatch.PointProj
275271
proj_repr.FromAffine(&affine_point)
@@ -281,7 +277,6 @@ func UnsafeReadUncompressedPoint(r io.Reader) *Element {
281277

282278
// Writes an uncompressed affine point to an io.Writer
283279
func (element *Element) UnsafeWriteUncompressedPoint(w io.Writer) (int, error) {
284-
285280
// Convert underlying point to affine representation
286281
var p bandersnatch.PointAffine
287282
p.FromProj(&element.inner)

banderwagon/element_test.go

+10-14
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@ import (
77

88
"github.com/crate-crypto/go-ipa/bandersnatch"
99
"github.com/crate-crypto/go-ipa/bandersnatch/fp"
10+
"github.com/crate-crypto/go-ipa/bandersnatch/fr"
1011
)
1112

1213
func TestEncodingFixedVectors(t *testing.T) {
13-
1414
expected_bit_strings := [16]string{
1515
"4a2c7486fd924882bf02c6908de395122843e3e05264d7991e18e7985dad51e9",
1616
"43aa74ef706605705989e8fd38df46873b7eae5921fbed115ac9d937399ce4d5",
@@ -90,6 +90,7 @@ func TestTwoTorsionEqual(t *testing.T) {
9090
point.Double(&point)
9191
}
9292
}
93+
9394
func TestPointAtInfinityComponent(t *testing.T) {
9495
// These are all points which will be shown to be on the curve
9596
// but are not in the correct subgroup
@@ -124,11 +125,9 @@ func TestPointAtInfinityComponent(t *testing.T) {
124125
panic("point should not be in the correct subgroup as it has an infinity component")
125126
}
126127
}
127-
128128
}
129129

130130
func TestAddSubDouble(t *testing.T) {
131-
132131
var A, B Element
133132

134133
A.Add(&Generator, &Generator)
@@ -149,7 +148,6 @@ func TestAddSubDouble(t *testing.T) {
149148
}
150149

151150
func TestSerde(t *testing.T) {
152-
153151
var point Element
154152
var point_aff bandersnatch.PointAffine
155153

@@ -164,11 +162,9 @@ func TestSerde(t *testing.T) {
164162
if !point_aff.Equal(&got) {
165163
panic("deserialised point does not equal serialised point ")
166164
}
167-
168165
}
169166

170167
func TestBatchElementsToBytes(t *testing.T) {
171-
172168
var A, B Element
173169

174170
A.Add(&Generator, &Generator)
@@ -183,34 +179,34 @@ func TestBatchElementsToBytes(t *testing.T) {
183179
got_serialised_b := serialised_points[1]
184180
if expected_serialised_a != got_serialised_a {
185181
panic("expected serialised point of A is incorrect ")
186-
187182
}
188183
if expected_serialised_b != got_serialised_b {
189184
panic("expected serialised point of B is incorrect ")
190185
}
191-
192186
}
193187

194188
func TestMultiMapToBaseField(t *testing.T) {
195-
196189
var A, B Element
197190

198191
A.Add(&Generator, &Generator)
199192
B.Double(&Generator)
200193
B.Double(&B)
201194

202-
expected_a := A.MapToScalarField()
203-
expected_b := B.MapToScalarField()
195+
var expected_a, expected_b fr.Element
196+
A.MapToScalarField(&expected_a)
197+
B.MapToScalarField(&expected_b)
204198

205-
scalars := MultiMapToScalarField([]*Element{&A, &B})
199+
var ARes, BRes fr.Element
200+
scalars := []*fr.Element{&ARes, &BRes}
201+
MultiMapToScalarField(scalars, []*Element{&A, &B})
206202

207203
got_a := scalars[0]
208204
got_b := scalars[1]
209-
if expected_a != got_a {
205+
if expected_a != *got_a {
210206
panic("expected scalar for point `A` is incorrect ")
211207
}
212208

213-
if expected_b != got_b {
209+
if expected_b != *got_b {
214210
panic("expected scalar for point `A` is incorrect ")
215211
}
216212
}

0 commit comments

Comments
 (0)