diff --git a/lnbits/core/services/fiat_providers.py b/lnbits/core/services/fiat_providers.py index 68cab298..085d32c3 100644 --- a/lnbits/core/services/fiat_providers.py +++ b/lnbits/core/services/fiat_providers.py @@ -5,7 +5,7 @@ import time from loguru import logger from lnbits.core.crud import get_wallet -from lnbits.core.crud.payments import create_payment, get_standalone_payment +from lnbits.core.crud.payments import create_payment from lnbits.core.models import CreatePayment, Payment, PaymentState from lnbits.core.models.misc import SimpleStatus from lnbits.db import Connection @@ -27,6 +27,73 @@ async def handle_fiat_payment_confirmation( logger.warning(e) +def check_stripe_signature( + payload: bytes, + sig_header: str | None, + secret: str | None, + tolerance_seconds=300, +): + if not sig_header: + logger.warning("Stripe-Signature header is missing.") + raise ValueError("Stripe-Signature header is missing.") + + if not secret: + logger.warning("Stripe webhook signing secret is not set.") + raise ValueError("Stripe webhook cannot be verified.") + + # Split the Stripe-Signature header + items = dict(i.split("=") for i in sig_header.split(",")) + timestamp = int(items["t"]) + signature = items["v1"] + + # Check timestamp tolerance + if abs(time.time() - timestamp) > tolerance_seconds: + logger.warning("Timestamp outside tolerance.") + logger.debug( + f"Current time: {time.time()}, " + f"Timestamp: {timestamp}, " + f"Tolerance: {tolerance_seconds} seconds" + ) + + raise ValueError("Timestamp outside tolerance." f"Timestamp: {timestamp}") + + signed_payload = f"{timestamp}.{payload.decode()}" + + # Compute HMAC SHA256 using the webhook secret + computed_signature = hmac.new( + key=secret.encode(), msg=signed_payload.encode(), digestmod=hashlib.sha256 + ).hexdigest() + + # Compare signatures using constant time comparison + if hmac.compare_digest(computed_signature, signature) is not True: + logger.warning("Stripe signature verification failed.") + raise ValueError("Stripe signature verification failed.") + + +async def test_connection(provider: str) -> SimpleStatus: + """ + Test the connection to Stripe by checking if the API key is valid. + This function should be called when setting up or testing the Stripe integration. + """ + fiat_provider = await get_fiat_provider(provider) + if not fiat_provider: + return SimpleStatus( + success=False, + message=f"Fiat provider '{provider}' not found.", + ) + status = await fiat_provider.status() + if status.error_message: + return SimpleStatus( + success=False, + message=f"Cconnection test failed: {status.error_message}", + ) + + return SimpleStatus( + success=True, + message="Connection test successful." f" Balance: {status.balance}.", + ) + + async def _credit_fiat_service_fee_wallet( payment: Payment, conn: Connection | None = None ): @@ -104,90 +171,3 @@ async def _debit_fiat_service_faucet_wallet( status=PaymentState.SUCCESS, conn=conn, ) - - -async def handle_stripe_event(event: dict): - event_id = event.get("id") - event_object = event.get("data", {}).get("object", {}) - object_type = event_object.get("object") - payment_hash = event_object.get("metadata", {}).get("payment_hash") - logger.debug( - f"Handling Stripe event: '{event_id}'. Type: '{object_type}'." - f" Payment hash: '{payment_hash}'." - ) - if not payment_hash: - logger.warning("Stripe event does not contain a payment hash.") - return - - payment = await get_standalone_payment(payment_hash) - if not payment: - logger.warning(f"No payment found for hash: '{payment_hash}'.") - return - await payment.check_fiat_status() - - -def check_stripe_signature( - payload: bytes, - sig_header: str | None, - secret: str | None, - tolerance_seconds=300, -): - if not sig_header: - logger.warning("Stripe-Signature header is missing.") - raise ValueError("Stripe-Signature header is missing.") - - if not secret: - logger.warning("Stripe webhook signing secret is not set.") - raise ValueError("Stripe webhook cannot be verified.") - - # Split the Stripe-Signature header - items = dict(i.split("=") for i in sig_header.split(",")) - timestamp = int(items["t"]) - signature = items["v1"] - - # Check timestamp tolerance - if abs(time.time() - timestamp) > tolerance_seconds: - logger.warning("Timestamp outside tolerance.") - logger.debug( - f"Current time: {time.time()}, " - f"Timestamp: {timestamp}, " - f"Tolerance: {tolerance_seconds} seconds" - ) - - raise ValueError("Timestamp outside tolerance." f"Timestamp: {timestamp}") - - signed_payload = f"{timestamp}.{payload.decode()}" - - # Compute HMAC SHA256 using the webhook secret - computed_signature = hmac.new( - key=secret.encode(), msg=signed_payload.encode(), digestmod=hashlib.sha256 - ).hexdigest() - - # Compare signatures using constant time comparison - if hmac.compare_digest(computed_signature, signature) is not True: - logger.warning("Stripe signature verification failed.") - raise ValueError("Stripe signature verification failed.") - - -async def test_connection(provider: str) -> SimpleStatus: - """ - Test the connection to Stripe by checking if the API key is valid. - This function should be called when setting up or testing the Stripe integration. - """ - fiat_provider = await get_fiat_provider(provider) - if not fiat_provider: - return SimpleStatus( - success=False, - message=f"Fiat provider '{provider}' not found.", - ) - status = await fiat_provider.status() - if status.error_message: - return SimpleStatus( - success=False, - message=f"Cconnection test failed: {status.error_message}", - ) - - return SimpleStatus( - success=True, - message="Connection test successful." f" Balance: {status.balance}.", - ) diff --git a/lnbits/core/services/payments.py b/lnbits/core/services/payments.py index 5a55a451..7fdfcfff 100644 --- a/lnbits/core/services/payments.py +++ b/lnbits/core/services/payments.py @@ -110,7 +110,7 @@ async def create_payment_request( async def create_fiat_invoice( wallet_id: str, invoice_data: CreateInvoice, conn: Connection | None = None -): +) -> Payment: fiat_provider_name = invoice_data.fiat_provider if not fiat_provider_name: raise ValueError("Fiat provider is required for fiat invoices.") diff --git a/lnbits/core/templates/admin/_tab_fiat_providers.html b/lnbits/core/templates/admin/_tab_fiat_providers.html index 774a6b5b..6806365e 100644 --- a/lnbits/core/templates/admin/_tab_fiat_providers.html +++ b/lnbits/core/templates/admin/_tab_fiat_providers.html @@ -103,10 +103,10 @@ diff --git a/lnbits/core/views/callback_api.py b/lnbits/core/views/callback_api.py index 37b9d2aa..d1314dff 100644 --- a/lnbits/core/views/callback_api.py +++ b/lnbits/core/views/callback_api.py @@ -1,10 +1,18 @@ -from fastapi import APIRouter, Request +import json +from fastapi import APIRouter, Request +from loguru import logger + +from lnbits.core.crud.payments import ( + get_standalone_payment, +) from lnbits.core.models.misc import SimpleStatus +from lnbits.core.models.payments import CreateInvoice from lnbits.core.services.fiat_providers import ( check_stripe_signature, - handle_stripe_event, ) +from lnbits.core.services.payments import create_fiat_invoice +from lnbits.fiat.base import FiatSubscriptionPaymentOptions from lnbits.settings import settings callback_router = APIRouter(prefix="/api/v1/callback", tags=["callback"]) @@ -33,3 +41,105 @@ async def api_generic_webhook_handler( success=False, message=f"Unknown fiat provider '{provider_name}'.", ) + + +async def handle_stripe_event(event: dict): + event_id = event.get("id") + event_type = event.get("type") + if event_type == "checkout.session.completed": + await _handle_stripe_checkout_session_completed(event) + elif event_type == "invoice.paid": + await _handle_stripe_subscription_invoice_paid(event) + else: + logger.info( + f"Unhandled Stripe event type: '{event_type}'." f" Event ID: '{event_id}'." + ) + + +async def _handle_stripe_checkout_session_completed(event: dict): + event_id = event.get("id") + event_object = event.get("data", {}).get("object", {}) + object_type = event_object.get("object") + payment_hash = event_object.get("metadata", {}).get("payment_hash") + lnbits_action = event_object.get("metadata", {}).get("lnbits_action") + logger.debug( + f"Handling Stripe event: '{event_id}'. Type: '{object_type}'." + f" Payment hash: '{payment_hash}'." + ) + if lnbits_action != "invoice": + logger.warning(f"Stripe event is not an invoice: '{lnbits_action}'.") + return + + if not payment_hash: + raise ValueError("Stripe event does not contain a payment hash.") + + payment = await get_standalone_payment(payment_hash) + if not payment: + raise ValueError(f"No payment found for hash: '{payment_hash}'.") + await payment.check_fiat_status() + + +async def _handle_stripe_subscription_invoice_paid(event: dict): + invoice = event.get("data", {}).get("object", {}) + parent = invoice.get("parent", {}) + + currency = invoice.get("currency", "").upper() + if not currency: + raise ValueError("Stripe invoice.paid event missing 'currency'.") + + amount_paid = invoice.get("amount_paid") + if not amount_paid: + raise ValueError("Stripe invoice.paid event missing 'amount_paid'.") + + payment_options = await _get_stripe_subscription_payment_options(parent) + if not payment_options.wallet_id: + raise ValueError("Stripe invoice.paid event missing 'wallet_id' in metadata.") + + memo = " | ".join( + [i.get("description", "") for i in invoice.get("lines", {}).get("data", [])] + + [payment_options.memo or "", invoice.get("customer_email", "")] + ) + + extra = { + **(payment_options.extra or {}), + "fiat_method": "subscription", + "tag": payment_options.tag, + "subscription": { + "checking_id": invoice.get("id"), + "payment_request": invoice.get("hosted_invoice_url"), + }, + } + + payment = await create_fiat_invoice( + wallet_id=payment_options.wallet_id, + invoice_data=CreateInvoice( + unit=currency, + amount=amount_paid / 100, # convert cents to dollars + memo=memo, + extra=extra, + fiat_provider="stripe", + ), + ) + + await payment.check_fiat_status() + + +async def _get_stripe_subscription_payment_options( + parent: dict, +) -> FiatSubscriptionPaymentOptions: + if not parent or not parent.get("type") == "subscription_details": + raise ValueError("Stripe invoice.paid event does not contain a subscription.") + + metadata = parent.get("subscription_details", {}).get("metadata", {}) + + if metadata.get("lnbits_action") != "subscription": + raise ValueError("Stripe invoice.paid metadata action is not 'subscription'.") + + if "extra" in metadata: + try: + metadata["extra"] = json.loads(metadata["extra"]) + except json.JSONDecodeError as exc: + logger.warning(exc) + metadata["extra"] = {} + + return FiatSubscriptionPaymentOptions(**metadata) diff --git a/lnbits/core/views/fiat_api.py b/lnbits/core/views/fiat_api.py index f9a29685..d2a09688 100644 --- a/lnbits/core/views/fiat_api.py +++ b/lnbits/core/views/fiat_api.py @@ -3,9 +3,11 @@ from http import HTTPStatus from fastapi import APIRouter, Depends, HTTPException from lnbits.core.models.misc import SimpleStatus +from lnbits.core.models.wallets import WalletTypeInfo from lnbits.core.services.fiat_providers import test_connection -from lnbits.decorators import check_admin +from lnbits.decorators import check_admin, require_admin_key from lnbits.fiat import StripeWallet, get_fiat_provider +from lnbits.fiat.base import CreateFiatSubscription, FiatSubscriptionResponse fiat_router = APIRouter(tags=["Fiat API"], prefix="/api/v1/fiat") @@ -19,20 +21,67 @@ async def api_test_fiat_provider(provider: str) -> SimpleStatus: return await test_connection(provider) +@fiat_router.post( + "/{provider}/subscription", + status_code=HTTPStatus.OK, +) +async def create_subscription( + provider: str, + data: CreateFiatSubscription, + key_type: WalletTypeInfo = Depends(require_admin_key), +) -> FiatSubscriptionResponse: + fiat_provider = await get_fiat_provider(provider) + if not fiat_provider: + raise HTTPException(404, "Fiat provider not found") + + wallet_id = data.payment_options.wallet_id + + if wallet_id and wallet_id != key_type.wallet.id: + raise HTTPException( + 403, + "Wallet id does not match your API key." + "Leave it empty to use your key's wallet.", + ) + + data.payment_options.wallet_id = key_type.wallet.id + subscription_response = await fiat_provider.create_subscription( + data.subscription_id, data.quantity, data.payment_options + ) + return subscription_response + + +@fiat_router.delete( + "/{provider}/subscription/{subscription_id}", + status_code=HTTPStatus.OK, +) +async def cancel_subscription( + provider: str, + subscription_id: str, + key_type: WalletTypeInfo = Depends(require_admin_key), +) -> FiatSubscriptionResponse: + fiat_provider = await get_fiat_provider(provider) + if not fiat_provider: + raise HTTPException(404, "Fiat provider not found") + + resp = await fiat_provider.cancel_subscription(subscription_id, key_type.wallet.id) + + return resp + + @fiat_router.post( "/{provider}/connection_token", status_code=HTTPStatus.OK, dependencies=[Depends(check_admin)], ) async def connection_token(provider: str): - provider_wallet = await get_fiat_provider(provider) + fiat_provider = await get_fiat_provider(provider) if provider == "stripe": - if not isinstance(provider_wallet, StripeWallet): + if not isinstance(fiat_provider, StripeWallet): raise HTTPException( status_code=500, detail="Stripe wallet/provider not configured" ) try: - tok = await provider_wallet.create_terminal_connection_token() + tok = await fiat_provider.create_terminal_connection_token() secret = tok.get("secret") if not secret: raise HTTPException( diff --git a/lnbits/fiat/base.py b/lnbits/fiat/base.py index 025b9c31..162ed85b 100644 --- a/lnbits/fiat/base.py +++ b/lnbits/fiat/base.py @@ -4,6 +4,8 @@ from abc import ABC, abstractmethod from collections.abc import AsyncGenerator, Coroutine from typing import TYPE_CHECKING, Any, NamedTuple +from pydantic import BaseModel, Field + if TYPE_CHECKING: pass @@ -76,6 +78,54 @@ class FiatPaymentStatus(NamedTuple): return "pending" +class FiatSubscriptionPaymentOptions(BaseModel): + + memo: str | None = Field( + default=None, + description="Payments created by the recurring subscription" + " will have this memo.", + ) + wallet_id: str | None = Field( + default=None, + description="Payments created by the recurring subscription" + " will be made to this wallet.", + ) + subscription_request_id: str | None = Field( + default=None, + description="Unique ID that can be used to identify the subscription request." + "If not provided, one will be generated.", + ) + tag: str | None = Field( + default=None, + description="Payments created by the recurring subscription" + " will have this tag. Admin only.", + ) + extra: dict[str, Any] | None = Field( + default=None, + description="Payments created by the recurring subscription" + " will merge this extra data to the payment extra. Admin only.", + ) + + success_url: str | None = Field( + default="https://my.lnbits.com", + description="The URL to redirect the user to after the" + " subscription is successfully created.", + ) + + +class CreateFiatSubscription(BaseModel): + subscription_id: str + quantity: int + payment_options: FiatSubscriptionPaymentOptions + + +class FiatSubscriptionResponse(BaseModel): + ok: bool = True + subscription_request_id: str | None = None + checkout_session_url: str | None = None + error_message: str | None = None + + class FiatPaymentSuccessStatus(FiatPaymentStatus): paid = True @@ -111,6 +161,32 @@ class FiatProvider(ABC): ) -> Coroutine[None, None, FiatInvoiceResponse]: pass + @abstractmethod + def create_subscription( + self, + subscription_id: str, + quantity: int, + payment_options: FiatSubscriptionPaymentOptions, + **kwargs, + ) -> Coroutine[None, None, FiatSubscriptionResponse]: + pass + + @abstractmethod + def cancel_subscription( + self, + subscription_id: str, + correlation_id: str, + **kwargs, + ) -> Coroutine[None, None, FiatSubscriptionResponse]: + """ + Cancel a subscription. + Args: + subscription_id: The ID of the subscription to cancel. + correlation_id: An identifier used to verify that the subscription belongs + to the user that made the request. Usually the wallet ID. + """ + pass + @abstractmethod def pay_invoice( self, diff --git a/lnbits/fiat/stripe.py b/lnbits/fiat/stripe.py index 25d0c6ee..aae6dcf2 100644 --- a/lnbits/fiat/stripe.py +++ b/lnbits/fiat/stripe.py @@ -1,5 +1,6 @@ import asyncio import json +import uuid from collections.abc import AsyncGenerator from datetime import datetime, timedelta, timezone from typing import Any, Literal @@ -9,7 +10,7 @@ import httpx from loguru import logger from pydantic import BaseModel, Field, ValidationError -from lnbits.helpers import normalize_endpoint +from lnbits.helpers import normalize_endpoint, urlsafe_short_hash from lnbits.settings import settings from .base import ( @@ -21,9 +22,11 @@ from .base import ( FiatPaymentSuccessStatus, FiatProvider, FiatStatusResponse, + FiatSubscriptionPaymentOptions, + FiatSubscriptionResponse, ) -FiatMethod = Literal["checkout", "terminal"] +FiatMethod = Literal["checkout", "terminal", "subscription"] class StripeTerminalOptions(BaseModel): @@ -43,6 +46,14 @@ class StripeCheckoutOptions(BaseModel): line_item_name: str | None = None +class StripeSubscriptionOptions(BaseModel): + class Config: + extra = "ignore" + + checking_id: str | None = None + payment_request: str | None = None + + class StripeCreateInvoiceOptions(BaseModel): class Config: extra = "ignore" @@ -50,6 +61,7 @@ class StripeCreateInvoiceOptions(BaseModel): fiat_method: FiatMethod = "checkout" terminal: StripeTerminalOptions | None = None checkout: StripeCheckoutOptions | None = None + subscription: StripeSubscriptionOptions | None = None class StripeWallet(FiatProvider): @@ -118,17 +130,125 @@ class StripeWallet(FiatProvider): if opts.fiat_method == "checkout": return await self._create_checkout_invoice( - amount_cents, currency, payment_hash, memo, opts + amount_cents, currency, payment_hash, memo, opts.checkout ) if opts.fiat_method == "terminal": return await self._create_terminal_invoice( - amount_cents, currency, payment_hash, opts + amount_cents, currency, payment_hash, opts.terminal ) + if opts.fiat_method == "subscription": + return self._create_subscription_invoice(opts.subscription) + return FiatInvoiceResponse( ok=False, error_message=f"Unsupported fiat_method: {opts.fiat_method}" ) + async def create_subscription( + self, + subscription_id: str, + quantity: int, + payment_options: FiatSubscriptionPaymentOptions, + **kwargs, + ) -> FiatSubscriptionResponse: + success_url = ( + payment_options.success_url + or settings.stripe_payment_success_url + or "https://lnbits.com" + ) + + if not payment_options.subscription_request_id: + payment_options.subscription_request_id = str(uuid.uuid4()) + payment_options.extra = payment_options.extra or {} + payment_options.extra["subscription_request_id"] = ( + payment_options.subscription_request_id + ) + + form_data: list[tuple[str, str]] = [ + ("mode", "subscription"), + ("success_url", success_url), + ("line_items[0][price]", subscription_id), + ("line_items[0][quantity]", f"{quantity}"), + ] + subscription_data = {**payment_options.dict(), "lnbits_action": "subscription"} + subscription_data["extra"] = json.dumps(subscription_data.get("extra") or {}) + + form_data += self._encode_metadata( + "subscription_data[metadata]", + subscription_data, + ) + + try: + r = await self.client.post( + "/v1/checkout/sessions", + headers=self._build_headers_form(), + content=urlencode(form_data), + ) + r.raise_for_status() + data = r.json() + + url = data.get("url") + if not url: + return FiatSubscriptionResponse( + ok=False, error_message="Server error: missing url" + ) + return FiatSubscriptionResponse( + ok=True, + checkout_session_url=url, + subscription_request_id=payment_options.subscription_request_id, + ) + except json.JSONDecodeError as exc: + logger.warning(exc) + return FiatSubscriptionResponse( + ok=False, error_message="Server error: invalid json response" + ) + except Exception as exc: + logger.warning(exc) + return FiatSubscriptionResponse( + ok=False, error_message=f"Unable to connect to {self.endpoint}." + ) + + async def cancel_subscription( + self, + subscription_id: str, + correlation_id: str, + **kwargs, + ) -> FiatSubscriptionResponse: + try: + params = { + "query": f"metadata['wallet_id']:'{correlation_id}'" + " AND " + f"metadata['subscription_request_id']:'{subscription_id}'" + } + r = await self.client.get( + "/v1/subscriptions/search", + params=params, + ) + r.raise_for_status() + search_result = r.json() + data = search_result.get("data") or [] + if not data or len(data) == 0: + return FiatSubscriptionResponse( + ok=False, error_message="Subscription not found." + ) + + subscription = data[0] + subscription_id = subscription.get("id") + if not subscription_id: + return FiatSubscriptionResponse( + ok=False, error_message="Subscription ID not found." + ) + + r = await self.client.delete(f"/v1/subscriptions/{subscription_id}") + r.raise_for_status() + + return FiatSubscriptionResponse(ok=True) + except Exception as exc: + logger.warning(exc) + return FiatSubscriptionResponse( + ok=False, error_message="Unable to un subscribe." + ) + async def pay_invoice(self, payment_request: str) -> FiatPaymentResponse: raise NotImplementedError("Stripe does not support paying invoices directly.") @@ -146,6 +266,11 @@ class StripeWallet(FiatProvider): r.raise_for_status() return self._status_from_payment_intent(r.json()) + if stripe_id.startswith("in_"): + r = await self.client.get(f"/v1/invoices/{stripe_id}") + r.raise_for_status() + return self._status_from_invoice(r.json()) + logger.debug(f"Unknown Stripe id prefix: {checking_id}") return FiatPaymentPendingStatus() @@ -176,9 +301,9 @@ class StripeWallet(FiatProvider): currency: str, payment_hash: str, memo: str | None, - opts: StripeCreateInvoiceOptions, + opts: StripeCheckoutOptions | None = None, ) -> FiatInvoiceResponse: - co = opts.checkout or StripeCheckoutOptions() + co = opts or StripeCheckoutOptions() success_url = ( co.success_url or settings.stripe_payment_success_url @@ -190,6 +315,7 @@ class StripeWallet(FiatProvider): ("mode", "payment"), ("success_url", success_url), ("metadata[payment_hash]", payment_hash), + ("metadata[lnbits_action]", "invoice"), ("line_items[0][price_data][currency]", currency.lower()), ("line_items[0][price_data][product_data][name]", line_item_name), ("line_items[0][price_data][unit_amount]", str(amount_cents)), @@ -228,9 +354,9 @@ class StripeWallet(FiatProvider): amount_cents: int, currency: str, payment_hash: str, - opts: StripeCreateInvoiceOptions, + opts: StripeTerminalOptions | None = None, ) -> FiatInvoiceResponse: - term = opts.terminal or StripeTerminalOptions() + term = opts or StripeTerminalOptions() data: dict[str, str] = { "amount": str(amount_cents), "currency": currency.lower(), @@ -265,6 +391,18 @@ class StripeWallet(FiatProvider): ok=False, error_message=f"Unable to connect to {self.endpoint}." ) + def _create_subscription_invoice( + self, + opts: StripeSubscriptionOptions | None = None, + ) -> FiatInvoiceResponse: + term = opts or StripeSubscriptionOptions() + + return FiatInvoiceResponse( + ok=True, + checking_id=term.checking_id or urlsafe_short_hash(), + payment_request=term.payment_request or "", + ) + def _normalize_stripe_id(self, checking_id: str) -> str: """Remove our internal prefix so Stripe sees a real id.""" return ( @@ -308,6 +446,18 @@ class StripeWallet(FiatProvider): return FiatPaymentPendingStatus() + def _status_from_invoice(self, invoice: dict) -> FiatPaymentStatus: + """Map an Invoice to LNbits fiat status.""" + status = invoice.get("status") + + if status == "paid": + return FiatPaymentSuccessStatus() + + if status in ["uncollectible", "void"]: + return FiatPaymentFailedStatus() + + return FiatPaymentPendingStatus() + def _build_headers_form(self) -> dict[str, str]: return {**self.headers, "Content-Type": "application/x-www-form-urlencoded"} @@ -316,7 +466,7 @@ class StripeWallet(FiatProvider): ) -> list[tuple[str, str]]: out: list[tuple[str, str]] = [] for k, v in (md or {}).items(): - out.append((f"{prefix}[{k}]", str(v))) + out.append((f"{prefix}[{k}]", str(v or ""))) return out def _parse_create_opts(