1
1
package hu .sztaki .ilab .ps .matrix .factorization
2
2
3
- import java .io .{FileWriter , PrintWriter }
4
-
3
+ import hu .sztaki .ilab .ps .matrix .factorization .sink .nDCGSink
5
4
import hu .sztaki .ilab .ps .matrix .factorization .utils .Rating
6
- import hu .sztaki .ilab .ps .matrix .factorization .utils .Utils .{ItemId , UserId }
7
- import hu .sztaki .ilab .ps .matrix .factorization .utils .Vector .VectorLength
8
- import org .apache .flink .streaming .api .functions .sink .RichSinkFunction
9
5
import org .apache .flink .streaming .api .scala ._
10
6
11
7
class PSOnlineMatrixFactorizationAndTopKGeneratorTest {
@@ -17,26 +13,26 @@ object PSOnlineMatrixFactorizationAndTopKGeneratorTest {
17
13
val numFactors = 10
18
14
val rangeMin = - 0.01
19
15
val rangeMax = 0.01
20
- val learningRate = 0.4
16
+ val learningRate = 0.2
21
17
val userMemory = 4
22
18
val K = 100
23
- val workerK = 50
19
+ val workerK = 100
24
20
val bucketSize = 100
25
21
val negativeSampleRate = 9
26
- val pullLimit = 1200
22
+ val pullLimit = 800
27
23
val workerParallelism = 4
28
24
val psParallelism = 4
29
25
val iterationWaitTime = 20000
30
26
31
27
def main (args : Array [String ]): Unit = {
32
28
val env = StreamExecutionEnvironment .getExecutionEnvironment
33
- val src = env.readTextFile(" TestData/test_batch " ).map(line => {
29
+ val src = env.readTextFile(" TestData/week_all " ).map(line => {
34
30
val fieldsArray = line.split(" ," )
35
31
36
32
Rating (fieldsArray(1 ).toInt, fieldsArray(2 ).toInt, 1.0 , fieldsArray(0 ).toLong)
37
33
})
38
34
39
- PSOnlineMatrixFactorizationAndTopKGenerator .psOnlineLearnerAndGenerator(
35
+ val topK = PSOnlineMatrixFactorizationAndTopKGenerator .psOnlineLearnerAndGenerator(
40
36
src,
41
37
numFactors,
42
38
rangeMin,
@@ -48,50 +44,10 @@ object PSOnlineMatrixFactorizationAndTopKGeneratorTest {
48
44
workerK,
49
45
bucketSize,
50
46
pullLimit = pullLimit,
51
- iterationWaitTime = iterationWaitTime).addSink(new RichSinkFunction [(UserId , ItemId , Long , List [(Double , ItemId )])] {
52
-
53
- var sumnDCG = 0.0
54
- var counter = 0
55
- var hit = 0
56
-
57
- val log2 : VectorLength = Math .log(2 )
58
-
59
- override def invoke (value : (UserId , ItemId , Long , List [(Double , ItemId )])): Unit = {
60
-
61
- val index = value._4.indexWhere (
62
- recommendation => recommendation._2 == value._2 ) match {
63
-
64
- case - 1 => Int .MaxValue
65
-
66
- case i => i + 1
67
- }
68
-
69
-
70
- val nDCG = index match {
71
-
72
- case Int .MaxValue => 0.0
73
-
74
- case i => log2 / Math .log(1.0 + i)
75
- }
76
-
77
- if (nDCG != 0 )
78
- hit += 1
79
- sumnDCG += nDCG
80
- counter += 1
81
- }
82
-
83
- override def close (): Unit = {
84
- val outputFile = new PrintWriter (new FileWriter (" TestData/PSOnlineMatrixFactorizationAndTopKGenerator_nDCG.out" ))
85
-
86
- val avgnDCG = sumnDCG / counter
47
+ iterationWaitTime = iterationWaitTime)
87
48
88
- outputFile write s " nDCG: $avgnDCG \n "
89
- outputFile write s " hit: $hit \n "
90
- outputFile write s " invokes: $counter"
49
+ nDCGSink.nDCGPeriodsToCsv(topK, " TestData/onlineMF_nDCG.csv" , 86400 )
91
50
92
- outputFile close()
93
- }
94
- }).setParallelism(1 )
95
51
96
52
env.execute()
97
53
}
0 commit comments