-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathMarkovChain.py
61 lines (43 loc) · 1.76 KB
/
MarkovChain.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from collections import Counter
from operator import itemgetter
import numpy as np
class MarkovChain():
def __init__(self, text, order):
self.text = text
self.order = order
tokens = ['<eos>'] + self.text.split()
n_words = len(set(tokens))
states = zip(*[tokens[i:] for i in range(order)])
n_states = len(set(states))
states_lookup = dict(zip(set(states), range(n_states)))
words_lookup = dict(zip(sorted(set(tokens)), range(n_words)))
counts = np.zeros((n_states, n_words))
for ngram, c in Counter(zip(*[tokens[i:] for i in range(order + 1)])).iteritems():
x = states_lookup[ngram[:order]]
y = words_lookup[ngram[order]]
counts[x, y] = c
with np.errstate(invalid='ignore', divide='ignore'):
P = counts/counts.sum(axis=1)[:, None]
P[np.isnan(P)] = 0
self.begin = sorted([state for state in set(states) if state[0] == '<eos>'],
key=itemgetter(1))
self.vocab = sorted(words_lookup.keys())
self.states = states_lookup
self.P = P
def start_sentence(self):
return self.begin[np.random.choice(len(self.begin))]
def next_word(self, last_state):
row = self.states[last_state]
return np.random.choice(self.vocab, p=self.P[row])
def generate_sentence(self):
last = self.start_sentence()
s = ' '.join(last)
next = self.next_word(last)
while next != '<eos>':
s += ' ' + next
last = last_state(s, self.order)
next = self.next_word(last)
return ' '.join(s.split()[1:])
def last_state(sentence, order):
word_list = sentence.split()
return tuple(word_list[-order:])