-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathLinearSVM.scala
111 lines (107 loc) · 4.74 KB
/
LinearSVM.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
// Wei Chen - LSVM - linear Support Vector Machine
// 2015-12-09
package com.scalaml.algorithm
import com.scalaml.general.MatrixFunc._
// LinearSVM = linear Support Vector Machine
// This core function only support dual classification cus linear
// classifier is only good at dual classification problems
class LinearSVM() extends Classification {
val algoname: String = "LinearSVM"
val version: String = "0.1"
var projector = Array[Double]()
var cost = Map(-1 -> 1.0, 1 -> 1.0) // Cost of two groups
var limit = 1000 // Iteration limit
var err = 1e-1 // Saturation error
override def clear(): Boolean = {
projector = Array[Double]()
true
}
override def config(paras: Map[String, Any]): Boolean = try {
cost = paras.getOrElse("COST", paras.getOrElse("cost", Map(-1 -> 1.0, 1 -> 1.0))).asInstanceOf[Map[Int, Double]]
limit = paras.getOrElse("LIMIT", paras.getOrElse("limit", 1000)).asInstanceOf[Int]
err = paras.getOrElse("ERROR", paras.getOrElse("error", paras.getOrElse("err", 1e-1))).asInstanceOf[Double]
true
} catch { case e: Exception =>
Console.err.println(e)
false
}
// --- Sub Variables & Functions ---
private val INF = 1.0 / 0 // Infinite
private val rng = scala.util.Random // Random Seed
private def dot(x: Array[Double], y: Array[Double]): Double =
arraymultiply(x, y).sum
private def randomSwap(arr: Array[Int]) {
val arrsize = arr.size
for(i <- 0 until arrsize) { // Randomize Saturation Direction
val j = rng.nextInt(arrsize - i)
val temp = arr(i) // Random SWAP i <-> i + j
arr(i) = arr(i + j)
arr(i + j) = temp
}
}
// --- Function Core Start ---
def train(
data: Array[(Int, Array[Double])] // Data Array(yi, xi)
): Boolean = try { // - Feature Initialization
val traindatasize = data.size
val featuresize = data.head._2.size
var w = new Array[Double](featuresize + 1) // Initial weighting
var alpha = new Array[Double](traindatasize) // Alpha SV pointer
var index = (0 until traindatasize).toArray // Initialize index
var QD = new Array[Double](traindatasize) // QD // TODO
val diag = cost.map(l => l._1 -> 0.5 / l._2) // Diag
for(i <- 0 until traindatasize) {
val (yi, xt) = data(i)
val xi = xt :+ 1.0
QD(i) = diag(yi) + xi.map(Math.pow(_, 2)).sum // Initialize QD
}
// - Iteration Coefficients Initiation
var saturated = false
var iter = 0
var old_PG_max = INF // Projected Gradient maximum saved
while(iter < limit && !saturated) {
iter += 1
var new_PG_max = -INF // new Projected Gradient maximum
var new_PG_min = INF // new Projected Gradient minimum
randomSwap(index)
var outzone = false
for(i <- index) { // Loop data with SWAP index
val (yi, xt) = data(i)
val xi = xt :+ 1.0
// Projected Gradient -> Cost with Alpha for PG -> 0
val PG = yi * dot(w, xi) - 1 + alpha(i) * diag(yi)
// if SV or Violate
if(alpha(i) > 0 || PG < 0) {
new_PG_max = Math.max(new_PG_max, PG) // Sandwich Saturation
new_PG_min = Math.min(new_PG_min, PG) // Test if all PG ~= 0
val alpha_old = alpha(i)
alpha(i) = Math.max(alpha_old - PG / QD(i), 0.0) // Update Alpha
val d = yi * (alpha(i) - alpha_old) // Difference
w = w.zip(xi).map(l => l._1 + l._2 * d) // wj += xij * d
} else if(PG <= old_PG_max) { // If in PG Zone
new_PG_max = Math.max(new_PG_max, 0.0) // Sandwich Saturation
new_PG_min = Math.min(new_PG_min, 0.0)
} else outzone = true // Out of PG Zone
}
// Update valid PG Zone
if(new_PG_max - new_PG_min > err) {
old_PG_max = (if(new_PG_max > 0) new_PG_max else INF)
} else if(outzone) { // Reset if PG not saturated correctly
old_PG_max = INF
} else saturated = true // Done
}
// Console.err.println("Iterate:" + iter + " SV_Number:" + alpha.filter(_ > 0).size)
// Console.err.println("W = " + w.mkString(","))
projector = w
true
} catch { case e: Exception =>
Console.err.println(e)
false
}
// --- Prediction Function ---
def predict(data: Array[Array[Double]]): Array[Int] = {
data.map { xt =>
if(dot(xt :+ 1.0, projector) < 0) -1 else 1
}
}
}