Skip to content

Commit 825deb7

Browse files
committed
feat: replace EmbedPublicKey by option
1 parent 946d42c commit 825deb7

5 files changed

+140
-280
lines changed

ipns/record.go

+53-31
Original file line numberDiff line numberDiff line change
@@ -194,14 +194,21 @@ const (
194194
)
195195

196196
type options struct {
197-
compatibleWithV1 bool
197+
v1Compatibility bool
198+
embedPublicKey *bool
198199
}
199200

200201
type Option func(*options)
201202

202-
func CompatibleWithV1(compatible bool) Option {
203-
return func(opts *options) {
204-
opts.compatibleWithV1 = compatible
203+
func WithV1Compatibility(compatible bool) Option {
204+
return func(o *options) {
205+
o.v1Compatibility = compatible
206+
}
207+
}
208+
209+
func WithPublicKey(embedded bool) Option {
210+
return func(o *options) {
211+
o.embedPublicKey = &embedded
205212
}
206213
}
207214

@@ -214,7 +221,9 @@ func processOptions(opts ...Option) *options {
214221
}
215222

216223
// NewRecord creates a new IPNS [Record] and signs it with the given private key.
217-
// This function does not embed the public key. To do so, call [EmbedPublicKey].
224+
// By default, we embed the public key for key types whose peer IDs do not encode
225+
// the public key, such as RSA and ECDSA key types. This can be changed with the
226+
// option [WithPublicKey].
218227
func NewRecord(sk ic.PrivKey, value path.Path, seq uint64, eol time.Time, ttl time.Duration, opts ...Option) (*Record, error) {
219228
options := processOptions(opts...)
220229

@@ -243,7 +252,7 @@ func NewRecord(sk ic.PrivKey, value path.Path, seq uint64, eol time.Time, ttl ti
243252
SignatureV2: sig2,
244253
}
245254

246-
if options.compatibleWithV1 {
255+
if options.v1Compatibility {
247256
pb.Value = []byte(value)
248257
typ := ipns_pb.IpnsEntry_EOL
249258
pb.ValidityType = &typ
@@ -263,6 +272,24 @@ func NewRecord(sk ic.PrivKey, value path.Path, seq uint64, eol time.Time, ttl ti
263272
pb.SignatureV1 = sig1
264273
}
265274

275+
embedPublicKey := false
276+
if options.embedPublicKey == nil {
277+
embedPublicKey, err = needToEmbedPublicKey(sk.GetPublic())
278+
if err != nil {
279+
return nil, err
280+
}
281+
} else {
282+
embedPublicKey = *options.embedPublicKey
283+
}
284+
285+
if embedPublicKey {
286+
pkBytes, err := ic.MarshalPublicKey(sk.GetPublic())
287+
if err != nil {
288+
return nil, err
289+
}
290+
pb.PubKey = pkBytes
291+
}
292+
266293
return &Record{
267294
pb: pb,
268295
node: node,
@@ -342,6 +369,26 @@ func recordDataForSignatureV2(data []byte) ([]byte, error) {
342369
return dataForSig, nil
343370
}
344371

372+
func needToEmbedPublicKey(pk ic.PubKey) (bool, error) {
373+
// First try extracting the peer ID from the public key.
374+
pid, err := peer.IDFromPublicKey(pk)
375+
if err != nil {
376+
return false, fmt.Errorf("cannot convert public key to peer ID: %w", err)
377+
}
378+
379+
_, err = pid.ExtractPublicKey()
380+
if err == nil {
381+
// Can be extracted, therefore no need to embed the public key.
382+
return false, nil
383+
}
384+
385+
if errors.Is(err, peer.ErrNoPublicKey) {
386+
return true, nil
387+
}
388+
389+
return false, fmt.Errorf("cannot extract ID from public key: %w", err)
390+
}
391+
345392
// compare compares two IPNS Records. It returns:
346393
//
347394
// - -1 if a is older than b
@@ -396,31 +443,6 @@ func compare(a, b *Record) (int, error) {
396443
return 0, nil
397444
}
398445

399-
// EmbedPublicKey embeds the given public key in the given [Record]. While not
400-
// strictly required, some nodes (e.g., DHT servers), may reject IPNS Records
401-
// that do not embed their public keys as they may not be able to validate them
402-
// efficiently.
403-
func EmbedPublicKey(r *Record, pk ic.PubKey) error {
404-
// Try extracting the public key from the ID. If we can, do not embed it.
405-
pid, err := peer.IDFromPublicKey(pk)
406-
if err != nil {
407-
return err
408-
}
409-
if _, err := pid.ExtractPublicKey(); err != peer.ErrNoPublicKey {
410-
// Either a *real* error or nil.
411-
return err
412-
}
413-
414-
// We failed to extract the public key from the peer ID, embed it.
415-
pkBytes, err := ic.MarshalPublicKey(pk)
416-
if err != nil {
417-
return err
418-
}
419-
420-
r.pb.PubKey = pkBytes
421-
return nil
422-
}
423-
424446
// ExtractPublicKey extracts a [crypto.PubKey] matching the given [peer.ID] from
425447
// the IPNS Record, if possible.
426448
func ExtractPublicKey(r *Record, pid peer.ID) (ic.PubKey, error) {

ipns/record_test.go

+75-26
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"testing"
66
"time"
77

8+
"github.com/gogo/protobuf/proto"
89
ipns_pb "github.com/ipfs/boxo/ipns/pb"
910
"github.com/ipfs/boxo/path"
1011
"github.com/ipfs/boxo/util"
@@ -107,7 +108,7 @@ func TestNewRecord(t *testing.T) {
107108
t.Run("V1+V2 with option", func(t *testing.T) {
108109
t.Parallel()
109110

110-
rec := mustNewRecord(t, sk, testPath, seq, eol, ttl, CompatibleWithV1(true))
111+
rec := mustNewRecord(t, sk, testPath, seq, eol, ttl, WithV1Compatibility(true))
111112
require.NotEmpty(t, rec.pb.SignatureV1)
112113

113114
_, err := rec.PubKey()
@@ -116,51 +117,50 @@ func TestNewRecord(t *testing.T) {
116117
fieldsMatch(t, rec, testPath, seq, eol, ttl)
117118
fieldsMatchV1(t, rec, testPath, seq, eol, ttl)
118119
})
119-
}
120-
121-
func TestEmbedPublicKey(t *testing.T) {
122-
t.Parallel()
123-
124-
sk, pk, pid := mustKeyPair(t, ic.RSA)
125-
126-
seq := uint64(0)
127-
eol := time.Now().Add(time.Hour)
128-
ttl := time.Minute * 10
129120

130-
rec := mustNewRecord(t, sk, testPath, seq, eol, ttl)
121+
t.Run("Public key embedded by default for RSA and ECDSA keys", func(t *testing.T) {
122+
t.Parallel()
131123

132-
_, err := rec.PubKey()
133-
require.ErrorIs(t, err, ErrPublicKeyNotFound)
124+
for _, keyType := range []int{ic.RSA, ic.ECDSA} {
125+
sk, _, _ := mustKeyPair(t, keyType)
126+
rec := mustNewRecord(t, sk, testPath, seq, eol, ttl)
127+
fieldsMatch(t, rec, testPath, seq, eol, ttl)
134128

135-
err = EmbedPublicKey(rec, pk)
136-
require.NoError(t, err)
129+
pk, err := rec.PubKey()
130+
require.NoError(t, err)
131+
require.True(t, pk.Equals(sk.GetPublic()))
132+
}
133+
})
137134

138-
recPK, err := rec.PubKey()
139-
require.NoError(t, err)
135+
t.Run("Public key not embedded by default for Ed25519 and Secp256k1 keys", func(t *testing.T) {
136+
t.Parallel()
140137

141-
recPID, err := peer.IDFromPublicKey(recPK)
142-
require.NoError(t, err)
138+
for _, keyType := range []int{ic.Ed25519, ic.Secp256k1} {
139+
sk, _, _ := mustKeyPair(t, keyType)
140+
rec := mustNewRecord(t, sk, testPath, seq, eol, ttl)
141+
fieldsMatch(t, rec, testPath, seq, eol, ttl)
143142

144-
require.Equal(t, pid, recPID)
143+
_, err := rec.PubKey()
144+
require.ErrorIs(t, err, ErrPublicKeyNotFound)
145+
}
146+
})
145147
}
146148

147149
func TestExtractPublicKey(t *testing.T) {
148150
t.Parallel()
149151

150152
t.Run("Returns expected public key when embedded in Peer ID", func(t *testing.T) {
151153
sk, pk, pid := mustKeyPair(t, ic.Ed25519)
152-
rec := mustNewRecord(t, sk, testPath, 0, time.Now().Add(time.Hour), time.Minute*10)
154+
rec := mustNewRecord(t, sk, testPath, 0, time.Now().Add(time.Hour), time.Minute*10, WithPublicKey(false))
153155

154156
pk2, err := ExtractPublicKey(rec, pid)
155157
require.Nil(t, err)
156158
require.Equal(t, pk, pk2)
157159
})
158160

159-
t.Run("Returns expected public key when embedded in Record", func(t *testing.T) {
161+
t.Run("Returns expected public key when embedded in Record (by default)", func(t *testing.T) {
160162
sk, pk, pid := mustKeyPair(t, ic.RSA)
161163
rec := mustNewRecord(t, sk, testPath, 0, time.Now().Add(time.Hour), time.Minute*10)
162-
err := EmbedPublicKey(rec, pk)
163-
require.Nil(t, err)
164164

165165
pk2, err := ExtractPublicKey(rec, pid)
166166
require.Nil(t, err)
@@ -169,7 +169,7 @@ func TestExtractPublicKey(t *testing.T) {
169169

170170
t.Run("Errors when not embedded in Record or Peer ID", func(t *testing.T) {
171171
sk, _, pid := mustKeyPair(t, ic.RSA)
172-
rec := mustNewRecord(t, sk, testPath, 0, time.Now().Add(time.Hour), time.Minute*10)
172+
rec := mustNewRecord(t, sk, testPath, 0, time.Now().Add(time.Hour), time.Minute*10, WithPublicKey(false))
173173

174174
pk, err := ExtractPublicKey(rec, pid)
175175
require.Error(t, err)
@@ -248,3 +248,52 @@ func TestCBORDataSerialization(t *testing.T) {
248248
assert.Equal(t, expected, f)
249249
}
250250
}
251+
252+
func TestUnmarshal(t *testing.T) {
253+
t.Parallel()
254+
255+
t.Run("Errors on invalid bytes", func(t *testing.T) {
256+
_, err := UnmarshalRecord([]byte("blah blah blah"))
257+
require.ErrorIs(t, err, ErrBadRecord)
258+
})
259+
260+
t.Run("Errors if record is too long", func(t *testing.T) {
261+
data := make([]byte, MaxRecordSize+1)
262+
_, err := UnmarshalRecord(data)
263+
require.ErrorIs(t, err, ErrRecordSize)
264+
})
265+
266+
t.Run("Errors with V1-only records", func(t *testing.T) {
267+
pb := ipns_pb.IpnsEntry{}
268+
data, err := proto.Marshal(&pb)
269+
require.NoError(t, err)
270+
_, err = UnmarshalRecord(data)
271+
require.ErrorIs(t, err, ErrDataMissing)
272+
})
273+
274+
t.Run("Errors on bad data", func(t *testing.T) {
275+
pb := ipns_pb.IpnsEntry{
276+
Data: []byte("definitely not cbor"),
277+
}
278+
data, err := proto.Marshal(&pb)
279+
require.NoError(t, err)
280+
_, err = UnmarshalRecord(data)
281+
require.ErrorIs(t, err, ErrBadRecord)
282+
})
283+
}
284+
285+
func TestKey(t *testing.T) {
286+
for _, v := range [][]string{
287+
{"RSA", "QmRp2LvtSQtCkUWCpi92ph5MdQyRtfb9jHbkNgZzGExGuG", "/ipns/k2k4r8kpauqq30hoj9oktej5btbgz1jeos16d3te36xd78trvak0jcor"},
288+
{"Ed25519", "12D3KooWSzRuSFHgLsKr6jJboAPdP7xMga2YBgBspYuErxswcgvt", "/ipns/k51qzi5uqu5dmjjgoe7s21dncepi970722cn30qlhm9qridas1c9ktkjb6ejux"},
289+
{"ECDSA", "QmSBUTocZ9LxE53Br9PDDcPWnR1FJQRv94U96Wkt8eypAw", "/ipns/k2k4r8ku8cnc1sl2h5xn7i07dma9abfnkqkxi4a6nd1xq0knoxe7b0y4"},
290+
{"Secp256k1", "16Uiu2HAmUymv6JpFwNZppdKUMxGJuHsTeicXgHGKbBasu4Ruj3K1", "/ipns/kzwfwjn5ji4puw3jc1qw4b073j74xvq21iziuqw4rem21pr7f0l4dj8i9yb978s"},
291+
} {
292+
t.Run(v[0], func(t *testing.T) {
293+
pid, err := peer.Decode(v[1])
294+
require.NoError(t, err)
295+
key := Key(pid)
296+
require.Equal(t, v[2], key)
297+
})
298+
}
299+
}

ipns/validation_test.go

+11-10
Original file line numberDiff line numberDiff line change
@@ -70,12 +70,12 @@ func TestOrdering(t *testing.T) {
7070
func TestValidator(t *testing.T) {
7171
t.Parallel()
7272

73-
check := func(t *testing.T, sk ic.PrivKey, keybook peerstore.KeyBook, key string, val []byte, eol time.Time, exp error) {
73+
check := func(t *testing.T, sk ic.PrivKey, keybook peerstore.KeyBook, key string, val []byte, eol time.Time, exp error, opts ...Option) {
7474
validator := Validator{keybook}
7575
data := val
7676
if data == nil {
7777
// do not call mustNewRecord because that validates the record!
78-
rec, err := NewRecord(sk, testPath, 1, eol, 0)
78+
rec, err := NewRecord(sk, testPath, 1, eol, 0, opts...)
7979
require.NoError(t, err)
8080
data = mustMarshal(t, rec)
8181
}
@@ -99,9 +99,10 @@ func TestValidator(t *testing.T) {
9999
check(t, sk, kb, RoutingKey(pid), nil, ts.Add(time.Hour*-1), ErrExpiredRecord)
100100
check(t, sk, kb, RoutingKey(pid), []byte("bad data"), ts.Add(time.Hour), ErrBadRecord)
101101
check(t, sk, kb, "/ipns/"+"bad key", nil, ts.Add(time.Hour), ErrKeyFormat)
102-
check(t, sk, emptyKB, RoutingKey(pid), nil, ts.Add(time.Hour), ErrPublicKeyNotFound)
103-
check(t, sk2, kb, RoutingKey(pid2), nil, ts.Add(time.Hour), ErrPublicKeyNotFound)
104-
check(t, sk2, kb, RoutingKey(pid), nil, ts.Add(time.Hour), ErrSignature)
102+
check(t, sk, emptyKB, RoutingKey(pid), nil, ts.Add(time.Hour), ErrPublicKeyNotFound, WithPublicKey(false))
103+
check(t, sk2, kb, RoutingKey(pid2), nil, ts.Add(time.Hour), ErrPublicKeyNotFound, WithPublicKey(false))
104+
check(t, sk2, kb, RoutingKey(pid), nil, ts.Add(time.Hour), ErrPublicKeyMismatch)
105+
check(t, sk2, kb, RoutingKey(pid), nil, ts.Add(time.Hour), ErrSignature, WithPublicKey(false))
105106
check(t, sk, kb, "//"+string(pid), nil, ts.Add(time.Hour), ErrInvalidPath)
106107
check(t, sk, kb, "/wrong/"+string(pid), nil, ts.Add(time.Hour), ErrInvalidPath)
107108
})
@@ -128,14 +129,14 @@ func TestValidator(t *testing.T) {
128129
kb, err := pstoremem.NewPeerstore()
129130
require.NoError(t, err)
130131

131-
sk, pk, pid := mustKeyPair(t, ic.RSA)
132-
rec := mustNewRecord(t, sk, testPath, 1, eol, 0)
132+
sk, _, pid := mustKeyPair(t, ic.RSA)
133+
rec := mustNewRecord(t, sk, testPath, 1, eol, 0, WithPublicKey(false))
133134

134135
// Fails with RSA key without embedded public key.
135136
check(t, sk, kb, RoutingKey(pid), mustMarshal(t, rec), eol, ErrPublicKeyNotFound)
136137

137138
// Embeds public key, must work now.
138-
require.NoError(t, EmbedPublicKey(rec, pk))
139+
rec = mustNewRecord(t, sk, testPath, 1, eol, 0)
139140
check(t, sk, kb, RoutingKey(pid), mustMarshal(t, rec), eol, nil)
140141

141142
// Force bad public key. Validation fails.
@@ -163,8 +164,8 @@ func TestValidate(t *testing.T) {
163164

164165
v := Validator{}
165166

166-
rec1 := mustNewRecord(t, sk, path.FromString("/path/1"), 1, eol, 0, CompatibleWithV1(true))
167-
rec2 := mustNewRecord(t, sk, path.FromString("/path/2"), 2, eol, 0, CompatibleWithV1(true))
167+
rec1 := mustNewRecord(t, sk, path.FromString("/path/1"), 1, eol, 0, WithV1Compatibility(true))
168+
rec2 := mustNewRecord(t, sk, path.FromString("/path/2"), 2, eol, 0, WithV1Compatibility(true))
168169

169170
best, err := v.Select(ipnsRoutingKey, [][]byte{mustMarshal(t, rec1), mustMarshal(t, rec2)})
170171
require.NoError(t, err)

0 commit comments

Comments
 (0)