Skip to content

Commit de8220d

Browse files
authored
Add example: LSTM text generation in the browser (#103)
Add example: LSTM text generation in the browser
1 parent fef4593 commit de8220d

File tree

9 files changed

+5498
-0
lines changed

9 files changed

+5498
-0
lines changed

lstm-text-generation/.babelrc

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
{
2+
"presets": [
3+
[
4+
"env",
5+
{
6+
"esmodules": false,
7+
"targets": {
8+
"browsers": [
9+
"> 3%"
10+
]
11+
}
12+
}
13+
]
14+
],
15+
"plugins": [
16+
"transform-runtime"
17+
]
18+
}

lstm-text-generation/README.md

+40
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# TensorFlow.js Example: Train LSTM to Generate Text
2+
3+
## Overview
4+
5+
This example illustrates how to use TensorFlow.js to train a LSTM model to
6+
generate random text based on the patterns in a text corpus such as
7+
Nietzsche's writing or the source code of TensorFlow.js itself.
8+
9+
The LSTM model operates at the character level. It takes a tensor of
10+
shape `[numExamples, sampleLen, charSetSize]` as the input. The input is a
11+
one-hot encoding of sequences of `sampleLen` characters. The characters
12+
belong to a set of `charSetSize` unique characters. With the input, the model
13+
outputs a tensor of shape `[numExamples, charSetSize]`, which represents the
14+
model's predicted probabilites of the character that follows the input sequence.
15+
The application then draws a random sample based on the predicted
16+
probabilities to get the next character. Once the next character is obtained,
17+
its one-hot encoding is concatenated with the previous input sequence to form
18+
the input for the next time step. This process is repeated in order to generate
19+
a character sequence of a given length. The randomness (diversity) is controlled
20+
by a temperature parameter.
21+
22+
The UI allows creation of models consisting of a single
23+
[LSTM layer](https://js.tensorflow.org/api/latest/#layers.lstm) or multiple,
24+
stacked LSTM layers.
25+
26+
This example also illustrates how to save a trained model in the browser's
27+
IndexedDB using TensorFlow.js's
28+
[model saving API](https://js.tensorflow.org/tutorials/model-save-load.html),
29+
so that the result of the training
30+
may persist across browser sessions. Once a previously-trained model is loaded
31+
from the IndexedDB, it can be used in text generation and/or further training.
32+
33+
This example is inspired by the LSTM text generation example from Keras:
34+
https://github.com/keras-team/keras/blob/master/examples/lstm_text_generation.py
35+
36+
## Usage
37+
38+
```sh
39+
yarn && yarn watch
40+
```

lstm-text-generation/data.js

+201
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
/**
2+
* @license
3+
* Copyright 2018 Google LLC. All Rights Reserved.
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
* =============================================================================
16+
*/
17+
18+
import * as tf from '@tensorflow/tfjs';
19+
20+
/**
21+
* A class for text data.
22+
*
23+
* This class manages the following:
24+
*
25+
* - Converting training data (as a string) into one-hot encoded vectors.
26+
* - Drawing random slices from the training data. This is useful for training
27+
* models and obtaining the seed text for model-based text generation.
28+
*/
29+
export class TextData {
30+
/**
31+
* Constructor of TextData.
32+
*
33+
* @param {string} dataIdentifier An identifier for this instance of TextData.
34+
* @param {string} textString The training text data.
35+
* @param {number} sampleLen Length of each training example, i.e., the input
36+
* sequence length expected by the LSTM model.
37+
* @param {number} sampleStep How many characters to skip when going from one
38+
* example of the training data (in `textString`) to the next.
39+
*/
40+
constructor(dataIdentifier, textString, sampleLen, sampleStep) {
41+
if (!dataIdentifier) {
42+
throw new Error('Model identifier is not provided.');
43+
}
44+
45+
this.dataIdentifier_ = dataIdentifier;
46+
47+
this.textString_ = textString;
48+
this.textLen_ = textString.length;
49+
this.sampleLen_ = sampleLen;
50+
this.sampleStep_ = sampleStep;
51+
52+
this.getCharSet_();
53+
this.convertAllTextToIndices_();
54+
this.generateExampleBeginIndices_();
55+
}
56+
57+
/**
58+
* Get data identifier.
59+
*
60+
* @returns {string} The data identifier.
61+
*/
62+
dataIdentifier() {
63+
return this.dataIdentifier_;
64+
}
65+
66+
/**
67+
* Get length of the training text data.
68+
*
69+
* @returns {number} Length of training text data.
70+
*/
71+
textLen() {
72+
return this.textLen_;
73+
}
74+
75+
/**
76+
* Get the length of each training example.
77+
*/
78+
sampleLen() {
79+
return this.sampleLen_;
80+
}
81+
82+
/**
83+
* Get the size of the character set.
84+
*
85+
* @returns {number} Size of the character set, i.e., how many unique
86+
* characters there are in the training text data.
87+
*/
88+
charSetSize() {
89+
return this.charSetSize_;
90+
}
91+
92+
/**
93+
* Generate the next epoch of data for training models.
94+
*
95+
* @param {number} numExamples Number examples to generate.
96+
* @returns {[tf.Tensor, tf.Tensor]} `xs` and `ys` Tensors.
97+
* `xs` has the shape of `[numExamples, this.sampleLen, this.charSetSize]`.
98+
* `ys` has the shape of `[numExamples, this.charSetSize]`.
99+
*/
100+
nextDataEpoch(numExamples) {
101+
const xsBuffer = new tf.TensorBuffer([
102+
numExamples, this.sampleLen_, this.charSetSize_]);
103+
const ysBuffer = new tf.TensorBuffer([numExamples, this.charSetSize_]);
104+
for (let i = 0; i < numExamples; ++i) {
105+
const beginIndex = this.exampleBeginIndices_[
106+
this.examplePosition_ % this.exampleBeginIndices_.length];
107+
for (let j = 0; j < this.sampleLen_; ++j) {
108+
xsBuffer.set(1, i, j, this.indices_[beginIndex + j]);
109+
}
110+
ysBuffer.set(1, i, this.indices_[beginIndex + this.sampleLen_]);
111+
this.examplePosition_++;
112+
}
113+
return [xsBuffer.toTensor(), ysBuffer.toTensor()];
114+
}
115+
116+
/**
117+
* Get the unique character at given index from the character set.
118+
*
119+
* @param {number} index
120+
* @returns {string} The unique character at `index` of the character set.
121+
*/
122+
getFromCharSet(index) {
123+
return this.charSet_[index];
124+
}
125+
126+
/**
127+
* Convert text string to integer indices.
128+
*
129+
* @param {string} text Input text.
130+
* @returns {number[]} Indices of the characters of `text`.
131+
*/
132+
textToIndices(text) {
133+
const indices = [];
134+
for (let i = 0; i < text.length; ++i) {
135+
indices.push(this.charSet_.indexOf(text[i]));
136+
}
137+
return indices;
138+
}
139+
140+
/**
141+
* Get a random slice of text data.
142+
*
143+
* @returns {[string, number[]} The string and index representation of the
144+
* same slice.
145+
*/
146+
getRandomSlice() {
147+
const startIndex =
148+
Math.round(Math.random() * (this.textLen_ - this.sampleLen_ - 1));
149+
const textSlice = this.slice_(startIndex, startIndex + this.sampleLen_);
150+
return [textSlice, this.textToIndices(textSlice)];
151+
}
152+
153+
/**
154+
* Get a slice of the training text data.
155+
*
156+
* @param {number} startIndex
157+
* @param {number} endIndex
158+
* @param {bool} useIndices Whether to return the indices instead of string.
159+
* @returns {string | Uint16Array} The result of the slicing.
160+
*/
161+
slice_(startIndex, endIndex) {
162+
return this.textString_.slice(startIndex, endIndex);
163+
}
164+
165+
/**
166+
* Get the set of unique characters from text.
167+
*/
168+
getCharSet_() {
169+
this.charSet_ = [];
170+
for (let i = 0; i < this.textLen_; ++i) {
171+
if (this.charSet_.indexOf(this.textString_[i]) === -1) {
172+
this.charSet_.push(this.textString_[i]);
173+
}
174+
}
175+
this.charSetSize_ = this.charSet_.length;
176+
}
177+
178+
/**
179+
* Convert all training text to integer indices.
180+
*/
181+
convertAllTextToIndices_() {
182+
this.indices_ = new Uint16Array(this.textToIndices(this.textString_));
183+
}
184+
185+
/**
186+
* Generate the example-begin indices; shuffle them randomly.
187+
*/
188+
generateExampleBeginIndices_() {
189+
// Prepare beginning indices of examples.
190+
this.exampleBeginIndices_ = [];
191+
for (let i = 0;
192+
i < this.textLen_ - this.sampleLen_ - 1;
193+
i += this.sampleStep_) {
194+
this.exampleBeginIndices_.push(i);
195+
}
196+
197+
// Randomly shuffle the beginning indices.
198+
tf.util.shuffle(this.exampleBeginIndices_);
199+
this.examplePosition_ = 0;
200+
}
201+
}

0 commit comments

Comments
 (0)