Skip to content

Commit c6bf63c

Browse files
authored
Add files via upload
0 parents  commit c6bf63c

13 files changed

+1270
-0
lines changed

Category.java

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
package com.dead.acctivi_classification;
2+
3+
public enum Category {
4+
Walk,Running, Climb_UP,Climb_Down, TEST
5+
}

Classifier.java

+187
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
package com.dead.acctivi_classification;
2+
3+
import java.util.ArrayList;
4+
import java.util.Collections;
5+
import java.util.HashMap;
6+
import java.util.Iterator;
7+
import java.util.List;
8+
import java.util.Map;
9+
10+
import com.dead.acctivi_classification.distanceAlgorithm.DistanceAlgorithm;
11+
import com.dead.acctivi_classification.distanceAlgorithm.EuclideanDistance;
12+
13+
14+
public class Classifier {
15+
16+
private int K;
17+
private double splitRatio;
18+
private double accuracy = 0;
19+
20+
private DistanceAlgorithm distanceAlgorithm;
21+
private List<DataPoint> listDataPoint;
22+
private List<DataPoint> listTrainData;
23+
private List<DataPoint> listTestData;
24+
private List<DataPoint> listTestValidator;
25+
private List<Double> listDistance;
26+
27+
public Classifier(){
28+
K = 11;
29+
splitRatio = 0.8;
30+
distanceAlgorithm = new EuclideanDistance();
31+
listDataPoint = new ArrayList<>();
32+
listTrainData = new ArrayList<>();
33+
listTestData = new ArrayList<>();
34+
listTestValidator = new ArrayList<>();
35+
}
36+
public int getK() {
37+
return K;
38+
}
39+
40+
public void setK(int k) {
41+
K = k;
42+
}
43+
44+
public double getSplitRatio() {
45+
return splitRatio;
46+
}
47+
48+
public void setSplitRatio(double splitRatio) {
49+
this.splitRatio = splitRatio;
50+
}
51+
52+
public List<DataPoint> getListDataPoint() {
53+
return listDataPoint;
54+
}
55+
56+
public void setListDataPoint(List<DataPoint> listDataPoint) {
57+
this.listDataPoint.clear();
58+
this.listDataPoint.addAll(listDataPoint);
59+
}
60+
61+
public List<DataPoint> getListTrainData() {
62+
return listTrainData;
63+
}
64+
65+
public List<DataPoint> getListTestData() {
66+
return listTestData;
67+
}
68+
69+
public DistanceAlgorithm getDistanceAlgorithm() {
70+
return distanceAlgorithm;
71+
}
72+
73+
public void setDistanceAlgorithm(DistanceAlgorithm distanceAlgorithm) {
74+
this.distanceAlgorithm = distanceAlgorithm;
75+
}
76+
77+
public double getAccuracy() {
78+
return accuracy;
79+
}
80+
81+
public void splitData(){
82+
listTestData.clear();
83+
listTrainData.clear();
84+
int trainSize = (int)(listDataPoint.size() * splitRatio);
85+
int testSize = listDataPoint.size() - trainSize;
86+
Collections.shuffle(listDataPoint);
87+
for (int i = 0;i < trainSize; i++)
88+
listTrainData.add(listDataPoint.get(i));
89+
for (int i = 0; i < testSize; i++){
90+
DataPoint dataPointTest = new DataPoint(listDataPoint.get(i + trainSize));
91+
DataPoint dataPointValidator = new DataPoint(dataPointTest);
92+
dataPointTest.setCategory(Category.TEST);
93+
listTestData.add(dataPointTest);
94+
listTestValidator.add(dataPointValidator);
95+
}
96+
}
97+
98+
private List<Double> calculateDistances(DataPoint point){
99+
List<Double> listDistance = new ArrayList<>();
100+
for (DataPoint dataPoint:listTrainData){
101+
double distance = distanceAlgorithm.calculateDistance(point.getMY(), point.getVY(),point.getSDY(),point.getMZ(), point.getVZ(),point.getSDZ(),
102+
dataPoint.getMY(), dataPoint.getVY(),dataPoint.getSDY(),dataPoint.getMZ(), dataPoint.getVZ(),dataPoint.getSDZ());
103+
listDistance.add(distance);
104+
}
105+
return listDistance;
106+
}
107+
108+
// NOT SURE WHATS HAPPENING
109+
private Category getMaxCategory(HashMap<Category, Integer> hashMap){
110+
Iterator<Map.Entry<Category, Integer>> iterator = hashMap.entrySet().iterator();
111+
int maxCategory = Integer.MIN_VALUE;
112+
Category category = null;
113+
while (iterator.hasNext()) {
114+
Map.Entry<Category, Integer> item = iterator.next();
115+
if (item.getValue() > maxCategory){
116+
category = item.getKey();
117+
}
118+
}
119+
return category;
120+
}
121+
122+
123+
private Category classifyDataPoint(DataPoint point){
124+
HashMap<Category, Integer> hashMap = new HashMap<>();
125+
listDistance = calculateDistances(point);
126+
for (int i = 0; i < K; i++){
127+
double min = Double.MAX_VALUE;
128+
int minIndex = -1;
129+
for (int j = 0; j < listDistance.size(); j++){
130+
if (listDistance.get(j) < min){
131+
min = listDistance.get(j);
132+
minIndex = j;
133+
}
134+
}
135+
Category category = listTrainData.get(minIndex).getCategory();
136+
if (hashMap.containsKey(category)){
137+
hashMap.put(category, hashMap.get(category) + 1);
138+
}else{
139+
hashMap.put(category, 1);
140+
}
141+
listDistance.set(minIndex, Double.MAX_VALUE);
142+
}
143+
return getMaxCategory(hashMap);
144+
}
145+
146+
public void classify(){
147+
accuracy = 0;
148+
for (int i = 0;i < listTestData.size(); i++){
149+
DataPoint dataPoint = listTestData.get(i);
150+
Category category = classifyDataPoint(dataPoint);
151+
if (isCorrect(category, listTestValidator.get(i).getCategory()))
152+
accuracy++;
153+
dataPoint.setCategory(category);
154+
}
155+
accuracy /= listTestData.size();
156+
}
157+
158+
Category predictNew(double mY, double vY, double sdY, double mZ, double vZ, double sdZ){
159+
160+
DataPoint dataPoint = new DataPoint(mY,vY,sdY,mZ,vZ,sdZ,Category.values()[4]);
161+
dataPoint.setCategory(Category.TEST);
162+
Category category = classifyDataPoint(dataPoint);
163+
164+
return category;
165+
}
166+
167+
void addTrainData(){
168+
listTestData.clear();
169+
listTrainData.clear();
170+
int trainSize = (int)(listDataPoint.size() * 1);
171+
Collections.shuffle(listDataPoint);
172+
for (int i = 0;i < trainSize; i++){
173+
listTrainData.add(listDataPoint.get(i));
174+
}
175+
}
176+
177+
178+
private boolean isCorrect(Category predictedCategory, Category trueCategory){
179+
return predictedCategory.equals(trueCategory);
180+
}
181+
public void reset() {
182+
listDataPoint.clear();
183+
listTestData.clear();
184+
listTrainData.clear();
185+
}
186+
}
187+

Constants.java

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
package com.dead.acctivi_classification;
2+
3+
public class Constants {
4+
public static final String DISTANCE_ALGORITHM = "distanceAlgorithm";
5+
public static final String K = "K";
6+
public static final String SPLITE_RATIO = "splitRatio";
7+
public static final String MINKOWSKI_P = "P";
8+
}

DataPoint.java

+80
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
package com.dead.acctivi_classification;
2+
3+
public class DataPoint {
4+
5+
private double mY,vY,sdY, mZ, vZ, sdZ;
6+
private Category category;
7+
8+
public DataPoint(double mY, double vY, double sdY,double mZ, double vZ, double sdZ, Category category){
9+
this.mY =mY ;
10+
this.vY = vY;
11+
this.sdY = sdY;
12+
this.mZ =mZ ;
13+
this.vZ = vZ;
14+
this.sdZ = sdZ;
15+
this.category = category;
16+
}
17+
18+
public DataPoint(DataPoint dataPoint){
19+
this.mY = dataPoint.getMY();
20+
this.mZ = dataPoint.getMZ();
21+
this.vY = dataPoint.getVY();
22+
this.vZ = dataPoint.getVZ();
23+
this.sdY = dataPoint.getSDY();
24+
this.sdZ = dataPoint.getSDZ();
25+
this.category = dataPoint.getCategory();
26+
}
27+
28+
29+
public double getVZ() {
30+
return vZ;
31+
}
32+
33+
public double getSDZ() {
34+
return sdZ;
35+
}
36+
37+
public double getVY() {
38+
return vY;
39+
}
40+
41+
public double getMZ() {
42+
return mZ;
43+
}
44+
45+
public double getMY() {
46+
return mY;
47+
}
48+
49+
public double getSDY() {
50+
return sdY;
51+
}
52+
53+
public void setMY(double my){
54+
this.mY = my;
55+
}
56+
public void setVY(double vy) {
57+
this.vY = vy;
58+
}
59+
public void setSDY(double sdy) {
60+
this.sdY = sdy;
61+
}
62+
public void setMZ(double mz){
63+
this.mZ = mz;
64+
}
65+
public void setVZ(double vz) {
66+
this.vZ = vz;
67+
}
68+
public void setSDZ(double sdz) {
69+
this.sdZ = sdz;
70+
}
71+
72+
73+
public Category getCategory() {
74+
return category;
75+
}
76+
77+
public void setCategory(Category category) {
78+
this.category = category;
79+
}
80+
}

0 commit comments

Comments
 (0)