Add Tree Models
Meteos only supports DecisionTreeRegression in tree models until now. This patch add support models as below. - DecisionTreeClassification - RandomForestRegression - RandomForestClassification implements blueprint add-support-models Change-Id: Ia41b852c1cc6ffa10033b3a6140a969cf3c6f716
This commit is contained in:
parent
d412e10193
commit
7d2d5ddfb0
|
@ -66,6 +66,8 @@ from pyspark.mllib.regression import RidgeRegressionModel
|
|||
from pyspark.mllib.regression import RidgeRegressionWithSGD
|
||||
from pyspark.mllib.tree import DecisionTree
|
||||
from pyspark.mllib.tree import DecisionTreeModel
|
||||
from pyspark.mllib.tree import RandomForest
|
||||
from pyspark.mllib.tree import RandomForestModel
|
||||
from pyspark.mllib.util import MLUtils
|
||||
|
||||
|
||||
|
@ -335,10 +337,14 @@ class NaiveBayesModelController(ModelController):
|
|||
return model.predict(params.split(','))
|
||||
|
||||
|
||||
class DecisionTreeModelController(ModelController):
|
||||
class TreeModelController(ModelController):
|
||||
|
||||
def __init__(self):
|
||||
super(DecisionTreeModelController, self).__init__()
|
||||
def __init__(self, train_name, model_name, algorithm):
|
||||
super(TreeModelController, self).__init__()
|
||||
self.train_class = eval(train_name)
|
||||
self.model_class = eval(model_name)
|
||||
self.algorithm = algorithm
|
||||
self.model_params = {}
|
||||
|
||||
def _parse_to_libsvm(self, param):
|
||||
|
||||
|
@ -359,17 +365,76 @@ class DecisionTreeModelController(ModelController):
|
|||
|
||||
return SparseVector.parse(parsed_str)
|
||||
|
||||
def _parse_model_params(self, params):
|
||||
|
||||
p = {}
|
||||
p['maxDepth'] = int(params.get('maxDepth', 5))
|
||||
p['maxBins'] = int(params.get('maxBins', 32))
|
||||
|
||||
if self.algorithm == 'Classification':
|
||||
p['numClasses'] = int(params.get('numClasses', 2))
|
||||
|
||||
if self.__class__.__name__ == 'RandomForestModelController':
|
||||
p['numTrees'] = int(params.get('numTrees', 3))
|
||||
|
||||
self.model_params = p
|
||||
|
||||
def _create_model(self, data, params, format='csv'):
|
||||
|
||||
self._parse_model_params(params)
|
||||
if format == 'csv':
|
||||
points = data.map(self.parsePoint)
|
||||
else:
|
||||
points = data
|
||||
|
||||
if (self.__class__.__name__ == 'DecisionTreeModelController' and
|
||||
self.algorithm == 'Regression'):
|
||||
|
||||
return getattr(self.train_class,
|
||||
'trainRegressor')(points,
|
||||
{},
|
||||
**self.model_params)
|
||||
|
||||
elif (self.__class__.__name__ == 'DecisionTreeModelController' and
|
||||
self.algorithm == 'Classification'):
|
||||
|
||||
numClasses = self.model_params.pop('numClasses')
|
||||
|
||||
return getattr(self.train_class,
|
||||
'trainClassifier')(points,
|
||||
numClasses,
|
||||
{},
|
||||
**self.model_params)
|
||||
|
||||
if (self.__class__.__name__ == 'RandomForestModelController' and
|
||||
self.algorithm == 'Regression'):
|
||||
|
||||
numTrees = self.model_params.pop('numTrees')
|
||||
|
||||
return getattr(self.train_class,
|
||||
'trainRegressor')(points,
|
||||
{},
|
||||
numTrees,
|
||||
**self.model_params)
|
||||
|
||||
elif (self.__class__.__name__ == 'RandomForestModelController' and
|
||||
self.algorithm == 'Classification'):
|
||||
|
||||
numClasses = self.model_params.pop('numClasses')
|
||||
numTrees = self.model_params.pop('numTrees')
|
||||
|
||||
return getattr(self.train_class,
|
||||
'trainClassifier')(points,
|
||||
numClasses,
|
||||
{},
|
||||
numTrees,
|
||||
**self.model_params)
|
||||
|
||||
def create_model(self, data, params):
|
||||
return self._create_model(data, params)
|
||||
|
||||
def create_model_libsvm(self, data, params):
|
||||
|
||||
impurity = params.get('impurity', 'variance')
|
||||
maxDepth = int(params.get('maxDepth', 5))
|
||||
maxBins = int(params.get('maxBins', 32))
|
||||
|
||||
return DecisionTree.trainRegressor(data,
|
||||
categoricalFeaturesInfo={},
|
||||
impurity=impurity,
|
||||
maxDepth=maxDepth,
|
||||
maxBins=maxBins)
|
||||
return self._create_model(data, params, format='libsvm')
|
||||
|
||||
def evaluate_model(self, context, model, data):
|
||||
|
||||
|
@ -384,7 +449,7 @@ class DecisionTreeModelController(ModelController):
|
|||
return result
|
||||
|
||||
def load_model(self, context, path):
|
||||
return DecisionTreeModel.load(context, path)
|
||||
return getattr(self.model_class, 'load')(context, path)
|
||||
|
||||
def predict(self, model, params):
|
||||
return model.predict(params.split(','))
|
||||
|
@ -394,6 +459,26 @@ class DecisionTreeModelController(ModelController):
|
|||
return model.predict(parsed_params)
|
||||
|
||||
|
||||
class DecisionTreeModelController(TreeModelController):
|
||||
|
||||
def __init__(self, algorithm):
|
||||
train_name = 'DecisionTree'
|
||||
model_name = 'DecisionTreeModel'
|
||||
super(DecisionTreeModelController, self).__init__(train_name,
|
||||
model_name,
|
||||
algorithm)
|
||||
|
||||
|
||||
class RandomForestModelController(TreeModelController):
|
||||
|
||||
def __init__(self, algorithm):
|
||||
train_name = 'RandomForest'
|
||||
model_name = 'RandomForestModel'
|
||||
super(RandomForestModelController, self).__init__(train_name,
|
||||
model_name,
|
||||
algorithm)
|
||||
|
||||
|
||||
class Word2VecModelController(ModelController):
|
||||
|
||||
def __init__(self):
|
||||
|
@ -490,7 +575,13 @@ class MeteosSparkController(object):
|
|||
elif model_type == 'RidgeRegression':
|
||||
self.controller = RidgeRegressionModelController()
|
||||
elif model_type == 'DecisionTreeRegression':
|
||||
self.controller = DecisionTreeModelController()
|
||||
self.controller = DecisionTreeModelController('Regression')
|
||||
elif model_type == 'DecisionTreeClassification':
|
||||
self.controller = DecisionTreeModelController('Classification')
|
||||
elif model_type == 'RandomForestRegression':
|
||||
self.controller = RandomForestModelController('Regression')
|
||||
elif model_type == 'RandomForestClassification':
|
||||
self.controller = RandomForestModelController('Classification')
|
||||
elif model_type == 'Word2Vec':
|
||||
self.controller = Word2VecModelController()
|
||||
elif model_type == 'FPGrowth':
|
||||
|
|
Loading…
Reference in New Issue