71 lines
2.1 KiB
Python
71 lines
2.1 KiB
Python
#!/usr/bin/env python3
|
|
|
|
import asyncio
|
|
import contextvars
|
|
from functools import wraps
|
|
from typing import Any, Callable
|
|
|
|
|
|
# https://github.com/aws/bedrock-agentcore-sdk-python/blob/52bc1946e3d007a168499069f4ce0a265371cd88/src/bedrock_agentcore/identity/auth.py#L220
|
|
# Apache-2.0 license
|
|
def requires_api_key(*, provider_name: str, into: str = "api_key") -> Callable:
|
|
def decorator(func: Callable) -> Callable:
|
|
# TODO
|
|
async def _get_api_key():
|
|
return f'[[TEST_API_KEY:${provider_name}]]' # TODO
|
|
|
|
@wraps(func)
|
|
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
api_key = await _get_api_key()
|
|
kwargs[into] = api_key
|
|
return await func(*args, **kwargs)
|
|
|
|
@wraps(func)
|
|
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
if _has_running_loop():
|
|
# for async env, eg. runtime
|
|
ctx = contextvars.copy_context()
|
|
import concurrent.futures
|
|
|
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
|
future = executor.submit(ctx.run, asyncio.run, _get_api_key())
|
|
api_key = future.result()
|
|
else:
|
|
# for sync env, eg. local dev
|
|
api_key = asyncio.run(_get_api_key())
|
|
|
|
kwargs[into] = api_key
|
|
return func(*args, **kwargs)
|
|
|
|
if asyncio.iscoroutinefunction(func):
|
|
return async_wrapper
|
|
else:
|
|
return sync_wrapper
|
|
|
|
return decorator
|
|
|
|
|
|
def _has_running_loop() -> bool:
|
|
try:
|
|
asyncio.get_running_loop()
|
|
return True
|
|
except RuntimeError:
|
|
return False
|
|
|
|
|
|
################################################################################
|
|
|
|
@requires_api_key(provider_name='demo1')
|
|
def get_api_key1(*, api_key: str) -> str:
|
|
return api_key
|
|
|
|
|
|
@requires_api_key(provider_name='demo2')
|
|
async def get_api_key2(*, api_key: str) -> str:
|
|
return api_key
|
|
|
|
|
|
if __name__ == '__main__':
|
|
print(f'API key: {get_api_key1(api_key="")}')
|
|
print(f'API key: {asyncio.run(get_api_key2(api_key=""))}')
|