-
Notifications
You must be signed in to change notification settings - Fork 27
/
Copy pathtrain_svm.m
executable file
·83 lines (65 loc) · 2.81 KB
/
train_svm.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
function new_model = train_svm(model_name, paths)
% TRAIN_SVM Train an SVM classifier with the specified images.
% Performing a cross-validation to find the best params.
%
% Asumming pedestrian label as 1.0 and not pedestrian as -1.0
% (Code using libsvm)
%
% INPUT:
% Paths: positive / negative images_path: paths of the images to train
% model_name: name for saving the SVM model
%
% OUTPUT: libSVM model
%
%$ Author Jose Marcos Rodriguez $
%$ Date: 2013/11/09 $
%$ Revision: 1.2 $
%% path stuff
if nargin < 2
model_save_path = uigetdir('.models','Select model save folder');
positive_images_path = uigetdir('dataset','Select positive image folder');
negative_images_path = uigetdir('dataset','Select negative image folder');
if isa(model_save_path,'double') || ...
isa(positive_images_path,'double') || ...
isa(negative_images_path,'double')
cprintf('Errors','Invalid paths...\nexiting...\n\n')
return
end
else
model_save_path = paths{1};
positive_images_path = paths{2};
negative_images_path = paths{3};
end
%% train matrix and labels
params = get_params('train_svm_params');
pos = params.num_positive_instances;
negs = params.num_negative_instances;
[positive_images,negative_images] = ...
get_files(pos, negs,{positive_images_path,negative_images_path});
[labels, train_matrix]= get_feature_matrix(positive_images,negative_images);
% =====================================================================
%% SVM STUFF
% Crosss validation (k-fold crossval)
% =====================================================================
train_params = get_params('train_svm_params');
kernel_type = train_params.kernel;
cost_range = train_params.cost_range;
gamma_range = train_params.gamma_range;
disp(train_params);
svm_params = ...
cross_validate(kernel_type,cost_range,gamma_range,...
train_matrix, labels, ...
strcat(model_save_path,filesep,model_name));
% just for fixing GUI freezing due to unic thread MatLab issue
drawnow;
% =====================================================================
%% SVM trainning
% =====================================================================
svm_start = tic;
cprintf('Comments', 'beggining svm train...\n')
new_model.(model_name) = svmtrain(labels, train_matrix, svm_params);
svm_elapsed = toc(svm_start);
fprintf('SVM training done in: %f seconds.\n',svm_elapsed);
fprintf(strcat('Saving model in ',model_save_path, model_name, '.mat','\n'));
save(strcat(model_save_path,filesep,model_name, '.mat'), '-struct','new_model',model_name);
end