contextvar
contextvars.ContextVar
파라미터를 통해 함수를 호출해가며 값을 전달할 때, 첫 번째 인자를 전달한 함수로부터
- 콜 스택이 깊어질 수록
- 파라미터 수가 늘어날 수록
함수의 시그니처가 길어지고, 파라미터의 수정이 어려워집니다.
이를 전역변수로 해결하려고 하면 멀티 스레드 환경에서 문제가 발생합니다.
import random
import time
from concurrent.futures import ThreadPoolExecutor
num_step = 0
def step(n: int) -> None:
if n == 3:
return
global num_step
s = num_step
print(f"n: {n}, s: {s}")
time.sleep(random.random() * 2)
token = num_step
num_step += 1
try:
step(n + 1)
finally:
num_step = token
s = num_step
print(f"n: {n}, s: {s}")
with ThreadPoolExecutor(max_workers=3) as executor:
executor.map(step, [0] * 3)
n: 0, s: 0
n: 0, s: 0
n: 0, s: 0
n: 1, s: 1
n: 1, s: 2
n: 1, s: 3
n: 2, s: 4
n: 2, s: 5
n: 2, s: 6
n: 2, s: 6
n: 1, s: 3
n: 0, s: 2
n: 2, s: 2
n: 1, s: 4
n: 0, s: 0
n: 2, s: 0
n: 1, s: 5
n: 0, s: 1
위 코드에서 단순히 step(0)
을 호출한다면 n
과 s
가 같은 값이 출력되지만, 멀티 스레드로 실행하면 실행 순서에 따라 n
과 s
가 다른 값이 출력됩니다.
이를 해결하기위해 contextvars.ContextVar
를 사용하면 아래와 같이 각 함수 호출은 독립적인 컨택스트에서 실행되기 때문에 문제가 발생하지 않습니다.
import random
import time
from concurrent.futures import ThreadPoolExecutor
from contextvars import ContextVar
num_step = ContextVar[int]("num_step", default=0)
def step(n: int) -> None:
if n == 3:
return
s = num_step.get()
print(f"n: {n}, s: {s}")
time.sleep(random.random() * 2)
token = num_step.set(s + 1)
try:
step(n + 1)
finally:
num_step.reset(token)
s = num_step.get()
print(f"n: {n}, s: {s}")
with ThreadPoolExecutor(max_workers=3) as executor:
executor.map(step, [0] * 3)
n: 0, s: 0
n: 0, s: 0
n: 0, s: 0
n: 1, s: 1
n: 1, s: 1
n: 2, s: 2
n: 1, s: 1
n: 2, s: 2
n: 2, s: 2
n: 2, s: 2
n: 1, s: 1
n: 0, s: 0
n: 2, s: 2
n: 1, s: 1
n: 0, s: 0
n: 2, s: 2
n: 1, s: 1
n: 0, s: 0
마찬가지로 asyncio에서도 사용할 수 있습니다.
import asyncio
import random
from contextvars import ContextVar
num_step = ContextVar[int]("num_step", default=0)
async def step(n: int) -> None:
if n == 3:
return
s = num_step.get()
print(f"n: {n}, s: {s}")
await asyncio.sleep(random.random() * 2)
token = num_step.set(s + 1)
try:
await step(n + 1)
finally:
num_step.reset(token)
s = num_step.get()
print(f"n: {n}, s: {s}")
async def main() -> None:
cos = [step(0) for _ in range(3)]
await asyncio.gather(*cos)
asyncio.run(main())
contextlib.contextmanager
와 함께 사용하기
컨택스트 변수는 값을 설정한 후 함수를 호출하고, 함수가 종료되면 원래 값으로 복원하는 방식을 반복적으로 사용해야하는 경우가 있습니다.
이때 contextlib.contextmanager
를 사용하여 set
과 reset
을 감싸고, get
을 위한 함수를 따로 만들어 노출시키면 개발자의 실수를 줄일 수 있습니다.
from contextlib import contextmanager
from contextvars import ContextVar
from typing import Any, Generator
_num_step = ContextVar[int]("num_step", default=0)
@contextmanager
def with_next_step() -> Generator[None, Any, None]:
token = _num_step.set(_num_step.get() + 1)
try:
yield
finally:
_num_step.reset(token)
def get_step() -> int:
return _num_step.get()