Date: 2022-08-23
This week I wrote two Python libraries, goto and bytecode, to explore the idea of generating or modifying functions via bytecode generation.
In this post we’re going to explore the first of those two (The second one will come later).
The idea of goto
is pretty simple: Add a “new” control flow mechanism allowing a programmer to move execution around arbitrarily within a function.
>>> import goto
>>>
>>> @goto.goto
def example():
... '''
... ... returns 1
... '''
= 1
... x "skip")
... example.goto(= 2
... x "skip")
... example.label(return x
...
...>>> example()
1
It’s actually fairly simple to manage it just using tools available in the standard library, the dis
module being the key piece.
Lets take our function from before (without the decorator), and use dis.dis
to take a look at the bytecode.
>>> import dis
>>>
>>> def example():
'''
... ... returns 1
... '''
= 1
... x "skip")
... example.goto(= 2
... x "skip")
... example.label(return x
...
...>>> dis.dis(example)
1] [2] [3] [4] [5]
[5 0 LOAD_CONST 1 (1)
2 STORE_FAST 0 (x)
6 4 LOAD_GLOBAL 0 (example)
6 LOAD_METHOD 1 (goto)
8 LOAD_CONST 2 ('skip')
10 CALL_METHOD 1
12 POP_TOP
7 14 LOAD_CONST 3 (2)
16 STORE_FAST 0 (x)
8 18 LOAD_GLOBAL 0 (example)
20 LOAD_METHOD 2 (label)
22 LOAD_CONST 2 ('skip')
24 CALL_METHOD 1
26 POP_TOP
9 28 LOAD_FAST 0 (x)
30 RETURN_VALUE
dis
generates for us that lets us know what the given argument value will translate to.What we’re interested in is being able to find contiguous regions of bytecode in our functions that look like lines 6 and 8, and replacing them with some kind of jump.
dis
to the rescue again, this time in the form of dis.Bytecode
. By passing our function of interest into dis.Bytecode
we can get all the info we need to scan the instructions looking for these patterns.
In this case, I used zip
to create a sliding window of 3 instructions:
= dis.Bytecode(f)
code = [i for i in code]
instr for i1, i2, i3 in zip(instr, instr[1:], instr[2:]):
...
Then checked to see if they followed the pattern:
if (
== _LOAD_GLOBAL or i1.opcode == _LOAD_DEREF)
(i1.opcode and i1.argval == f.__name__
and i2.opcode == _LOAD_METHOD
and i3.opcode == _LOAD_CONST
):if i2.argval == "label":
= (
targets[i3.argval] + 6 # Add 6 to skip past the label function call instructions
i3.offset // 2 # divide by 2 to get instruction offset from byte offset
) elif i2.argval == "goto":
= i3.argval gotos[i1.offset]
This is part of an initial pass-through to collect all the goto
and label
statements. We want to know the location of all label
s ahead of time to make our lives much easier when actually generating the new bytecode.
Now that we have the location of all goto
s and label
s, generating the new bytecode becomes pretty trivial. All you need to know is that the JUMP_ABSOLUTE
instruction exists. It’s what we use to do the goto
ing.
= io.BytesIO()
writer for i1 in instr:
if i1.offset in gotos:
= targets[gotos[i1.offset]]
target bytes([_JUMP_ABSOLUTE, target]))
writer.write(else:
bytes([i1.opcode, i1.arg or 0])) writer.write(
When we encounter a goto
, we replace the original instruction with a JUMP_ABSOLUTE
to the associated label. Otherwise we use the original instruction.
Then all that’s left is to use it to construct a new function. We do so by using most of the original properties of the function we’re decorating (only replacing the co_code
attribute), and then initializing a types.FunctionType
object.
= writer.getvalue()
bytecode = code.codeobj.replace(co_code=bytecode)
new_code
= FunctionType(
func __name__, f.__defaults__, f.__closure__
new_code, f.__globals__, f. )
You can check out the full source here.
We fall prey to what seems to be an optimization that CPython makes where anything after a return
statement doesn’t have any bytecode generated for it.
>>> import dis
>>>
>>> def f():
"after")
... f.goto(return 5
... "after")
... f.label(return 6
...
...>>> dis.dis(f)
2 0 LOAD_GLOBAL 0 (f)
2 LOAD_METHOD 1 (goto)
4 LOAD_CONST 1 ('after')
6 CALL_METHOD 1
8 POP_TOP
3 10 LOAD_CONST 2 (5)
12 RETURN_VALUE
As you can see, the label and the second return are nowhere to be found, which breaks the scanning.
Additionally, there are no saftey nets here, it’s very easy to cause things to go horribly wrong. This is the first time I can remember triggering segfaults while developing anything in Python.
>>> import goto
>>>
>>> @goto.goto
def f():
... "in-loop")
... f.goto(for i in range(5):
... "in-loop") # Jumping into a loop like this will case a segfault
... f.label(
...>>> f()
1, 'python' terminated by signal SIGSEGV (Address boundary error) fish: Job
But if you’re careful, you can do some silly stuff.
>>> import goto
>>>
>>> @goto.goto
def f():
... for i in range(4):
... print("a", i)
... "loop-b")
... f.goto("loop-a")
... f.label(for i in range(4):
... print("b", i)
... "loop-a")
... f.goto("loop-b")
... f.label(
...>>> f()
0
a 1
b 2
a 3
b 0
b 1
a 2
b 3 a