Home

Optimizing My Cellular Automata

September 5th, 2015

Last time we ended up with a nice little class that implem, andented our one dimensional cellular automaton. That code works perfectly fine if all you want is to play around with it and make nice figures, but we have bigger plans here. Before further exploring this subject, we might wanna test our code for performance, as some speed is gonna make our work easier in latter posts.

First we'll take the original code.

from random import randint

def rolling_window(arr, wsize):
    arr = arr[-wsize//2 + 1:] + arr + arr[:wsize//2]
    for i in range(len(arr) - wsize + 1):
        yield arr[i:i + wsize]


class CARef:
    def __init__(self, ncells, k, init='center'):
        self.ncells = ncells
        self.k = k

        if init == 'center':
            self.cells = [0] * self.ncells
            self.cells[self.ncells // 2] = 1
        elif init == 'random':
            self.cells = [randint(0, 1) for _ in range(self.ncells)]
        else:
            self.cells = list(init)
            assert len(self.cells) == self.ncells

    @staticmethod
    def make_rule(rule_id, k):
        rule_len = 2 ** k
        rule = list(map(int, bin(rule_id)[2:]))
        rule = [0] * (rule_len - len(rule)) + rule
        rule = rule[::-1]
        return rule

    def state_id(self, state):
        assert len(state) == self.k
        return int(''.join(map(str, state)), base=2)

    def apply_rule(self, rule):
        assert len(rule) == 2 ** self.k
        self.cells = [rule[self.state_id(w)]
                      for w in rolling_window(self.cells, self.k)]
        return self

This will be our reference implementation. Let's make a simple script to validate future implementations comparing with CARef, and time a simple benchmark.

def test(CACmp):
    rule_ref = CARef.make_rule(126, 3)
    rule_cmp = CACmp.make_rule(126, 3)
    assert all(i == j for i, j in zip(rule_ref, rule_cmp))

    caref = CARef(255, 3)
    cacmp = CACmp(255, 3)
    assert all(i == j for i, j in zip(caref.cells, cacmp.cells))

    for i in range(100):
        caref.apply_rule(rule_ref)
        cacmp.apply_rule(rule_cmp)
        assert all(i == j for i, j in zip(caref.cells, cacmp.cells))

    caref = CARef(255, 3, init='random')
    cacmp = CARef(255, 3, init=caref.cells)
    assert all(i == j for i, j in zip(caref.cells, cacmp.cells))

    for i in range(100):
        caref.apply_rule(rule_ref)
        cacmp.apply_rule(rule_cmp)
        assert all(i == j for i, j in zip(caref.cells, cacmp.cells))

    print('All tests passed.')


def time(CAutomaton, number_of_loops):
    from timeit import timeit
    prep = ';'.join(['ca = CAutomaton(250, 3)',
                     'rule = CAutomaton.make_rule(126, 3)'])
    main = 'ca.apply_rule(rule)'

    total_time = timeit(
        main, prep, number=number_of_loops, globals={"CAutomaton": CAutomaton}
    )
    avg_time = total_time / number_of_loops

    res = '{} loops, total {:.2e} s, avg. {:.2e} s/loop'
    print(res.format(number_of_loops, total_time, avg_time))

test(CARef)
time(CARef, 10000)

All tests passed.
10000 loops, total 1.25e+01 s, avg. 1.25e-03 s/loop

Ok, it takes around 13 seconds for my computer to apply a rule 10,000 times, but is that fast or slow? I'm gonna bet it's slow, this being a pure python implementation and all, but to get an idea of how slow it is, I made a simple C implementation for us to have a performance goal. You can check the code here.

$ gcc-5 -O3 -std=c99 cprof.c -o cprof && ./cprof
All tests passed.
10000 loops, total 2.00e-02 s, avg. 2.00e-06 s/loop

Whoa, That was fast! About 625 times faster to be precise. That's embarrassing I guess, but keep in mind we use python not for the runtime speed but because we love simple and beautiful code. Besides, I'm sure we can make our code a lot faster with not much work.

The python stack offers a handful of optimization alternatives that vary a lot in which approach they take to execute the code faster. One of the most promising and actively developed is PyPy, which is basically a JITed python interpreter. It has the disadvantage that most packages that use the CPython C-API (i.e. all the important ones, like numpy, scipy, pandas, etc) are not compatible with it. Luckily our implementation is pure python, making it a perfect test case for pypy. Using it is just a matter of substituting python for pypy in the command line.

All tests passed.
10000 loops, total 1.17e+00 s, avg. 1.17e-04 s/loop

Now, that's a considerable speed up for literally no extra work, props to the pypy team. Before going to other methods we might wanna address the state_id method. As noted in the last post, all these string conversions going on definitely aren't helping performance. Let's try to fix it with some clever bit shifting (ok, not that clever). If you don't get whats going on just remember that 1 << n == 2 ** n and 0 << n == 0 so that what this function is basically doing is sum(2**i for i, j in enumerate(reversed(state)) if j), which is basically your standard way of converting binary to decimal. The only reason I'm using bit shifting arithmetic is because it's a lil' bit faster.

class CAOpt1(CARef):
    def state_id(self, state):
        assert len(state) == self.k
        sid = 0
        for i, j in enumerate(reversed(state)):
            sid |= j << i
        return sid

test(CAOpt1)
time(CAOpt1, 1000)

All tests passed.
10000 loops, total 5.59e-01 s, avg. 5.59e-05 s/loop

Nice, we got an even better speed up (there might be some unnecessary overhead due to the subclassing, but I don't wanna rewrite the whole class here). We're at 23x right now relative to the pure python implementation, finally bellow the 1s mark, and "only" an order of magnitude slower than C.

I'm pretty satisfied with pypy, but let's see if we can make it even faster using numpy arrays. The main strategy when writing numpy code is to try and get rid of all the loops you can, and use ufuncs and fancy indexing instead. The first challenge is the fact the our code uses a rolling window, and this kind of stencil operations are normally not supported by the numpy API. Surprisingly, some quick googling revealed that actually there is a way of doing rolling windows in numpy. Here is how it looks like.

import numpy as np

def rolling_window_np(arr, wsize):
    arr = np.concatenate((arr[-wsize//2+1:], arr, arr[:wsize//2]))
    shape = arr.shape[:-1] + (arr.shape[-1] - wsize + 1, wsize)
    strides = arr.strides + (arr.strides[-1],)
    return np.lib.stride_tricks.as_strided(arr, shape=shape, strides=strides)

print(rolling_window_np([1,2,3,4,5,6], 3))
#[stdout]
[[6 1 2]
 [1 2 3]
 [2 3 4]
 [3 4 5]
 [4 5 6]
 [5 6 1]]

Ok, this definitely works, but if this implementation looks a bit arcane it's because it is. I mean, stride_tricks? Where did that come from? It's definitely nowhere in the online documentation. I even managed to segfault the interpreter by messing with the strides. Well, better not worry about that too much, let's see how this implementation compares with our pure python one.

$ ipython
>>> from prof0 import rolling_window as rolling_window_py
>>> x = np.arange(10)
>>> %timeit list(rolling_window_py(x, 3))
100000 loops, best of 3: 16.4 µs per loop

>>> %timeit rolling_window_np(x, 3)
10000 loops, best of 3: 29.9 µs per loop

>>> print()
>>> x = np.arange(255)
>>> %timeit list(rolling_window_py(x, 3))
10000 loops, best of 3: 129 µs per loop

>>> %timeit rolling_window_np(x, 3)
10000 loops, best of 3: 28.3 µs per loop

As it normally goes with numpy, the stride trick rolling window is only faster when the array is large enough so that the numpy overhead is small compared to the operations performed. Also worth noticing is that the numpy implementation uses more memory than the pure python one, because it stores the whole windowed array in memory, unlike our first implementation which used generators. Since we're going to work with a not too large nor too small number of cells (somewhere between 200 and 500), this nifty implementation looks good enough for us.

Now, just substituting the rolling window function is unlikely to yield any significant speedup. We also need to vectorize the state_id method. Looking at the method used in CAOpt1 this is just a matter of using the correct numpy ufuncs. Here's the result

import numpy as np

def rolling_window(arr, wsize):
    arr = np.concatenate((arr[-wsize//2+1:], arr, arr[:wsize//2]))
    shape = arr.shape[:-1] + (arr.shape[-1] - wsize + 1, wsize)
    strides = arr.strides + (arr.strides[-1],)
    return np.lib.stride_tricks.as_strided(arr, shape=shape, strides=strides)

class CAOpt2(CARef):
    def __init__(self, ncells, k, init='center'):
        CARef.__init__(self, ncells, k, init)
        self.cells = np.array(self.cells, dtype=np.byte)
        self.base = np.arange(self.k)[::-1]

    @staticmethod
    def make_rule(rule_num, k):
        return np.array(CARef.make_rule(rule_num, k))

    def apply_rule(self, rule):
        assert len(rule) == 2 ** self.k
        states = rolling_window(self.cells, self.k)
        state_ids = np.bitwise_or.reduce(states << self.base, axis=1)
        self.cells = rule[state_ids]
        return self

test(CAOpt2)
time(CAOpt2, 1000)

All tests passed.
10000 loops, total 7.57e-01 s, avg. 7.57e-05 s/loop

Just a bit slower than the best pypy result. One last trick I want to try is to use fancy indexing to eliminate completely the rolling_window from apply_rule. To do that we'll use an \(N_{cells}\times k\) array called neigh, which is basically an adjacency list of the cells, that is, the ith row of the array contains the indices of the k neighbors of the ith cell. Then we can get the states by simply evaluating cells[neigh] (this is not just a optimization strategy, this will come in handy in later posts). The end result look like this (with no subclassing)

import numpy as np

def rolling_window(arr, wsize):
    arr = np.concatenate((arr[-wsize//2+1:], arr, arr[:wsize//2]))
    shape = arr.shape[:-1] + (arr.shape[-1] - wsize + 1, wsize)
    strides = arr.strides + (arr.strides[-1],)
    return np.lib.stride_tricks.as_strided(arr, shape=shape, strides=strides)

class CAOpt3(CARef):
    def __init__(self, ncells, k, init='center'):
        self.ncells = ncells
        self.k = k
        self.base = np.arange(self.k)[::-1]
        self.neigh = rolling_window(np.arange(self.ncells), self.k)

        if init == 'center':
            self.cells = np.zeros(self.ncells, dtype=np.byte)
            self.cells[self.ncells // 2] = 1
        elif init == 'random':
            self.cells = np.random.randint(1, 2, size=self.ncells).astype(np.byte)
        else:
            self.cells = np.array(init)
            assert len(self.cells) == self.ncells

    @staticmethod
    def make_rule(rule_id, k):
        rule_len = 2 ** k
        rule = list(map(int, bin(rule_id)[2:]))
        rule = [0] * (rule_len - len(rule)) + rule
        rule = rule[::-1]
        return np.array(rule)

    def apply_rule(self, rule):
        assert len(rule) == 2 ** self.k
        states = self.cells[self.neigh]
        state_ids = np.bitwise_or.reduce(states << self.base, axis=1)
        self.cells = rule[state_ids]
        return self

test(CAOpt3)
time(CAOpt3, 1000)

All tests passed.
10000 loops, total 3.53e-01 s, avg. 3.53e-05 s/loop

Faster than pypy! I think I'm calling it a day, we could probably get even better performance by using numba to jit some of the most work intensive parts, or even use cython to create a compiled extension, but more than performance I want to keep this code simple and easily modifiable, and 35 ns per iteration looks great for our purposes. Here is a little summary of what we did:

Time per \(10^4\) iterations Speedup Cumulative Speedup
Pure Python 12.5 s - -
Pypy 1.17 s 11 x 11 x
Numpy 0.76 s 1.5 x 16.5 x
Pypy + Bitshift 0.56 s 1.4 x 22.3 x
Numpy + Fancy Indexing 0.35 s 1.6 x 35.7 x
C 0.02 s 17 x 625 x

Comments on Reddit