-
Notifications
You must be signed in to change notification settings - Fork 20
/
Copy pathSolversManager.cs
159 lines (147 loc) · 6.26 KB
/
SolversManager.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
using System.Collections.Generic;
using UnityEditor;
using UnityEngine;
class SolversManager
{
public enum SolverType
{
MaxPercentageError,
AveragePercentageError
}
private MetricsManager metricsManager = null;
#region Solvers
abstract class Solver
{
public MetricsManager metricsManager = null;
public virtual double computeLoss(List<Color> estimates, List<Color> reference) {
double cost = 0.0;
for (int j = 0; j < estimates.Count; j++) {
double sample_cost = computeSampleLoss(estimates[j], reference[j]);
cost += sample_cost;
}
cost /= estimates.Count;
cost *= 100;
return cost;
}
public abstract double computeSampleLoss(Color estimate, Color reference);
}
abstract class PercentageErrorSolver : Solver
{
public override double computeSampleLoss(Color estimate, Color reference) {
Vector3[] sample_evaluation = metricsManager.evaluateSample(estimate, reference);
// return the average absolute relative difference, but also account for zero values
// abs(x-y) / (abs(x) + abs(y))
// could also use sqrt here to give higher feedback to larger differences
double cost = 0.0;
double num_valid_weights = 0.0;
Vector3 sample_cost_triplet = Vector3.zero;
sample_cost_triplet.x = System.Math.Abs(sample_evaluation[0].x - sample_evaluation[1].x) / (System.Math.Abs(sample_evaluation[0].x) + System.Math.Abs(sample_evaluation[1].x));
if (!double.IsNaN(sample_cost_triplet.x)) {
cost = sample_cost_triplet.x;
num_valid_weights++;
}
if (metricsManager.CurrentMetricType != MetricsManager.MetricType.Luminance) {
sample_cost_triplet.y = System.Math.Abs(sample_evaluation[0].y - sample_evaluation[1].y) / (System.Math.Abs(sample_evaluation[0].y) + System.Math.Abs(sample_evaluation[1].y));
if (!double.IsNaN(sample_cost_triplet.y)) {
cost += sample_cost_triplet.y;
num_valid_weights++;
}
sample_cost_triplet.z = System.Math.Abs(sample_evaluation[0].z - sample_evaluation[1].z) / (System.Math.Abs(sample_evaluation[0].z) + System.Math.Abs(sample_evaluation[1].z));
if (!double.IsNaN(sample_cost_triplet.z)) {
cost += sample_cost_triplet.z;
num_valid_weights++;
}
}
if (num_valid_weights > 0) {
cost /= num_valid_weights;
}
return cost;
}
}
class MaxPercentageErrorSolver : PercentageErrorSolver
{
public override double computeLoss(List<Color> estimates, List<Color> reference) {
double cost = 0.0;
for (int j = 0; j < estimates.Count; j++) {
double sample_cost = computeSampleLoss(estimates[j], reference[j]);
cost = System.Math.Max(cost, sample_cost);
}
cost *= 100;
return cost;
}
}
class AveragePercentageErrorSolver : PercentageErrorSolver
{
public override double computeLoss(List<Color> estimates, List<Color> reference) {
double cost = 0.0;
for (int j = 0; j < estimates.Count; j++) {
double sample_cost = computeSampleLoss(estimates[j], reference[j]);
cost += sample_cost;
}
cost /= estimates.Count;
if (metricsManager.CurrentMetricType != MetricsManager.MetricType.Luminance) {
cost /= 3.0;
}
cost *= 100;
return cost;
}
}
class L1NormSolver : Solver
{
public override double computeSampleLoss(Color estimate, Color reference) {
Vector3 sample_cost = metricsManager.computeSampleLoss(estimate, reference);
return System.Math.Abs(sample_cost.x) + System.Math.Abs(sample_cost.y) + System.Math.Abs(sample_cost.z);
}
}
class L2NormSolver : Solver
{
public override double computeSampleLoss(Color estimate, Color reference) {
Vector3 sample_cost = metricsManager.computeSampleLoss(estimate, reference);
return System.Math.Sqrt(sample_cost.x * sample_cost.x) + System.Math.Sqrt(sample_cost.y * sample_cost.y) + System.Math.Sqrt(sample_cost.z * sample_cost.z);
}
}
class L2NormSquaredSolver : Solver
{
public override double computeSampleLoss(Color estimate, Color reference) {
Vector3 sample_cost = metricsManager.computeSampleLoss(estimate, reference);
return (sample_cost.x * sample_cost.x) + (sample_cost.y * sample_cost.y) + (sample_cost.z * sample_cost.z);
}
}
#endregion
private Solver currentSolver;
private Dictionary<SolverType, Solver> SolverList = new Dictionary<SolverType, Solver> {
{ SolverType.MaxPercentageError, new MaxPercentageErrorSolver() },
{ SolverType.AveragePercentageError, new AveragePercentageErrorSolver() }
};
public SolverType CurrentSolverType { get; private set; }
public MetricsManager.MetricType CurrentMetricType {
get { return metricsManager.CurrentMetricType; }
}
public SolversManager() {
Reset();
}
public MetricsManager MetricsManager {
get { return metricsManager; }
set {
metricsManager = value;
foreach (var key in SolverList.Keys) {
SolverList[key].metricsManager = metricsManager;
}
}
}
public void Reset() {
CurrentSolverType = SolverType.AveragePercentageError;
}
public void populateGUI() {
//CurrentSolverType = (SolverType)EditorGUILayout.EnumPopup(new GUIContent("Solver:", "The solver method"), CurrentSolverType, CustomStyles.defaultGUILayoutOption);
}
public void SetCurrentSolver() {
currentSolver = SolverList[CurrentSolverType];
}
public void SetCurrentMetric() {
metricsManager.SetCurrentMetric();
}
public double computeLoss(List<Color> estimates, List<Color> reference) {
return currentSolver.computeLoss(estimates, reference);
}
}