Django Multiprocessing

Once we got to a certain scale, it became necessary to start multithreading many of the scripts we run to keep our infrastructure healthy. Turns out that does not work very well out of the box with Django (we are currently running 1.11 on Python 2.7 though py3 is coming to our stack soon!). We could, of course, use celery to distribute the work to our task running infrastructure, but for many operations the effort to set that up is too large for the payoff.

Essentially, the problem is that Django assumes a single process and so has one connection for each service it connects to – in our case, Postgres, Redis for caching and other utility stuff, and Elasticsearch. With Python’s threading model, this means that those connections just get copied into the memory of each child process and then used without coordination. Very quickly you find that this model is unusable in practice. There are numerous stack overflow posts (I first saw this one) outlining the problem and providing the solution: close all the connections in the main thread before spawning the child processes. Each child thread will detect that it doesn’t have a connection and grab a new one as needed.

This idea formed the basis for our solution, which extends the concept to all connection types and also prescribes an easy interface for how the programmer can write safe multiprocessed Django code at Talentpair. And also gets a few niceties along the way — aggregation of results, automatic worker allocation based on the number of cores and periodic printouts to show how much work has been done.

We named the core class MultiProcess and implemented it as a Context Manager to make sure that setup and teardown are handled properly.

Here’s the first building block:

def close_service_connections():    # close db connections, they will be recreated automatically
    db.connections.close_all()

    # close ES connection, needs to be manually recreated
    connections.connections.remove_connection("default")

    # close redis connections, will be recreated automatcially
    for k in settings.CACHES.keys():
        caches[k].close()


def recreate_service_connections():

    # ES is the only one that needs to be recreated explicitly
    connections.connections.create_connection(hosts=[settings.ELASTIC_FULL_URL], timeout=20)

With these two pieces, we have a nice place to put code to close and restart connections to all our services.

The next question is what sort of data flow should we support? We chose to support the map pattern. The input to MultiProcess is a function which takes one argument and a list of things to be passed to that function. In practice, this means we generally provide a function which operates on a database record and a list of primary keys.

def update_index_for_uuid(uuid):
    obj = get_instance_from_uuid(uuid)

    index_the_thing(obj)
list_of_uuids = [u1, u2, ...]

So using the MultiProcess module generally looks something like this:

uuids = [obj.uuid for obj in cls.objects.all()]
with MultiProcess() as mp:
    mp.map(update_index_for_uuid, uuids)
    results = mp.results()

This code

  • Checks the number of cores available and allocates that many child processes (minus 1) to do the work
  • Breaks the list of uuids up into equal sized sub-lists and passes them on, one for each worker
  • Waits for all children to finish and then grabs the results

The only real gotcha with this pattern is that the list of uuids must be constructed before entering the with block. Inside that block, in the main thread, all the service connections will have been closed.

Here is an abbreviated version of the MultiProcess context manager. The main features are:

  • Upon entering the with block, close all connections
  • Upon leaving the with block, reopen all connections, so that the main thread can go on its merry way
  • Use a threadsafe queue to aggregate results and keep track of how many input list items have been operated on
class MultiProcess(object):

    queue = None
    workers = []
    num_workers = 4

    def __enter__(self):
        close_service_connections()
        return self

    def map(self, func, iterable):

        self.queue = Manager().Queue()

        for worker_idx in range(self.num_workers):

            items = []

            for idx, item in enumerate(iterable):
                if idx % self.num_workers == worker_idx:
                    items.append(item)

            p = Process(target=threadwrapper(func),
                        args=[self.queue, items])
            p.start()
            self.workers.append(p)

        self._wait()

    def _wait(self):
        """ Wait for all workers to finish and wakes up
        peridocially to print out how much work has happened
        """
        # do some waiting

    def results(self):
        rv = []
        try:
            while True:
                rv.append(self.queue.get(block=False))
        except Empty:
            return rv    def __exit__(self, type, value, traceback):
        # recreate the connections so we can do more stuff
        # in the parent thread
        recreate_service_connections()

The threadwrapper decorator

  • Sets up connections for the child thread
  • Loops over the list of items, calling func against each one, putting the result in the threadsafe queue
  • Closes all connections

Note that this module is probably not appropriate for use inside celery workers or API code because in both of those cases the deployment infrastructure should already be making full use of available computing resources.

The full file for anyone that would like to use it:

from __future__ import print_function, absolute_import, unicode_literals

import time
try:
    from Queue import Empty
except ImportError:
    from queue import Empty

from multiprocessing import Process, cpu_count, Manager
import logging
import traceback

from django import db
from django.conf import settings
from django.core.cache import caches
from elasticsearch_dsl import connections


loggly = logging.getLogger('loggly')


class Timer(object):

    """ Simple class for timing code blocks
    """
    def __init__(self):
        self.start_time = time.time()

    def done(self):
        end_time = time.time()
        return int(end_time - self.start_time)


def close_service_connections():
    """ Close all connections before we spawn our processes
    This function should only be used when writing multithreaded scripts where connections need to manually
    opened and closed so that threads don't reuse the same connection
    https://stackoverflow.com/questions/8242837/django-multiprocessing-and-database-connections
    """

    # close db connections, they will be recreated automatically
    db.connections.close_all()

    # close ES connection, needs to be manually recreated
    connections.connections.remove_connection("default")

    # close redis connections, will be recreated automatcially
    for k in settings.CACHES.keys():
        caches[k].close()


def recreate_service_connections():
    """ All this happens automatically when django starts up, this function should only be used when writing
    multithreaded scripts where connections need to manually opened and closed so that threads don't reuse
    the same connection
    """

    # ES is the only one that needs to be recreated explicitly
    connections.connections.create_connection(hosts=[settings.ELASTIC_FULL_URL], timeout=20)


def threadwrapper(some_function, catch_exceptions=True):
    """ This wrapper should only be used when a function is being called in a multiprocessing context
    """

    def wrapper(queue, items):
        recreate_service_connections()

        for i in items:
            try:
                rv = some_function(i)
            except Exception:
                rv = None

                if catch_exceptions:
                    loggly.error("threadwrapper caught an error, continuing - %s" % traceback.format_exc())
                else:
                    raise

            queue.put(rv, block=False)

        close_service_connections()

    return wrapper


class MultiProcess(object):
    """ Nicely abstracts away some of the challenges when doing multiprocessing with Django
    Unfortunately, falls over when running tests so its not really tested
    We implement this as a context manager so we dont have to worry about garbage collection calling __del__
    """

    queue = None
    item_count = 1
    workers = []

    def __init__(self, num_workers=None, max_workers=None, debug_print=False, status_interval=20):

        if num_workers is None:

            # always use at least one threads and leave one cpu available for other stuff
            # but 1 is the minumum
            self.num_workers = cpu_count() - 1
            if self.num_workers < 2:
                self.num_workers = 1

            if max_workers and self.num_workers &gt; max_workers:
                self.num_workers = max_workers

        else:
            self.num_workers = num_workers

        self.debug_print = debug_print

        self.status_interval = status_interval

        if debug_print:
            print("Using %s workers" % self.num_workers)

    def __enter__(self):
        close_service_connections()
        return self

    def map(self, func, iterable):

        self.queue = Manager().Queue()
        self.item_count = len(iterable) or 1

        for worker_idx in range(self.num_workers):

            items = []

            for idx, item in enumerate(iterable):
                if idx % self.num_workers == worker_idx:
                    items.append(item)

            if self.debug_print:
                print("Working on %s uids of %s in worker %s" % (len(items), len(iterable), worker_idx))

            p = Process(target=threadwrapper(func), args=[self.queue, items])
            p.start()
            self.workers.append(p)

        self._wait()

    def _wait(self):
        """ Wait for all workers to finish and wakes up peridocially to print out how much work has happened
        """
        total_time = Timer()

        while [p for p in self.workers if p.is_alive()]:

            tpt = Timer()

            for p in self.workers:
                p.join(timeout=self.status_interval)

                interval_secs = tpt.done() // 1000

                # if we've timed out on the status interval, print it out and reset the counter
                if self.debug_print and interval_secs &gt;= self.status_interval:
                    tpt = Timer()

                    total_secs = total_time.done() // 1000

                    percent = (self.queue.qsize() * 100) // self.item_count
                    print("--------- {}% done ({}s elapsed) ---------".format(percent, total_secs))

    def results(self):
        rv = []
        try:
            while True:
                rv.append(self.queue.get(block=False))
        except Empty:
            return rv

    def __exit__(self, type, value, traceback):
        # recreate the connections so we can do more stuff
        recreate_service_connections()

Enjoy and good luck with your Python3 migrations!

Leave a Reply

This site uses Akismet to reduce spam. Learn how your comment data is processed.