diff --git a/crud.py b/crud.py index 3046f0e..450c091 100644 --- a/crud.py +++ b/crud.py @@ -1,3 +1,4 @@ +import json from datetime import datetime, timedelta, timezone from lnbits.db import Database @@ -8,47 +9,115 @@ from .models import CreateEvent, Event, Ticket, TicketExtra db = Database("ext_events") +def _parse_ticket_row(row) -> dict: + """Normalize a ticket row before constructing a Ticket model. + + - Empty-string sentinels in name/email (used because the DB columns are + NOT NULL but the Pydantic field is Optional when user_id is set) are + converted back to None. + - The `extra` JSON column may come back as a string when the row is + fetched without a model= argument; parse it so Pydantic can build + TicketExtra from a dict. + """ + ticket_data = dict(row) + + if ticket_data.get("name") == "": + ticket_data["name"] = None + if ticket_data.get("email") == "": + ticket_data["email"] = None + + extra = ticket_data.get("extra") + if isinstance(extra, str): + ticket_data["extra"] = json.loads(extra) if extra else {} + + return ticket_data + + async def create_ticket( - payment_hash: str, wallet: str, event: str, name: str, email: str, extra: dict + payment_hash: str, + wallet: str, + event: str, + name: str | None = None, + email: str | None = None, + user_id: str | None = None, + extra: dict | None = None, ) -> Ticket: now = datetime.now(timezone.utc) - ticket = Ticket( + + # name/email columns are NOT NULL in the schema, so we store "" when only + # user_id is supplied. _parse_ticket_row reverses this on read. + if user_id: + db_name = "" + db_email = "" + else: + db_name = name or "" + db_email = email or "" + + db_ticket = Ticket( id=payment_hash, wallet=wallet, event=event, - name=name, - email=email, + name=db_name, + email=db_email, + user_id=user_id, + registered=False, + paid=False, + reg_timestamp=now, + time=now, + extra=TicketExtra(**extra) if extra else TicketExtra(), + ) + await db.insert("events.ticket", db_ticket) + + return Ticket( + id=payment_hash, + wallet=wallet, + event=event, + name=name, + email=email, + user_id=user_id, registered=False, paid=False, reg_timestamp=now, time=now, extra=TicketExtra(**extra) if extra else TicketExtra(), ) - await db.insert("events.ticket", ticket) - return ticket async def update_ticket(ticket: Ticket) -> Ticket: - await db.update("events.ticket", ticket) + ticket_dict = ticket.dict() + if ticket_dict.get("name") is None: + ticket_dict["name"] = "" + if ticket_dict.get("email") is None: + ticket_dict["email"] = "" + await db.update("events.ticket", Ticket(**ticket_dict)) return ticket async def get_ticket(payment_hash: str) -> Ticket | None: - return await db.fetchone( + row = await db.fetchone( "SELECT * FROM events.ticket WHERE id = :id", {"id": payment_hash}, - Ticket, ) + if not row: + return None + return Ticket(**_parse_ticket_row(row)) async def get_tickets(wallet_ids: str | list[str]) -> list[Ticket]: if isinstance(wallet_ids, str): wallet_ids = [wallet_ids] q = ",".join([f"'{wallet_id}'" for wallet_id in wallet_ids]) - return await db.fetchall( - f"SELECT * FROM events.ticket WHERE wallet IN ({q})", - model=Ticket, + rows = await db.fetchall(f"SELECT * FROM events.ticket WHERE wallet IN ({q})") + return [Ticket(**_parse_ticket_row(row)) for row in rows] + + +async def get_tickets_by_user_id(user_id: str) -> list[Ticket]: + """All tickets owned by the given LNbits user_id.""" + rows = await db.fetchall( + "SELECT * FROM events.ticket WHERE user_id = :user_id ORDER BY time DESC", + {"user_id": user_id}, ) + return [Ticket(**_parse_ticket_row(row)) for row in rows] async def delete_ticket(payment_hash: str) -> None: @@ -107,8 +176,8 @@ async def delete_event(event_id: str) -> None: async def get_event_tickets(event_id: str) -> list[Ticket]: - return await db.fetchall( + rows = await db.fetchall( "SELECT * FROM events.ticket WHERE event = :event", {"event": event_id}, - Ticket, ) + return [Ticket(**_parse_ticket_row(row)) for row in rows] diff --git a/migrations.py b/migrations.py index c055617..3664d69 100644 --- a/migrations.py +++ b/migrations.py @@ -189,3 +189,14 @@ async def m006_add_extra_fields(db): ) await _alter_add_column_safe(db, "ALTER TABLE events.events ADD COLUMN extra TEXT") await _alter_add_column_safe(db, "ALTER TABLE events.ticket ADD COLUMN extra TEXT") + + +async def m007_add_user_id_support(db): + """ + Add user_id column to ticket table so a ticket can reference an LNbits + user id instead of (name, email). Application logic enforces that exactly + one identifier scheme is used per ticket. + """ + await _alter_add_column_safe( + db, "ALTER TABLE events.ticket ADD COLUMN user_id TEXT" + ) diff --git a/models.py b/models.py index 14547d1..415a2e7 100644 --- a/models.py +++ b/models.py @@ -1,7 +1,7 @@ from datetime import datetime from fastapi import Query -from pydantic import BaseModel, EmailStr, Field, validator +from pydantic import BaseModel, EmailStr, Field, root_validator, validator class PromoCode(BaseModel): @@ -77,18 +77,33 @@ class TicketExtra(BaseModel): class CreateTicket(BaseModel): - name: str - email: EmailStr + name: str | None = None + email: EmailStr | None = None + user_id: str | None = None # LNbits user id (alternative to name+email) promo_code: str | None = None refund_address: str | None = None + @root_validator + def validate_identifiers(cls, values): + name = values.get("name") + email = values.get("email") + user_id = values.get("user_id") + if not user_id and not (name and email): + raise ValueError( + "Either user_id or both name and email must be provided" + ) + if user_id and (name or email): + raise ValueError("Cannot provide both user_id and name/email") + return values + class Ticket(BaseModel): id: str wallet: str event: str - name: str - email: str + name: str | None = None + email: str | None = None + user_id: str | None = None registered: bool paid: bool time: datetime @@ -98,7 +113,7 @@ class Ticket(BaseModel): class PublicTicket(BaseModel): event: str - name: str + name: str | None = None registered: bool paid: bool time: datetime diff --git a/views_api.py b/views_api.py index 73cecd3..d1d8631 100644 --- a/views_api.py +++ b/views_api.py @@ -33,6 +33,7 @@ from .crud import ( get_events, get_ticket, get_tickets, + get_tickets_by_user_id, purge_unpaid_tickets, update_event, update_ticket, @@ -177,6 +178,16 @@ async def api_tickets( return await get_tickets(wallet_ids) +@tickets_api_router.get("/user/{user_id}") +async def api_tickets_by_user_id(user_id: str) -> list[Ticket]: + """Tickets bound to an LNbits user_id (used by external integrations). + + Declared before /{ticket_id} so FastAPI matches the literal `/user/` + prefix instead of treating "user" as a ticket id. + """ + return await get_tickets_by_user_id(user_id) + + @tickets_api_router.get("/{ticket_id}", response_model=PublicTicket) async def api_get_ticket(ticket_id: str) -> Ticket: ticket = await get_ticket(ticket_id) @@ -193,7 +204,9 @@ async def api_get_ticket(ticket_id: str) -> Ticket: @tickets_api_router.post("/{event_id}") -async def api_ticket_create(event_id: str, data: CreateTicket) -> TicketPaymentRequest: +async def api_ticket_create( + event_id: str, data: CreateTicket +) -> TicketPaymentRequest: event = await get_event(event_id) if not event: raise HTTPException( @@ -206,6 +219,14 @@ async def api_ticket_create(event_id: str, data: CreateTicket) -> TicketPaymentR if event.amount_tickets > 0 and event.sold >= event.amount_tickets: raise HTTPException(status_code=HTTPStatus.GONE, detail="Event is sold out.") + if data.user_id: + return await _create_user_id_ticket(event, data.user_id) + return await _create_named_ticket(event, data) + + +async def _create_named_ticket( + event: Event, data: CreateTicket +) -> TicketPaymentRequest: name = data.name email = data.email promo_code = data.promo_code.upper() if data.promo_code else None @@ -214,12 +235,10 @@ async def api_ticket_create(event_id: str, data: CreateTicket) -> TicketPaymentR extra: dict[str, Any] = {"tag": "events", "name": name, "email": email} if promo_code: - # check if promo_code exists in event.extra.promo_codes if promo_code not in [pc.code for pc in event.extra.promo_codes]: raise HTTPException( status_code=HTTPStatus.BAD_REQUEST, detail="Invalid promo code." ) - # get the promocode promo = next(pc for pc in event.extra.promo_codes if pc.code == promo_code) extra["promo_code"] = promo.code price = event.price_per_ticket * (1 - promo.discount_percent / 100) @@ -229,13 +248,12 @@ async def api_ticket_create(event_id: str, data: CreateTicket) -> TicketPaymentR extra["currency"] = event.currency extra["fiatAmount"] = price extra["rate"] = await get_fiat_rate_satoshis(event.currency) - price = await fiat_amount_as_satoshis(price, event.currency) payment = await create_invoice( wallet_id=event.wallet, amount=price, - memo=f"{event_id}", + memo=f"{event.id}", extra=extra, ) await create_ticket( @@ -250,7 +268,36 @@ async def api_ticket_create(event_id: str, data: CreateTicket) -> TicketPaymentR "sats_paid": int(price), }, ) + return TicketPaymentRequest( + payment_hash=payment.payment_hash, payment_request=payment.bolt11 + ) + +async def _create_user_id_ticket( + event: Event, user_id: str +) -> TicketPaymentRequest: + price = event.price_per_ticket + extra: dict[str, Any] = {"tag": "events", "user_id": user_id} + + if event.currency != "sats": + price = await fiat_amount_as_satoshis(event.price_per_ticket, event.currency) + extra["fiat"] = True + extra["currency"] = event.currency + extra["fiatAmount"] = event.price_per_ticket + extra["rate"] = await get_fiat_rate_satoshis(event.currency) + + payment = await create_invoice( + wallet_id=event.wallet, + amount=price, + memo=f"{event.id}", + extra=extra, + ) + await create_ticket( + payment_hash=payment.payment_hash, + wallet=event.wallet, + event=event.id, + user_id=user_id, + ) return TicketPaymentRequest( payment_hash=payment.payment_hash, payment_request=payment.bolt11 )