From 8b60c64aded776d9eb57e601e194e4d65d03e76f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?dni=20=E2=9A=A1?= Date: Thu, 5 Jan 2023 12:02:23 +0100 Subject: [PATCH] fix lnaddress mypy issue --- lnbits/extensions/lnaddress/cloudflare.py | 7 ++++--- lnbits/extensions/lnaddress/crud.py | 2 ++ lnbits/extensions/lnaddress/lnurl.py | 15 ++++----------- lnbits/extensions/lnaddress/models.py | 4 ++-- lnbits/extensions/lnaddress/tasks.py | 13 ++++++++----- lnbits/extensions/lnaddress/views.py | 5 ++--- lnbits/extensions/lnaddress/views_api.py | 16 +++++++++------- pyproject.toml | 1 - 8 files changed, 31 insertions(+), 32 deletions(-) diff --git a/lnbits/extensions/lnaddress/cloudflare.py b/lnbits/extensions/lnaddress/cloudflare.py index 981a37b0..679cb515 100644 --- a/lnbits/extensions/lnaddress/cloudflare.py +++ b/lnbits/extensions/lnaddress/cloudflare.py @@ -16,7 +16,7 @@ async def cloudflare_create_record(domain: Domains, ip: str): "Content-Type": "application/json", } - cf_response = "" + cf_response = {} async with httpx.AsyncClient() as client: try: r = await client.post( @@ -31,9 +31,9 @@ async def cloudflare_create_record(domain: Domains, ip: str): }, timeout=40, ) - cf_response = json.loads(r.text) + cf_response = r.json() except AssertionError: - cf_response = "Error occured" + cf_response = {"error": "Error occured"} return cf_response @@ -53,3 +53,4 @@ async def cloudflare_deleterecord(domain: Domains, domain_id: str): cf_response = r.text except AssertionError: cf_response = "Error occured" + return cf_response diff --git a/lnbits/extensions/lnaddress/crud.py b/lnbits/extensions/lnaddress/crud.py index 25338215..0e590ec8 100644 --- a/lnbits/extensions/lnaddress/crud.py +++ b/lnbits/extensions/lnaddress/crud.py @@ -128,6 +128,7 @@ async def get_addresses(wallet_ids: Union[str, List[str]]) -> List[Addresses]: async def set_address_paid(payment_hash: str) -> Addresses: address = await get_address(payment_hash) + assert address if address.paid == False: await db.execute( @@ -146,6 +147,7 @@ async def set_address_paid(payment_hash: str) -> Addresses: async def set_address_renewed(address_id: str, duration: int): address = await get_address(address_id) + assert address extend_duration = int(address.duration) + duration await db.execute( diff --git a/lnbits/extensions/lnaddress/lnurl.py b/lnbits/extensions/lnaddress/lnurl.py index 6f799439..c4c3cea5 100644 --- a/lnbits/extensions/lnaddress/lnurl.py +++ b/lnbits/extensions/lnaddress/lnurl.py @@ -1,17 +1,9 @@ -import hashlib -import json from datetime import datetime, timedelta import httpx -from fastapi.params import Query -from lnurl import ( # type: ignore - LnurlErrorResponse, - LnurlPayActionResponse, - LnurlPayResponse, -) +from fastapi import Query, Request +from lnurl import LnurlErrorResponse from loguru import logger -from starlette.requests import Request -from starlette.responses import HTMLResponse from . import lnaddress_ext from .crud import get_address, get_address_by_username, get_domain @@ -52,6 +44,7 @@ async def lnurl_callback(address_id, amount: int = Query(...)): amount_received = amount domain = await get_domain(address.domain) + assert domain base_url = ( address.wallet_endpoint[:-1] @@ -79,7 +72,7 @@ async def lnurl_callback(address_id, amount: int = Query(...)): ) r = call.json() - except AssertionError as e: + except Exception: return LnurlErrorResponse(reason="ERROR") # resp = LnurlPayActionResponse(pr=r["payment_request"], routes=[]) diff --git a/lnbits/extensions/lnaddress/models.py b/lnbits/extensions/lnaddress/models.py index 248f856c..77eb3cd3 100644 --- a/lnbits/extensions/lnaddress/models.py +++ b/lnbits/extensions/lnaddress/models.py @@ -1,9 +1,9 @@ import json from typing import Optional -from fastapi.params import Query +from fastapi import Query from lnurl.types import LnurlPayMetadata -from pydantic.main import BaseModel +from pydantic import BaseModel class CreateDomain(BaseModel): diff --git a/lnbits/extensions/lnaddress/tasks.py b/lnbits/extensions/lnaddress/tasks.py index 0c377eec..bdbf5691 100644 --- a/lnbits/extensions/lnaddress/tasks.py +++ b/lnbits/extensions/lnaddress/tasks.py @@ -1,6 +1,7 @@ import asyncio import httpx +from loguru import logger from lnbits.core.models import Payment from lnbits.helpers import get_current_extension_name @@ -21,7 +22,9 @@ async def wait_for_paid_invoices(): async def call_webhook_on_paid(payment_hash): ### Use webhook to notify about cloudflare registration address = await get_address(payment_hash) + assert address domain = await get_domain(address.domain) + assert domain if not domain.webhook: return @@ -39,24 +42,24 @@ async def call_webhook_on_paid(payment_hash): }, timeout=40, ) - except AssertionError: - webhook = None + r.raise_for_status() + except Exception as e: + logger.error(f"lnaddress: error calling webhook on paid: {str(e)}") async def on_invoice_paid(payment: Payment) -> None: + if not payment.extra: + return if payment.extra.get("tag") == "lnaddress": - await payment.set_pending(False) await set_address_paid(payment_hash=payment.payment_hash) await call_webhook_on_paid(payment_hash=payment.payment_hash) elif payment.extra.get("tag") == "renew lnaddress": - await payment.set_pending(False) await set_address_renewed( address_id=payment.extra["id"], duration=payment.extra["duration"] ) await call_webhook_on_paid(payment_hash=payment.payment_hash) - else: return diff --git a/lnbits/extensions/lnaddress/views.py b/lnbits/extensions/lnaddress/views.py index 8c838f0c..d1a7be83 100644 --- a/lnbits/extensions/lnaddress/views.py +++ b/lnbits/extensions/lnaddress/views.py @@ -1,10 +1,8 @@ from http import HTTPStatus from urllib.parse import urlparse -from fastapi import Request -from fastapi.params import Depends +from fastapi import Depends, HTTPException, Request from fastapi.templating import Jinja2Templates -from starlette.exceptions import HTTPException from starlette.responses import HTMLResponse from lnbits.core.crud import get_wallet @@ -35,6 +33,7 @@ async def display(domain_id, request: Request): await purge_addresses(domain_id) wallet = await get_wallet(domain.wallet) + assert wallet url = urlparse(str(request.url)) return lnaddress_renderer().TemplateResponse( diff --git a/lnbits/extensions/lnaddress/views_api.py b/lnbits/extensions/lnaddress/views_api.py index 46ef6b99..d9e50e9d 100644 --- a/lnbits/extensions/lnaddress/views_api.py +++ b/lnbits/extensions/lnaddress/views_api.py @@ -1,9 +1,7 @@ from http import HTTPStatus from urllib.parse import urlparse -from fastapi import Request -from fastapi.params import Depends, Query -from starlette.exceptions import HTTPException +from fastapi import Depends, HTTPException, Query, Request from lnbits.core.crud import get_user from lnbits.core.services import check_transaction_status, create_invoice @@ -11,7 +9,7 @@ from lnbits.decorators import WalletTypeInfo, get_key_type from lnbits.extensions.lnaddress.models import CreateAddress, CreateDomain from . import lnaddress_ext -from .cloudflare import cloudflare_create_record, cloudflare_deleterecord +from .cloudflare import cloudflare_create_record from .crud import ( check_address_available, create_address, @@ -35,7 +33,8 @@ async def api_domains( wallet_ids = [g.wallet.id] if all_wallets: - wallet_ids = (await get_user(g.wallet.user)).wallet_ids + user = await get_user(g.wallet.user) + wallet_ids = user.wallet_ids if user else [] return [domain.dict() for domain in await get_domains(wallet_ids)] @@ -69,7 +68,7 @@ async def api_domain_create( cf_response = await cloudflare_create_record(domain=domain, ip=root_url) - if not cf_response or cf_response["success"] != True: + if not cf_response or not cf_response["success"]: await delete_domain(domain.id) raise HTTPException( status_code=HTTPStatus.BAD_REQUEST, @@ -106,7 +105,8 @@ async def api_addresses( wallet_ids = [g.wallet.id] if all_wallets: - wallet_ids = (await get_user(g.wallet.user)).wallet_ids + user = await get_user(g.wallet.user) + wallet_ids = user.wallet_ids if user else [] return [address.dict() for address in await get_addresses(wallet_ids)] @@ -227,7 +227,9 @@ async def api_lnaddress_make_address( @lnaddress_ext.get("/api/v1/addresses/{payment_hash}") async def api_address_send_address(payment_hash): address = await get_address(payment_hash) + assert address domain = await get_domain(address.domain) + assert domain try: status = await check_transaction_status(domain.wallet, payment_hash) is_paid = not status.pending diff --git a/pyproject.toml b/pyproject.toml index e2116ed0..2d11fba7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -93,7 +93,6 @@ exclude = """(?x)( | ^lnbits/extensions/boltz. | ^lnbits/extensions/boltcards. | ^lnbits/extensions/livestream. - | ^lnbits/extensions/lnaddress. | ^lnbits/extensions/lnurldevice. | ^lnbits/extensions/satspay. | ^lnbits/extensions/watchonly.