Skip to content
This repository was archived by the owner on Jun 17, 2024. It is now read-only.

Commit 8108ce7

Browse files
authored
Fix bitpacking (#13)
1 parent 93226fe commit 8108ce7

File tree

2 files changed

+41
-27
lines changed

2 files changed

+41
-27
lines changed

src/bitpacking.zig

+20
Original file line numberDiff line numberDiff line change
@@ -54,3 +54,23 @@ test "bitpack" {
5454
BP.decode(3, &packed_ints, &output);
5555
try std.testing.expectEqual(.{2} ** 1024, output);
5656
}
57+
58+
test "bitpack range" {
59+
const std = @import("std");
60+
const fl = @import("./fastlanez.zig");
61+
const BP = BitPacking(fl.FastLanez(u8));
62+
63+
const W = 3;
64+
65+
var ints: [1024]u8 = undefined;
66+
for (0..1024) |i| {
67+
ints[i] = @intCast(i % 7);
68+
}
69+
70+
var packed_ints: [128 * W]u8 = undefined;
71+
BP.encode(W, &ints, &packed_ints);
72+
73+
var output: [1024]u8 = undefined;
74+
BP.decode(W, &packed_ints, &output);
75+
try std.testing.expectEqual(ints, output);
76+
}

src/fastlanez.zig

+21-27
Original file line numberDiff line numberDiff line change
@@ -123,10 +123,10 @@ pub fn FastLanez(comptime Element: type) type {
123123
/// The position in the output that we're writing to. Will finish equal to Width.
124124
out_idx: comptime_int = 0,
125125

126-
shift_bits: comptime_int = 0,
127-
mask_bits: comptime_int = Width,
126+
bit_idx: comptime_int = 0,
128127

129128
/// Invoke to store the next vector.
129+
/// Called T times, and writes W times. bit_idx tracks how many bits have been written into the result.
130130
pub inline fn pack(comptime self: *Self, out: *PackedBytes(Width), word: MM1024, state: MM1024) MM1024 {
131131
var tmp: MM1024 = undefined;
132132
if (self.t == 0) {
@@ -140,24 +140,23 @@ pub fn FastLanez(comptime Element: type) type {
140140
}
141141
self.t += 1;
142142

143-
// If we didn't take all W bits last time, then we load the remainder
144-
if (self.mask_bits < Width) {
145-
tmp = or_(tmp, and_rshift(word, self.mask_bits, bitmask(self.shift_bits)));
146-
}
147-
148-
// Update the number of mask bits
149-
self.mask_bits = @min(T - self.shift_bits, Width);
143+
const shift_bits = self.bit_idx % T;
144+
const mask_bits = @min(T - shift_bits, Width - (self.bit_idx % Width));
150145

151-
// Pull the masked bits into the tmp register
152-
tmp = or_(tmp, and_lshift(word, self.shift_bits, bitmask(self.mask_bits)));
153-
self.shift_bits += Width;
146+
tmp = or_(tmp, and_lshift(word, shift_bits, bitmask(mask_bits)));
147+
self.bit_idx += mask_bits;
154148

155-
if (self.shift_bits >= T) {
156-
// If we have a full 1024 bits, then store it and reset the tmp register
149+
if (self.bit_idx % T == 0) {
150+
// If we have a full T bits, then store it and reset the tmp register
157151
store(out, self.out_idx, tmp);
158152
tmp = @splat(0);
159153
self.out_idx += 1;
160-
self.shift_bits -= T;
154+
155+
// Put the remainder of the bits in the next word
156+
if (mask_bits < Width) {
157+
tmp = or_(tmp, and_rshift(word, mask_bits, bitmask(Width - mask_bits)));
158+
self.bit_idx += (Width - mask_bits);
159+
}
161160
}
162161

163162
return tmp;
@@ -177,7 +176,7 @@ pub fn FastLanez(comptime Element: type) type {
177176
t: comptime_int = 0,
178177

179178
input_idx: comptime_int = 0,
180-
shift_bits: comptime_int = 0,
179+
bit_idx: comptime_int = 0,
181180

182181
pub inline fn unpack(comptime self: *Self, input: *const PackedBytes(Width), state: MM1024) struct { MM1024, MM1024 } {
183182
if (self.t > T) {
@@ -193,24 +192,19 @@ pub fn FastLanez(comptime Element: type) type {
193192
tmp = state;
194193
}
195194

196-
const mask_bits = @min(T - self.shift_bits, Width);
195+
const shift_bits = self.bit_idx % T;
196+
const mask_bits = @min(T - shift_bits, Width - (self.bit_idx % Width));
197197

198-
var next: MM1024 = undefined;
199-
if (self.shift_bits == T) {
200-
next = tmp;
201-
} else {
202-
next = and_rshift(tmp, self.shift_bits, bitmask(mask_bits));
203-
}
198+
var next: MM1024 = and_rshift(tmp, shift_bits, bitmask(mask_bits));
204199

205-
if (mask_bits != Width) {
200+
if (mask_bits != Width and self.input_idx < Width) {
206201
tmp = load(input, self.input_idx);
207202
self.input_idx += 1;
208203

209204
next = or_(next, and_lshift(tmp, mask_bits, bitmask(Width - mask_bits)));
210-
211-
self.shift_bits = Width - mask_bits;
205+
self.bit_idx += Width;
212206
} else {
213-
self.shift_bits += Width;
207+
self.bit_idx += mask_bits;
214208
}
215209

216210
return .{ next, tmp };

0 commit comments

Comments
 (0)