feat: add decorator_func_test.py
This commit is contained in:
70
decorator_func_test.py
Normal file
70
decorator_func_test.py
Normal file
@@ -0,0 +1,70 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import asyncio
|
||||
import contextvars
|
||||
from functools import wraps
|
||||
from typing import Any, Callable
|
||||
|
||||
|
||||
# https://github.com/aws/bedrock-agentcore-sdk-python/blob/52bc1946e3d007a168499069f4ce0a265371cd88/src/bedrock_agentcore/identity/auth.py#L220
|
||||
# Apache-2.0 license
|
||||
def requires_api_key(*, provider_name: str, into: str = "api_key") -> Callable:
|
||||
def decorator(func: Callable) -> Callable:
|
||||
# TODO
|
||||
async def _get_api_key():
|
||||
return f'[[TEST_API_KEY:${provider_name}]]' # TODO
|
||||
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
api_key = await _get_api_key()
|
||||
kwargs[into] = api_key
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
@wraps(func)
|
||||
def sync_wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
if _has_running_loop():
|
||||
# for async env, eg. runtime
|
||||
ctx = contextvars.copy_context()
|
||||
import concurrent.futures
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(ctx.run, asyncio.run, _get_api_key())
|
||||
api_key = future.result()
|
||||
else:
|
||||
# for sync env, eg. local dev
|
||||
api_key = asyncio.run(_get_api_key())
|
||||
|
||||
kwargs[into] = api_key
|
||||
return func(*args, **kwargs)
|
||||
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
return async_wrapper
|
||||
else:
|
||||
return sync_wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def _has_running_loop() -> bool:
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
return True
|
||||
except RuntimeError:
|
||||
return False
|
||||
|
||||
|
||||
################################################################################
|
||||
|
||||
@requires_api_key(provider_name='demo1')
|
||||
def get_api_key1(*, api_key: str) -> str:
|
||||
return api_key
|
||||
|
||||
|
||||
@requires_api_key(provider_name='demo2')
|
||||
async def get_api_key2(*, api_key: str) -> str:
|
||||
return api_key
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
print(f'API key: {get_api_key1(api_key="")}')
|
||||
print(f'API key: {asyncio.run(get_api_key2(api_key=""))}')
|
||||
Reference in New Issue
Block a user