Skip to content

Commit bad6a8e

Browse files
committed
Update test
1 parent 0f9a1fd commit bad6a8e

File tree

2 files changed

+9
-53
lines changed

2 files changed

+9
-53
lines changed

src/main/scala/hu/sztaki/ilab/ps/matrix/factorization/PSOnlineMatrixFactorizationAndTopKGenerator.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ object PSOnlineMatrixFactorizationAndTopKGenerator {
9595
}, x => x.targetWorker)
9696

9797
FlinkParameterServer.transform(
98-
partitionedInput, workerLogic, serverLogic,workerParallelism, psParallelism, iterationWaitTime)
98+
partitionedInput, workerLogic, serverLogic, workerParallelism, psParallelism, iterationWaitTime)
9999
.flatMap( new CollectTopKFromEachWorker(K, userMemory, workerParallelism)).setParallelism(1)
100100
}
101101

src/test/scala/hu/sztaki/ilab/ps/matrix/factorization/PSOnlineMatrixFactorizationAndTopKGeneratorTest.scala

+8-52
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,7 @@
11
package hu.sztaki.ilab.ps.matrix.factorization
22

3-
import java.io.{FileWriter, PrintWriter}
4-
3+
import hu.sztaki.ilab.ps.matrix.factorization.sink.nDCGSink
54
import hu.sztaki.ilab.ps.matrix.factorization.utils.Rating
6-
import hu.sztaki.ilab.ps.matrix.factorization.utils.Utils.{ItemId, UserId}
7-
import hu.sztaki.ilab.ps.matrix.factorization.utils.Vector.VectorLength
8-
import org.apache.flink.streaming.api.functions.sink.RichSinkFunction
95
import org.apache.flink.streaming.api.scala._
106

117
class PSOnlineMatrixFactorizationAndTopKGeneratorTest {
@@ -17,26 +13,26 @@ object PSOnlineMatrixFactorizationAndTopKGeneratorTest {
1713
val numFactors = 10
1814
val rangeMin = -0.01
1915
val rangeMax = 0.01
20-
val learningRate = 0.4
16+
val learningRate = 0.2
2117
val userMemory = 4
2218
val K = 100
23-
val workerK = 50
19+
val workerK = 100
2420
val bucketSize = 100
2521
val negativeSampleRate = 9
26-
val pullLimit = 1200
22+
val pullLimit = 800
2723
val workerParallelism = 4
2824
val psParallelism = 4
2925
val iterationWaitTime = 20000
3026

3127
def main(args: Array[String]): Unit = {
3228
val env = StreamExecutionEnvironment.getExecutionEnvironment
33-
val src = env.readTextFile("TestData/test_batch").map(line => {
29+
val src = env.readTextFile("TestData/week_all").map(line => {
3430
val fieldsArray = line.split(",")
3531

3632
Rating(fieldsArray(1).toInt, fieldsArray(2).toInt, 1.0, fieldsArray(0).toLong)
3733
})
3834

39-
PSOnlineMatrixFactorizationAndTopKGenerator.psOnlineLearnerAndGenerator(
35+
val topK = PSOnlineMatrixFactorizationAndTopKGenerator.psOnlineLearnerAndGenerator(
4036
src,
4137
numFactors,
4238
rangeMin,
@@ -48,50 +44,10 @@ object PSOnlineMatrixFactorizationAndTopKGeneratorTest {
4844
workerK,
4945
bucketSize,
5046
pullLimit = pullLimit,
51-
iterationWaitTime = iterationWaitTime).addSink(new RichSinkFunction[(UserId, ItemId, Long, List[(Double, ItemId)])] {
52-
53-
var sumnDCG = 0.0
54-
var counter = 0
55-
var hit = 0
56-
57-
val log2: VectorLength = Math.log(2)
58-
59-
override def invoke(value: (UserId, ItemId, Long, List[(Double, ItemId)])): Unit = {
60-
61-
val index = value._4.indexWhere (
62-
recommendation => recommendation._2 == value._2 ) match {
63-
64-
case -1 => Int.MaxValue
65-
66-
case i => i + 1
67-
}
68-
69-
70-
val nDCG = index match {
71-
72-
case Int.MaxValue => 0.0
73-
74-
case i => log2 / Math.log(1.0 + i)
75-
}
76-
77-
if(nDCG != 0)
78-
hit += 1
79-
sumnDCG += nDCG
80-
counter += 1
81-
}
82-
83-
override def close(): Unit = {
84-
val outputFile = new PrintWriter(new FileWriter("TestData/PSOnlineMatrixFactorizationAndTopKGenerator_nDCG.out"))
85-
86-
val avgnDCG = sumnDCG / counter
47+
iterationWaitTime = iterationWaitTime)
8748

88-
outputFile write s"nDCG: $avgnDCG \n"
89-
outputFile write s"hit: $hit \n"
90-
outputFile write s"invokes: $counter"
49+
nDCGSink.nDCGPeriodsToCsv(topK, "TestData/onlineMF_nDCG.csv", 86400)
9150

92-
outputFile close()
93-
}
94-
}).setParallelism(1)
9551

9652
env.execute()
9753
}

0 commit comments

Comments
 (0)