diff --git a/decoder.go b/decoder.go index dd3993b..cc8eaa1 100644 --- a/decoder.go +++ b/decoder.go @@ -106,6 +106,8 @@ const ( TypeSet ValueType = 2 TypeZSet ValueType = 3 TypeHash ValueType = 4 + TypeZSet2 ValueType = 5 + TypeModule ValueType = 6 TypeHashZipmap ValueType = 9 TypeListZiplist ValueType = 10 @@ -118,7 +120,8 @@ const ( const ( rdb6bitLen = 0 rdb14bitLen = 1 - rdb32bitLen = 2 + rdb32bitLen = 0x80 + rdb64bitLen = 0x81 rdbEncVal = 3 rdbFlagAux = 0xfa @@ -269,7 +272,7 @@ func (d *decode) readObject(key []byte, typ ValueType, expiry int64) error { d.event.Sadd(key, member) } d.event.EndSet(key) - case TypeZSet: + case TypeZSet, TypeZSet2: cardinality, _, err := d.readLength() if err != nil { return err @@ -280,9 +283,17 @@ func (d *decode) readObject(key []byte, typ ValueType, expiry int64) error { if err != nil { return err } - score, err := d.readFloat64() - if err != nil { - return err + var score float64; + if typ == TypeZSet2 { + score, err = d.readDouble64(); + if err != nil { + return err + } + } else { + score, err = d.readFloat64(); + if err != nil { + return err + } } d.event.Zadd(key, score, member) } @@ -315,6 +326,8 @@ func (d *decode) readObject(key []byte, typ ValueType, expiry int64) error { return d.readZiplistZset(key, expiry) case TypeHashZiplist: return d.readZiplistHash(key, expiry) + case TypeModule: + return fmt.Errorf("rdb: unable to read Redis Modules RDB objects (key %s)", key) default: return fmt.Errorf("rdb: unknown object type %d for key %s", typ, key) } @@ -628,7 +641,7 @@ func (d *decode) checkHeader() error { } version, _ := strconv.ParseInt(string(header[5:]), 10, 64) - if version < 1 || version > 7 { + if version < 1 || version > 8 { return fmt.Errorf("rdb: invalid RDB version number %d", version) } @@ -699,6 +712,14 @@ func (d *decode) readUint32() (uint32, error) { return binary.LittleEndian.Uint32(d.intBuf), nil } +func (d *decode) readUint32Big() (uint32, error) { + _, err := io.ReadFull(d.r, d.intBuf[:4]) + if err != nil { + return 0, err + } + return binary.BigEndian.Uint32(d.intBuf), nil +} + func (d *decode) readUint64() (uint64, error) { _, err := io.ReadFull(d.r, d.intBuf) if err != nil { @@ -707,12 +728,21 @@ func (d *decode) readUint64() (uint64, error) { return binary.LittleEndian.Uint64(d.intBuf), nil } -func (d *decode) readUint32Big() (uint32, error) { - _, err := io.ReadFull(d.r, d.intBuf[:4]) +func (d *decode) readUint64Big() (uint64, error) { + _, err := io.ReadFull(d.r, d.intBuf) if err != nil { return 0, err } - return binary.BigEndian.Uint32(d.intBuf), nil + return binary.BigEndian.Uint64(d.intBuf), nil +} + +func (d *decode) readDouble64() (float64, error) { + _, err := io.ReadFull(d.r, d.intBuf) + if err != nil { + return 0, err + } + bits := binary.LittleEndian.Uint64(d.intBuf); + return float64(math.Float64frombits(bits)), nil; } // Doubles are saved as strings prefixed by an unsigned @@ -747,6 +777,9 @@ func (d *decode) readFloat64() (float64, error) { panic("not reached") } +/** + * https://rdb.fnordig.de/file_format.html#length-encoding + */ func (d *decode) readLength() (uint32, bool, error) { b, err := d.r.ReadByte() if err != nil { @@ -771,8 +804,14 @@ func (d *decode) readLength() (uint32, bool, error) { default: // When the first two bits are 10, the next 6 bits are discarded. // The next 4 bytes are the length. - length, err := d.readUint32Big() - return length, false, err + if b == rdb64bitLen { + length, err := d.readUint64Big() + return uint32(length), false, err + } else { + length, err := d.readUint32Big() + return length, false, err + } + } panic("not reached") diff --git a/encoder.go b/encoder.go index 7902a7d..eac5d56 100644 --- a/encoder.go +++ b/encoder.go @@ -77,7 +77,7 @@ func (e *Encoder) EncodeLength(l uint32) (err error) { _, err = e.w.Write([]byte{byte(l>>8) | rdb14bitLen<<6, byte(l)}) default: b := make([]byte, 5) - b[0] = rdb32bitLen << 6 + b[0] = rdb32bitLen binary.BigEndian.PutUint32(b[1:], l) _, err = e.w.Write(b) }