Skip to content

Commit

Permalink
fix: unify property and variable names across the library
Browse files Browse the repository at this point in the history
  • Loading branch information
niieani committed Nov 13, 2024
1 parent ed01980 commit 6030d91
Show file tree
Hide file tree
Showing 8 changed files with 38 additions and 37 deletions.
21 changes: 11 additions & 10 deletions src/BytePairEncodingCore.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ import { escapeRegExp } from './util.js'
export type RawBytePairRanks = readonly (string | readonly number[])[]

export interface BytePairEncodingConfig {
mergeableBytePairRanks: RawBytePairRanks
specialTokenMapping?: Map<string, number>
bytePairRankDecoder: RawBytePairRanks
specialTokensEncoder?: Map<string, number>
tokenSplitRegex: RegExp
}

Expand Down Expand Up @@ -38,18 +38,18 @@ export class BytePairEncodingCore {
private textEncoder = new TextEncoder()

constructor({
mergeableBytePairRanks: bytePairEncoder,
specialTokenMapping: specialTokenEncoder,
bytePairRankDecoder,
specialTokensEncoder,
tokenSplitRegex,
}: BytePairEncodingConfig) {
this.bytePairRankDecoder = bytePairEncoder
this.bytePairRankDecoder = bytePairRankDecoder
this.bytePairStringRankEncoder = new Map<string, number>()

// size without array holes (which may be present in the encoder)
this.mergeableBytePairRankCount = Object.keys(bytePairEncoder).length
this.mergeableBytePairRankCount = Object.keys(bytePairRankDecoder).length
const binaryLookup: [Uint8Array, number][] = []
// forEach skips array holes:
bytePairEncoder.forEach((value, rank) => {
bytePairRankDecoder.forEach((value, rank) => {
if (typeof value === 'string') {
this.bytePairStringRankEncoder.set(value, rank)
return
Expand All @@ -61,9 +61,10 @@ export class BytePairEncodingCore {
this.bytePairNonUtfSortedEncoder = binaryLookup.sort((a, b) =>
compareUint8Arrays(a[0], b[0]),
)
this.specialTokensEncoder = specialTokenEncoder ?? new Map<string, number>()
this.specialTokensDecoder = specialTokenEncoder
? new Map([...specialTokenEncoder].map(([key, value]) => [value, key]))
this.specialTokensEncoder =
specialTokensEncoder ?? new Map<string, number>()
this.specialTokensDecoder = specialTokensEncoder
? new Map([...specialTokensEncoder].map(([key, value]) => [value, key]))
: new Map<number, string>()
this.tokenSplitRegex = tokenSplitRegex

Expand Down
22 changes: 11 additions & 11 deletions src/GptEncoding.ts
Original file line number Diff line number Diff line change
Expand Up @@ -60,39 +60,39 @@ export class GptEncoding {

modelName?: ModelName
private bytePairEncodingCoreProcessor: BytePairEncodingCore
private specialTokenMapping: Map<string, number>
private specialTokensEncoder: Map<string, number>
private specialTokensSet: Set<string>
private allSpecialTokenRegex: RegExp
private defaultSpecialTokenConfig: SpecialTokenConfig

readonly vocabularySize: number

private constructor({
mergeableBytePairRanks,
specialTokenMapping,
bytePairRankDecoder: mergeableBytePairRanks,
specialTokensEncoder,
expectedVocabularySize,
modelName,
...rest
}: EncodingParams) {
this.specialTokenMapping = specialTokenMapping
this.specialTokensSet = new Set<string>(this.specialTokenMapping.keys())
this.specialTokensEncoder = specialTokensEncoder
this.specialTokensSet = new Set<string>(this.specialTokensEncoder.keys())
this.allSpecialTokenRegex = getSpecialTokenRegex(this.specialTokensSet)

this.bytePairEncodingCoreProcessor = new BytePairEncodingCore({
mergeableBytePairRanks,
specialTokenMapping,
bytePairRankDecoder: mergeableBytePairRanks,
specialTokensEncoder,
...rest,
})
this.defaultSpecialTokenConfig = this.processSpecialTokens()

const maxTokenValue = Math.max(
mergeableBytePairRanks.length - 1,
getMaxValueFromMap(specialTokenMapping),
getMaxValueFromMap(specialTokensEncoder),
)

this.vocabularySize =
this.bytePairEncodingCoreProcessor.mergeableBytePairRankCount +
specialTokenMapping.size
specialTokensEncoder.size

if (expectedVocabularySize !== undefined) {
if (this.vocabularySize !== expectedVocabularySize) {
Expand Down Expand Up @@ -245,8 +245,8 @@ export class GptEncoding {

const params: ChatParameters | undefined =
chatModelParams[model as ChatModelName]
const chatStartToken = this.specialTokenMapping.get(ImStart)
const chatEndToken = this.specialTokenMapping.get(ImEnd)
const chatStartToken = this.specialTokensEncoder.get(ImStart)
const chatEndToken = this.specialTokensEncoder.get(ImEnd)

if (!params || chatStartToken === undefined || chatEndToken === undefined) {
throw new Error(`Model '${model}' does not support chat.`)
Expand Down
6 changes: 3 additions & 3 deletions src/encodingParams/Cl100KBase.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import {
} from '../specialTokens.js'

export function Cl100KBase(
mergeableBytePairRanks: RawBytePairRanks,
bytePairRankDecoder: RawBytePairRanks,
): EncodingParams {
const specialTokenMapping = new Map<string, number>([
[EndOfText, 100_257],
Expand All @@ -29,7 +29,7 @@ export function Cl100KBase(
return {
tokenSplitRegex:
/(?:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+/giu,
mergeableBytePairRanks,
specialTokenMapping,
bytePairRankDecoder,
specialTokensEncoder: specialTokenMapping,
}
}
6 changes: 3 additions & 3 deletions src/encodingParams/O200KBase.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import {
} from '../specialTokens.js'

export function O200KBase(
mergeableBytePairRanks: RawBytePairRanks,
bytePairRankDecoder: RawBytePairRanks,
): EncodingParams {
const specialTokenMapping = new Map<string, number>([
[EndOfText, 199_999],
Expand All @@ -29,7 +29,7 @@ export function O200KBase(
return {
tokenSplitRegex:
/(?:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+/giu,
mergeableBytePairRanks,
specialTokenMapping,
bytePairRankDecoder,
specialTokensEncoder: specialTokenMapping,
}
}
6 changes: 3 additions & 3 deletions src/encodingParams/P50KBase.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@ import { type EncodingParams, tokenSplitRegex } from '../modelParams.js'
import { EndOfText } from '../specialTokens.js'

export function P50KBase(
mergeableBytePairRanks: RawBytePairRanks,
bytePairRankDecoder: RawBytePairRanks,
): EncodingParams {
return {
expectedVocabularySize: 50_281,
tokenSplitRegex,
mergeableBytePairRanks,
specialTokenMapping: new Map<string, number>([[EndOfText, 50_256]]),
bytePairRankDecoder,
specialTokensEncoder: new Map<string, number>([[EndOfText, 50_256]]),
}
}
6 changes: 3 additions & 3 deletions src/encodingParams/P50KEdit.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import { type EncodingParams, tokenSplitRegex } from '../modelParams.js'
import { EndOfText, FimMiddle, FimPrefix, FimSuffix } from '../specialTokens.js'

export function P50KEdit(
mergeableBytePairRanks: RawBytePairRanks,
bytePairRankDecoder: RawBytePairRanks,
): EncodingParams {
const specialTokenMapping = new Map<string, number>([
[EndOfText, 50_256],
Expand All @@ -15,7 +15,7 @@ export function P50KEdit(

return {
tokenSplitRegex,
mergeableBytePairRanks,
specialTokenMapping,
bytePairRankDecoder,
specialTokensEncoder: specialTokenMapping,
}
}
6 changes: 3 additions & 3 deletions src/encodingParams/R50KBase.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@ import { type EncodingParams, tokenSplitRegex } from '../modelParams.js'
import { EndOfText } from '../specialTokens.js'

export function R50KBase(
mergeableBytePairRanks: RawBytePairRanks,
bytePairRankDecoder: RawBytePairRanks,
): EncodingParams {
return {
expectedVocabularySize: 50_257,
tokenSplitRegex,
mergeableBytePairRanks,
specialTokenMapping: new Map<string, number>([[EndOfText, 50_256]]),
bytePairRankDecoder,
specialTokensEncoder: new Map<string, number>([[EndOfText, 50_256]]),
}
}
2 changes: 1 addition & 1 deletion src/modelParams.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ export interface EncodingParams extends BytePairEncodingConfig {
* It's complex due to its need to deal with a wide variety of cases in text processing.
*/
tokenSplitRegex: RegExp
specialTokenMapping: Map<string, number>
specialTokensEncoder: Map<string, number>
modelName?: ModelName
}

Expand Down

0 comments on commit 6030d91

Please sign in to comment.