[feat] Stripe subscription (#3369)

This commit is contained in:
Vlad Stan 2025-10-17 01:14:06 +03:00 committed by GitHub
parent 182894fd93
commit bf06def9b7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 472 additions and 107 deletions

View file

@ -5,7 +5,7 @@ import time
from loguru import logger
from lnbits.core.crud import get_wallet
from lnbits.core.crud.payments import create_payment, get_standalone_payment
from lnbits.core.crud.payments import create_payment
from lnbits.core.models import CreatePayment, Payment, PaymentState
from lnbits.core.models.misc import SimpleStatus
from lnbits.db import Connection
@ -27,6 +27,73 @@ async def handle_fiat_payment_confirmation(
logger.warning(e)
def check_stripe_signature(
payload: bytes,
sig_header: str | None,
secret: str | None,
tolerance_seconds=300,
):
if not sig_header:
logger.warning("Stripe-Signature header is missing.")
raise ValueError("Stripe-Signature header is missing.")
if not secret:
logger.warning("Stripe webhook signing secret is not set.")
raise ValueError("Stripe webhook cannot be verified.")
# Split the Stripe-Signature header
items = dict(i.split("=") for i in sig_header.split(","))
timestamp = int(items["t"])
signature = items["v1"]
# Check timestamp tolerance
if abs(time.time() - timestamp) > tolerance_seconds:
logger.warning("Timestamp outside tolerance.")
logger.debug(
f"Current time: {time.time()}, "
f"Timestamp: {timestamp}, "
f"Tolerance: {tolerance_seconds} seconds"
)
raise ValueError("Timestamp outside tolerance." f"Timestamp: {timestamp}")
signed_payload = f"{timestamp}.{payload.decode()}"
# Compute HMAC SHA256 using the webhook secret
computed_signature = hmac.new(
key=secret.encode(), msg=signed_payload.encode(), digestmod=hashlib.sha256
).hexdigest()
# Compare signatures using constant time comparison
if hmac.compare_digest(computed_signature, signature) is not True:
logger.warning("Stripe signature verification failed.")
raise ValueError("Stripe signature verification failed.")
async def test_connection(provider: str) -> SimpleStatus:
"""
Test the connection to Stripe by checking if the API key is valid.
This function should be called when setting up or testing the Stripe integration.
"""
fiat_provider = await get_fiat_provider(provider)
if not fiat_provider:
return SimpleStatus(
success=False,
message=f"Fiat provider '{provider}' not found.",
)
status = await fiat_provider.status()
if status.error_message:
return SimpleStatus(
success=False,
message=f"Cconnection test failed: {status.error_message}",
)
return SimpleStatus(
success=True,
message="Connection test successful." f" Balance: {status.balance}.",
)
async def _credit_fiat_service_fee_wallet(
payment: Payment, conn: Connection | None = None
):
@ -104,90 +171,3 @@ async def _debit_fiat_service_faucet_wallet(
status=PaymentState.SUCCESS,
conn=conn,
)
async def handle_stripe_event(event: dict):
event_id = event.get("id")
event_object = event.get("data", {}).get("object", {})
object_type = event_object.get("object")
payment_hash = event_object.get("metadata", {}).get("payment_hash")
logger.debug(
f"Handling Stripe event: '{event_id}'. Type: '{object_type}'."
f" Payment hash: '{payment_hash}'."
)
if not payment_hash:
logger.warning("Stripe event does not contain a payment hash.")
return
payment = await get_standalone_payment(payment_hash)
if not payment:
logger.warning(f"No payment found for hash: '{payment_hash}'.")
return
await payment.check_fiat_status()
def check_stripe_signature(
payload: bytes,
sig_header: str | None,
secret: str | None,
tolerance_seconds=300,
):
if not sig_header:
logger.warning("Stripe-Signature header is missing.")
raise ValueError("Stripe-Signature header is missing.")
if not secret:
logger.warning("Stripe webhook signing secret is not set.")
raise ValueError("Stripe webhook cannot be verified.")
# Split the Stripe-Signature header
items = dict(i.split("=") for i in sig_header.split(","))
timestamp = int(items["t"])
signature = items["v1"]
# Check timestamp tolerance
if abs(time.time() - timestamp) > tolerance_seconds:
logger.warning("Timestamp outside tolerance.")
logger.debug(
f"Current time: {time.time()}, "
f"Timestamp: {timestamp}, "
f"Tolerance: {tolerance_seconds} seconds"
)
raise ValueError("Timestamp outside tolerance." f"Timestamp: {timestamp}")
signed_payload = f"{timestamp}.{payload.decode()}"
# Compute HMAC SHA256 using the webhook secret
computed_signature = hmac.new(
key=secret.encode(), msg=signed_payload.encode(), digestmod=hashlib.sha256
).hexdigest()
# Compare signatures using constant time comparison
if hmac.compare_digest(computed_signature, signature) is not True:
logger.warning("Stripe signature verification failed.")
raise ValueError("Stripe signature verification failed.")
async def test_connection(provider: str) -> SimpleStatus:
"""
Test the connection to Stripe by checking if the API key is valid.
This function should be called when setting up or testing the Stripe integration.
"""
fiat_provider = await get_fiat_provider(provider)
if not fiat_provider:
return SimpleStatus(
success=False,
message=f"Fiat provider '{provider}' not found.",
)
status = await fiat_provider.status()
if status.error_message:
return SimpleStatus(
success=False,
message=f"Cconnection test failed: {status.error_message}",
)
return SimpleStatus(
success=True,
message="Connection test successful." f" Balance: {status.balance}.",
)

View file

@ -110,7 +110,7 @@ async def create_payment_request(
async def create_fiat_invoice(
wallet_id: str, invoice_data: CreateInvoice, conn: Connection | None = None
):
) -> Payment:
fiat_provider_name = invoice_data.fiat_provider
if not fiat_provider_name:
raise ValueError("Fiat provider is required for fiat invoices.")

View file

@ -103,10 +103,10 @@
<q-card-section>
<span v-text="$t('webhook_events_list')"></span>
<ul>
<li><code>checkout.session.async_payment_failed</code></li>
<li><code>checkout.session.async_payment_succeeded</code></li>
<li><code>checkout.session.completed</code></li>
<li><code>checkout.session.expired</code></li>
- the user completed the checkout process
<li><code>invoice.paid</code></li>
- the invoice was successfully paid (for subscriptions)
</ul>
</q-card-section>
</q-expansion-item>

View file

@ -1,10 +1,18 @@
from fastapi import APIRouter, Request
import json
from fastapi import APIRouter, Request
from loguru import logger
from lnbits.core.crud.payments import (
get_standalone_payment,
)
from lnbits.core.models.misc import SimpleStatus
from lnbits.core.models.payments import CreateInvoice
from lnbits.core.services.fiat_providers import (
check_stripe_signature,
handle_stripe_event,
)
from lnbits.core.services.payments import create_fiat_invoice
from lnbits.fiat.base import FiatSubscriptionPaymentOptions
from lnbits.settings import settings
callback_router = APIRouter(prefix="/api/v1/callback", tags=["callback"])
@ -33,3 +41,105 @@ async def api_generic_webhook_handler(
success=False,
message=f"Unknown fiat provider '{provider_name}'.",
)
async def handle_stripe_event(event: dict):
event_id = event.get("id")
event_type = event.get("type")
if event_type == "checkout.session.completed":
await _handle_stripe_checkout_session_completed(event)
elif event_type == "invoice.paid":
await _handle_stripe_subscription_invoice_paid(event)
else:
logger.info(
f"Unhandled Stripe event type: '{event_type}'." f" Event ID: '{event_id}'."
)
async def _handle_stripe_checkout_session_completed(event: dict):
event_id = event.get("id")
event_object = event.get("data", {}).get("object", {})
object_type = event_object.get("object")
payment_hash = event_object.get("metadata", {}).get("payment_hash")
lnbits_action = event_object.get("metadata", {}).get("lnbits_action")
logger.debug(
f"Handling Stripe event: '{event_id}'. Type: '{object_type}'."
f" Payment hash: '{payment_hash}'."
)
if lnbits_action != "invoice":
logger.warning(f"Stripe event is not an invoice: '{lnbits_action}'.")
return
if not payment_hash:
raise ValueError("Stripe event does not contain a payment hash.")
payment = await get_standalone_payment(payment_hash)
if not payment:
raise ValueError(f"No payment found for hash: '{payment_hash}'.")
await payment.check_fiat_status()
async def _handle_stripe_subscription_invoice_paid(event: dict):
invoice = event.get("data", {}).get("object", {})
parent = invoice.get("parent", {})
currency = invoice.get("currency", "").upper()
if not currency:
raise ValueError("Stripe invoice.paid event missing 'currency'.")
amount_paid = invoice.get("amount_paid")
if not amount_paid:
raise ValueError("Stripe invoice.paid event missing 'amount_paid'.")
payment_options = await _get_stripe_subscription_payment_options(parent)
if not payment_options.wallet_id:
raise ValueError("Stripe invoice.paid event missing 'wallet_id' in metadata.")
memo = " | ".join(
[i.get("description", "") for i in invoice.get("lines", {}).get("data", [])]
+ [payment_options.memo or "", invoice.get("customer_email", "")]
)
extra = {
**(payment_options.extra or {}),
"fiat_method": "subscription",
"tag": payment_options.tag,
"subscription": {
"checking_id": invoice.get("id"),
"payment_request": invoice.get("hosted_invoice_url"),
},
}
payment = await create_fiat_invoice(
wallet_id=payment_options.wallet_id,
invoice_data=CreateInvoice(
unit=currency,
amount=amount_paid / 100, # convert cents to dollars
memo=memo,
extra=extra,
fiat_provider="stripe",
),
)
await payment.check_fiat_status()
async def _get_stripe_subscription_payment_options(
parent: dict,
) -> FiatSubscriptionPaymentOptions:
if not parent or not parent.get("type") == "subscription_details":
raise ValueError("Stripe invoice.paid event does not contain a subscription.")
metadata = parent.get("subscription_details", {}).get("metadata", {})
if metadata.get("lnbits_action") != "subscription":
raise ValueError("Stripe invoice.paid metadata action is not 'subscription'.")
if "extra" in metadata:
try:
metadata["extra"] = json.loads(metadata["extra"])
except json.JSONDecodeError as exc:
logger.warning(exc)
metadata["extra"] = {}
return FiatSubscriptionPaymentOptions(**metadata)

View file

@ -3,9 +3,11 @@ from http import HTTPStatus
from fastapi import APIRouter, Depends, HTTPException
from lnbits.core.models.misc import SimpleStatus
from lnbits.core.models.wallets import WalletTypeInfo
from lnbits.core.services.fiat_providers import test_connection
from lnbits.decorators import check_admin
from lnbits.decorators import check_admin, require_admin_key
from lnbits.fiat import StripeWallet, get_fiat_provider
from lnbits.fiat.base import CreateFiatSubscription, FiatSubscriptionResponse
fiat_router = APIRouter(tags=["Fiat API"], prefix="/api/v1/fiat")
@ -19,20 +21,67 @@ async def api_test_fiat_provider(provider: str) -> SimpleStatus:
return await test_connection(provider)
@fiat_router.post(
"/{provider}/subscription",
status_code=HTTPStatus.OK,
)
async def create_subscription(
provider: str,
data: CreateFiatSubscription,
key_type: WalletTypeInfo = Depends(require_admin_key),
) -> FiatSubscriptionResponse:
fiat_provider = await get_fiat_provider(provider)
if not fiat_provider:
raise HTTPException(404, "Fiat provider not found")
wallet_id = data.payment_options.wallet_id
if wallet_id and wallet_id != key_type.wallet.id:
raise HTTPException(
403,
"Wallet id does not match your API key."
"Leave it empty to use your key's wallet.",
)
data.payment_options.wallet_id = key_type.wallet.id
subscription_response = await fiat_provider.create_subscription(
data.subscription_id, data.quantity, data.payment_options
)
return subscription_response
@fiat_router.delete(
"/{provider}/subscription/{subscription_id}",
status_code=HTTPStatus.OK,
)
async def cancel_subscription(
provider: str,
subscription_id: str,
key_type: WalletTypeInfo = Depends(require_admin_key),
) -> FiatSubscriptionResponse:
fiat_provider = await get_fiat_provider(provider)
if not fiat_provider:
raise HTTPException(404, "Fiat provider not found")
resp = await fiat_provider.cancel_subscription(subscription_id, key_type.wallet.id)
return resp
@fiat_router.post(
"/{provider}/connection_token",
status_code=HTTPStatus.OK,
dependencies=[Depends(check_admin)],
)
async def connection_token(provider: str):
provider_wallet = await get_fiat_provider(provider)
fiat_provider = await get_fiat_provider(provider)
if provider == "stripe":
if not isinstance(provider_wallet, StripeWallet):
if not isinstance(fiat_provider, StripeWallet):
raise HTTPException(
status_code=500, detail="Stripe wallet/provider not configured"
)
try:
tok = await provider_wallet.create_terminal_connection_token()
tok = await fiat_provider.create_terminal_connection_token()
secret = tok.get("secret")
if not secret:
raise HTTPException(

View file

@ -4,6 +4,8 @@ from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator, Coroutine
from typing import TYPE_CHECKING, Any, NamedTuple
from pydantic import BaseModel, Field
if TYPE_CHECKING:
pass
@ -76,6 +78,54 @@ class FiatPaymentStatus(NamedTuple):
return "pending"
class FiatSubscriptionPaymentOptions(BaseModel):
memo: str | None = Field(
default=None,
description="Payments created by the recurring subscription"
" will have this memo.",
)
wallet_id: str | None = Field(
default=None,
description="Payments created by the recurring subscription"
" will be made to this wallet.",
)
subscription_request_id: str | None = Field(
default=None,
description="Unique ID that can be used to identify the subscription request."
"If not provided, one will be generated.",
)
tag: str | None = Field(
default=None,
description="Payments created by the recurring subscription"
" will have this tag. Admin only.",
)
extra: dict[str, Any] | None = Field(
default=None,
description="Payments created by the recurring subscription"
" will merge this extra data to the payment extra. Admin only.",
)
success_url: str | None = Field(
default="https://my.lnbits.com",
description="The URL to redirect the user to after the"
" subscription is successfully created.",
)
class CreateFiatSubscription(BaseModel):
subscription_id: str
quantity: int
payment_options: FiatSubscriptionPaymentOptions
class FiatSubscriptionResponse(BaseModel):
ok: bool = True
subscription_request_id: str | None = None
checkout_session_url: str | None = None
error_message: str | None = None
class FiatPaymentSuccessStatus(FiatPaymentStatus):
paid = True
@ -111,6 +161,32 @@ class FiatProvider(ABC):
) -> Coroutine[None, None, FiatInvoiceResponse]:
pass
@abstractmethod
def create_subscription(
self,
subscription_id: str,
quantity: int,
payment_options: FiatSubscriptionPaymentOptions,
**kwargs,
) -> Coroutine[None, None, FiatSubscriptionResponse]:
pass
@abstractmethod
def cancel_subscription(
self,
subscription_id: str,
correlation_id: str,
**kwargs,
) -> Coroutine[None, None, FiatSubscriptionResponse]:
"""
Cancel a subscription.
Args:
subscription_id: The ID of the subscription to cancel.
correlation_id: An identifier used to verify that the subscription belongs
to the user that made the request. Usually the wallet ID.
"""
pass
@abstractmethod
def pay_invoice(
self,

View file

@ -1,5 +1,6 @@
import asyncio
import json
import uuid
from collections.abc import AsyncGenerator
from datetime import datetime, timedelta, timezone
from typing import Any, Literal
@ -9,7 +10,7 @@ import httpx
from loguru import logger
from pydantic import BaseModel, Field, ValidationError
from lnbits.helpers import normalize_endpoint
from lnbits.helpers import normalize_endpoint, urlsafe_short_hash
from lnbits.settings import settings
from .base import (
@ -21,9 +22,11 @@ from .base import (
FiatPaymentSuccessStatus,
FiatProvider,
FiatStatusResponse,
FiatSubscriptionPaymentOptions,
FiatSubscriptionResponse,
)
FiatMethod = Literal["checkout", "terminal"]
FiatMethod = Literal["checkout", "terminal", "subscription"]
class StripeTerminalOptions(BaseModel):
@ -43,6 +46,14 @@ class StripeCheckoutOptions(BaseModel):
line_item_name: str | None = None
class StripeSubscriptionOptions(BaseModel):
class Config:
extra = "ignore"
checking_id: str | None = None
payment_request: str | None = None
class StripeCreateInvoiceOptions(BaseModel):
class Config:
extra = "ignore"
@ -50,6 +61,7 @@ class StripeCreateInvoiceOptions(BaseModel):
fiat_method: FiatMethod = "checkout"
terminal: StripeTerminalOptions | None = None
checkout: StripeCheckoutOptions | None = None
subscription: StripeSubscriptionOptions | None = None
class StripeWallet(FiatProvider):
@ -118,17 +130,125 @@ class StripeWallet(FiatProvider):
if opts.fiat_method == "checkout":
return await self._create_checkout_invoice(
amount_cents, currency, payment_hash, memo, opts
amount_cents, currency, payment_hash, memo, opts.checkout
)
if opts.fiat_method == "terminal":
return await self._create_terminal_invoice(
amount_cents, currency, payment_hash, opts
amount_cents, currency, payment_hash, opts.terminal
)
if opts.fiat_method == "subscription":
return self._create_subscription_invoice(opts.subscription)
return FiatInvoiceResponse(
ok=False, error_message=f"Unsupported fiat_method: {opts.fiat_method}"
)
async def create_subscription(
self,
subscription_id: str,
quantity: int,
payment_options: FiatSubscriptionPaymentOptions,
**kwargs,
) -> FiatSubscriptionResponse:
success_url = (
payment_options.success_url
or settings.stripe_payment_success_url
or "https://lnbits.com"
)
if not payment_options.subscription_request_id:
payment_options.subscription_request_id = str(uuid.uuid4())
payment_options.extra = payment_options.extra or {}
payment_options.extra["subscription_request_id"] = (
payment_options.subscription_request_id
)
form_data: list[tuple[str, str]] = [
("mode", "subscription"),
("success_url", success_url),
("line_items[0][price]", subscription_id),
("line_items[0][quantity]", f"{quantity}"),
]
subscription_data = {**payment_options.dict(), "lnbits_action": "subscription"}
subscription_data["extra"] = json.dumps(subscription_data.get("extra") or {})
form_data += self._encode_metadata(
"subscription_data[metadata]",
subscription_data,
)
try:
r = await self.client.post(
"/v1/checkout/sessions",
headers=self._build_headers_form(),
content=urlencode(form_data),
)
r.raise_for_status()
data = r.json()
url = data.get("url")
if not url:
return FiatSubscriptionResponse(
ok=False, error_message="Server error: missing url"
)
return FiatSubscriptionResponse(
ok=True,
checkout_session_url=url,
subscription_request_id=payment_options.subscription_request_id,
)
except json.JSONDecodeError as exc:
logger.warning(exc)
return FiatSubscriptionResponse(
ok=False, error_message="Server error: invalid json response"
)
except Exception as exc:
logger.warning(exc)
return FiatSubscriptionResponse(
ok=False, error_message=f"Unable to connect to {self.endpoint}."
)
async def cancel_subscription(
self,
subscription_id: str,
correlation_id: str,
**kwargs,
) -> FiatSubscriptionResponse:
try:
params = {
"query": f"metadata['wallet_id']:'{correlation_id}'"
" AND "
f"metadata['subscription_request_id']:'{subscription_id}'"
}
r = await self.client.get(
"/v1/subscriptions/search",
params=params,
)
r.raise_for_status()
search_result = r.json()
data = search_result.get("data") or []
if not data or len(data) == 0:
return FiatSubscriptionResponse(
ok=False, error_message="Subscription not found."
)
subscription = data[0]
subscription_id = subscription.get("id")
if not subscription_id:
return FiatSubscriptionResponse(
ok=False, error_message="Subscription ID not found."
)
r = await self.client.delete(f"/v1/subscriptions/{subscription_id}")
r.raise_for_status()
return FiatSubscriptionResponse(ok=True)
except Exception as exc:
logger.warning(exc)
return FiatSubscriptionResponse(
ok=False, error_message="Unable to un subscribe."
)
async def pay_invoice(self, payment_request: str) -> FiatPaymentResponse:
raise NotImplementedError("Stripe does not support paying invoices directly.")
@ -146,6 +266,11 @@ class StripeWallet(FiatProvider):
r.raise_for_status()
return self._status_from_payment_intent(r.json())
if stripe_id.startswith("in_"):
r = await self.client.get(f"/v1/invoices/{stripe_id}")
r.raise_for_status()
return self._status_from_invoice(r.json())
logger.debug(f"Unknown Stripe id prefix: {checking_id}")
return FiatPaymentPendingStatus()
@ -176,9 +301,9 @@ class StripeWallet(FiatProvider):
currency: str,
payment_hash: str,
memo: str | None,
opts: StripeCreateInvoiceOptions,
opts: StripeCheckoutOptions | None = None,
) -> FiatInvoiceResponse:
co = opts.checkout or StripeCheckoutOptions()
co = opts or StripeCheckoutOptions()
success_url = (
co.success_url
or settings.stripe_payment_success_url
@ -190,6 +315,7 @@ class StripeWallet(FiatProvider):
("mode", "payment"),
("success_url", success_url),
("metadata[payment_hash]", payment_hash),
("metadata[lnbits_action]", "invoice"),
("line_items[0][price_data][currency]", currency.lower()),
("line_items[0][price_data][product_data][name]", line_item_name),
("line_items[0][price_data][unit_amount]", str(amount_cents)),
@ -228,9 +354,9 @@ class StripeWallet(FiatProvider):
amount_cents: int,
currency: str,
payment_hash: str,
opts: StripeCreateInvoiceOptions,
opts: StripeTerminalOptions | None = None,
) -> FiatInvoiceResponse:
term = opts.terminal or StripeTerminalOptions()
term = opts or StripeTerminalOptions()
data: dict[str, str] = {
"amount": str(amount_cents),
"currency": currency.lower(),
@ -265,6 +391,18 @@ class StripeWallet(FiatProvider):
ok=False, error_message=f"Unable to connect to {self.endpoint}."
)
def _create_subscription_invoice(
self,
opts: StripeSubscriptionOptions | None = None,
) -> FiatInvoiceResponse:
term = opts or StripeSubscriptionOptions()
return FiatInvoiceResponse(
ok=True,
checking_id=term.checking_id or urlsafe_short_hash(),
payment_request=term.payment_request or "",
)
def _normalize_stripe_id(self, checking_id: str) -> str:
"""Remove our internal prefix so Stripe sees a real id."""
return (
@ -308,6 +446,18 @@ class StripeWallet(FiatProvider):
return FiatPaymentPendingStatus()
def _status_from_invoice(self, invoice: dict) -> FiatPaymentStatus:
"""Map an Invoice to LNbits fiat status."""
status = invoice.get("status")
if status == "paid":
return FiatPaymentSuccessStatus()
if status in ["uncollectible", "void"]:
return FiatPaymentFailedStatus()
return FiatPaymentPendingStatus()
def _build_headers_form(self) -> dict[str, str]:
return {**self.headers, "Content-Type": "application/x-www-form-urlencoded"}
@ -316,7 +466,7 @@ class StripeWallet(FiatProvider):
) -> list[tuple[str, str]]:
out: list[tuple[str, str]] = []
for k, v in (md or {}).items():
out.append((f"{prefix}[{k}]", str(v)))
out.append((f"{prefix}[{k}]", str(v or "")))
return out
def _parse_create_opts(