diff --git a/src/preprocessing/gpt_encode.zig b/src/preprocessing/gpt_encode.zig new file mode 100644 index 0000000..d465a52 --- /dev/null +++ b/src/preprocessing/gpt_encode.zig @@ -0,0 +1,345 @@ +const std = @import("std"); +const io = std.io; + +const ENCODE_TABLE: [256][]const u8 = .{ + "\xc4\x80", + "\xc4\x81", + "\xc4\x82", + "\xc4\x83", + "\xc4\x84", + "\xc4\x85", + "\xc4\x86", + "\xc4\x87", + "\xc4\x88", + "\xc4\x89", + "\xc4\x8a", + "\xc4\x8b", + "\xc4\x8c", + "\xc4\x8d", + "\xc4\x8e", + "\xc4\x8f", + "\xc4\x90", + "\xc4\x91", + "\xc4\x92", + "\xc4\x93", + "\xc4\x94", + "\xc4\x95", + "\xc4\x96", + "\xc4\x97", + "\xc4\x98", + "\xc4\x99", + "\xc4\x9a", + "\xc4\x9b", + "\xc4\x9c", + "\xc4\x9d", + "\xc4\x9e", + "\xc4\x9f", + "\xc4\xa0", + "\x21", + "\x22", + "\x23", + "\x24", + "\x25", + "\x26", + "\x27", + "\x28", + "\x29", + "\x2a", + "\x2b", + "\x2c", + "\x2d", + "\x2e", + "\x2f", + "\x30", + "\x31", + "\x32", + "\x33", + "\x34", + "\x35", + "\x36", + "\x37", + "\x38", + "\x39", + "\x3a", + "\x3b", + "\x3c", + "\x3d", + "\x3e", + "\x3f", + "\x40", + "\x41", + "\x42", + "\x43", + "\x44", + "\x45", + "\x46", + "\x47", + "\x48", + "\x49", + "\x4a", + "\x4b", + "\x4c", + "\x4d", + "\x4e", + "\x4f", + "\x50", + "\x51", + "\x52", + "\x53", + "\x54", + "\x55", + "\x56", + "\x57", + "\x58", + "\x59", + "\x5a", + "\x5b", + "\x5c", + "\x5d", + "\x5e", + "\x5f", + "\x60", + "\x61", + "\x62", + "\x63", + "\x64", + "\x65", + "\x66", + "\x67", + "\x68", + "\x69", + "\x6a", + "\x6b", + "\x6c", + "\x6d", + "\x6e", + "\x6f", + "\x70", + "\x71", + "\x72", + "\x73", + "\x74", + "\x75", + "\x76", + "\x77", + "\x78", + "\x79", + "\x7a", + "\x7b", + "\x7c", + "\x7d", + "\x7e", + "\xc4\xa1", + "\xc4\xa2", + "\xc4\xa3", + "\xc4\xa4", + "\xc4\xa5", + "\xc4\xa6", + "\xc4\xa7", + "\xc4\xa8", + "\xc4\xa9", + "\xc4\xaa", + "\xc4\xab", + "\xc4\xac", + "\xc4\xad", + "\xc4\xae", + "\xc4\xaf", + "\xc4\xb0", + "\xc4\xb1", + "\xc4\xb2", + "\xc4\xb3", + "\xc4\xb4", + "\xc4\xb5", + "\xc4\xb6", + "\xc4\xb7", + "\xc4\xb8", + "\xc4\xb9", + "\xc4\xba", + "\xc4\xbb", + "\xc4\xbc", + "\xc4\xbd", + "\xc4\xbe", + "\xc4\xbf", + "\xc5\x80", + "\xc5\x81", + "\xc5\x82", + "\xc2\xa1", + "\xc2\xa2", + "\xc2\xa3", + "\xc2\xa4", + "\xc2\xa5", + "\xc2\xa6", + "\xc2\xa7", + "\xc2\xa8", + "\xc2\xa9", + "\xc2\xaa", + "\xc2\xab", + "\xc2\xac", + "\xc5\x83", + "\xc2\xae", + "\xc2\xaf", + "\xc2\xb0", + "\xc2\xb1", + "\xc2\xb2", + "\xc2\xb3", + "\xc2\xb4", + "\xc2\xb5", + "\xc2\xb6", + "\xc2\xb7", + "\xc2\xb8", + "\xc2\xb9", + "\xc2\xba", + "\xc2\xbb", + "\xc2\xbc", + "\xc2\xbd", + "\xc2\xbe", + "\xc2\xbf", + "\xc3\x80", + "\xc3\x81", + "\xc3\x82", + "\xc3\x83", + "\xc3\x84", + "\xc3\x85", + "\xc3\x86", + "\xc3\x87", + "\xc3\x88", + "\xc3\x89", + "\xc3\x8a", + "\xc3\x8b", + "\xc3\x8c", + "\xc3\x8d", + "\xc3\x8e", + "\xc3\x8f", + "\xc3\x90", + "\xc3\x91", + "\xc3\x92", + "\xc3\x93", + "\xc3\x94", + "\xc3\x95", + "\xc3\x96", + "\xc3\x97", + "\xc3\x98", + "\xc3\x99", + "\xc3\x9a", + "\xc3\x9b", + "\xc3\x9c", + "\xc3\x9d", + "\xc3\x9e", + "\xc3\x9f", + "\xc3\xa0", + "\xc3\xa1", + "\xc3\xa2", + "\xc3\xa3", + "\xc3\xa4", + "\xc3\xa5", + "\xc3\xa6", + "\xc3\xa7", + "\xc3\xa8", + "\xc3\xa9", + "\xc3\xaa", + "\xc3\xab", + "\xc3\xac", + "\xc3\xad", + "\xc3\xae", + "\xc3\xaf", + "\xc3\xb0", + "\xc3\xb1", + "\xc3\xb2", + "\xc3\xb3", + "\xc3\xb4", + "\xc3\xb5", + "\xc3\xb6", + "\xc3\xb7", + "\xc3\xb8", + "\xc3\xb9", + "\xc3\xba", + "\xc3\xbb", + "\xc3\xbc", + "\xc3\xbd", + "\xc3\xbe", + "\xc3\xbf", +}; + +pub fn encode(out: []u8, in: []const u8) ![]u8 { + var i: usize = 0; + var out_idx: usize = 0; + while (i < in.len) : (i += 1) { + const slice = ENCODE_TABLE[in[i]]; + const end_idx = out_idx + slice.len; + if (end_idx > out.len) { + return error.YourSliceIsTooSmall; + } + @memcpy(out[out_idx..end_idx], slice); + out_idx = end_idx; + } + return out[0..out_idx]; +} + +pub fn get_encoded_len(in: []const u8) usize { + var i: usize = 0; + var out_idx: usize = 0; + while (i < in.len) : (i += 1) { + const slice = ENCODE_TABLE[in[i]]; + out_idx += slice.len; + } + return out_idx; +} + +fn make_decode_table() [0x10000]u8 { + @setEvalBranchQuota(100000); + var ret: [0x10000]u8 = undefined; + @memset(&ret, 0xaa); + for (ENCODE_TABLE, 0..) |slice, i| { + if (slice.len == 1) { + var n: usize = 0; + while (n < 256) : (n += 1) { + const idx: u16 = @bitCast([2]u8{slice[0], n}); + ret[idx] = i; + } + } + if (slice.len == 2) { + const value: u16 = @bitCast((slice.ptr)[0..2].*); + ret[value] = @as(u8, @intCast(i)); + } + } + return ret; +} + +const DECODE_TABLE = make_decode_table(); + +pub fn decode(out: []u8, in: []const u8) ![]u8 { + var i: usize = 0; + var out_idx: usize = 0; + while (i < in.len) : (i += 1) { + var value: u16 = 0; + if (i+1 < in.len) { + value = @bitCast((in.ptr + i)[0..2].*); + } else { + value = in[i]; + } + const out_byte = DECODE_TABLE[value]; + if (out_idx >= out.len) { + return error.YourSliceIsTooSmall; + } + out[out_idx] = out_byte; + out_idx += 1; + i += (value >> 7) & 1; + } + return out[0..out_idx]; +} + +pub fn get_decoded_len(in: []const u8) usize { + var i: usize = 0; + var out_idx: usize = 0; + while (i < in.len) : (i += 1) { + var value: u16 = 0; + if (i+1 < in.len) { + value = @bitCast((in.ptr + i)[0..2].*); + } else { + value = in[i]; + } + i += (value >> 7) & 1; + out_idx += 1; + } + return out_idx; +} diff --git a/src/preprocessing/tokenizer.zig b/src/preprocessing/tokenizer.zig index 916db74..91ee14e 100644 --- a/src/preprocessing/tokenizer.zig +++ b/src/preprocessing/tokenizer.zig @@ -3,6 +3,7 @@ const mem = std.mem; const json = std.json; const Allocator = mem.Allocator; const Thread = std.Thread; +const gpt_encode = @import("gpt_encode.zig"); // Constants for parallel processing pub const MIN_PARALLEL_TEXT_SIZE = 10_000; @@ -126,7 +127,7 @@ pub const Tokenizer = struct { const rstrip = obj.get("rstrip").?.bool; const normalized = obj.get("normalized").?.bool; - const content_copy = try allocator.dupe(u8, content); + const content_copy = try convertFromGpt2Chars(content, allocator); errdefer allocator.free(content_copy); try tokenizer.special_tokens.append(.{ @@ -148,7 +149,7 @@ pub const Tokenizer = struct { const token = entry.key_ptr.*; const id: u32 = @intCast(entry.value_ptr.*.integer); - const token_copy = try allocator.dupe(u8, token); + const token_copy = try convertFromGpt2Chars(token, allocator); errdefer allocator.free(token_copy); try tokenizer.tokens.put(token_copy, id); @@ -157,6 +158,9 @@ pub const Tokenizer = struct { // Load merges const merges = model.get("merges").?.array; for (merges.items) |merge| { + // NOTE: if you eventually want to use the merges, the way to parse + // them is to first split by " " to get two strings separated by " " + // then gpt-decode each of these strings. const merge_str = try allocator.dupe(u8, merge.string); errdefer allocator.free(merge_str); @@ -285,45 +289,19 @@ pub const Tokenizer = struct { // GPT2-specific functions from original fn convertToGpt2Chars(text: []const u8, allocator: Allocator) ![]u8 { - var result = std.ArrayList(u8).init(allocator); - errdefer result.deinit(); - - var i: usize = 0; - while (i < text.len) { - if (text[i] == ' ') { - try result.appendSlice(&GPT2_SPACE_PREFIX); // Ġ - } else if (text[i] == '\n') { - try result.appendSlice(&GPT2_NEWLINE_PREFIX); // Ċ - } else { - try result.append(text[i]); - } - i += 1; - } - return result.toOwnedSlice(); + const out_len = gpt_encode.get_encoded_len(text); + const result = try allocator.alloc(u8, out_len); + errdefer allocator.free(result); + _ = try gpt_encode.encode(result, text); + return result; } fn convertFromGpt2Chars(text: []const u8, allocator: Allocator) ![]u8 { - var result = std.ArrayList(u8).init(allocator); - errdefer result.deinit(); - - var i: usize = 0; - while (i < text.len) { - // Check for Ġ - if (i + 1 < text.len and text[i] == 0xC4) { - if (text[i + 1] == 0xA0) { - try result.append(' '); // Replace Ġ with space - i += 2; - continue; - } else if (text[i + 1] == 0x82) { - try result.append('\n'); // Replace Ċ with newline - i += 2; - continue; - } - } - try result.append(text[i]); - i += 1; - } - return result.toOwnedSlice(); + const out_len = gpt_encode.get_decoded_len(text); + const result = try allocator.alloc(u8, out_len); + errdefer allocator.free(result); + _ = try gpt_encode.decode(result, text); + return result; } // Parallel processing functions @@ -354,17 +332,14 @@ pub const Tokenizer = struct { var tokens = std.ArrayList(u32).init(self.allocator); errdefer tokens.deinit(); - const gpt2_text = try convertToGpt2Chars(chunk, self.allocator); - defer self.allocator.free(gpt2_text); - var current_pos: usize = 0; - while (current_pos < gpt2_text.len) { + while (current_pos < chunk.len) { var current_node = self.trie_root; var longest_match: ?struct { id: u32, len: usize } = null; var match_length: usize = 0; - while (current_pos + match_length < gpt2_text.len) { - const byte = gpt2_text[current_pos + match_length]; + while (current_pos + match_length < chunk.len) { + const byte = chunk[current_pos + match_length]; if (current_node.children.get(byte)) |next_node| { current_node = next_node; match_length += 1; @@ -378,7 +353,7 @@ pub const Tokenizer = struct { try tokens.append(match.id); current_pos += match.len; } else { - try tokens.append(gpt2_text[current_pos]); + try tokens.append(chunk[current_pos]); current_pos += 1; } } @@ -498,9 +473,7 @@ pub const Tokenizer = struct { var it = self.tokens.iterator(); while (it.next()) |entry| { if (entry.value_ptr.* == token_id) { - const converted = try convertFromGpt2Chars(entry.key_ptr.*, self.allocator); - defer self.allocator.free(converted); - try decoded.appendSlice(converted); + try decoded.appendSlice(entry.key_ptr.*); found = true; break; }