Skip to content

Commit

Permalink
fix lda bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
He Yunlong committed Aug 31, 2015
1 parent 142736c commit 191981c
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 72 deletions.
22 changes: 20 additions & 2 deletions src/main/java/com/intel/distml/api/BigModelWriter.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,41 @@
import com.intel.distml.util.KeyRange;
import com.intel.distml.util.Matrix;

import java.util.HashMap;

/**
* Created by yunlong on 4/29/15.
*/
public class BigModelWriter implements ModelWriter {

private static final int MAX_FETCH_SIZE = 2048000000; // 2g
private static final int MAX_FETCH_SIZE = 204800000; // 200M

protected int maxFetchRows;

protected HashMap<String, Integer> fetchBatchSizes;

public BigModelWriter() {
this(10);
}

public BigModelWriter(int estimatedRowSize) {
maxFetchRows = MAX_FETCH_SIZE / estimatedRowSize;
fetchBatchSizes = new HashMap<String, Integer>();
}

public void setParamRowSize(String matrixName, int rowSize) {
fetchBatchSizes.put(matrixName, MAX_FETCH_SIZE/rowSize);
}

@Override
public void writeModel (Model model, ServerDataBus dataBus) {

for (String matrixName : model.dataMap.keySet()) {

int batchSize = maxFetchRows;
if (fetchBatchSizes.containsKey(matrixName)) {
batchSize = fetchBatchSizes.get(matrixName).intValue();
}
DMatrix m = model.dataMap.get (matrixName);
if (m.hasFlag(DMatrix.FLAG_PARAM)) {
long size = m.getRowKeys().size();
Expand All @@ -32,7 +50,7 @@ public void writeModel (Model model, ServerDataBus dataBus) {
m.setLocalCache(null);
long start = 0L;
while(start < size-1) {
long end = Math.min(start + maxFetchRows, size) - 1;
long end = Math.min(start + batchSize, size) - 1;
KeyRange range = new KeyRange(start, end);
System.out.println("fetch param: " + range);
Matrix result = dataBus.fetchFromServer (matrixName, range);
Expand Down
77 changes: 27 additions & 50 deletions src/main/java/com/intel/distml/model/lda/LDAModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,25 +20,24 @@ public class LDAModel extends Model {
private float alpha;
private float beta;
private int K;
public Dictionary dict;

private int vocabularySize;

private float[] p;//temp variables for sampling

public LDAModel(float _alpha,float _beta,int _K,Dictionary _dict){
public LDAModel(float _alpha, float _beta, int _K, int _v){

dataSetImmutable = false;
this.autoFetchParams=false;
this.autoPushUpdates=false;
this.autoFetchParams = false;
this.autoPushUpdates = false;

this.alpha=_alpha;
this.beta=_beta;
this.K=_K;
this.dict=_dict;
this.p=new float[_K];
this.alpha =_alpha;
this.beta =_beta;
this.K =_K;
this.p = new float[_K];
this.vocabularySize = _v;

registerMatrix(LDAModel.MATRIX_PARAM_WORDTOPIC,new ParamWordTopic(dict.getSize(),K));
registerMatrix(LDAModel.MATRIX_PARAM_TOPIC,new ParamTopic(K));
registerMatrix(LDAModel.MATRIX_PARAM_WORDTOPIC, new ParamWordTopic(vocabularySize, K));
registerMatrix(LDAModel.MATRIX_PARAM_TOPIC, new ParamTopic(K));

}

Expand Down Expand Up @@ -118,70 +117,48 @@ public void compute(Matrix sample, int workerIndex, DataBus dataBus,final int it
dataBus.pushUpdate(LDAModel.MATRIX_PARAM_TOPIC,topicsUpdate);

}
//Help functions
LDADataMatrix sampling(LDADataMatrix ldaData,HashMapMatrix wordTopics,Topic topics){
// System.out.print("before gibbs sampling,the topics are ");
// for(int i=0;i<ldaData.topics.length;i++)System.out.print(ldaData.topics[i]+",");
// System.out.println();
//
// System.out.print("before gibbs sampling topic array:");
// for(int i=0;i<K;i++)System.out.print(topics.element(i)+",");
// System.out.println();

long start=System.currentTimeMillis();
//Help functions
LDADataMatrix sampling(LDADataMatrix ldaData, HashMapMatrix wordTopics, Topic topics){

Integer[] numTopic=(Integer[])topics.values;
int[] numDocTopic=ldaData.nDocTopic;
for(int i=0;i<ldaData.words.length;i++){
Integer[] numTopic = (Integer[])topics.values;
int[] numDocTopic = ldaData.nDocTopic;
for(int i = 0; i < ldaData.words.length; i++) {

int topic=ldaData.topics[i];
int wordID=ldaData.words[i];
int topic = ldaData.topics[i];
int wordID = ldaData.words[i];

Integer[] thisWordTopics=(Integer[])wordTopics.get(wordID);
Integer[] thisWordTopics = (Integer[])wordTopics.get(wordID);

numTopic[topic]--;
thisWordTopics[topic]--;
numDocTopic[topic]--;

float Vbeta = dict.getSize() * this.beta;
// System.out.println("Vbeta:" + Vbeta + "beta: " + this.beta +" dict size:" +dict.getSize());
float Vbeta = vocabularySize * this.beta;
float Kalpha = this.K * this.alpha;

for(int k=0;k<K;k++){
for(int k = 0; k < K; k++) {
this.p[k] = (thisWordTopics[k] + beta) / (numTopic[k] + Vbeta) *
(numDocTopic[k] + alpha) / (ldaData.words.length - 1 + Kalpha);
// if(wordID==0)System.out.println("======" + (thisWordTopics[k] + beta) + "/" + numTopic[k] + "+" + Vbeta +
// "*" + (numDocTopic[k] + alpha) + "/" + (ldaData.words.length - 1 + Kalpha) + "===the P is: "+ this.p[k] +"========");
}

for(int k=1;k<K;k++){
this.p[k]+=this.p[k-1];
for (int k = 1; k < K; k++) {
this.p[k] += this.p[k-1];
}

double u=Math.random()*this.p[K-1];
// double u=0.4*this.p[K-1];
double u = Math.random() * this.p[K-1];

for(topic=0;topic<K;topic++){
if(p[topic]>=u)break;
for(topic = 0; topic < K; topic++){
if(p[topic] >= u) break;
}

if(topic==K)System.out.println("max p: " +this.p[K-1] + " u: " + u);
numTopic[topic]++;
thisWordTopics[topic]++;
numDocTopic[topic]++;

ldaData.topics[i]=topic;
ldaData.topics[i] = topic;
}

// System.out.print("after gibbs sampling topic array:");
// for(int i=0;i<K;i++)System.out.print(topics.element(i)+",");
// System.out.println();
//
// System.out.print("after gibbs sampling,the topics are ");
// for(int i=0;i<ldaData.topics.length;i++)System.out.print(ldaData.topics[i]+",");
// System.out.println();

System.out.println("LDA compute with time :" + (System.currentTimeMillis() - start));
return ldaData;
}

Expand Down
12 changes: 9 additions & 3 deletions src/main/java/com/intel/distml/model/lda/LDAModelWriter.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@ static class Pair{
}
}

Dictionary dict;

public LDAModelWriter(Dictionary dict) {
this.dict = dict;
}

@Override
public void writeModel(Model model, ServerDataBus dataBus) {
for (String matrixName: model.dataMap.keySet()) {
Expand Down Expand Up @@ -77,7 +83,7 @@ public void writeToFile(String matrixName, Matrix m, String filePath,Model model
for(int i=0;i<values.length;i++) {

StringBuilder tmp = new StringBuilder();
tmp.append(((LDAModel)model).dict.getWord(i));
tmp.append(dict.getWord(i));
tmp.append(":");
for(int j = 0;j<values[i].length;j++) {
tmp.append(values[i][j]);
Expand All @@ -101,7 +107,7 @@ public void writeToFile(String matrixName, Matrix m, String filePath,Model model
bw = new BufferedWriter(new FileWriter(new File("rank-" + filePath)));


int Num=((LDAModel)model).dict.getSize() > 5?5:((LDAModel)model).dict.getSize();
int Num = dict.getSize() > 5? 5 : dict.getSize();
Pair[] rankedWords = new Pair[Num];
for(int i = 0;i < Num;i++){
rankedWords[i]=new Pair(0,0);
Expand All @@ -121,7 +127,7 @@ public void writeToFile(String matrixName, Matrix m, String filePath,Model model
rankedWords[j].wordID=pos;
rankedWords[j].count=max;
values[pos][i]=0;//remove the max number
bw.write(((LDAModel) model).dict.getWord(pos) + ":" +max);
bw.write(dict.getWord(pos) + ":" + max);
bw.newLine();
}
}
Expand Down
18 changes: 11 additions & 7 deletions src/main/java/com/intel/distml/util/GeneralMatrix.java
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ else if (colKeys.equals(m.colKeys)) {

long newFirst = Math.min(keys1.firstKey, keys2.firstKey);
long newLast = Math.max(keys1.lastKey, keys2.lastKey);
int newSize = (int) (newLast - newFirst);
int newSize = (int) (newLast - newFirst + 1);
int[] dims = new int[] {newSize, (int)colKeys.size()};
T[][] _v = (T[][]) Array.newInstance(values[0][0].getClass(), dims);

Expand All @@ -264,11 +264,13 @@ else if (colKeys.equals(m.colKeys)) {
}
}

int offset = (int) (keys2.firstKey - keys1.firstKey);
if (keys2.lastKey > keys1.lastKey) {
for (int i = (int)keys1.size(); i < (keys2.lastKey); i++) {
int _offset = (int) keys1.size();
int offset = (int) (keys1.lastKey - keys2.firstKey + 1);
int count = (int) (keys2.lastKey - keys1.lastKey);
for (int i = 0; i < count; i++) {
for (int j = 0; j < colKeys.size(); j++) {
_v[i][j] = mValues[i - offset][j];
_v[i + _offset][j] = mValues[i + offset][j];
}
}
}
Expand All @@ -281,10 +283,12 @@ else if (colKeys.equals(m.colKeys)) {
}

if (keys1.lastKey > keys2.lastKey) {
int offset = (int) (keys1.firstKey - keys2.firstKey);
for (int i = (int)keys2.size(); i < (keys1.lastKey); i++) {
int _offset = (int) keys2.size();
int offset = (int) (keys2.lastKey - keys1.firstKey + 1);
int count = (int) (keys1.lastKey - keys2.lastKey);
for (int i = 0; i < count; i++) {
for (int j = 0; j < colKeys.size(); j++) {
_v[i][j] = values[i - offset][j];
_v[i + _offset][j] = values[i + offset][j];
}
}
}
Expand Down
21 changes: 13 additions & 8 deletions src/main/scala/com/intel/distml/app/lda/LDA.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package com.intel.distml.app.lda

import com.intel.distml.api.Model
import com.intel.distml.api.{BigModelWriter, Model}
import com.intel.distml.model.lda.{LDAModel, LDADataMatrix, Dictionary}
import com.intel.distml.platform.{TrainingContext, TrainingHelper}
import org.apache.spark.{SparkContext, SparkConf}
Expand All @@ -14,11 +14,12 @@ import scala.collection.mutable.ListBuffer
object LDA {

def normalizeString(src : String) : String = {
src.replaceAll("[^A-Z,^a-z]", " ").trim().toLowerCase();
src.replaceAll("[^A-Z^a-z]", " ").trim().toLowerCase();
}

var dic = new Dictionary()
val K=10//topic number
val K = 2000 //topic number

def main(args: Array[String]) {

var sparkMaster = args(0)
Expand All @@ -31,7 +32,7 @@ object LDA {
val conf = new SparkConf()
.setMaster(sparkMaster)
.setAppName("LDA")
.set("spark.executor.memory", "10g")
.set("spark.executor.memory", sparkMem)
.set("spark.home", sparkHome)
.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
.setJars(Seq(appJars))
Expand All @@ -41,7 +42,7 @@ object LDA {

println("====================start: ==========")
// val trainingFile: String = "hdfs://dl-s1:9000/usr/ruixiang/lda/newdocs2.dat"
val trainingFile: String = "hdfs://dl-s1:9000/data/wiki/WestburyLab.wikicorp.201004"
val trainingFile: String = "hdfs://dl-s1:9000/data/wiki/wiki_1000000"

var rawLines = spark.textFile(trainingFile).map(normalizeString).filter(s => s.length > 0)
// rawLines = rawLines.repartition(2);
Expand All @@ -51,19 +52,23 @@ object LDA {
words.foreach(x=>dic.addWord(x))

println("====================the word number: " + words.size + " ==========")
for (i <- 0 to 99)
println(words(i))

var rddTopic = rawLines.mapPartitions(transFromString2LDAData)
//rddTopic=rddTopic.repartition(1);

val config: TrainingContext = new TrainingContext();
config.iteration(100);
config.iteration(1);
config.miniBatchSize(20);
config.psCount(1);
// config.workerCount(1)

val m: Model = new LDAModel(0.5f, 0.1f,K,dic)
val m: Model = new LDAModel(0.5f, 0.1f, K, dic.getSize)

TrainingHelper.startTraining(spark, m, rddTopic, config);
val writer = new BigModelWriter();
writer.setParamRowSize(LDAModel.MATRIX_PARAM_WORDTOPIC, K * 10);
TrainingHelper.startTraining(spark, m, rddTopic, config, writer);
System.out.println("LDA has ended!")

}
Expand Down
4 changes: 2 additions & 2 deletions src/main/scala/com/intel/distml/app/lda/LocalLDA.scala
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,9 @@ object LocalLDA {
config.psCount(1);
// config.workerCount(1)//it does not work now,dependent on partition number

val m: Model = new LDAModel(0.5f, 0.1f,K,dic)
val m: Model = new LDAModel(0.5f, 0.1f,K,dic.getSize)

TrainingHelper.startTraining(spark, m, rddTopic, config,new LDAModelWriter);
TrainingHelper.startTraining(spark, m, rddTopic, config,new LDAModelWriter(dic));
// TrainingHelper.startTraining(spark, m, rddTopic, config,new LDAModelWriter());
System.out.println("LDA has ended!")
}
Expand Down

0 comments on commit 191981c

Please sign in to comment.