mypy fixes for generic.py, decurators.py, eclair.py

This commit is contained in:
dni 2022-07-20 09:36:13 +02:00
parent 7ad9ad799e
commit 56560fca02
3 changed files with 33 additions and 27 deletions

View file

@ -55,9 +55,9 @@ async def home(request: Request, lightning: str = None):
) )
async def extensions( async def extensions(
request: Request, request: Request,
user: User = Depends(check_user_exists), user = Depends(check_user_exists),
enable: str = Query(None), enable = Query(None),
disable: str = Query(None), disable = Query(None),
): ):
extension_to_enable = enable extension_to_enable = enable
extension_to_disable = disable extension_to_disable = disable
@ -109,10 +109,10 @@ nothing: create everything<br>
""", """,
) )
async def wallet( async def wallet(
request: Request = Query(None), request = Query(None),
nme: Optional[str] = Query(None), nme = Query(None),
usr: Optional[UUID4] = Query(None), usr = Query(None),
wal: Optional[UUID4] = Query(None), wal = Query(None),
): ):
user_id = usr.hex if usr else None user_id = usr.hex if usr else None
wallet_id = wal.hex if wal else None wallet_id = wal.hex if wal else None
@ -218,7 +218,7 @@ async def lnurl_full_withdraw_callback(request: Request):
@core_html_routes.get("/deletewallet", response_class=RedirectResponse) @core_html_routes.get("/deletewallet", response_class=RedirectResponse)
async def deletewallet(request: Request, wal: str = Query(...), usr: str = Query(...)): async def deletewallet(request: Request, wal = Query(...), usr = Query(...)):
user = await get_user(usr) user = await get_user(usr)
assert user is not None assert user is not None
user_wallet_ids = [u.id for u in user.wallets] user_wallet_ids = [u.id for u in user.wallets]

View file

@ -1,5 +1,7 @@
from http import HTTPStatus from http import HTTPStatus
from typing import Optional
from cerberus import Validator # type: ignore from cerberus import Validator # type: ignore
from fastapi import status from fastapi import status
from fastapi.exceptions import HTTPException from fastapi.exceptions import HTTPException
@ -29,20 +31,20 @@ class KeyChecker(SecurityBase):
self._key_type = "invoice" self._key_type = "invoice"
self._api_key = api_key self._api_key = api_key
if api_key: if api_key:
self.model: APIKey = APIKey( key = APIKey(
**{"in": APIKeyIn.query}, **{"in": APIKeyIn.query},
name="X-API-KEY", name="X-API-KEY",
description="Wallet API Key - QUERY", description="Wallet API Key - QUERY",
) )
else: else:
self.model: APIKey = APIKey( key = APIKey(
**{"in": APIKeyIn.header}, **{"in": APIKeyIn.header},
name="X-API-KEY", name="X-API-KEY",
description="Wallet API Key - HEADER", description="Wallet API Key - HEADER",
) )
self.wallet = None self.model: APIKey = key
async def __call__(self, request: Request) -> Wallet: async def __call__(self, request: Request):
try: try:
key_value = ( key_value = (
self._api_key self._api_key
@ -52,12 +54,13 @@ class KeyChecker(SecurityBase):
# FIXME: Find another way to validate the key. A fetch from DB should be avoided here. # FIXME: Find another way to validate the key. A fetch from DB should be avoided here.
# Also, we should not return the wallet here - thats silly. # Also, we should not return the wallet here - thats silly.
# Possibly store it in a Redis DB # Possibly store it in a Redis DB
self.wallet = await get_wallet_for_key(key_value, self._key_type) wallet = await get_wallet_for_key(key_value, self._key_type)
if not self.wallet: if not wallet:
raise HTTPException( raise HTTPException(
status_code=HTTPStatus.UNAUTHORIZED, status_code=HTTPStatus.UNAUTHORIZED,
detail="Invalid key or expired key.", detail="Invalid key or expired key.",
) )
self.wallet = wallet
except KeyError: except KeyError:
raise HTTPException( raise HTTPException(
@ -120,8 +123,8 @@ api_key_query = APIKeyQuery(
async def get_key_type( async def get_key_type(
r: Request, r: Request,
api_key_header: str = Security(api_key_header), api_key_header = Security(api_key_header),
api_key_query: str = Security(api_key_query), api_key_query = Security(api_key_query),
) -> WalletTypeInfo: ) -> WalletTypeInfo:
# 0: admin # 0: admin
# 1: invoice # 1: invoice
@ -134,9 +137,10 @@ async def get_key_type(
token = api_key_header if api_key_header else api_key_query token = api_key_header if api_key_header else api_key_query
try: try:
checker = WalletAdminKeyChecker(api_key=token) admin_checker = WalletAdminKeyChecker(api_key=token)
await checker.__call__(r) await admin_checker.__call__(r)
wallet = WalletTypeInfo(0, checker.wallet) wallet = WalletTypeInfo(0, admin_checker.wallet)
assert wallet.wallet is not None
if (LNBITS_ADMIN_USERS and wallet.wallet.user not in LNBITS_ADMIN_USERS) and ( if (LNBITS_ADMIN_USERS and wallet.wallet.user not in LNBITS_ADMIN_USERS) and (
LNBITS_ADMIN_EXTENSIONS and pathname in LNBITS_ADMIN_EXTENSIONS LNBITS_ADMIN_EXTENSIONS and pathname in LNBITS_ADMIN_EXTENSIONS
): ):
@ -153,9 +157,9 @@ async def get_key_type(
raise raise
try: try:
checker = WalletInvoiceKeyChecker(api_key=token) invoice_checker = WalletInvoiceKeyChecker(api_key=token)
await checker.__call__(r) await invoice_checker.__call__(r)
wallet = WalletTypeInfo(1, checker.wallet) wallet = WalletTypeInfo(1, invoice_checker.wallet)
if (LNBITS_ADMIN_USERS and wallet.wallet.user not in LNBITS_ADMIN_USERS) and ( if (LNBITS_ADMIN_USERS and wallet.wallet.user not in LNBITS_ADMIN_USERS) and (
LNBITS_ADMIN_EXTENSIONS and pathname in LNBITS_ADMIN_EXTENSIONS LNBITS_ADMIN_EXTENSIONS and pathname in LNBITS_ADMIN_EXTENSIONS
): ):
@ -174,8 +178,8 @@ async def get_key_type(
async def require_admin_key( async def require_admin_key(
r: Request, r: Request,
api_key_header: str = Security(api_key_header), api_key_header = Security(api_key_header),
api_key_query: str = Security(api_key_query), api_key_query = Security(api_key_query),
): ):
token = api_key_header if api_key_header else api_key_query token = api_key_header if api_key_header else api_key_query
@ -193,8 +197,8 @@ async def require_admin_key(
async def require_invoice_key( async def require_invoice_key(
r: Request, r: Request,
api_key_header: str = Security(api_key_header), api_key_header = Security(api_key_header),
api_key_query: str = Security(api_key_query), api_key_query = Security(api_key_query),
): ):
token = api_key_header if api_key_header else api_key_query token = api_key_header if api_key_header else api_key_query

View file

@ -7,7 +7,9 @@ from typing import AsyncGenerator, Dict, Optional
import httpx import httpx
from loguru import logger from loguru import logger
from websockets import connect
# mypy https://github.com/aaugustin/websockets/issues/940
from websockets.client import connect
from websockets.exceptions import ( from websockets.exceptions import (
ConnectionClosed, ConnectionClosed,
ConnectionClosedError, ConnectionClosedError,