#!/usr/bin/env python3 import asyncio import contextvars from functools import wraps from typing import Any, Callable # https://github.com/aws/bedrock-agentcore-sdk-python/blob/52bc1946e3d007a168499069f4ce0a265371cd88/src/bedrock_agentcore/identity/auth.py#L220 # Apache-2.0 license def requires_api_key(*, provider_name: str, into: str = "api_key") -> Callable: def decorator(func: Callable) -> Callable: # TODO async def _get_api_key(): return f'[[TEST_API_KEY:${provider_name}]]' # TODO @wraps(func) async def async_wrapper(*args: Any, **kwargs: Any) -> Any: api_key = await _get_api_key() kwargs[into] = api_key return await func(*args, **kwargs) @wraps(func) def sync_wrapper(*args: Any, **kwargs: Any) -> Any: if _has_running_loop(): # for async env, eg. runtime ctx = contextvars.copy_context() import concurrent.futures with concurrent.futures.ThreadPoolExecutor() as executor: future = executor.submit(ctx.run, asyncio.run, _get_api_key()) api_key = future.result() else: # for sync env, eg. local dev api_key = asyncio.run(_get_api_key()) kwargs[into] = api_key return func(*args, **kwargs) if asyncio.iscoroutinefunction(func): return async_wrapper else: return sync_wrapper return decorator def _has_running_loop() -> bool: try: asyncio.get_running_loop() return True except RuntimeError: return False ################################################################################ @requires_api_key(provider_name='demo1') def get_api_key1(*, api_key: str) -> str: return api_key @requires_api_key(provider_name='demo2') async def get_api_key2(*, api_key: str) -> str: return api_key if __name__ == '__main__': print(f'API key: {get_api_key1(api_key="")}') print(f'API key: {asyncio.run(get_api_key2(api_key=""))}')