From 4c05018883676f2d2fcc32f5a5d7afe2e3899736 Mon Sep 17 00:00:00 2001 From: Joan Varvenne Date: Wed, 14 Sep 2016 11:38:33 +0100 Subject: [PATCH] Make it possible to pick the number of samples used for SVM. Change-Id: I72b625fcefd4b8fa87cc49f71a847e8223395305 --- monasca_analytics/sml/svm_one_class.py | 17 +++++++++++++---- test/sml/test_svm_one_class.py | 5 ++++- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/monasca_analytics/sml/svm_one_class.py b/monasca_analytics/sml/svm_one_class.py index 941b10a..ba5e970 100644 --- a/monasca_analytics/sml/svm_one_class.py +++ b/monasca_analytics/sml/svm_one_class.py @@ -20,6 +20,8 @@ import numpy as np from sklearn import svm import voluptuous +import monasca_analytics.banana.typeck.type_util as type_util +import monasca_analytics.component.params as params from monasca_analytics.sml import base from monasca_analytics.util import validation_utils as vu @@ -36,24 +38,31 @@ class SvmOneClass(base.BaseSML): def __init__(self, _id, _config): super(SvmOneClass, self).__init__(_id, _config) + self._nb_samples = int(_config["nb_samples"]) @staticmethod def validate_config(_config): svm_schema = voluptuous.Schema({ - "module": voluptuous.And(basestring, vu.NoSpaceCharacter()) + "module": voluptuous.And(basestring, vu.NoSpaceCharacter()), + "nb_samples": voluptuous.Or(float, int) }, required=True) return svm_schema(_config) @staticmethod def get_default_config(): - return {"module": SvmOneClass.__name__} + return { + "module": SvmOneClass.__name__, + "nb_samples": N_SAMPLES + } @staticmethod def get_params(): - return [] + return [ + params.ParamDescriptor("nb_samples", type_util.Number(), N_SAMPLES) + ] def number_of_samples_required(self): - return N_SAMPLES + return self._nb_samples def _generate_train_test_sets(self, samples, ratio_train): num_samples_train = int(len(samples) * ratio_train) diff --git a/test/sml/test_svm_one_class.py b/test/sml/test_svm_one_class.py index f2d670e..2c5d78d 100644 --- a/test/sml/test_svm_one_class.py +++ b/test/sml/test_svm_one_class.py @@ -29,7 +29,10 @@ class TestSvmOneClass(MonanasTestCase): def setUp(self): super(TestSvmOneClass, self).setUp() - self.svm = svm_one_class.SvmOneClass("fakeid", {"module": "fake"}) + self.svm = svm_one_class.SvmOneClass("fakeid", { + "module": "fake", + "nb_samples": 1000 + }) def tearDown(self): super(TestSvmOneClass, self).tearDown()