Skip to main content

contextvar


contextvars.ContextVar

파라미터를 통해 함수를 호출해가며 값을 전달할 때, 첫 번째 인자를 전달한 함수로부터

  • 콜 스택이 깊어질 수록
  • 파라미터 수가 늘어날 수록

함수의 시그니처가 길어지고, 파라미터의 수정이 어려워집니다.

이를 전역변수로 해결하려고 하면 멀티 스레드 환경에서 문제가 발생합니다.

import random
import time
from concurrent.futures import ThreadPoolExecutor

num_step = 0


def step(n: int) -> None:
if n == 3:
return

global num_step
s = num_step
print(f"n: {n}, s: {s}")

time.sleep(random.random() * 2)

token = num_step
num_step += 1
try:
step(n + 1)
finally:
num_step = token

s = num_step
print(f"n: {n}, s: {s}")


with ThreadPoolExecutor(max_workers=3) as executor:
executor.map(step, [0] * 3)
n: 0, s: 0
n: 0, s: 0
n: 0, s: 0
n: 1, s: 1
n: 1, s: 2
n: 1, s: 3
n: 2, s: 4
n: 2, s: 5
n: 2, s: 6
n: 2, s: 6
n: 1, s: 3
n: 0, s: 2
n: 2, s: 2
n: 1, s: 4
n: 0, s: 0
n: 2, s: 0
n: 1, s: 5
n: 0, s: 1

위 코드에서 단순히 step(0)을 호출한다면 ns가 같은 값이 출력되지만, 멀티 스레드로 실행하면 실행 순서에 따라 ns가 다른 값이 출력됩니다.

이를 해결하기위해 contextvars.ContextVar를 사용하면 아래와 같이 각 함수 호출은 독립적인 컨택스트에서 실행되기 때문에 문제가 발생하지 않습니다.

import random
import time
from concurrent.futures import ThreadPoolExecutor
from contextvars import ContextVar

num_step = ContextVar[int]("num_step", default=0)


def step(n: int) -> None:
if n == 3:
return

s = num_step.get()
print(f"n: {n}, s: {s}")

time.sleep(random.random() * 2)

token = num_step.set(s + 1)
try:
step(n + 1)
finally:
num_step.reset(token)

s = num_step.get()
print(f"n: {n}, s: {s}")


with ThreadPoolExecutor(max_workers=3) as executor:
executor.map(step, [0] * 3)
n: 0, s: 0
n: 0, s: 0
n: 0, s: 0
n: 1, s: 1
n: 1, s: 1
n: 2, s: 2
n: 1, s: 1
n: 2, s: 2
n: 2, s: 2
n: 2, s: 2
n: 1, s: 1
n: 0, s: 0
n: 2, s: 2
n: 1, s: 1
n: 0, s: 0
n: 2, s: 2
n: 1, s: 1
n: 0, s: 0

마찬가지로 asyncio에서도 사용할 수 있습니다.

import asyncio
import random
from contextvars import ContextVar

num_step = ContextVar[int]("num_step", default=0)


async def step(n: int) -> None:
if n == 3:
return

s = num_step.get()
print(f"n: {n}, s: {s}")

await asyncio.sleep(random.random() * 2)

token = num_step.set(s + 1)
try:
await step(n + 1)
finally:
num_step.reset(token)

s = num_step.get()
print(f"n: {n}, s: {s}")


async def main() -> None:
cos = [step(0) for _ in range(3)]
await asyncio.gather(*cos)


asyncio.run(main())

contextlib.contextmanager와 함께 사용하기

컨택스트 변수는 값을 설정한 후 함수를 호출하고, 함수가 종료되면 원래 값으로 복원하는 방식을 반복적으로 사용해야하는 경우가 있습니다.

이때 contextlib.contextmanager를 사용하여 setreset을 감싸고, get을 위한 함수를 따로 만들어 노출시키면 개발자의 실수를 줄일 수 있습니다.

from contextlib import contextmanager
from contextvars import ContextVar
from typing import Any, Generator

_num_step = ContextVar[int]("num_step", default=0)


@contextmanager
def with_next_step() -> Generator[None, Any, None]:
token = _num_step.set(num_step.get() + 1)
try:
yield
finally:
_num_step.reset(token)


def get_step() -> int:
return _num_step.get()