forked from sekhansen/text-mining-tutorial
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtutorial.py
82 lines (60 loc) · 2.37 KB
/
tutorial.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
"""
(c) 2015, Stephen Hansen, [email protected]
Python script for tutorial illustrating collapsed Gibbs sampling for Latent Dirichlet Allocation.
See explanation for commands on http://nbviewer.ipython.org/github/sekhansen/text-mining-tutorial/blob/master/tutorial_notebook.ipynb.
"""
import pandas as pd
import topicmodels
########## select data on which to run topic model #########
data = pd.read_table("speech_data_extend.txt",encoding="utf-8")
data = data[data.year >= 1947]
########## clean documents #########
docsobj = topicmodels.RawDocs(data.speech, "long")
docsobj.token_clean(1)
docsobj.stopword_remove("tokens")
docsobj.stem()
docsobj.stopword_remove("stems")
docsobj.term_rank("stems")
docsobj.rank_remove("tfidf","stems",docsobj.tfidf_ranking[5000][1])
all_stems = [s for d in docsobj.stems for s in d]
print("number of unique stems = %d" % len(set(all_stems)))
print("number of total stems = %d" % len(all_stems))
########## estimate topic model #########
ldaobj = topicmodels.LDA.LDAGibbs(docsobj.stems,30)
ldaobj.sample(0,50,10)
ldaobj.sample(0,50,10)
ldaobj.samples_keep(4)
ldaobj.topic_content(20)
dt = ldaobj.dt_avg()
tt = ldaobj.tt_avg()
ldaobj.dict_print()
data = data.drop('speech',1)
for i in xrange(ldaobj.K): data['T' + str(i)] = dt[:,i]
data.to_csv("final_output.csv",index=False)
########## query aggregate documents #########
data['speech'] = [' '.join(s) for s in docsobj.stems]
aggspeeches = data.groupby(['year','president'])['speech'].apply(lambda x: ' '.join(x))
aggdocs = topicmodels.RawDocs(aggspeeches)
queryobj = topicmodels.LDA.QueryGibbs(aggdocs.tokens,ldaobj.token_key,ldaobj.tt)
queryobj.query(10)
queryobj.perplexity()
queryobj.query(30)
queryobj.perplexity()
dt_query = queryobj.dt_avg()
aggdata = pd.DataFrame(dt_query,index=aggspeeches.index,columns=['T' + str(i) for i in xrange(queryobj.K)])
aggdata.to_csv("final_output_agg.csv")
########## top topics #########
def top_topics(x):
top = x.values.argsort()[-5:][::-1]
return(pd.Series(top,index=range(1,6)))
temp = aggdata.reset_index()
ranking = temp.set_index('president')
ranking = ranking - ranking.mean()
ranking = ranking.groupby(level='president').mean()
ranking = ranking.sort('year')
ranking = ranking.drop('year',1)
ranking = ranking.apply(top_topics,axis=1)
ranking.to_csv("president_top_topics.csv")
########## Plot topics #########
pltobj, dendrogram = ldaobj.dendrogram()
pltobj.show()