diff --git a/lnbits/decorators.py b/lnbits/decorators.py index 440f8fa1..5beef46a 100644 --- a/lnbits/decorators.py +++ b/lnbits/decorators.py @@ -1,5 +1,6 @@ from functools import wraps from http import HTTPStatus +from base64 import b64decode from fastapi.security import api_key from pydantic.types import UUID4 @@ -12,6 +13,7 @@ from fastapi.exceptions import HTTPException from fastapi.openapi.models import APIKey, APIKeyIn from fastapi.params import Security from fastapi.security.api_key import APIKeyHeader, APIKeyQuery +from fastapi.security import OAuth2PasswordBearer from fastapi.security.base import SecurityBase from starlette.requests import Request @@ -47,13 +49,13 @@ class KeyChecker(SecurityBase): raise HTTPException(status_code=HTTPStatus.UNAUTHORIZED, detail="Invalid key or expired key.") except KeyError: - raise HTTPException(status_code=HTTPStatus.BAD_REQUEST, + raise HTTPException(status_code=HTTPStatus.BAD_REQUEST, detail="`X-API-KEY` header missing.") class WalletInvoiceKeyChecker(KeyChecker): """ WalletInvoiceKeyChecker will ensure that the provided invoice - wallet key is correct and populate g().wallet with the wallet + wallet key is correct and populate g().wallet with the wallet for the key in `X-API-key`. The checker will raise an HTTPException when the key is wrong in some ways. @@ -65,7 +67,7 @@ class WalletInvoiceKeyChecker(KeyChecker): class WalletAdminKeyChecker(KeyChecker): """ WalletAdminKeyChecker will ensure that the provided admin - wallet key is correct and populate g().wallet with the wallet + wallet key is correct and populate g().wallet with the wallet for the key in `X-API-key`. The checker will raise an HTTPException when the key is wrong in some ways. @@ -85,14 +87,19 @@ class WalletTypeInfo(): api_key_header = APIKeyHeader(name="X-API-KEY", auto_error=False, description="Admin or Invoice key for wallet API's") api_key_query = APIKeyQuery(name="api-key", auto_error=False, description="Admin or Invoice key for wallet API's") -async def get_key_type(r: Request, - api_key_header: str = Security(api_key_header), +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") +async def get_key_type(r: Request, + token: str = Security(oauth2_scheme), + api_key_header: str = Security(api_key_header), api_key_query: str = Security(api_key_query)) -> WalletTypeInfo: # 0: admin # 1: invoice # 2: invalid + # print("TOKEN", b64decode(token).decode("utf-8").split(":")) + + key_type, key = b64decode(token).decode("utf-8").split(":") try: - checker = WalletAdminKeyChecker(api_key=api_key_query) + checker = WalletAdminKeyChecker(api_key=key if token else api_key_query) await checker.__call__(r) return WalletTypeInfo(0, checker.wallet) except HTTPException as e: @@ -104,7 +111,7 @@ async def get_key_type(r: Request, raise try: - checker = WalletInvoiceKeyChecker() + checker = WalletInvoiceKeyChecker(api_key=key if token else None) await checker.__call__(r) return WalletTypeInfo(1, checker.wallet) except HTTPException as e: @@ -121,7 +128,7 @@ def api_validate_post_request(*, schema: dict): async def wrapped_view(**kwargs): if "application/json" not in request.headers["Content-Type"]: raise HTTPException( - status_code=HTTPStatus.BAD_REQUEST, + status_code=HTTPStatus.BAD_REQUEST, detail=jsonify({"message": "Content-Type must be `application/json`."}) ) @@ -131,10 +138,10 @@ def api_validate_post_request(*, schema: dict): if not v.validate(g().data): raise HTTPException( - status_code=HTTPStatus.BAD_REQUEST, + status_code=HTTPStatus.BAD_REQUEST, detail=jsonify({"message": f"Errors in request data: {v.errors}"}) ) - + return await view(**kwargs) @@ -144,7 +151,7 @@ def api_validate_post_request(*, schema: dict): async def check_user_exists(usr: UUID4) -> User: - g().user = await get_user(usr.hex) + g().user = await get_user(usr.hex) if not g().user: raise HTTPException( status_code=HTTPStatus.NOT_FOUND,