diff --git a/lnbits/auth_bearer.py b/lnbits/auth_bearer.py deleted file mode 100644 index 163785dd..00000000 --- a/lnbits/auth_bearer.py +++ /dev/null @@ -1,51 +0,0 @@ -from fastapi import Request, HTTPException -from fastapi.security.api_key import APIKeyQuery, APIKeyCookie, APIKeyHeader, APIKey - -# https://medium.com/data-rebels/fastapi-authentication-revisited-enabling-api-key-authentication-122dc5975680 - -from fastapi import Security, Depends, FastAPI, HTTPException -from fastapi.security.api_key import APIKeyQuery, APIKeyCookie, APIKeyHeader, APIKey -from fastapi.security.base import SecurityBase - - -API_KEY = "usr" -API_KEY_NAME = "X-API-key" - -api_key_query = APIKeyQuery(name=API_KEY_NAME, auto_error=False) -api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False) - - -class AuthBearer(SecurityBase): - def __init__(self, scheme_name: str = None, auto_error: bool = True): - self.scheme_name = scheme_name or self.__class__.__name__ - self.auto_error = auto_error - - async def __call__(self, request: Request): - key = await self.get_api_key() - print(key) - # credentials: HTTPAuthorizationCredentials = await super(AuthBearer, self).__call__(request) - # if credentials: - # if not credentials.scheme == "Bearer": - # raise HTTPException( - # status_code=403, detail="Invalid authentication scheme.") - # if not self.verify_jwt(credentials.credentials): - # raise HTTPException( - # status_code=403, detail="Invalid token or expired token.") - # return credentials.credentials - # else: - # raise HTTPException( - # status_code=403, detail="Invalid authorization code.") - - async def get_api_key( - self, - api_key_query: str = Security(api_key_query), - api_key_header: str = Security(api_key_header), - ): - if api_key_query == API_KEY: - return api_key_query - elif api_key_header == API_KEY: - return api_key_header - else: - raise HTTPException( - status_code=403, detail="Could not validate credentials" - ) diff --git a/lnbits/core/services.py b/lnbits/core/services.py index 08ee6e37..64db8d72 100644 --- a/lnbits/core/services.py +++ b/lnbits/core/services.py @@ -143,23 +143,25 @@ async def pay_invoice( if wallet.balance_msat < 0: raise PermissionError("Insufficient balance.") - if internal_checking_id: - # mark the invoice from the other side as not pending anymore - # so the other side only has access to his new money when we are sure - # the payer has enough to deduct from + if internal_checking_id: + # mark the invoice from the other side as not pending anymore + # so the other side only has access to his new money when we are sure + # the payer has enough to deduct from + async with db.connect() as conn: await update_payment_status( checking_id=internal_checking_id, pending=False, conn=conn ) - # notify receiver asynchronously + # notify receiver asynchronously - from lnbits.tasks import internal_invoice_queue + from lnbits.tasks import internal_invoice_queue - await internal_invoice_queue.put(internal_checking_id) - else: - # actually pay the external invoice - payment: PaymentResponse = await WALLET.pay_invoice(payment_request) - if payment.checking_id: + await internal_invoice_queue.put(internal_checking_id) + else: + # actually pay the external invoice + payment: PaymentResponse = await WALLET.pay_invoice(payment_request) + if payment.checking_id: + async with db.connect() as conn: await create_payment( checking_id=payment.checking_id, fee=payment.fee_msat, @@ -169,13 +171,13 @@ async def pay_invoice( **payment_kwargs, ) await delete_payment(temp_id, conn=conn) - else: - raise PaymentFailure( - payment.error_message - or "Payment failed, but backend didn't give us an error message." - ) + else: + raise PaymentFailure( + payment.error_message + or "Payment failed, but backend didn't give us an error message." + ) - return invoice.payment_hash + return invoice.payment_hash async def redeem_lnurl_withdraw( @@ -314,7 +316,8 @@ async def check_invoice_status( if not payment.pending: return status if payment.is_out and status.failed: - print(f" - deleting outgoing failed payment {payment.checking_id}: {status}") + print( + f" - deleting outgoing failed payment {payment.checking_id}: {status}") await payment.delete() elif not status.pending: print( diff --git a/lnbits/core/views/generic.py b/lnbits/core/views/generic.py index 5e0ededf..5dbe92a7 100644 --- a/lnbits/core/views/generic.py +++ b/lnbits/core/views/generic.py @@ -4,18 +4,16 @@ from typing import Optional from fastapi import Request, status from fastapi.exceptions import HTTPException -from fastapi.param_functions import Body from fastapi.params import Depends, Query from fastapi.responses import FileResponse, RedirectResponse from fastapi.routing import APIRouter from pydantic.types import UUID4 -from starlette.responses import HTMLResponse +from starlette.responses import HTMLResponse, JSONResponse from lnbits.core import db from lnbits.core.models import User from lnbits.decorators import check_user_exists from lnbits.helpers import template_renderer, url_for -from lnbits.requestvars import g from lnbits.settings import LNBITS_ALLOWED_USERS, LNBITS_SITE_TITLE, SERVICE_FEE from ..crud import ( @@ -32,7 +30,7 @@ from ..services import pay_invoice, redeem_lnurl_withdraw core_html_routes: APIRouter = APIRouter(tags=["Core NON-API Website Routes"]) -@core_html_routes.get("/favicon.ico") +@core_html_routes.get("/favicon.ico", response_class=FileResponse) async def favicon(): return FileResponse("lnbits/core/static/favicon.ico") @@ -44,7 +42,11 @@ async def home(request: Request, lightning: str = None): ) -@core_html_routes.get("/extensions", name="core.extensions") +@core_html_routes.get( + "/extensions", + name="core.extensions", + response_class=HTMLResponse, +) async def extensions( request: Request, user: User = Depends(check_user_exists), @@ -77,9 +79,19 @@ async def extensions( ) -@core_html_routes.get("/wallet", response_class=HTMLResponse) -# Not sure how to validate -# @validate_uuids(["usr", "nme"]) +@core_html_routes.get( + "/wallet", + response_class=HTMLResponse, + description=""" +Args: + +just **wallet_name**: create a new user, then create a new wallet for user with wallet_name
+just **user_id**: return the first user wallet or create one if none found (with default wallet_name)
+**user_id** and **wallet_name**: create a new wallet for user with wallet_name
+**user_id** and **wallet_id**: return that wallet if user is the owner
+nothing: create everything
+""", +) async def wallet( request: Request = Query(None), nme: Optional[str] = Query(None), @@ -91,12 +103,6 @@ async def wallet( wallet_name = nme service_fee = int(SERVICE_FEE) if int(SERVICE_FEE) == SERVICE_FEE else SERVICE_FEE - # just wallet_name: create a new user, then create a new wallet for user with wallet_name - # just user_id: return the first user wallet or create one if none found (with default wallet_name) - # user_id and wallet_name: create a new wallet for user with wallet_name - # user_id and wallet_id: return that wallet if user is the owner - # nothing: create everything - if not user_id: user = await get_user((await create_account()).id) else: @@ -137,14 +143,13 @@ async def wallet( ) -@core_html_routes.get("/withdraw") -# @validate_uuids(["usr", "wal"], required=True) +@core_html_routes.get("/withdraw", response_class=JSONResponse) async def lnurl_full_withdraw(request: Request): - user = await get_user(request.args.get("usr")) + user = await get_user(request.query_params.get("usr")) if not user: return {"status": "ERROR", "reason": "User does not exist."} - wallet = user.get_wallet(request.args.get("wal")) + wallet = user.get_wallet(request.query_params.get("wal")) if not wallet: return {"status": "ERROR", "reason": "Wallet does not exist."} @@ -159,18 +164,17 @@ async def lnurl_full_withdraw(request: Request): } -@core_html_routes.get("/withdraw/cb") -# @validate_uuids(["usr", "wal"], required=True) +@core_html_routes.get("/withdraw/cb", response_class=JSONResponse) async def lnurl_full_withdraw_callback(request: Request): - user = await get_user(request.args.get("usr")) + user = await get_user(request.query_params.get("usr")) if not user: return {"status": "ERROR", "reason": "User does not exist."} - wallet = user.get_wallet(request.args.get("wal")) + wallet = user.get_wallet(request.query_params.get("wal")) if not wallet: return {"status": "ERROR", "reason": "Wallet does not exist."} - pr = request.args.get("pr") + pr = request.query_params.get("pr") async def pay(): try: @@ -180,14 +184,14 @@ async def lnurl_full_withdraw_callback(request: Request): asyncio.create_task(pay()) - balance_notify = request.args.get("balanceNotify") + balance_notify = request.query_params.get("balanceNotify") if balance_notify: await save_balance_notify(wallet.id, balance_notify) return {"status": "OK"} -@core_html_routes.get("/deletewallet") +@core_html_routes.get("/deletewallet", response_class=RedirectResponse) async def deletewallet(request: Request, wal: str = Query(...), usr: str = Query(...)): user = await get_user(usr) user_wallet_ids = [u.id for u in user.wallets] @@ -211,14 +215,13 @@ async def deletewallet(request: Request, wal: str = Query(...), usr: str = Query @core_html_routes.get("/withdraw/notify/{service}") -# @validate_uuids(["wal"], required=True) async def lnurl_balance_notify(request: Request, service: str): - bc = await get_balance_check(request.args.get("wal"), service) + bc = await get_balance_check(request.query_params.get("wal"), service) if bc: redeem_lnurl_withdraw(bc.wallet, bc.url) -@core_html_routes.get("/lnurlwallet") +@core_html_routes.get("/lnurlwallet", response_class=RedirectResponse) async def lnurlwallet(request: Request): async with db.connect() as conn: account = await create_account(conn=conn) @@ -228,7 +231,7 @@ async def lnurlwallet(request: Request): asyncio.create_task( redeem_lnurl_withdraw( wallet.id, - request.args.get("lightning"), + request.query_params.get("lightning"), "LNbits initial funding: voucher redeem.", {"tag": "lnurlwallet"}, 5, # wait 5 seconds before sending the invoice to the service diff --git a/lnbits/extensions/bleskomat/helpers.py b/lnbits/extensions/bleskomat/helpers.py index 1062ca27..6e55b3df 100644 --- a/lnbits/extensions/bleskomat/helpers.py +++ b/lnbits/extensions/bleskomat/helpers.py @@ -35,8 +35,8 @@ def generate_bleskomat_lnurl_secret(api_key_id: str, signature: str): return m.hexdigest() -def get_callback_url(request: Request): - return request.url_for("bleskomat.api_bleskomat_lnurl") +def get_callback_url(req: Request): + return req.url_for("bleskomat.api_bleskomat_lnurl") def is_supported_lnurl_subprotocol(tag: str) -> bool: diff --git a/lnbits/extensions/bleskomat/lnurl_api.py b/lnbits/extensions/bleskomat/lnurl_api.py index fa3e6133..4faa0ee9 100644 --- a/lnbits/extensions/bleskomat/lnurl_api.py +++ b/lnbits/extensions/bleskomat/lnurl_api.py @@ -25,9 +25,9 @@ from .helpers import ( # Handles signed URL from Bleskomat ATMs and "action" callback of auto-generated LNURLs. @bleskomat_ext.get("/u", name="bleskomat.api_bleskomat_lnurl") -async def api_bleskomat_lnurl(request: Request): +async def api_bleskomat_lnurl(req: Request): try: - query = request.query_params + query = req.query_params # Unshorten query if "s" is used instead of "signature". if "s" in query: @@ -96,7 +96,7 @@ async def api_bleskomat_lnurl(request: Request): ) # Reply with LNURL response object. - return lnurl.get_info_response_object(secret) + return lnurl.get_info_response_object(secret, req) # No signature provided. # Treat as "action" callback. diff --git a/lnbits/extensions/bleskomat/models.py b/lnbits/extensions/bleskomat/models.py index e96ddb80..89aefe1f 100644 --- a/lnbits/extensions/bleskomat/models.py +++ b/lnbits/extensions/bleskomat/models.py @@ -7,7 +7,7 @@ from pydantic import BaseModel, validator from starlette.requests import Request from lnbits import bolt11 -from lnbits.core.services import pay_invoice +from lnbits.core.services import pay_invoice, PaymentFailure from . import db from .exchange_rates import exchange_rate_providers, fiat_currencies @@ -119,13 +119,13 @@ class BleskomatLnurl(BaseModel): tag = self.tag if tag == "withdrawRequest": try: - payment_hash = await pay_invoice( + await pay_invoice( wallet_id=self.wallet, payment_request=query["pr"] ) + except (ValueError, PermissionError, PaymentFailure) as e: + raise LnurlValidationError("Failed to pay invoice: " + str(e)) except Exception: - raise LnurlValidationError("Failed to pay invoice") - if not payment_hash: - raise LnurlValidationError("Failed to pay invoice") + raise LnurlValidationError("Unexpected error") async def use(self, conn) -> bool: now = int(time.time()) diff --git a/lnbits/extensions/bleskomat/views.py b/lnbits/extensions/bleskomat/views.py index c3e775c8..92d47513 100644 --- a/lnbits/extensions/bleskomat/views.py +++ b/lnbits/extensions/bleskomat/views.py @@ -14,13 +14,13 @@ templates = Jinja2Templates(directory="templates") @bleskomat_ext.get("/", response_class=HTMLResponse) -async def index(request: Request, user: User = Depends(check_user_exists)): +async def index(req: Request, user: User = Depends(check_user_exists)): bleskomat_vars = { - "callback_url": get_callback_url(request=request), + "callback_url": get_callback_url(req), "exchange_rate_providers": exchange_rate_providers_serializable, "fiat_currencies": fiat_currencies, } return bleskomat_renderer().TemplateResponse( "bleskomat/index.html", - {"request": request, "user": user.dict(), "bleskomat_vars": bleskomat_vars}, + {"request": req, "user": user.dict(), "bleskomat_vars": bleskomat_vars}, ) diff --git a/lnbits/extensions/satsdice/views.py b/lnbits/extensions/satsdice/views.py index a9cbcaf7..72e24867 100644 --- a/lnbits/extensions/satsdice/views.py +++ b/lnbits/extensions/satsdice/views.py @@ -61,7 +61,7 @@ async def displaywin( request: Request, link_id: str = Query(None), payment_hash: str = Query(None) ): satsdicelink = await get_satsdice_pay(link_id) - if not satsdiceLink: + if not satsdicelink: raise HTTPException( status_code=HTTPStatus.NOT_FOUND, detail="satsdice link does not exist." ) diff --git a/lnbits/extensions/tipjar/views_api.py b/lnbits/extensions/tipjar/views_api.py index 55c2d83f..c4d72ccd 100644 --- a/lnbits/extensions/tipjar/views_api.py +++ b/lnbits/extensions/tipjar/views_api.py @@ -169,7 +169,7 @@ async def api_delete_tip( raise HTTPException( status_code=HTTPStatus.NOT_FOUND, detail="No tip with this ID!" ) - if tip.wallet != g.wallet.id: + if tip.wallet != wallet.wallet.id: raise HTTPException( status_code=HTTPStatus.FORBIDDEN, detail="Not authorized to delete this tip!", @@ -189,7 +189,7 @@ async def api_delete_tipjar( raise HTTPException( status_code=HTTPStatus.NOT_FOUND, detail="No tipjar with this ID!" ) - if tipjar.wallet != g.wallet.id: + if tipjar.wallet != wallet.wallet.id: raise HTTPException( status_code=HTTPStatus.FORBIDDEN,