How to train a GMM with an initial GaussianMixtureModel?

We're trying to train a Gaussian Mixture Model (GMM) with a specified initial model in python with the MLLIB on Spark. Doc 1.5.1 of pyspark says we should use a GaussianMixtureModel object as input for the "initialModel" parameter to the GaussianMixture.train method. Before creating our own initial model (the plan is to use a Kmean result for instance), we simply wanted to test case this scenario. So we try to initialize a 2nd training using the GaussianMixtureModel from the output a 1st training. But this trivial scenario throws an error. Could you please help us determine what's going on here ? Thanks a lot guillaume

PS: we run (py) spark 1.5.1 with hadoop 2.6

Below is the trivial scenario code and the error:

from pyspark.mllib.clustering import GaussianMixture
from numpy import array
import sys
import os
import pyspark

### Local default options
K=2 # "k" (int) Set the number of Gaussians in the mixture model.  Default: 2
convergenceTol=1e-3 # "convergenceTol" (double) Set the largest change in log-likelihood at which convergence is considered to have occurred.
maxIterations=100 # "maxIterations" (int) Set the maximum number of iterations to run. Default: 100
seed=None # "seed" (long) Set the random seed

### Load and parse the sample data
data = sc.textFile("gmm_data.txt") # Data from the dummy set here: data/mllib/gmm_data.txt
parsedData = line: array([float(x) for x in line.strip().split(' ')]))
print type(parsedData)
print type(parsedData.first())

### 1st training: Build the GMM
gmm = GaussianMixture.train(parsedData, K, convergenceTol,
maxIterations, seed, initialModel)

# output parameters of model
for i in range(2):
    print ("weight = ", gmm.weights[i], "mu = ", gmm.gaussians[i].mu,
        "sigma = ", gmm.gaussians[i].sigma.toArray())

### 2nd training: Re-build a GMM using an initial model
initialModel = gmm
print type(initialModel)
gmm = GaussianMixture.train(parsedData, K, convergenceTol, maxIterations, seed, initialModel)

And this is output with the error:

<class 'pyspark.rdd.PipelinedRDD'>
<type 'numpy.ndarray'>
('weight = ', 0.51945003367044018, 'mu = ', DenseVector([-0.1045,
0.0429]), 'sigma = ', array([[ 4.90706817, -2.00676881],
       [-2.00676881,  1.01143891]]))
('weight = ', 0.48054996632955982, 'mu = ', DenseVector([0.0722,
0.0167]), 'sigma = ', array([[ 4.77975653,  1.87624558],
       [ 1.87624558,  0.91467242]]))
<class 'pyspark.mllib.clustering.GaussianMixtureModel'>

Py4JJavaError                             Traceback (most recent call last)
<ipython-input-30-0008fe75eb61> in <module>()
     33 initialModel = gmm
     34 print type(initialModel)
---> 35 gmm = GaussianMixture.train(parsedData, K, convergenceTol,
maxIterations, seed, initialModel) #

in train(cls, rdd, k, convergenceTol, maxIterations, seed,
    306         java_model =
    307                                    k, convergenceTol,
maxIterations, seed,
--> 308                                    initialModelWeights,
initialModelMu, initialModelSigma)
    309         return GaussianMixtureModel(java_model)

in callMLlibFunc(name, *args)
    128     sc = SparkContext._active_spark_context
    129     api = getattr(sc._jvm.PythonMLLibAPI(), name)
--> 130     return callJavaFunc(sc, api, *args)

in callJavaFunc(sc, func, *args)
    120 def callJavaFunc(sc, func, *args):
    121     """ Call Java Function """
--> 122     args = [_py2java(sc, a) for a in args]
    123     return _java2py(sc, func(*args))

in _py2java(sc, obj)
     86     else:
     87         data = bytearray(PickleSerializer().dumps(obj))
---> 88         obj = sc._jvm.SerDe.loads(data)
     89     return obj

in __call__(self, *args)
    536         answer = self.gateway_client.send_command(command)
    537         return_value = get_return_value(answer, self.gateway_client,
--> 538                 self.target_id,
    540         for temp_arg in temp_args:

/opt/spark/spark-1.5.1-bin-hadoop2.6/python/pyspark/sql/utils.pyc in
deco(*a, **kw)
     34     def deco(*a, **kw):
     35         try:
---> 36             return f(*a, **kw)
     37         except py4j.protocol.Py4JJavaError as e:
     38             s = e.java_exception.toString()

in get_return_value(answer, gateway_client, target_id, name)
    298                 raise Py4JJavaError(
    299                     'An error occurred while calling {0}{1}{2}.\n'.
--> 300                     format(target_id, '.', name), value)
    301             else:
    302                 raise Py4JError(

Py4JJavaError: An error occurred while calling
: net.razorvine.pickle.PickleException: expected zero arguments for
construction of ClassDict (for numpy.core.multiarray._reconstruct)
at net.razorvine.pickle.objects.ClassDictConstructor.construct(
at net.razorvine.pickle.Unpickler.load_reduce(
at net.razorvine.pickle.Unpickler.dispatch(
at net.razorvine.pickle.Unpickler.load(
at net.razorvine.pickle.Unpickler.loads(
at org.apache.spark.mllib.api.python.SerDe$.loads(PythonMLLibAPI.scala:1462)
at org.apache.spark.mllib.api.python.SerDe.loads(PythonMLLibAPI.scala)
at sun.reflect.GeneratedMethodAccessor31.invoke(Unknown Source)
at sun.reflect.DelegatingMethodAccessorImpl.invoke(
at java.lang.reflect.Method.invoke(
at py4j.reflection.MethodInvoker.invoke(
at py4j.reflection.ReflectionEngine.invoke(
at py4j.Gateway.invoke(
at py4j.commands.AbstractCommand.invokeMethod(
at py4j.commands.CallCommand.execute(


It was a bug and should be already fixed on master and branches 1.4-1.6. See SPARK-12006 and the corresponding PR.

