-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrainClassifier.m
More file actions
38 lines (30 loc) · 983 Bytes
/
trainClassifier.m
File metadata and controls
38 lines (30 loc) · 983 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
% Import svm library
importDependencies();
% Define class labels
class_labels = { 'airplanes'; 'barrel'; 'bonsai'; 'brontosaurus'; 'camera';
'ceiling_fan'; 'anchor'; 'binocular'; 'brain'; 'butterfly';
'car_side'; 'cellphone'};
% Load data sets
[trainingSet,trainingLabels,testSet,testLabels] = ...
sampleData(getDatabank(class_labels));
% Features to extract
request = {
@roundness 'Default';
@elongation 'Default';
@rectangularity 'Default';
@HOG 'Default';
@solidity 'Default';
%@curveness 'Default'
};
% Extract Features
testFeatures = extractFeature(request, testSet);
trainingFeatures = extractFeature(request, trainingSet);
% Train classifier
classifier = trainsvm(trainingFeatures, trainingLabels, class_labels);
% Predict test
[predLabels, stats] = predictsvm(classifier, testFeatures, ...
testLabels, class_labels);
% Show statistics
showStatistics(stats, class_labels, predLabels, testLabels);
% Save classifier
save('cache','classifier');