Skip to content

Commit 32cad3b

Browse files
authored
Merge pull request #11 from matlab-deep-learning/tokenizer_optimizations
Tokenizer optimizations
2 parents 1978a49 + 3e0d206 commit 32cad3b

File tree

5 files changed

+22
-19
lines changed

5 files changed

+22
-19
lines changed

+bert/+tokenizer/+internal/BasicTokenizer.m

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,11 @@
3434
u = this.cleanText(u);
3535
u = this.tokenizeCJK(u);
3636
text = u.string();
37-
origTokens = this.whiteSpaceTokenize(text);
3837
if this.IgnoreCase
39-
origTokens = lower(origTokens);
40-
origTokens = textanalytics.unicode.nfd(origTokens);
38+
text = lower(text);
39+
text = textanalytics.unicode.nfd(text);
4140
end
42-
u = textanalytics.unicode.UTF32(origTokens);
41+
u = textanalytics.unicode.UTF32(text);
4342
cats = u.characterCategories('Granularity','detailed');
4443
if this.IgnoreCase
4544
[u,cats] = this.stripAccents(u,cats);

+bert/+tokenizer/+internal/FullTokenizer.m

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,10 @@
8585
% tokens = tokenize(tokenizer,text) tokenizes the input
8686
% string text using the FullTokenizer specified by tokenizer.
8787
basicToks = this.Basic.tokenize(txt);
88+
basicToksUnicode = textanalytics.unicode.UTF32(basicToks);
8889
subToks = cell(numel(basicToks),1);
8990
for i = 1:numel(basicToks)
90-
subToks{i} = this.WordPiece.tokenize(basicToks{i});
91+
subToks{i} = this.WordPiece.tokenize(basicToksUnicode(i));
9192
end
9293
toks = cat(2,subToks{:});
9394
end

+bert/+tokenizer/+internal/WhitespaceTokenizer.m

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
% by splitting str on whitespace.
1111
arguments
1212
~
13-
text (1,1) string
13+
text
1414
end
1515
text = strip(text);
1616
text = split(text).';

+bert/+tokenizer/+internal/WordPieceTokenizer.m

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -37,16 +37,15 @@
3737
this.Vocab = this.parseVocab(vocab);
3838
end
3939

40-
function tokens = tokenize(this,text)
40+
function tokens = tokenize(this,utext)
4141
arguments
4242
this
43-
text (1,1) string
43+
utext
4444
end
4545
tokens = string.empty();
46-
wsTokens = this.WhitespaceTokenizer.tokenize(text);
47-
wsTokensU = textanalytics.unicode.UTF32(wsTokens);
48-
for i = 1:numel(wsTokensU)
49-
token = wsTokensU(i);
46+
sub = textanalytics.unicode.UTF32();
47+
for i = 1:numel(utext)
48+
token = utext(i);
5049
if numel(token.Data)>this.MaxChar
5150
tokens = [tokens,this.Unk]; %#ok
5251
continue
@@ -57,14 +56,14 @@
5756
while start<(numel(token.Data)+1)
5857
finish = numel(token.Data);
5958
currentSub = [];
60-
while start<finish+1
61-
sub = textanalytics.unicode.UTF32();
59+
while start<finish+1
6260
sub.Data = token.Data(start:finish);
6361
if start>1
6462
sub.Data = [uint32('##'),sub.Data];
6563
end
66-
if this.Vocab.isVocabularyWord(sub.string())
67-
currentSub = sub.string();
64+
strForm = sub.string();
65+
if this.Vocab.isVocabularyWord(strForm)
66+
currentSub = strForm;
6867
break
6968
end
7069
finish = finish-1;

test/bert/tokenizer/internal/tWordPieceTokenizer.m

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@ function canSetUnknownToken(test)
3939
tok = bert.tokenizer.internal.WordPieceTokenizer(enc,'UnknownToken',unk);
4040
test.verifyEqual(tok.Unk,unk)
4141
str = "blah";
42-
act_out = tok.tokenize(str);
42+
ustr = textanalytics.unicode.UTF32(str);
43+
act_out = tok.tokenize(ustr);
4344
exp_out = unk;
4445
test.verifyEqual(act_out,exp_out);
4546
end
@@ -50,7 +51,8 @@ function canSetMaxTokenLength(test)
5051
tok = bert.tokenizer.internal.WordPieceTokenizer(enc,'MaxTokenLength',maxLen);
5152
test.verifyEqual(tok.MaxChar,maxLen);
5253
str = "foo";
53-
act_out = tok.tokenize(str);
54+
ustr = textanalytics.unicode.UTF32(str);
55+
act_out = tok.tokenize(ustr);
5456
exp_out = tok.Unk;
5557
test.verifyEqual(act_out,exp_out);
5658
end
@@ -59,7 +61,9 @@ function canTokenize(test)
5961
enc = wordEncoding(["foo","bar","##foo"]);
6062
tok = bert.tokenizer.internal.WordPieceTokenizer(enc);
6163
str = "foo bar foobar barba bafoobar barfoo";
62-
act_out = tok.tokenize(str);
64+
wsTok = bert.tokenizer.internal.WhitespaceTokenizer;
65+
ustr = textanalytics.unicode.UTF32(wsTok.tokenize(str));
66+
act_out = tok.tokenize(ustr);
6367
exp_out = ["foo","bar",tok.Unk,tok.Unk,tok.Unk,"bar","##foo"];
6468
test.verifyEqual(act_out,exp_out);
6569
end

0 commit comments

Comments
 (0)