Yo, I heard you like decorators

This starts out innocently enough. Suppose we have a generator function like this one, which produces the Collatz sequence:

def collatz(x):
    yield x
    while x != 1:
        x = (3 * x + 1) if (x % 2) else (x // 2)
        yield x

We can look at the items produced by this sequence by iterating over it:

>>> list(collatz(12))
[12, 6, 3, 10, 5, 16, 8, 4, 2, 1]

However, suppose we want to augment what this function returns without screwing with it too much. We could want to see the index of each item, for example. Python gives us enumerate for this purpose, but we have to call it on the result of the function:

>>> list(enumerate(collatz(12)))
[(0, 12),  (1, 6),  (2, 3),  (3, 10),  (4, 5),  (5, 16),  (6, 8),  (7, 4),  (8, 2),  (9, 1)]

We could turn enumerate into a decorator, which would allow us to change what our function returns without changing how we call it:

def enumerator(f):
    def wrapper(*args, **kwargs):
        return enumerate(f(*args, **kwargs))
   
    return wrapper

@enumerator
def collatz(x):
    # ...

Now we get:

>>> list(collatz(12))
[(0, 12),  (1, 6),  (2, 3),  (3, 10),  (4, 5),  (5, 16),  (6, 8),  (7, 4),  (8, 2),  (9, 1)]

Decorators are a good way to extend the functionality of existing code without having to change it. One of my favorites is functools.lru_cache , which lets you transparently add memoization.

Our simple decorator is fairly limiting - it doesn't let us use the enumerate function's start parameter. We can fix that, but we'll have to write a decorator that accepts parameters. To do this, we write a function that returns a decorator:

def enumerator(start=0):
    def outer_wrapper(f):
        def inner_wrapper(*args, **kwargs):
            return enumerate(f(*args, **kwargs), start=start)
   
        return inner_wrapper
   
    return outer_wrapper

@enumerator(1)
def collatz(x):
    # ...

(This is also called a "decorator." My suggestion of "double dec-er" was rejected without, I thought, proper consideration.)

Now the indexes start from 1:

>>> list(collatz(12))
[(1, 12),  (2, 6),  (3, 3),  (4, 10),  (5, 5),  (6, 16),  (7, 8),  (8, 4),  (9, 2),  (10, 1)]

The enumerator code is a bit confusing, but essentially we're doing:

enumerator_1 = enumerator(1)

@enumerator_1
def collatz(x):
    # ...

We could write a very similar decorator for a different function, like itertools.accumulate (which produces a running total):

def accumulator(func=operator.add):
    def outer_wrapper(f):
        def inner_wrapper(*args, **kwargs):
            return itertools.accumulate(f(*args, **kwargs), func=func)
       
        return inner_wrapper
   
    return outer_wrapper

@accumulator()
def collatz(x)
    # ...

However, that's quite repetitive. If we want to generalize it, we'll need another layer of wrapping...

def decorator_factory(wrapping_func):
    def decorator(**wrapping_kwargs):
        def outer_wrapper(f):
            def inner_wrapper(*args, **kwargs):
                result = f(*args, **kwargs)
                return wrapping_func(result, **wrapping_kwargs)
           
            return inner_wrapper
       
        return outer_wrapper
   
    return decorator

This is sort of mind-bending, but it does let us create our decorators. We can then stack them:

enumerator = decorator_factory(enumerate)
accumulator = decorator_factory(itertools.accumulate)

@enumerator(start=1)
@accumulator()
def collatz(x):
    # ...

This produces the running total of our sequence, with indexes:

>>> list(collatz(12))
[(1, 12), (2, 18), (3, 21), (4, 31), (5, 36), (6, 52), (7, 60), (8, 64), (9, 66), (10, 67)]

To try to understand what's happening, we can get the same result by being more verbose:

enumerator = decorator_factory(enumerate)
accumulator = decorator_factory(itertools.accumulate)

enumerator_1 = enumerator(start=1)
accumulator_add = accumulator()

@enumerator_1
@accumulator_add
def collatz(x):
    # ...

Our decorator_factory assumes that the first argument to the wrapping_func is the result of the function we're decorating. That works for enumerate(sequence, start) and itertools.accumulate(iterable, func) , but it won't work for something with a different signature like itertools.dropwhile(predicate, iterable) - the result is the second argument.

We can fix that problem by (1) allowing the decorator factory to accept positional arguments, and (2) telling the decorator factory the index at which to put the result of our function:

def decorator_factory(wrapping_func, result_index=0):
    def decorator(*wrapping_args, **wrapping_kwargs):
        def outer_wrapper(f):
            def inner_wrapper(*args, **kwargs):
                result = f(*args, **kwargs)
                wrapping_args_ = list(wrapping_args)
                wrapping_args_.insert(result_index, result)
                return wrapping_func(*wrapping_args_, **wrapping_kwargs)
           
            return inner_wrapper
       
        return outer_wrapper
   
    return decorator

This is bananas, but it works:

accumulator = decorator_factory(itertools.accumulate)
dropper = decorator_factory(itertools.dropwhile, 1)

@dropper(lambda x: x < 50)
@accumulator()
def collatz(x):
    # ...

Now we have the items from the running total, starting with the first one that's at least 50:

>>> list(collatz(12))
[52, 60, 64, 66, 67]

The decorator factory will work on any function that operates on something that returns an iterable, like most of the things in the more_itertools library:

chunker = decorator_factory(more_itertools.chunked)
taker = decorator_factory(more_itertools.take, 1)

@chunker(3)  # Group into chunks of 3
@taker(6)  # Get the first 6 items
def collatz(x):
    # ...

>>> it = collatz(12)
>>> list(it)
[[12, 6, 3], [10, 5, 16]]

Is any of this a good idea for real code? Probably not, but it was fun to puzzle out.

This post was inspired by the discussion in the issue tracker of the more-itertools library. If you've got good ideas, or crazy ones, let us know.