Just now, I was working on a script to process that loads a bunch of files, runs them through a neural network, does a few other simple things, then writes the data back out. The script took 15 minutes and I really wanted it to be faster...
My gut reaction was to batch the data. I knew I could run a batch of >100 examples almost as fast as a single example, so I thought ah yes that is the right place to start speeding things up. But, then I thought... well... maybe I should measure first.
If you have pycharm professional (free for academics) then profiling your code is easy. You just hit that button instead of the normal run or debug buttons. At the end (or at any time during the run) you can view a table or graph of which functions take the most time. Here's what I saw...
The boxes in red are very slow, taking up a lot of the total run time. To my surprise, the functions for calling the neural network aren't even on here. Instead, write_example is taking 96.1% of the time! After parallelizing write_example to use all my CPUs...
class MultiprocessingExampleWriter:
def __init__(self):
self.pool = Pool()
def write(self, outdir: pathlib.Path, out_example: Dict, example_idx: int, save_format: str = 'pkl'):
self.pool.apply_async(func=write_example, args=(outdir, out_example, example_idx, save_format))
The total run time went from 15 minutes to <2 minutes. Huge win. And the fix was way simpler than batching my data, and the code looks much nicer.