feat: add decorator_func_test.py
This commit is contained in:
70
decorator_func_test.py
Normal file
70
decorator_func_test.py
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import contextvars
|
||||||
|
from functools import wraps
|
||||||
|
from typing import Any, Callable
|
||||||
|
|
||||||
|
|
||||||
|
# https://github.com/aws/bedrock-agentcore-sdk-python/blob/52bc1946e3d007a168499069f4ce0a265371cd88/src/bedrock_agentcore/identity/auth.py#L220
|
||||||
|
# Apache-2.0 license
|
||||||
|
def requires_api_key(*, provider_name: str, into: str = "api_key") -> Callable:
|
||||||
|
def decorator(func: Callable) -> Callable:
|
||||||
|
# TODO
|
||||||
|
async def _get_api_key():
|
||||||
|
return f'[[TEST_API_KEY:${provider_name}]]' # TODO
|
||||||
|
|
||||||
|
@wraps(func)
|
||||||
|
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||||
|
api_key = await _get_api_key()
|
||||||
|
kwargs[into] = api_key
|
||||||
|
return await func(*args, **kwargs)
|
||||||
|
|
||||||
|
@wraps(func)
|
||||||
|
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||||
|
if _has_running_loop():
|
||||||
|
# for async env, eg. runtime
|
||||||
|
ctx = contextvars.copy_context()
|
||||||
|
import concurrent.futures
|
||||||
|
|
||||||
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||||
|
future = executor.submit(ctx.run, asyncio.run, _get_api_key())
|
||||||
|
api_key = future.result()
|
||||||
|
else:
|
||||||
|
# for sync env, eg. local dev
|
||||||
|
api_key = asyncio.run(_get_api_key())
|
||||||
|
|
||||||
|
kwargs[into] = api_key
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
if asyncio.iscoroutinefunction(func):
|
||||||
|
return async_wrapper
|
||||||
|
else:
|
||||||
|
return sync_wrapper
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def _has_running_loop() -> bool:
|
||||||
|
try:
|
||||||
|
asyncio.get_running_loop()
|
||||||
|
return True
|
||||||
|
except RuntimeError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
################################################################################
|
||||||
|
|
||||||
|
@requires_api_key(provider_name='demo1')
|
||||||
|
def get_api_key1(*, api_key: str) -> str:
|
||||||
|
return api_key
|
||||||
|
|
||||||
|
|
||||||
|
@requires_api_key(provider_name='demo2')
|
||||||
|
async def get_api_key2(*, api_key: str) -> str:
|
||||||
|
return api_key
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
print(f'API key: {get_api_key1(api_key="")}')
|
||||||
|
print(f'API key: {asyncio.run(get_api_key2(api_key=""))}')
|
||||||
Reference in New Issue
Block a user