From 56560fca02c335534d0974b133814c4811ae88d8 Mon Sep 17 00:00:00 2001 From: dni Date: Wed, 20 Jul 2022 09:36:13 +0200 Subject: [PATCH] mypy fixes for generic.py, decurators.py, eclair.py --- lnbits/core/views/generic.py | 16 +++++++-------- lnbits/decorators.py | 40 ++++++++++++++++++++---------------- lnbits/wallets/eclair.py | 4 +++- 3 files changed, 33 insertions(+), 27 deletions(-) diff --git a/lnbits/core/views/generic.py b/lnbits/core/views/generic.py index 7ef306dc..5f44131f 100644 --- a/lnbits/core/views/generic.py +++ b/lnbits/core/views/generic.py @@ -55,9 +55,9 @@ async def home(request: Request, lightning: str = None): ) async def extensions( request: Request, - user: User = Depends(check_user_exists), - enable: str = Query(None), - disable: str = Query(None), + user = Depends(check_user_exists), + enable = Query(None), + disable = Query(None), ): extension_to_enable = enable extension_to_disable = disable @@ -109,10 +109,10 @@ nothing: create everything
""", ) async def wallet( - request: Request = Query(None), - nme: Optional[str] = Query(None), - usr: Optional[UUID4] = Query(None), - wal: Optional[UUID4] = Query(None), + request = Query(None), + nme = Query(None), + usr = Query(None), + wal = Query(None), ): user_id = usr.hex if usr else None wallet_id = wal.hex if wal else None @@ -218,7 +218,7 @@ async def lnurl_full_withdraw_callback(request: Request): @core_html_routes.get("/deletewallet", response_class=RedirectResponse) -async def deletewallet(request: Request, wal: str = Query(...), usr: str = Query(...)): +async def deletewallet(request: Request, wal = Query(...), usr = Query(...)): user = await get_user(usr) assert user is not None user_wallet_ids = [u.id for u in user.wallets] diff --git a/lnbits/decorators.py b/lnbits/decorators.py index e65b9041..77fe3227 100644 --- a/lnbits/decorators.py +++ b/lnbits/decorators.py @@ -1,5 +1,7 @@ from http import HTTPStatus +from typing import Optional + from cerberus import Validator # type: ignore from fastapi import status from fastapi.exceptions import HTTPException @@ -29,20 +31,20 @@ class KeyChecker(SecurityBase): self._key_type = "invoice" self._api_key = api_key if api_key: - self.model: APIKey = APIKey( + key = APIKey( **{"in": APIKeyIn.query}, name="X-API-KEY", description="Wallet API Key - QUERY", ) else: - self.model: APIKey = APIKey( + key = APIKey( **{"in": APIKeyIn.header}, name="X-API-KEY", description="Wallet API Key - HEADER", ) - self.wallet = None + self.model: APIKey = key - async def __call__(self, request: Request) -> Wallet: + async def __call__(self, request: Request): try: key_value = ( self._api_key @@ -52,12 +54,13 @@ class KeyChecker(SecurityBase): # FIXME: Find another way to validate the key. A fetch from DB should be avoided here. # Also, we should not return the wallet here - thats silly. # Possibly store it in a Redis DB - self.wallet = await get_wallet_for_key(key_value, self._key_type) - if not self.wallet: + wallet = await get_wallet_for_key(key_value, self._key_type) + if not wallet: raise HTTPException( status_code=HTTPStatus.UNAUTHORIZED, detail="Invalid key or expired key.", ) + self.wallet = wallet except KeyError: raise HTTPException( @@ -120,8 +123,8 @@ api_key_query = APIKeyQuery( async def get_key_type( r: Request, - api_key_header: str = Security(api_key_header), - api_key_query: str = Security(api_key_query), + api_key_header = Security(api_key_header), + api_key_query = Security(api_key_query), ) -> WalletTypeInfo: # 0: admin # 1: invoice @@ -134,9 +137,10 @@ async def get_key_type( token = api_key_header if api_key_header else api_key_query try: - checker = WalletAdminKeyChecker(api_key=token) - await checker.__call__(r) - wallet = WalletTypeInfo(0, checker.wallet) + admin_checker = WalletAdminKeyChecker(api_key=token) + await admin_checker.__call__(r) + wallet = WalletTypeInfo(0, admin_checker.wallet) + assert wallet.wallet is not None if (LNBITS_ADMIN_USERS and wallet.wallet.user not in LNBITS_ADMIN_USERS) and ( LNBITS_ADMIN_EXTENSIONS and pathname in LNBITS_ADMIN_EXTENSIONS ): @@ -153,9 +157,9 @@ async def get_key_type( raise try: - checker = WalletInvoiceKeyChecker(api_key=token) - await checker.__call__(r) - wallet = WalletTypeInfo(1, checker.wallet) + invoice_checker = WalletInvoiceKeyChecker(api_key=token) + await invoice_checker.__call__(r) + wallet = WalletTypeInfo(1, invoice_checker.wallet) if (LNBITS_ADMIN_USERS and wallet.wallet.user not in LNBITS_ADMIN_USERS) and ( LNBITS_ADMIN_EXTENSIONS and pathname in LNBITS_ADMIN_EXTENSIONS ): @@ -174,8 +178,8 @@ async def get_key_type( async def require_admin_key( r: Request, - api_key_header: str = Security(api_key_header), - api_key_query: str = Security(api_key_query), + api_key_header = Security(api_key_header), + api_key_query = Security(api_key_query), ): token = api_key_header if api_key_header else api_key_query @@ -193,8 +197,8 @@ async def require_admin_key( async def require_invoice_key( r: Request, - api_key_header: str = Security(api_key_header), - api_key_query: str = Security(api_key_query), + api_key_header = Security(api_key_header), + api_key_query = Security(api_key_query), ): token = api_key_header if api_key_header else api_key_query diff --git a/lnbits/wallets/eclair.py b/lnbits/wallets/eclair.py index 0ac3fd2a..bad707ff 100644 --- a/lnbits/wallets/eclair.py +++ b/lnbits/wallets/eclair.py @@ -7,7 +7,9 @@ from typing import AsyncGenerator, Dict, Optional import httpx from loguru import logger -from websockets import connect + +# mypy https://github.com/aaugustin/websockets/issues/940 +from websockets.client import connect from websockets.exceptions import ( ConnectionClosed, ConnectionClosedError,