diff --git a/corextopic/corextopic.py b/corextopic/corextopic.py index 0055b07..38fd46d 100644 --- a/corextopic/corextopic.py +++ b/corextopic/corextopic.py @@ -18,17 +18,17 @@ License: Apache V2 """ -import warnings +from os import makedirs, path + +import joblib import numpy as np # Tested with 1.8.0 -from os import makedirs -from os import path +import scipy.sparse as ss +from six import string_types # For Python 2&3 compatible string checking + try: from scipy.special import logsumexp except ImportError: from scipy.misc import logsumexp # Tested with 0.13.0 -import scipy.sparse as ss -from six import string_types # For Python 2&3 compatible string checking -import joblib class Corex(object): @@ -103,13 +103,15 @@ class Corex(object): def __init__(self, n_hidden=2, max_iter=200, eps=1e-5, seed=None, verbose=False, count='binarize', tree=True, **kwargs): - self.n_hidden = n_hidden # Number of hidden factors to use (Y_1,...Y_m) in paper + # Number of hidden factors to use (Y_1,...Y_m) in paper + self.n_hidden = n_hidden self.max_iter = max_iter # Maximum number of updates to run, regardless of convergence self.eps = eps # Change to signal convergence self.tree = tree np.random.seed(seed) # Set seed for deterministic results self.verbose = verbose - self.t = 20 # Initial softness of the soft-max function for alpha (see NIPS paper [1]) + # Initial softness of the soft-max function for alpha (see NIPS paper [1]) + self.t = 20 self.count = count # Which strategy, if necessary, for binarizing count data if verbose > 0: np.set_printoptions(precision=3, suppress=True, linewidth=200) @@ -150,7 +152,8 @@ def fit(self, X, y=None, anchors=None, anchor_strength=1, words=None, docs=None) """ Fit CorEx on the data X. See fit_transform. """ - self.fit_transform(X, anchors=anchors, anchor_strength=anchor_strength, words=words, docs=docs) + self.fit_transform( + X, anchors=anchors, anchor_strength=anchor_strength, words=words, docs=docs) return self def fit_transform(self, X, y=None, anchors=None, anchor_strength=1, words=None, docs=None): @@ -182,7 +185,8 @@ def fit_transform(self, X, y=None, anchors=None, anchor_strength=1, words=None, p_y_given_x = np.random.random((self.n_samples, self.n_hidden)) if anchors is not None: for j, a in enumerate(anchors): - p_y_given_x[:, j] = 0.5 * p_y_given_x[:, j] + 0.5 * X[:, a].mean(axis=1).A1 # Assumes X is a binary matrix + p_y_given_x[:, j] = 0.5 * p_y_given_x[:, j] + 0.5 * \ + X[:, a].mean(axis=1).A1 # Assumes X is a binary matrix for nloop in range(self.max_iter): if nloop > 1: @@ -191,10 +195,12 @@ def fit_transform(self, X, y=None, anchors=None, anchor_strength=1, words=None, # Switch label for Y_j so that it is correlated with the top word p_y_given_x[:, j] = 1. - p_y_given_x[:, j] self.log_p_y = self.calculate_p_y(p_y_given_x) - self.theta = self.calculate_theta(X, p_y_given_x, self.log_p_y) # log p(x_i=1|y) nv by m by k + self.theta = self.calculate_theta( + X, p_y_given_x, self.log_p_y) # log p(x_i=1|y) nv by m by k if nloop > 0: # Structure learning step - self.alpha = self.calculate_alpha(X, p_y_given_x, self.theta, self.log_p_y, self.tcs) + self.alpha = self.calculate_alpha( + X, p_y_given_x, self.theta, self.log_p_y, self.tcs) if anchors is not None: for a in flatten(anchors): self.alpha[:, a] = 0 @@ -203,7 +209,8 @@ def fit_transform(self, X, y=None, anchors=None, anchor_strength=1, words=None, p_y_given_x, _, log_z = self.calculate_latent(X, self.theta) - self.update_tc(log_z) # Calculate TC and record history to check convergence + # Calculate TC and record history to check convergence + self.update_tc(log_z) self.print_verbose() if self.convergence(): break @@ -213,8 +220,10 @@ def fit_transform(self, X, y=None, anchors=None, anchor_strength=1, words=None, if anchors is None: self.sort_and_output(X) - self.p_y_given_x, self.log_p_y_given_x, self.log_z = self.calculate_latent(X, self.theta) # Needed to output labels - self.mis = self.calculate_mis(self.theta, self.log_p_y) # / self.h_x # could normalize MIs + self.p_y_given_x, self.log_p_y_given_x, self.log_z = self.calculate_latent( + X, self.theta) # Needed to output labels + # / self.h_x # could normalize MIs + self.mis = self.calculate_mis(self.theta, self.log_p_y) return self.labels def transform(self, X, details=False): @@ -236,11 +245,16 @@ def transform(self, X, details=False): log_p = np.empty((2, n_samples, self.n_hidden)) c0 = np.einsum('ji,ij->j', alpha, self.theta[0]) c1 = np.einsum('ji,ij->j', alpha, self.theta[1]) # length n_hidden - info0 = np.einsum('ji,ij->ij', alpha, self.theta[2] - self.theta[0]) - info1 = np.einsum('ji,ij->ij', alpha, self.theta[3] - self.theta[1]) - log_p[1] = c1 + X.dot(info1) # sum_i log p(xi=xi^l|y_j=1) # Shape is 2 by l by j + info0 = np.einsum('ji,ij->ij', alpha, + self.theta[2] - self.theta[0]) + info1 = np.einsum('ji,ij->ij', alpha, + self.theta[3] - self.theta[1]) + # sum_i log p(xi=xi^l|y_j=1) # Shape is 2 by l by j + log_p[1] = c1 + X.dot(info1) log_p[0] = c0 + X.dot(info0) # sum_i log p(xi=xi^l|y_j=0) - surprise = [-np.sum([log_p[labels[l, j], l, j] for j in range(self.n_hidden)]) for l in range(n_samples)] + surprise = [-np.sum([log_p[labels[l, j], l, j] + for j in range(self.n_hidden)]) + for l in range(n_samples)] return p_y_given_x, log_z, np.array(surprise) elif details: return p_y_given_x, log_z @@ -267,7 +281,8 @@ def preprocess(self, X): doc_length = ss.diags(1. / length, 0) # max_counts = ss.diags(1. / X.max(axis=1).A.ravel(), 0) X = doc_length * X * bg_rate - X.data = np.clip(X.data, 0, 1) # np.log(X.data) / (np.log(X.data) + 1) + # np.log(X.data) / (np.log(X.data) + 1) + X.data = np.clip(X.data, 0, 1) return X def initialize_parameters(self, X, words, docs): @@ -284,9 +299,11 @@ def initialize_parameters(self, X, words, docs): X.sum(axis=0)).ravel() # 1-d array of total word occurrences. (Probably slow for CSR) if np.any(self.word_counts == 0) or np.any(self.word_counts == self.n_samples): print('WARNING: Some words never appear (or always appear)') - self.word_counts = self.word_counts.clip(0.01, self.n_samples - 0.01) + self.word_counts = self.word_counts.clip( + 0.01, self.n_samples - 0.01) self.word_freq = (self.word_counts).astype(float) / self.n_samples - self.px_frac = (np.log1p(-self.word_freq) - np.log(self.word_freq)).reshape((-1, 1)) # nv by 1 + self.px_frac = (np.log1p(-self.word_freq) - + np.log(self.word_freq)).reshape((-1, 1)) # nv by 1 self.lp0 = np.log1p(-self.word_freq).reshape((-1, 1)) # log p(x_i=0) self.h_x = binary_entropy(self.word_freq) if self.verbose: @@ -295,9 +312,10 @@ def initialize_parameters(self, X, words, docs): self.words = words if words is not None: if len(words) != X.shape[1]: - print('WARNING: number of column labels != number of columns of X. Check len(words) and X.shape[1]') - col_index2word = {index:word for index,word in enumerate(words)} - word2col_index = {word:index for index,word in enumerate(words)} + print( + 'WARNING: number of column labels != number of columns of X. Check len(words) and X.shape[1]') + col_index2word = {index: word for index, word in enumerate(words)} + word2col_index = {word: index for index, word in enumerate(words)} self.col_index2word = col_index2word self.word2col_index = word2col_index else: @@ -307,8 +325,9 @@ def initialize_parameters(self, X, words, docs): self.docs = docs if docs is not None: if len(docs) != X.shape[0]: - print('WARNING: number of row labels != number of rows of X. Check len(docs) and X.shape[0]') - row_index2doc = {index:doc for index,doc in enumerate(docs)} + print( + 'WARNING: number of row labels != number of rows of X. Check len(docs) and X.shape[0]') + row_index2doc = {index: doc for index, doc in enumerate(docs)} self.row_index2doc = row_index2doc else: self.row_index2doc = None @@ -324,9 +343,11 @@ def update_word_parameters(self, X, words): X.sum(axis=0)).ravel() # 1-d array of total word occurrences. (Probably slow for CSR) if np.any(self.word_counts == 0) or np.any(self.word_counts == self.n_samples): print('WARNING: Some words never appear (or always appear)') - self.word_counts = self.word_counts.clip(0.01, self.n_samples - 0.01) + self.word_counts = self.word_counts.clip( + 0.01, self.n_samples - 0.01) self.word_freq = (self.word_counts).astype(float) / self.n_samples - self.px_frac = (np.log1p(-self.word_freq) - np.log(self.word_freq)).reshape((-1, 1)) # nv by 1 + self.px_frac = (np.log1p(-self.word_freq) - + np.log(self.word_freq)).reshape((-1, 1)) # nv by 1 self.lp0 = np.log1p(-self.word_freq).reshape((-1, 1)) # log p(x_i=0) self.h_x = binary_entropy(self.word_freq) if self.verbose: @@ -334,9 +355,10 @@ def update_word_parameters(self, X, words): self.words = words if words is not None: if len(words) != X.shape[1]: - print('WARNING: number of column labels != number of columns of X. Check len(words) and X.shape[1]') - col_index2word = {index:word for index,word in enumerate(words)} - word2col_index = {word:index for index,word in enumerate(words)} + print( + 'WARNING: number of column labels != number of columns of X. Check len(words) and X.shape[1]') + col_index2word = {index: word for index, word in enumerate(words)} + word2col_index = {word: index for index, word in enumerate(words)} self.col_index2word = col_index2word self.word2col_index = word2col_index else: @@ -358,12 +380,15 @@ def preprocess_anchors(self, anchors): if isinstance(anchor, string_types): if self.words is not None: if anchor in self.word2col_index: - new_anchor_list.append(self.word2col_index[anchor]) + new_anchor_list.append( + self.word2col_index[anchor]) else: - w = 'WARNING: Anchor word not in word column labels provided to CorEx: {}'.format(anchor) + w = 'WARNING: Anchor word not in word column labels provided to CorEx: {}'.format( + anchor) print(w) else: - raise NameError("Provided non-index anchors to CorEx without also providing 'words'") + raise NameError( + "Provided non-index anchors to CorEx without also providing 'words'") else: new_anchor_list.append(anchor) # Update anchors with new anchor list @@ -387,7 +412,8 @@ def calculate_theta(self, X, p_y_given_x, log_p_y): p_dot_y = X.T.dot(p_y_given_x).clip(0.01 * np.exp(log_p_y), (n_samples - 0.01) * np.exp( log_p_y)) # nv by ns dot ns by m -> nv by m # TODO: Change to CSC for speed? lp_1g1 = np.log(p_dot_y) - np.log(n_samples) - log_p_y - lp_1g0 = np.log(self.word_counts[:, np.newaxis] - p_dot_y) - np.log(n_samples) - log_1mp(log_p_y) + lp_1g0 = np.log( + self.word_counts[:, np.newaxis] - p_dot_y) - np.log(n_samples) - log_1mp(log_p_y) lp_0g0 = log_1mp(lp_1g0) lp_0g1 = log_1mp(lp_1g1) return np.array([lp_0g0, lp_0g1, lp_1g0, lp_1g1]) # 4 by nv by m @@ -407,7 +433,8 @@ def calculate_alpha(self, X, p_y_given_x, theta, log_p_y, tcs): # t = 20 + (20 * np.abs(tcs) / tc_oom).reshape((self.n_hidden, 1)) # worked well in many tests t = (1 + self.t * np.abs(tcs).reshape((self.n_hidden, 1))) maxmis = np.max(mis, axis=0) - for i in np.where((mis == maxmis).sum(axis=0))[0]: # Break ties for the largest MI + # Break ties for the largest MI + for i in np.where((mis == maxmis).sum(axis=0))[0]: mis[:, i] += 1e-10 * np.random.random(self.n_hidden) maxmis[i] = np.max(mis[:, i]) with np.errstate(under='ignore'): @@ -425,9 +452,12 @@ def calculate_latent(self, X, theta): ns, nv = X.shape log_pygx_unnorm = np.empty((2, ns, self.n_hidden)) c0 = np.einsum('ji,ij->j', self.alpha, theta[0] - self.lp0) - c1 = np.einsum('ji,ij->j', self.alpha, theta[1] - self.lp0) # length n_hidden - info0 = np.einsum('ji,ij->ij', self.alpha, theta[2] - theta[0] + self.px_frac) - info1 = np.einsum('ji,ij->ij', self.alpha, theta[3] - theta[1] + self.px_frac) + c1 = np.einsum('ji,ij->j', self.alpha, + theta[1] - self.lp0) # length n_hidden + info0 = np.einsum('ji,ij->ij', self.alpha, + theta[2] - theta[0] + self.px_frac) + info1 = np.einsum('ji,ij->ij', self.alpha, + theta[3] - theta[1] + self.px_frac) log_pygx_unnorm[1] = self.log_p_y + c1 + X.dot(info1) log_pygx_unnorm[0] = log_1mp(self.log_p_y) + c0 + X.dot(info0) return self.normalize_latent(log_pygx_unnorm) @@ -454,7 +484,8 @@ def normalize_latent(self, log_pygx_unnorm): """ with np.errstate(under='ignore'): - log_z = logsumexp(log_pygx_unnorm, axis=0) # Essential to maintain precision. + # Essential to maintain precision. + log_z = logsumexp(log_pygx_unnorm, axis=0) log_pygx = log_pygx_unnorm[1] - log_z p_norm = np.exp(log_pygx) return p_norm.clip(1e-6, 1 - 1e-6), log_pygx, log_z # ns by m @@ -472,7 +503,8 @@ def print_verbose(self): def convergence(self): if len(self.tc_history) > 10: - dist = -np.mean(self.tc_history[-10:-5]) + np.mean(self.tc_history[-5:]) + dist = -np.mean(self.tc_history[-10:-5]) + \ + np.mean(self.tc_history[-5:]) return np.abs(dist) < self.eps # Check for convergence. else: return False @@ -483,7 +515,7 @@ def __getstate__(self): self_dict = self.__dict__.copy() return self_dict - def save(self, filename, ensure_compatibility = True): + def save(self, filename, ensure_compatibility=True): """ Pickle a class instance. E.g., corex.save('saved.pkl') When set to True, ensure_compatibility resets self.words before saving @@ -494,7 +526,7 @@ def save(self, filename, ensure_compatibility = True): the topics via get_topics(). """ # Avoid saving words with object. - #TODO: figure out why Unicode sometimes causes an issue with loading after pickling + # TODO: figure out why Unicode sometimes causes an issue with loading after pickling temp_words = self.words if ensure_compatibility and (self.words is not None): self.words = None @@ -503,7 +535,10 @@ def save(self, filename, ensure_compatibility = True): import pickle if path.dirname(filename) and not path.exists(path.dirname(filename)): makedirs(path.dirname(filename)) - pickle.dump(self, open(filename, 'wb'), protocol=-1) + + with open(filename, "wb") as fp: + pickle.dump(self, fp, protocol=-1) + # Restore words to CorEx object self.words = temp_words @@ -523,17 +558,23 @@ def save_joblib(self, filename): self.words = temp_words def sort_and_output(self, X): - order = np.argsort(self.tcs)[::-1] # Order components from strongest TC to weakest + # Order components from strongest TC to weakest + order = np.argsort(self.tcs)[::-1] self.tcs = self.tcs[order] # TC for each component self.alpha = self.alpha[order] # Connections between X_i and Y_j - self.log_p_y = self.log_p_y[order] # Parameters defining the representation - self.theta = self.theta[:, :, order] # Parameters defining the representation + # Parameters defining the representation + self.log_p_y = self.log_p_y[order] + # Parameters defining the representation + self.theta = self.theta[:, :, order] def calculate_mis(self, theta, log_p_y): """Return MI in nats, size n_hidden by n_variables""" p_y = np.exp(log_p_y).reshape((-1, 1)) # size n_hidden, 1 - mis = self.h_x - p_y * binary_entropy(np.exp(theta[3]).T) - (1 - p_y) * binary_entropy(np.exp(theta[2]).T) - return (mis - 1. / (2. * self.n_samples)).clip(0.) # P-T bias correction + mis = self.h_x - p_y * \ + binary_entropy(np.exp(theta[3]).T) - (1 - p_y) * \ + binary_entropy(np.exp(theta[2]).T) + # P-T bias correction + return (mis - 1. / (2. * self.n_samples)).clip(0.) def get_topics(self, n_words=10, topic=None, print_words=True): """ @@ -554,22 +595,26 @@ def get_topics(self, n_words=10, topic=None, print_words=True): # Determine whether to return column word labels or indices if self.words is None: print_words = False - print("NOTE: 'words' not provided to CorEx. Returning topics as lists of column indices") + print( + "NOTE: 'words' not provided to CorEx. Returning topics as lists of column indices") elif len(self.words) != self.alpha.shape[1]: print_words = False - print('WARNING: number of column labels != number of columns of X. Cannot reliably add labels to topics. Check len(words) and X.shape[1]. Use .set_words() to fix') + print( + 'WARNING: number of column labels != number of columns of X. Cannot reliably add labels to topics. Check len(words) and X.shape[1]. Use .set_words() to fix') - topics = [] # TODO: make this faster, it's slower than it should be + topics = [] # TODO: make this faster, it's slower than it should be for n in topic_ns: # Get indices of which words belong to the topic inds = np.where(self.alpha[n] >= 1.)[0] # Sort topic words according to mutual information - inds = inds[np.argsort(-self.alpha[n,inds] * self.mis[n,inds])] + inds = inds[np.argsort(-self.alpha[n, inds] * self.mis[n, inds])] # Create topic to return if print_words is True: - topic = [(self.col_index2word[ind], self.sign[n,ind]*self.mis[n,ind]) for ind in inds[:n_words]] + topic = [(self.col_index2word[ind], self.sign[n, ind] + * self.mis[n, ind]) for ind in inds[:n_words]] else: - topic = [(ind, self.sign[n,ind]*self.mis[n,ind]) for ind in inds[:n_words]] + topic = [(ind, self.sign[n, ind]*self.mis[n, ind]) + for ind in inds[:n_words]] # Add topic to list of topics if returning all topics. Otherwise, return topic if len(topic_ns) != 1: topics.append(topic) @@ -599,10 +644,12 @@ def get_top_docs(self, n_docs=10, topic=None, sort_by='log_prob', print_docs=Tru # Determine whether to return row doc labels or indices if self.docs is None: print_docs = False - print("NOTE: 'docs' not provided to CorEx. Returning top docs as lists of row indices") + print( + "NOTE: 'docs' not provided to CorEx. Returning top docs as lists of row indices") elif len(self.docs) != self.labels.shape[0]: - print_words = False - print('WARNING: number of row labels != number of rows of X. Cannot reliably add labels. Check len(docs) and X.shape[0]. Use .set_docs() to fix') + print_docs = False + print( + 'WARNING: number of row labels != number of rows of X. Cannot reliably add labels. Check len(docs) and X.shape[0]. Use .set_docs() to fix') # Get appropriate matrix to sort if sort_by == 'log_prob': doc_values = self.log_p_y_given_x @@ -614,12 +661,14 @@ def get_top_docs(self, n_docs=10, topic=None, sort_by='log_prob', print_docs=Tru return # Get top docs for each topic doc_inds = np.argsort(-doc_values, axis=0) - top_docs = [] # TODO: make this faster, it's slower than it should be + top_docs = [] # TODO: make this faster, it's slower than it should be for n in topic_ns: if print_docs is True: - topic_docs = [(self.row_index2doc[ind], doc_values[ind,n]) for ind in doc_inds[:n_docs,n]] + topic_docs = [(self.row_index2doc[ind], doc_values[ind, n]) + for ind in doc_inds[:n_docs, n]] else: - topic_docs = [(ind, doc_values[ind,n]) for ind in doc_inds[:n_docs,n]] + topic_docs = [(ind, doc_values[ind, n]) + for ind in doc_inds[:n_docs, n]] # Add docs to list of top docs per topic if returning all topics. Otherwise, return if len(topic_ns) != 1: top_docs.append(topic_docs) @@ -632,9 +681,10 @@ def set_words(self, words): self.words = words if words is not None: if len(words) != self.alpha.shape[1]: - print('WARNING: number of column labels != number of columns of X. Check len(words) and .alpha.shape[1]') - col_index2word = {index:word for index,word in enumerate(words)} - word2col_index = {word:index for index,word in enumerate(words)} + print( + 'WARNING: number of column labels != number of columns of X. Check len(words) and .alpha.shape[1]') + col_index2word = {index: word for index, word in enumerate(words)} + word2col_index = {word: index for index, word in enumerate(words)} self.col_index2word = col_index2word self.word2col_index = word2col_index @@ -642,8 +692,9 @@ def set_docs(self, docs): self.docs = docs if docs is not None: if len(docs) != self.labels.shape[0]: - print('WARNING: number of row labels != number of rows of X. Check len(docs) and .labels.shape[0]') - row_index2doc = {index:doc for index,doc in enumerate(docs)} + print( + 'WARNING: number of row labels != number of rows of X. Check len(docs) and .labels.shape[0]') + row_index2doc = {index: doc for index, doc in enumerate(docs)} self.row_index2doc = row_index2doc @@ -668,7 +719,8 @@ def flatten(a): def load(filename): """ Unpickle class instance. """ import pickle - return pickle.load(open(filename, 'rb')) + with open(filename, 'rb') as fp: + return pickle.load(fp) def load_joblib(filename):