Skip to content

Commit

Permalink
nucleus sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
WuTheFWasThat committed Aug 27, 2019
1 parent f35fa1d commit ac5d522
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 3 deletions.
3 changes: 2 additions & 1 deletion src/generate_unconditional_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def sample_model(
length=None,
temperature=1,
top_k=0,
top_p=1,
models_dir='models',
):
"""
Expand Down Expand Up @@ -58,7 +59,7 @@ def sample_model(
hparams=hparams, length=length,
start_token=enc.encoder['<|endoftext|>'],
batch_size=batch_size,
temperature=temperature, top_k=top_k
temperature=temperature, top_k=top_k, top_p=top_p
)[:, 1:]

saver = tf.train.Saver()
Expand Down
3 changes: 2 additions & 1 deletion src/interactive_conditional_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def interact_model(
length=None,
temperature=1,
top_k=0,
top_p=1,
models_dir='models',
):
"""
Expand Down Expand Up @@ -61,7 +62,7 @@ def interact_model(
hparams=hparams, length=length,
context=context,
batch_size=batch_size,
temperature=temperature, top_k=top_k
temperature=temperature, top_k=top_k, top_p=top_p
)

saver = tf.train.Saver()
Expand Down
21 changes: 20 additions & 1 deletion src/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,25 @@ def _top_k():
)


def sample_sequence(*, hparams, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0):
def top_p_logits(logits, p):
"""Nucleus sampling"""
batch, _ = logits.shape.as_list()
sorted_logits = tf.sort(logits, direction='DESCENDING', axis=-1)
cumulative_probs = tf.cumsum(tf.nn.softmax(sorted_logits, axis=-1), axis=-1)
indices = tf.stack([
tf.range(0, batch),
# number of indices to include
tf.maximum(tf.reduce_sum(tf.cast(cumulative_probs <= p, tf.int32), axis=-1) - 1, 0),
], axis=-1)
min_values = tf.gather_nd(sorted_logits, indices)
return tf.where(
logits < min_values,
tf.ones_like(logits) * -1e10,
logits,
)


def sample_sequence(*, hparams, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0, top_p=1):
if start_token is None:
assert context is not None, 'Specify exactly one of start_token and context!'
else:
Expand All @@ -45,6 +63,7 @@ def body(past, prev, output):
next_outputs = step(hparams, prev, past=past)
logits = next_outputs['logits'][:, -1, :] / tf.to_float(temperature)
logits = top_k_logits(logits, k=top_k)
logits = top_p_logits(logits, p=top_p)
samples = tf.multinomial(logits, num_samples=1, output_dtype=tf.int32)
return [
next_outputs['presents'] if past is None else tf.concat([past, next_outputs['presents']], axis=-2),
Expand Down

4 comments on commit ac5d522

@KuCoolCan
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[7085, 2456, 3975, 284, 530, 11241, 11

@KuCoolCan
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

te amo

@KuCoolCan
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

edson y pakal

@KuCoolCan
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

amo a mis hijas perdon si me fui sin avisar

Please sign in to comment.