Merge "Enable KMeans to support a text format"
This commit is contained in:
commit
9536896c68
|
@ -130,10 +130,13 @@ class ModelController(object):
|
|||
def textToIndex(self, text):
|
||||
return HashingTF().transform(text.split(" "))
|
||||
|
||||
def parseTextRDDToIndex(self, data):
|
||||
def parseTextRDDToIndex(self, data, label=True):
|
||||
|
||||
labels = data.map(lambda line: float(line.split(" ", 1)[0]))
|
||||
documents = data.map(lambda line: line.split(" ", 1)[1].split(" "))
|
||||
if label:
|
||||
labels = data.map(lambda line: float(line.split(" ", 1)[0]))
|
||||
documents = data.map(lambda line: line.split(" ", 1)[1].split(" "))
|
||||
else:
|
||||
documents = data.map(lambda line: line.split(" "))
|
||||
|
||||
tf = HashingTF().transform(documents)
|
||||
tf.cache()
|
||||
|
@ -141,7 +144,10 @@ class ModelController(object):
|
|||
idfIgnore = IDF(minDocFreq=2).fit(tf)
|
||||
index = idfIgnore.transform(tf)
|
||||
|
||||
return labels.zip(index).map(lambda line: LabeledPoint(line[0], line[1]))
|
||||
if label:
|
||||
return labels.zip(index).map(lambda line: LabeledPoint(line[0], line[1]))
|
||||
else:
|
||||
return index
|
||||
|
||||
def evaluateBinaryClassification(self, predictionAndLabels):
|
||||
|
||||
|
@ -166,22 +172,40 @@ class KMeansModelController(ModelController):
|
|||
|
||||
def __init__(self):
|
||||
super(KMeansModelController, self).__init__()
|
||||
self.model_params = {}
|
||||
|
||||
def _parse_model_params(self, params):
|
||||
|
||||
p = {}
|
||||
p['numClasses'] = int(params.get('numClasses', 2))
|
||||
p['maxIterations'] = int(params.get('numIterations', 10))
|
||||
p['runs'] = int(params.get('runs', 10))
|
||||
p['initializationMode'] = params.get('mode', 'random')
|
||||
|
||||
self.model_params = p
|
||||
|
||||
def create_model(self, data, params):
|
||||
|
||||
numClasses = int(params.get('numClasses', 2))
|
||||
numIterations = int(params.get('numIterations', 10))
|
||||
runs = int(params.get('runs', 10))
|
||||
mode = params.get('mode', 'random')
|
||||
self._parse_model_params(params)
|
||||
numClasses = self.model_params.pop('numClasses')
|
||||
|
||||
parsedData = data.map(
|
||||
lambda line: array([float(x) for x in line.split(',')]))
|
||||
|
||||
return KMeans.train(parsedData,
|
||||
numClasses,
|
||||
maxIterations=numIterations,
|
||||
runs=runs,
|
||||
initializationMode=mode)
|
||||
**self.model_params)
|
||||
|
||||
def create_model_text(self, data, params):
|
||||
|
||||
self._parse_model_params(params)
|
||||
numClasses = self.model_params.pop('numClasses')
|
||||
|
||||
parsedData = self.parseTextRDDToIndex(data, label=False)
|
||||
|
||||
return KMeans.train(parsedData,
|
||||
numClasses,
|
||||
**self.model_params)
|
||||
|
||||
def load_model(self, context, path):
|
||||
return KMeansModel.load(context, path)
|
||||
|
@ -189,6 +213,10 @@ class KMeansModelController(ModelController):
|
|||
def predict(self, model, params):
|
||||
return model.predict(params.split(','))
|
||||
|
||||
def predict_text(self, model, params):
|
||||
index = self.textToIndex(params)
|
||||
return model.predict(index)
|
||||
|
||||
|
||||
class RecommendationController(ModelController):
|
||||
|
||||
|
|
Loading…
Reference in New Issue