Files
python-tests/decorator_func_test.py

71 lines
2.1 KiB
Python

#!/usr/bin/env python3
import asyncio
import contextvars
from functools import wraps
from typing import Any, Callable
# https://github.com/aws/bedrock-agentcore-sdk-python/blob/52bc1946e3d007a168499069f4ce0a265371cd88/src/bedrock_agentcore/identity/auth.py#L220
# Apache-2.0 license
def requires_api_key(*, provider_name: str, into: str = "api_key") -> Callable:
def decorator(func: Callable) -> Callable:
# TODO
async def _get_api_key():
return f'[[TEST_API_KEY:${provider_name}]]' # TODO
@wraps(func)
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
api_key = await _get_api_key()
kwargs[into] = api_key
return await func(*args, **kwargs)
@wraps(func)
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
if _has_running_loop():
# for async env, eg. runtime
ctx = contextvars.copy_context()
import concurrent.futures
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(ctx.run, asyncio.run, _get_api_key())
api_key = future.result()
else:
# for sync env, eg. local dev
api_key = asyncio.run(_get_api_key())
kwargs[into] = api_key
return func(*args, **kwargs)
if asyncio.iscoroutinefunction(func):
return async_wrapper
else:
return sync_wrapper
return decorator
def _has_running_loop() -> bool:
try:
asyncio.get_running_loop()
return True
except RuntimeError:
return False
################################################################################
@requires_api_key(provider_name='demo1')
def get_api_key1(*, api_key: str) -> str:
return api_key
@requires_api_key(provider_name='demo2')
async def get_api_key2(*, api_key: str) -> str:
return api_key
if __name__ == '__main__':
print(f'API key: {get_api_key1(api_key="")}')
print(f'API key: {asyncio.run(get_api_key2(api_key=""))}')