Pickle performance bottlenecks when using multiprocessing

Some while ago I have written a parameter scan (regularization in logistic regression to be specific) that was taking a bit to long to execute. Since the machine on which it was executed was essentially other-user-free (and had some 20 cores laying unused 🙂 ) I decided to go multiprocessing. Usually I pick multiprocessing.Pool.map on such ocassions where boilerplate seems to be the smallest. The main drawback here is that you must provide an iterable on which your target function is applied. As consequence functions with more than one argument cannot be used. Since we have complete control over the target function we can make it accept a single argument with all necessary parameters packed inside. This is exactly what WorkerTodo namedtuple does in the following code:

import time
from multiprocessing import Pool
from collections import namedtuple

WorkerTodo = namedtuple("WorkerTodo", "data threshold")
Data = namedtuple("Data", "x y z t")

def go(worker_todo):
    return sum(1 for measurement in worker_todo.data 
               if measurement.x > worker_todo.threshold)

def main():
    data = [ Data(i, i+1, i+2, i+3) for i in xrange(10**6)]

    todos = [WorkerTodo(data, thr) for thr in xrange(4)]

    t_start = time.time()
    for todo in todos:
        go(todo)
    print "Direct call (no multiprocessing): %s seconds" % (time.time()-t_start)

    t_start = time.time()
    p = Pool(2)
    p.map(go, todos)
    print "Using: multiprocessing.Pool.map:  %s seconds" % (time.time()-t_start)

if __name__ == "__main__":
    main()

As you can see I have included the unparallelized version for comparison. The results were at least disappointing, if not annoying:

Direct call (no multiprocessing): 1.43389391899 seconds
Using: multiprocessing.Pool.map:  33.8968751431 seconds

What is happening? In order to understand this we will change the WorkerTodo definition to the following:

class WorkerTodo(object):
    def __init__(self, data, threshold):
        self.data = data
        self.threshold = threshold

    def __getstate__(self):
        print "WorkerTodo __getstate__ called for thr=%s" % self.threshold
        return (time.time(), self.data, self.threshold)

    def __setstate__(self, s):
        tstart, self.data, self.threshold = s
        print "pickle-unpickle time for thr=%s was %s seconds" \
                    % (self.threshold, time.time()-tstart)

Here we use a simple class to send parameters to the child process, with __getstate__ and __setstate__ methods added (those are specific to the pickle protocol). There are written in a way showing if our data gets pickled and if so how long it takes. The output explains the poor performance:

Direct call (no multiprocessing): 0.970059156418 seconds
WorkerTodo __getstate__ called for thr=0
WorkerTodo __getstate__ called for thr=1
pickle-unpickle time for thr=0 was 8.90351390839 seconds
WorkerTodo __getstate__ called for thr=2
pickle-unpickle time for thr=1 was 9.14890098572 seconds
WorkerTodo __getstate__ called for thr=3
pickle-unpickle time for thr=2 was 9.21250295639 seconds
pickle-unpickle time for thr=3 was 9.45859909058 seconds
Using: multiprocessing.Pool.map:  30.2289671898 seconds

Our data gets pickled before sending to workers, which takes very long time. Calling cPickle dumps/loads by hand will make us certain this is the reason:

import cPickle
def main():
    data = [ Data(i, i+1, i+2, i+3) for i in xrange(10**6)]

    t_start = time.time()
    pickle_data = cPickle.dumps(data)
    print "cpickle dumps took", time.time()-t_start

    t_start = time.time()
    cPickle.loads(pickle_data)
    print "cpickle loads took", time.time()-t_start
cpickle dumps took 10.0882921219
cpickle loads took 3.06176805496

In the described case the solution was to use multiprocessing.Process instead of multiprocessing.Pool.map (note: it is far from perfect, as I will explain below). The final code looks the following:

import time
import cPickle
from multiprocessing import Process, Pool
from collections import namedtuple

Data = namedtuple("Data", "x y z t")

class WorkerTodo(object):
    def __init__(self, data, threshold):
        self.data = data
        self.threshold = threshold

    def __getstate__(self):
        print "WorkerTodo __getstate__ called for thr=%s" % self.threshold
        return (time.time(), self.data, self.threshold)

    def __setstate__(self, s):
        tstart, self.data, self.threshold = s
        print "pickle-unpickle time for thr=%s was %s seconds" \
                    % (self.threshold, time.time()-tstart)

def go(worker_todo):
    return sum(1 for measurement in worker_todo.data 
               if measurement.x > worker_todo.threshold)

def main():
    data = [ Data(i, i+1, i+2, i+3) for i in xrange(10**6)]

    t_start = time.time()
    pickle_data = cPickle.dumps(data)
    print "cPickle test: dumps took %s seconds" % (time.time()-t_start)

    t_start = time.time()
    cPickle.loads(pickle_data)
    print "cPickle test: loads took %s seconds" % (time.time()-t_start)

    todos = [WorkerTodo(data, thr) for thr in xrange(4)]

    t_start = time.time()
    for todo in todos:
        go(todo)
    print "Direct call (no multiprocessing): %s seconds" % (time.time()-t_start)

    t_start = time.time()
    p = Pool(2)
    p.map(go, todos)
    print "Pool.map took", time.time()-t_start
    print "Using: multiprocessing.Pool.map:  %s seconds" % (time.time()-t_start)

    t_start = time.time()
    all_proc = []
    for todo in todos:
        p = Process(target=go, args=(todo,))
        p.start()
        all_proc.append(p)
    for p in all_proc:
        p.join()
    print "Using: multiprocessing.Process:  %s seconds" % (time.time()-t_start)

if __name__ == "__main__":
    main()

The added part using (multiprocessing.Process) runs fast indeed:

Direct call (no multiprocessing): 0.973602056503 seconds
Using: multiprocessing.Pool.map:  32.3261759281 seconds
Using: multiprocessing.Process:  0.727508068085 seconds

What happens? In the last approach python relies on the operating system (at least on Linux, see remarks below) to provide independent copies of data for the child processes.

Both presented approaches pose a serious limitation – data gets copied multiple times in memory. In multiprocessing.Pool.map it is obvious – since pickle/unpickle gets called, independent objects are created. With multiprocessing.Process it’s more subtle – since copy-on-write is involved (at least on Linux) data will get copied eventually by the operating system, triggered by data access (e.g. during iteration). The effect is quite subtle and comes from the fact that python does reference counting for objects (and reference count is stored inside the object itself). So as soon data is read (e.g. during iteration) reference counts of accessed objects are increased and the operating system is forced to make a copy. Fortunately there is a solution also for that, I’ll try to post on it soon.

 

Final remarks:

  • The above code was executed on a Linux system. According to the multiprocessing module documentation python code executed on a windows system is likely to rely on pickling even more. So program running reasonably on Linux, on windows may start to get hit by pickle bottlenecks in places you considered to be fine.
  • In this post I have completely omitted another not so pleasant pickle limitation – it won’t pickle lambdas or locally defined functions (see https://docs.python.org/2/library/pickle.html#what-can-be-pickled-and-unpickled).  As a consequence, you cannot use such as a multiprocessing target function. There are alternatives to the pickle module that overcome this limitation, unfortunately you cannot plug them into the standard multiprocessing module.  You can find alternative multiprocessing frameworks on PyPI (e.g. pathos) that use pickle alternatives.

Leave a Reply

Your email address will not be published.