diff --git a/crud.py b/crud.py index 3046f0e..450c091 100644 --- a/crud.py +++ b/crud.py @@ -1,3 +1,4 @@ +import json from datetime import datetime, timedelta, timezone from lnbits.db import Database @@ -8,47 +9,115 @@ from .models import CreateEvent, Event, Ticket, TicketExtra db = Database("ext_events") +def _parse_ticket_row(row) -> dict: + """Normalize a ticket row before constructing a Ticket model. + + - Empty-string sentinels in name/email (used because the DB columns are + NOT NULL but the Pydantic field is Optional when user_id is set) are + converted back to None. + - The `extra` JSON column may come back as a string when the row is + fetched without a model= argument; parse it so Pydantic can build + TicketExtra from a dict. + """ + ticket_data = dict(row) + + if ticket_data.get("name") == "": + ticket_data["name"] = None + if ticket_data.get("email") == "": + ticket_data["email"] = None + + extra = ticket_data.get("extra") + if isinstance(extra, str): + ticket_data["extra"] = json.loads(extra) if extra else {} + + return ticket_data + + async def create_ticket( - payment_hash: str, wallet: str, event: str, name: str, email: str, extra: dict + payment_hash: str, + wallet: str, + event: str, + name: str | None = None, + email: str | None = None, + user_id: str | None = None, + extra: dict | None = None, ) -> Ticket: now = datetime.now(timezone.utc) - ticket = Ticket( + + # name/email columns are NOT NULL in the schema, so we store "" when only + # user_id is supplied. _parse_ticket_row reverses this on read. + if user_id: + db_name = "" + db_email = "" + else: + db_name = name or "" + db_email = email or "" + + db_ticket = Ticket( id=payment_hash, wallet=wallet, event=event, - name=name, - email=email, + name=db_name, + email=db_email, + user_id=user_id, + registered=False, + paid=False, + reg_timestamp=now, + time=now, + extra=TicketExtra(**extra) if extra else TicketExtra(), + ) + await db.insert("events.ticket", db_ticket) + + return Ticket( + id=payment_hash, + wallet=wallet, + event=event, + name=name, + email=email, + user_id=user_id, registered=False, paid=False, reg_timestamp=now, time=now, extra=TicketExtra(**extra) if extra else TicketExtra(), ) - await db.insert("events.ticket", ticket) - return ticket async def update_ticket(ticket: Ticket) -> Ticket: - await db.update("events.ticket", ticket) + ticket_dict = ticket.dict() + if ticket_dict.get("name") is None: + ticket_dict["name"] = "" + if ticket_dict.get("email") is None: + ticket_dict["email"] = "" + await db.update("events.ticket", Ticket(**ticket_dict)) return ticket async def get_ticket(payment_hash: str) -> Ticket | None: - return await db.fetchone( + row = await db.fetchone( "SELECT * FROM events.ticket WHERE id = :id", {"id": payment_hash}, - Ticket, ) + if not row: + return None + return Ticket(**_parse_ticket_row(row)) async def get_tickets(wallet_ids: str | list[str]) -> list[Ticket]: if isinstance(wallet_ids, str): wallet_ids = [wallet_ids] q = ",".join([f"'{wallet_id}'" for wallet_id in wallet_ids]) - return await db.fetchall( - f"SELECT * FROM events.ticket WHERE wallet IN ({q})", - model=Ticket, + rows = await db.fetchall(f"SELECT * FROM events.ticket WHERE wallet IN ({q})") + return [Ticket(**_parse_ticket_row(row)) for row in rows] + + +async def get_tickets_by_user_id(user_id: str) -> list[Ticket]: + """All tickets owned by the given LNbits user_id.""" + rows = await db.fetchall( + "SELECT * FROM events.ticket WHERE user_id = :user_id ORDER BY time DESC", + {"user_id": user_id}, ) + return [Ticket(**_parse_ticket_row(row)) for row in rows] async def delete_ticket(payment_hash: str) -> None: @@ -107,8 +176,8 @@ async def delete_event(event_id: str) -> None: async def get_event_tickets(event_id: str) -> list[Ticket]: - return await db.fetchall( + rows = await db.fetchall( "SELECT * FROM events.ticket WHERE event = :event", {"event": event_id}, - Ticket, ) + return [Ticket(**_parse_ticket_row(row)) for row in rows] diff --git a/models.py b/models.py index fa1569f..0216653 100644 --- a/models.py +++ b/models.py @@ -1,7 +1,7 @@ from datetime import datetime from fastapi import Query -from pydantic import BaseModel, EmailStr, Field, validator +from pydantic import BaseModel, EmailStr, Field, root_validator, validator class PromoCode(BaseModel): @@ -94,21 +94,36 @@ class TicketExtra(BaseModel): class CreateTicket(BaseModel): - name: str - email: EmailStr + name: str | None = None + email: EmailStr | None = None + user_id: str | None = None # LNbits user id (alternative to name+email) promo_code: str | None = None refund_address: str | None = None nostr_identifier: str | None = None payment_method: str | None = None fiat_provider: str | None = None + @root_validator + def validate_identifiers(cls, values): + name = values.get("name") + email = values.get("email") + user_id = values.get("user_id") + if not user_id and not (name and email): + raise ValueError( + "Either user_id or both name and email must be provided" + ) + if user_id and (name or email): + raise ValueError("Cannot provide both user_id and name/email") + return values + class Ticket(BaseModel): id: str wallet: str event: str - name: str - email: str + name: str | None = None + email: str | None = None + user_id: str | None = None registered: bool paid: bool time: datetime @@ -118,7 +133,7 @@ class Ticket(BaseModel): class PublicTicket(BaseModel): event: str - name: str + name: str | None = None registered: bool paid: bool time: datetime