How I’ve managed to shoot myself in the foot with numpy and cPickle

Currently, I use python mostly for data analysis and modeling. Whenever I can I take a pipeline-like approach, where data is processed in multiple steps. Those are implemented in separate py files with cPickle used for data persistence and exchange. You can think of this as of a poor man’s mapreduce.

Usually the development is an iterative process. I run and change given step over and over till I’m happy (at least for a moment) with the result. This means loading of same data with cPickle multiple times. Which, turns out, I’ve been doing wrong for a very long time.

So, what was wrong? Consider this code for pickling and unpickling a large numpy array:

import numpy as np
import cPickle
import time

data = np.random.sample(int(1e7))
t1 = time.time()
with open("data.pkl", "wb") as of_:
    cPickle.dump(data, of_)
print "Write took", time.time()-t1

t1 = time.time()
with open("data.pkl", "rb") as of_:
    cPickle.load(of_)
print "Read took", time.time()-t1

On my laptop, the execution of the write/read parts of the above takes 12 and 6.5 seconds respectively. The code is straightforward, nothing looks wrong. Except the fact I have omitted pickle protocol version specification. Which, as it turns out, has a dramatic impact on the performance. If we set the protocol to the latest and greatest:

with open("data.pkl", "wb") as of_:
    cPickle.dump(data, of_, cPickle.HIGHEST_PROTOCOL)

the write/read times drop to 0.7 and 0.1 seconds respectively. Which is nearly two orders of magnitude difference!

It is nothing unusual now to have data big enough to make the total loading time (i.e. summed over for all “change it and run” iterations during a day) significant, e.g. half an hour. Which is 100% wasted. If this feels to you like overreacting think how including a half of minute lag in a start of a program would affect your comfort as a developer.

Of course we could take another approach by loading our data once inside ipython notebook and do all of the development there. This I try to avoid whenever I can, but that’s a topic for another post.

So remember – pickle protocol matters!

Unittesting print statements

Recently I was refactoring a small package that is supposed to allow execution of arbitrary python code on a remote machine. The first implementation was working nicely but with one serious drawback – function handling the actual code execution was running in a synchronous (blocking) mode. As the result all of the output (both stdout and stderr) was presented only at the end, i.e. when code finished its execution. This was unacceptable since the package should work in a way as transparent to the user as possible. So a wall of text when code completes its task wasn’t acceptable.

The goal of the refactoring was simple – to have the output presented to the user immediately after it was printed on the remote host. As a TDD worshipper I wanted to start this in a kosher way, i.e. with a test. And I got stuck.

For a day or so I had no idea how to progress. How do you unittest the print statements? It’s funny when I think about this now. I have used a similar technique many times in the past for output redirection, yet somehow haven’t managed to make a connection with this problem.

The print statement

So how do you do it? First we should understand what happens when print statement is executed. In python 2.x the print statement does two things – converts provided expressions into strings and writes the result to a file like object handling the stdout. Conveniently it is available as sys.stdout (i.e. as a part of sys module). So all you have to do is to overwrite the sys.stdout with your own object providing a ‘write’ method. Later you may discover, that some other methods may be also needed (e.g. ‘flush’ is quite often used), but for starters, having only the ‘write’ method should be sufficient.

A first try – simple stdout interceptor

The code below does just that. The MyOutput class is designed to replace the original sys.stdout:

import unittest
import sys

def fn_print(nrepeat):
    print "ab"*nrepeat

class MyTest(unittest.TestCase):
    def test_stdout(self):
        class MyOutput(object):
            def __init__(self):
                self.data = []

            def write(self, s):
                self.data.append(s)

            def __str__(self):
                return "".join(self.data)

        stdout_org = sys.stdout
        my_stdout = MyOutput()
        try:
            sys.stdout = my_stdout
            fn_print(2)
        finally:
            sys.stdout = stdout_org

        self.assertEquals( str(my_stdout), "abab\n") 

if __name__ == "__main__":
    unittest.main()

The fn_print function provides output to test against. After replacing sys.stdout we call this function and compare the obtained output with the expected one. It is worth noting that in the example above the original sys.stdout is first preserved and then carefully restored inside the ‘finally’ block. If you don’t do this you are likely to loose any output coming from other tests.

Is my code async? Logging time of arrival

In the second example we will address the original problem – is output presented as a wall of text at the end or maybe in real time as we want to. For this we will add time of arrival logging capability to the object replacing sys.stdout:

import unittest
import time
import sys

def fn_print_with_delay(nrepeat):
    for i in xrange(nrepeat):
        print    # prints a single newline
        time.sleep(0.5)

class TestServer(unittest.TestCase):
    def test_stdout_time(self):
        class TimeLoggingOutput(object):
            def __init__(self):
                self.data = []
                self.timestamps = []

            def write(self, s):
                self.timestamps.append(time.time())
                self.data.append(s)

        stdout_org = sys.stdout
        my_stdout = TimeLoggingOutput()
        nrep = 3 # make sure is >1
        try:
            sys.stdout = my_stdout
            fn_print_with_delay(nrep)
        finally:
            sys.stdout = stdout_org

        for i in xrange(nrep):
            if i > 0:
                dt = my_stdout.timestamps[i]-my_stdout.timestamps[i-1]
                self.assertTrue(0.5<dt<0.52)

if __name__ == "__main__":
    unittest.main()

The code is pretty much self-explanatory – the fn_print_with_delay function prints newlines in half of a second intervals. We override sys.stdout with an instance of a class capable of storing timestamps (obtained with time.time()) of all calls to the write method. At the and we assert the timestamps are spaced half of a second approximately. The code above works as expected:

.
----------------------------------------------------------------------
Ran 1 test in 1.502s

OK

If we change the interval inside the fn_print_with_delay function to one second, the test will (fortunately) fail.

Wrap-up

As we saw, testing for expected output is in fact trivial – all you have to do is to put an instance of a class with a ‘write’ method in proper place (i.e. sys.stdout). The only ‘gotcha’ is the cleanup – you should remember to restore sys.stdout to its original state. You may apply the exact same technique if you need to test stderr (just target the sys.stderr instead of sys.stdout). It is also worth noting that using a similar technique you could intercept (or completely silence) output coming from external libraries.

Modify MS Word documents with python

A couple of months ago my wife completed and defended her Ph.D. thesis in archaeology. To our surprise she received a proposition to turn it into a book. Probably I don’t need to write how excited we are, especially since this kind of thing happens rarely in her environment.

Unfortunately it also means lots of work. The thesis is about 1 thousand pages long is written in MS Word (docx format). Now, my wife must once again go through the whole text and edit it to meet the print standards of the publisher. One of the things that needs to be modified is the format of citations. Currently those look the following

(M. Mouse, 1901; D. Duck, 1999)

and need to become

(MOUSE, 1901; DUCK, 1999)

All citations were done manually in the original text, i.e. were not handled by any sort of bibliography manager (I’m not sure how handy, useful or even if such thing is included in MS Word, but that’s a different story).

After some experimentation I’ve managed to write a simple python script beeing able to find and modify citations so they look as desired. I’m going to show you how to get started solving this or a similar problem in python.

The toolbox

The whole maneuver was possible thanks to the fact that thesis was saved in the docx format. Essentially it’s a zip file with a bunch of xml files in it. The one with the document text is named (surprise, surprise) document.xml. So all we have to do is unzip the docx file, use python to parse and modify text in this file to our needs a xml, and then zip it back.

As you can see our toolbox is very simple – so far it is python and zip/unzip commands. In principle we could try to ditch the external zip commands and use python own zipfile module, but this seems a minor overkill, as the number of zip/unzip operations we will need to perform is not that large.

The last thing needed is a xml formatter or pretty printer. The file we want to modify (document.xml) is essentially one line long, so any form of ‘manual’ inspection, e.g. performing a diff, would be impractical in such form. Xml formatter will add line breaks and indentation (i.e. make the file human readable), so it will be possible to visually check what changes were done. For this I was using xmllint:

xmllint --format old_file.xml > new_file.xml

Once again it would be possible to perform this within python (you can google a solution, e.g. with lxml, easily) but I will stick to the external tool as it worked fine for me. It is worth noting that using mentioned cli tools does not mean manual operations since you can incorporate them into your script (os.system call is good enough for the job).

Start small

My general advice is to start small. If you plan on modifying a long document (as I did), create a new one with a couple of pages copied and pasted into it. Then unzip, parse the document.xml file, then save to a new file without any modifications. Complete first iteration by building a new docx file (i.e. overwrite original document.xml file with the freshly created one). My initial script is below, and surprisingly – it wasn’t working properly:

import xml.etree.ElementTree as ET

tree = ET.parse('document.xml.org')
root = tree.getroot()
for element in root.iter():
    pass

tree.write(open('document.xml', 'wb'), encoding='utf-8')

The resulting docx file was opening OK in libreoffice. It was also ok in gmail preview. But MS Word wasn’t happy with the result and refused to open the file. A quick look at the original and the new document.xml files show where the problem is:

<!-- beginning of the orignal document.xml file -->
<w:document xmlns:ve="http://schemas.openxmlformats.org/markup-compatibility/2006" xmlns:o="urn:schemas-microsoft-com:office:office" xmlns:r="http://schemas.openxmlformats.org/officeDocument/2006/relationships" xmlns:m="http://schemas.openxmlformats.org/officeDocument/2006/math" xmlns:v="urn:schemas-microsoft-com:vml" xmlns:wp="http://schemas.openxmlformats.org/drawingml/2006/wordprocessingDrawing" xmlns:w10="urn:schemas-microsoft-com:office:word" xmlns:w="http://schemas.openxmlformats.org/wordprocessingml/2006/main" xmlns:wne="http://schemas.microsoft.com/office/word/2006/wordml">
  <w:body>
    <w:p w:rsidR="00E961ED" w:rsidRPr="00AE4808" w:rsidRDefault="00AE4808">
<!-- remaining part ommited for brevity -->
<!-- beginning of the document.xml file obtained in the first try-->
<ns0:document xmlns:ns0="http://schemas.openxmlformats.org/wordprocessingml/2006/main">
  <ns0:body>
    <ns0:p ns0:rsidR="00E961ED" ns0:rsidRDefault="00AE4808" ns0:rsidRPr="00AE4808">
<!-- remaining part ommited for brevity -->

If you scroll both listings sideways you will notice, that the second file lacks most of the namespace mappings present in the first (original) file. It seems, that ElementTree is not especially user-friendly when it comes to handling namespaces (see this stackoverflow question for details). Fortunately, in our case fix is quite easy – use a different library to handle xml parsing and creation. Bellow you can find a working snippet, this time using lxml:

from lxml import etree as ET
tree = ET.parse('document.xml.org')
root = tree.getroot()

for element in root.iter():
    pass

tree.write('document.xml', xml_declaration = True, encoding = "UTF-8", method = "xml", standalone = "yes")

File created with the above snippet doesn’t differ from the original one, so resulting docx file is correctly opened inside Word. Yay!

Regex or bust!

Finally we are on track to tackle the problem. Since citations come in a coherent format (which we want to change by deleting groups of characters or making them upper case) this seems a natural place for regex. Unfortunately regex won’t work for us out of the box, since text is scattered across multiple xml elements. We need to gather it for the whole document and somehow keep track of link between given letter and element it belongs to. On top of that we need to be able to mark a letter for deletion or to be made upper case. This is achieved with the following class:

KILL = 0
CAPS = 1
class ElementsWithText(object):
    def __init__(self):
        self.elements = []
        self.commands = {}

    def append(self, element):
        self.elements.append(element)

    def __unicode__(self):
        return u"".join([x.text for x in self.elements])

    def set_command(self, index, command):
        self.commands[index] = command

    def finalize(self):
        iletter = -1
        for element in self.elements:
            final_text_for_this_element = u""
            for letter in element.text:
                iletter += 1
                if iletter not in self.commands:
                    final_text_for_this_element += letter
                else:
                    if self.commands[iletter] == KILL:
                        continue
                    elif self.commands[iletter] == CAPS:
                        final_text_for_this_element += letter.upper()
            element.text = final_text_for_this_element

Essentially, this is a container for all elements with some utility methods. The __unicode__ method builds a complete text from all elements that were stored inside the object. This, as mentioned earlier, is crucial if we want to use regex. The set_command method stores desired action to be performed (here KILL or CAPS) on a given letter (i.e. with given “global” index or position inside text). Finally, the finalize method, goes through all elements and modifies their text in accordance with instructions encoded inside self.commands instance data.

The above class can be used in the following way:

from lxml import etree as ET
import re

tree = ET.parse('document.xml.org')
root = tree.getroot()

elements_data = ElementsWithText()
for element in root.iter():
    if element.text:
        elements_data.append(element)

str_text = unicode(elements_data)
for match in re.finditer("\(([^ ]+?\. )(\w+)", str_text):
    for i in xrange(match.start(1), match.end(1)):
        elements_data.set_command(i, KILL)

    for i in xrange(match.start(2), match.end(2)):
        elements_data.set_command(i, CAPS)

elements_data.finalize()

tree.write('document.xml', xml_declaration = True, encoding = "UTF-8", method = "xml", standalone = "yes")

After parsing the document.xml file with lxml, we store the elements with having non-empty text inside an instance of our class (i.e. ElementsWithText). Then we build a complete text (str_text variable above), on which we can run regex matching. The regular expression we use allows marking which parts of text should be omitted and which should be capitalized. Call to the finalize method performs those modifications. At the end we land with modified document.xml file we can put inside a new docx file.

Summary

We have learned how to modify text inside a docx (the latest MS Word format) file. A crucial part of the process was understanding how to write a document.xml file conformant with the docx format. It was also necessary to code some additional infrastructure in order to be able to use regex.

I have played with the code that was basis for this post for a couple of hours before obtaining a final solution (since there were some special cases or exceptions with respect to the general rule of what and how to change). Was it useful for fixing the citations? An honest answer would be “partially”. It turned out, that setting letters to upper case was not enough – MS Word has a special style called “Small Caps”, that makes things look slightly nicer. So the “delete part” of the program was OK, the “upper case” not fully. And since at the time I had no possibility to work further on the problem (and the deadline for another manuscript version was close) part of this task had to be performed in a tedious, manual way.

Multiprocessing and exceptions – some batteries not included

Today I’m going to write about a not that minor inconvenience one faces when using the built-in multiprocessing module – how child process exceptions are presented to the user. I will show you also how to improve it, so in case something goes wrong you don’t have to guess where the problem is.

Standalone multiprocessing

Through this story, we will stick to a very simple calculation shown below. We have our computation code contained in the ‘go’ function and want to apply it to a range of parameters. We decided to make use of facilities provided by the multiprocessing module. Unfortunately, during a long and tiring coding sprint, a bug crept into our code:

from multiprocessing import Pool

def go(x):
    ret = 0.
    for i in xrange(x+1):
        ret += 1./(5-i)
    return ret

def main():
    pool = Pool(processes=4)  
    print pool.map(go, range(10))

if __name__ == "__main__":
    main()

The output we get after running the code above is far from beeing over-verbose:

Traceback (most recent call last):
  File "go_1.py", line 14, in &lt;module&gt;
    main()
  File "go_1.py", line 11, in main
    print pool.map(go, range(10))
  File "/usr/lib/python2.7/multiprocessing/pool.py", line 251, in map
    return self.map_async(func, iterable, chunksize).get()
  File "/usr/lib/python2.7/multiprocessing/pool.py", line 558, in get
    raise self._value
ZeroDivisionError: float division by zero

From such traceback we can find out what was the type of exception and what was the target function of the Pool.map call.  In case of our ‘go’ function guessing where is the problem is fairly simple – the target function is short, with only a single place where this exception may be coming. In real life the target function will be usually far more complicated and may call other functions from external modules. So seeing traceback similar to the one above doesn’t help at all. Is it our code that thrown the exception? numpy? scikit-learn? Happy guessing – lack of information which line in our code caused it makes our life miserable. At this point we have two possibilities – launch a proper python debugger or try to obtain traceback as it would be presented to us if the code would be run in the non-multiprocessing way.

Since traceback is often enough to understand what is the problem, this time we will leave the debugger at rest and try to obtain a more informative printout.

The traceback module

In order to improve our situation we will use the traceback module to, ehm… obtain a traceback. In order to have our solution reusable we will put it into a decorator:

import functools
import traceback
import sys

def get_traceback(f):
    @functools.wraps(f)
    def wrapper(*args, **kwargs):
        try:
            return f(*args, **kwargs)
        except Exception, ex:
            ret = '#' * 60
            ret += "\nException caught:"
            ret += "\n"+'-'*60
            ret += "\n" + traceback.format_exc()
            ret += "\n" + '-' * 60
            ret += "\n"+ "#" * 60
            print &gt;&gt; sys.stderr, ret
            sys.stderr.flush()
            raise ex

    return wrapper

The code above simply prints the traceback in case of problems, i.e. when an exception is thrown and not handled inside wrapped function. After applying it to our function:

@get_traceback
def go(x):
   (...)

The error message starts to be meaningful:

Exception caught:
------------------------------------------------------------
Traceback (most recent call last):
  File "./go_2.py", line 10, in wrapper
    return f(*args, **kwargs)
  File "./go_2.py", line 28, in go
    ret += 1./(5-i)
ZeroDivisionError: float division by zero

(this is actually repeated couple of times since our exception is thrown inside more than one process). In the above output you can exactly see where (which line number) is the problem coming from.

It is worth noting that the usage of functools.wraps helper decorator is crucial in our case – without this the __name__ attribute of the decorated function gets lost (i.e. set to ‘wrapper’) which then makes pickle module fail. The later one is used by the multiprocessing module to serialize function executed inside child processes. You can verify this by getting rid of functools and then setting the __name__ of resulting decorated function manually.

So at this point we are able to get a proper traceback which could be enough. But there is also a different possibility I would like to explore.

The fun way

Some while ago I have discovered a little gem – the joblib package.  In order to get it, you need to run ‘pip install joblib’ inside your virtualenv. Among others, it offers an alternative to the multiprocessing module when doing parallel computation similar to ours. With joblib, we can rewrite our code in the following way:

from joblib import Parallel, delayed
def go(x):
    ret = 0.
    for i in xrange(x+1):
        ret += 1./(5-i)
    return ret

def main():
    print Parallel(n_jobs=4)(delayed(go)(i) for i in range(10))

if __name__ == "__main__":
    main()

The (partial) output we get from running it is the following:

/home/tfruboes/2017.02.threadedGIL/go_3.py in go(x=5)
      1 from joblib import Parallel, delayed
      2 
      3 def go(x):
      4     ret = 0.
      5     for i in xrange(x+1):
----&gt; 6         ret += 1./(5-i)
        ret = 2.283333333333333
        i = 5
      7     return ret
      8 
      9 def main():
     10     print Parallel(n_jobs=4)(delayed(go)(i) for i in range(10))

ZeroDivisionError: float division by zero

As you can see, we got a code listing with the line causing the exception marked. Below that line, you can also see information on local variables at the point exception was thrown. You may also notice that arguments with which the ‘go’ function was called are also print. So tons of useful information that in lots of cases will allow us to immediately understand the problem. Neat!

Wrap up

We have seen, that in normal conditions the multiprocessing module won’t give us the usual amount of information on an exception beeing thrown inside the child process. This is slightly surprising, as one could expect that (following the “batteries included” philosophy) this should be done in the exactly same way as when no multiprocessing module is used. In order to get this info you should use the traceback module. Or, in some cases, go for joblib. Note, that it offers far more than nice printouts in case of problems.

A quick and easy way to view cProfile results

Every now and then I need to profile parts of my code (as we all do). For me this used to happen with a frequency low enough to prevent remembering what one does with the output of the profiler, i.e. how to display results in an interpretable way. So every time I had to measure code performance I googled the cProfile module and looked up the examples on how to display, sort and interpret results. This was a minor annoyance.

Fortunately, it is no more thanks to the snakeviz package. All you have to do in order to use it is run

pip install snakeviz

inside your virtualenv. Then proceed as usual – in your code import cProfile and replace direct call to the main() function (or any other function you wish to use as starting point for profiling) with

cProfile.run('main()', "path_to_stats.prof_file")

After your program terminates just point snakeviz to the profiler output file

snakeviz path_to_stats.prof_file

This should pop up a browser window with a nice and meaningful visualization of results. And that’s it – you got your results presented in a clean and ready to consume way.

In order to show snakeviz in action I have created a very simple mock-up script:

import time
import cProfile

def read_data():
    time.sleep(1.5)

def clean_data():
    time.sleep(2.7)

def fit():
    time.sleep(4.2)

def build_model():
    clean_data()
    fit()

def main():
    read_data()
    build_model()

if __name__ == "__main__":
    cProfile.run('main()', "stats.prof")

Results are shown below. After launching snakeviz you can select one of two display styles – Sunburst or Icicle. The later one seems to be more informative and it’s shown in the image below (click on image for a full-size version):

Visualizing cProfile results with snakeviz.

Snakeviz provides two ways to explore profiler data. You can choose the sorting criterion in the output table (e.g. the number of given function calls or cumulative time) or select a subset of the output graph to display.

If you work iteratively (i.e. “change code-rerun the profiler” multiple times), snakeviz will also nicely fit this pattern. In order to display updated results, you don’t have to go through the command line and restart it. Assuming the profiler output went to the same file, just go to the browser, hit refresh (i.e. F5 in most browsers) and voila – a new set of results appears.

Snakeviz is a very well thought tool. Not overdone, with a “just right” set of features for the job it aims to do.

Final notes:

  • For more info on snakeviz go visit the project homepage.
  • Another notable tool for the job is RunSnakeRun. Unfortunately, it seems to be no longer maintained (never actually tried it, so your mileage may vary).
  • Sometime cProfle is not enough, i.e. you may need more detailed information. The line_profiler package is another great utility module, allowing to see a line by line execution time of the profiled function. I’m going to cover this module on the blog soon.

Pickle performance bottlenecks when using multiprocessing

Some while ago I have written a parameter scan (regularization in logistic regression to be specific) that was taking a bit to long to execute. Since the machine on which it was executed was essentially other-user-free (and had some 20 cores laying unused 🙂 ) I decided to go multiprocessing. Usually I pick multiprocessing.Pool.map on such ocassions where boilerplate seems to be the smallest. The main drawback here is that you must provide an iterable on which your target function is applied. As consequence functions with more than one argument cannot be used. Since we have complete control over the target function we can make it accept a single argument with all necessary parameters packed inside. This is exactly what WorkerTodo namedtuple does in the following code:

import time
from multiprocessing import Pool
from collections import namedtuple

WorkerTodo = namedtuple("WorkerTodo", "data threshold")
Data = namedtuple("Data", "x y z t")

def go(worker_todo):
    return sum(1 for measurement in worker_todo.data 
               if measurement.x &amp;amp;gt; worker_todo.threshold)

def main():
    data = [ Data(i, i+1, i+2, i+3) for i in xrange(10**6)]

    todos = [WorkerTodo(data, thr) for thr in xrange(4)]

    t_start = time.time()
    for todo in todos:
        go(todo)
    print "Direct call (no multiprocessing): %s seconds" % (time.time()-t_start)

    t_start = time.time()
    p = Pool(2)
    p.map(go, todos)
    print "Using: multiprocessing.Pool.map:  %s seconds" % (time.time()-t_start)

if __name__ == "__main__":
    main()

As you can see I have included the unparallelized version for comparison. The results were at least disappointing, if not annoying:

Direct call (no multiprocessing): 1.43389391899 seconds
Using: multiprocessing.Pool.map:  33.8968751431 seconds

What is happening? In order to understand this we will change the WorkerTodo definition to the following:

class WorkerTodo(object):
    def __init__(self, data, threshold):
        self.data = data
        self.threshold = threshold

    def __getstate__(self):
        print "WorkerTodo __getstate__ called for thr=%s" % self.threshold
        return (time.time(), self.data, self.threshold)

    def __setstate__(self, s):
        tstart, self.data, self.threshold = s
        print "pickle-unpickle time for thr=%s was %s seconds" \
                    % (self.threshold, time.time()-tstart)

Here we use a simple class to send parameters to the child process, with __getstate__ and __setstate__ methods added (those are specific to the pickle protocol). There are written in a way showing if our data gets pickled and if so how long it takes. The output explains the poor performance:

Direct call (no multiprocessing): 0.970059156418 seconds
WorkerTodo __getstate__ called for thr=0
WorkerTodo __getstate__ called for thr=1
pickle-unpickle time for thr=0 was 8.90351390839 seconds
WorkerTodo __getstate__ called for thr=2
pickle-unpickle time for thr=1 was 9.14890098572 seconds
WorkerTodo __getstate__ called for thr=3
pickle-unpickle time for thr=2 was 9.21250295639 seconds
pickle-unpickle time for thr=3 was 9.45859909058 seconds
Using: multiprocessing.Pool.map:  30.2289671898 seconds

Our data gets pickled before sending to workers, which takes very long time. Calling cPickle dumps/loads by hand will make us certain this is the reason:

import cPickle
def main():
    data = [ Data(i, i+1, i+2, i+3) for i in xrange(10**6)]

    t_start = time.time()
    pickle_data = cPickle.dumps(data)
    print "cpickle dumps took", time.time()-t_start

    t_start = time.time()
    cPickle.loads(pickle_data)
    print "cpickle loads took", time.time()-t_start
cpickle dumps took 10.0882921219
cpickle loads took 3.06176805496

In the described case the solution was to use multiprocessing.Process instead of multiprocessing.Pool.map (note: it is far from perfect, as I will explain below). The final code looks the following:

import time
import cPickle
from multiprocessing import Process, Pool
from collections import namedtuple

Data = namedtuple("Data", "x y z t")

class WorkerTodo(object):
    def __init__(self, data, threshold):
        self.data = data
        self.threshold = threshold

    def __getstate__(self):
        print "WorkerTodo __getstate__ called for thr=%s" % self.threshold
        return (time.time(), self.data, self.threshold)

    def __setstate__(self, s):
        tstart, self.data, self.threshold = s
        print "pickle-unpickle time for thr=%s was %s seconds" \
                    % (self.threshold, time.time()-tstart)

def go(worker_todo):
    return sum(1 for measurement in worker_todo.data 
               if measurement.x &amp;amp;amp;amp;gt; worker_todo.threshold)

def main():
    data = [ Data(i, i+1, i+2, i+3) for i in xrange(10**6)]

    t_start = time.time()
    pickle_data = cPickle.dumps(data)
    print "cPickle test: dumps took %s seconds" % (time.time()-t_start)

    t_start = time.time()
    cPickle.loads(pickle_data)
    print "cPickle test: loads took %s seconds" % (time.time()-t_start)

    todos = [WorkerTodo(data, thr) for thr in xrange(4)]

    t_start = time.time()
    for todo in todos:
        go(todo)
    print "Direct call (no multiprocessing): %s seconds" % (time.time()-t_start)

    t_start = time.time()
    p = Pool(2)
    p.map(go, todos)
    print "Pool.map took", time.time()-t_start
    print "Using: multiprocessing.Pool.map:  %s seconds" % (time.time()-t_start)

    t_start = time.time()
    all_proc = []
    for todo in todos:
        p = Process(target=go, args=(todo,))
        p.start()
        all_proc.append(p)
    for p in all_proc:
        p.join()
    print "Using: multiprocessing.Process:  %s seconds" % (time.time()-t_start)

if __name__ == "__main__":
    main()

The added part using (multiprocessing.Process) runs fast indeed:

Direct call (no multiprocessing): 0.973602056503 seconds
Using: multiprocessing.Pool.map:  32.3261759281 seconds
Using: multiprocessing.Process:  0.727508068085 seconds

What happens? In the last approach python relies on the operating system (at least on Linux, see remarks below) to provide independent copies of data for the child processes.

Both presented approaches pose a serious limitation – data gets copied multiple times in memory. In multiprocessing.Pool.map it is obvious – since pickle/unpickle gets called, independent objects are created. With multiprocessing.Process it’s more subtle – since copy-on-write is involved (at least on Linux) data will get copied eventually by the operating system, triggered by data access (e.g. during iteration). The effect is quite subtle and comes from the fact that python does reference counting for objects (and reference count is stored inside the object itself). So as soon data is read (e.g. during iteration) reference counts of accessed objects are increased and the operating system is forced to make a copy. Fortunately there is a solution also for that, I’ll try to post on it soon.

 

Final remarks:

  • The above code was executed on a Linux system. According to the multiprocessing module documentation python code executed on a windows system is likely to rely on pickling even more. So program running reasonably on Linux, on windows may start to get hit by pickle bottlenecks in places you considered to be fine.
  • In this post I have completely omitted another not so pleasant pickle limitation – it won’t pickle lambdas or locally defined functions (see https://docs.python.org/2/library/pickle.html#what-can-be-pickled-and-unpickled).  As a consequence, you cannot use such as a multiprocessing target function. There are alternatives to the pickle module that overcome this limitation, unfortunately you cannot plug them into the standard multiprocessing module.  You can find alternative multiprocessing frameworks on PyPI (e.g. pathos) that use pickle alternatives.

Being defensive with pickle in evolving environment

Pickle is an in-house python object persistence solution. Although very useful, care must be taken when using it with class definitions that may change, i.e. are under active development. Consider the following example

import cPickle

class Test():
    def __init__(self):
        self.var1 = 1
        self.var2 = 2

t1 = Test()
print t1.__dict__
t_pickle_str = cPickle.dumps(t1)

class Test(Versionable):
    def __init__(self):
        self.var3 = 3

t2 = cPickle.loads(t_pickle_str)
print t2.__dict__

Both printouts will show you var1 and var2 instance variables and no var3, despite the fact that class logic changed in meantime. This is a normal and expected behavior.

At some point, I needed to include a protection against this in one of my data cleaning algorithms that was wrapped inside a class. The set of parameters used by the algorithm (and stored as instance variables) was too expensive to determine (train) each time when used in the production code but quite cheap to pickle. The algorithm itself was tweaked and changed from time to time which could lead to subtle (and non-verbose) bugs if wrong (old) pickle file was used with class definition with updated algorithm.

In order to handle such situation, it is possible to exploit the fact that pickle doesn’t serialize class level variables. During deserialization, those are simply taken from the current class definition. Thanks to this it is possible to introduce class version control during pickle/unpickle. Consider the following mixin:

class Versionable(object):
    def __getstate__(self):
        if not hasattr(self, "_class_version"):
            raise Exception("Your class must define _class_version class variable")
        return dict(_class_version=self._class_version, **self.__dict__)
    def __setstate__(self, dict_):
        version_present_in_pickle = dict_.pop("_class_version")
        if version_present_in_pickle != self._class_version:
            raise Exception("Class versions differ: in pickle file: {}, "
                            "in current class definition: {}"
                            .format(version_present_in_pickle,
                                    self._class_version))
        self.__dict__ = dict_

Here  __getstate__ and __setstate__ pickling protocol methods (which should not be confused with pickle protocol version) are provided. The __getstate__ method attaches to pickled data the current class version (taken from _class_version class level variable that must be defined in a subclass). The __setstate__ method compares this value red from pickle with the one from the current class definition. If there is mismatch exception is thrown.

The following code shows Versionable mixin (saved into versionable.py file) in action:

from versionable import Versionable
import cPickle

class TestVersioning(Versionable):
    _class_version = 1

t1 = TestVersioning()

t_pickle_str = cPickle.dumps(t1)

class TestVersioning(Versionable):
    _class_version = 2

t2 = cPickle.loads(t_pickle_str)

This leads to the following output

Traceback (most recent call last):
  File "/home/tfruboes/test.py", line 16, in <module>
    t2 = cPickle.loads(t_pickle_str)
  File "/home/tfruboes/versionable.py", line 20, in __setstate__
    self._class_version))
Exception: Class versions differ: in pickle file: 1, in current class definition: 2

So as long as you remember to bump the version number when incompatible changes are made you are safe.

Some random notes:

  • For the _class_version class variable, you can use anything that is comparable with the “==” operator. So if you want to be more descriptive and provide more than one version number (e.g. minor and major) dict will be also fine.
  • There is an alternative approach possible in order to implement safety using versioning using copy_reg module. For this see item 44 from the “Effective Python” book (if you haven’t visited the books section it’s a right time 🙂 )