diff --git a/decorator_func_test.py b/decorator_func_test.py new file mode 100644 index 0000000..0e5db27 --- /dev/null +++ b/decorator_func_test.py @@ -0,0 +1,70 @@ +#!/usr/bin/env python3 + +import asyncio +import contextvars +from functools import wraps +from typing import Any, Callable + + +# https://github.com/aws/bedrock-agentcore-sdk-python/blob/52bc1946e3d007a168499069f4ce0a265371cd88/src/bedrock_agentcore/identity/auth.py#L220 +# Apache-2.0 license +def requires_api_key(*, provider_name: str, into: str = "api_key") -> Callable: + def decorator(func: Callable) -> Callable: + # TODO + async def _get_api_key(): + return f'[[TEST_API_KEY:${provider_name}]]' # TODO + + @wraps(func) + async def async_wrapper(*args: Any, **kwargs: Any) -> Any: + api_key = await _get_api_key() + kwargs[into] = api_key + return await func(*args, **kwargs) + + @wraps(func) + def sync_wrapper(*args: Any, **kwargs: Any) -> Any: + if _has_running_loop(): + # for async env, eg. runtime + ctx = contextvars.copy_context() + import concurrent.futures + + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(ctx.run, asyncio.run, _get_api_key()) + api_key = future.result() + else: + # for sync env, eg. local dev + api_key = asyncio.run(_get_api_key()) + + kwargs[into] = api_key + return func(*args, **kwargs) + + if asyncio.iscoroutinefunction(func): + return async_wrapper + else: + return sync_wrapper + + return decorator + + +def _has_running_loop() -> bool: + try: + asyncio.get_running_loop() + return True + except RuntimeError: + return False + + +################################################################################ + +@requires_api_key(provider_name='demo1') +def get_api_key1(*, api_key: str) -> str: + return api_key + + +@requires_api_key(provider_name='demo2') +async def get_api_key2(*, api_key: str) -> str: + return api_key + + +if __name__ == '__main__': + print(f'API key: {get_api_key1(api_key="")}') + print(f'API key: {asyncio.run(get_api_key2(api_key=""))}')