feat: parse nested pydantic models fetchone and fetchall + add shortcuts for insert_query and update_query into Database (#2714)

* feat: add shortcuts for insert_query and update_query into `Database`
example: await db.insert("table_name", base_model)
* remove where from argument
* chore: code clean-up
* extension manager
* lnbits-qrcode  components
* parse date from dict
* refactor: make `settings` a fixture
* chore: remove verbose key names
* fix: time column
* fix: cast balance to `int`
* extension toggle vue3
* vue3 @input migration
* fix: payment extra and payment hash
* fix dynamic fields and ext db migration
* remove shadow on cards in dark theme
* screwed up and made more css pushes to this branch
* attempt to make chip component in settings dynamic fields
* dynamic chips
* qrscanner
* clean init admin settings
* make get_user better
* add dbversion model
* remove update_payment_status/extra/details
* traces for value and assertion errors
* refactor services
* add PaymentFiatAmount
* return Payment on api endpoints
* rename to get_user_from_account
* refactor: just refactor (#2740)
* rc5
* Fix db cache (#2741)
* [refactor] split services.py (#2742)
* refactor: spit `core.py` (#2743)
* refactor: make QR more customizable
* fix: print.html
* fix: qrcode options
* fix: white shadow on dark theme
* fix: datetime wasnt parsed in dict_to_model
* add timezone for conversion
* only parse timestamp for sqlite, postgres does it
* log internal payment success
* fix: export wallet to phone QR
* Adding a customisable border theme, like gradient (#2746)
* fixed mobile scan btn
* fix test websocket
* fix get_payments tests
* dict_to_model skip none values
* preimage none instead of defaulting to 0000...
* fixup test real invoice tests
* fixed pheonixd for wss
* fix nodemanager test settings
* fix lnbits funding
* only insert extension when they dont exist

---------

Co-authored-by: Vlad Stan <stan.v.vlad@gmail.com>
Co-authored-by: Tiago Vasconcelos <talvasconcelos@gmail.com>
Co-authored-by: Arc <ben@arc.wales>
Co-authored-by: Arc <33088785+arcbtc@users.noreply.github.com>
This commit is contained in:
dni ⚡ 2024-10-29 09:58:22 +01:00 committed by GitHub
parent ae4eda04ba
commit 2940cf97c5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
84 changed files with 4220 additions and 3776 deletions

View file

@ -6,7 +6,7 @@ import shutil
import sys import sys
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from pathlib import Path from pathlib import Path
from typing import Callable, List, Optional from typing import Callable, Optional
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
@ -17,12 +17,13 @@ from slowapi.util import get_remote_address
from starlette.middleware.sessions import SessionMiddleware from starlette.middleware.sessions import SessionMiddleware
from lnbits.core.crud import ( from lnbits.core.crud import (
add_installed_extension, get_db_version,
get_dbversions,
get_installed_extensions, get_installed_extensions,
update_installed_extension_state, update_installed_extension_state,
) )
from lnbits.core.extensions.extension_manager import deactivate_extension from lnbits.core.extensions.extension_manager import (
deactivate_extension,
)
from lnbits.core.extensions.helpers import version_parse from lnbits.core.extensions.helpers import version_parse
from lnbits.core.helpers import migrate_extension_database from lnbits.core.helpers import migrate_extension_database
from lnbits.core.tasks import ( # watchdog_task from lnbits.core.tasks import ( # watchdog_task
@ -47,7 +48,7 @@ from lnbits.wallets import get_funding_source, set_funding_source
from .commands import migrate_databases from .commands import migrate_databases
from .core import init_core_routers from .core import init_core_routers
from .core.db import core_app_extra from .core.db import core_app_extra
from .core.extensions.models import Extension, InstallableExtension from .core.extensions.models import Extension, ExtensionMeta, InstallableExtension
from .core.services import check_admin_settings, check_webpush_settings from .core.services import check_admin_settings, check_webpush_settings
from .middleware import ( from .middleware import (
CustomGZipMiddleware, CustomGZipMiddleware,
@ -252,7 +253,7 @@ async def check_installed_extensions(app: FastAPI):
async def build_all_installed_extensions_list( async def build_all_installed_extensions_list(
include_deactivated: Optional[bool] = True, include_deactivated: Optional[bool] = True,
) -> List[InstallableExtension]: ) -> list[InstallableExtension]:
""" """
Returns a list of all the installed extensions plus the extensions that Returns a list of all the installed extensions plus the extensions that
MUST be installed by default (see LNBITS_EXTENSIONS_DEFAULT_INSTALL). MUST be installed by default (see LNBITS_EXTENSIONS_DEFAULT_INSTALL).
@ -272,8 +273,13 @@ async def build_all_installed_extensions_list(
release = next((e for e in ext_releases if e.is_version_compatible), None) release = next((e for e in ext_releases if e.is_version_compatible), None)
if release: if release:
ext_meta = ExtensionMeta(installed_release=release)
ext_info = InstallableExtension( ext_info = InstallableExtension(
id=ext_id, name=ext_id, installed_release=release, icon=release.icon id=ext_id,
name=ext_id,
version=release.version,
icon=release.icon,
meta=ext_meta,
) )
installed_extensions.append(ext_info) installed_extensions.append(ext_info)
@ -304,14 +310,13 @@ async def check_installed_extension_files(ext: InstallableExtension) -> bool:
async def restore_installed_extension(app: FastAPI, ext: InstallableExtension): async def restore_installed_extension(app: FastAPI, ext: InstallableExtension):
await add_installed_extension(ext)
await update_installed_extension_state(ext_id=ext.id, active=True) await update_installed_extension_state(ext_id=ext.id, active=True)
extension = Extension.from_installable_ext(ext) extension = Extension.from_installable_ext(ext)
register_ext_routes(app, extension) register_ext_routes(app, extension)
current_version = (await get_dbversions()).get(ext.id, 0) current_version = await get_db_version(ext.id)
await migrate_extension_database(extension, current_version) await migrate_extension_database(ext, current_version)
# mount routes for the new version # mount routes for the new version
core_app_extra.register_new_ext_routes(extension) core_app_extra.register_new_ext_routes(extension)

View file

@ -3,7 +3,7 @@ import importlib
import time import time
from functools import wraps from functools import wraps
from pathlib import Path from pathlib import Path
from typing import List, Optional, Tuple from typing import Optional
import click import click
import httpx import httpx
@ -17,12 +17,13 @@ from lnbits.core.crud import (
delete_unused_wallets, delete_unused_wallets,
delete_wallet_by_id, delete_wallet_by_id,
delete_wallet_payment, delete_wallet_payment,
get_dbversions, get_db_versions,
get_installed_extension, get_installed_extension,
get_installed_extensions, get_installed_extensions,
get_payment,
get_payments, get_payments,
remove_deleted_wallets, remove_deleted_wallets,
update_payment_status, update_payment,
) )
from lnbits.core.extensions.models import ( from lnbits.core.extensions.models import (
CreateExtension, CreateExtension,
@ -122,7 +123,7 @@ def database_migrate():
async def db_versions(): async def db_versions():
"""Show current database versions""" """Show current database versions"""
async with core_db.connect() as conn: async with core_db.connect() as conn:
click.echo(await get_dbversions(conn)) click.echo(await get_db_versions(conn))
@db.command("cleanup-wallets") @db.command("cleanup-wallets")
@ -172,9 +173,10 @@ async def database_delete_wallet_payment(wallet: str, checking_id: str):
async def database_revert_payment(checking_id: str): async def database_revert_payment(checking_id: str):
"""Mark payment as pending""" """Mark payment as pending"""
async with core_db.connect() as conn: async with core_db.connect() as conn:
await update_payment_status( payment = await get_payment(checking_id=checking_id, conn=conn)
status=PaymentState.PENDING, checking_id=checking_id, conn=conn payment.status = PaymentState.PENDING
) await update_payment(payment, conn=conn)
click.echo(f"Payment '{checking_id}' marked as pending.")
@db.command("cleanup-accounts") @db.command("cleanup-accounts")
@ -231,7 +233,7 @@ async def check_invalid_payments(
click.echo("Funding source: " + str(funding_source)) click.echo("Funding source: " + str(funding_source))
# payments that are settled in the DB, but not at the Funding source level # payments that are settled in the DB, but not at the Funding source level
invalid_payments: List[Payment] = [] invalid_payments: list[Payment] = []
invalid_wallets = {} invalid_wallets = {}
for db_payment in settled_db_payments: for db_payment in settled_db_payments:
if verbose: if verbose:
@ -277,8 +279,10 @@ async def extensions_list():
from lnbits.app import build_all_installed_extensions_list from lnbits.app import build_all_installed_extensions_list
for ext in await build_all_installed_extensions_list(): for ext in await build_all_installed_extensions_list():
assert ext.installed_release, f"Extension {ext.id} has no installed_release" assert (
click.echo(f" - {ext.id} ({ext.installed_release.version})") ext.meta and ext.meta.installed_release
), f"Extension {ext.id} has no installed_release"
click.echo(f" - {ext.id} ({ext.meta.installed_release.version})")
@extensions.command("update") @extensions.command("update")
@ -461,7 +465,7 @@ async def install_extension(
source_repo: Optional[str] = None, source_repo: Optional[str] = None,
url: Optional[str] = None, url: Optional[str] = None,
admin_user: Optional[str] = None, admin_user: Optional[str] = None,
) -> Tuple[bool, str]: ) -> tuple[bool, str]:
try: try:
release = await _select_release(extension, repo_index, source_repo) release = await _select_release(extension, repo_index, source_repo)
if not release: if not release:
@ -490,7 +494,7 @@ async def update_extension(
source_repo: Optional[str] = None, source_repo: Optional[str] = None,
url: Optional[str] = None, url: Optional[str] = None,
admin_user: Optional[str] = None, admin_user: Optional[str] = None,
) -> Tuple[bool, str]: ) -> tuple[bool, str]:
try: try:
click.echo(f"Updating '{extension}' extension.") click.echo(f"Updating '{extension}' extension.")
installed_ext = await get_installed_extension(extension) installed_ext = await get_installed_extension(extension)
@ -503,7 +507,7 @@ async def update_extension(
click.echo(f"Current '{extension}' version: {installed_ext.installed_version}.") click.echo(f"Current '{extension}' version: {installed_ext.installed_version}.")
assert ( assert (
installed_ext.installed_release installed_ext.meta and installed_ext.meta.installed_release
), "Cannot find previously installed release. Please uninstall first." ), "Cannot find previously installed release. Please uninstall first."
release = await _select_release(extension, repo_index, source_repo) release = await _select_release(extension, repo_index, source_repo)
@ -511,7 +515,7 @@ async def update_extension(
return False, "No release selected." return False, "No release selected."
if ( if (
release.version == installed_ext.installed_version release.version == installed_ext.installed_version
and release.source_repo == installed_ext.installed_release.source_repo and release.source_repo == installed_ext.meta.installed_release.source_repo
): ):
click.echo(f"Extension '{extension}' already up to date.") click.echo(f"Extension '{extension}' already up to date.")
return False, "Already up to date" return False, "Already up to date"

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,162 @@
from .db_versions import (
delete_dbversion,
get_db_version,
get_db_versions,
update_migration_version,
)
from .extensions import (
create_installed_extension,
create_user_extension,
delete_installed_extension,
drop_extension_db,
get_installed_extension,
get_installed_extensions,
get_user_active_extensions_ids,
get_user_extension,
update_installed_extension,
update_installed_extension_state,
update_user_extension,
)
from .payments import (
DateTrunc,
check_internal,
create_payment,
delete_expired_invoices,
delete_wallet_payment,
get_latest_payments_by_extension,
get_payment,
get_payments,
get_payments_history,
get_payments_paginated,
get_standalone_payment,
get_wallet_payment,
is_internal_status_success,
mark_webhook_sent,
update_payment,
update_payment_checking_id,
update_payment_extra,
)
from .settings import (
create_admin_settings,
delete_admin_settings,
get_admin_settings,
get_super_settings,
update_admin_settings,
update_super_user,
)
from .tinyurl import create_tinyurl, delete_tinyurl, get_tinyurl, get_tinyurl_by_url
from .users import (
create_account,
delete_account,
delete_accounts_no_wallets,
get_account,
get_account_by_email,
get_account_by_pubkey,
get_account_by_username,
get_account_by_username_or_email,
get_accounts,
get_user,
get_user_from_account,
update_account,
)
from .wallets import (
create_wallet,
delete_unused_wallets,
delete_wallet,
delete_wallet_by_id,
force_delete_wallet,
get_total_balance,
get_wallet,
get_wallet_for_key,
get_wallets,
remove_deleted_wallets,
update_wallet,
)
from .webpush import (
create_webpush_subscription,
delete_webpush_subscription,
delete_webpush_subscriptions,
get_webpush_subscription,
get_webpush_subscriptions_for_user,
)
__all__ = [
# db_versions
"get_db_version",
"get_db_versions",
"update_migration_version",
"delete_dbversion",
# extensions
"create_installed_extension",
"create_user_extension",
"delete_installed_extension",
"drop_extension_db",
"get_installed_extension",
"get_installed_extensions",
"get_user_active_extensions_ids",
"get_user_extension",
"update_installed_extension",
"update_installed_extension_state",
"update_user_extension",
# payments
"DateTrunc",
"check_internal",
"create_payment",
"delete_expired_invoices",
"delete_wallet_payment",
"get_latest_payments_by_extension",
"get_payment",
"get_payments",
"get_payments_history",
"get_payments_paginated",
"get_standalone_payment",
"get_wallet_payment",
"is_internal_status_success",
"mark_webhook_sent",
"update_payment",
"update_payment_checking_id",
"update_payment_extra",
# settings
"create_admin_settings",
"delete_admin_settings",
"get_admin_settings",
"get_super_settings",
"update_admin_settings",
"update_super_user",
# tinyurl
"create_tinyurl",
"delete_tinyurl",
"get_tinyurl",
"get_tinyurl_by_url",
# users
"create_account",
"delete_account",
"delete_accounts_no_wallets",
"get_account",
"get_account_by_email",
"get_account_by_pubkey",
"get_account_by_username",
"get_account_by_username_or_email",
"get_accounts",
"get_user",
"get_user_from_account",
"update_account",
# wallets
"create_wallet",
"delete_unused_wallets",
"delete_wallet",
"delete_wallet_by_id",
"force_delete_wallet",
"get_total_balance",
"get_wallet",
"get_wallet_for_key",
"get_wallets",
"remove_deleted_wallets",
"update_wallet",
# webpush
"create_webpush_subscription",
"delete_webpush_subscription",
"delete_webpush_subscriptions",
"get_webpush_subscription",
"get_webpush_subscriptions_for_user",
]

View file

@ -0,0 +1,39 @@
from typing import Optional
from lnbits.core.db import db
from lnbits.db import Connection
from ..models import DbVersion
async def get_db_version(
ext_id: str, conn: Optional[Connection] = None
) -> Optional[DbVersion]:
return await (conn or db).fetchone(
"SELECT * FROM dbversions WHERE db = :ext_id",
{"ext_id": ext_id},
model=DbVersion,
)
async def get_db_versions(conn: Optional[Connection] = None) -> list[DbVersion]:
return await (conn or db).fetchall("SELECT * FROM dbversions", model=DbVersion)
async def update_migration_version(conn, db_name, version):
await (conn or db).execute(
"""
INSERT INTO dbversions (db, version) VALUES (:db, :version)
ON CONFLICT (db) DO UPDATE SET version = :version
""",
{"db": db_name, "version": version},
)
async def delete_dbversion(*, ext_id: str, conn: Optional[Connection] = None) -> None:
await (conn or db).execute(
"""
DELETE FROM dbversions WHERE db = :ext
""",
{"ext": ext_id},
)

View file

@ -0,0 +1,137 @@
from typing import Optional
from lnbits.core.db import db
from lnbits.core.extensions.models import (
InstallableExtension,
UserExtension,
)
from lnbits.db import Connection, Database
async def create_installed_extension(
ext: InstallableExtension,
conn: Optional[Connection] = None,
) -> None:
await (conn or db).insert("installed_extensions", ext)
async def update_installed_extension(
ext: InstallableExtension,
conn: Optional[Connection] = None,
) -> None:
await (conn or db).update("installed_extensions", ext)
async def update_installed_extension_state(
*, ext_id: str, active: bool, conn: Optional[Connection] = None
) -> None:
await (conn or db).execute(
"""
UPDATE installed_extensions SET active = :active WHERE id = :ext
""",
{"ext": ext_id, "active": active},
)
async def delete_installed_extension(
*, ext_id: str, conn: Optional[Connection] = None
) -> None:
await (conn or db).execute(
"""
DELETE from installed_extensions WHERE id = :ext
""",
{"ext": ext_id},
)
async def drop_extension_db(ext_id: str, conn: Optional[Connection] = None) -> None:
row: dict = await (conn or db).fetchone(
"SELECT * FROM dbversions WHERE db = :id",
{"id": ext_id},
)
# Check that 'ext_id' is a valid extension id and not a malicious string
assert row, f"Extension '{ext_id}' db version cannot be found"
is_file_based_db = await Database.clean_ext_db_files(ext_id)
if is_file_based_db:
return
# String formatting is required, params are not accepted for 'DROP SCHEMA'.
# The `ext_id` value is verified above.
await (conn or db).execute(
f"DROP SCHEMA IF EXISTS {ext_id} CASCADE",
)
async def get_installed_extension(
ext_id: str, conn: Optional[Connection] = None
) -> Optional[InstallableExtension]:
extension = await (conn or db).fetchone(
"SELECT * FROM installed_extensions WHERE id = :id",
{"id": ext_id},
InstallableExtension,
)
return extension
async def get_installed_extensions(
active: Optional[bool] = None,
conn: Optional[Connection] = None,
) -> list[InstallableExtension]:
where = "WHERE active = :active" if active is not None else ""
values = {"active": active} if active is not None else {}
all_extensions = await (conn or db).fetchall(
f"SELECT * FROM installed_extensions {where}",
values,
model=InstallableExtension,
)
return all_extensions
async def get_user_extension(
user_id: str, extension: str, conn: Optional[Connection] = None
) -> Optional[UserExtension]:
return await (conn or db).fetchone(
"""
SELECT * FROM extensions
WHERE "user" = :user AND extension = :ext
""",
{"user": user_id, "ext": extension},
model=UserExtension,
)
async def get_user_extensions(
user_id: str, conn: Optional[Connection] = None
) -> list[UserExtension]:
return await (conn or db).fetchall(
"""SELECT * FROM extensions WHERE "user" = :user""",
{"user": user_id},
model=UserExtension,
)
async def create_user_extension(
user_extension: UserExtension, conn: Optional[Connection] = None
) -> None:
await (conn or db).insert("extensions", user_extension)
async def update_user_extension(
user_extension: UserExtension, conn: Optional[Connection] = None
) -> None:
where = """WHERE extension = :extension AND "user" = :user"""
await (conn or db).update("extensions", user_extension, where)
async def get_user_active_extensions_ids(
user_id: str, conn: Optional[Connection] = None
) -> list[str]:
exts = await (conn or db).fetchall(
"""
SELECT * FROM extensions WHERE "user" = :user AND active
""",
{"user": user_id},
UserExtension,
)
return [ext.extension for ext in exts]

View file

@ -0,0 +1,385 @@
from time import time
from typing import Literal, Optional
from lnbits.core.crud.wallets import get_total_balance, get_wallet
from lnbits.core.db import db
from lnbits.core.models import PaymentState
from lnbits.db import DB_TYPE, SQLITE, Connection, Filters, Page
from ..models import (
CreatePayment,
Payment,
PaymentFilters,
PaymentHistoryPoint,
)
DateTrunc = Literal["hour", "day", "month"]
sqlite_formats = {
"hour": "%Y-%m-%d %H:00:00",
"day": "%Y-%m-%d 00:00:00",
"month": "%Y-%m-01 00:00:00",
}
def update_payment_extra():
pass
async def get_payment(checking_id: str, conn: Optional[Connection] = None) -> Payment:
return await (conn or db).fetchone(
"SELECT * FROM apipayments WHERE checking_id = :checking_id",
{"checking_id": checking_id},
Payment,
)
async def get_standalone_payment(
checking_id_or_hash: str,
conn: Optional[Connection] = None,
incoming: Optional[bool] = False,
wallet_id: Optional[str] = None,
) -> Optional[Payment]:
clause: str = "checking_id = :checking_id OR payment_hash = :hash"
values = {
"wallet_id": wallet_id,
"checking_id": checking_id_or_hash,
"hash": checking_id_or_hash,
}
if incoming:
clause = f"({clause}) AND amount > 0"
if wallet_id:
clause = f"({clause}) AND wallet_id = :wallet_id"
row = await (conn or db).fetchone(
f"""
SELECT * FROM apipayments
WHERE {clause}
ORDER BY amount LIMIT 1
""",
values,
Payment,
)
return row
async def get_wallet_payment(
wallet_id: str, payment_hash: str, conn: Optional[Connection] = None
) -> Optional[Payment]:
payment = await (conn or db).fetchone(
"""
SELECT *
FROM apipayments
WHERE wallet_id = :wallet AND payment_hash = :hash
""",
{"wallet": wallet_id, "hash": payment_hash},
Payment,
)
return payment
async def get_latest_payments_by_extension(
ext_name: str, ext_id: str, limit: int = 5
) -> list[Payment]:
return await db.fetchall(
f"""
SELECT * FROM apipayments
WHERE status = '{PaymentState.SUCCESS}'
AND extra LIKE :ext_name
AND extra LIKE :ext_id
ORDER BY time DESC LIMIT {limit}
""",
{"ext_name": f"%{ext_name}%", "ext_id": f"%{ext_id}%"},
Payment,
)
async def get_payments_paginated(
*,
wallet_id: Optional[str] = None,
complete: bool = False,
pending: bool = False,
outgoing: bool = False,
incoming: bool = False,
since: Optional[int] = None,
exclude_uncheckable: bool = False,
filters: Optional[Filters[PaymentFilters]] = None,
conn: Optional[Connection] = None,
) -> Page[Payment]:
"""
Filters payments to be returned by complete | pending | outgoing | incoming.
"""
values: dict = {
"wallet_id": wallet_id,
"time": since,
}
clause: list[str] = []
if since is not None:
clause.append(f"time > {db.timestamp_placeholder('time')}")
if wallet_id:
clause.append("wallet_id = :wallet_id")
if complete and pending:
pass
elif complete:
clause.append(
f"((amount > 0 AND status = '{PaymentState.SUCCESS}') OR amount < 0)"
)
elif pending:
clause.append(f"status = '{PaymentState.PENDING}'")
else:
pass
if outgoing and incoming:
pass
elif outgoing:
clause.append("amount < 0")
elif incoming:
clause.append("amount > 0")
else:
pass
if exclude_uncheckable: # checkable means it has a checking_id that isn't internal
clause.append("checking_id NOT LIKE 'temp_%'")
clause.append("checking_id NOT LIKE 'internal_%'")
return await (conn or db).fetch_page(
"SELECT * FROM apipayments",
clause,
values,
filters=filters,
model=Payment,
)
async def get_payments(
*,
wallet_id: Optional[str] = None,
complete: bool = False,
pending: bool = False,
outgoing: bool = False,
incoming: bool = False,
since: Optional[int] = None,
exclude_uncheckable: bool = False,
filters: Optional[Filters[PaymentFilters]] = None,
conn: Optional[Connection] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
) -> list[Payment]:
"""
Filters payments to be returned by complete | pending | outgoing | incoming.
"""
filters = filters or Filters()
filters.sortby = filters.sortby or "time"
filters.direction = filters.direction or "desc"
filters.limit = limit or filters.limit
filters.offset = offset or filters.offset
page = await get_payments_paginated(
wallet_id=wallet_id,
complete=complete,
pending=pending,
outgoing=outgoing,
incoming=incoming,
since=since,
exclude_uncheckable=exclude_uncheckable,
filters=filters,
conn=conn,
)
return page.data
async def delete_expired_invoices(
conn: Optional[Connection] = None,
) -> None:
# first we delete all invoices older than one month
await (conn or db).execute(
f"""
DELETE FROM apipayments
WHERE status = '{PaymentState.PENDING}' AND amount > 0
AND time < {db.timestamp_placeholder("delta")}
""",
{"delta": int(time() - 2592000)},
)
# then we delete all invoices whose expiry date is in the past
await (conn or db).execute(
f"""
DELETE FROM apipayments
WHERE status = '{PaymentState.PENDING}' AND amount > 0
AND expiry < {db.timestamp_placeholder("now")}
""",
{"now": int(time())},
)
async def create_payment(
checking_id: str,
data: CreatePayment,
status: PaymentState = PaymentState.PENDING,
conn: Optional[Connection] = None,
) -> Payment:
# we don't allow the creation of the same invoice twice
# note: this can be removed if the db uniqueness constraints are set appropriately
previous_payment = await get_standalone_payment(checking_id, conn=conn)
assert previous_payment is None, "Payment already exists"
payment = Payment(
checking_id=checking_id,
status=status,
wallet_id=data.wallet_id,
payment_hash=data.payment_hash,
bolt11=data.bolt11,
amount=data.amount_msat,
memo=data.memo,
preimage=data.preimage,
expiry=data.expiry,
webhook=data.webhook,
fee=data.fee,
extra=data.extra or {},
)
await (conn or db).insert("apipayments", payment)
return payment
async def update_payment_checking_id(
checking_id: str, new_checking_id: str, conn: Optional[Connection] = None
) -> None:
await (conn or db).execute(
"UPDATE apipayments SET checking_id = :new_id WHERE checking_id = :old_id",
{"new_id": new_checking_id, "old_id": checking_id},
)
async def update_payment(
payment: Payment,
new_checking_id: Optional[str] = None,
conn: Optional[Connection] = None,
) -> None:
await (conn or db).update(
"apipayments", payment, "WHERE checking_id = :checking_id"
)
if new_checking_id and new_checking_id != payment.checking_id:
await update_payment_checking_id(payment.checking_id, new_checking_id, conn)
async def get_payments_history(
wallet_id: Optional[str] = None,
group: DateTrunc = "day",
filters: Optional[Filters] = None,
) -> list[PaymentHistoryPoint]:
if not filters:
filters = Filters()
if DB_TYPE == SQLITE and group in sqlite_formats:
date_trunc = f"strftime('{sqlite_formats[group]}', time, 'unixepoch')"
elif group in ("day", "hour", "month"):
date_trunc = f"date_trunc('{group}', time)"
else:
raise ValueError(f"Invalid group value: {group}")
values = {
"wallet_id": wallet_id,
}
where = [
f"wallet_id = :wallet_id AND (status = '{PaymentState.SUCCESS}' OR amount < 0)"
]
transactions: list[dict] = await db.fetchall(
f"""
SELECT {date_trunc} date,
SUM(CASE WHEN amount > 0 THEN amount ELSE 0 END) income,
SUM(CASE WHEN amount < 0 THEN abs(amount) + abs(fee) ELSE 0 END) spending
FROM apipayments
{filters.where(where)}
GROUP BY date
ORDER BY date DESC
""",
filters.values(values),
)
if wallet_id:
wallet = await get_wallet(wallet_id)
if wallet:
balance = wallet.balance_msat
else:
raise ValueError("Unknown wallet")
else:
balance = await get_total_balance()
# since we dont know the balance at the starting point,
# we take the current balance and walk backwards
results: list[PaymentHistoryPoint] = []
for row in transactions:
results.insert(
0,
PaymentHistoryPoint(
balance=balance,
date=row.get("date", 0),
income=row.get("income", 0),
spending=row.get("spending", 0),
),
)
balance -= row.get("income", 0) - row.get("spending", 0)
return results
async def delete_wallet_payment(
checking_id: str, wallet_id: str, conn: Optional[Connection] = None
) -> None:
await (conn or db).execute(
"DELETE FROM apipayments WHERE checking_id = :checking_id AND wallet = :wallet",
{"checking_id": checking_id, "wallet": wallet_id},
)
async def check_internal(
payment_hash: str, conn: Optional[Connection] = None
) -> Optional[Payment]:
"""
Returns the checking_id of the internal payment if it exists,
otherwise None
"""
return await (conn or db).fetchone(
f"""
SELECT * FROM apipayments
WHERE payment_hash = :hash AND status = '{PaymentState.PENDING}' AND amount > 0
""",
{"hash": payment_hash},
Payment,
)
async def is_internal_status_success(
payment_hash: str, conn: Optional[Connection] = None
) -> bool:
"""
Returns True if the internal payment was found and is successful,
"""
payment = await (conn or db).fetchone(
"""
SELECT * FROM apipayments
WHERE payment_hash = :payment_hash AND amount > 0
""",
{"payment_hash": payment_hash},
Payment,
)
if not payment:
return False
return payment.status == PaymentState.SUCCESS.value
async def mark_webhook_sent(payment_hash: str, status: int) -> None:
await db.execute(
"""
UPDATE apipayments SET webhook_status = :status
WHERE payment_hash = :hash
""",
{"status": status, "hash": payment_hash},
)

View file

@ -0,0 +1,71 @@
import json
from typing import Optional
from lnbits.core.db import db
from lnbits.settings import (
AdminSettings,
EditableSettings,
SuperSettings,
settings,
)
async def get_super_settings() -> Optional[SuperSettings]:
row: dict = await db.fetchone("SELECT * FROM settings")
if not row:
return None
editable_settings = json.loads(row["editable_settings"])
return SuperSettings(**{"super_user": row["super_user"], **editable_settings})
async def get_admin_settings(is_super_user: bool = False) -> Optional[AdminSettings]:
sets = await get_super_settings()
if not sets:
return None
row_dict = dict(sets)
row_dict.pop("super_user")
row_dict.pop("auth_all_methods")
admin_settings = AdminSettings(
is_super_user=is_super_user,
lnbits_allowed_funding_sources=settings.lnbits_allowed_funding_sources,
**row_dict,
)
return admin_settings
async def delete_admin_settings() -> None:
await db.execute("DELETE FROM settings")
async def update_admin_settings(data: EditableSettings) -> None:
row: dict = await db.fetchone("SELECT editable_settings FROM settings")
editable_settings = json.loads(row["editable_settings"]) if row else {}
editable_settings.update(data.dict(exclude_unset=True))
await db.execute(
"UPDATE settings SET editable_settings = :settings",
{"settings": json.dumps(editable_settings)},
)
async def update_super_user(super_user: str) -> SuperSettings:
await db.execute(
"UPDATE settings SET super_user = :user",
{"user": super_user},
)
settings = await get_super_settings()
assert settings, "updated super_user settings could not be retrieved"
return settings
async def create_admin_settings(super_user: str, new_settings: dict):
await db.execute(
"""
INSERT INTO settings (super_user, editable_settings)
VALUES (:user, :settings)
""",
{"user": super_user, "settings": json.dumps(new_settings)},
)
settings = await get_super_settings()
assert settings, "created admin settings could not be retrieved"
return settings

View file

@ -0,0 +1,42 @@
from typing import Optional
import shortuuid
from lnbits.core.db import db
from ..models import TinyURL
async def create_tinyurl(domain: str, endless: bool, wallet: str):
tinyurl_id = shortuuid.uuid()[:8]
await db.execute(
"""
INSERT INTO tiny_url (id, url, endless, wallet)
VALUES (:tinyurl, :domain, :endless, :wallet)
""",
{"tinyurl": tinyurl_id, "domain": domain, "endless": endless, "wallet": wallet},
)
return await get_tinyurl(tinyurl_id)
async def get_tinyurl(tinyurl_id: str) -> Optional[TinyURL]:
return await db.fetchone(
"SELECT * FROM tiny_url WHERE id = :tinyurl",
{"tinyurl": tinyurl_id},
TinyURL,
)
async def get_tinyurl_by_url(url: str) -> list[TinyURL]:
return await db.fetchall(
"SELECT * FROM tiny_url WHERE url = :url",
{"url": url},
TinyURL,
)
async def delete_tinyurl(tinyurl_id: str):
await db.execute(
"DELETE FROM tiny_url WHERE id = :tinyurl",
{"tinyurl": tinyurl_id},
)

170
lnbits/core/crud/users.py Normal file
View file

@ -0,0 +1,170 @@
from datetime import datetime, timezone
from time import time
from typing import Optional
from uuid import uuid4
from lnbits.core.crud.extensions import get_user_active_extensions_ids
from lnbits.core.crud.wallets import get_wallets
from lnbits.core.db import db
from lnbits.db import Connection, Filters, Page
from ..models import (
Account,
AccountFilters,
AccountOverview,
User,
)
async def create_account(
account: Optional[Account] = None,
conn: Optional[Connection] = None,
) -> Account:
if not account:
now = datetime.now(timezone.utc)
account = Account(id=uuid4().hex, created_at=now, updated_at=now)
await (conn or db).insert("accounts", account)
return account
async def update_account(account: Account) -> Account:
account.updated_at = datetime.now(timezone.utc)
await db.update("accounts", account)
return account
async def delete_account(user_id: str, conn: Optional[Connection] = None) -> None:
await (conn or db).execute(
"DELETE from accounts WHERE id = :user",
{"user": user_id},
)
async def get_accounts(
filters: Optional[Filters[AccountFilters]] = None,
conn: Optional[Connection] = None,
) -> Page[AccountOverview]:
return await (conn or db).fetch_page(
"""
SELECT
accounts.id,
accounts.username,
accounts.email,
SUM(COALESCE((
SELECT balance FROM balances WHERE wallet_id = wallets.id
), 0)) as balance_msat,
SUM((
SELECT COUNT(*) FROM apipayments WHERE wallet_id = wallets.id
)) as transaction_count,
(
SELECT COUNT(*) FROM wallets WHERE wallets.user = accounts.id
) as wallet_count,
MAX((
SELECT time FROM apipayments
WHERE wallet_id = wallets.id ORDER BY time DESC LIMIT 1
)) as last_payment
FROM accounts LEFT JOIN wallets ON accounts.id = wallets.user
""",
[],
{},
filters=filters,
model=AccountOverview,
group_by=["accounts.id"],
)
async def get_account(
user_id: str, conn: Optional[Connection] = None
) -> Optional[Account]:
return await (conn or db).fetchone(
"SELECT * FROM accounts WHERE id = :id",
{"id": user_id},
Account,
)
async def delete_accounts_no_wallets(
time_delta: int,
conn: Optional[Connection] = None,
) -> None:
delta = int(time()) - time_delta
await (conn or db).execute(
f"""
DELETE FROM accounts
WHERE NOT EXISTS (
SELECT wallets.id FROM wallets WHERE wallets.user = accounts.id
) AND (
(updated_at is null AND created_at < :delta)
OR updated_at < {db.timestamp_placeholder("delta")}
)
""",
{"delta": delta},
)
async def get_account_by_username(
username: str, conn: Optional[Connection] = None
) -> Optional[Account]:
return await (conn or db).fetchone(
"SELECT * FROM accounts WHERE username = :username",
{"username": username},
Account,
)
async def get_account_by_pubkey(
pubkey: str, conn: Optional[Connection] = None
) -> Optional[Account]:
return await (conn or db).fetchone(
"SELECT * FROM accounts WHERE pubkey = :pubkey",
{"pubkey": pubkey},
Account,
)
async def get_account_by_email(
email: str, conn: Optional[Connection] = None
) -> Optional[Account]:
return await (conn or db).fetchone(
"SELECT * FROM accounts WHERE email = :email",
{"email": email},
Account,
)
async def get_account_by_username_or_email(
username_or_email: str, conn: Optional[Connection] = None
) -> Optional[Account]:
return await (conn or db).fetchone(
"SELECT * FROM accounts WHERE email = :value or username = :value",
{"value": username_or_email},
Account,
)
async def get_user(user_id: str, conn: Optional[Connection] = None) -> Optional[User]:
account = await get_account(user_id, conn)
if not account:
return None
return await get_user_from_account(account, conn)
async def get_user_from_account(
account: Account, conn: Optional[Connection] = None
) -> Optional[User]:
extensions = await get_user_active_extensions_ids(account.id, conn)
wallets = await get_wallets(account.id, False, conn=conn)
return User(
id=account.id,
email=account.email,
username=account.username,
pubkey=account.pubkey,
extra=account.extra,
created_at=account.created_at,
updated_at=account.updated_at,
extensions=extensions,
wallets=wallets,
admin=account.is_admin,
super_user=account.is_super_user,
has_password=account.password_hash is not None,
)

157
lnbits/core/crud/wallets.py Normal file
View file

@ -0,0 +1,157 @@
from datetime import datetime, timezone
from time import time
from typing import Optional
from uuid import uuid4
from lnbits.core.db import db
from lnbits.db import Connection
from lnbits.settings import settings
from ..models import Wallet
async def create_wallet(
*,
user_id: str,
wallet_name: Optional[str] = None,
conn: Optional[Connection] = None,
) -> Wallet:
wallet_id = uuid4().hex
wallet = Wallet(
id=wallet_id,
name=wallet_name or settings.lnbits_default_wallet_name,
user=user_id,
adminkey=uuid4().hex,
inkey=uuid4().hex,
)
await (conn or db).insert("wallets", wallet)
return wallet
async def update_wallet(
wallet: Wallet,
conn: Optional[Connection] = None,
) -> Optional[Wallet]:
wallet.updated_at = datetime.now(timezone.utc)
await (conn or db).update("wallets", wallet)
return wallet
async def delete_wallet(
*,
user_id: str,
wallet_id: str,
deleted: bool = True,
conn: Optional[Connection] = None,
) -> None:
now = int(time())
await (conn or db).execute(
f"""
UPDATE wallets
SET deleted = :deleted, updated_at = {db.timestamp_placeholder('now')}
WHERE id = :wallet AND "user" = :user
""",
{"wallet": wallet_id, "user": user_id, "deleted": deleted, "now": now},
)
async def force_delete_wallet(
wallet_id: str, conn: Optional[Connection] = None
) -> None:
await (conn or db).execute(
"DELETE FROM wallets WHERE id = :wallet",
{"wallet": wallet_id},
)
async def delete_wallet_by_id(
wallet_id: str, conn: Optional[Connection] = None
) -> Optional[int]:
now = int(time())
result = await (conn or db).execute(
f"""
UPDATE wallets
SET deleted = true, updated_at = {db.timestamp_placeholder('now')}
WHERE id = :wallet
""",
{"wallet": wallet_id, "now": now},
)
return result.rowcount
async def remove_deleted_wallets(conn: Optional[Connection] = None) -> None:
await (conn or db).execute("DELETE FROM wallets WHERE deleted = true")
async def delete_unused_wallets(
time_delta: int,
conn: Optional[Connection] = None,
) -> None:
delta = int(time()) - time_delta
await (conn or db).execute(
"""
DELETE FROM wallets
WHERE (
SELECT COUNT(*) FROM apipayments WHERE wallet_id = wallets.id
) = 0 AND (
(updated_at is null AND created_at < :delta)
OR updated_at < :delta
)
""",
{"delta": delta},
)
async def get_wallet(
wallet_id: str, deleted: Optional[bool] = None, conn: Optional[Connection] = None
) -> Optional[Wallet]:
where = "AND deleted = :deleted" if deleted is not None else ""
return await (conn or db).fetchone(
f"""
SELECT *, COALESCE((
SELECT balance FROM balances WHERE wallet_id = wallets.id
), 0) AS balance_msat FROM wallets
WHERE id = :wallet {where}
""",
{"wallet": wallet_id, "deleted": deleted},
Wallet,
)
async def get_wallets(
user_id: str, deleted: Optional[bool] = None, conn: Optional[Connection] = None
) -> list[Wallet]:
where = "AND deleted = :deleted" if deleted is not None else ""
return await (conn or db).fetchall(
f"""
SELECT *, COALESCE((
SELECT balance FROM balances WHERE wallet_id = wallets.id
), 0) AS balance_msat FROM wallets
WHERE "user" = :user {where}
""",
{"user": user_id, "deleted": deleted},
Wallet,
)
async def get_wallet_for_key(
key: str,
conn: Optional[Connection] = None,
) -> Optional[Wallet]:
return await (conn or db).fetchone(
"""
SELECT *, COALESCE((
SELECT balance FROM balances WHERE wallet_id = wallets.id
), 0)
AS balance_msat FROM wallets
WHERE (adminkey = :key OR inkey = :key) AND deleted = false
""",
{"key": key},
Wallet,
)
async def get_total_balance(conn: Optional[Connection] = None):
result = await (conn or db).execute("SELECT SUM(balance) FROM balances")
row = result.mappings().first()
return row.get("balance", 0)

View file

@ -0,0 +1,59 @@
from typing import Optional
from lnbits.core.db import db
from ..models import WebPushSubscription
async def get_webpush_subscription(
endpoint: str, user: str
) -> Optional[WebPushSubscription]:
return await db.fetchone(
"""
SELECT * FROM webpush_subscriptions
WHERE endpoint = :endpoint AND "user" = :user
""",
{"endpoint": endpoint, "user": user},
WebPushSubscription,
)
async def get_webpush_subscriptions_for_user(user: str) -> list[WebPushSubscription]:
return await db.fetchall(
"""SELECT * FROM webpush_subscriptions WHERE "user" = :user""",
{"user": user},
WebPushSubscription,
)
async def create_webpush_subscription(
endpoint: str, user: str, data: str, host: str
) -> WebPushSubscription:
await db.execute(
"""
INSERT INTO webpush_subscriptions (endpoint, "user", data, host)
VALUES (:endpoint, :user, :data, :host)
""",
{"endpoint": endpoint, "user": user, "data": data, "host": host},
)
subscription = await get_webpush_subscription(endpoint, user)
assert subscription, "Newly created webpush subscription couldn't be retrieved"
return subscription
async def delete_webpush_subscription(endpoint: str, user: str) -> int:
resp = await db.execute(
"""
DELETE FROM webpush_subscriptions WHERE endpoint = :endpoint AND "user" = :user
""",
{"endpoint": endpoint, "user": user},
)
return resp.rowcount
async def delete_webpush_subscriptions(endpoint: str) -> int:
resp = await db.execute(
"DELETE FROM webpush_subscriptions WHERE endpoint = :endpoint",
{"endpoint": endpoint},
)
return resp.rowcount

View file

@ -3,14 +3,14 @@ import importlib
from loguru import logger from loguru import logger
from lnbits.core import core_app_extra
from lnbits.core.crud import ( from lnbits.core.crud import (
add_installed_extension, create_installed_extension,
delete_installed_extension, delete_installed_extension,
get_dbversions, get_db_version,
get_installed_extension, get_installed_extension,
update_installed_extension_state, update_installed_extension_state,
) )
from lnbits.core.db import core_app_extra
from lnbits.core.helpers import migrate_extension_database from lnbits.core.helpers import migrate_extension_database
from lnbits.settings import settings from lnbits.settings import settings
@ -18,22 +18,27 @@ from .models import Extension, InstallableExtension
async def install_extension(ext_info: InstallableExtension) -> Extension: async def install_extension(ext_info: InstallableExtension) -> Extension:
ext_id = ext_info.id
extension = Extension.from_installable_ext(ext_info) extension = Extension.from_installable_ext(ext_info)
installed_ext = await get_installed_extension(ext_info.id) installed_ext = await get_installed_extension(ext_id)
ext_info.payments = installed_ext.payments if installed_ext else [] if installed_ext:
ext_info.meta = installed_ext.meta
await ext_info.download_archive() await ext_info.download_archive()
ext_info.extract_archive() ext_info.extract_archive()
db_version = (await get_dbversions()).get(ext_info.id, 0) db_version = await get_db_version(ext_id)
await migrate_extension_database(extension, db_version) await migrate_extension_database(ext_info, db_version)
await add_installed_extension(ext_info) # if the extensions does not exist in the installed extensions table, create it
# if it does exist, it will be activated later in the code
if not installed_ext:
await create_installed_extension(ext_info)
if extension.is_upgrade_extension: if extension.is_upgrade_extension:
# call stop while the old routes are still active # call stop while the old routes are still active
await stop_extension_background_work(ext_info.id) await stop_extension_background_work(ext_id)
return extension return extension

View file

@ -109,8 +109,8 @@ class ReleasePaymentInfo(BaseModel):
class PayToEnableInfo(BaseModel): class PayToEnableInfo(BaseModel):
required: Optional[bool] = False amount: int
amount: Optional[int] = None required: bool = False
wallet: Optional[str] = None wallet: Optional[str] = None
@ -120,6 +120,7 @@ class UserExtensionInfo(BaseModel):
class UserExtension(BaseModel): class UserExtension(BaseModel):
user: str
extension: str extension: str
active: bool active: bool
extra: Optional[UserExtensionInfo] = None extra: Optional[UserExtensionInfo] = None
@ -372,29 +373,37 @@ class ExtensionRelease(BaseModel):
return None return None
class ExtensionMeta(BaseModel):
installed_release: Optional[ExtensionRelease] = None
latest_release: Optional[ExtensionRelease] = None
pay_to_enable: Optional[PayToEnableInfo] = None
payments: list[ReleasePaymentInfo] = []
dependencies: list[str] = []
archive: Optional[str] = None
featured: bool = False
class InstallableExtension(BaseModel): class InstallableExtension(BaseModel):
id: str id: str
name: str name: str
version: str
active: Optional[bool] = False active: Optional[bool] = False
short_description: Optional[str] = None short_description: Optional[str] = None
icon: Optional[str] = None icon: Optional[str] = None
dependencies: list[str] = []
is_admin_only: bool = False
stars: int = 0 stars: int = 0
featured = False meta: Optional[ExtensionMeta] = None
latest_release: Optional[ExtensionRelease] = None
installed_release: Optional[ExtensionRelease] = None @property
payments: list[ReleasePaymentInfo] = [] def is_admin_only(self) -> bool:
pay_to_enable: Optional[PayToEnableInfo] = None return self.id in settings.lnbits_admin_extensions
archive: Optional[str] = None
@property @property
def hash(self) -> str: def hash(self) -> str:
if self.installed_release: if self.meta and self.meta.installed_release:
if self.installed_release.hash: if self.meta.installed_release.hash:
return self.installed_release.hash return self.meta.installed_release.hash
m = hashlib.sha256() m = hashlib.sha256()
m.update(f"{self.installed_release.archive}".encode()) m.update(f"{self.meta.installed_release.archive}".encode())
return m.hexdigest() return m.hexdigest()
return "not-installed" return "not-installed"
@ -432,15 +441,15 @@ class InstallableExtension(BaseModel):
@property @property
def installed_version(self) -> str: def installed_version(self) -> str:
if self.installed_release: if self.meta and self.meta.installed_release:
return self.installed_release.version return self.meta.installed_release.version
return "" return ""
@property @property
def requires_payment(self) -> bool: def requires_payment(self) -> bool:
if not self.pay_to_enable: if not self.meta or not self.meta.pay_to_enable:
return False return False
return self.pay_to_enable.required is True return self.meta.pay_to_enable.required is True
async def download_archive(self): async def download_archive(self):
logger.info(f"Downloading extension {self.name} ({self.installed_version}).") logger.info(f"Downloading extension {self.name} ({self.installed_version}).")
@ -448,12 +457,14 @@ class InstallableExtension(BaseModel):
if ext_zip_file.is_file(): if ext_zip_file.is_file():
os.remove(ext_zip_file) os.remove(ext_zip_file)
try: try:
assert self.installed_release, "installed_release is none." assert (
self.meta and self.meta.installed_release
), "installed_release is none."
self._restore_payment_info() self._restore_payment_info()
await asyncio.to_thread( await asyncio.to_thread(
download_url, self.installed_release.archive_url, ext_zip_file download_url, self.meta.installed_release.archive_url, ext_zip_file
) )
self._remember_payment_info() self._remember_payment_info()
@ -463,7 +474,11 @@ class InstallableExtension(BaseModel):
raise AssertionError("Cannot fetch extension archive file") from exc raise AssertionError("Cannot fetch extension archive file") from exc
archive_hash = file_hash(ext_zip_file) archive_hash = file_hash(ext_zip_file)
if self.installed_release.hash and self.installed_release.hash != archive_hash: if (
self.meta
and self.meta.installed_release.hash
and self.meta.installed_release.hash != archive_hash
):
# remove downloaded archive # remove downloaded archive
if ext_zip_file.is_file(): if ext_zip_file.is_file():
os.remove(ext_zip_file) os.remove(ext_zip_file)
@ -497,17 +512,18 @@ class InstallableExtension(BaseModel):
self.short_description = config_json.get("short_description") self.short_description = config_json.get("short_description")
if ( if (
self.installed_release self.meta
and self.installed_release.is_github_release and self.meta.installed_release
and self.meta.installed_release.is_github_release
and config_json.get("tile") and config_json.get("tile")
): ):
self.icon = icon_to_github_url( self.icon = icon_to_github_url(
self.installed_release.source_repo, config_json.get("tile") self.meta.installed_release.source_repo, config_json.get("tile")
) )
shutil.rmtree(self.ext_dir, True) shutil.rmtree(self.ext_dir, True)
shutil.copytree(Path(self.ext_upgrade_dir), Path(self.ext_dir)) shutil.copytree(Path(self.ext_upgrade_dir), Path(self.ext_dir))
logger.success(f"Extension {self.name} ({self.installed_version}) installed.") logger.info(f"Extension {self.name} ({self.installed_version}) extracted.")
def clean_extension_files(self): def clean_extension_files(self):
# remove downloaded archive # remove downloaded archive
@ -522,64 +538,54 @@ class InstallableExtension(BaseModel):
def check_latest_version(self, release: Optional[ExtensionRelease]): def check_latest_version(self, release: Optional[ExtensionRelease]):
if not release: if not release:
return return
if not self.latest_release: if not self.meta or not self.meta.latest_release:
self.latest_release = release meta = self.meta or ExtensionMeta()
meta.latest_release = release
self.meta = meta
return return
if version_parse(self.latest_release.version) < version_parse(release.version): if version_parse(self.meta.latest_release.version) < version_parse(
self.latest_release = release release.version
):
self.meta.latest_release = release
def find_existing_payment( def find_existing_payment(
self, pay_link: Optional[str] self, pay_link: Optional[str]
) -> Optional[ReleasePaymentInfo]: ) -> Optional[ReleasePaymentInfo]:
if not pay_link: if not pay_link or not self.meta or not self.meta.payments:
return None return None
return next( return next(
(p for p in self.payments if p.pay_link == pay_link), (p for p in self.meta.payments if p.pay_link == pay_link),
None, None,
) )
def _restore_payment_info(self): def _restore_payment_info(self):
if not self.installed_release: if (
not self.meta
or not self.meta.installed_release
or not self.meta.installed_release.pay_link
or not self.meta.installed_release.payment_hash
):
return return
if not self.installed_release.pay_link: payment_info = self.find_existing_payment(self.meta.installed_release.pay_link)
return
if self.installed_release.payment_hash:
return
payment_info = self.find_existing_payment(self.installed_release.pay_link)
if payment_info: if payment_info:
self.installed_release.payment_hash = payment_info.payment_hash self.meta.installed_release.payment_hash = payment_info.payment_hash
def _remember_payment_info(self): def _remember_payment_info(self):
if not self.installed_release or not self.installed_release.pay_link: if (
not self.meta
or not self.meta.installed_release
or not self.meta.installed_release.pay_link
):
return return
payment_info = ReleasePaymentInfo( payment_info = ReleasePaymentInfo(
amount=self.installed_release.cost_sats, amount=self.meta.installed_release.cost_sats,
pay_link=self.installed_release.pay_link, pay_link=self.meta.installed_release.pay_link,
payment_hash=self.installed_release.payment_hash, payment_hash=self.meta.installed_release.payment_hash,
) )
self.payments = [ self.meta.payments = [
p for p in self.payments if p.pay_link != payment_info.pay_link p for p in self.meta.payments if p.pay_link != payment_info.pay_link
] ]
self.payments.append(payment_info) self.meta.payments.append(payment_info)
@classmethod
def from_row(cls, data: dict) -> InstallableExtension:
meta = json.loads(data["meta"])
ext = InstallableExtension(**data)
if "installed_release" in meta:
ext.installed_release = ExtensionRelease(**meta["installed_release"])
if meta.get("pay_to_enable"):
ext.pay_to_enable = PayToEnableInfo(**meta["pay_to_enable"])
if meta.get("payments"):
ext.payments = [ReleasePaymentInfo(**p) for p in meta["payments"]]
return ext
@classmethod
def from_rows(cls, rows: Optional[list[Any]] = None) -> list[InstallableExtension]:
if rows is None:
rows = []
return [InstallableExtension.from_row(row) for row in rows]
@classmethod @classmethod
async def from_github_release( async def from_github_release(
@ -593,14 +599,17 @@ class InstallableExtension(BaseModel):
return InstallableExtension( return InstallableExtension(
id=github_release.id, id=github_release.id,
name=config.name, name=config.name,
version=latest_release.tag_name,
short_description=config.short_description, short_description=config.short_description,
stars=int(repo.stargazers_count), stars=int(repo.stargazers_count),
icon=icon_to_github_url( icon=icon_to_github_url(
source_repo, source_repo,
config.tile, config.tile,
), ),
latest_release=ExtensionRelease.from_github_release( meta=ExtensionMeta(
source_repo, latest_release latest_release=ExtensionRelease.from_github_release(
source_repo, latest_release
),
), ),
) )
except Exception as e: except Exception as e:
@ -609,13 +618,14 @@ class InstallableExtension(BaseModel):
@classmethod @classmethod
def from_explicit_release(cls, e: ExplicitRelease) -> InstallableExtension: def from_explicit_release(cls, e: ExplicitRelease) -> InstallableExtension:
meta = ExtensionMeta(archive=e.archive, dependencies=e.dependencies)
return InstallableExtension( return InstallableExtension(
id=e.id, id=e.id,
name=e.name, name=e.name,
archive=e.archive, version=e.version,
short_description=e.short_description, short_description=e.short_description,
icon=e.icon, icon=e.icon,
dependencies=e.dependencies, meta=meta,
) )
@classmethod @classmethod
@ -636,11 +646,13 @@ class InstallableExtension(BaseModel):
existing_ext = next( existing_ext = next(
(ee for ee in extension_list if ee.id == r.id), None (ee for ee in extension_list if ee.id == r.id), None
) )
if existing_ext: if existing_ext and ext.meta:
existing_ext.check_latest_version(ext.latest_release) existing_ext.check_latest_version(ext.meta.latest_release)
continue continue
ext.featured = ext.id in manifest.featured meta = ext.meta or ExtensionMeta()
meta.featured = ext.id in manifest.featured
ext.meta = meta
extension_list += [ext] extension_list += [ext]
extension_id_list += [ext.id] extension_id_list += [ext.id]
@ -654,7 +666,9 @@ class InstallableExtension(BaseModel):
continue continue
ext = InstallableExtension.from_explicit_release(e) ext = InstallableExtension.from_explicit_release(e)
ext.check_latest_version(release) ext.check_latest_version(release)
ext.featured = ext.id in manifest.featured meta = ext.meta or ExtensionMeta()
meta.featured = ext.id in manifest.featured
ext.meta = meta
extension_list += [ext] extension_list += [ext]
extension_id_list += [e.id] extension_id_list += [e.id]
except Exception as e: except Exception as e:

View file

@ -1,6 +1,6 @@
import importlib import importlib
import re import re
from typing import Any from typing import Any, Optional
from urllib.parse import urlparse from urllib.parse import urlparse
from uuid import UUID from uuid import UUID
@ -8,39 +8,45 @@ from loguru import logger
from lnbits.core import migrations as core_migrations from lnbits.core import migrations as core_migrations
from lnbits.core.crud import ( from lnbits.core.crud import (
get_dbversions, get_db_versions,
get_installed_extensions, get_installed_extensions,
update_migration_version, update_migration_version,
) )
from lnbits.core.db import db as core_db from lnbits.core.db import db as core_db
from lnbits.core.extensions.models import ( from lnbits.core.extensions.models import InstallableExtension
Extension, from lnbits.core.models import DbVersion
)
from lnbits.db import COCKROACH, POSTGRES, SQLITE, Connection from lnbits.db import COCKROACH, POSTGRES, SQLITE, Connection
from lnbits.settings import settings from lnbits.settings import settings
async def migrate_extension_database(ext: Extension, current_version): async def migrate_extension_database(
ext: InstallableExtension, current_version: Optional[DbVersion] = None
):
try: try:
ext_migrations = importlib.import_module(f"{ext.module_name}.migrations") ext_migrations = importlib.import_module(f"{ext.module_name}.migrations")
ext_db = importlib.import_module(ext.module_name).db ext_db = importlib.import_module(ext.module_name).db
except ImportError as exc: except ImportError as exc:
logger.error(exc) logger.error(exc)
raise ImportError(f"Cannot import module for extension '{ext.code}'.") from exc raise ImportError(f"Cannot import module for extension '{ext.id}'.") from exc
async with ext_db.connect() as ext_conn: async with ext_db.connect() as ext_conn:
await run_migration(ext_conn, ext_migrations, ext.code, current_version) await run_migration(ext_conn, ext_migrations, ext.id, current_version)
async def run_migration( async def run_migration(
db: Connection, migrations_module: Any, db_name: str, current_version: int db: Connection,
migrations_module: Any,
db_name: str,
current_version: Optional[DbVersion] = None,
): ):
matcher = re.compile(r"^m(\d\d\d)_") matcher = re.compile(r"^m(\d\d\d)_")
for key, migrate in migrations_module.__dict__.items():
for key, migrate in list(migrations_module.__dict__.items()):
match = matcher.match(key) match = matcher.match(key)
if match: if match:
version = int(match.group(1)) version = int(match.group(1))
if version > current_version: if not current_version or version > current_version.version:
logger.debug(f"running migration {db_name}.{version}") logger.debug(f"running migration {db_name}.{version}")
print(f"running migration {db_name}.{version}") print(f"running migration {db_name}.{version}")
await migrate(db) await migrate(db)
@ -87,21 +93,31 @@ async def migrate_databases():
if not exists: if not exists:
await core_migrations.m000_create_migrations_table(conn) await core_migrations.m000_create_migrations_table(conn)
current_versions = await get_dbversions(conn) current_versions = await get_db_versions(conn)
core_version = current_versions.get("core", 0) core_version = next(
(v for v in current_versions if v.db == "core"),
DbVersion(db="core", version=0),
)
await run_migration(conn, core_migrations, "core", core_version) await run_migration(conn, core_migrations, "core", core_version)
# here is the first place we can be sure that the # here is the first place we can be sure that the
# `installed_extensions` table has been created # `installed_extensions` table has been created
await load_disabled_extension_list() await load_disabled_extension_list()
# todo: revisit, use installed extensions for ext in await get_installed_extensions():
for ext in Extension.get_valid_extensions(False): current_version = next(
current_version = current_versions.get(ext.code, 0) (v for v in current_versions if v.db == ext.id),
DbVersion(db=ext.id, version=0),
)
if current_version is None:
logger.warning(
f"Extension {ext.id} has no migration version. This should not happen."
)
continue
try: try:
await migrate_extension_database(ext, current_version) await migrate_extension_database(ext, current_version)
except Exception as e: except Exception as e:
logger.exception(f"Error migrating extension {ext.code}: {e}") logger.exception(f"Error migrating extension {ext.id}: {e}")
logger.info("✔️ All migrations done.") logger.info("✔️ All migrations done.")

View file

@ -1,9 +1,11 @@
import json
from time import time from time import time
from loguru import logger from loguru import logger
from sqlalchemy.exc import OperationalError from sqlalchemy.exc import OperationalError
from lnbits import bolt11 from lnbits import bolt11
from lnbits.db import Connection
async def m000_create_migrations_table(db): async def m000_create_migrations_table(db):
@ -99,9 +101,8 @@ async def m002_add_fields_to_apipayments(db):
await db.execute("ALTER TABLE apipayments ADD COLUMN bolt11 TEXT") await db.execute("ALTER TABLE apipayments ADD COLUMN bolt11 TEXT")
await db.execute("ALTER TABLE apipayments ADD COLUMN extra TEXT") await db.execute("ALTER TABLE apipayments ADD COLUMN extra TEXT")
import json result = await db.execute("SELECT * FROM apipayments")
rows = result.mappings().all()
rows = await db.fetchall("SELECT * FROM apipayments")
for row in rows: for row in rows:
if not row["memo"] or not row["memo"].startswith("#"): if not row["memo"] or not row["memo"].startswith("#"):
continue continue
@ -211,7 +212,7 @@ async def m007_set_invoice_expiries(db):
Precomputes invoice expiry for existing pending incoming payments. Precomputes invoice expiry for existing pending incoming payments.
""" """
try: try:
rows = await db.fetchall( result = await db.execute(
f""" f"""
SELECT bolt11, checking_id SELECT bolt11, checking_id
FROM apipayments FROM apipayments
@ -222,6 +223,7 @@ async def m007_set_invoice_expiries(db):
AND time < {db.timestamp_now} AND time < {db.timestamp_now}
""" """
) )
rows = result.mappings().all()
if len(rows): if len(rows):
logger.info(f"Migration: Checking expiry of {len(rows)} invoices") logger.info(f"Migration: Checking expiry of {len(rows)} invoices")
for i, ( for i, (
@ -339,7 +341,7 @@ async def m014_set_deleted_wallets(db):
Sets deleted column to wallets. Sets deleted column to wallets.
""" """
try: try:
rows = await db.fetchall( result = await db.execute(
""" """
SELECT * SELECT *
FROM wallets FROM wallets
@ -348,12 +350,13 @@ async def m014_set_deleted_wallets(db):
AND inkey LIKE 'del:%' AND inkey LIKE 'del:%'
""" """
) )
rows = result.mappings().all()
for row in rows: for row in rows:
try: try:
user = row[2].split(":")[1] user = row["user"].split(":")[1]
adminkey = row[3].split(":")[1] adminkey = row["adminkey"].split(":")[1]
inkey = row[4].split(":")[1] inkey = row["inkey"].split(":")[1]
await db.execute( await db.execute(
""" """
UPDATE wallets SET UPDATE wallets SET
@ -541,8 +544,6 @@ async def m021_add_success_failed_to_apipayments(db):
GROUP BY apipayments.wallet GROUP BY apipayments.wallet
""" """
) )
# TODO: drop column in next release
# await db.execute("ALTER TABLE apipayments DROP COLUMN pending")
async def m022_add_pubkey_to_accounts(db): async def m022_add_pubkey_to_accounts(db):
@ -553,3 +554,78 @@ async def m022_add_pubkey_to_accounts(db):
await db.execute("ALTER TABLE accounts ADD COLUMN pubkey TEXT") await db.execute("ALTER TABLE accounts ADD COLUMN pubkey TEXT")
except OperationalError: except OperationalError:
pass pass
async def m023_add_column_column_to_apipayments(db):
"""
renames hash to payment_hash and drops unused index
"""
await db.execute("DROP INDEX by_hash")
await db.execute("ALTER TABLE apipayments RENAME COLUMN hash TO payment_hash")
await db.execute("ALTER TABLE apipayments RENAME COLUMN wallet TO wallet_id")
await db.execute("ALTER TABLE accounts RENAME COLUMN pass TO password_hash")
await db.execute("CREATE INDEX by_hash ON apipayments (payment_hash)")
async def m024_drop_pending(db):
await db.execute("ALTER TABLE apipayments DROP COLUMN pending")
async def m025_refresh_view(db):
await db.execute("DROP VIEW balances")
await db.execute(
"""
CREATE VIEW balances AS
SELECT apipayments.wallet_id,
SUM(apipayments.amount - ABS(apipayments.fee)) AS balance
FROM wallets
LEFT JOIN apipayments ON apipayments.wallet_id = wallets.id
WHERE (wallets.deleted = false OR wallets.deleted is NULL)
AND (
(apipayments.status = 'success' AND apipayments.amount > 0)
OR (apipayments.status IN ('success', 'pending') AND apipayments.amount < 0)
)
GROUP BY apipayments.wallet_id
"""
)
async def m026_update_payment_table(db):
await db.execute("ALTER TABLE apipayments ADD COLUMN tag TEXT")
await db.execute("ALTER TABLE apipayments ADD COLUMN extension TEXT")
await db.execute("ALTER TABLE apipayments ADD COLUMN created_at TIMESTAMP")
await db.execute("ALTER TABLE apipayments ADD COLUMN updated_at TIMESTAMP")
async def m027_update_apipayments_data(db: Connection):
result = None
try:
result = await db.execute("SELECT * FROM apipayments")
except Exception as exc:
logger.warning("Could not select, trying again after cache cleared.")
logger.debug(exc)
await db.execute("COMMIT")
result = await db.execute("SELECT * FROM apipayments")
payments = result.mappings().all()
for payment in payments:
tag = None
created_at = payment.get("time")
if payment.get("extra"):
extra = json.loads(payment.get("extra"))
tag = extra.get("tag")
tsph = db.timestamp_placeholder("created_at")
await db.execute(
f"""
UPDATE apipayments
SET tag = :tag, created_at = {tsph}, updated_at = {tsph}
WHERE checking_id = :checking_id
""",
{
"tag": tag,
"created_at": created_at,
"checking_id": payment.get("checking_id"),
},
)

View file

@ -1,19 +1,18 @@
from __future__ import annotations from __future__ import annotations
import datetime
import hashlib import hashlib
import hmac import hmac
import json
import time
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime, timezone
from enum import Enum from enum import Enum
from typing import Callable, Optional from typing import Callable, Optional
from ecdsa import SECP256k1, SigningKey from ecdsa import SECP256k1, SigningKey
from fastapi import Query from fastapi import Query
from pydantic import BaseModel, validator from passlib.context import CryptContext
from pydantic import BaseModel, Field, validator
from lnbits.db import FilterModel, FromRowModel from lnbits.db import FilterModel
from lnbits.helpers import url_for from lnbits.helpers import url_for
from lnbits.lnurl import encode as lnurl_encode from lnbits.lnurl import encode as lnurl_encode
from lnbits.settings import settings from lnbits.settings import settings
@ -35,16 +34,21 @@ class BaseWallet(BaseModel):
balance_msat: int balance_msat: int
class Wallet(BaseWallet): class Wallet(BaseModel):
id: str
user: str user: str
currency: Optional[str] name: str
deleted: bool adminkey: str
created_at: Optional[int] = None inkey: str
updated_at: Optional[int] = None deleted: bool = False
created_at: datetime = datetime.now(timezone.utc)
updated_at: datetime = datetime.now(timezone.utc)
currency: Optional[str] = None
balance_msat: int = Field(default=0, no_database=True)
@property @property
def balance(self) -> int: def balance(self) -> int:
return self.balance_msat // 1000 return int(self.balance_msat // 1000)
@property @property
def withdrawable_balance(self) -> int: def withdrawable_balance(self) -> int:
@ -68,11 +72,6 @@ class Wallet(BaseWallet):
linking_key, curve=SECP256k1, hashfunc=hashlib.sha256 linking_key, curve=SECP256k1, hashfunc=hashlib.sha256
) )
async def get_payment(self, payment_hash: str) -> Optional[Payment]:
from .crud import get_standalone_payment
return await get_standalone_payment(payment_hash)
class KeyType(Enum): class KeyType(Enum):
admin = 0 admin = 0
@ -90,7 +89,7 @@ class WalletTypeInfo:
wallet: Wallet wallet: Wallet
class UserConfig(BaseModel): class UserExtra(BaseModel):
email_verified: Optional[bool] = False email_verified: Optional[bool] = False
first_name: Optional[str] = None first_name: Optional[str] = None
last_name: Optional[str] = None last_name: Optional[str] = None
@ -103,16 +102,43 @@ class UserConfig(BaseModel):
provider: Optional[str] = "lnbits" # auth provider provider: Optional[str] = "lnbits" # auth provider
class Account(FromRowModel): class Account(BaseModel):
id: str id: str
is_super_user: Optional[bool] = False
is_admin: Optional[bool] = False
username: Optional[str] = None username: Optional[str] = None
password_hash: Optional[str] = None
pubkey: Optional[str] = None
email: Optional[str] = None email: Optional[str] = None
balance_msat: Optional[int] = 0 extra: UserExtra = UserExtra()
created_at: datetime = datetime.now(timezone.utc)
updated_at: datetime = datetime.now(timezone.utc)
@property
def is_super_user(self) -> bool:
return self.id == settings.super_user
@property
def is_admin(self) -> bool:
return self.id in settings.lnbits_admin_users or self.is_super_user
def hash_password(self, password: str) -> str:
"""sets and returns the hashed password"""
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
self.password_hash = pwd_context.hash(password)
return self.password_hash
def verify_password(self, password: str) -> bool:
"""returns True if the password matches the hash"""
if not self.password_hash:
return False
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
return pwd_context.verify(password, self.password_hash)
class AccountOverview(Account):
transaction_count: Optional[int] = 0 transaction_count: Optional[int] = 0
wallet_count: Optional[int] = 0 wallet_count: Optional[int] = 0
last_payment: Optional[datetime.datetime] = None balance_msat: Optional[int] = 0
last_payment: Optional[datetime] = None
class AccountFilters(FilterModel): class AccountFilters(FilterModel):
@ -127,7 +153,7 @@ class AccountFilters(FilterModel):
] ]
id: str id: str
last_payment: Optional[datetime.datetime] = None last_payment: Optional[datetime] = None
transaction_count: Optional[int] = None transaction_count: Optional[int] = None
wallet_count: Optional[int] = None wallet_count: Optional[int] = None
username: Optional[str] = None username: Optional[str] = None
@ -136,6 +162,8 @@ class AccountFilters(FilterModel):
class User(BaseModel): class User(BaseModel):
id: str id: str
created_at: datetime
updated_at: datetime
email: Optional[str] = None email: Optional[str] = None
username: Optional[str] = None username: Optional[str] = None
pubkey: Optional[str] = None pubkey: Optional[str] = None
@ -144,9 +172,7 @@ class User(BaseModel):
admin: bool = False admin: bool = False
super_user: bool = False super_user: bool = False
has_password: bool = False has_password: bool = False
config: Optional[UserConfig] = None extra: UserExtra = UserExtra()
created_at: Optional[int] = None
updated_at: Optional[int] = None
@property @property
def wallet_ids(self) -> list[str]: def wallet_ids(self) -> list[str]:
@ -178,7 +204,7 @@ class UpdateUser(BaseModel):
user_id: str user_id: str
email: Optional[str] = Query(default=None) email: Optional[str] = Query(default=None)
username: Optional[str] = Query(default=..., min_length=2, max_length=20) username: Optional[str] = Query(default=..., min_length=2, max_length=20)
config: Optional[UserConfig] = None extra: Optional[UserExtra] = None
class UpdateUserPassword(BaseModel): class UpdateUserPassword(BaseModel):
@ -231,36 +257,55 @@ class PaymentState(str, Enum):
return self.value return self.value
class PaymentExtra(BaseModel):
comment: Optional[str] = None
success_action: Optional[str] = None
lnurl_response: Optional[str] = None
class PayInvoice(BaseModel):
payment_request: str
description: Optional[str] = None
max_sat: Optional[int] = None
extra: Optional[dict] = {}
class CreatePayment(BaseModel): class CreatePayment(BaseModel):
wallet_id: str wallet_id: str
payment_request: str
payment_hash: str payment_hash: str
amount: int bolt11: str
amount_msat: int
memo: str memo: str
extra: Optional[dict] = {}
preimage: Optional[str] = None preimage: Optional[str] = None
expiry: Optional[datetime.datetime] = None expiry: Optional[datetime] = None
extra: Optional[dict] = None
webhook: Optional[str] = None webhook: Optional[str] = None
fee: int = 0 fee: int = 0
class Payment(FromRowModel): class Payment(BaseModel):
status: str
# TODO should be removed in the future, backward compatibility
pending: bool
checking_id: str checking_id: str
payment_hash: str
wallet_id: str
amount: int amount: int
fee: int fee: int
memo: Optional[str]
time: int
bolt11: str bolt11: str
preimage: str status: str = PaymentState.PENDING
payment_hash: str memo: Optional[str] = None
expiry: Optional[float] expiry: Optional[datetime] = None
extra: Optional[dict] webhook: Optional[str] = None
wallet_id: str webhook_status: Optional[int] = None
webhook: Optional[str] preimage: Optional[str] = None
webhook_status: Optional[int] tag: Optional[str] = None
extension: Optional[str] = None
time: datetime = datetime.now(timezone.utc)
created_at: datetime = datetime.now(timezone.utc)
updated_at: datetime = datetime.now(timezone.utc)
extra: dict = {}
@property
def pending(self) -> bool:
return self.status == PaymentState.PENDING.value
@property @property
def success(self) -> bool: def success(self) -> bool:
@ -270,33 +315,6 @@ class Payment(FromRowModel):
def failed(self) -> bool: def failed(self) -> bool:
return self.status == PaymentState.FAILED.value return self.status == PaymentState.FAILED.value
@classmethod
def from_row(cls, row: dict):
return cls(
checking_id=row["checking_id"],
payment_hash=row["hash"] or "0" * 64,
bolt11=row["bolt11"] or "",
preimage=row["preimage"] or "0" * 64,
extra=json.loads(row["extra"] or "{}"),
status=row["status"],
# TODO should be removed in the future, backward compatibility
pending=row["status"] == PaymentState.PENDING.value,
amount=row["amount"],
fee=row["fee"],
memo=row["memo"],
time=row["time"],
expiry=row["expiry"],
wallet_id=row["wallet"],
webhook=row["webhook"],
webhook_status=row["webhook_status"],
)
@property
def tag(self) -> Optional[str]:
if self.extra is None:
return ""
return self.extra.get("tag")
@property @property
def msat(self) -> int: def msat(self) -> int:
return self.amount return self.amount
@ -315,7 +333,7 @@ class Payment(FromRowModel):
@property @property
def is_expired(self) -> bool: def is_expired(self) -> bool:
return self.expiry < time.time() if self.expiry else False return self.expiry < datetime.now(timezone.utc) if self.expiry else False
@property @property
def is_internal(self) -> bool: def is_internal(self) -> bool:
@ -343,11 +361,11 @@ class PaymentFilters(FilterModel):
amount: int amount: int
fee: int fee: int
memo: Optional[str] memo: Optional[str]
time: datetime.datetime time: datetime
bolt11: str bolt11: str
preimage: str preimage: str
payment_hash: str payment_hash: str
expiry: Optional[datetime.datetime] expiry: Optional[datetime]
extra: dict = {} extra: dict = {}
wallet_id: str wallet_id: str
webhook: Optional[str] webhook: Optional[str]
@ -355,7 +373,7 @@ class PaymentFilters(FilterModel):
class PaymentHistoryPoint(BaseModel): class PaymentHistoryPoint(BaseModel):
date: datetime.datetime date: datetime
income: int income: int
spending: int spending: int
balance: int balance: int
@ -377,10 +395,6 @@ class TinyURL(BaseModel):
wallet: str wallet: str
time: float time: float
@classmethod
def from_row(cls, row: dict):
return cls(**dict(row))
class ConversionData(BaseModel): class ConversionData(BaseModel):
from_: str = "sat" from_: str = "sat"
@ -425,7 +439,6 @@ class CreateInvoice(BaseModel):
def unit_is_from_allowed_currencies(cls, v): def unit_is_from_allowed_currencies(cls, v):
if v != "sat" and v not in allowed_currencies(): if v != "sat" and v not in allowed_currencies():
raise ValueError("The provided unit is not supported") raise ValueError("The provided unit is not supported")
return v return v
@ -451,7 +464,7 @@ class WebPushSubscription(BaseModel):
user: str user: str
data: str data: str
host: str host: str
timestamp: str timestamp: datetime
class BalanceDelta(BaseModel): class BalanceDelta(BaseModel):
@ -466,3 +479,8 @@ class BalanceDelta(BaseModel):
class SimpleStatus(BaseModel): class SimpleStatus(BaseModel):
success: bool success: bool
message: str message: str
class DbVersion(BaseModel):
db: str
version: int

View file

@ -1,926 +0,0 @@
import asyncio
import json
import time
from io import BytesIO
from pathlib import Path
from typing import Optional
from urllib.parse import parse_qs, urlparse
from uuid import UUID, uuid4
import httpx
from bolt11 import MilliSatoshi
from bolt11 import decode as bolt11_decode
from cryptography.hazmat.primitives import serialization
from fastapi import Depends, WebSocket
from loguru import logger
from passlib.context import CryptContext
from py_vapid import Vapid
from py_vapid.utils import b64urlencode
from lnbits.core.db import db
from lnbits.db import Connection
from lnbits.decorators import (
WalletTypeInfo,
check_user_extension_access,
require_admin_key,
)
from lnbits.exceptions import InvoiceError, PaymentError
from lnbits.helpers import url_for
from lnbits.lnurl import LnurlErrorResponse
from lnbits.lnurl import decode as decode_lnurl
from lnbits.settings import (
EditableSettings,
SuperSettings,
readonly_variables,
send_admin_user_to_saas,
settings,
)
from lnbits.utils.exchange_rates import fiat_amount_as_satoshis, satoshis_amount_as_fiat
from lnbits.wallets import fake_wallet, get_funding_source, set_funding_source
from lnbits.wallets.base import (
PaymentPendingStatus,
PaymentResponse,
PaymentStatus,
PaymentSuccessStatus,
)
from .crud import (
check_internal,
check_internal_pending,
create_account,
create_admin_settings,
create_payment,
create_wallet,
get_account,
get_account_by_email,
get_account_by_username,
get_payments,
get_standalone_payment,
get_super_settings,
get_total_balance,
get_wallet,
get_wallet_payment,
update_admin_settings,
update_payment_details,
update_payment_status,
update_super_user,
update_user_extension,
)
from .helpers import to_valid_user_id
from .models import (
BalanceDelta,
CreatePayment,
Payment,
PaymentState,
User,
UserConfig,
Wallet,
)
async def calculate_fiat_amounts(
amount: float,
wallet_id: str,
currency: Optional[str] = None,
extra: Optional[dict] = None,
conn: Optional[Connection] = None,
) -> tuple[int, Optional[dict]]:
wallet = await get_wallet(wallet_id, conn=conn)
assert wallet, "invalid wallet_id"
wallet_currency = wallet.currency or settings.lnbits_default_accounting_currency
if currency and currency != "sat":
amount_sat = await fiat_amount_as_satoshis(amount, currency)
extra = extra or {}
if currency != wallet_currency:
extra["fiat_currency"] = currency
extra["fiat_amount"] = round(amount, ndigits=3)
extra["fiat_rate"] = amount_sat / amount
else:
amount_sat = int(amount)
if wallet_currency:
if wallet_currency == currency:
fiat_amount = amount
else:
fiat_amount = await satoshis_amount_as_fiat(amount_sat, wallet_currency)
extra = extra or {}
extra["wallet_fiat_currency"] = wallet_currency
extra["wallet_fiat_amount"] = round(fiat_amount, ndigits=3)
extra["wallet_fiat_rate"] = amount_sat / fiat_amount
logger.debug(
f"Calculated fiat amounts {wallet.id=} {amount=} {currency=}: {extra=}"
)
return amount_sat, extra
async def create_invoice(
*,
wallet_id: str,
amount: float,
currency: Optional[str] = "sat",
memo: str,
description_hash: Optional[bytes] = None,
unhashed_description: Optional[bytes] = None,
expiry: Optional[int] = None,
extra: Optional[dict] = None,
webhook: Optional[str] = None,
internal: Optional[bool] = False,
conn: Optional[Connection] = None,
) -> tuple[str, str]:
if not amount > 0:
raise InvoiceError("Amountless invoices not supported.", status="failed")
user_wallet = await get_wallet(wallet_id, conn=conn)
if not user_wallet:
raise InvoiceError(f"Could not fetch wallet '{wallet_id}'.", status="failed")
invoice_memo = None if description_hash else memo
# use the fake wallet if the invoice is for internal use only
funding_source = fake_wallet if internal else get_funding_source()
amount_sat, extra = await calculate_fiat_amounts(
amount, wallet_id, currency=currency, extra=extra, conn=conn
)
if settings.is_wallet_max_balance_exceeded(
user_wallet.balance_msat / 1000 + amount_sat
):
raise InvoiceError(
f"Wallet balance cannot exceed "
f"{settings.lnbits_wallet_limit_max_balance} sats.",
status="failed",
)
(
ok,
checking_id,
payment_request,
error_message,
) = await funding_source.create_invoice(
amount=amount_sat,
memo=invoice_memo,
description_hash=description_hash,
unhashed_description=unhashed_description,
expiry=expiry or settings.lightning_invoice_expiry,
)
if not ok or not payment_request or not checking_id:
raise InvoiceError(
error_message or "unexpected backend error.", status="pending"
)
invoice = bolt11_decode(payment_request)
create_payment_model = CreatePayment(
wallet_id=wallet_id,
payment_request=payment_request,
payment_hash=invoice.payment_hash,
amount=amount_sat * 1000,
expiry=invoice.expiry_date,
memo=memo,
extra=extra,
webhook=webhook,
)
await create_payment(
checking_id=checking_id,
data=create_payment_model,
conn=conn,
)
return invoice.payment_hash, payment_request
async def pay_invoice(
*,
wallet_id: str,
payment_request: str,
max_sat: Optional[int] = None,
extra: Optional[dict] = None,
description: str = "",
conn: Optional[Connection] = None,
) -> str:
"""
Pay a Lightning invoice.
First, we create a temporary payment in the database with fees set to the reserve
fee. We then check whether the balance of the payer would go negative.
We then attempt to pay the invoice through the backend. If the payment is
successful, we update the payment in the database with the payment details.
If the payment is unsuccessful, we delete the temporary payment.
If the payment is still in flight, we hope that some other process
will regularly check for the payment.
"""
try:
invoice = bolt11_decode(payment_request)
except Exception as exc:
raise PaymentError("Bolt11 decoding failed.", status="failed") from exc
if not invoice.amount_msat or not invoice.amount_msat > 0:
raise PaymentError("Amountless invoices not supported.", status="failed")
if max_sat and invoice.amount_msat > max_sat * 1000:
raise PaymentError("Amount in invoice is too high.", status="failed")
await check_wallet_limits(wallet_id, conn, invoice.amount_msat)
async with db.reuse_conn(conn) if conn else db.connect() as conn:
temp_id = invoice.payment_hash
internal_id = f"internal_{invoice.payment_hash}"
_, extra = await calculate_fiat_amounts(
invoice.amount_msat / 1000, wallet_id, extra=extra, conn=conn
)
create_payment_model = CreatePayment(
wallet_id=wallet_id,
payment_request=payment_request,
payment_hash=invoice.payment_hash,
amount=-invoice.amount_msat,
expiry=invoice.expiry_date,
memo=description or invoice.description or "",
extra=extra,
)
# we check if an internal invoice exists that has already been paid
# (not pending anymore)
if not await check_internal_pending(invoice.payment_hash, conn=conn):
raise PaymentError("Internal invoice already paid.", status="failed")
# check_internal() returns the checking_id of the invoice we're waiting for
# (pending only)
internal_checking_id = await check_internal(invoice.payment_hash, conn=conn)
if internal_checking_id:
# perform additional checks on the internal payment
# the payment hash is not enough to make sure that this is the same invoice
internal_invoice = await get_standalone_payment(
internal_checking_id, incoming=True, conn=conn
)
assert internal_invoice is not None
if (
internal_invoice.amount != invoice.amount_msat
or internal_invoice.bolt11 != payment_request.lower()
):
raise PaymentError("Invalid invoice.", status="failed")
logger.debug(f"creating temporary internal payment with id {internal_id}")
# create a new payment from this wallet
fee_reserve_total_msat = fee_reserve_total(
invoice.amount_msat, internal=True
)
create_payment_model.fee = service_fee(invoice.amount_msat, True)
new_payment = await create_payment(
checking_id=internal_id,
data=create_payment_model,
status=PaymentState.SUCCESS,
conn=conn,
)
else:
new_payment = await _create_external_payment(
temp_id=temp_id,
amount_msat=invoice.amount_msat,
data=create_payment_model,
conn=conn,
)
# do the balance check
wallet = await get_wallet(wallet_id, conn=conn)
assert wallet, "Wallet for balancecheck could not be fetched"
fee_reserve_total_msat = fee_reserve_total(invoice.amount_msat, internal=False)
_check_wallet_balance(wallet, fee_reserve_total_msat, internal_checking_id)
if extra and "tag" in extra:
# check if the payment is made for an extension that the user disabled
status = await check_user_extension_access(wallet.user, extra["tag"])
if not status.success:
raise PaymentError(status.message)
if internal_checking_id:
service_fee_msat = service_fee(invoice.amount_msat, internal=True)
logger.debug(f"marking temporary payment as not pending {internal_checking_id}")
# mark the invoice from the other side as not pending anymore
# so the other side only has access to his new money when we are sure
# the payer has enough to deduct from
async with db.connect() as conn:
await update_payment_status(
checking_id=internal_checking_id,
status=PaymentState.SUCCESS,
conn=conn,
)
await send_payment_notification(wallet, new_payment)
# notify receiver asynchronously
from lnbits.tasks import internal_invoice_queue
logger.debug(f"enqueuing internal invoice {internal_checking_id}")
await internal_invoice_queue.put(internal_checking_id)
else:
fee_reserve_msat = fee_reserve(invoice.amount_msat, internal=False)
service_fee_msat = service_fee(invoice.amount_msat, internal=False)
logger.debug(f"backend: sending payment {temp_id}")
# actually pay the external invoice
funding_source = get_funding_source()
payment: PaymentResponse = await funding_source.pay_invoice(
payment_request, fee_reserve_msat
)
if payment.checking_id and payment.checking_id != temp_id:
logger.warning(
f"backend sent unexpected checking_id (expected: {temp_id} got:"
f" {payment.checking_id})"
)
logger.debug(f"backend: pay_invoice finished {temp_id}, {payment}")
if payment.checking_id and payment.ok is not False:
# payment.ok can be True (paid) or None (pending)!
logger.debug(f"updating payment {temp_id}")
async with db.connect() as conn:
await update_payment_details(
checking_id=temp_id,
status=(
PaymentState.SUCCESS
if payment.ok is True
else PaymentState.PENDING
),
fee=-(
abs(payment.fee_msat if payment.fee_msat else 0)
+ abs(service_fee_msat)
),
preimage=payment.preimage,
new_checking_id=payment.checking_id,
conn=conn,
)
wallet = await get_wallet(wallet_id, conn=conn)
updated = await get_wallet_payment(
wallet_id, payment.checking_id, conn=conn
)
if wallet and updated and updated.success:
await send_payment_notification(wallet, updated)
logger.success(f"payment successful {payment.checking_id}")
elif payment.checking_id is None and payment.ok is False:
# payment failed
logger.debug(f"payment failed {temp_id}, {payment.error_message}")
async with db.connect() as conn:
await update_payment_status(
checking_id=temp_id,
status=PaymentState.FAILED,
conn=conn,
)
raise PaymentError(
f"Payment failed: {payment.error_message}"
or "Payment failed, but backend didn't give us an error message.",
status="failed",
)
else:
logger.warning(
"didn't receive checking_id from backend, payment may be stuck in"
f" database: {temp_id}"
)
# credit service fee wallet
if settings.lnbits_service_fee_wallet and service_fee_msat:
create_payment_model = CreatePayment(
wallet_id=settings.lnbits_service_fee_wallet,
payment_request=payment_request,
payment_hash=invoice.payment_hash,
amount=abs(service_fee_msat),
memo="Service fee",
)
new_payment = await create_payment(
checking_id=f"service_fee_{temp_id}",
data=create_payment_model,
status=PaymentState.SUCCESS,
)
return invoice.payment_hash
async def _create_external_payment(
temp_id: str,
amount_msat: MilliSatoshi,
data: CreatePayment,
conn: Optional[Connection],
) -> Payment:
fee_reserve_total_msat = fee_reserve_total(amount_msat, internal=False)
# check if there is already a payment with the same checking_id
old_payment = await get_standalone_payment(temp_id, conn=conn)
if old_payment:
# fail on pending payments
if old_payment.pending:
raise PaymentError("Payment is still pending.", status="pending")
if old_payment.success:
raise PaymentError("Payment already paid.", status="success")
if old_payment.failed:
status = await old_payment.check_status()
if status.success:
# payment was successful on the fundingsource
await update_payment_status(
checking_id=temp_id, status=PaymentState.SUCCESS, conn=conn
)
raise PaymentError(
"Failed payment was already paid on the fundingsource.",
status="success",
)
if status.failed:
raise PaymentError(
"Payment is failed node, retrying is not possible.", status="failed"
)
# status.pending fall through and try again
return old_payment
logger.debug(f"creating temporary payment with id {temp_id}")
# create a temporary payment here so we can check if
# the balance is enough in the next step
try:
data.fee = -abs(fee_reserve_total_msat)
new_payment = await create_payment(
checking_id=temp_id,
data=data,
conn=conn,
)
return new_payment
except Exception as exc:
logger.error(f"could not create temporary payment: {exc}")
# happens if the same wallet tries to pay an invoice twice
raise PaymentError("Could not make payment", status="failed") from exc
def _check_wallet_balance(
wallet: Wallet,
fee_reserve_total_msat: int,
internal_checking_id: Optional[str] = None,
):
if wallet.balance_msat < 0:
logger.debug("balance is too low, deleting temporary payment")
if not internal_checking_id and wallet.balance_msat > -fee_reserve_total_msat:
raise PaymentError(
f"You must reserve at least ({round(fee_reserve_total_msat/1000)}"
" sat) to cover potential routing fees.",
status="failed",
)
raise PaymentError("Insufficient balance.", status="failed")
async def check_wallet_limits(wallet_id, conn, amount_msat):
await check_time_limit_between_transactions(conn, wallet_id)
await check_wallet_daily_withdraw_limit(conn, wallet_id, amount_msat)
async def check_time_limit_between_transactions(conn, wallet_id):
limit = settings.lnbits_wallet_limit_secs_between_trans
if not limit or limit <= 0:
return
payments = await get_payments(
since=int(time.time()) - limit,
wallet_id=wallet_id,
limit=1,
conn=conn,
)
if len(payments) == 0:
return
raise PaymentError(
status="failed",
message=f"The time limit of {limit} seconds between payments has been reached.",
)
async def check_wallet_daily_withdraw_limit(conn, wallet_id, amount_msat):
limit = settings.lnbits_wallet_limit_daily_max_withdraw
if not limit:
return
if limit < 0:
raise ValueError("It is not allowed to spend funds from this server.")
payments = await get_payments(
since=int(time.time()) - 60 * 60 * 24,
outgoing=True,
wallet_id=wallet_id,
limit=1,
conn=conn,
)
if len(payments) == 0:
return
total = 0
for pay in payments:
total += pay.amount
total = total - amount_msat
if limit * 1000 + total < 0:
raise ValueError(
"Daily withdrawal limit of "
+ str(settings.lnbits_wallet_limit_daily_max_withdraw)
+ " sats reached."
)
async def redeem_lnurl_withdraw(
wallet_id: str,
lnurl_request: str,
memo: Optional[str] = None,
extra: Optional[dict] = None,
wait_seconds: int = 0,
conn: Optional[Connection] = None,
) -> None:
if not lnurl_request:
return None
res = {}
headers = {"User-Agent": settings.user_agent}
async with httpx.AsyncClient(headers=headers) as client:
lnurl = decode_lnurl(lnurl_request)
r = await client.get(str(lnurl))
res = r.json()
try:
_, payment_request = await create_invoice(
wallet_id=wallet_id,
amount=int(res["maxWithdrawable"] / 1000),
memo=memo or res["defaultDescription"] or "",
extra=extra,
conn=conn,
)
except Exception:
logger.warning(
f"failed to create invoice on redeem_lnurl_withdraw "
f"from {lnurl}. params: {res}"
)
return None
if wait_seconds:
await asyncio.sleep(wait_seconds)
params = {"k1": res["k1"], "pr": payment_request}
try:
params["balanceNotify"] = url_for(
f"/withdraw/notify/{urlparse(lnurl_request).netloc}",
external=True,
wal=wallet_id,
)
except Exception:
pass
headers = {"User-Agent": settings.user_agent}
async with httpx.AsyncClient(headers=headers) as client:
try:
await client.get(res["callback"], params=params)
except Exception:
pass
async def perform_lnurlauth(
callback: str,
wallet: WalletTypeInfo = Depends(require_admin_key),
) -> Optional[LnurlErrorResponse]:
cb = urlparse(callback)
k1 = bytes.fromhex(parse_qs(cb.query)["k1"][0])
key = wallet.wallet.lnurlauth_key(cb.netloc)
def int_to_bytes_suitable_der(x: int) -> bytes:
"""for strict DER we need to encode the integer with some quirks"""
b = x.to_bytes((x.bit_length() + 7) // 8, "big")
if len(b) == 0:
# ensure there's at least one byte when the int is zero
return bytes([0])
if b[0] & 0x80 != 0:
# ensure it doesn't start with a 0x80 and so it isn't
# interpreted as a negative number
return bytes([0]) + b
return b
def encode_strict_der(r: int, s: int, order: int):
# if s > order/2 verification will fail sometimes
# so we must fix it here see:
# https://github.com/indutny/elliptic/blob/e71b2d9359c5fe9437fbf46f1f05096de447de57/lib/elliptic/ec/index.js#L146-L147
if s > order // 2:
s = order - s
# now we do the strict DER encoding copied from
# https://github.com/KiriKiri/bip66 (without any checks)
r_temp = int_to_bytes_suitable_der(r)
s_temp = int_to_bytes_suitable_der(s)
r_len = len(r_temp)
s_len = len(s_temp)
sign_len = 6 + r_len + s_len
signature = BytesIO()
signature.write(0x30.to_bytes(1, "big", signed=False))
signature.write((sign_len - 2).to_bytes(1, "big", signed=False))
signature.write(0x02.to_bytes(1, "big", signed=False))
signature.write(r_len.to_bytes(1, "big", signed=False))
signature.write(r_temp)
signature.write(0x02.to_bytes(1, "big", signed=False))
signature.write(s_len.to_bytes(1, "big", signed=False))
signature.write(s_temp)
return signature.getvalue()
sig = key.sign_digest_deterministic(k1, sigencode=encode_strict_der)
headers = {"User-Agent": settings.user_agent}
async with httpx.AsyncClient(headers=headers) as client:
assert key.verifying_key, "LNURLauth verifying_key does not exist"
r = await client.get(
callback,
params={
"k1": k1.hex(),
"key": key.verifying_key.to_string("compressed").hex(),
"sig": sig.hex(),
},
)
try:
resp = json.loads(r.text)
if resp["status"] == "OK":
return None
return LnurlErrorResponse(reason=resp["reason"])
except (KeyError, json.decoder.JSONDecodeError):
return LnurlErrorResponse(
reason=r.text[:200] + "..." if len(r.text) > 200 else r.text
)
async def check_transaction_status(
wallet_id: str, payment_hash: str, conn: Optional[Connection] = None
) -> PaymentStatus:
payment: Optional[Payment] = await get_wallet_payment(
wallet_id, payment_hash, conn=conn
)
if not payment:
return PaymentPendingStatus()
if payment.status == PaymentState.SUCCESS.value:
return PaymentSuccessStatus(fee_msat=payment.fee)
return await payment.check_status()
# WARN: this same value must be used for balance check and passed to
# funding_source.pay_invoice(), it may cause a vulnerability if the values differ
def fee_reserve(amount_msat: int, internal: bool = False) -> int:
if internal:
return 0
reserve_min = settings.lnbits_reserve_fee_min
reserve_percent = settings.lnbits_reserve_fee_percent
return max(int(reserve_min), int(amount_msat * reserve_percent / 100.0))
def service_fee(amount_msat: int, internal: bool = False) -> int:
amount_msat = abs(amount_msat)
service_fee_percent = settings.lnbits_service_fee
fee_max = settings.lnbits_service_fee_max * 1000
if settings.lnbits_service_fee_wallet:
if internal and settings.lnbits_service_fee_ignore_internal:
return 0
fee_percentage = int(amount_msat / 100 * service_fee_percent)
if fee_max > 0 and fee_percentage > fee_max:
return fee_max
else:
return fee_percentage
else:
return 0
def fee_reserve_total(amount_msat: int, internal: bool = False) -> int:
return fee_reserve(amount_msat, internal) + service_fee(amount_msat, internal)
async def send_payment_notification(wallet: Wallet, payment: Payment):
await websocket_updater(
wallet.inkey,
json.dumps(
{
"wallet_balance": wallet.balance,
"payment": payment.dict(),
}
),
)
await websocket_updater(
payment.payment_hash, json.dumps({"pending": payment.pending})
)
async def update_wallet_balance(wallet_id: str, amount: int):
payment_hash, _ = await create_invoice(
wallet_id=wallet_id,
amount=amount,
memo="Admin top up",
internal=True,
)
async with db.connect() as conn:
checking_id = await check_internal(payment_hash, conn=conn)
assert checking_id, "newly created checking_id cannot be retrieved"
await update_payment_status(
checking_id=checking_id, status=PaymentState.SUCCESS, conn=conn
)
# notify receiver asynchronously
from lnbits.tasks import internal_invoice_queue
await internal_invoice_queue.put(checking_id)
async def check_admin_settings():
if settings.super_user:
settings.super_user = to_valid_user_id(settings.super_user).hex
if settings.lnbits_admin_ui:
settings_db = await get_super_settings()
if not settings_db:
# create new settings if table is empty
logger.warning("Settings DB empty. Inserting default settings.")
settings_db = await init_admin_settings(settings.super_user)
logger.warning("Initialized settings from environment variables.")
if settings.super_user and settings.super_user != settings_db.super_user:
# .env super_user overwrites DB super_user
settings_db = await update_super_user(settings.super_user)
update_cached_settings(settings_db.dict())
# saving superuser to {data_dir}/.super_user file
with open(Path(settings.lnbits_data_folder) / ".super_user", "w") as file:
file.write(settings.super_user)
# callback for saas
if (
settings.lnbits_saas_callback
and settings.lnbits_saas_secret
and settings.lnbits_saas_instance_id
):
send_admin_user_to_saas()
account = await get_account(settings.super_user)
if account and account.config and account.config.provider == "env":
settings.first_install = True
logger.success(
"✔️ Admin UI is enabled. run `poetry run lnbits-cli superuser` "
"to get the superuser."
)
async def check_webpush_settings():
if not settings.lnbits_webpush_privkey:
vapid = Vapid()
vapid.generate_keys()
privkey = vapid.private_pem()
assert vapid.public_key, "VAPID public key does not exist"
pubkey = b64urlencode(
vapid.public_key.public_bytes(
serialization.Encoding.X962,
serialization.PublicFormat.UncompressedPoint,
)
)
push_settings = {
"lnbits_webpush_privkey": privkey.decode(),
"lnbits_webpush_pubkey": pubkey,
}
update_cached_settings(push_settings)
await update_admin_settings(EditableSettings(**push_settings))
logger.info("Initialized webpush settings with generated VAPID key pair.")
logger.info(f"Pubkey: {settings.lnbits_webpush_pubkey}")
def update_cached_settings(sets_dict: dict):
for key, value in sets_dict.items():
if key in readonly_variables:
continue
if key not in settings.dict().keys():
continue
try:
setattr(settings, key, value)
except Exception:
logger.warning(f"Failed overriding setting: {key}, value: {value}")
if "super_user" in sets_dict:
settings.super_user = sets_dict["super_user"]
async def init_admin_settings(super_user: Optional[str] = None) -> SuperSettings:
account = None
if super_user:
account = await get_account(super_user)
if not account:
account = await create_account(
user_id=super_user, user_config=UserConfig(provider="env")
)
if not account.wallets or len(account.wallets) == 0:
await create_wallet(user_id=account.id)
editable_settings = EditableSettings.from_dict(settings.dict())
return await create_admin_settings(account.id, editable_settings.dict())
async def create_user_account(
user_id: Optional[str] = None,
email: Optional[str] = None,
username: Optional[str] = None,
pubkey: Optional[str] = None,
password: Optional[str] = None,
wallet_name: Optional[str] = None,
user_config: Optional[UserConfig] = None,
) -> User:
if not settings.new_accounts_allowed:
raise ValueError("Account creation is disabled.")
if username and await get_account_by_username(username):
raise ValueError("Username already exists.")
if email and await get_account_by_email(email):
raise ValueError("Email already exists.")
if user_id:
user_uuid4 = UUID(hex=user_id, version=4)
assert user_uuid4.hex == user_id, "User ID is not valid UUID4 hex string"
else:
user_id = uuid4().hex
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
password = pwd_context.hash(password) if password else None
account = await create_account(
user_id, username, pubkey, email, password, user_config
)
wallet = await create_wallet(user_id=account.id, wallet_name=wallet_name)
account.wallets = [wallet]
for ext_id in settings.lnbits_user_default_extensions:
await update_user_extension(user_id=account.id, extension=ext_id, active=True)
return account
class WebsocketConnectionManager:
def __init__(self) -> None:
self.active_connections: list[WebSocket] = []
async def connect(self, websocket: WebSocket, item_id: str):
logger.debug(f"Websocket connected to {item_id}")
await websocket.accept()
self.active_connections.append(websocket)
def disconnect(self, websocket: WebSocket):
self.active_connections.remove(websocket)
async def send_data(self, message: str, item_id: str):
for connection in self.active_connections:
if connection.path_params["item_id"] == item_id:
await connection.send_text(message)
websocket_manager = WebsocketConnectionManager()
async def websocket_updater(item_id, data):
return await websocket_manager.send_data(f"{data}", item_id)
async def switch_to_voidwallet() -> None:
funding_source = get_funding_source()
if funding_source.__class__.__name__ == "VoidWallet":
return
set_funding_source("VoidWallet")
settings.lnbits_backend_wallet_class = "VoidWallet"
async def get_balance_delta() -> BalanceDelta:
funding_source = get_funding_source()
status = await funding_source.status()
lnbits_balance = await get_total_balance()
return BalanceDelta(
lnbits_balance_msats=lnbits_balance,
node_balance_msats=status.balance_msat,
)
async def update_pending_payments(wallet_id: str):
pending_payments = await get_payments(
wallet_id=wallet_id,
pending=True,
exclude_uncheckable=True,
)
for payment in pending_payments:
status = await payment.check_status()
if status.failed:
await update_payment_status(
checking_id=payment.checking_id,
status=PaymentState.FAILED,
)
elif status.success:
await update_payment_status(
checking_id=payment.checking_id,
status=PaymentState.SUCCESS,
)

View file

@ -0,0 +1,55 @@
from .funding_source import (
get_balance_delta,
switch_to_voidwallet,
)
from .lnurl import perform_lnurlauth, redeem_lnurl_withdraw
from .payments import (
calculate_fiat_amounts,
check_transaction_status,
check_wallet_limits,
create_invoice,
fee_reserve,
fee_reserve_total,
pay_invoice,
send_payment_notification,
service_fee,
update_pending_payments,
update_wallet_balance,
)
from .settings import (
check_webpush_settings,
update_cached_settings,
)
from .users import check_admin_settings, create_user_account, init_admin_settings
from .websockets import websocket_manager, websocket_updater
__all__ = [
# funding source
"get_balance_delta",
"switch_to_voidwallet",
# lnurl
"redeem_lnurl_withdraw",
"perform_lnurlauth",
# payments
"calculate_fiat_amounts",
"check_transaction_status",
"check_wallet_limits",
"create_invoice",
"fee_reserve",
"fee_reserve_total",
"pay_invoice",
"send_payment_notification",
"service_fee",
"update_pending_payments",
"update_wallet_balance",
# settings
"check_webpush_settings",
"update_cached_settings",
# users
"check_admin_settings",
"create_user_account",
"init_admin_settings",
# websockets
"websocket_manager",
"websocket_updater",
]

View file

@ -0,0 +1,23 @@
from lnbits.settings import settings
from lnbits.wallets import get_funding_source, set_funding_source
from ..crud import get_total_balance
from ..models import BalanceDelta
async def switch_to_voidwallet() -> None:
funding_source = get_funding_source()
if funding_source.__class__.__name__ == "VoidWallet":
return
set_funding_source("VoidWallet")
settings.lnbits_backend_wallet_class = "VoidWallet"
async def get_balance_delta() -> BalanceDelta:
funding_source = get_funding_source()
status = await funding_source.status()
lnbits_balance = await get_total_balance()
return BalanceDelta(
lnbits_balance_msats=lnbits_balance,
node_balance_msats=status.balance_msat,
)

View file

@ -0,0 +1,155 @@
import asyncio
import json
from io import BytesIO
from typing import Optional
from urllib.parse import parse_qs, urlparse
import httpx
from fastapi import Depends
from loguru import logger
from lnbits.db import Connection
from lnbits.decorators import (
WalletTypeInfo,
require_admin_key,
)
from lnbits.helpers import url_for
from lnbits.lnurl import LnurlErrorResponse
from lnbits.lnurl import decode as decode_lnurl
from lnbits.settings import settings
from .payments import create_invoice
async def redeem_lnurl_withdraw(
wallet_id: str,
lnurl_request: str,
memo: Optional[str] = None,
extra: Optional[dict] = None,
wait_seconds: int = 0,
conn: Optional[Connection] = None,
) -> None:
if not lnurl_request:
return None
res = {}
headers = {"User-Agent": settings.user_agent}
async with httpx.AsyncClient(headers=headers) as client:
lnurl = decode_lnurl(lnurl_request)
r = await client.get(str(lnurl))
res = r.json()
try:
_, payment_request = await create_invoice(
wallet_id=wallet_id,
amount=int(res["maxWithdrawable"] / 1000),
memo=memo or res["defaultDescription"] or "",
extra=extra,
conn=conn,
)
except Exception:
logger.warning(
f"failed to create invoice on redeem_lnurl_withdraw "
f"from {lnurl}. params: {res}"
)
return None
if wait_seconds:
await asyncio.sleep(wait_seconds)
params = {"k1": res["k1"], "pr": payment_request}
try:
params["balanceNotify"] = url_for(
f"/withdraw/notify/{urlparse(lnurl_request).netloc}",
external=True,
wal=wallet_id,
)
except Exception:
pass
headers = {"User-Agent": settings.user_agent}
async with httpx.AsyncClient(headers=headers) as client:
try:
await client.get(res["callback"], params=params)
except Exception:
pass
async def perform_lnurlauth(
callback: str,
wallet: WalletTypeInfo = Depends(require_admin_key),
) -> Optional[LnurlErrorResponse]:
cb = urlparse(callback)
k1 = bytes.fromhex(parse_qs(cb.query)["k1"][0])
key = wallet.wallet.lnurlauth_key(cb.netloc)
def int_to_bytes_suitable_der(x: int) -> bytes:
"""for strict DER we need to encode the integer with some quirks"""
b = x.to_bytes((x.bit_length() + 7) // 8, "big")
if len(b) == 0:
# ensure there's at least one byte when the int is zero
return bytes([0])
if b[0] & 0x80 != 0:
# ensure it doesn't start with a 0x80 and so it isn't
# interpreted as a negative number
return bytes([0]) + b
return b
def encode_strict_der(r: int, s: int, order: int):
# if s > order/2 verification will fail sometimes
# so we must fix it here see:
# https://github.com/indutny/elliptic/blob/e71b2d9359c5fe9437fbf46f1f05096de447de57/lib/elliptic/ec/index.js#L146-L147
if s > order // 2:
s = order - s
# now we do the strict DER encoding copied from
# https://github.com/KiriKiri/bip66 (without any checks)
r_temp = int_to_bytes_suitable_der(r)
s_temp = int_to_bytes_suitable_der(s)
r_len = len(r_temp)
s_len = len(s_temp)
sign_len = 6 + r_len + s_len
signature = BytesIO()
signature.write(0x30.to_bytes(1, "big", signed=False))
signature.write((sign_len - 2).to_bytes(1, "big", signed=False))
signature.write(0x02.to_bytes(1, "big", signed=False))
signature.write(r_len.to_bytes(1, "big", signed=False))
signature.write(r_temp)
signature.write(0x02.to_bytes(1, "big", signed=False))
signature.write(s_len.to_bytes(1, "big", signed=False))
signature.write(s_temp)
return signature.getvalue()
sig = key.sign_digest_deterministic(k1, sigencode=encode_strict_der)
headers = {"User-Agent": settings.user_agent}
async with httpx.AsyncClient(headers=headers) as client:
assert key.verifying_key, "LNURLauth verifying_key does not exist"
r = await client.get(
callback,
params={
"k1": k1.hex(),
"key": key.verifying_key.to_string("compressed").hex(),
"sig": sig.hex(),
},
)
try:
resp = json.loads(r.text)
if resp["status"] == "OK":
return None
return LnurlErrorResponse(reason=resp["reason"])
except (KeyError, json.decoder.JSONDecodeError):
return LnurlErrorResponse(
reason=r.text[:200] + "..." if len(r.text) > 200 else r.text
)

View file

@ -0,0 +1,580 @@
import json
import time
from typing import Optional
from bolt11 import decode as bolt11_decode
from bolt11.types import Bolt11
from loguru import logger
from lnbits.core.db import db
from lnbits.db import Connection
from lnbits.decorators import check_user_extension_access
from lnbits.exceptions import InvoiceError, PaymentError
from lnbits.settings import settings
from lnbits.utils.exchange_rates import fiat_amount_as_satoshis, satoshis_amount_as_fiat
from lnbits.wallets import fake_wallet, get_funding_source
from lnbits.wallets.base import (
PaymentPendingStatus,
PaymentResponse,
PaymentStatus,
PaymentSuccessStatus,
)
from ..crud import (
check_internal,
create_payment,
get_payments,
get_standalone_payment,
get_wallet,
get_wallet_payment,
is_internal_status_success,
update_payment,
)
from ..models import (
CreatePayment,
Payment,
PaymentState,
Wallet,
)
from .websockets import websocket_manager
async def pay_invoice(
*,
wallet_id: str,
payment_request: str,
max_sat: Optional[int] = None,
extra: Optional[dict] = None,
description: str = "",
tag: str = "",
conn: Optional[Connection] = None,
) -> Payment:
invoice = _validate_payment_request(payment_request, max_sat)
assert invoice.amount_msat
async with db.reuse_conn(conn) if conn else db.connect() as conn:
amount_msat = invoice.amount_msat
wallet = await _check_wallet_for_payment(wallet_id, tag, amount_msat, conn)
if await is_internal_status_success(invoice.payment_hash, conn):
raise PaymentError("Internal invoice already paid.", status="failed")
_, extra = await calculate_fiat_amounts(amount_msat / 1000, wallet, extra=extra)
create_payment_model = CreatePayment(
wallet_id=wallet_id,
bolt11=payment_request,
payment_hash=invoice.payment_hash,
amount_msat=-amount_msat,
expiry=invoice.expiry_date,
memo=description or invoice.description or "",
extra=extra,
)
payment = await _pay_invoice(wallet, create_payment_model, conn)
await _credit_service_fee_wallet(payment, conn)
return payment
async def create_invoice(
*,
wallet_id: str,
amount: float,
currency: Optional[str] = "sat",
memo: str,
description_hash: Optional[bytes] = None,
unhashed_description: Optional[bytes] = None,
expiry: Optional[int] = None,
extra: Optional[dict] = None,
webhook: Optional[str] = None,
internal: Optional[bool] = False,
conn: Optional[Connection] = None,
) -> Payment:
if not amount > 0:
raise InvoiceError("Amountless invoices not supported.", status="failed")
user_wallet = await get_wallet(wallet_id, conn=conn)
if not user_wallet:
raise InvoiceError(f"Could not fetch wallet '{wallet_id}'.", status="failed")
invoice_memo = None if description_hash else memo
# use the fake wallet if the invoice is for internal use only
funding_source = fake_wallet if internal else get_funding_source()
amount_sat, extra = await calculate_fiat_amounts(
amount, user_wallet, currency, extra
)
if settings.is_wallet_max_balance_exceeded(
user_wallet.balance_msat / 1000 + amount_sat
):
raise InvoiceError(
f"Wallet balance cannot exceed "
f"{settings.lnbits_wallet_limit_max_balance} sats.",
status="failed",
)
(
ok,
checking_id,
payment_request,
error_message,
) = await funding_source.create_invoice(
amount=amount_sat,
memo=invoice_memo,
description_hash=description_hash,
unhashed_description=unhashed_description,
expiry=expiry or settings.lightning_invoice_expiry,
)
if not ok or not payment_request or not checking_id:
raise InvoiceError(
error_message or "unexpected backend error.", status="pending"
)
invoice = bolt11_decode(payment_request)
create_payment_model = CreatePayment(
wallet_id=wallet_id,
bolt11=payment_request,
payment_hash=invoice.payment_hash,
amount_msat=amount_sat * 1000,
expiry=invoice.expiry_date,
memo=memo,
extra=extra,
webhook=webhook,
)
payment = await create_payment(
checking_id=checking_id,
data=create_payment_model,
conn=conn,
)
return payment
async def update_pending_payments(wallet_id: str):
pending_payments = await get_payments(
wallet_id=wallet_id,
pending=True,
exclude_uncheckable=True,
)
for payment in pending_payments:
status = await payment.check_status()
if status.failed:
payment.status = PaymentState.FAILED
await update_payment(payment)
elif status.success:
payment.status = PaymentState.SUCCESS
await update_payment(payment)
def fee_reserve_total(amount_msat: int, internal: bool = False) -> int:
return fee_reserve(amount_msat, internal) + service_fee(amount_msat, internal)
# WARN: this same value must be used for balance check and passed to
# funding_source.pay_invoice(), it may cause a vulnerability if the values differ
def fee_reserve(amount_msat: int, internal: bool = False) -> int:
if internal:
return 0
reserve_min = settings.lnbits_reserve_fee_min
reserve_percent = settings.lnbits_reserve_fee_percent
return max(int(reserve_min), int(amount_msat * reserve_percent / 100.0))
def service_fee(amount_msat: int, internal: bool = False) -> int:
amount_msat = abs(amount_msat)
service_fee_percent = settings.lnbits_service_fee
fee_max = settings.lnbits_service_fee_max * 1000
if settings.lnbits_service_fee_wallet:
if internal and settings.lnbits_service_fee_ignore_internal:
return 0
fee_percentage = int(amount_msat / 100 * service_fee_percent)
if fee_max > 0 and fee_percentage > fee_max:
return fee_max
else:
return fee_percentage
else:
return 0
async def update_wallet_balance(wallet_id: str, amount: int):
async with db.connect() as conn:
payment = await create_invoice(
wallet_id=wallet_id,
amount=amount,
memo="Admin top up",
internal=True,
conn=conn,
)
payment.status = PaymentState.SUCCESS
await update_payment(payment, conn=conn)
# notify receiver asynchronously
from lnbits.tasks import internal_invoice_queue
await internal_invoice_queue.put(payment.checking_id)
async def send_payment_notification(wallet: Wallet, payment: Payment):
# TODO: websocket message should be a clean payment model
# await websocket_manager.send_data(payment.json(), wallet.inkey)
# TODO: figure out why we send the balance with the payment here.
# cleaner would be to have a separate message for the balance
# and send it with the id of the wallet so wallets can subscribe to it
await websocket_manager.send_data(
json.dumps(
{
"wallet_balance": wallet.balance,
# use pydantic json serialization to get the correct datetime format
"payment": json.loads(payment.json()),
},
),
wallet.inkey,
)
await websocket_manager.send_data(
json.dumps({"pending": payment.pending}), payment.payment_hash
)
async def check_wallet_limits(
wallet_id: str, amount_msat: int, conn: Optional[Connection] = None
):
await check_time_limit_between_transactions(wallet_id, conn)
await check_wallet_daily_withdraw_limit(wallet_id, amount_msat, conn)
async def check_time_limit_between_transactions(
wallet_id: str, conn: Optional[Connection] = None
):
limit = settings.lnbits_wallet_limit_secs_between_trans
if not limit or limit <= 0:
return
payments = await get_payments(
since=int(time.time()) - limit,
wallet_id=wallet_id,
limit=1,
conn=conn,
)
if len(payments) == 0:
return
raise PaymentError(
status="failed",
message=f"The time limit of {limit} seconds between payments has been reached.",
)
async def check_wallet_daily_withdraw_limit(
wallet_id: str, amount_msat: int, conn: Optional[Connection] = None
):
limit = settings.lnbits_wallet_limit_daily_max_withdraw
if not limit:
return
if limit < 0:
raise ValueError("It is not allowed to spend funds from this server.")
payments = await get_payments(
since=int(time.time()) - 60 * 60 * 24,
outgoing=True,
wallet_id=wallet_id,
limit=1,
conn=conn,
)
if len(payments) == 0:
return
total = 0
for pay in payments:
total += pay.amount
total = total - amount_msat
if limit * 1000 + total < 0:
raise ValueError(
"Daily withdrawal limit of "
+ str(settings.lnbits_wallet_limit_daily_max_withdraw)
+ " sats reached."
)
async def calculate_fiat_amounts(
amount: float,
wallet: Wallet,
currency: Optional[str] = None,
extra: Optional[dict] = None,
) -> tuple[int, dict]:
wallet_currency = wallet.currency or settings.lnbits_default_accounting_currency
fiat_amounts: dict = extra or {}
if currency and currency != "sat":
amount_sat = await fiat_amount_as_satoshis(amount, currency)
if currency != wallet_currency:
fiat_amounts["fiat_currency"] = currency
fiat_amounts["fiat_amount"] = round(amount, ndigits=3)
fiat_amounts["fiat_rate"] = amount_sat / amount
else:
amount_sat = int(amount)
if wallet_currency:
if wallet_currency == currency:
fiat_amount = amount
else:
fiat_amount = await satoshis_amount_as_fiat(amount_sat, wallet_currency)
fiat_amounts["wallet_fiat_currency"] = wallet_currency
fiat_amounts["wallet_fiat_amount"] = round(fiat_amount, ndigits=3)
fiat_amounts["wallet_fiat_rate"] = amount_sat / fiat_amount
logger.debug(
f"Calculated fiat amounts {wallet.id=} {amount=} {currency=}: {fiat_amounts=}"
)
return amount_sat, fiat_amounts
async def check_transaction_status(
wallet_id: str, payment_hash: str, conn: Optional[Connection] = None
) -> PaymentStatus:
payment: Optional[Payment] = await get_wallet_payment(
wallet_id, payment_hash, conn=conn
)
if not payment:
return PaymentPendingStatus()
if payment.status == PaymentState.SUCCESS.value:
return PaymentSuccessStatus(fee_msat=payment.fee)
return await payment.check_status()
async def _pay_invoice(wallet, create_payment_model, conn):
payment = await _pay_internal_invoice(wallet, create_payment_model, conn)
if not payment:
payment = await _pay_external_invoice(wallet, create_payment_model, conn)
return payment
async def _pay_internal_invoice(
wallet: Wallet,
create_payment_model: CreatePayment,
conn: Optional[Connection] = None,
) -> Optional[Payment]:
"""
Pay an internal payment.
returns None if the payment is not internal.
"""
# check_internal() returns the payment of the invoice we're waiting for
# (pending only)
internal_payment = await check_internal(
create_payment_model.payment_hash, conn=conn
)
if not internal_payment:
return None
# perform additional checks on the internal payment
# the payment hash is not enough to make sure that this is the same invoice
internal_invoice = await get_standalone_payment(
internal_payment.checking_id, incoming=True, conn=conn
)
if not internal_invoice:
raise PaymentError("Internal payment not found.", status="failed")
amount_msat = create_payment_model.amount_msat
if (
internal_invoice.amount != abs(amount_msat)
or internal_invoice.bolt11 != create_payment_model.bolt11.lower()
):
raise PaymentError("Invalid invoice. Bolt11 changed.", status="failed")
fee_reserve_total_msat = fee_reserve_total(abs(amount_msat), internal=True)
create_payment_model.fee = abs(fee_reserve_total_msat)
if wallet.balance_msat < abs(amount_msat) + fee_reserve_total_msat:
raise PaymentError("Insufficient balance.", status="failed")
internal_id = f"internal_{create_payment_model.payment_hash}"
logger.debug(f"creating temporary internal payment with id {internal_id}")
payment = await create_payment(
checking_id=internal_id,
data=create_payment_model,
status=PaymentState.SUCCESS,
conn=conn,
)
# mark the invoice from the other side as not pending anymore
# so the other side only has access to his new money when we are sure
# the payer has enough to deduct from
internal_payment.status = PaymentState.SUCCESS
await update_payment(internal_payment, conn=conn)
logger.success(f"internal payment successful {internal_payment.checking_id}")
await send_payment_notification(wallet, payment)
# notify receiver asynchronously
from lnbits.tasks import internal_invoice_queue
logger.debug(f"enqueuing internal invoice {internal_payment.checking_id}")
await internal_invoice_queue.put(internal_payment.checking_id)
return payment
async def _pay_external_invoice(
wallet: Wallet,
create_payment_model: CreatePayment,
conn: Optional[Connection] = None,
) -> Payment:
checking_id = create_payment_model.payment_hash
amount_msat = create_payment_model.amount_msat
fee_reserve_total_msat = fee_reserve_total(amount_msat, internal=False)
if wallet.balance_msat < abs(amount_msat) + fee_reserve_total_msat:
raise PaymentError(
f"You must reserve at least ({round(fee_reserve_total_msat/1000)}"
" sat) to cover potential routing fees.",
status="failed",
)
# check if there is already a payment with the same checking_id
old_payment = await get_standalone_payment(checking_id, conn=conn)
if old_payment:
return await _verify_external_payment(old_payment, conn)
create_payment_model.fee = -abs(fee_reserve_total_msat)
payment = await create_payment(
checking_id=checking_id,
data=create_payment_model,
conn=conn,
)
fee_reserve_msat = fee_reserve(amount_msat, internal=False)
service_fee_msat = service_fee(amount_msat, internal=False)
funding_source = get_funding_source()
logger.debug(f"fundingsource: sending payment {checking_id}")
payment_response: PaymentResponse = await funding_source.pay_invoice(
create_payment_model.bolt11, fee_reserve_msat
)
logger.debug(f"backend: pay_invoice finished {checking_id}, {payment_response}")
if payment_response.checking_id and payment_response.checking_id != checking_id:
logger.warning(
f"backend sent unexpected checking_id (expected: {checking_id} got:"
f" {payment_response.checking_id})"
)
if payment_response.checking_id and payment_response.ok is not False:
# payment.ok can be True (paid) or None (pending)!
logger.debug(f"updating payment {checking_id}")
payment.status = (
PaymentState.SUCCESS
if payment_response.ok is True
else PaymentState.PENDING
)
payment.fee = -(abs(payment_response.fee_msat or 0) + abs(service_fee_msat))
payment.preimage = payment_response.preimage
await update_payment(payment, payment_response.checking_id, conn=conn)
payment.checking_id = payment_response.checking_id
if payment.success:
await send_payment_notification(wallet, payment)
logger.success(f"payment successful {payment_response.checking_id}")
elif payment_response.checking_id is None and payment_response.ok is False:
# payment failed
logger.debug(f"payment failed {checking_id}, {payment_response.error_message}")
payment.status = PaymentState.FAILED
await update_payment(payment, conn=conn)
raise PaymentError(
f"Payment failed: {payment_response.error_message}"
or "Payment failed, but backend didn't give us an error message.",
status="failed",
)
else:
logger.warning(
"didn't receive checking_id from backend, payment may be stuck in"
f" database: {checking_id}"
)
return payment
async def _verify_external_payment(
payment: Payment, conn: Optional[Connection] = None
) -> Payment:
# fail on pending payments
if payment.pending:
raise PaymentError("Payment is still pending.", status="pending")
if payment.success:
raise PaymentError("Payment already paid.", status="success")
# payment failed
status = await payment.check_status()
if status.failed:
raise PaymentError(
"Payment is failed node, retrying is not possible.", status="failed"
)
if status.success:
# payment was successful on the fundingsource
payment.status = PaymentState.SUCCESS
await update_payment(payment, conn=conn)
raise PaymentError(
"Failed payment was already paid on the fundingsource.",
status="success",
)
# status.pending fall through and try again
return payment
async def _check_wallet_for_payment(
wallet_id: str,
tag: str,
amount_msat: int,
conn: Optional[Connection],
):
wallet = await get_wallet(wallet_id, conn=conn)
if not wallet:
raise PaymentError(f"Could not fetch wallet '{wallet_id}'.", status="failed")
# check if the payment is made for an extension that the user disabled
status = await check_user_extension_access(wallet.user, tag)
if not status.success:
raise PaymentError(status.message)
await check_wallet_limits(wallet_id, amount_msat, conn)
return wallet
def _validate_payment_request(
payment_request: str, max_sat: Optional[int] = None
) -> Bolt11:
try:
invoice = bolt11_decode(payment_request)
except Exception as exc:
raise PaymentError("Bolt11 decoding failed.", status="failed") from exc
if not invoice.amount_msat or not invoice.amount_msat > 0:
raise PaymentError("Amountless invoices not supported.", status="failed")
if max_sat and invoice.amount_msat > max_sat * 1000:
raise PaymentError("Amount in invoice is too high.", status="failed")
return invoice
async def _credit_service_fee_wallet(
payment: Payment, conn: Optional[Connection] = None
):
service_fee_msat = service_fee(payment.amount, internal=payment.is_internal)
if not settings.lnbits_service_fee_wallet or not service_fee_msat:
return
create_payment_model = CreatePayment(
wallet_id=settings.lnbits_service_fee_wallet,
bolt11=payment.bolt11,
payment_hash=payment.payment_hash,
amount_msat=abs(service_fee_msat),
memo="Service fee",
)
await create_payment(
checking_id=f"service_fee_{payment.payment_hash}",
data=create_payment_model,
status=PaymentState.SUCCESS,
conn=conn,
)

View file

@ -0,0 +1,49 @@
from cryptography.hazmat.primitives import serialization
from loguru import logger
from py_vapid import Vapid
from py_vapid.utils import b64urlencode
from lnbits.settings import (
EditableSettings,
readonly_variables,
settings,
)
from ..crud import update_admin_settings
async def check_webpush_settings():
if not settings.lnbits_webpush_privkey:
vapid = Vapid()
vapid.generate_keys()
privkey = vapid.private_pem()
assert vapid.public_key, "VAPID public key does not exist"
pubkey = b64urlencode(
vapid.public_key.public_bytes(
serialization.Encoding.X962,
serialization.PublicFormat.UncompressedPoint,
)
)
push_settings = {
"lnbits_webpush_privkey": privkey.decode(),
"lnbits_webpush_pubkey": pubkey,
}
update_cached_settings(push_settings)
await update_admin_settings(EditableSettings(**push_settings))
logger.info("Initialized webpush settings with generated VAPID key pair.")
logger.info(f"Pubkey: {settings.lnbits_webpush_pubkey}")
def update_cached_settings(sets_dict: dict):
for key, value in sets_dict.items():
if key in readonly_variables:
continue
if key not in settings.dict().keys():
continue
try:
setattr(settings, key, value)
except Exception:
logger.warning(f"Failed overriding setting: {key}, value: {value}")
if "super_user" in sets_dict:
settings.super_user = sets_dict["super_user"]

View file

@ -0,0 +1,128 @@
from pathlib import Path
from typing import Optional
from uuid import UUID, uuid4
from loguru import logger
from lnbits.core.extensions.models import UserExtension
from lnbits.settings import (
EditableSettings,
SuperSettings,
send_admin_user_to_saas,
settings,
)
from ..crud import (
create_account,
create_admin_settings,
create_wallet,
get_account,
get_account_by_email,
get_account_by_pubkey,
get_account_by_username,
get_super_settings,
get_user_from_account,
update_super_user,
update_user_extension,
)
from ..helpers import to_valid_user_id
from ..models import (
Account,
User,
UserExtra,
)
from .settings import update_cached_settings
async def create_user_account(
account: Optional[Account] = None, wallet_name: Optional[str] = None
) -> User:
if not settings.new_accounts_allowed:
raise ValueError("Account creation is disabled.")
if account:
if account.username and await get_account_by_username(account.username):
raise ValueError("Username already exists.")
if account.email and await get_account_by_email(account.email):
raise ValueError("Email already exists.")
if account.pubkey and await get_account_by_pubkey(account.pubkey):
raise ValueError("Pubkey already exists.")
if account.id:
user_uuid4 = UUID(hex=account.id, version=4)
assert user_uuid4.hex == account.id, "User ID is not valid UUID4 hex string"
else:
account.id = uuid4().hex
account = await create_account(account)
await create_wallet(
user_id=account.id,
wallet_name=wallet_name or settings.lnbits_default_wallet_name,
)
for ext_id in settings.lnbits_user_default_extensions:
user_ext = UserExtension(user=account.id, extension=ext_id, active=True)
await update_user_extension(user_ext)
user = await get_user_from_account(account)
assert user, "Cannot find user for account."
return user
async def check_admin_settings():
if settings.super_user:
settings.super_user = to_valid_user_id(settings.super_user).hex
if settings.lnbits_admin_ui:
settings_db = await get_super_settings()
if not settings_db:
# create new settings if table is empty
logger.warning("Settings DB empty. Inserting default settings.")
settings_db = await init_admin_settings(settings.super_user)
logger.warning("Initialized settings from environment variables.")
if settings.super_user and settings.super_user != settings_db.super_user:
# .env super_user overwrites DB super_user
settings_db = await update_super_user(settings.super_user)
update_cached_settings(settings_db.dict())
# saving superuser to {data_dir}/.super_user file
with open(Path(settings.lnbits_data_folder) / ".super_user", "w") as file:
file.write(settings.super_user)
# callback for saas
if (
settings.lnbits_saas_callback
and settings.lnbits_saas_secret
and settings.lnbits_saas_instance_id
):
send_admin_user_to_saas()
account = await get_account(settings.super_user)
if account and account.extra and account.extra.provider == "env":
settings.first_install = True
logger.success(
"✔️ Admin UI is enabled. run `poetry run lnbits-cli superuser` "
"to get the superuser."
)
async def init_admin_settings(super_user: Optional[str] = None) -> SuperSettings:
account = None
if super_user:
account = await get_account(super_user)
if not account:
account_id = super_user or uuid4().hex
account = Account(
id=account_id,
extra=UserExtra(provider="env"),
)
await create_account(account)
await create_wallet(user_id=account.id)
editable_settings = EditableSettings.from_dict(settings.dict())
return await create_admin_settings(account.id, editable_settings.dict())

View file

@ -0,0 +1,27 @@
from fastapi import WebSocket
from loguru import logger
class WebsocketConnectionManager:
def __init__(self) -> None:
self.active_connections: list[WebSocket] = []
async def connect(self, websocket: WebSocket, item_id: str):
logger.debug(f"Websocket connected to {item_id}")
await websocket.accept()
self.active_connections.append(websocket)
def disconnect(self, websocket: WebSocket):
self.active_connections.remove(websocket)
async def send_data(self, message: str, item_id: str):
for connection in self.active_connections:
if connection.path_params["item_id"] == item_id:
await connection.send_text(message)
websocket_manager = WebsocketConnectionManager()
async def websocket_updater(item_id: str, data: str):
return await websocket_manager.send_data(data, item_id)

View file

@ -16,14 +16,14 @@
<q-item dense class="q-pa-none"> <q-item dense class="q-pa-none">
<q-item-section> <q-item-section>
<q-item-label> <q-item-label>
<strong>Wallet ID: </strong><em>{{ wallet.id }}</em> <strong>Wallet ID: </strong><em v-text="wallet.id"></em>
</q-item-label> </q-item-label>
</q-item-section> </q-item-section>
<q-item-section side> <q-item-section side>
<q-icon <q-icon
name="content_copy" name="content_copy"
class="cursor-pointer" class="cursor-pointer"
@click="copyText('{{ wallet.id }}')" @click="copyText(wallet.id)"
></q-icon> ></q-icon>
</q-item-section> </q-item-section>
</q-item> </q-item>
@ -32,7 +32,7 @@
<q-item-label> <q-item-label>
<strong>Admin key: </strong <strong>Admin key: </strong
><em ><em
v-text="adminkeyHidden ? '****************' : `{{ wallet.adminkey }}`" v-text="adminkeyHidden ? '****************' : wallet.adminkey"
></em> ></em>
</q-item-label> </q-item-label>
</q-item-section> </q-item-section>
@ -46,7 +46,7 @@
<q-icon <q-icon
name="content_copy" name="content_copy"
class="cursor-pointer q-ml-sm" class="cursor-pointer q-ml-sm"
@click="copyText('{{ wallet.adminkey }}')" @click="copyText(wallet.adminkey)"
></q-icon> ></q-icon>
</div> </div>
</q-item-section> </q-item-section>
@ -55,9 +55,7 @@
<q-item-section> <q-item-section>
<q-item-label> <q-item-label>
<strong>Invoice/read key: </strong <strong>Invoice/read key: </strong
><em ><em v-text="inkeyHidden ? '****************' : wallet.inkey"></em>
v-text="inkeyHidden ? '****************' : `{{ wallet.inkey }}`"
></em>
</q-item-label> </q-item-label>
</q-item-section> </q-item-section>
<q-item-section side> <q-item-section side>
@ -70,7 +68,7 @@
<q-icon <q-icon
name="content_copy" name="content_copy"
class="cursor-pointer q-ml-sm" class="cursor-pointer q-ml-sm"
@click="copyText('{{ wallet.inkey }}')" @click="copyText(wallet.inkey)"
></q-icon> ></q-icon>
</div> </div>
</q-item-section> </q-item-section>
@ -87,7 +85,7 @@
<q-card-section> <q-card-section>
<code><span class="text-light-green">GET</span> /api/v1/wallet</code> <code><span class="text-light-green">GET</span> /api/v1/wallet</code>
<h5 class="text-caption q-mt-sm q-mb-none">Headers</h5> <h5 class="text-caption q-mt-sm q-mb-none">Headers</h5>
<code>{"X-Api-Key": "<i>{{ wallet.inkey }}</i>"}</code><br /> <code>{"X-Api-Key": "<i v-text="wallet.inkey"></i>"}</code><br />
<h5 class="text-caption q-mt-sm q-mb-none"> <h5 class="text-caption q-mt-sm q-mb-none">
Returns 200 OK (application/json) Returns 200 OK (application/json)
</h5> </h5>
@ -97,12 +95,13 @@
> >
<h5 class="text-caption q-mt-sm q-mb-none">Curl example</h5> <h5 class="text-caption q-mt-sm q-mb-none">Curl example</h5>
<code <code
>curl {{ request.base_url }}api/v1/wallet -H "X-Api-Key: >curl <span v-text="baseUrl"></span>api/v1/wallet -H "X-Api-Key:
<i>{{ wallet.inkey }}</i>"</code <i v-text="wallet.inkey"></i>"</code
> >
</q-card-section> </q-card-section>
</q-card> </q-card>
</q-expansion-item> </q-expansion-item>
<q-expansion-item <q-expansion-item
group="api" group="api"
dense dense
@ -113,7 +112,7 @@
<q-card-section> <q-card-section>
<code><span class="text-light-green">POST</span> /api/v1/payments</code> <code><span class="text-light-green">POST</span> /api/v1/payments</code>
<h5 class="text-caption q-mt-sm q-mb-none">Headers</h5> <h5 class="text-caption q-mt-sm q-mb-none">Headers</h5>
<code>{"X-Api-Key": "<i>{{ wallet.inkey }}</i>"}</code><br /> <code>{"X-Api-Key": "<i v-text="wallet.inkey"></i>"}</code><br />
<h5 class="text-caption q-mt-sm q-mb-none">Body (application/json)</h5> <h5 class="text-caption q-mt-sm q-mb-none">Body (application/json)</h5>
<code <code
>{"out": false, "amount": &lt;int&gt;, "memo": &lt;string&gt;, >{"out": false, "amount": &lt;int&gt;, "memo": &lt;string&gt;,
@ -129,9 +128,10 @@
> >
<h5 class="text-caption q-mt-sm q-mb-none">Curl example</h5> <h5 class="text-caption q-mt-sm q-mb-none">Curl example</h5>
<code <code
>curl -X POST {{ request.base_url }}api/v1/payments -d '{"out": false, >curl -X POST <span v-text="baseUrl"></span>api/v1/payments -d
"amount": &lt;int&gt;, "memo": &lt;string&gt;}' -H "X-Api-Key: '{"out": false, "amount": &lt;int&gt;, "memo": &lt;string&gt;}' -H
<i>{{ wallet.inkey }}</i>" -H "Content-type: application/json"</code "X-Api-Key: <i v-text="wallet.inkey"></i>" -H "Content-type:
application/json"</code
> >
</q-card-section> </q-card-section>
</q-card> </q-card>
@ -155,9 +155,9 @@
<code>{"payment_hash": &lt;string&gt;}</code> <code>{"payment_hash": &lt;string&gt;}</code>
<h5 class="text-caption q-mt-sm q-mb-none">Curl example</h5> <h5 class="text-caption q-mt-sm q-mb-none">Curl example</h5>
<code <code
>curl -X POST {{ request.base_url }}api/v1/payments -d '{"out": true, >curl -X POST <span v-text="baseUrl"></span>api/v1/payments -d
"bolt11": &lt;string&gt;}' -H "X-Api-Key: '{"out": true, "bolt11": &lt;string&gt;}' -H "X-Api-Key:
<i>{{ wallet.adminkey }}"</i> -H "Content-type: <i v-text="wallet.adminkey"></i>" -H "Content-type:
application/json"</code application/json"</code
> >
</q-card-section> </q-card-section>
@ -183,7 +183,7 @@
</h5> </h5>
<h5 class="text-caption q-mt-sm q-mb-none">Curl example</h5> <h5 class="text-caption q-mt-sm q-mb-none">Curl example</h5>
<code <code
>curl -X POST {{ request.base_url }}api/v1/payments/decode -d >curl -X POST <span v-text="baseUrl"></span>api/v1/payments/decode -d
'{"data": &lt;bolt11/lnurl, string&gt;}' -H "Content-type: '{"data": &lt;bolt11/lnurl, string&gt;}' -H "Content-type:
application/json"</code application/json"</code
> >
@ -211,9 +211,10 @@
<code>{"paid": &lt;bool&gt;}</code> <code>{"paid": &lt;bool&gt;}</code>
<h5 class="text-caption q-mt-sm q-mb-none">Curl example</h5> <h5 class="text-caption q-mt-sm q-mb-none">Curl example</h5>
<code <code
>curl -X GET {{ request.base_url >curl -X GET
}}api/v1/payments/&lt;payment_hash&gt; -H "X-Api-Key: <span v-text="baseUrl"></span>api/v1/payments/&lt;payment_hash&gt; -H
<i>{{ wallet.inkey }}"</i> -H "Content-type: application/json"</code "X-Api-Key: <i v-text="wallet.inkey"></i>" -H "Content-type:
application/json"</code
> >
</q-card-section> </q-card-section>
</q-card> </q-card>

View file

@ -38,9 +38,9 @@
</div> </div>
<div class="col"> <div class="col">
<q-img <q-img
v-if="user.config.picture" v-if="user.extra.picture"
style="max-width: 100px" style="max-width: 100px"
:src="user.config.picture" :src="user.extra.picture"
class="float-right" class="float-right"
></q-img> ></q-img>
</div> </div>
@ -133,9 +133,9 @@
<div class="row"> <div class="row">
<div class="col"> <div class="col">
<q-img <q-img
v-if="user.config.picture" v-if="user.extra.picture"
style="max-width: 100px" style="max-width: 100px"
:src="user.config.picture" :src="user.extra.picture"
class="float-right" class="float-right"
></q-img> ></q-img>
</div> </div>
@ -236,9 +236,9 @@
</div> </div>
</q-card-section> </q-card-section>
<q-card-section v-if="user.config"> <q-card-section v-if="user.extra">
<q-input <q-input
v-model="user.config.first_name" v-model="user.extra.first_name"
:label="$t('first_name')" :label="$t('first_name')"
filled filled
dense dense
@ -246,7 +246,7 @@
> >
</q-input> </q-input>
<q-input <q-input
v-model="user.config.last_name" v-model="user.extra.last_name"
:label="$t('last_name')" :label="$t('last_name')"
filled filled
dense dense
@ -254,7 +254,7 @@
> >
</q-input> </q-input>
<q-input <q-input
v-model="user.config.provider" v-model="user.extra.provider"
:label="$t('auth_provider')" :label="$t('auth_provider')"
filled filled
dense dense
@ -263,7 +263,7 @@
> >
</q-input> </q-input>
<q-input <q-input
v-model="user.config.picture" v-model="user.extra.picture"
:label="$t('picture')" :label="$t('picture')"
filled filled
class="q-mb-md" class="q-mb-md"
@ -452,6 +452,23 @@
</q-btn> </q-btn>
</div> </div>
</div> </div>
<div class="row q-mb-md">
<div class="col-4">
<span v-text="$t('border_choices')"></span>
</div>
<div class="col-8">
<q-select
v-model="borderChoice"
:options="borderOptions"
label="Reactions"
@update:model-value="applyBorder"
>
<q-tooltip
><span v-text="$t('border_choices')"></span
></q-tooltip>
</q-select>
</div>
</div>
<div class="row q-mb-md"> <div class="row q-mb-md">
<div class="col-4">Notifications</div> <div class="col-4">Notifications</div>
<div class="col-8"> <div class="col-8">
@ -470,7 +487,7 @@
v-model="reactionChoice" v-model="reactionChoice"
:options="reactionOptions" :options="reactionOptions"
label="Reactions" label="Reactions"
@input="reactionChoiceFunc" @update:model-value="reactionChoiceFunc"
> >
<q-tooltip <q-tooltip
><span v-text="$t('payment_reactions')"></span ><span v-text="$t('payment_reactions')"></span

View file

@ -107,7 +107,7 @@
color="secondary" color="secondary"
style="" style=""
v-model="extension.isActive" v-model="extension.isActive"
@input="toggleExtension(extension)" @update:model-value="toggleExtension(extension)"
><q-tooltip> ><q-tooltip>
&nbsp; &nbsp;
<span <span
@ -659,11 +659,9 @@
<a <a
:href="'lightning:' + selectedExtension.payToEnable.paymentRequest" :href="'lightning:' + selectedExtension.payToEnable.paymentRequest"
> >
<q-responsive :ratio="1" class="q-mx-xl"> <lnbits-qrcode
<lnbits-qrcode :value="'lightning:' + selectedExtension.payToEnable.paymentRequest.toUpperCase()"
:value="'lightning:' + selectedExtension.payToEnable.paymentRequest.toUpperCase()" ></lnbits-qrcode>
></lnbits-qrcode>
</q-responsive>
</a> </a>
</div> </div>
<div v-else class="col"> <div v-else class="col">
@ -1060,7 +1058,7 @@
extension.inProgress = false extension.inProgress = false
}) })
}, },
toggleExtension: function (extension) { toggleExtension(extension) {
const action = extension.isActive ? 'activate' : 'deactivate' const action = extension.isActive ? 'activate' : 'deactivate'
LNbits.api LNbits.api
.request( .request(

View file

@ -6,7 +6,7 @@
<script src="{{ static_url_for('static', 'js/wallet.js') }}"></script> <script src="{{ static_url_for('static', 'js/wallet.js') }}"></script>
{% endblock %} {% endblock %}
<!----> <!---->
{% block title %} {{ wallet.name }} - {{ SITE_TITLE }} {% endblock %} {% block title %}{{ wallet_name }} - {{ SITE_TITLE }} {% endblock %}
<!----> <!---->
{% block page %} {% block page %}
<div class="row q-col-gutter-md"> <div class="row q-col-gutter-md">
@ -36,9 +36,8 @@
<q-card-section> <q-card-section>
<h3 class="q-my-none text-no-wrap"> <h3 class="q-my-none text-no-wrap">
<strong v-text="formattedBalance"></strong> <strong v-text="formattedBalance"></strong>
<small>{{LNBITS_DENOMINATION}}</small> <small> {{LNBITS_DENOMINATION}}</small>
<lnbits-update-balance <lnbits-update-balance
v-if="'{{user.super_user}}' == 'True'"
:wallet_id="this.g.wallet.id" :wallet_id="this.g.wallet.id"
flat flat
:callback="updateBalanceCallback" :callback="updateBalanceCallback"
@ -119,7 +118,7 @@
<q-card-section> <q-card-section>
<h6 class="text-subtitle1 q-mt-none q-mb-sm"> <h6 class="text-subtitle1 q-mt-none q-mb-sm">
{{ SITE_TITLE }} Wallet: {{ SITE_TITLE }} Wallet:
<strong><em>{{wallet.name}}</em></strong> <strong><em>{{wallet_name}}</em></strong>
</h6> </h6>
</q-card-section> </q-card-section>
<q-card-section class="q-pa-none"> <q-card-section class="q-pa-none">
@ -154,17 +153,15 @@
<q-card> <q-card>
<q-card-section class="text-center"> <q-card-section class="text-center">
<p v-text="$t('export_to_phone_desc')"></p> <p v-text="$t('export_to_phone_desc')"></p>
<qrcode-vue <lnbits-qrcode :value="exportUrl"></lnbits-qrcode>
:value="'{{request.base_url}}wallet?usr={{user.id}}&wal={{wallet.id}}'"
:options="{ width: 256 }"
></qrcode-vue>
</q-card-section> </q-card-section>
<span v-text="exportWalletQR"></span>
<q-card-actions class="flex-center q-pb-md"> <q-card-actions class="flex-center q-pb-md">
<q-btn <q-btn
outline outline
color="grey" color="grey"
:label="$t('copy_wallet_url')" :label="$t('copy_wallet_url')"
@click="copyText('{{request.base_url}}wallet?usr={{user.id}}&wal={{wallet.id}}')" @click="copyText(exportUrl)"
></q-btn> ></q-btn>
</q-card-actions> </q-card-actions>
</q-card> </q-card>
@ -183,7 +180,6 @@
v-model.trim="update.name" v-model.trim="update.name"
label="Name" label="Name"
dense dense
@update:model-value="(e) => console.log(e)"
/> />
</div> </div>
<q-btn <q-btn
@ -370,11 +366,9 @@
> >
<div class="text-center q-mb-lg"> <div class="text-center q-mb-lg">
<a :href="'lightning:' + receive.paymentReq"> <a :href="'lightning:' + receive.paymentReq">
<q-responsive :ratio="1" class="q-mx-xl"> <lnbits-qrcode
<lnbits-qrcode :value="'lightning:' + receive.paymentReq.toUpperCase()"
:value="'lightning:' + receive.paymentReq.toUpperCase()" ></lnbits-qrcode>
></lnbits-qrcode>
</q-responsive>
</a> </a>
</div> </div>
<div class="row q-mt-lg"> <div class="row q-mt-lg">
@ -627,8 +621,8 @@
<div v-else> <div v-else>
<q-responsive :ratio="1"> <q-responsive :ratio="1">
<qrcode-stream <qrcode-stream
@decode="decodeQR" @detect="decodeQR"
@init="onInitQR" @camera-on="onInitQR"
class="rounded-borders" class="rounded-borders"
></qrcode-stream> ></qrcode-stream>
</q-responsive> </q-responsive>
@ -651,8 +645,8 @@
<q-card class="q-pa-lg q-pt-xl"> <q-card class="q-pa-lg q-pt-xl">
<div class="text-center q-mb-lg"> <div class="text-center q-mb-lg">
<qrcode-stream <qrcode-stream
@decode="decodeQR" @detect="decodeQR"
@init="onInitQR" @camera-on="onInitQR"
class="rounded-borders" class="rounded-borders"
></qrcode-stream> ></qrcode-stream>
</div> </div>
@ -667,22 +661,28 @@
</div> </div>
</q-card> </q-card>
</q-dialog> </q-dialog>
<div
<q-tabs
class="lt-md fixed-bottom left-0 right-0 bg-primary text-white shadow-2 z-top" class="lt-md fixed-bottom left-0 right-0 bg-primary text-white shadow-2 z-top"
active-class="px-0"
indicator-color="transparent"
align="justify"
> >
<q-tab <q-tabs
icon="file_download" active-class="px-0"
@click="showReceiveDialog" indicator-color="transparent"
:label="$t('receive')" align="justify"
> >
</q-tab> <q-tab
icon="file_download"
@click="showReceiveDialog"
:label="$t('receive')"
>
</q-tab>
<q-tab @click="showParseDialog" icon="file_upload" :label="$t('send')"> <q-tab
</q-tab> @click="showParseDialog"
icon="file_upload"
:label="$t('send')"
>
</q-tab>
</q-tabs>
<q-btn <q-btn
round round
size="35px" size="35px"
@ -692,8 +692,7 @@
class="text-white bg-primary z-top vertical-bottom absolute-center absolute" class="text-white bg-primary z-top vertical-bottom absolute-center absolute"
> >
</q-btn> </q-btn>
</q-tabs> </div>
<q-dialog v-model="disclaimerDialog.show" position="top"> <q-dialog v-model="disclaimerDialog.show" position="top">
<q-card class="q-pa-lg"> <q-card class="q-pa-lg">
<h6 <h6

View file

@ -160,7 +160,7 @@
<q-table <q-table
dense dense
flat flat
:data="this.filteredChannels" :rows="this.filteredChannels"
:filter="channels.filter" :filter="channels.filter"
no-data-label="No channels opened" no-data-label="No channels opened"
> >
@ -239,7 +239,7 @@
<q-table <q-table
dense dense
flat flat
:data="peers.data" :rows="peers.data"
:filter="peers.filter" :filter="peers.filter"
no-data-label="No transactions made yet" no-data-label="No transactions made yet"
> >

View file

@ -42,11 +42,7 @@
</div> </div>
{% endblock %} {% block scripts %} {{ window_vars(user) }} {% endblock %} {% block scripts %} {{ window_vars(user) }}
<script src="{{ static_url_for('static', 'js/node.js') }}"></script>
<script> <script>
Vue.component(VueQrcode.name, VueQrcode)
Vue.use(VueQrcodeReader)
window.app = Vue.createApp({ window.app = Vue.createApp({
el: '#vue', el: '#vue',
config: { config: {
@ -367,10 +363,13 @@
this.transactionDetailsDialog.data = details this.transactionDetailsDialog.data = details
console.log('details', details) console.log('details', details)
}, },
exportCSV: function () {}, shortenNodeId(nodeId) {
shortenNodeId return nodeId
? nodeId.substring(0, 5) + '...' + nodeId.substring(nodeId.length - 5)
: '...'
}
} }
}) })
</script> </script>
<script src="{{ static_url_for('static', 'js/node.js') }}"></script>
{% endblock %} {% endblock %}

View file

@ -37,6 +37,7 @@
icon="content_copy" icon="content_copy"
size="sm" size="sm"
color="primary" color="primary"
class="q-ml-xs"
@click="copyText(props.row.id)" @click="copyText(props.row.id)"
> >
<q-tooltip>Copy Wallet ID</q-tooltip> <q-tooltip>Copy Wallet ID</q-tooltip>
@ -45,6 +46,7 @@
v-if="!props.row.deleted" v-if="!props.row.deleted"
:wallet_id="props.row.id" :wallet_id="props.row.id"
:callback="topupCallback" :callback="topupCallback"
class="q-ml-xs"
></lnbits-update-balance> ></lnbits-update-balance>
<q-btn <q-btn
round round
@ -52,6 +54,7 @@
icon="vpn_key" icon="vpn_key"
size="sm" size="sm"
color="primary" color="primary"
class="q-ml-xs"
@click="copyText(props.row.adminkey)" @click="copyText(props.row.adminkey)"
> >
<q-tooltip>Copy Admin Key</q-tooltip> <q-tooltip>Copy Admin Key</q-tooltip>
@ -62,6 +65,7 @@
icon="vpn_key" icon="vpn_key"
size="sm" size="sm"
color="secondary" color="secondary"
class="q-ml-xs"
@click="copyText(props.row.inkey)" @click="copyText(props.row.inkey)"
> >
<q-tooltip>Copy Invoice Key</q-tooltip> <q-tooltip>Copy Invoice Key</q-tooltip>
@ -72,6 +76,7 @@
icon="toggle_off" icon="toggle_off"
size="sm" size="sm"
color="secondary" color="secondary"
class="q-ml-xs"
@click="undeleteUserWallet(props.row.user, props.row.id)" @click="undeleteUserWallet(props.row.user, props.row.id)"
> >
<q-tooltip>Undelete Wallet</q-tooltip> <q-tooltip>Undelete Wallet</q-tooltip>
@ -81,6 +86,7 @@
icon="delete" icon="delete"
size="sm" size="sm"
color="negative" color="negative"
class="q-ml-xs"
@click="deleteUserWallet(props.row.user, props.row.id, props.row.deleted)" @click="deleteUserWallet(props.row.user, props.row.id, props.row.deleted)"
> >
<q-tooltip>Delete Wallet</q-tooltip> <q-tooltip>Delete Wallet</q-tooltip>

View file

@ -61,6 +61,7 @@ include "users/_createWalletDialog.html" %}
icon="content_copy" icon="content_copy"
size="sm" size="sm"
color="primary" color="primary"
class="q-ml-xs"
@click="copyText(props.row.id)" @click="copyText(props.row.id)"
> >
<q-tooltip>Copy User ID</q-tooltip> <q-tooltip>Copy User ID</q-tooltip>
@ -71,6 +72,7 @@ include "users/_createWalletDialog.html" %}
icon="build" icon="build"
size="sm" size="sm"
:color="props.row.is_admin ? 'primary' : 'grey'" :color="props.row.is_admin ? 'primary' : 'grey'"
class="q-ml-xs"
@click="toggleAdmin(props.row.id)" @click="toggleAdmin(props.row.id)"
> >
<q-tooltip>Toggle Admin</q-tooltip> <q-tooltip>Toggle Admin</q-tooltip>
@ -81,6 +83,7 @@ include "users/_createWalletDialog.html" %}
icon="build" icon="build"
size="sm" size="sm"
color="positive" color="positive"
class="q-ml-xs"
> >
<q-tooltip>Super User</q-tooltip> <q-tooltip>Super User</q-tooltip>
</q-btn> </q-btn>
@ -98,6 +101,7 @@ include "users/_createWalletDialog.html" %}
icon="delete" icon="delete"
size="sm" size="sm"
color="negative" color="negative"
class="q-ml-xs"
@click="deleteUser(props.row.id, props)" @click="deleteUser(props.row.id, props)"
> >
<q-tooltip>Delete User</q-tooltip> <q-tooltip>Delete User</q-tooltip>
@ -111,7 +115,10 @@ include "users/_createWalletDialog.html" %}
<q-td auto-width v-text="props.row.transaction_count"></q-td> <q-td auto-width v-text="props.row.transaction_count"></q-td>
<q-td auto-width v-text="props.row.username"></q-td> <q-td auto-width v-text="props.row.username"></q-td>
<q-td auto-width v-text="props.row.email"></q-td> <q-td auto-width v-text="props.row.email"></q-td>
<q-td auto-width v-text="props.row.last_payment"></q-td> <q-td
auto-width
v-text="formatDate(props.row.last_payment)"
></q-td>
</q-tr> </q-tr>
</template> </template>
</q-table> </q-table>

View file

@ -3,7 +3,7 @@ import json
from http import HTTPStatus from http import HTTPStatus
from io import BytesIO from io import BytesIO
from time import time from time import time
from typing import Any, Dict, List from typing import Any
from urllib.parse import ParseResult, parse_qs, urlencode, urlparse, urlunparse from urllib.parse import ParseResult, parse_qs, urlencode, urlparse, urlunparse
import httpx import httpx
@ -13,7 +13,7 @@ from fastapi import (
Depends, Depends,
) )
from fastapi.exceptions import HTTPException from fastapi.exceptions import HTTPException
from starlette.responses import StreamingResponse from fastapi.responses import StreamingResponse
from lnbits.core.crud import get_user from lnbits.core.crud import get_user
from lnbits.core.models import ( from lnbits.core.models import (
@ -43,10 +43,6 @@ from lnbits.wallets.base import StatusResponse
from ..services import create_user_account, perform_lnurlauth from ..services import create_user_account, perform_lnurlauth
# backwards compatibility for extension
# TODO: remove api_payment and pay_invoice imports from extensions
from .payment_api import api_payment, pay_invoice # noqa: F401
api_router = APIRouter(tags=["Core"]) api_router = APIRouter(tags=["Core"])
@ -87,20 +83,16 @@ async def health_check(wallet: WalletTypeInfo = Depends(require_invoice_key)) ->
"/api/v1/wallets", "/api/v1/wallets",
name="Wallets", name="Wallets",
description="Get basic info for all of user's wallets.", description="Get basic info for all of user's wallets.",
response_model=list[BaseWallet],
) )
async def api_wallets(user: User = Depends(check_user_exists)) -> List[BaseWallet]: async def api_wallets(user: User = Depends(check_user_exists)) -> list[Wallet]:
return [BaseWallet(**w.dict()) for w in user.wallets] return user.wallets
@api_router.post("/api/v1/account", response_model=Wallet) @api_router.post("/api/v1/account")
async def api_create_account(data: CreateWallet) -> Wallet: async def api_create_account(data: CreateWallet) -> Wallet:
if not settings.new_accounts_allowed: user = await create_user_account(wallet_name=data.name)
raise HTTPException( return user.wallets[0]
status_code=HTTPStatus.FORBIDDEN,
detail="Account creation is disabled.",
)
account = await create_user_account(wallet_name=data.name)
return account.wallets[0]
@api_router.get("/api/v1/lnurlscan/{code}") @api_router.get("/api/v1/lnurlscan/{code}")
@ -128,7 +120,7 @@ async def api_lnurlscan(
) from exc ) from exc
# params is what will be returned to the client # params is what will be returned to the client
params: Dict = {"domain": domain} params: dict = {"domain": domain}
if "tag=login" in url: if "tag=login" in url:
params.update(kind="auth") params.update(kind="auth")
@ -177,7 +169,7 @@ async def api_lnurlscan(
# callback with k1 already in it # callback with k1 already in it
parsed_callback: ParseResult = urlparse(data["callback"]) parsed_callback: ParseResult = urlparse(data["callback"])
qs: Dict = parse_qs(parsed_callback.query) qs: dict = parse_qs(parsed_callback.query)
qs["k1"] = data["k1"] qs["k1"] = data["k1"]
# balanceCheck/balanceNotify # balanceCheck/balanceNotify
@ -234,13 +226,13 @@ async def api_perform_lnurlauth(
@api_router.get("/api/v1/rate/{currency}") @api_router.get("/api/v1/rate/{currency}")
async def api_check_fiat_rate(currency: str) -> Dict[str, float]: async def api_check_fiat_rate(currency: str) -> dict[str, float]:
rate = await get_fiat_rate_satoshis(currency) rate = await get_fiat_rate_satoshis(currency)
return {"rate": rate} return {"rate": rate}
@api_router.get("/api/v1/currencies") @api_router.get("/api/v1/currencies")
async def api_list_currencies_available() -> List[str]: async def api_list_currencies_available() -> list[str]:
return allowed_currencies() return allowed_currencies()

View file

@ -1,19 +1,15 @@
import base64 import base64
import importlib import importlib
import json import json
from http import HTTPStatus
from time import time from time import time
from typing import Callable, Optional from typing import Callable, Optional
from uuid import uuid4
from fastapi import APIRouter, Depends, HTTPException, Request, status from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.responses import JSONResponse, RedirectResponse from fastapi.responses import JSONResponse, RedirectResponse
from fastapi_sso.sso.base import OpenID, SSOBase from fastapi_sso.sso.base import OpenID, SSOBase
from loguru import logger from loguru import logger
from starlette.status import (
HTTP_400_BAD_REQUEST,
HTTP_401_UNAUTHORIZED,
HTTP_403_FORBIDDEN,
HTTP_500_INTERNAL_SERVER_ERROR,
)
from lnbits.core.services import create_user_account from lnbits.core.services import create_user_account
from lnbits.decorators import access_token_payload, check_user_exists from lnbits.decorators import access_token_payload, check_user_exists
@ -31,16 +27,14 @@ from ..crud import (
get_account, get_account,
get_account_by_email, get_account_by_email,
get_account_by_pubkey, get_account_by_pubkey,
get_account_by_username,
get_account_by_username_or_email, get_account_by_username_or_email,
get_user, get_user_from_account,
get_user_password,
update_account, update_account,
update_user_password,
update_user_pubkey,
verify_user_password,
) )
from ..models import ( from ..models import (
AccessTokenPayload, AccessTokenPayload,
Account,
CreateUser, CreateUser,
LoginUsernamePassword, LoginUsernamePassword,
LoginUsr, LoginUsr,
@ -50,7 +44,7 @@ from ..models import (
UpdateUserPassword, UpdateUserPassword,
UpdateUserPubkey, UpdateUserPubkey,
User, User,
UserConfig, UserExtra,
) )
auth_router = APIRouter(prefix="/api/v1/auth", tags=["Auth"]) auth_router = APIRouter(prefix="/api/v1/auth", tags=["Auth"])
@ -65,65 +59,43 @@ async def get_auth_user(user: User = Depends(check_user_exists)) -> User:
async def login(data: LoginUsernamePassword) -> JSONResponse: async def login(data: LoginUsernamePassword) -> JSONResponse:
if not settings.is_auth_method_allowed(AuthMethods.username_and_password): if not settings.is_auth_method_allowed(AuthMethods.username_and_password):
raise HTTPException( raise HTTPException(
HTTP_401_UNAUTHORIZED, "Login by 'Username and Password' not allowed." HTTPStatus.UNAUTHORIZED, "Login by 'Username and Password' not allowed."
) )
account = await get_account_by_username_or_email(data.username)
try: if not account or not account.verify_password(data.password):
user = await get_account_by_username_or_email(data.username) raise HTTPException(HTTPStatus.UNAUTHORIZED, "Invalid credentials.")
return _auth_success_response(account.username, account.id, account.email)
if not user:
raise HTTPException(HTTP_401_UNAUTHORIZED, "Invalid credentials.")
if not await verify_user_password(user.id, data.password):
raise HTTPException(HTTP_401_UNAUTHORIZED, "Invalid credentials.")
return _auth_success_response(user.username, user.id, user.email)
except HTTPException as exc:
raise exc
except Exception as exc:
logger.debug(exc)
raise HTTPException(HTTP_500_INTERNAL_SERVER_ERROR, "Cannot login.") from exc
@auth_router.post("/nostr", description="Login via Nostr") @auth_router.post("/nostr", description="Login via Nostr")
async def nostr_login(request: Request) -> JSONResponse: async def nostr_login(request: Request) -> JSONResponse:
if not settings.is_auth_method_allowed(AuthMethods.nostr_auth_nip98): if not settings.is_auth_method_allowed(AuthMethods.nostr_auth_nip98):
raise HTTPException(HTTP_401_UNAUTHORIZED, "Login with Nostr Auth not allowed.") raise HTTPException(
HTTPStatus.UNAUTHORIZED, "Login with Nostr Auth not allowed."
try: )
event = _nostr_nip98_event(request) event = _nostr_nip98_event(request)
account = await get_account_by_pubkey(event["pubkey"])
user = await get_account_by_pubkey(event["pubkey"]) if not account:
if not user: account = Account(
user = await create_user_account( id=uuid4().hex,
pubkey=event["pubkey"], user_config=UserConfig(provider="nostr") pubkey=event["pubkey"],
) extra=UserExtra(provider="nostr"),
)
return _auth_success_response(user.username or "", user.id, user.email) await create_user_account(account)
except HTTPException as exc: return _auth_success_response(account.username or "", account.id, account.email)
raise exc
except AssertionError as exc:
raise HTTPException(HTTP_401_UNAUTHORIZED, str(exc)) from exc
except Exception as exc:
logger.warning(exc)
raise HTTPException(HTTP_500_INTERNAL_SERVER_ERROR, "Cannot login.") from exc
@auth_router.post("/usr", description="Login via the User ID") @auth_router.post("/usr", description="Login via the User ID")
async def login_usr(data: LoginUsr) -> JSONResponse: async def login_usr(data: LoginUsr) -> JSONResponse:
if not settings.is_auth_method_allowed(AuthMethods.user_id_only): if not settings.is_auth_method_allowed(AuthMethods.user_id_only):
raise HTTPException(HTTP_401_UNAUTHORIZED, "Login by 'User ID' not allowed.") raise HTTPException(
HTTPStatus.UNAUTHORIZED,
try: "Login by 'User ID' not allowed.",
user = await get_user(data.usr) )
if not user: account = await get_account(data.usr)
raise HTTPException(HTTP_401_UNAUTHORIZED, "User ID does not exist.") if not account:
raise HTTPException(HTTPStatus.UNAUTHORIZED, "User ID does not exist.")
return _auth_success_response(user.username or "", user.id, user.email) return _auth_success_response(account.username, account.id, account.email)
except HTTPException as exc:
raise exc
except Exception as exc:
logger.debug(exc)
raise HTTPException(HTTP_500_INTERNAL_SERVER_ERROR, "Cannot login.") from exc
@auth_router.get("/{provider}", description="SSO Provider") @auth_router.get("/{provider}", description="SSO Provider")
@ -133,7 +105,8 @@ async def login_with_sso_provider(
provider_sso = _new_sso(provider) provider_sso = _new_sso(provider)
if not provider_sso: if not provider_sso:
raise HTTPException( raise HTTPException(
HTTP_401_UNAUTHORIZED, f"Login by '{provider}' not allowed." HTTPStatus.UNAUTHORIZED,
f"Login by '{provider}' not allowed.",
) )
provider_sso.redirect_uri = str(request.base_url) + f"api/v1/auth/{provider}/token" provider_sso.redirect_uri = str(request.base_url) + f"api/v1/auth/{provider}/token"
@ -147,31 +120,22 @@ async def handle_oauth_token(request: Request, provider: str) -> RedirectRespons
provider_sso = _new_sso(provider) provider_sso = _new_sso(provider)
if not provider_sso: if not provider_sso:
raise HTTPException( raise HTTPException(
HTTP_401_UNAUTHORIZED, f"Login by '{provider}' not allowed." HTTPStatus.UNAUTHORIZED,
f"Login by '{provider}' not allowed.",
) )
try: with provider_sso:
with provider_sso: userinfo = await provider_sso.verify_and_process(request)
userinfo = await provider_sso.verify_and_process(request) if not userinfo:
assert userinfo is not None raise HTTPException(HTTPStatus.UNAUTHORIZED, "Invalid user info.")
user_id = decrypt_internal_message(provider_sso.state) user_id = decrypt_internal_message(provider_sso.state)
request.session.pop("user", None) request.session.pop("user", None)
return await _handle_sso_login(userinfo, user_id) return await _handle_sso_login(userinfo, user_id)
except HTTPException as exc:
raise exc
except ValueError as exc:
raise HTTPException(HTTP_403_FORBIDDEN, str(exc)) from exc
except Exception as exc:
logger.debug(exc)
raise HTTPException(
HTTP_500_INTERNAL_SERVER_ERROR,
f"Cannot authenticate user with {provider} Auth.",
) from exc
@auth_router.post("/logout") @auth_router.post("/logout")
async def logout() -> JSONResponse: async def logout() -> JSONResponse:
response = JSONResponse({"status": "success"}, status_code=status.HTTP_200_OK) response = JSONResponse({"status": "success"}, HTTPStatus.OK)
response.delete_cookie("cookie_access_token") response.delete_cookie("cookie_access_token")
response.delete_cookie("is_lnbits_user_authorized") response.delete_cookie("is_lnbits_user_authorized")
response.delete_cookie("is_access_token_expired") response.delete_cookie("is_access_token_expired")
@ -184,62 +148,32 @@ async def logout() -> JSONResponse:
async def register(data: CreateUser) -> JSONResponse: async def register(data: CreateUser) -> JSONResponse:
if not settings.is_auth_method_allowed(AuthMethods.username_and_password): if not settings.is_auth_method_allowed(AuthMethods.username_and_password):
raise HTTPException( raise HTTPException(
HTTP_401_UNAUTHORIZED, "Register by 'Username and Password' not allowed." HTTPStatus.UNAUTHORIZED,
"Register by 'Username and Password' not allowed.",
) )
if data.password != data.password_repeat: if data.password != data.password_repeat:
raise HTTPException(HTTP_400_BAD_REQUEST, "Passwords do not match.") raise HTTPException(HTTPStatus.BAD_REQUEST, "Passwords do not match.")
if not data.username: if not data.username:
raise HTTPException(HTTP_400_BAD_REQUEST, "Missing username.") raise HTTPException(HTTPStatus.BAD_REQUEST, "Missing username.")
if not is_valid_username(data.username): if not is_valid_username(data.username):
raise HTTPException(HTTP_400_BAD_REQUEST, "Invalid username.") raise HTTPException(HTTPStatus.BAD_REQUEST, "Invalid username.")
if await get_account_by_username(data.username):
raise HTTPException(HTTPStatus.BAD_REQUEST, "Username already exists.")
if data.email and not is_valid_email_address(data.email): if data.email and not is_valid_email_address(data.email):
raise HTTPException(HTTP_400_BAD_REQUEST, "Invalid email.") raise HTTPException(HTTPStatus.BAD_REQUEST, "Invalid email.")
try: account = Account(
user = await create_user_account( id=uuid4().hex,
email=data.email, username=data.username, password=data.password email=data.email,
) username=data.username,
return _auth_success_response(user.username, user.id, user.email) )
account.hash_password(data.password)
except ValueError as exc: await create_user_account(account)
raise HTTPException(HTTP_403_FORBIDDEN, str(exc)) from exc return _auth_success_response(account.username, account.id, account.email)
except Exception as exc:
logger.debug(exc)
raise HTTPException(
HTTP_500_INTERNAL_SERVER_ERROR, "Cannot create user."
) from exc
@auth_router.put("/password")
async def update_password(
data: UpdateUserPassword,
user: User = Depends(check_user_exists),
payload: AccessTokenPayload = Depends(access_token_payload),
) -> Optional[User]:
if data.user_id != user.id:
raise HTTPException(HTTP_400_BAD_REQUEST, "Invalid user ID.")
try:
if data.username and not user.username:
await update_account(user_id=user.id, username=data.username)
# old accounts do not have a pasword
if await get_user_password(data.user_id):
assert data.password_old, "Missing old password"
old_pwd_ok = await verify_user_password(data.user_id, data.password_old)
assert old_pwd_ok, "Invalid credentials."
return await update_user_password(data, payload.auth_time or 0)
except AssertionError as exc:
raise HTTPException(HTTP_403_FORBIDDEN, str(exc)) from exc
except Exception as exc:
logger.debug(exc)
raise HTTPException(
HTTP_500_INTERNAL_SERVER_ERROR, "Cannot update user password."
) from exc
@auth_router.put("/pubkey") @auth_router.put("/pubkey")
@ -249,62 +183,89 @@ async def update_pubkey(
payload: AccessTokenPayload = Depends(access_token_payload), payload: AccessTokenPayload = Depends(access_token_payload),
) -> Optional[User]: ) -> Optional[User]:
if data.user_id != user.id: if data.user_id != user.id:
raise HTTPException(HTTP_400_BAD_REQUEST, "Invalid user ID.") raise HTTPException(HTTPStatus.BAD_REQUEST, "Invalid user ID.")
try: _validate_auth_timeout(payload.auth_time)
data.pubkey = normalize_public_key(data.pubkey) if (
return await update_user_pubkey(data, payload.auth_time or 0) data.pubkey
and data.pubkey != user.pubkey
and await get_account_by_pubkey(data.pubkey)
):
raise HTTPException(HTTPStatus.BAD_REQUEST, "Public key already in use.")
except AssertionError as exc: account = await get_account(user.id)
raise HTTPException(HTTP_403_FORBIDDEN, str(exc)) from exc if not account:
except Exception as exc: raise HTTPException(HTTPStatus.NOT_FOUND, "Account not found.")
logger.debug(exc)
raise HTTPException( account.pubkey = normalize_public_key(data.pubkey)
HTTP_500_INTERNAL_SERVER_ERROR, "Cannot update user pubkey." await update_account(account)
) from exc return await get_user_from_account(account)
@auth_router.put("/password")
async def update_password(
data: UpdateUserPassword,
user: User = Depends(check_user_exists),
payload: AccessTokenPayload = Depends(access_token_payload),
) -> Optional[User]:
_validate_auth_timeout(payload.auth_time)
assert data.user_id == user.id, "Invalid user ID."
if (
data.username
and user.username != data.username
and await get_account_by_username(data.username)
):
raise HTTPException(HTTPStatus.BAD_REQUEST, "Username already exists.")
account = await get_account(user.id)
assert account, "Account not found."
# old accounts do not have a password
if account.password_hash:
assert data.password_old, "Missing old password."
assert account.verify_password(data.password_old), "Invalid old password."
account.username = data.username
account.hash_password(data.password)
await update_account(account)
_user = await get_user_from_account(account)
if not _user:
raise HTTPException(HTTPStatus.NOT_FOUND, "User not found.")
return _user
@auth_router.put("/reset") @auth_router.put("/reset")
async def reset_password(data: ResetUserPassword) -> JSONResponse: async def reset_password(data: ResetUserPassword) -> JSONResponse:
if not settings.is_auth_method_allowed(AuthMethods.username_and_password): if not settings.is_auth_method_allowed(AuthMethods.username_and_password):
raise HTTPException( raise HTTPException(
HTTP_401_UNAUTHORIZED, "Auth by 'Username and Password' not allowed." HTTPStatus.UNAUTHORIZED, "Auth by 'Username and Password' not allowed."
) )
assert data.password == data.password_repeat, "Passwords do not match."
assert data.reset_key[:10].startswith("reset_key_"), "This is not a reset key."
try: try:
assert data.reset_key[:10] == "reset_key_", "This is not a reset key." reset_key = base64.b64decode(data.reset_key[10:]).decode()
reset_data_json = decrypt_internal_message(reset_key)
reset_data_json = decrypt_internal_message(
base64.b64decode(data.reset_key[10:]).decode()
)
assert reset_data_json, "Cannot process reset key."
action, user_id, request_time = json.loads(reset_data_json)
assert action == "reset", "Expected reset action."
assert user_id is not None, "Missing user ID."
assert request_time is not None, "Missing reset time."
user = await get_account(user_id)
assert user, "User not found."
update_pwd = UpdateUserPassword(
user_id=user.id,
username=user.username or "",
password=data.password,
password_repeat=data.password_repeat,
)
user = await update_user_password(update_pwd, request_time)
return _auth_success_response(
username=user.username, user_id=user_id, email=user.email
)
except AssertionError as exc:
raise HTTPException(HTTP_403_FORBIDDEN, str(exc)) from exc
except Exception as exc: except Exception as exc:
logger.warning(exc) raise ValueError("Invalid reset key.") from exc
raise HTTPException(
HTTP_500_INTERNAL_SERVER_ERROR, "Cannot reset user password." assert reset_data_json, "Cannot process reset key."
) from exc
action, user_id, request_time = json.loads(reset_data_json)
assert action, "Missing action."
assert user_id, "Missing user ID."
assert request_time, "Missing reset time."
_validate_auth_timeout(request_time)
account = await get_account(user_id)
if not account:
raise HTTPException(HTTPStatus.NOT_FOUND, "User not found.")
account.hash_password(data.password)
await update_account(account)
return _auth_success_response(account.username, user_id, account.email)
@auth_router.put("/update") @auth_router.put("/update")
@ -312,80 +273,83 @@ async def update(
data: UpdateUser, user: User = Depends(check_user_exists) data: UpdateUser, user: User = Depends(check_user_exists)
) -> Optional[User]: ) -> Optional[User]:
if data.user_id != user.id: if data.user_id != user.id:
raise HTTPException(HTTP_400_BAD_REQUEST, "Invalid user ID.") raise HTTPException(HTTPStatus.BAD_REQUEST, "Invalid user ID.")
if data.username and not is_valid_username(data.username): if data.username and not is_valid_username(data.username):
raise HTTPException(HTTP_400_BAD_REQUEST, "Invalid username.") raise HTTPException(HTTPStatus.BAD_REQUEST, "Invalid username.")
if data.email != user.email: if data.email != user.email:
raise HTTPException(HTTP_400_BAD_REQUEST, "Email mismatch.")
try:
return await update_account(user.id, data.username, None, data.config)
except AssertionError as exc:
raise HTTPException(HTTP_403_FORBIDDEN, str(exc)) from exc
except Exception as exc:
logger.debug(exc)
raise HTTPException( raise HTTPException(
HTTP_500_INTERNAL_SERVER_ERROR, "Cannot update user." HTTPStatus.BAD_REQUEST,
) from exc "Email mismatch.",
)
if (
data.username
and user.username != data.username
and await get_account_by_username(data.username)
):
raise HTTPException(HTTPStatus.BAD_REQUEST, "Username already exists.")
if (
data.email
and data.email != user.email
and await get_account_by_email(data.email)
):
raise HTTPException(HTTPStatus.BAD_REQUEST, "Email already exists.")
account = await get_account(user.id)
if not account:
raise HTTPException(HTTPStatus.NOT_FOUND, "Account not found.")
if data.username:
account.username = data.username
if data.email:
account.email = data.email
if data.extra:
account.extra = data.extra
await update_account(account)
return await get_user_from_account(account)
@auth_router.put("/first_install") @auth_router.put("/first_install")
async def first_install(data: UpdateSuperuserPassword) -> JSONResponse: async def first_install(data: UpdateSuperuserPassword) -> JSONResponse:
if not settings.first_install: if not settings.first_install:
raise HTTPException(HTTP_401_UNAUTHORIZED, "This is not your first install") raise HTTPException(HTTPStatus.UNAUTHORIZED, "This is not your first install")
try: account = await get_account(settings.super_user)
await update_account( if not account:
user_id=settings.super_user, raise HTTPException(HTTPStatus.INTERNAL_SERVER_ERROR, "Superuser not found.")
username=data.username, account.username = data.username
user_config=UserConfig(provider="lnbits"), account.extra = account.extra or UserExtra()
) account.extra.provider = "lnbits"
super_user = UpdateUserPassword( account.hash_password(data.password)
user_id=settings.super_user, await update_account(account)
password=data.password, settings.first_install = False
password_repeat=data.password_repeat, return _auth_success_response(account.username, account.id, account.email)
username=data.username,
)
user = await update_user_password(super_user, int(time()))
settings.first_install = False
return _auth_success_response(user.username, user.id, user.email)
except AssertionError as exc:
raise HTTPException(HTTP_403_FORBIDDEN, str(exc)) from exc
except Exception as exc:
logger.debug(exc)
raise HTTPException(
HTTP_500_INTERNAL_SERVER_ERROR, "Cannot init user password."
) from exc
async def _handle_sso_login(userinfo: OpenID, verified_user_id: Optional[str] = None): async def _handle_sso_login(userinfo: OpenID, verified_user_id: Optional[str] = None):
email = userinfo.email email = userinfo.email
if not email or not is_valid_email_address(email): if not email or not is_valid_email_address(email):
raise HTTPException(HTTP_400_BAD_REQUEST, "Invalid email.") raise HTTPException(HTTPStatus.BAD_REQUEST, "Invalid email.")
redirect_path = "/wallet" redirect_path = "/wallet"
user_config = UserConfig(**dict(userinfo))
user_config.email_verified = True
account = await get_account_by_email(email) account = await get_account_by_email(email)
if verified_user_id: if verified_user_id:
if account: if account:
raise HTTPException(HTTP_401_UNAUTHORIZED, "Email already used.") raise HTTPException(HTTPStatus.UNAUTHORIZED, "Email already used.")
account = await get_account(verified_user_id) account = await get_account(verified_user_id)
if not account: if not account:
raise HTTPException(HTTP_401_UNAUTHORIZED, "Cannot verify user email.") raise HTTPException(HTTPStatus.UNAUTHORIZED, "Cannot verify user email.")
redirect_path = "/account" redirect_path = "/account"
if account: if account:
user = await update_account(account.id, email=email, user_config=user_config) account.extra = account.extra or UserExtra()
account.extra.email_verified = True
await update_account(account)
else: else:
if not settings.new_accounts_allowed: account = Account(
raise HTTPException(HTTP_400_BAD_REQUEST, "Account creation is disabled.") id=uuid4().hex, email=email, extra=UserExtra(email_verified=True)
user = await create_user_account(email=email, user_config=user_config) )
await create_user_account(account)
if not user:
raise HTTPException(HTTP_401_UNAUTHORIZED, "User not found.")
return _auth_redirect_response(redirect_path, email) return _auth_redirect_response(redirect_path, email)
@ -461,23 +425,23 @@ def _find_auth_provider_class(provider: str) -> Callable:
def _nostr_nip98_event(request: Request) -> dict: def _nostr_nip98_event(request: Request) -> dict:
auth_header = request.headers.get("Authorization") auth_header = request.headers.get("Authorization")
assert auth_header, "Nostr Auth header missing." if not auth_header:
raise HTTPException(HTTPStatus.UNAUTHORIZED, "Nostr Auth header missing.")
scheme, token = auth_header.split() scheme, token = auth_header.split()
assert scheme.lower() == "nostr", "Authorization header is not nostr." if scheme.lower() != "nostr":
raise HTTPException(HTTPStatus.UNAUTHORIZED, "Invalid Authorization scheme.")
event = None event = None
try: try:
event_json = base64.b64decode(token.encode("ascii")) event_json = base64.b64decode(token.encode("ascii"))
event = json.loads(event_json) event = json.loads(event_json)
except Exception as exc: except Exception as exc:
logger.warning(exc) logger.warning(exc)
assert event, "Nostr login event cannot be parsed." assert event, "Nostr login event cannot be parsed."
assert verify_event(event), "Nostr login event is not valid." if not verify_event(event):
raise HTTPException(HTTPStatus.BAD_REQUEST, "Nostr login event is not valid.")
assert event["kind"] == 27_235, "Invalid event kind." assert event["kind"] == 27_235, "Invalid event kind."
auth_threshold = settings.auth_credetials_update_threshold auth_threshold = settings.auth_credetials_update_threshold
assert ( assert (
abs(time() - event["created_at"]) < auth_threshold abs(time() - event["created_at"]) < auth_threshold
@ -485,11 +449,22 @@ def _nostr_nip98_event(request: Request) -> dict:
method: Optional[str] = next((v for k, v in event["tags"] if k == "method"), None) method: Optional[str] = next((v for k, v in event["tags"] if k == "method"), None)
assert method, "Tag 'method' is missing." assert method, "Tag 'method' is missing."
assert method.upper() == "POST", "Incorrect value for tag 'method'." assert method.upper() == "POST", "Invalid value for tag 'method'."
url = next((v for k, v in event["tags"] if k == "u"), None) url = next((v for k, v in event["tags"] if k == "u"), None)
assert url, "Tag 'u' for URL is missing." assert url, "Tag 'u' for URL is missing."
accepted_urls = [f"{u}/nostr" for u in settings.nostr_absolute_request_urls] accepted_urls = [f"{u}/nostr" for u in settings.nostr_absolute_request_urls]
assert url in accepted_urls, f"Incorrect value for tag 'u': '{url}'." assert url in accepted_urls, f"Invalid value for tag 'u': '{url}'."
return event return event
def _validate_auth_timeout(auth_time: Optional[int] = 0):
if abs(time() - (auth_time or 0)) > settings.auth_credetials_update_threshold:
raise HTTPException(
HTTPStatus.BAD_REQUEST,
"You can only update your credentials in the first"
f" {settings.auth_credetials_update_threshold} seconds."
" Please login again or ask a new reset key!",
)

View file

@ -1,7 +1,6 @@
import sys
import traceback
from http import HTTPStatus from http import HTTPStatus
from typing import (
List,
)
from bolt11 import decode as bolt11_decode from bolt11 import decode as bolt11_decode
from fastapi import ( from fastapi import (
@ -21,10 +20,12 @@ from lnbits.core.extensions.models import (
CreateExtension, CreateExtension,
Extension, Extension,
ExtensionConfig, ExtensionConfig,
ExtensionMeta,
ExtensionRelease, ExtensionRelease,
InstallableExtension, InstallableExtension,
PayToEnableInfo, PayToEnableInfo,
ReleasePaymentInfo, ReleasePaymentInfo,
UserExtension,
UserExtensionInfo, UserExtensionInfo,
) )
from lnbits.core.models import ( from lnbits.core.models import (
@ -38,15 +39,15 @@ from lnbits.decorators import (
) )
from ..crud import ( from ..crud import (
create_user_extension,
delete_dbversion, delete_dbversion,
drop_extension_db, drop_extension_db,
get_dbversions, get_db_version,
get_installed_extension, get_installed_extension,
get_installed_extensions, get_installed_extensions,
get_user_extension, get_user_extension,
update_extension_pay_to_enable, update_installed_extension,
update_user_extension, update_user_extension,
update_user_extension_extra,
) )
extension_router = APIRouter( extension_router = APIRouter(
@ -71,8 +72,13 @@ async def api_install_extension(data: CreateExtension):
) )
release.payment_hash = data.payment_hash release.payment_hash = data.payment_hash
ext_meta = ExtensionMeta(installed_release=release)
ext_info = InstallableExtension( ext_info = InstallableExtension(
id=data.ext_id, name=data.ext_id, installed_release=release, icon=release.icon id=data.ext_id,
name=data.ext_id,
version=data.version,
meta=ext_meta,
icon=release.icon,
) )
try: try:
@ -80,6 +86,8 @@ async def api_install_extension(data: CreateExtension):
except Exception as exc: except Exception as exc:
logger.warning(exc) logger.warning(exc)
etype, _, tb = sys.exc_info()
traceback.print_exception(etype, exc, tb)
ext_info.clean_extension_files() ext_info.clean_extension_files()
detail = ( detail = (
str(exc) str(exc)
@ -109,33 +117,28 @@ async def api_install_extension(data: CreateExtension):
) from exc ) from exc
@extension_router.get("/{ext_id}/details", dependencies=[Depends(check_user_exists)]) @extension_router.get("/{ext_id}/details")
async def api_extension_details( async def api_extension_details(
ext_id: str, ext_id: str,
details_link: str, details_link: str,
): ):
all_releases = await InstallableExtension.get_extension_releases(ext_id)
try: release = next((r for r in all_releases if r.details_link == details_link), None)
all_releases = await InstallableExtension.get_extension_releases(ext_id) if not release:
release = next(
(r for r in all_releases if r.details_link == details_link), None
)
assert release, "Details not found for release"
release_details = await ExtensionRelease.fetch_release_details(details_link)
assert release_details, "Cannot fetch details for release"
release_details["icon"] = release.icon
release_details["repo"] = release.repo
return release_details
except AssertionError as exc:
raise HTTPException(HTTPStatus.BAD_REQUEST, str(exc)) from exc
except Exception as exc:
logger.warning(exc)
raise HTTPException( raise HTTPException(
HTTPStatus.INTERNAL_SERVER_ERROR, status_code=HTTPStatus.NOT_FOUND, detail="Release not found"
f"Failed to get details for extension {ext_id}.", )
) from exc
release_details = await ExtensionRelease.fetch_release_details(details_link)
if not release_details:
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
detail="Cannot fetch details for release",
)
release_details["icon"] = release.icon
release_details["repo"] = release.repo
return release_details
@extension_router.put("/{ext_id}/sell") @extension_router.put("/{ext_id}/sell")
@ -144,22 +147,21 @@ async def api_update_pay_to_enable(
data: PayToEnableInfo, data: PayToEnableInfo,
user: User = Depends(check_admin), user: User = Depends(check_admin),
) -> SimpleStatus: ) -> SimpleStatus:
try: if data.wallet not in user.wallet_ids:
assert (
data.wallet in user.wallet_ids
), "Wallet does not belong to this admin user."
await update_extension_pay_to_enable(ext_id, data)
return SimpleStatus(
success=True, message=f"Payment info updated for '{ext_id}' extension."
)
except AssertionError as exc:
raise HTTPException(HTTPStatus.BAD_REQUEST, str(exc)) from exc
except Exception as exc:
logger.warning(exc)
raise HTTPException( raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR, HTTPStatus.BAD_REQUEST, "Wallet does not belong to this admin user."
detail=(f"Failed to update pay to install data for extension '{ext_id}' "), )
) from exc extension = await get_installed_extension(ext_id)
if not extension:
raise HTTPException(HTTPStatus.NOT_FOUND, f"Extension '{ext_id}' not found.")
if extension.meta:
extension.meta.pay_to_enable = data
else:
extension.meta = ExtensionMeta(pay_to_enable=data)
await update_installed_extension(extension)
return SimpleStatus(
success=True, message=f"Payment info updated for '{ext_id}' extension."
)
@extension_router.put("/{ext_id}/enable") @extension_router.put("/{ext_id}/enable")
@ -176,28 +178,34 @@ async def api_enable_extension(
assert ext, f"Extension '{ext_id}' is not installed." assert ext, f"Extension '{ext_id}' is not installed."
assert ext.active, f"Extension '{ext_id}' is not activated." assert ext.active, f"Extension '{ext_id}' is not activated."
user_ext = await get_user_extension(user.id, ext_id)
if not user_ext:
user_ext = UserExtension(user=user.id, extension=ext_id, active=False)
await create_user_extension(user_ext)
if user.admin or not ext.requires_payment: if user.admin or not ext.requires_payment:
await update_user_extension(user_id=user.id, extension=ext_id, active=True) user_ext.active = True
await update_user_extension(user_ext)
return SimpleStatus(success=True, message=f"Extension '{ext_id}' enabled.") return SimpleStatus(success=True, message=f"Extension '{ext_id}' enabled.")
user_ext = await get_user_extension(user.id, ext_id) if not (user_ext.extra and user_ext.extra.payment_hash_to_enable):
if not (user_ext and user_ext.extra and user_ext.extra.payment_hash_to_enable):
raise HTTPException( raise HTTPException(
HTTPStatus.PAYMENT_REQUIRED, f"Extension '{ext_id}' requires payment." HTTPStatus.PAYMENT_REQUIRED, f"Extension '{ext_id}' requires payment."
) )
if user_ext.is_paid: if user_ext.is_paid:
await update_user_extension(user_id=user.id, extension=ext_id, active=True) user_ext.active = True
await update_user_extension(user_ext)
return SimpleStatus( return SimpleStatus(
success=True, message=f"Paid extension '{ext_id}' enabled." success=True, message=f"Paid extension '{ext_id}' enabled."
) )
assert ( assert (
ext.pay_to_enable and ext.pay_to_enable.wallet ext.meta and ext.meta.pay_to_enable and ext.meta.pay_to_enable.wallet
), f"Extension '{ext_id}' is missing payment wallet." ), f"Extension '{ext_id}' is missing payment wallet."
payment_status = await check_transaction_status( payment_status = await check_transaction_status(
wallet_id=ext.pay_to_enable.wallet, wallet_id=ext.meta.pay_to_enable.wallet,
payment_hash=user_ext.extra.payment_hash_to_enable, payment_hash=user_ext.extra.payment_hash_to_enable,
) )
@ -207,10 +215,9 @@ async def api_enable_extension(
f"Invoice generated but not paid for enabeling extension '{ext_id}'.", f"Invoice generated but not paid for enabeling extension '{ext_id}'.",
) )
user_ext.active = True
user_ext.extra.paid_to_enable = True user_ext.extra.paid_to_enable = True
await update_user_extension_extra(user.id, ext_id, user_ext.extra) await update_user_extension(user_ext)
await update_user_extension(user_id=user.id, extension=ext_id, active=True)
return SimpleStatus(success=True, message=f"Paid extension '{ext_id}' enabled.") return SimpleStatus(success=True, message=f"Paid extension '{ext_id}' enabled.")
except AssertionError as exc: except AssertionError as exc:
@ -233,16 +240,15 @@ async def api_disable_extension(
raise HTTPException( raise HTTPException(
HTTPStatus.BAD_REQUEST, f"Extension '{ext_id}' doesn't exist." HTTPStatus.BAD_REQUEST, f"Extension '{ext_id}' doesn't exist."
) )
try: user_ext = await get_user_extension(user.id, ext_id)
logger.info(f"Disabeling extension: {ext_id}.") if not user_ext or not user_ext.active:
await update_user_extension(user_id=user.id, extension=ext_id, active=False) return SimpleStatus(
return SimpleStatus(success=True, message=f"Extension '{ext_id}' disabled.") success=True, message=f"Extension '{ext_id}' already disabled."
except Exception as exc: )
logger.warning(exc) logger.info(f"Disabeling extension: {ext_id}.")
raise HTTPException( user_ext.active = False
status_code=HTTPStatus.INTERNAL_SERVER_ERROR, await update_user_extension(user_ext)
detail=(f"Failed to disable '{ext_id}'."), return SimpleStatus(success=True, message=f"Extension '{ext_id}' disabled.")
) from exc
@extension_router.put("/{ext_id}/activate", dependencies=[Depends(check_admin)]) @extension_router.put("/{ext_id}/activate", dependencies=[Depends(check_admin)])
@ -298,7 +304,11 @@ async def api_uninstall_extension(ext_id: str) -> SimpleStatus:
installed_ext = next( installed_ext = next(
(ext for ext in installed_extensions if ext.id == valid_ext_id), None (ext for ext in installed_extensions if ext.id == valid_ext_id), None
) )
if installed_ext and ext_id in installed_ext.dependencies: if (
installed_ext
and installed_ext.meta
and ext_id in installed_ext.meta.dependencies
):
raise HTTPException( raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST, status_code=HTTPStatus.BAD_REQUEST,
detail=( detail=(
@ -319,9 +329,9 @@ async def api_uninstall_extension(ext_id: str) -> SimpleStatus:
@extension_router.get("/{ext_id}/releases", dependencies=[Depends(check_admin)]) @extension_router.get("/{ext_id}/releases", dependencies=[Depends(check_admin)])
async def get_extension_releases(ext_id: str) -> List[ExtensionRelease]: async def get_extension_releases(ext_id: str) -> list[ExtensionRelease]:
try: try:
extension_releases: List[ExtensionRelease] = ( extension_releases: list[ExtensionRelease] = (
await InstallableExtension.get_extension_releases(ext_id) await InstallableExtension.get_extension_releases(ext_id)
) )
@ -386,45 +396,59 @@ async def get_pay_to_install_invoice(
async def get_pay_to_enable_invoice( async def get_pay_to_enable_invoice(
ext_id: str, data: PayToEnableInfo, user: User = Depends(check_user_exists) ext_id: str, data: PayToEnableInfo, user: User = Depends(check_user_exists)
): ):
try: if not data.amount or data.amount <= 0:
assert data.amount and data.amount > 0, "A non-zero amount must be specified."
ext = await get_installed_extension(ext_id)
assert ext, f"Extension '{ext_id}' not found."
assert ext.pay_to_enable, f"Payment Info not found for extension '{ext_id}'."
assert (
ext.pay_to_enable.required
), f"Payment not required for extension '{ext_id}'."
assert ext.pay_to_enable.wallet and ext.pay_to_enable.amount, (
f"Payment wallet or amount missing for extension '{ext_id}'."
"Please contact the administrator."
)
assert (
data.amount >= ext.pay_to_enable.amount
), f"Minimum amount is {ext.pay_to_enable.amount} sats."
payment_hash, payment_request = await create_invoice(
wallet_id=ext.pay_to_enable.wallet,
amount=data.amount,
memo=f"Enable '{ext.name}' extension.",
)
user_ext = await get_user_extension(user.id, ext_id)
user_ext_info = (
user_ext.extra if user_ext and user_ext.extra else UserExtensionInfo()
)
user_ext_info.payment_hash_to_enable = payment_hash
await update_user_extension_extra(user.id, ext_id, user_ext_info)
return {"payment_hash": payment_hash, "payment_request": payment_request}
except AssertionError as exc:
raise HTTPException(HTTPStatus.BAD_REQUEST, str(exc)) from exc
except Exception as exc:
logger.warning(exc)
raise HTTPException( raise HTTPException(
HTTPStatus.INTERNAL_SERVER_ERROR, "Cannot request invoice." status_code=HTTPStatus.BAD_REQUEST, detail="Amount must be greater than 0."
) from exc )
ext = await get_installed_extension(ext_id)
if not ext:
raise HTTPException(
status_code=HTTPStatus.NOT_FOUND, detail=f"Extension '{ext_id}' not found."
)
if not ext.meta or not ext.meta.pay_to_enable:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST,
detail=f"Payment info not found for extension '{ext_id}'.",
)
if not ext.meta.pay_to_enable.required:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST,
detail=f"Payment not required for extension '{ext_id}'.",
)
if not ext.meta.pay_to_enable.wallet or not ext.meta.pay_to_enable.amount:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST,
detail=f"Payment wallet or amount missing for extension '{ext_id}'.",
)
if data.amount < ext.meta.pay_to_enable.amount:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST,
detail=(
f"Amount {data.amount} sats is less than required "
f"{ext.meta.pay_to_enable.amount} sats."
),
)
payment = await create_invoice(
wallet_id=ext.meta.pay_to_enable.wallet,
amount=data.amount,
memo=f"Enable '{ext.name}' extension.",
)
user_ext = await get_user_extension(user.id, ext_id)
if not user_ext:
user_ext = UserExtension(user=user.id, extension=ext_id, active=False)
await create_user_extension(user_ext)
user_ext_info = user_ext.extra if user_ext.extra else UserExtensionInfo()
user_ext_info.payment_hash_to_enable = payment.payment_hash
user_ext.extra = user_ext_info
await update_user_extension(user_ext)
return {"payment_hash": payment.payment_hash, "payment_request": payment.bolt11}
@extension_router.get( @extension_router.get(
@ -454,7 +478,7 @@ async def get_extension_release(org: str, repo: str, tag_name: str):
) )
async def delete_extension_db(ext_id: str): async def delete_extension_db(ext_id: str):
try: try:
db_version = (await get_dbversions()).get(ext_id, None) db_version = await get_db_version(ext_id)
if not db_version: if not db_version:
raise HTTPException( raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST, status_code=HTTPStatus.BAD_REQUEST,

View file

@ -9,13 +9,12 @@ from fastapi.exceptions import HTTPException
from fastapi.responses import FileResponse, HTMLResponse, RedirectResponse from fastapi.responses import FileResponse, HTMLResponse, RedirectResponse
from fastapi.routing import APIRouter from fastapi.routing import APIRouter
from lnurl import decode as lnurl_decode from lnurl import decode as lnurl_decode
from loguru import logger
from pydantic.types import UUID4 from pydantic.types import UUID4
from lnbits.core.extensions.models import Extension, InstallableExtension from lnbits.core.extensions.models import Extension, ExtensionMeta, InstallableExtension
from lnbits.core.helpers import to_valid_user_id from lnbits.core.helpers import to_valid_user_id
from lnbits.core.models import User from lnbits.core.models import User
from lnbits.core.services import create_invoice from lnbits.core.services import create_invoice, create_user_account
from lnbits.decorators import check_admin, check_user_exists from lnbits.decorators import check_admin, check_user_exists
from lnbits.helpers import template_renderer from lnbits.helpers import template_renderer
from lnbits.settings import settings from lnbits.settings import settings
@ -23,11 +22,11 @@ from lnbits.wallets import get_funding_source
from ...utils.exchange_rates import allowed_currencies, currencies from ...utils.exchange_rates import allowed_currencies, currencies
from ..crud import ( from ..crud import (
create_account,
create_wallet, create_wallet,
get_dbversions, get_db_versions,
get_installed_extensions, get_installed_extensions,
get_user, get_user,
get_wallet,
) )
generic_router = APIRouter( generic_router = APIRouter(
@ -74,83 +73,87 @@ async def robots():
@generic_router.get("/extensions", name="extensions", response_class=HTMLResponse) @generic_router.get("/extensions", name="extensions", response_class=HTMLResponse)
async def extensions(request: Request, user: User = Depends(check_user_exists)): async def extensions(request: Request, user: User = Depends(check_user_exists)):
try: installed_exts: List[InstallableExtension] = await get_installed_extensions()
installed_exts: List[InstallableExtension] = await get_installed_extensions() installed_exts_ids = [e.id for e in installed_exts]
installed_exts_ids = [e.id for e in installed_exts]
installable_exts = await InstallableExtension.get_installable_extensions() installable_exts = await InstallableExtension.get_installable_extensions()
installable_exts_ids = [e.id for e in installable_exts] installable_exts_ids = [e.id for e in installable_exts]
installable_exts += [ installable_exts += [e for e in installed_exts if e.id not in installable_exts_ids]
e for e in installed_exts if e.id not in installable_exts_ids
]
for e in installable_exts: for e in installable_exts:
installed_ext = next((ie for ie in installed_exts if e.id == ie.id), None) installed_ext = next((ie for ie in installed_exts if e.id == ie.id), None)
if installed_ext: if installed_ext and installed_ext.meta:
e.installed_release = installed_ext.installed_release installed_release = installed_ext.meta.installed_release
if installed_ext.pay_to_enable and not user.admin: if installed_ext.meta.pay_to_enable and not user.admin:
# not a security leak, but better not to share the wallet id # not a security leak, but better not to share the wallet id
installed_ext.pay_to_enable.wallet = None installed_ext.meta.pay_to_enable.wallet = None
e.pay_to_enable = installed_ext.pay_to_enable pay_to_enable = installed_ext.meta.pay_to_enable
# use the installed extension values if e.meta:
e.name = installed_ext.name e.meta.installed_release = installed_release
e.short_description = installed_ext.short_description e.meta.pay_to_enable = pay_to_enable
e.icon = installed_ext.icon else:
e.meta = ExtensionMeta(
installed_release=installed_release,
pay_to_enable=pay_to_enable,
)
# use the installed extension values
e.name = installed_ext.name
e.short_description = installed_ext.short_description
e.icon = installed_ext.icon
except Exception as ex: all_ext_ids = [ext.code for ext in Extension.get_valid_extensions()]
logger.warning(ex) inactive_extensions = [e.id for e in await get_installed_extensions(active=False)]
installable_exts = [] db_versions = await get_db_versions()
installed_exts_ids = []
try: extensions = [
all_ext_ids = [ext.code for ext in Extension.get_valid_extensions()] {
inactive_extensions = [ "id": ext.id,
e.id for e in await get_installed_extensions(active=False) "name": ext.name,
] "icon": ext.icon,
db_version = await get_dbversions() "shortDescription": ext.short_description,
extensions = [ "stars": ext.stars,
{ "isFeatured": ext.meta.featured if ext.meta else False,
"id": ext.id, "dependencies": ext.meta.dependencies if ext.meta else "",
"name": ext.name, "isInstalled": ext.id in installed_exts_ids,
"icon": ext.icon, "hasDatabaseTables": next(
"shortDescription": ext.short_description, (True for version in db_versions if version.db == ext.id), False
"stars": ext.stars, ),
"isFeatured": ext.featured, "isAvailable": ext.id in all_ext_ids,
"dependencies": ext.dependencies, "isAdminOnly": ext.id in settings.lnbits_admin_extensions,
"isInstalled": ext.id in installed_exts_ids, "isActive": ext.id not in inactive_extensions,
"hasDatabaseTables": ext.id in db_version, "latestRelease": (
"isAvailable": ext.id in all_ext_ids, dict(ext.meta.latest_release)
"isAdminOnly": ext.id in settings.lnbits_admin_extensions, if ext.meta and ext.meta.latest_release
"isActive": ext.id not in inactive_extensions, else None
"latestRelease": ( ),
dict(ext.latest_release) if ext.latest_release else None "installedRelease": (
), dict(ext.meta.installed_release)
"installedRelease": ( if ext.meta and ext.meta.installed_release
dict(ext.installed_release) if ext.installed_release else None else None
), ),
"payToEnable": (dict(ext.pay_to_enable) if ext.pay_to_enable else {}), "payToEnable": (
"isPaymentRequired": ext.requires_payment, dict(ext.meta.pay_to_enable)
} if ext.meta and ext.meta.pay_to_enable
for ext in installable_exts else {}
] ),
"isPaymentRequired": ext.requires_payment,
}
for ext in installable_exts
]
# refresh user state. Eg: enabled extensions. # refresh user state. Eg: enabled extensions.
user = await get_user(user.id) or user # TODO: refactor
# user = await get_user(user.id) or user
return template_renderer().TemplateResponse( return template_renderer().TemplateResponse(
request, request,
"core/extensions.html", "core/extensions.html",
{ {
"user": user.dict(), "user": user.json(),
"extensions": extensions, "extensions": extensions,
}, },
) )
except Exception as exc:
logger.warning(exc)
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail=str(exc)
) from exc
@generic_router.get( @generic_router.get(
@ -165,18 +168,16 @@ async def wallet(
wal: Optional[UUID4] = Query(None), wal: Optional[UUID4] = Query(None),
): ):
if wal: if wal:
wallet_id = wal.hex wallet = await get_wallet(wal.hex)
elif len(user.wallets) == 0: elif len(user.wallets) == 0:
wallet = await create_wallet(user_id=user.id) wallet = await create_wallet(user_id=user.id)
user = await get_user(user_id=user.id) or user user.wallets.append(wallet)
wallet_id = wallet.id
elif lnbits_last_active_wallet and user.get_wallet(lnbits_last_active_wallet): elif lnbits_last_active_wallet and user.get_wallet(lnbits_last_active_wallet):
wallet_id = lnbits_last_active_wallet wallet = await get_wallet(lnbits_last_active_wallet)
else: else:
wallet_id = user.wallets[0].id wallet = user.wallets[0]
user_wallet = user.get_wallet(wallet_id) if not wallet or wallet.deleted:
if not user_wallet or user_wallet.deleted:
return template_renderer().TemplateResponse( return template_renderer().TemplateResponse(
request, "error.html", {"err": "Wallet not found"}, HTTPStatus.NOT_FOUND request, "error.html", {"err": "Wallet not found"}, HTTPStatus.NOT_FOUND
) )
@ -185,15 +186,16 @@ async def wallet(
request, request,
"core/wallet.html", "core/wallet.html",
{ {
"user": user.dict(), "user": user.json(),
"wallet": user_wallet.dict(), "wallet": wallet.json(),
"wallet_name": wallet.name,
"currencies": allowed_currencies(), "currencies": allowed_currencies(),
"service_fee": settings.lnbits_service_fee, "service_fee": settings.lnbits_service_fee,
"service_fee_max": settings.lnbits_service_fee_max, "service_fee_max": settings.lnbits_service_fee_max,
"web_manifest": f"/manifest/{user.id}.webmanifest", "web_manifest": f"/manifest/{user.id}.webmanifest",
}, },
) )
resp.set_cookie("lnbits_last_active_wallet", wallet_id) resp.set_cookie("lnbits_last_active_wallet", wallet.id)
return resp return resp
@ -209,7 +211,9 @@ async def account(
return template_renderer().TemplateResponse( return template_renderer().TemplateResponse(
request, request,
"core/account.html", "core/account.html",
{"user": user.dict()}, {
"user": user.json(),
},
) )
@ -228,11 +232,9 @@ async def service_worker(request: Request):
@generic_router.get("/manifest/{usr}.webmanifest") @generic_router.get("/manifest/{usr}.webmanifest")
async def manifest(request: Request, usr: str): async def manifest(request: Request, usr: str):
host = urlparse(str(request.url)).netloc host = urlparse(str(request.url)).netloc
user = await get_user(usr) user = await get_user(usr)
if not user: if not user:
raise HTTPException(status_code=HTTPStatus.NOT_FOUND) raise HTTPException(status_code=HTTPStatus.NOT_FOUND)
return { return {
"short_name": settings.lnbits_site_title, "short_name": settings.lnbits_site_title,
"name": settings.lnbits_site_title + " Wallet", "name": settings.lnbits_site_title + " Wallet",
@ -320,10 +322,10 @@ async def node(request: Request, user: User = Depends(check_admin)):
request, request,
"node/index.html", "node/index.html",
{ {
"user": user.dict(), "user": user.json(),
"settings": settings.dict(), "settings": settings.dict(),
"balance": balance, "balance": balance,
"wallets": user.wallets[0].dict(), "wallets": user.wallets[0].json(),
}, },
) )
@ -358,7 +360,7 @@ async def admin_index(request: Request, user: User = Depends(check_admin)):
request, request,
"admin/index.html", "admin/index.html",
{ {
"user": user.dict(), "user": user.json(),
"settings": settings.dict(), "settings": settings.dict(),
"balance": balance, "balance": balance,
"currencies": list(currencies.keys()), "currencies": list(currencies.keys()),
@ -375,7 +377,7 @@ async def users_index(request: Request, user: User = Depends(check_admin)):
"users/index.html", "users/index.html",
{ {
"request": request, "request": request,
"user": user.dict(), "user": user.json(),
"settings": settings.dict(), "settings": settings.dict(),
"currencies": list(currencies.keys()), "currencies": list(currencies.keys()),
}, },
@ -424,7 +426,7 @@ async def lnurlwallet(request: Request):
status_code=HTTPStatus.BAD_REQUEST, status_code=HTTPStatus.BAD_REQUEST,
detail="Invalid lnurl. Expected maxWithdrawable", detail="Invalid lnurl. Expected maxWithdrawable",
) )
account = await create_account() account = await create_user_account()
wallet = await create_wallet(user_id=account.id) wallet = await create_wallet(user_id=account.id)
_, payment_request = await create_invoice( _, payment_request = await create_invoice(
wallet_id=wallet.id, wallet_id=wallet.id,

View file

@ -3,13 +3,12 @@ import json
import uuid import uuid
from http import HTTPStatus from http import HTTPStatus
from math import ceil from math import ceil
from typing import List, Optional, Union from typing import List, Optional
from urllib.parse import urlparse from urllib.parse import urlparse
import httpx import httpx
from fastapi import ( from fastapi import (
APIRouter, APIRouter,
Body,
Depends, Depends,
Header, Header,
HTTPException, HTTPException,
@ -21,7 +20,6 @@ from loguru import logger
from sse_starlette.sse import EventSourceResponse from sse_starlette.sse import EventSourceResponse
from lnbits import bolt11 from lnbits import bolt11
from lnbits.core.db import db
from lnbits.core.models import ( from lnbits.core.models import (
CreateInvoice, CreateInvoice,
CreateLnurl, CreateLnurl,
@ -121,7 +119,7 @@ async def api_payments_paginated(
return page return page
async def api_payments_create_invoice(data: CreateInvoice, wallet: Wallet): async def _api_payments_create_invoice(data: CreateInvoice, wallet: Wallet):
description_hash = b"" description_hash = b""
unhashed_description = b"" unhashed_description = b""
memo = data.memo or settings.lnbits_site_title memo = data.memo or settings.lnbits_site_title
@ -145,60 +143,42 @@ async def api_payments_create_invoice(data: CreateInvoice, wallet: Wallet):
# do not save memo if description_hash or unhashed_description is set # do not save memo if description_hash or unhashed_description is set
memo = "" memo = ""
async with db.connect() as conn: payment = await create_invoice(
payment_hash, payment_request = await create_invoice( wallet_id=wallet.id,
wallet_id=wallet.id, amount=data.amount,
amount=data.amount, memo=memo,
memo=memo, currency=data.unit,
currency=data.unit, description_hash=description_hash,
description_hash=description_hash, unhashed_description=unhashed_description,
unhashed_description=unhashed_description, expiry=data.expiry,
expiry=data.expiry, extra=data.extra,
extra=data.extra, webhook=data.webhook,
webhook=data.webhook, internal=data.internal,
internal=data.internal, )
conn=conn,
)
# NOTE: we get the checking_id with a seperate query because create_invoice
# does not return it and it would be a big hustle to change its return type
# (used across extensions)
payment_db = await get_standalone_payment(payment_hash, conn=conn)
assert payment_db is not None, "payment not found"
checking_id = payment_db.checking_id
invoice = bolt11.decode(payment_request) # lnurl_response is not saved in the database
lnurl_response: Union[None, bool, str] = None
if data.lnurl_callback: if data.lnurl_callback:
headers = {"User-Agent": settings.user_agent} headers = {"User-Agent": settings.user_agent}
async with httpx.AsyncClient(headers=headers) as client: async with httpx.AsyncClient(headers=headers) as client:
try: try:
r = await client.get( r = await client.get(
data.lnurl_callback, data.lnurl_callback,
params={ params={"pr": payment.bolt11},
"pr": payment_request,
},
timeout=10, timeout=10,
) )
if r.is_error: if r.is_error:
lnurl_response = r.text payment.extra["lnurl_response"] = r.text
else: else:
resp = json.loads(r.text) resp = json.loads(r.text)
if resp["status"] != "OK": if resp["status"] != "OK":
lnurl_response = resp["reason"] payment.extra["lnurl_response"] = resp["reason"]
else: else:
lnurl_response = True payment.extra["lnurl_response"] = True
except (httpx.ConnectError, httpx.RequestError) as ex: except (httpx.ConnectError, httpx.RequestError) as ex:
logger.error(ex) logger.error(ex)
lnurl_response = False payment.extra["lnurl_response"] = False
return { return payment
"payment_hash": invoice.payment_hash,
"payment_request": payment_request,
"lnurl_response": lnurl_response,
# maintain backwards compatibility with API clients:
"checking_id": checking_id,
}
@payment_router.post( @payment_router.post(
@ -220,30 +200,25 @@ async def api_payments_create_invoice(data: CreateInvoice, wallet: Wallet):
}, },
) )
async def api_payments_create( async def api_payments_create(
invoice_data: CreateInvoice,
wallet: WalletTypeInfo = Depends(require_invoice_key), wallet: WalletTypeInfo = Depends(require_invoice_key),
invoice_data: CreateInvoice = Body(...), ) -> Payment:
):
if invoice_data.out is True and wallet.key_type == KeyType.admin: if invoice_data.out is True and wallet.key_type == KeyType.admin:
if not invoice_data.bolt11: if not invoice_data.bolt11:
raise HTTPException( raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST, status_code=HTTPStatus.BAD_REQUEST,
detail="BOLT11 string is invalid or not given", detail="Missing BOLT11 invoice",
) )
payment = await pay_invoice(
payment_hash = await pay_invoice(
wallet_id=wallet.wallet.id, wallet_id=wallet.wallet.id,
payment_request=invoice_data.bolt11, payment_request=invoice_data.bolt11,
extra=invoice_data.extra, extra=invoice_data.extra,
) )
return { return payment
"payment_hash": payment_hash,
# maintain backwards compatibility with API clients:
"checking_id": payment_hash,
}
elif not invoice_data.out: elif not invoice_data.out:
# invoice key # invoice key
return await api_payments_create_invoice(invoice_data, wallet.wallet) return await _api_payments_create_invoice(invoice_data, wallet.wallet)
else: else:
raise HTTPException( raise HTTPException(
status_code=HTTPStatus.UNAUTHORIZED, status_code=HTTPStatus.UNAUTHORIZED,
@ -269,7 +244,7 @@ async def api_payments_fee_reserve(invoice: str = Query("invoice")) -> JSONRespo
@payment_router.post("/lnurl") @payment_router.post("/lnurl")
async def api_payments_pay_lnurl( async def api_payments_pay_lnurl(
data: CreateLnurl, wallet: WalletTypeInfo = Depends(require_admin_key) data: CreateLnurl, wallet: WalletTypeInfo = Depends(require_admin_key)
): ) -> Payment:
domain = urlparse(data.callback).netloc domain = urlparse(data.callback).netloc
headers = {"User-Agent": settings.user_agent} headers = {"User-Agent": settings.user_agent}
@ -313,15 +288,12 @@ async def api_payments_pay_lnurl(
raise HTTPException( raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST, status_code=HTTPStatus.BAD_REQUEST,
detail=( detail=(
( f"{domain} returned an invalid invoice. Expected"
f"{domain} returned an invalid invoice. Expected" f" {amount_msat} msat, got {invoice.amount_msat}."
f" {amount_msat} msat, got {invoice.amount_msat}."
),
), ),
) )
extra = {} extra = {}
if params.get("successAction"): if params.get("successAction"):
extra["success_action"] = params["successAction"] extra["success_action"] = params["successAction"]
if data.comment: if data.comment:
@ -330,19 +302,14 @@ async def api_payments_pay_lnurl(
extra["fiat_currency"] = data.unit extra["fiat_currency"] = data.unit
extra["fiat_amount"] = data.amount / 1000 extra["fiat_amount"] = data.amount / 1000
assert data.description is not None, "description is required" assert data.description is not None, "description is required"
payment_hash = await pay_invoice(
payment = await pay_invoice(
wallet_id=wallet.wallet.id, wallet_id=wallet.wallet.id,
payment_request=params["pr"], payment_request=params["pr"],
description=data.description, description=data.description,
extra=extra, extra=extra,
) )
return payment
return {
"success_action": params.get("successAction"),
"payment_hash": payment_hash,
# maintain backwards compatibility with API clients:
"checking_id": payment_hash,
}
async def subscribe_wallet_invoices(request: Request, wallet: Wallet): async def subscribe_wallet_invoices(request: Request, wallet: Wallet):

View file

@ -5,7 +5,7 @@ from http import HTTPStatus
from typing import List from typing import List
from fastapi import APIRouter, Depends from fastapi import APIRouter, Depends
from starlette.exceptions import HTTPException from fastapi.exceptions import HTTPException
from lnbits.core.crud import ( from lnbits.core.crud import (
delete_account, delete_account,
@ -17,8 +17,8 @@ from lnbits.core.crud import (
update_admin_settings, update_admin_settings,
) )
from lnbits.core.models import ( from lnbits.core.models import (
Account,
AccountFilters, AccountFilters,
AccountOverview,
CreateTopup, CreateTopup,
User, User,
Wallet, Wallet,
@ -40,42 +40,33 @@ users_router = APIRouter(prefix="/users/api/v1", dependencies=[Depends(check_adm
) )
async def api_get_users( async def api_get_users(
filters: Filters = Depends(parse_filters(AccountFilters)), filters: Filters = Depends(parse_filters(AccountFilters)),
) -> Page[Account]: ) -> Page[AccountOverview]:
try: return await get_accounts(filters=filters)
filtered = await get_accounts(filters=filters)
for user in filtered.data:
user.is_super_user = user.id == settings.super_user
user.is_admin = user.id in settings.lnbits_admin_users or user.is_super_user
return filtered
except Exception as exc:
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
detail=f"Could not fetch users. {exc!s}",
) from exc
@users_router.delete("/user/{user_id}", status_code=HTTPStatus.OK) @users_router.delete("/user/{user_id}", status_code=HTTPStatus.OK)
async def api_users_delete_user( async def api_users_delete_user(
user_id: str, user: User = Depends(check_admin) user_id: str, user: User = Depends(check_admin)
) -> None: ) -> None:
wallets = await get_wallets(user_id)
try: if len(wallets) > 0:
wallets = await get_wallets(user_id)
if len(wallets) > 0:
raise Exception("Cannot delete user with wallets.")
if user_id == settings.super_user:
raise Exception("Cannot delete super user.")
if user_id in settings.lnbits_admin_users and not user.super_user:
raise Exception("Only super_user can delete admin user.")
await delete_account(user_id)
except Exception as exc:
raise HTTPException( raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR, status_code=HTTPStatus.BAD_REQUEST,
detail=f"{exc!s}", detail="Cannot delete user with wallets.",
) from exc )
if user_id == settings.super_user:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST,
detail="Cannot delete super user.",
)
if user_id in settings.lnbits_admin_users and not user.super_user:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST,
detail="Only super_user can delete admin user.",
)
await delete_account(user_id)
@users_router.put( @users_router.put(
@ -98,66 +89,53 @@ async def api_users_reset_password(user_id: str) -> str:
@users_router.get("/user/{user_id}/admin", dependencies=[Depends(check_super_user)]) @users_router.get("/user/{user_id}/admin", dependencies=[Depends(check_super_user)])
async def api_users_toggle_admin(user_id: str) -> None: async def api_users_toggle_admin(user_id: str) -> None:
try: if user_id == settings.super_user:
if user_id == settings.super_user:
raise Exception("Cannot change super user.")
if user_id in settings.lnbits_admin_users:
settings.lnbits_admin_users.remove(user_id)
else:
settings.lnbits_admin_users.append(user_id)
update_settings = EditableSettings(
lnbits_admin_users=settings.lnbits_admin_users
)
await update_admin_settings(update_settings)
except Exception as exc:
raise HTTPException( raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR, status_code=HTTPStatus.BAD_REQUEST,
detail=f"Could not update admin settings. {exc}", detail="Cannot change super user.",
) from exc )
if user_id in settings.lnbits_admin_users:
settings.lnbits_admin_users.remove(user_id)
else:
settings.lnbits_admin_users.append(user_id)
update_settings = EditableSettings(lnbits_admin_users=settings.lnbits_admin_users)
await update_admin_settings(update_settings)
@users_router.get("/user/{user_id}/wallet") @users_router.get("/user/{user_id}/wallet")
async def api_users_get_user_wallet(user_id: str) -> List[Wallet]: async def api_users_get_user_wallet(user_id: str) -> List[Wallet]:
try: return await get_wallets(user_id)
return await get_wallets(user_id)
except Exception as exc:
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
detail=f"Could not fetch user wallets. {exc}",
) from exc
@users_router.get("/user/{user_id}/wallet/{wallet}/undelete") @users_router.get("/user/{user_id}/wallet/{wallet}/undelete")
async def api_users_undelete_user_wallet(user_id: str, wallet: str) -> None: async def api_users_undelete_user_wallet(user_id: str, wallet: str) -> None:
try: wal = await get_wallet(wallet)
wal = await get_wallet(wallet) if not wal:
if not wal:
raise Exception("Wallet does not exist.")
if user_id != wal.user:
raise Exception("Wallet does not belong to user.")
if wal.deleted:
await delete_wallet(user_id=user_id, wallet_id=wallet, deleted=False)
except Exception as exc:
raise HTTPException( raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR, status_code=HTTPStatus.NOT_FOUND,
detail=f"{exc!s}", detail="Wallet does not exist.",
) from exc )
if user_id != wal.user:
raise HTTPException(
status_code=HTTPStatus.FORBIDDEN,
detail="Wallet does not belong to user.",
)
if wal.deleted:
await delete_wallet(user_id=user_id, wallet_id=wallet, deleted=False)
@users_router.delete("/user/{user_id}/wallet/{wallet}") @users_router.delete("/user/{user_id}/wallet/{wallet}")
async def api_users_delete_user_wallet(user_id: str, wallet: str) -> None: async def api_users_delete_user_wallet(user_id: str, wallet: str) -> None:
try: wal = await get_wallet(wallet)
wal = await get_wallet(wallet) if not wal:
if not wal:
raise Exception("Wallet does not exist.")
if wal.deleted:
await force_delete_wallet(wallet)
await delete_wallet(user_id=user_id, wallet_id=wallet)
except Exception as exc:
raise HTTPException( raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR, status_code=HTTPStatus.NOT_FOUND,
detail=f"{exc!s}", detail="Wallet does not exist.",
) from exc )
if wal.deleted:
await force_delete_wallet(wallet)
await delete_wallet(user_id=user_id, wallet_id=wallet)
@users_router.put( @users_router.put(
@ -167,14 +145,9 @@ async def api_users_delete_user_wallet(user_id: str, wallet: str) -> None:
dependencies=[Depends(check_super_user)], dependencies=[Depends(check_super_user)],
) )
async def api_topup_balance(data: CreateTopup) -> dict[str, str]: async def api_topup_balance(data: CreateTopup) -> dict[str, str]:
try: await get_wallet(data.id)
await get_wallet(data.id) if settings.lnbits_backend_wallet_class == "VoidWallet":
if settings.lnbits_backend_wallet_class == "VoidWallet": raise Exception("VoidWallet active")
raise Exception("VoidWallet active")
await update_wallet_balance(wallet_id=data.id, amount=int(data.amount)) await update_wallet_balance(wallet_id=data.id, amount=int(data.amount))
return {"status": "Success"} return {"status": "Success"}
except Exception as exc:
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail=f"{exc!s}"
) from exc

View file

@ -1,9 +1,11 @@
from http import HTTPStatus
from typing import Optional from typing import Optional
from fastapi import ( from fastapi import (
APIRouter, APIRouter,
Body, Body,
Depends, Depends,
HTTPException,
) )
from lnbits.core.models import ( from lnbits.core.models import (
@ -20,6 +22,7 @@ from lnbits.decorators import (
from ..crud import ( from ..crud import (
create_wallet, create_wallet,
delete_wallet, delete_wallet,
get_wallet,
update_wallet, update_wallet,
) )
@ -27,35 +30,45 @@ wallet_router = APIRouter(prefix="/api/v1/wallet", tags=["Wallet"])
@wallet_router.get("") @wallet_router.get("")
async def api_wallet(wallet: WalletTypeInfo = Depends(require_invoice_key)): async def api_wallet(key_info: WalletTypeInfo = Depends(require_invoice_key)):
res = { res = {
"name": wallet.wallet.name, "name": key_info.wallet.name,
"balance": wallet.wallet.balance_msat, "balance": key_info.wallet.balance_msat,
} }
if wallet.key_type == KeyType.admin: if key_info.key_type == KeyType.admin:
res["id"] = wallet.wallet.id res["id"] = key_info.wallet.id
return res return res
@wallet_router.put("/{new_name}") @wallet_router.put("/{new_name}")
async def api_update_wallet_name( async def api_update_wallet_name(
new_name: str, wallet: WalletTypeInfo = Depends(require_admin_key) new_name: str, key_info: WalletTypeInfo = Depends(require_admin_key)
): ):
await update_wallet(wallet.wallet.id, new_name) wallet = await get_wallet(key_info.wallet.id)
if not wallet:
raise HTTPException(status_code=HTTPStatus.NOT_FOUND, detail="Wallet not found")
wallet.name = new_name
await update_wallet(wallet)
return { return {
"id": wallet.wallet.id, "id": wallet.id,
"name": wallet.wallet.name, "name": wallet.name,
"balance": wallet.wallet.balance_msat, "balance": wallet.balance_msat,
} }
@wallet_router.patch("", response_model=Wallet) @wallet_router.patch("")
async def api_update_wallet( async def api_update_wallet(
name: Optional[str] = Body(None), name: Optional[str] = Body(None),
currency: Optional[str] = Body(None), currency: Optional[str] = Body(None),
wallet: WalletTypeInfo = Depends(require_admin_key), key_info: WalletTypeInfo = Depends(require_admin_key),
): ) -> Wallet:
return await update_wallet(wallet.wallet.id, name, currency) wallet = await get_wallet(key_info.wallet.id)
if not wallet:
raise HTTPException(status_code=HTTPStatus.NOT_FOUND, detail="Wallet not found")
wallet.name = name or wallet.name
wallet.currency = currency if currency is not None else wallet.currency
await update_wallet(wallet)
return wallet
@wallet_router.delete("") @wallet_router.delete("")
@ -68,9 +81,9 @@ async def api_delete_wallet(
) )
@wallet_router.post("", response_model=Wallet) @wallet_router.post("")
async def api_create_wallet( async def api_create_wallet(
data: CreateWallet, data: CreateWallet,
wallet: WalletTypeInfo = Depends(require_admin_key), key_info: WalletTypeInfo = Depends(require_admin_key),
) -> Wallet: ) -> Wallet:
return await create_wallet(user_id=wallet.wallet.user, wallet_name=data.name) return await create_wallet(user_id=key_info.wallet.user, wallet_name=data.name)

View file

@ -1,13 +1,14 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import datetime import json
import os import os
import re import re
import time import time
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from datetime import datetime, timezone
from enum import Enum from enum import Enum
from typing import Any, Generic, Literal, Optional, TypeVar from typing import Any, Generic, Literal, Optional, TypeVar, Union
from loguru import logger from loguru import logger
from pydantic import BaseModel, ValidationError, root_validator from pydantic import BaseModel, ValidationError, root_validator
@ -50,7 +51,7 @@ def compat_timestamp_placeholder(key: str):
def get_placeholder(model: Any, field: str) -> str: def get_placeholder(model: Any, field: str) -> str:
type_ = model.__fields__[field].type_ type_ = model.__fields__[field].type_
if type_ == datetime.datetime: if type_ == datetime:
return compat_timestamp_placeholder(field) return compat_timestamp_placeholder(field)
else: else:
return f":{field}" return f":{field}"
@ -67,7 +68,7 @@ class Compat:
return f"{seconds}" return f"{seconds}"
return "<nothing>" return "<nothing>"
def datetime_to_timestamp(self, date: datetime.datetime): def datetime_to_timestamp(self, date: datetime):
if self.type in {POSTGRES, COCKROACH}: if self.type in {POSTGRES, COCKROACH}:
return date.strftime("%Y-%m-%d %H:%M:%S") return date.strftime("%Y-%m-%d %H:%M:%S")
elif self.type == SQLITE: elif self.type == SQLITE:
@ -134,7 +135,7 @@ class Connection(Compat):
for key, raw_value in values.items(): for key, raw_value in values.items():
if isinstance(raw_value, str): if isinstance(raw_value, str):
clean_values[key] = re.sub(clean_regex, "", raw_value) clean_values[key] = re.sub(clean_regex, "", raw_value)
elif isinstance(raw_value, datetime.datetime): elif isinstance(raw_value, datetime):
ts = raw_value.timestamp() ts = raw_value.timestamp()
if self.type == SQLITE: if self.type == SQLITE:
clean_values[key] = int(ts) clean_values[key] = int(ts)
@ -144,29 +145,59 @@ class Connection(Compat):
clean_values[key] = raw_value clean_values[key] = raw_value
return clean_values return clean_values
async def fetchall(self, query: str, values: Optional[dict] = None) -> list[dict]: async def fetchall(
self,
query: str,
values: Optional[dict] = None,
model: Optional[type[TModel]] = None,
) -> list[TModel]:
params = self.rewrite_values(values) if values else {} params = self.rewrite_values(values) if values else {}
result = await self.conn.execute(text(self.rewrite_query(query)), params) result = await self.conn.execute(text(self.rewrite_query(query)), params)
row = result.mappings().all() row = result.mappings().all()
result.close() result.close()
if not row:
return []
if model:
return [dict_to_model(r, model) for r in row]
return row return row
async def fetchone(self, query: str, values: Optional[dict] = None) -> dict: async def fetchone(
self,
query: str,
values: Optional[dict] = None,
model: Optional[type[TModel]] = None,
) -> TModel:
params = self.rewrite_values(values) if values else {} params = self.rewrite_values(values) if values else {}
result = await self.conn.execute(text(self.rewrite_query(query)), params) result = await self.conn.execute(text(self.rewrite_query(query)), params)
row = result.mappings().first() row = result.mappings().first()
result.close() result.close()
if model and row:
return dict_to_model(row, model)
return row return row
async def update(
self, table_name: str, model: BaseModel, where: str = "WHERE id = :id"
):
await self.conn.execute(
text(update_query(table_name, model, where)), model_to_dict(model)
)
await self.conn.commit()
async def insert(self, table_name: str, model: BaseModel):
await self.conn.execute(
text(insert_query(table_name, model)), model_to_dict(model)
)
await self.conn.commit()
async def fetch_page( async def fetch_page(
self, self,
query: str, query: str,
where: Optional[list[str]] = None, where: Optional[list[str]] = None,
values: Optional[dict] = None, values: Optional[dict] = None,
filters: Optional[Filters] = None, filters: Optional[Filters] = None,
model: Optional[type[TRowModel]] = None, model: Optional[type[TModel]] = None,
group_by: Optional[list[str]] = None, group_by: Optional[list[str]] = None,
) -> Page[TRowModel]: ) -> Page[TModel]:
if not filters: if not filters:
filters = Filters() filters = Filters()
clause = filters.where(where) clause = filters.where(where)
@ -190,11 +221,12 @@ class Connection(Compat):
{filters.pagination()} {filters.pagination()}
""", """,
self.rewrite_values(parsed_values), self.rewrite_values(parsed_values),
model,
) )
if rows: if rows:
# no need for extra query if no pagination is specified # no need for extra query if no pagination is specified
if filters.offset or filters.limit: if filters.offset or filters.limit:
result = await self.fetchone( result = await self.execute(
f""" f"""
SELECT COUNT(*) as count FROM ( SELECT COUNT(*) as count FROM (
{query} {query}
@ -204,14 +236,16 @@ class Connection(Compat):
""", """,
parsed_values, parsed_values,
) )
count = int(result.get("count", 0)) row = result.mappings().first()
result.close()
count = int(row.get("count", 0))
else: else:
count = len(rows) count = len(rows)
else: else:
count = 0 count = 0
return Page( return Page(
data=[model.from_row(row) for row in rows] if model else [], data=rows,
total=count, total=count,
) )
@ -251,21 +285,19 @@ class Database(Compat):
@event.listens_for(self.engine.sync_engine, "connect") @event.listens_for(self.engine.sync_engine, "connect")
def register_custom_types(dbapi_connection, *_): def register_custom_types(dbapi_connection, *_):
def _parse_timestamp(value): def _parse_date(value) -> datetime:
if value is None: if value is None:
return None value = "1970-01-01 00:00:00"
f = "%Y-%m-%d %H:%M:%S.%f" f = "%Y-%m-%d %H:%M:%S.%f"
if "." not in value: if "." not in value:
f = "%Y-%m-%d %H:%M:%S" f = "%Y-%m-%d %H:%M:%S"
return int( return datetime.strptime(value, f)
time.mktime(datetime.datetime.strptime(value, f).timetuple())
)
dbapi_connection.run_async( dbapi_connection.run_async(
lambda connection: connection.set_type_codec( lambda connection: connection.set_type_codec(
"TIMESTAMP", "TIMESTAMP",
encoder=datetime.datetime, encoder=datetime,
decoder=_parse_timestamp, decoder=_parse_date,
schema="pg_catalog", schema="pg_catalog",
) )
) )
@ -296,13 +328,33 @@ class Database(Compat):
finally: finally:
self.lock.release() self.lock.release()
async def fetchall(self, query: str, values: Optional[dict] = None) -> list[dict]: async def fetchall(
self,
query: str,
values: Optional[dict] = None,
model: Optional[type[TModel]] = None,
) -> list[TModel]:
async with self.connect() as conn: async with self.connect() as conn:
return await conn.fetchall(query, values) return await conn.fetchall(query, values, model)
async def fetchone(self, query: str, values: Optional[dict] = None) -> dict: async def fetchone(
self,
query: str,
values: Optional[dict] = None,
model: Optional[type[TModel]] = None,
) -> TModel:
async with self.connect() as conn: async with self.connect() as conn:
return await conn.fetchone(query, values) return await conn.fetchone(query, values, model)
async def insert(self, table_name: str, model: BaseModel) -> None:
async with self.connect() as conn:
await conn.insert(table_name, model)
async def update(
self, table_name: str, model: BaseModel, where: str = "WHERE id = :id"
) -> None:
async with self.connect() as conn:
await conn.update(table_name, model, where)
async def fetch_page( async def fetch_page(
self, self,
@ -310,9 +362,9 @@ class Database(Compat):
where: Optional[list[str]] = None, where: Optional[list[str]] = None,
values: Optional[dict] = None, values: Optional[dict] = None,
filters: Optional[Filters] = None, filters: Optional[Filters] = None,
model: Optional[type[TRowModel]] = None, model: Optional[type[TModel]] = None,
group_by: Optional[list[str]] = None, group_by: Optional[list[str]] = None,
) -> Page[TRowModel]: ) -> Page[TModel]:
async with self.connect() as conn: async with self.connect() as conn:
return await conn.fetch_page(query, where, values, filters, model, group_by) return await conn.fetch_page(query, where, values, filters, model, group_by)
@ -372,12 +424,6 @@ class Operator(Enum):
raise ValueError("Unknown SQL Operator") raise ValueError("Unknown SQL Operator")
class FromRowModel(BaseModel):
@classmethod
def from_row(cls, row: dict):
return cls(**row)
class FilterModel(BaseModel): class FilterModel(BaseModel):
__search_fields__: list[str] = [] __search_fields__: list[str] = []
__sort_fields__: Optional[list[str]] = None __sort_fields__: Optional[list[str]] = None
@ -385,7 +431,6 @@ class FilterModel(BaseModel):
T = TypeVar("T") T = TypeVar("T")
TModel = TypeVar("TModel", bound=BaseModel) TModel = TypeVar("TModel", bound=BaseModel)
TRowModel = TypeVar("TRowModel", bound=FromRowModel)
TFilterModel = TypeVar("TFilterModel", bound=FilterModel) TFilterModel = TypeVar("TFilterModel", bound=FilterModel)
@ -435,10 +480,7 @@ class Filter(BaseModel, Generic[TFilterModel]):
stmt = [] stmt = []
for key in self.values.keys() if self.values else []: for key in self.values.keys() if self.values else []:
clean_key = key.split("__")[0] clean_key = key.split("__")[0]
if ( if self.model and self.model.__fields__[clean_key].type_ == datetime:
self.model
and self.model.__fields__[clean_key].type_ == datetime.datetime
):
placeholder = compat_timestamp_placeholder(key) placeholder = compat_timestamp_placeholder(key)
else: else:
placeholder = f":{key}" placeholder = f":{key}"
@ -518,3 +560,111 @@ class Filters(BaseModel, Generic[TFilterModel]):
if self.search and self.model: if self.search and self.model:
values["search"] = f"%{self.search}%" values["search"] = f"%{self.search}%"
return values return values
def insert_query(table_name: str, model: BaseModel) -> str:
"""
Generate an insert query with placeholders for a given table and model
:param table_name: Name of the table
:param model: Pydantic model
"""
placeholders = []
keys = model_to_dict(model).keys()
for field in keys:
placeholders.append(get_placeholder(model, field))
# add quotes to keys to avoid SQL conflicts (e.g. `user` is a reserved keyword)
fields = ", ".join([f'"{key}"' for key in keys])
values = ", ".join(placeholders)
return f"INSERT INTO {table_name} ({fields}) VALUES ({values})"
def update_query(
table_name: str, model: BaseModel, where: str = "WHERE id = :id"
) -> str:
"""
Generate an update query with placeholders for a given table and model
:param table_name: Name of the table
:param model: Pydantic model
:param where: Where string, default to `WHERE id = :id`
"""
fields = []
for field in model_to_dict(model).keys():
placeholder = get_placeholder(model, field)
# add quotes to keys to avoid SQL conflicts (e.g. `user` is a reserved keyword)
fields.append(f'"{field}" = {placeholder}')
query = ", ".join(fields)
return f"UPDATE {table_name} SET {query} {where}"
def model_to_dict(model: BaseModel) -> dict:
"""
Convert a Pydantic model to a dictionary with JSON-encoded nested models
private fields starting with _ are ignored
:param model: Pydantic model
"""
_dict: dict = {}
for key, value in model.dict().items():
type_ = model.__fields__[key].type_
if model.__fields__[key].field_info.extra.get("no_database", False):
continue
if isinstance(value, datetime):
_dict[key] = value.timestamp()
continue
if type(type_) is type(BaseModel) or type_ is dict:
_dict[key] = json.dumps(value)
continue
_dict[key] = value
return _dict
def dict_to_submodel(model: type[TModel], value: Union[dict, str]) -> Optional[TModel]:
"""convert a dictionary or JSON string to a Pydantic model"""
if isinstance(value, str):
if value == "null":
return None
_subdict = json.loads(value)
elif isinstance(value, dict):
_subdict = value
else:
logger.warning(f"Expected str or dict, got {type(value)}")
return None
# recursively convert nested models
return dict_to_model(_subdict, model)
def dict_to_model(_row: dict, model: type[TModel]) -> TModel:
"""
Convert a dictionary with JSON-encoded nested models to a Pydantic model
:param _dict: Dictionary from database
:param model: Pydantic model
"""
_dict: dict = {}
for key, value in _row.items():
if value is None:
continue
if key not in model.__fields__:
logger.warning(f"Converting {key} to model `{model}`.")
continue
type_ = model.__fields__[key].type_
if issubclass(type_, bool):
_dict[key] = bool(value)
continue
if issubclass(type_, datetime):
if DB_TYPE == SQLITE:
_dict[key] = datetime.fromtimestamp(value, timezone.utc)
else:
_dict[key] = value
continue
if issubclass(type_, BaseModel) and value:
_dict[key] = dict_to_submodel(type_, value)
continue
# TODO: remove this when all sub models are migrated to Pydantic
# NOTE: this is for type dict on BaseModel, (used in Payment class)
if type_ is dict and value:
_dict[key] = json.loads(value)
continue
_dict[key] = value
continue
_model = model.construct(**_dict)
return _model

View file

@ -14,12 +14,13 @@ from lnbits.core.crud import (
get_account, get_account,
get_account_by_email, get_account_by_email,
get_account_by_username, get_account_by_username,
get_user,
get_user_active_extensions_ids, get_user_active_extensions_ids,
get_user_from_account,
get_wallet_for_key, get_wallet_for_key,
) )
from lnbits.core.models import ( from lnbits.core.models import (
AccessTokenPayload, AccessTokenPayload,
Account,
KeyType, KeyType,
SimpleStatus, SimpleStatus,
User, User,
@ -65,7 +66,7 @@ class KeyChecker(SecurityBase):
name="X-API-KEY", name="X-API-KEY",
description="Wallet API Key - HEADER", description="Wallet API Key - HEADER",
) )
self.model: APIKey = openapi_model self.model: APIKey = openapi_model # type: ignore
async def __call__(self, request: Request) -> WalletTypeInfo: async def __call__(self, request: Request) -> WalletTypeInfo:
@ -144,14 +145,16 @@ async def check_user_exists(
else: else:
raise HTTPException(HTTPStatus.UNAUTHORIZED, "Missing user ID or access token.") raise HTTPException(HTTPStatus.UNAUTHORIZED, "Missing user ID or access token.")
if not account or not settings.is_user_allowed(account.id): if not account:
raise HTTPException(HTTPStatus.UNAUTHORIZED, "User not found.")
if not settings.is_user_allowed(account.id):
raise HTTPException(HTTPStatus.UNAUTHORIZED, "User not allowed.") raise HTTPException(HTTPStatus.UNAUTHORIZED, "User not allowed.")
user = await get_user(account.id) user = await get_user_from_account(account)
assert user, "User not found for account." if not user:
raise HTTPException(HTTPStatus.UNAUTHORIZED, "User not found.")
await _check_user_extension_access(user.id, r["path"]) await _check_user_extension_access(user.id, r["path"])
return user return user
@ -261,7 +264,7 @@ async def _check_user_extension_access(user_id: str, current_path: str):
) )
async def _get_account_from_token(access_token) -> Optional[User]: async def _get_account_from_token(access_token) -> Optional[Account]:
try: try:
payload: dict = jwt.decode(access_token, settings.auth_secret_key, ["HS256"]) payload: dict = jwt.decode(access_token, settings.auth_secret_key, ["HS256"])
user = await _get_user_from_jwt_payload(payload) user = await _get_user_from_jwt_payload(payload)
@ -281,7 +284,7 @@ async def _get_account_from_token(access_token) -> Optional[User]:
raise HTTPException(HTTPStatus.UNAUTHORIZED, "Invalid access token.") from exc raise HTTPException(HTTPStatus.UNAUTHORIZED, "Invalid access token.") from exc
async def _get_user_from_jwt_payload(payload) -> Optional[User]: async def _get_user_from_jwt_payload(payload) -> Optional[Account]:
if "sub" in payload and payload.get("sub"): if "sub" in payload and payload.get("sub"):
return await get_account_by_username(str(payload.get("sub"))) return await get_account_by_username(str(payload.get("sub")))
if "usr" in payload and payload.get("usr"): if "usr" in payload and payload.get("usr"):

View file

@ -23,14 +23,6 @@ class InvoiceError(Exception):
self.status = status self.status = status
def register_exception_handlers(app: FastAPI):
register_exception_handler(app)
register_request_validation_exception_handler(app)
register_http_exception_handler(app)
register_payment_error_handler(app)
register_invoice_error_handler(app)
def render_html_error(request: Request, exc: Exception) -> Optional[Response]: def render_html_error(request: Request, exc: Exception) -> Optional[Response]:
# Only the browser sends "text/html" request # Only the browser sends "text/html" request
# not fail proof, but everything else get's a JSON response # not fail proof, but everything else get's a JSON response
@ -63,7 +55,9 @@ def render_html_error(request: Request, exc: Exception) -> Optional[Response]:
return None return None
def register_exception_handler(app: FastAPI): def register_exception_handlers(app: FastAPI):
"""Register exception handlers for the FastAPI app"""
@app.exception_handler(Exception) @app.exception_handler(Exception)
async def exception_handler(request: Request, exc: Exception): async def exception_handler(request: Request, exc: Exception):
etype, _, tb = sys.exc_info() etype, _, tb = sys.exc_info()
@ -74,8 +68,26 @@ def register_exception_handler(app: FastAPI):
content={"detail": str(exc)}, content={"detail": str(exc)},
) )
@app.exception_handler(AssertionError)
async def assert_error_handler(request: Request, exc: AssertionError):
etype, _, tb = sys.exc_info()
traceback.print_exception(etype, exc, tb)
logger.warning(f"AssertionError: {exc!s}")
return render_html_error(request, exc) or JSONResponse(
status_code=HTTPStatus.BAD_REQUEST,
content={"detail": str(exc)},
)
@app.exception_handler(ValueError)
async def value_error_handler(request: Request, exc: ValueError):
etype, _, tb = sys.exc_info()
traceback.print_exception(etype, exc, tb)
logger.warning(f"ValueError: {exc!s}")
return render_html_error(request, exc) or JSONResponse(
status_code=HTTPStatus.BAD_REQUEST,
content={"detail": str(exc)},
)
def register_request_validation_exception_handler(app: FastAPI):
@app.exception_handler(RequestValidationError) @app.exception_handler(RequestValidationError)
async def validation_exception_handler( async def validation_exception_handler(
request: Request, exc: RequestValidationError request: Request, exc: RequestValidationError
@ -86,8 +98,6 @@ def register_request_validation_exception_handler(app: FastAPI):
content={"detail": str(exc)}, content={"detail": str(exc)},
) )
def register_http_exception_handler(app: FastAPI):
@app.exception_handler(HTTPException) @app.exception_handler(HTTPException)
async def http_exception_handler(request: Request, exc: HTTPException): async def http_exception_handler(request: Request, exc: HTTPException):
logger.error(f"HTTPException {exc.status_code}: {exc.detail}") logger.error(f"HTTPException {exc.status_code}: {exc.detail}")
@ -96,8 +106,6 @@ def register_http_exception_handler(app: FastAPI):
content={"detail": exc.detail}, content={"detail": exc.detail},
) )
def register_payment_error_handler(app: FastAPI):
@app.exception_handler(PaymentError) @app.exception_handler(PaymentError)
async def payment_error_handler(request: Request, exc: PaymentError): async def payment_error_handler(request: Request, exc: PaymentError):
logger.error(f"{exc.message}, {exc.status}") logger.error(f"{exc.message}, {exc.status}")
@ -106,8 +114,6 @@ def register_payment_error_handler(app: FastAPI):
content={"detail": exc.message, "status": exc.status}, content={"detail": exc.message, "status": exc.status},
) )
def register_invoice_error_handler(app: FastAPI):
@app.exception_handler(InvoiceError) @app.exception_handler(InvoiceError)
async def invoice_error_handler(request: Request, exc: InvoiceError): async def invoice_error_handler(request: Request, exc: InvoiceError):
logger.error(f"{exc.message}, Status: {exc.status}") logger.error(f"{exc.message}, Status: {exc.status}")

View file

@ -1,17 +1,15 @@
import json import json
import re import re
from datetime import datetime, timedelta from datetime import datetime, timedelta, timezone
from pathlib import Path from pathlib import Path
from typing import Any, List, Optional, Type from typing import Any, Optional, Type
import jinja2 import jinja2
import jwt import jwt
import shortuuid import shortuuid
from pydantic import BaseModel
from pydantic.schema import field_schema from pydantic.schema import field_schema
from lnbits.core.extensions.models import Extension from lnbits.core.extensions.models import Extension
from lnbits.db import get_placeholder
from lnbits.jinja2_templating import Jinja2Templates from lnbits.jinja2_templating import Jinja2Templates
from lnbits.nodes import get_node_class from lnbits.nodes import get_node_class
from lnbits.requestvars import g from lnbits.requestvars import g
@ -51,7 +49,7 @@ def static_url_for(static: str, path: str) -> str:
return f"/{static}/{path}?v={settings.server_startup_time}" return f"/{static}/{path}?v={settings.server_startup_time}"
def template_renderer(additional_folders: Optional[List] = None) -> Jinja2Templates: def template_renderer(additional_folders: Optional[list] = None) -> Jinja2Templates:
folders = ["lnbits/templates", "lnbits/core/templates"] folders = ["lnbits/templates", "lnbits/core/templates"]
if additional_folders: if additional_folders:
additional_folders += [ additional_folders += [
@ -175,37 +173,6 @@ def generate_filter_params_openapi(model: Type[FilterModel], keep_optional=False
} }
def insert_query(table_name: str, model: BaseModel) -> str:
"""
Generate an insert query with placeholders for a given table and model
:param table_name: Name of the table
:param model: Pydantic model
"""
placeholders = []
for field in model.dict().keys():
placeholders.append(get_placeholder(model, field))
fields = ", ".join(model.dict().keys())
values = ", ".join(placeholders)
return f"INSERT INTO {table_name} ({fields}) VALUES ({values})"
def update_query(
table_name: str, model: BaseModel, where: str = "WHERE id = :id"
) -> str:
"""
Generate an update query with placeholders for a given table and model
:param table_name: Name of the table
:param model: Pydantic model
:param where: Where string, default to `WHERE id = :id`
"""
fields = []
for field in model.dict().keys():
placeholder = get_placeholder(model, field)
fields.append(f"{field} = {placeholder}")
query = ", ".join(fields)
return f"UPDATE {table_name} SET {query} {where}"
def is_valid_email_address(email: str) -> bool: def is_valid_email_address(email: str) -> bool:
email_regex = r"[A-Za-z0-9\._%+-]+@[A-Za-z0-9\.-]+\.[A-Za-z]{2,63}" email_regex = r"[A-Za-z0-9\._%+-]+@[A-Za-z0-9\.-]+\.[A-Za-z]{2,63}"
return re.fullmatch(email_regex, email) is not None return re.fullmatch(email_regex, email) is not None
@ -217,7 +184,9 @@ def is_valid_username(username: str) -> bool:
def create_access_token(data: dict): def create_access_token(data: dict):
expire = datetime.utcnow() + timedelta(minutes=settings.auth_token_expire_minutes) expire = datetime.now(timezone.utc) + timedelta(
minutes=settings.auth_token_expire_minutes
)
to_encode = data.copy() to_encode = data.copy()
to_encode.update({"exp": expire}) to_encode.update({"exp": expire})
return jwt.encode(to_encode, settings.auth_secret_key, "HS256") return jwt.encode(to_encode, settings.auth_secret_key, "HS256")

View file

@ -7,7 +7,6 @@ import json
from enum import Enum from enum import Enum
from hashlib import sha256 from hashlib import sha256
from os import path from os import path
from sqlite3 import Row
from time import time from time import time
from typing import Any, Optional from typing import Any, Optional
@ -635,11 +634,6 @@ class ReadOnlySettings(
class Settings(EditableSettings, ReadOnlySettings, TransientSettings, BaseSettings): class Settings(EditableSettings, ReadOnlySettings, TransientSettings, BaseSettings):
@classmethod
def from_row(cls, row: Row) -> Settings:
data = dict(row)
return cls(**data)
class Config: class Config:
env_file = ".env" env_file = ".env"
env_file_encoding = "utf-8" env_file_encoding = "utf-8"

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -530,23 +530,26 @@ video {
overflow-wrap: break-word; overflow-wrap: break-word;
} }
.qrcode__wrapper canvas { .qrcode__wrapper {
position: relative; position: relative;
display: flex;
align-items: center;
justify-content: center;
}
.qrcode__wrapper canvas {
width: 100% !important; width: 100% !important;
max-width: 100%; height: 100% !important;
max-height: 100%; max-width: 350px;
} }
.qrcode__image { .qrcode__image {
position: absolute;
max-width: 52px;
width: 15%; width: 15%;
height: 15%;
overflow: hidden; overflow: hidden;
background: #fff; background: #fff;
left: 50%;
overflow: hidden; overflow: hidden;
position: absolute;
top: 50%;
transform: translate(-50%, -50%);
padding: 0.2rem; padding: 0.2rem;
border-radius: 0.2rem; border-radius: 0.2rem;
} }

View file

@ -268,5 +268,6 @@ window.localisation.en = {
contributors: 'Contributors', contributors: 'Contributors',
license: 'License', license: 'License',
reset_key: 'Reset Key', reset_key: 'Reset Key',
reset_password: 'Reset Password' reset_password: 'Reset Password',
border_choices: 'Border Choices'
} }

View file

@ -12,6 +12,7 @@ window.app = Vue.createApp({
'confettiFireworks', 'confettiFireworks',
'confettiStars' 'confettiStars'
], ],
borderOptions: ['retro-border', 'hard-border', 'no-border'],
tab: 'user', tab: 'user',
credentialsData: { credentialsData: {
show: false, show: false,
@ -63,6 +64,27 @@ window.app = Vue.createApp({
this.$q.localStorage.set('lnbits.gradientBg', false) this.$q.localStorage.set('lnbits.gradientBg', false)
} }
}, },
applyBorder: function () {
slef = this
if (this.borderChoice) {
this.$q.localStorage.setItem('lnbits.border', this.borderChoice)
}
let borderStyle = this.$q.localStorage.getItem('lnbits.border')
this.borderChoice = borderStyle
let borderStyleCSS
if (borderStyle == 'hard-border') {
borderStyleCSS = `box-shadow: 0 0 0 1px rgba(0,0,0,.12), 0 0 0 1px #ffffff47; border: none;`
}
if (borderStyle == 'no-border') {
borderStyleCSS = `box-shadow: none; border: none;`
}
if (borderStyle == 'retro-border') {
borderStyleCSS = `border: none; border-color: rgba(255, 255, 255, 0.28); box-shadow: 0 1px 5px rgba(255, 255, 255, 0.2), 0 2px 2px rgba(255, 255, 255, 0.14), 0 3px 1px -2px rgba(255, 255, 255, 0.12);`
}
let style = document.createElement('style')
style.innerHTML = `body[data-theme="${this.$q.localStorage.getItem('lnbits.theme')}"] .q-card.q-card--dark, .q-date--dark { ${borderStyleCSS} }`
document.head.appendChild(style)
},
toggleGradient: function () { toggleGradient: function () {
this.gradientChoice = !this.gradientChoice this.gradientChoice = !this.gradientChoice
this.applyGradient() this.applyGradient()
@ -92,7 +114,7 @@ window.app = Vue.createApp({
user_id: this.user.id, user_id: this.user.id,
username: this.user.username, username: this.user.username,
email: this.user.email, email: this.user.email,
config: this.user.config extra: this.user.extra
} }
) )
this.user = data this.user = data
@ -183,12 +205,15 @@ window.app = Vue.createApp({
const {data} = await LNbits.api.getAuthenticatedUser() const {data} = await LNbits.api.getAuthenticatedUser()
this.user = data this.user = data
this.hasUsername = !!data.username this.hasUsername = !!data.username
if (!this.user.config) this.user.config = {} if (!this.user.extra) this.user.extra = {}
} catch (e) { } catch (e) {
LNbits.utils.notifyApiError(e) LNbits.utils.notifyApiError(e)
} }
if (this.$q.localStorage.getItem('lnbits.gradientBg')) { if (this.$q.localStorage.getItem('lnbits.gradientBg')) {
this.applyGradient() this.applyGradient()
} }
if (this.$q.localStorage.getItem('lnbits.border')) {
this.applyBorder()
}
} }
}) })

View file

@ -278,7 +278,7 @@ window.LNbits = {
preimage: data.preimage, preimage: data.preimage,
payment_hash: data.payment_hash, payment_hash: data.payment_hash,
expiry: data.expiry, expiry: data.expiry,
extra: data.extra, extra: data.extra ?? {},
wallet_id: data.wallet_id, wallet_id: data.wallet_id,
webhook: data.webhook, webhook: data.webhook,
webhook_status: data.webhook_status, webhook_status: data.webhook_status,
@ -286,13 +286,10 @@ window.LNbits = {
fiat_currency: data.fiat_currency fiat_currency: data.fiat_currency
} }
obj.date = Quasar.date.formatDate( obj.date = Quasar.date.formatDate(new Date(obj.time), 'YYYY-MM-DD HH:mm')
new Date(obj.time * 1000),
'YYYY-MM-DD HH:mm'
)
obj.dateFrom = moment(obj.date).fromNow() obj.dateFrom = moment(obj.date).fromNow()
obj.expirydate = Quasar.date.formatDate( obj.expirydate = Quasar.date.formatDate(
new Date(obj.expiry * 1000), new Date(obj.expiry),
'YYYY-MM-DD HH:mm' 'YYYY-MM-DD HH:mm'
) )
obj.expirydateFrom = moment(obj.expirydate).fromNow() obj.expirydateFrom = moment(obj.expirydate).fromNow()
@ -337,6 +334,12 @@ window.LNbits = {
.join('') .join('')
return hashHex return hashHex
}, },
formatDate: function (timestamp) {
return Quasar.date.formatDate(
new Date(timestamp * 1000),
'YYYY-MM-DD HH:mm'
)
},
formatCurrency: function (value, currency) { formatCurrency: function (value, currency) {
return new Intl.NumberFormat(window.LOCALE, { return new Intl.NumberFormat(window.LOCALE, {
style: 'currency', style: 'currency',
@ -476,6 +479,7 @@ window.windowMixin = {
return { return {
toggleSubs: true, toggleSubs: true,
reactionChoice: 'confettiBothSides', reactionChoice: 'confettiBothSides',
borderChoice: '',
gradientChoice: gradientChoice:
this.$q.localStorage.getItem('lnbits.gradientBg') || false, this.$q.localStorage.getItem('lnbits.gradientBg') || false,
isUserAuthorized: false, isUserAuthorized: false,
@ -517,6 +521,30 @@ window.windowMixin = {
document.head.appendChild(style) document.head.appendChild(style)
} }
}, },
applyBorder: function () {
if (this.borderChoice) {
this.$q.localStorage.setItem('lnbits.border', this.borderChoice)
}
let borderStyle = this.$q.localStorage.getItem('lnbits.border')
if (!borderStyle) {
this.$q.localStorage.set('lnbits.border', 'retro-border')
borderStyle = 'hard-border'
}
this.borderChoice = borderStyle
let borderStyleCSS
if (borderStyle == 'hard-border') {
borderStyleCSS = `box-shadow: 0 0 0 1px rgba(0,0,0,.12), 0 0 0 1px #ffffff47; border: none;`
}
if (borderStyle == 'no-border') {
borderStyleCSS = `box-shadow: none; border: none;`
}
if (borderStyle == 'retro-border') {
borderStyleCSS = `border: none; border-color: rgba(255, 255, 255, 0.28); box-shadow: 0 1px 5px rgba(255, 255, 255, 0.2), 0 2px 2px rgba(255, 255, 255, 0.14), 0 3px 1px -2px rgba(255, 255, 255, 0.12);`
}
let style = document.createElement('style')
style.innerHTML = `body[data-theme="${this.$q.localStorage.getItem('lnbits.theme')}"] .q-card.q-card--dark, .q-date--dark { ${borderStyleCSS} }`
document.head.appendChild(style)
},
setColors: function () { setColors: function () {
this.$q.localStorage.set( this.$q.localStorage.set(
'lnbits.primaryColor', 'lnbits.primaryColor',
@ -592,6 +620,7 @@ window.windowMixin = {
const theme = params.get('theme') const theme = params.get('theme')
const darkMode = params.get('dark') const darkMode = params.get('dark')
const gradient = params.get('gradient') const gradient = params.get('gradient')
const border = params.get('border')
if ( if (
theme && theme &&
@ -617,6 +646,9 @@ window.windowMixin = {
this.$q.localStorage.set('lnbits.darkMode', true) this.$q.localStorage.set('lnbits.darkMode', true)
} }
} }
if (border) {
this.$q.localStorage.set('lnbits.border', border)
}
// Remove processed parameters // Remove processed parameters
fields.forEach(param => params.delete(param)) fields.forEach(param => params.delete(param))
@ -678,6 +710,7 @@ window.windowMixin = {
} }
this.applyGradient() this.applyGradient()
this.applyBorder()
if (window.user) { if (window.user) {
this.g.user = Object.freeze(window.LNbits.map.user(window.user)) this.g.user = Object.freeze(window.LNbits.map.user(window.user))

View file

@ -200,11 +200,25 @@ window.app.component('lnbits-qrcode', {
components: { components: {
QrcodeVue QrcodeVue
}, },
props: ['value'], props: {
value: {
type: String,
required: true
},
options: Object
},
data() { data() {
return { return {
logo: LNBITS_QR_LOGO custom: {
margin: 1,
width: 350,
size: 350,
logo: LNBITS_QR_LOGO
}
} }
},
created() {
this.custom = {...this.custom, ...this.options}
} }
}) })
@ -405,7 +419,7 @@ window.app.component('lnbits-notifications-btn', {
window.app.component('lnbits-dynamic-fields', { window.app.component('lnbits-dynamic-fields', {
template: '#lnbits-dynamic-fields', template: '#lnbits-dynamic-fields',
mixins: [window.windowMixin], mixins: [window.windowMixin],
props: ['options', 'value'], props: ['options', 'modelValue'],
data() { data() {
return { return {
formData: null, formData: null,
@ -427,11 +441,42 @@ window.app.component('lnbits-dynamic-fields', {
}, {}) }, {})
}, },
handleValueChanged() { handleValueChanged() {
this.$emit('input', this.formData) this.$emit('update:model-value', this.formData)
} }
}, },
created() { created() {
this.formData = this.buildData(this.options, this.value) this.formData = this.buildData(this.options, this.modelValue)
}
})
window.app.component('lnbits-dynamic-chips', {
template: '#lnbits-dynamic-chips',
mixins: [window.windowMixin],
props: ['modelValue'],
data() {
return {
chip: '',
chips: []
}
},
methods: {
addChip() {
if (!this.chip) return
this.chips.push(this.chip)
this.chip = ''
this.$emit('update:model-value', this.chips.join(','))
},
removeChip(index) {
this.chips.splice(index, 1)
this.$emit('update:model-value', this.chips.join(','))
}
},
created() {
if (typeof this.modelValue === 'string') {
this.chips = this.modelValue.split(',')
} else {
this.chips = [...this.modelValue]
}
} }
}) })
@ -444,7 +489,7 @@ window.app.component('lnbits-update-balance', {
return LNBITS_DENOMINATION return LNBITS_DENOMINATION
}, },
admin() { admin() {
return this.g.user.admin return user.super_user
} }
}, },
data: function () { data: function () {

View file

@ -1,9 +1,3 @@
function shortenNodeId(nodeId) {
return nodeId
? nodeId.substring(0, 5) + '...' + nodeId.substring(nodeId.length - 5)
: '...'
}
window.app.component('lnbits-node-ranks', { window.app.component('lnbits-node-ranks', {
props: ['ranks'], props: ['ranks'],
data: function () { data: function () {
@ -141,7 +135,11 @@ window.app.component('lnbits-node-info', {
}, },
mixins: [window.windowMixin], mixins: [window.windowMixin],
methods: { methods: {
shortenNodeId shortenNodeId(nodeId) {
return nodeId
? nodeId.substring(0, 5) + '...' + nodeId.substring(nodeId.length - 5)
: '...'
}
}, },
template: ` template: `
<div class='row items-baseline q-gutter-x-sm'> <div class='row items-baseline q-gutter-x-sm'>

View file

@ -165,32 +165,22 @@ window.app = Vue.createApp({
type: 'bubble', type: 'bubble',
options: { options: {
scales: { scales: {
xAxes: [ x: {
{ type: 'linear',
type: 'linear', beginAtZero: true,
ticks: { title: {
beginAtZero: true text: 'Transaction count'
},
scaleLabel: {
display: true,
labelString: 'Tx count'
}
} }
], },
yAxes: [ y: {
{ type: 'linear',
type: 'linear', beginAtZero: true,
ticks: { title: {
beginAtZero: true text: 'User balance in million sats'
},
scaleLabel: {
display: true,
labelString: 'User balance in million sats'
}
} }
] }
}, },
tooltips: { tooltip: {
callbacks: { callbacks: {
label: function (tooltipItem, data) { label: function (tooltipItem, data) {
const dataset = data.datasets[tooltipItem.datasetIndex] const dataset = data.datasets[tooltipItem.datasetIndex]
@ -215,6 +205,9 @@ window.app = Vue.createApp({
}) })
}, },
methods: { methods: {
formatDate: function (value) {
return LNbits.utils.formatDate(value)
},
formatSat: function (value) { formatSat: function (value) {
return LNbits.utils.formatSat(Math.floor(value / 1000)) return LNbits.utils.formatSat(Math.floor(value / 1000))
}, },

View file

@ -5,7 +5,10 @@ window.app = Vue.createApp({
return { return {
updatePayments: false, updatePayments: false,
origin: window.location.origin, origin: window.location.origin,
wallet: LNbits.map.wallet(window.wallet),
user: LNbits.map.user(window.user), user: LNbits.map.user(window.user),
exportUrl: `${window.location.origin}/wallet?usr=${window.user.id}&wal=${window.wallet.id}`,
baseUrl: `${window.location.protocol}//${window.location.host}/`,
receive: { receive: {
show: false, show: false,
status: 'pending', status: 'pending',
@ -142,9 +145,11 @@ window.app = Vue.createApp({
) )
.then(response => { .then(response => {
this.receive.status = 'success' this.receive.status = 'success'
this.receive.paymentReq = response.data.payment_request this.receive.paymentReq = response.data.bolt11
this.receive.paymentHash = response.data.payment_hash this.receive.paymentHash = response.data.payment_hash
// TODO: lnurl_callback and lnurl_response
// WITHDRAW
if (response.data.lnurl_response !== null) { if (response.data.lnurl_response !== null) {
if (response.data.lnurl_response === false) { if (response.data.lnurl_response === false) {
response.data.lnurl_response = `Unable to connect` response.data.lnurl_response = `Unable to connect`
@ -255,7 +260,7 @@ window.app = Vue.createApp({
}) })
}, },
decodeQR: function (res) { decodeQR: function (res) {
this.parse.data.request = res this.parse.data.request = res[0].rawValue
this.decodeRequest() this.decodeRequest()
this.parse.camera.show = false this.parse.camera.show = false
}, },
@ -391,12 +396,13 @@ window.app = Vue.createApp({
dismissPaymentMsg() dismissPaymentMsg()
clearInterval(this.parse.paymentChecker) clearInterval(this.parse.paymentChecker)
// show lnurlpay success action // show lnurlpay success action
if (response.data.success_action) { const extra = response.data.extra
switch (response.data.success_action.tag) { if (extra.success_action) {
switch (extra.success_action.tag) {
case 'url': case 'url':
Quasar.Notify.create({ Quasar.Notify.create({
message: `<a target="_blank" style="color: inherit" href="${response.data.success_action.url}">${response.data.success_action.url}</a>`, message: `<a target="_blank" style="color: inherit" href="${extra.success_action.url}">${extra.success_action.url}</a>`,
caption: response.data.success_action.description, caption: extra.success_action.description,
html: true, html: true,
type: 'positive', type: 'positive',
timeout: 0, timeout: 0,
@ -405,7 +411,7 @@ window.app = Vue.createApp({
break break
case 'message': case 'message':
Quasar.Notify.create({ Quasar.Notify.create({
message: response.data.success_action.message, message: extra.success_action.message,
type: 'positive', type: 'positive',
timeout: 0, timeout: 0,
closeBtn: true closeBtn: true
@ -416,14 +422,14 @@ window.app = Vue.createApp({
.getPayment(this.g.wallet, response.data.payment_hash) .getPayment(this.g.wallet, response.data.payment_hash)
.then(({data: payment}) => .then(({data: payment}) =>
decryptLnurlPayAES( decryptLnurlPayAES(
response.data.success_action, extra.success_action,
payment.preimage payment.preimage
) )
) )
.then(value => { .then(value => {
Quasar.Notify.create({ Quasar.Notify.create({
message: value, message: value,
caption: response.data.success_action.description, caption: extra.success_action.description,
html: true, html: true,
type: 'positive', type: 'positive',
timeout: 0, timeout: 0,

View file

@ -207,23 +207,24 @@ video {
} }
// qrcode // qrcode
.qrcode__wrapper canvas { .qrcode__wrapper {
position: relative; position: relative;
width: 100% !important; // important to override qrcode inline width display: flex;
max-width: 100%; align-items: center;
max-height: 100%; justify-content: center;
}
.qrcode__wrapper canvas {
width: 100% !important; // important to override qrcode inline width
height: 100% !important;
max-width: 350px; // default width of <lnbits-qrcode> component
} }
.qrcode__image { .qrcode__image {
position: absolute;
max-width: 52px;
width: 15%; width: 15%;
height: 15%;
overflow: hidden; overflow: hidden;
background: #fff; background: #fff;
left: 50%;
overflow: hidden; overflow: hidden;
position: absolute;
top: 50%;
transform: translate(-50%, -50%);
padding: 0.2rem; padding: 0.2rem;
border-radius: 0.2rem; border-radius: 0.2rem;
} }

File diff suppressed because one or more lines are too long

View file

@ -20,8 +20,7 @@ from lnbits.core.crud import (
delete_webpush_subscriptions, delete_webpush_subscriptions,
get_payments, get_payments,
get_standalone_payment, get_standalone_payment,
update_payment_details, update_payment,
update_payment_status,
) )
from lnbits.core.models import Payment, PaymentState from lnbits.core.models import Payment, PaymentState
from lnbits.settings import settings from lnbits.settings import settings
@ -181,17 +180,14 @@ async def check_pending_payments():
status = await payment.check_status() status = await payment.check_status()
prefix = f"payment ({i+1} / {count})" prefix = f"payment ({i+1} / {count})"
if status.failed: if status.failed:
await update_payment_status( payment.status = PaymentState.FAILED
payment.checking_id, status=PaymentState.FAILED await update_payment(payment)
)
logger.debug(f"{prefix} failed {payment.checking_id}") logger.debug(f"{prefix} failed {payment.checking_id}")
elif status.success: elif status.success:
await update_payment_details( payment.fee = status.fee_msat or 0
checking_id=payment.checking_id, payment.preimage = status.preimage
fee=status.fee_msat, payment.status = PaymentState.SUCCESS
preimage=status.preimage, await update_payment(payment)
status=PaymentState.SUCCESS,
)
logger.debug(f"{prefix} success {payment.checking_id}") logger.debug(f"{prefix} success {payment.checking_id}")
else: else:
logger.debug(f"{prefix} pending {payment.checking_id}") logger.debug(f"{prefix} pending {payment.checking_id}")
@ -211,14 +207,10 @@ async def invoice_callback_dispatcher(checking_id: str, is_internal: bool = Fals
payment = await get_standalone_payment(checking_id, incoming=True) payment = await get_standalone_payment(checking_id, incoming=True)
if payment and payment.is_in: if payment and payment.is_in:
status = await payment.check_status() status = await payment.check_status()
await update_payment_details( payment.fee = status.fee_msat or 0
checking_id=payment.checking_id, payment.preimage = status.preimage
fee=status.fee_msat, payment.status = PaymentState.SUCCESS
preimage=status.preimage, await update_payment(payment)
status=PaymentState.SUCCESS,
)
payment = await get_standalone_payment(checking_id, incoming=True)
assert payment, "updated payment not found"
internal = "internal" if is_internal else "" internal = "internal" if is_internal else ""
logger.success(f"{internal} invoice {checking_id} settled") logger.success(f"{internal} invoice {checking_id} settled")
for name, send_chan in invoice_listeners.items(): for name, send_chan in invoice_listeners.items():

View file

@ -251,6 +251,17 @@
/> />
</div> </div>
<div class="text-wrap">
<b style="white-space: nowrap" v-text="$t('Invoice')"></b>:&nbsp;
<q-icon
name="content_copy"
@click="copyText(payment.bolt11)"
size="1em"
color="grey"
class="q-mb-xs cursor-pointer"
/>
</div>
<div class="text-wrap"> <div class="text-wrap">
<b style="white-space: nowrap" v-text="$t('memo')"></b>:&nbsp; <b style="white-space: nowrap" v-text="$t('memo')"></b>:&nbsp;
<span v-text="payment.memo"></span> <span v-text="payment.memo"></span>
@ -301,7 +312,7 @@
v-if="o.options?.length" v-if="o.options?.length"
:options="o.options" :options="o.options"
v-model="formData[o.name]" v-model="formData[o.name]"
@input="handleValueChanged" @update:model-value="handleValueChanged"
class="q-ml-xl" class="q-ml-xl"
> >
</lnbits-dynamic-fields> </lnbits-dynamic-fields>
@ -310,7 +321,7 @@
v-if="o.type === 'number'" v-if="o.type === 'number'"
type="number" type="number"
v-model="formData[o.name]" v-model="formData[o.name]"
@input="handleValueChanged" @update:model-value="handleValueChanged"
:label="o.label || o.name" :label="o.label || o.name"
:hint="o.description" :hint="o.description"
:rules="applyRules(o.required)" :rules="applyRules(o.required)"
@ -322,7 +333,7 @@
type="textarea" type="textarea"
rows="5" rows="5"
v-model="formData[o.name]" v-model="formData[o.name]"
@input="handleValueChanged" @update:model-value="handleValueChanged"
:label="o.label || o.name" :label="o.label || o.name"
:hint="o.description" :hint="o.description"
:rules="applyRules(o.required)" :rules="applyRules(o.required)"
@ -332,7 +343,7 @@
<q-input <q-input
v-else-if="o.type === 'password'" v-else-if="o.type === 'password'"
v-model="formData[o.name]" v-model="formData[o.name]"
@input="handleValueChanged" @update:model-value="handleValueChanged"
type="password" type="password"
:label="o.label || o.name" :label="o.label || o.name"
:hint="o.description" :hint="o.description"
@ -343,7 +354,7 @@
<q-select <q-select
v-else-if="o.type === 'select'" v-else-if="o.type === 'select'"
v-model="formData[o.name]" v-model="formData[o.name]"
@input="handleValueChanged" @update:model-value="handleValueChanged"
:label="o.label || o.name" :label="o.label || o.name"
:hint="o.description" :hint="o.description"
:options="o.values" :options="o.values"
@ -352,7 +363,7 @@
<q-select <q-select
v-else-if="o.isList" v-else-if="o.isList"
v-model.trim="formData[o.name]" v-model.trim="formData[o.name]"
@input="handleValueChanged" @update:model-value="handleValueChanged"
input-debounce="0" input-debounce="0"
new-value-mode="add-unique" new-value-mode="add-unique"
:label="o.label || o.name" :label="o.label || o.name"
@ -371,7 +382,7 @@
<q-item-section avatar top> <q-item-section avatar top>
<q-checkbox <q-checkbox
v-model="formData[o.name]" v-model="formData[o.name]"
@input="handleValueChanged" @update:model-value="handleValueChanged"
/> />
</q-item-section> </q-item-section>
<q-item-section> <q-item-section>
@ -391,10 +402,16 @@
style="display: none" style="display: none"
:rules="applyRules(o.required)" :rules="applyRules(o.required)"
></q-input> ></q-input>
<div v-else-if="o.type === 'chips'">
<lnbits-dynamic-chips
v-model="formData[o.name]"
@update:model-value="handleValueChanged"
></lnbits-dynamic-chips>
</div>
<q-input <q-input
v-else v-else
v-model="formData[o.name]" v-model="formData[o.name]"
@input="handleValueChanged" @update:model-value="handleValueChanged"
:hint="o.description" :hint="o.description"
:label="o.label || o.name" :label="o.label || o.name"
:rules="applyRules(o.required)" :rules="applyRules(o.required)"
@ -407,6 +424,32 @@
</div> </div>
</template> </template>
<template id="lnbits-dynamic-chips">
<q-input
filled
v-model="chip"
@keydown.enter.prevent="addChip"
type="text"
label="wss://...."
hint="Add relays"
class="q-mb-md"
>
<q-btn @click="addChip" dense flat icon="add"></q-btn>
</q-input>
<div>
<q-chip
v-for="(chip, i) in chips"
:key="chip"
removable
@remove="removeChip(i)"
color="primary"
text-color="white"
:label="chip"
>
</q-chip>
</div>
</template>
<template id="lnbits-notifications-btn"> <template id="lnbits-notifications-btn">
<q-btn <q-btn
v-if="g.user.wallets" v-if="g.user.wallets"
@ -457,8 +500,20 @@
<template id="lnbits-qrcode"> <template id="lnbits-qrcode">
<div class="qrcode__wrapper"> <div class="qrcode__wrapper">
<qrcode-vue :value="value" size="350" class="rounded-borders"></qrcode-vue> <qrcode-vue
<img class="qrcode__image" :src="logo" alt="..." /> :value="value"
level="Q"
render-as="svg"
:margin="custom.margin"
:size="custom.width"
class="rounded-borders"
></qrcode-vue>
<img
v-if="custom.logo"
class="qrcode__image"
:src="custom.logo"
alt="qrcode icon"
/>
</div> </div>
</template> </template>
@ -585,12 +640,12 @@
:rows="paymentsOmitter" :rows="paymentsOmitter"
:row-key="paymentTableRowKey" :row-key="paymentTableRowKey"
:columns="paymentsTable.columns" :columns="paymentsTable.columns"
:pagination.sync="paymentsTable.pagination"
:no-data-label="$t('no_transactions')" :no-data-label="$t('no_transactions')"
:filter="paymentsTable.search" :filter="paymentsTable.search"
:loading="paymentsTable.loading" :loading="paymentsTable.loading"
:hide-header="mobileSimple" :hide-header="mobileSimple"
:hide-bottom="mobileSimple" :hide-bottom="mobileSimple"
v-model:pagination="paymentsTable.pagination"
@request="fetchPayments" @request="fetchPayments"
> >
<template v-slot:header="props"> <template v-slot:header="props">
@ -699,13 +754,9 @@
></lnbits-payment-details> ></lnbits-payment-details>
<div v-if="props.row.bolt11" class="text-center q-mb-lg"> <div v-if="props.row.bolt11" class="text-center q-mb-lg">
<a :href="'lightning:' + props.row.bolt11"> <a :href="'lightning:' + props.row.bolt11">
<q-responsive :ratio="1" class="q-mx-xl"> <lnbits-qrcode
<lnbits-qrcode :value="'lightning:' + props.row.bolt11.toUpperCase()"
:value=" ></lnbits-qrcode>
'lightning:' + props.row.bolt11.toUpperCase()
"
></lnbits-qrcode>
</q-responsive>
</a> </a>
</div> </div>
<div class="row q-mt-lg"> <div class="row q-mt-lg">
@ -797,7 +848,7 @@
</q-form> </q-form>
</template> </template>
<template id="lnbits-extension-btn-dialog"> <template id="lnbits-extension-settings-btn-dialog">
<q-btn <q-btn
v-if="options" v-if="options"
unelevated unelevated

View file

@ -39,26 +39,21 @@
</q-card-section> </q-card-section>
</q-card> </q-card>
</div> </div>
{% endblock %} {% block scripts %}
<script>
window.app = Vue.createApp({
el: '#vue',
mixins: [window.windowMixin],
data: function () {
return {}
},
methods: {
goBack: function () {
window.history.back()
},
goHome: function () {
window.location.href = '/'
}
}
})
</script>
{% endblock %}
</div> </div>
{% endblock %} {% block scripts %}
<script>
window.app = Vue.createApp({
el: '#vue',
mixins: [window.windowMixin],
methods: {
goBack: function () {
window.history.back()
},
goHome: function () {
window.location.href = '/'
}
}
})
</script>
{% endblock %}

View file

@ -5,10 +5,10 @@
window.currencies = {{ currencies | tojson | safe }}; window.currencies = {{ currencies | tojson | safe }};
{% endif %} {% endif %}
{% if user %} {% if user %}
window.user = {{ user | tojson | safe }}; window.user = JSON.parse({{ user | tojson | safe }});
{% endif %} {% endif %}
{% if wallet %} {% if wallet %}
window.wallet = {{ wallet | tojson | safe }}; window.wallet = JSON.parse({{ wallet | tojson | safe }});
{% endif %} {% endif %}
</script> </script>
{%- endmacro %} {%- endmacro %}

View file

@ -34,18 +34,26 @@
</head> </head>
<body> <body>
<q-layout id="vue" view="hHh lpR lfr" v-cloak> <div id="vue">
<q-page-container> <q-layout view="hHh lpR lfr" v-cloak>
<q-page class="q-px-md q-py-lg" :class="{'q-px-lg': $q.screen.gt.xs}"> <q-page-container>
{% block page %}{% endblock %} <q-page class="q-px-md q-py-lg" :class="{'q-px-lg': $q.screen.gt.xs}">
</q-page> {% block page %}{% endblock %}
</q-page-container> </q-page>
</q-layout> </q-page-container>
</q-layout>
</div>
{% for url in INCLUDED_JS %} {% include('components.vue') %}{% block vue_templates %}{% endblock %} {%
for url in INCLUDED_JS %}
<script src="{{ static_url_for('static', url) }}"></script> <script src="{{ static_url_for('static', url) }}"></script>
{% endfor %} {% endfor %}
<script>
const LNBITS_QR_LOGO = {{ LNBITS_QR_LOGO | tojson }}
</script>
<!----> <!---->
{% block scripts %}{% endblock %} {% block scripts %}{% endblock %} {% for url in INCLUDED_COMPONENTS %}
<script src="{{ static_url_for('static', url) }}"></script>
{% endfor %}
</body> </body>
</html> </html>

View file

@ -85,15 +85,13 @@ class LNbitsWallet(Wallet):
r.raise_for_status() r.raise_for_status()
data = r.json() data = r.json()
if r.is_error or "payment_request" not in data: if r.is_error or "bolt11" not in data:
error_message = data["detail"] if "detail" in data else r.text error_message = data["detail"] if "detail" in data else r.text
return InvoiceResponse( return InvoiceResponse(
False, None, None, f"Server error: '{error_message}'" False, None, None, f"Server error: '{error_message}'"
) )
return InvoiceResponse( return InvoiceResponse(True, data["checking_id"], data["bolt11"], None)
True, data["checking_id"], data["payment_request"], None
)
except json.JSONDecodeError: except json.JSONDecodeError:
return InvoiceResponse( return InvoiceResponse(
False, None, None, "Server error: 'invalid json response'" False, None, None, "Server error: 'invalid json response'"

View file

@ -36,8 +36,18 @@ class PhoenixdWallet(Wallet):
) )
self.endpoint = self.normalize_endpoint(settings.phoenixd_api_endpoint) self.endpoint = self.normalize_endpoint(settings.phoenixd_api_endpoint)
parsed_url = urllib.parse.urlparse(settings.phoenixd_api_endpoint)
self.ws_url = f"ws://{urllib.parse.urlsplit(self.endpoint).netloc}/websocket" if parsed_url.scheme == "http":
ws_protocol = "ws"
elif parsed_url.scheme == "https":
ws_protocol = "wss"
else:
raise ValueError(f"Unsupported scheme: {parsed_url.scheme}")
self.ws_url = (
f"{ws_protocol}://{urllib.parse.urlsplit(self.endpoint).netloc}/websocket"
)
password = settings.phoenixd_api_password password = settings.phoenixd_api_password
encoded_auth = base64.b64encode(f":{password}".encode()) encoded_auth = base64.b64encode(f":{password}".encode())
auth = str(encoded_auth, "utf-8") auth = str(encoded_auth, "utf-8")

7
package-lock.json generated
View file

@ -1320,9 +1320,10 @@
} }
}, },
"node_modules/vue-qrcode-reader": { "node_modules/vue-qrcode-reader": {
"version": "5.5.10", "version": "5.5.11",
"resolved": "https://registry.npmjs.org/vue-qrcode-reader/-/vue-qrcode-reader-5.5.10.tgz", "resolved": "https://registry.npmjs.org/vue-qrcode-reader/-/vue-qrcode-reader-5.5.11.tgz",
"integrity": "sha512-lj83FKqRyvo0VLMu49wrLsaHueonfXcwyX9r/GDw0y+myOY5xTfsl75hjBgmmByAxzFSlCPI+CGA9FxYVtRAFQ==", "integrity": "sha512-Ec/bVML1jgxSX+usbgdcXGhOFEFo4EzApCO2CNT1YK0Dcb0Mp7ASygz78RJJs22SU2oI7vz9iJDyr4ucSDTvjQ==",
"license": "MIT",
"dependencies": { "dependencies": {
"barcode-detector": "2.2.2", "barcode-detector": "2.2.2",
"webrtc-adapter": "8.2.3" "webrtc-adapter": "8.2.3"

View file

@ -1,6 +1,6 @@
[tool.poetry] [tool.poetry]
name = "lnbits" name = "lnbits"
version = "1.0.0-rc2" version = "1.0.0-rc5"
description = "LNbits, free and open-source Lightning wallet and accounts system." description = "LNbits, free and open-source Lightning wallet and accounts system."
authors = ["Alan Bits <alan@lnbits.com>"] authors = ["Alan Bits <alan@lnbits.com>"]
readme = "README.md" readme = "README.md"
@ -201,6 +201,7 @@ dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
[tool.ruff.lint.pep8-naming] [tool.ruff.lint.pep8-naming]
classmethod-decorators = [ classmethod-decorators = [
"root_validator", "root_validator",
"validator",
] ]
[tool.ruff.lint.mccabe] [tool.ruff.lint.mccabe]

View file

@ -1,6 +1,7 @@
import pytest import pytest
from lnbits.settings import settings from lnbits.core.models import User
from lnbits.settings import Settings
@pytest.mark.asyncio @pytest.mark.asyncio
@ -18,7 +19,7 @@ async def test_admin_get_settings(client, superuser):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_admin_update_settings(client, superuser): async def test_admin_update_settings(client, superuser: User, settings: Settings):
new_site_title = "UPDATED SITETITLE" new_site_title = "UPDATED SITETITLE"
response = await client.put( response = await client.put(
f"/admin/api/v1/settings?usr={superuser.id}", f"/admin/api/v1/settings?usr={superuser.id}",

View file

@ -5,7 +5,7 @@ import pytest
from lnbits import bolt11 from lnbits import bolt11
from lnbits.core.models import CreateInvoice, Payment from lnbits.core.models import CreateInvoice, Payment
from lnbits.core.views.payment_api import api_payment from lnbits.core.views.payment_api import api_payment
from lnbits.settings import settings from lnbits.settings import Settings
from ..helpers import ( from ..helpers import (
get_random_invoice_data, get_random_invoice_data,
@ -14,10 +14,13 @@ from ..helpers import (
# create account POST /api/v1/account # create account POST /api/v1/account
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_account(client): async def test_create_account(client, settings: Settings):
settings.lnbits_allow_new_accounts = False settings.lnbits_allow_new_accounts = False
response = await client.post("/api/v1/account", json={"name": "test"}) response = await client.post("/api/v1/account", json={"name": "test"})
assert response.status_code == 403
assert response.status_code == 400
assert response.json().get("detail") == "Account creation is disabled."
settings.lnbits_allow_new_accounts = True settings.lnbits_allow_new_accounts = True
response = await client.post("/api/v1/account", json={"name": "test"}) response = await client.post("/api/v1/account", json={"name": "test"})
assert response.status_code == 200 assert response.status_code == 200
@ -120,7 +123,7 @@ async def test_create_invoice(client, inkey_headers_to):
invoice = response.json() invoice = response.json()
assert "payment_hash" in invoice assert "payment_hash" in invoice
assert len(invoice["payment_hash"]) == 64 assert len(invoice["payment_hash"]) == 64
assert "payment_request" in invoice assert "bolt11" in invoice
assert "checking_id" in invoice assert "checking_id" in invoice
assert len(invoice["checking_id"]) assert len(invoice["checking_id"])
return invoice return invoice
@ -135,7 +138,7 @@ async def test_create_invoice_fiat_amount(client, inkey_headers_to):
) )
assert response.status_code == 201 assert response.status_code == 201
invoice = response.json() invoice = response.json()
decode = bolt11.decode(invoice["payment_request"]) decode = bolt11.decode(invoice["bolt11"])
assert decode.amount_msat != data["amount"] * 1000 assert decode.amount_msat != data["amount"] * 1000
assert decode.payment_hash assert decode.payment_hash
@ -177,7 +180,7 @@ async def test_create_internal_invoice(client, inkey_headers_to):
assert response.status_code == 201 assert response.status_code == 201
assert "payment_hash" in invoice assert "payment_hash" in invoice
assert len(invoice["payment_hash"]) == 64 assert len(invoice["payment_hash"]) == 64
assert "payment_request" in invoice assert "bolt11" in invoice
assert "checking_id" in invoice assert "checking_id" in invoice
assert len(invoice["checking_id"]) assert len(invoice["checking_id"])
return invoice return invoice
@ -194,26 +197,28 @@ async def test_create_invoice_custom_expiry(client, inkey_headers_to):
) )
assert response.status_code == 201 assert response.status_code == 201
invoice = response.json() invoice = response.json()
bolt11_invoice = bolt11.decode(invoice["payment_request"]) bolt11_invoice = bolt11.decode(invoice["bolt11"])
assert bolt11_invoice.expiry == expiry_seconds assert bolt11_invoice.expiry == expiry_seconds
# check POST /api/v1/payments: make payment # check POST /api/v1/payments: make payment
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_pay_invoice(client, from_wallet_ws, invoice, adminkey_headers_from): async def test_pay_invoice(
data = {"out": True, "bolt11": invoice["payment_request"]} client, from_wallet_ws, invoice: Payment, adminkey_headers_from
):
data = {"out": True, "bolt11": invoice.bolt11}
response = await client.post( response = await client.post(
"/api/v1/payments", json=data, headers=adminkey_headers_from "/api/v1/payments", json=data, headers=adminkey_headers_from
) )
assert response.status_code < 300 assert response.status_code < 300
invoice = response.json() invoice_ = response.json()
assert len(invoice["payment_hash"]) == 64 assert len(invoice_["payment_hash"]) == 64
assert len(invoice["checking_id"]) > 0 assert len(invoice_["checking_id"]) > 0
data = from_wallet_ws.receive_json() ws_data = from_wallet_ws.receive_json()
assert "wallet_balance" in data assert "wallet_balance" in ws_data
payment = Payment(**data["payment"]) payment = Payment(**ws_data["payment"])
assert payment.payment_hash == invoice["payment_hash"] assert payment.payment_hash == invoice_["payment_hash"]
# websocket from to_wallet cant be tested before https://github.com/lnbits/lnbits/pull/1793 # websocket from to_wallet cant be tested before https://github.com/lnbits/lnbits/pull/1793
# data = to_wallet_ws.receive_json() # data = to_wallet_ws.receive_json()
@ -224,9 +229,9 @@ async def test_pay_invoice(client, from_wallet_ws, invoice, adminkey_headers_fro
# check GET /api/v1/payments/<hash>: payment status # check GET /api/v1/payments/<hash>: payment status
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_check_payment_without_key(client, invoice): async def test_check_payment_without_key(client, invoice: Payment):
# check the payment status # check the payment status
response = await client.get(f"/api/v1/payments/{invoice['payment_hash']}") response = await client.get(f"/api/v1/payments/{invoice.payment_hash}")
assert response.status_code < 300 assert response.status_code < 300
assert response.json()["paid"] is True assert response.json()["paid"] is True
assert invoice assert invoice
@ -240,10 +245,10 @@ async def test_check_payment_without_key(client, invoice):
# If sqlite: it will succeed only with adminkey_headers_to # If sqlite: it will succeed only with adminkey_headers_to
# TODO: fix this # TODO: fix this
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_check_payment_with_key(client, invoice, inkey_headers_from): async def test_check_payment_with_key(client, invoice: Payment, inkey_headers_from):
# check the payment status # check the payment status
response = await client.get( response = await client.get(
f"/api/v1/payments/{invoice['payment_hash']}", headers=inkey_headers_from f"/api/v1/payments/{invoice.payment_hash}", headers=inkey_headers_from
) )
assert response.status_code < 300 assert response.status_code < 300
assert response.json()["paid"] is True assert response.json()["paid"] is True
@ -255,7 +260,7 @@ async def test_check_payment_with_key(client, invoice, inkey_headers_from):
# check POST /api/v1/payments: payment with wrong key type # check POST /api/v1/payments: payment with wrong key type
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_pay_invoice_wrong_key(client, invoice, adminkey_headers_from): async def test_pay_invoice_wrong_key(client, invoice, adminkey_headers_from):
data = {"out": True, "bolt11": invoice["payment_request"]} data = {"out": True, "bolt11": invoice.bolt11}
# try payment with wrong key # try payment with wrong key
wrong_adminkey_headers = adminkey_headers_from.copy() wrong_adminkey_headers = adminkey_headers_from.copy()
wrong_adminkey_headers["X-Api-Key"] = "wrong_key" wrong_adminkey_headers["X-Api-Key"] = "wrong_key"
@ -276,7 +281,7 @@ async def test_pay_invoice_self_payment(client, adminkey_headers_from):
) )
assert response.status_code < 300 assert response.status_code < 300
json_data = response.json() json_data = response.json()
data = {"out": True, "bolt11": json_data["payment_request"]} data = {"out": True, "bolt11": json_data["bolt11"]}
response = await client.post( response = await client.post(
"/api/v1/payments", json=data, headers=adminkey_headers_from "/api/v1/payments", json=data, headers=adminkey_headers_from
) )
@ -286,7 +291,7 @@ async def test_pay_invoice_self_payment(client, adminkey_headers_from):
# check POST /api/v1/payments: payment with invoice key [should fail] # check POST /api/v1/payments: payment with invoice key [should fail]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_pay_invoice_invoicekey(client, invoice, inkey_headers_from): async def test_pay_invoice_invoicekey(client, invoice, inkey_headers_from):
data = {"out": True, "bolt11": invoice["payment_request"]} data = {"out": True, "bolt11": invoice.bolt11}
# try payment with invoice key # try payment with invoice key
response = await client.post( response = await client.post(
"/api/v1/payments", json=data, headers=inkey_headers_from "/api/v1/payments", json=data, headers=inkey_headers_from
@ -297,7 +302,7 @@ async def test_pay_invoice_invoicekey(client, invoice, inkey_headers_from):
# check POST /api/v1/payments: payment with admin key, trying to pay twice [should fail] # check POST /api/v1/payments: payment with admin key, trying to pay twice [should fail]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_pay_invoice_adminkey(client, invoice, adminkey_headers_from): async def test_pay_invoice_adminkey(client, invoice, adminkey_headers_from):
data = {"out": True, "bolt11": invoice["payment_request"]} data = {"out": True, "bolt11": invoice.bolt11}
# try payment with admin key # try payment with admin key
response = await client.post( response = await client.post(
"/api/v1/payments", json=data, headers=adminkey_headers_from "/api/v1/payments", json=data, headers=adminkey_headers_from
@ -306,19 +311,20 @@ async def test_pay_invoice_adminkey(client, invoice, adminkey_headers_from):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_payments(client, adminkey_headers_from, fake_payments): async def test_get_payments(client, inkey_fresh_headers_to, fake_payments):
fake_data, filters = fake_payments fake_data, filters = fake_payments
async def get_payments(params: dict): async def get_payments(params: dict):
response = await client.get( response = await client.get(
"/api/v1/payments", "/api/v1/payments",
params=filters | params, params=filters | params,
headers=adminkey_headers_from, headers=inkey_fresh_headers_to,
) )
assert response.status_code == 200 assert response.status_code == 200
return [Payment(**payment) for payment in response.json()] return [Payment(**payment) for payment in response.json()]
payments = await get_payments({"sortby": "amount", "direction": "desc", "limit": 2}) payments = await get_payments({"sortby": "amount", "direction": "desc", "limit": 2})
assert len(payments) != 0
assert payments[-1].amount < payments[0].amount assert payments[-1].amount < payments[0].amount
assert len(payments) == 2 assert len(payments) == 2
@ -340,13 +346,13 @@ async def test_get_payments(client, adminkey_headers_from, fake_payments):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_payments_paginated(client, adminkey_headers_from, fake_payments): async def test_get_payments_paginated(client, inkey_fresh_headers_to, fake_payments):
fake_data, filters = fake_payments fake_data, filters = fake_payments
response = await client.get( response = await client.get(
"/api/v1/payments/paginated", "/api/v1/payments/paginated",
params=filters | {"limit": 2}, params=filters | {"limit": 2},
headers=adminkey_headers_from, headers=inkey_fresh_headers_to,
) )
assert response.status_code == 200 assert response.status_code == 200
paginated = response.json() paginated = response.json()
@ -355,13 +361,13 @@ async def test_get_payments_paginated(client, adminkey_headers_from, fake_paymen
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_payments_history(client, adminkey_headers_from, fake_payments): async def test_get_payments_history(client, inkey_fresh_headers_to, fake_payments):
fake_data, filters = fake_payments fake_data, filters = fake_payments
response = await client.get( response = await client.get(
"/api/v1/payments/history", "/api/v1/payments/history",
params=filters, params=filters,
headers=adminkey_headers_from, headers=inkey_fresh_headers_to,
) )
assert response.status_code == 200 assert response.status_code == 200
@ -377,7 +383,7 @@ async def test_get_payments_history(client, adminkey_headers_from, fake_payments
response = await client.get( response = await client.get(
"/api/v1/payments/history?group=INVALID", "/api/v1/payments/history?group=INVALID",
params=filters, params=filters,
headers=adminkey_headers_from, headers=inkey_fresh_headers_to,
) )
assert response.status_code == 400 assert response.status_code == 400
@ -385,21 +391,21 @@ async def test_get_payments_history(client, adminkey_headers_from, fake_payments
# check POST /api/v1/payments/decode # check POST /api/v1/payments/decode
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_decode_invoice(client, invoice): async def test_decode_invoice(client, invoice: Payment):
data = {"data": invoice["payment_request"]} data = {"data": invoice.bolt11}
response = await client.post( response = await client.post(
"/api/v1/payments/decode", "/api/v1/payments/decode",
json=data, json=data,
) )
assert response.status_code < 300 assert response.status_code < 300
assert response.json()["payment_hash"] == invoice["payment_hash"] assert response.json()["payment_hash"] == invoice.payment_hash
# check api_payment() internal function call (NOT API): payment status # check api_payment() internal function call (NOT API): payment status
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_api_payment_without_key(invoice): async def test_api_payment_without_key(invoice: Payment):
# check the payment status # check the payment status
response = await api_payment(invoice["payment_hash"]) response = await api_payment(invoice.payment_hash)
assert isinstance(response, dict) assert isinstance(response, dict)
assert response["paid"] is True assert response["paid"] is True
# no key, that's why no "details" # no key, that's why no "details"
@ -408,11 +414,9 @@ async def test_api_payment_without_key(invoice):
# check api_payment() internal function call (NOT API): payment status # check api_payment() internal function call (NOT API): payment status
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_api_payment_with_key(invoice, inkey_headers_from): async def test_api_payment_with_key(invoice: Payment, inkey_headers_from):
# check the payment status # check the payment status
response = await api_payment( response = await api_payment(invoice.payment_hash, inkey_headers_from["X-Api-Key"])
invoice["payment_hash"], inkey_headers_from["X-Api-Key"]
)
assert isinstance(response, dict) assert isinstance(response, dict)
assert response["paid"] is True assert response["paid"] is True
assert "details" in response assert "details" in response
@ -431,7 +435,7 @@ async def test_create_invoice_with_description_hash(client, inkey_headers_to):
) )
invoice = response.json() invoice = response.json()
invoice_bolt11 = bolt11.decode(invoice["payment_request"]) invoice_bolt11 = bolt11.decode(invoice["bolt11"])
assert invoice_bolt11.description_hash == descr_hash assert invoice_bolt11.description_hash == descr_hash
return invoice return invoice
@ -448,7 +452,7 @@ async def test_create_invoice_with_unhashed_description(client, inkey_headers_to
) )
invoice = response.json() invoice = response.json()
invoice_bolt11 = bolt11.decode(invoice["payment_request"]) invoice_bolt11 = bolt11.decode(invoice["bolt11"])
assert invoice_bolt11.description_hash == descr_hash assert invoice_bolt11.description_hash == descr_hash
assert invoice_bolt11.description is None assert invoice_bolt11.description is None
return invoice return invoice
@ -475,7 +479,7 @@ async def test_update_wallet(client, adminkey_headers_from):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_fiat_tracking(client, adminkey_headers_from): async def test_fiat_tracking(client, adminkey_headers_from, settings: Settings):
async def create_invoice(): async def create_invoice():
data = await get_random_invoice_data() data = await get_random_invoice_data()
response = await client.post( response = await client.post(
@ -501,13 +505,15 @@ async def test_fiat_tracking(client, adminkey_headers_from):
settings.lnbits_default_accounting_currency = "USD" settings.lnbits_default_accounting_currency = "USD"
payment = await create_invoice() payment = await create_invoice()
assert payment["extra"]["wallet_fiat_currency"] == "USD" extra = payment["extra"]
assert payment["extra"]["wallet_fiat_amount"] != payment["amount"] assert extra["wallet_fiat_currency"] == "USD"
assert payment["extra"]["wallet_fiat_rate"] assert extra["wallet_fiat_amount"] != payment["amount"]
assert extra["wallet_fiat_rate"]
await update_currency("EUR") await update_currency("EUR")
payment = await create_invoice() payment = await create_invoice()
assert payment["extra"]["wallet_fiat_currency"] == "EUR" extra = payment["extra"]
assert payment["extra"]["wallet_fiat_amount"] != payment["amount"] assert extra["wallet_fiat_currency"] == "EUR"
assert payment["extra"]["wallet_fiat_rate"] assert extra["wallet_fiat_amount"] != payment["amount"]
assert extra["wallet_fiat_rate"]

View file

@ -11,7 +11,7 @@ from httpx import AsyncClient
from lnbits.core.models import AccessTokenPayload, User from lnbits.core.models import AccessTokenPayload, User
from lnbits.core.views.user_api import api_users_reset_password from lnbits.core.views.user_api import api_users_reset_password
from lnbits.settings import AuthMethods, settings from lnbits.settings import AuthMethods, Settings
from lnbits.utils.nostr import hex_to_npub, sign_event from lnbits.utils.nostr import hex_to_npub, sign_event
nostr_event = { nostr_event = {
@ -29,8 +29,6 @@ private_key = secp256k1.PrivateKey(
) )
pubkey_hex = private_key.pubkey.serialize().hex()[2:] pubkey_hex = private_key.pubkey.serialize().hex()[2:]
settings.auth_allowed_methods = AuthMethods.all()
################################ LOGIN ################################ ################################ LOGIN ################################
@pytest.mark.asyncio @pytest.mark.asyncio
@ -63,7 +61,9 @@ async def test_login_alan_usr(user_alan: User, http_client: AsyncClient):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_login_usr_not_allowed(user_alan: User, http_client: AsyncClient): async def test_login_usr_not_allowed(
user_alan: User, http_client: AsyncClient, settings: Settings
):
# exclude 'user_id_only' # exclude 'user_id_only'
settings.auth_allowed_methods = [AuthMethods.username_and_password.value] settings.auth_allowed_methods = [AuthMethods.username_and_password.value]
@ -83,7 +83,7 @@ async def test_login_usr_not_allowed(user_alan: User, http_client: AsyncClient):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_login_alan_username_password_ok( async def test_login_alan_username_password_ok(
user_alan: User, http_client: AsyncClient user_alan: User, http_client: AsyncClient, settings: Settings
): ):
response = await http_client.post( response = await http_client.post(
"/api/v1/auth", json={"username": user_alan.username, "password": "secret1234"} "/api/v1/auth", json={"username": user_alan.username, "password": "secret1234"}
@ -95,6 +95,7 @@ async def test_login_alan_username_password_ok(
payload: dict = jwt.decode(access_token, settings.auth_secret_key, ["HS256"]) payload: dict = jwt.decode(access_token, settings.auth_secret_key, ["HS256"])
access_token_payload = AccessTokenPayload(**payload) access_token_payload = AccessTokenPayload(**payload)
assert access_token_payload.sub == "alan", "Subject is Alan." assert access_token_payload.sub == "alan", "Subject is Alan."
assert access_token_payload.email == "alan@lnbits.com" assert access_token_payload.email == "alan@lnbits.com"
assert access_token_payload.auth_time, "Auth time should be set by server." assert access_token_payload.auth_time, "Auth time should be set by server."
@ -113,7 +114,9 @@ async def test_login_alan_username_password_ok(
assert not user.admin, "Not admin." assert not user.admin, "Not admin."
assert not user.super_user, "Not superuser." assert not user.super_user, "Not superuser."
assert user.has_password, "Password configured." assert user.has_password, "Password configured."
assert len(user.wallets) == 1, "One default wallet." assert (
len(user.wallets) == 1
), f"Expected 1 default wallet, not {len(user.wallets)}."
@pytest.mark.asyncio @pytest.mark.asyncio
@ -139,7 +142,7 @@ async def test_login_alan_password_nok(user_alan: User, http_client: AsyncClient
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_login_username_password_not_allowed( async def test_login_username_password_not_allowed(
user_alan: User, http_client: AsyncClient user_alan: User, http_client: AsyncClient, settings: Settings
): ):
# exclude 'username_password' # exclude 'username_password'
settings.auth_allowed_methods = [AuthMethods.user_id_only.value] settings.auth_allowed_methods = [AuthMethods.user_id_only.value]
@ -164,7 +167,7 @@ async def test_login_username_password_not_allowed(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_login_alan_change_auth_secret_key( async def test_login_alan_change_auth_secret_key(
user_alan: User, http_client: AsyncClient user_alan: User, http_client: AsyncClient, settings: Settings
): ):
response = await http_client.post( response = await http_client.post(
"/api/v1/auth", json={"username": user_alan.username, "password": "secret1234"} "/api/v1/auth", json={"username": user_alan.username, "password": "secret1234"}
@ -221,7 +224,9 @@ async def test_register_ok(http_client: AsyncClient):
assert not user.admin, "Not admin." assert not user.admin, "Not admin."
assert not user.super_user, "Not superuser." assert not user.super_user, "Not superuser."
assert user.has_password, "Password configured." assert user.has_password, "Password configured."
assert len(user.wallets) == 1, "One default wallet." assert (
len(user.wallets) == 1
), f"Expected 1 default wallet, not {len(user.wallets)}."
@pytest.mark.asyncio @pytest.mark.asyncio
@ -250,7 +255,8 @@ async def test_register_email_twice(http_client: AsyncClient):
"email": f"u21.{tiny_id}@lnbits.com", "email": f"u21.{tiny_id}@lnbits.com",
}, },
) )
assert response.status_code == 403, "Not allowed."
assert response.status_code == 400, "Not allowed."
assert response.json().get("detail") == "Email already exists." assert response.json().get("detail") == "Email already exists."
@ -280,7 +286,7 @@ async def test_register_username_twice(http_client: AsyncClient):
"email": f"u21.{tiny_id_2}@lnbits.com", "email": f"u21.{tiny_id_2}@lnbits.com",
}, },
) )
assert response.status_code == 403, "Not allowed." assert response.status_code == 400, "Not allowed."
assert response.json().get("detail") == "Username already exists." assert response.json().get("detail") == "Username already exists."
@ -320,7 +326,7 @@ async def test_register_bad_email(http_client: AsyncClient):
################################ CHANGE PASSWORD ################################ ################################ CHANGE PASSWORD ################################
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_change_password_ok(http_client: AsyncClient): async def test_change_password_ok(http_client: AsyncClient, settings: Settings):
tiny_id = shortuuid.uuid()[:8] tiny_id = shortuuid.uuid()[:8]
response = await http_client.post( response = await http_client.post(
"/api/v1/auth/register", "/api/v1/auth/register",
@ -409,8 +415,8 @@ async def test_alan_change_password_old_nok(user_alan: User, http_client: AsyncC
}, },
) )
assert response.status_code == 403, "Old password bad." assert response.status_code == 400, "Old password bad."
assert response.json().get("detail") == "Invalid credentials." assert response.json().get("detail") == "Invalid old password."
@pytest.mark.asyncio @pytest.mark.asyncio
@ -441,7 +447,7 @@ async def test_alan_change_password_different_user(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_alan_change_password_auth_threshold_expired( async def test_alan_change_password_auth_threshold_expired(
user_alan: User, http_client: AsyncClient user_alan: User, http_client: AsyncClient, settings: Settings
): ):
response = await http_client.post("/api/v1/auth/usr", json={"usr": user_alan.id}) response = await http_client.post("/api/v1/auth/usr", json={"usr": user_alan.id})
@ -464,7 +470,7 @@ async def test_alan_change_password_auth_threshold_expired(
}, },
) )
assert response.status_code == 403, "Treshold expired." assert response.status_code == 400
assert ( assert (
response.json().get("detail") == "You can only update your credentials" response.json().get("detail") == "You can only update your credentials"
" in the first 1 seconds." " in the first 1 seconds."
@ -476,7 +482,7 @@ async def test_alan_change_password_auth_threshold_expired(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_register_nostr_ok(http_client: AsyncClient): async def test_register_nostr_ok(http_client: AsyncClient, settings: Settings):
event = {**nostr_event} event = {**nostr_event}
event["created_at"] = int(time.time()) event["created_at"] = int(time.time())
@ -502,6 +508,7 @@ async def test_register_nostr_ok(http_client: AsyncClient):
response = await http_client.get( response = await http_client.get(
"/api/v1/auth", headers={"Authorization": f"Bearer {access_token}"} "/api/v1/auth", headers={"Authorization": f"Bearer {access_token}"}
) )
user = User(**response.json()) user = User(**response.json())
assert user.username is None, "No username." assert user.username is None, "No username."
assert user.email is None, "No email." assert user.email is None, "No email."
@ -509,11 +516,13 @@ async def test_register_nostr_ok(http_client: AsyncClient):
assert not user.admin, "Not admin." assert not user.admin, "Not admin."
assert not user.super_user, "Not superuser." assert not user.super_user, "Not superuser."
assert not user.has_password, "Password configured." assert not user.has_password, "Password configured."
assert len(user.wallets) == 1, "One default wallet." assert (
len(user.wallets) == 1
), f"Expected 1 default wallet, not {len(user.wallets)}."
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_register_nostr_not_allowed(http_client: AsyncClient): async def test_register_nostr_not_allowed(http_client: AsyncClient, settings: Settings):
# exclude 'nostr_auth_nip98' # exclude 'nostr_auth_nip98'
settings.auth_allowed_methods = [AuthMethods.username_and_password.value] settings.auth_allowed_methods = [AuthMethods.username_and_password.value]
response = await http_client.post( response = await http_client.post(
@ -540,25 +549,25 @@ async def test_register_nostr_bad_header(http_client: AsyncClient):
) )
assert response.status_code == 401, "Non nostr header." assert response.status_code == 401, "Non nostr header."
assert response.json().get("detail") == "Authorization header is not nostr." assert response.json().get("detail") == "Invalid Authorization scheme."
response = await http_client.post( response = await http_client.post(
"/api/v1/auth/nostr", "/api/v1/auth/nostr",
headers={"Authorization": "nostr xyz"}, headers={"Authorization": "nostr xyz"},
) )
assert response.status_code == 401, "Nostr not base64." assert response.status_code == 400, "Nostr not base64."
assert response.json().get("detail") == "Nostr login event cannot be parsed." assert response.json().get("detail") == "Nostr login event cannot be parsed."
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_register_nostr_bad_event(http_client: AsyncClient): async def test_register_nostr_bad_event(http_client: AsyncClient, settings: Settings):
settings.auth_allowed_methods = AuthMethods.all() settings.auth_allowed_methods = AuthMethods.all()
base64_event = base64.b64encode(json.dumps(nostr_event).encode()).decode("ascii") base64_event = base64.b64encode(json.dumps(nostr_event).encode()).decode("ascii")
response = await http_client.post( response = await http_client.post(
"/api/v1/auth/nostr", "/api/v1/auth/nostr",
headers={"Authorization": f"nostr {base64_event}"}, headers={"Authorization": f"nostr {base64_event}"},
) )
assert response.status_code == 401, "Nostr event expired." assert response.status_code == 400, "Nostr event expired."
assert ( assert (
response.json().get("detail") response.json().get("detail")
== f"More than {settings.auth_credetials_update_threshold}" == f"More than {settings.auth_credetials_update_threshold}"
@ -574,7 +583,7 @@ async def test_register_nostr_bad_event(http_client: AsyncClient):
"/api/v1/auth/nostr", "/api/v1/auth/nostr",
headers={"Authorization": f"nostr {base64_event}"}, headers={"Authorization": f"nostr {base64_event}"},
) )
assert response.status_code == 401, "Nostr event signature invalid." assert response.status_code == 400, "Nostr event signature invalid."
assert response.json().get("detail") == "Nostr login event is not valid." assert response.json().get("detail") == "Nostr login event is not valid."
@ -591,7 +600,7 @@ async def test_register_nostr_bad_event_kind(http_client: AsyncClient):
"/api/v1/auth/nostr", "/api/v1/auth/nostr",
headers={"Authorization": f"nostr {base64_event_bad_kind}"}, headers={"Authorization": f"nostr {base64_event_bad_kind}"},
) )
assert response.status_code == 401, "Nostr event kind invalid." assert response.status_code == 400, "Nostr event kind invalid."
assert response.json().get("detail") == "Invalid event kind." assert response.json().get("detail") == "Invalid event kind."
@ -610,7 +619,7 @@ async def test_register_nostr_bad_event_tag_u(http_client: AsyncClient):
"/api/v1/auth/nostr", "/api/v1/auth/nostr",
headers={"Authorization": f"nostr {base64_event_tag_kind}"}, headers={"Authorization": f"nostr {base64_event_tag_kind}"},
) )
assert response.status_code == 401, "Nostr event tag missing." assert response.status_code == 400, "Nostr event tag missing."
assert response.json().get("detail") == "Tag 'method' is missing." assert response.json().get("detail") == "Tag 'method' is missing."
event_bad_kind["tags"] = [["u", "http://localhost:5000/nostr"], ["method", "XYZ"]] event_bad_kind["tags"] = [["u", "http://localhost:5000/nostr"], ["method", "XYZ"]]
@ -623,8 +632,8 @@ async def test_register_nostr_bad_event_tag_u(http_client: AsyncClient):
"/api/v1/auth/nostr", "/api/v1/auth/nostr",
headers={"Authorization": f"nostr {base64_event_tag_kind}"}, headers={"Authorization": f"nostr {base64_event_tag_kind}"},
) )
assert response.status_code == 401, "Nostr event tag invalid." assert response.status_code == 400, "Nostr event tag invalid."
assert response.json().get("detail") == "Incorrect value for tag 'method'." assert response.json().get("detail") == "Invalid value for tag 'method'."
@pytest.mark.asyncio @pytest.mark.asyncio
@ -642,7 +651,7 @@ async def test_register_nostr_bad_event_tag_menthod(http_client: AsyncClient):
"/api/v1/auth/nostr", "/api/v1/auth/nostr",
headers={"Authorization": f"nostr {base64_event}"}, headers={"Authorization": f"nostr {base64_event}"},
) )
assert response.status_code == 401, "Nostr event tag missing." assert response.status_code == 400, "Nostr event tag missing."
assert response.json().get("detail") == "Tag 'u' for URL is missing." assert response.json().get("detail") == "Tag 'u' for URL is missing."
event_bad_kind["tags"] = [["u", "http://demo.lnbits.com/nostr"], ["method", "POST"]] event_bad_kind["tags"] = [["u", "http://demo.lnbits.com/nostr"], ["method", "POST"]]
@ -655,15 +664,15 @@ async def test_register_nostr_bad_event_tag_menthod(http_client: AsyncClient):
"/api/v1/auth/nostr", "/api/v1/auth/nostr",
headers={"Authorization": f"nostr {base64_event}"}, headers={"Authorization": f"nostr {base64_event}"},
) )
assert response.status_code == 401, "Nostr event tag invalid." assert response.status_code == 400, "Nostr event tag invalid."
assert ( assert (
response.json().get("detail") == "Incorrect value for tag 'u':" response.json().get("detail") == "Invalid value for tag 'u':"
" 'http://demo.lnbits.com/nostr'." " 'http://demo.lnbits.com/nostr'."
) )
################################ CHANGE PUBLIC KEY ################################ ################################ CHANGE PUBLIC KEY ################################
async def test_change_pubkey_npub_ok(http_client: AsyncClient, user_alan: User): async def test_change_pubkey_npub_ok(http_client: AsyncClient, settings: Settings):
tiny_id = shortuuid.uuid()[:8] tiny_id = shortuuid.uuid()[:8]
response = await http_client.post( response = await http_client.post(
"/api/v1/auth/register", "/api/v1/auth/register",
@ -703,7 +712,9 @@ async def test_change_pubkey_npub_ok(http_client: AsyncClient, user_alan: User):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_change_pubkey_ok(http_client: AsyncClient, user_alan: User): async def test_change_pubkey_ok(
http_client: AsyncClient, user_alan: User, settings: Settings
):
tiny_id = shortuuid.uuid()[:8] tiny_id = shortuuid.uuid()[:8]
response = await http_client.post( response = await http_client.post(
"/api/v1/auth/register", "/api/v1/auth/register",
@ -783,7 +794,7 @@ async def test_change_pubkey_ok(http_client: AsyncClient, user_alan: User):
}, },
) )
assert response.status_code == 403, "Pubkey already used." assert response.status_code == 400, "Pubkey already used."
assert response.json().get("detail") == "Public key already in use." assert response.json().get("detail") == "Public key already in use."
@ -825,7 +836,7 @@ async def test_change_pubkey_other_user(http_client: AsyncClient, user_alan: Use
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_alan_change_pubkey_auth_threshold_expired( async def test_alan_change_pubkey_auth_threshold_expired(
user_alan: User, http_client: AsyncClient user_alan: User, http_client: AsyncClient, settings: Settings
): ):
response = await http_client.post("/api/v1/auth/usr", json={"usr": user_alan.id}) response = await http_client.post("/api/v1/auth/usr", json={"usr": user_alan.id})
@ -835,7 +846,7 @@ async def test_alan_change_pubkey_auth_threshold_expired(
assert access_token is not None assert access_token is not None
settings.auth_credetials_update_threshold = 1 settings.auth_credetials_update_threshold = 1
time.sleep(1.1) time.sleep(2.1)
response = await http_client.put( response = await http_client.put(
"/api/v1/auth/pubkey", "/api/v1/auth/pubkey",
headers={"Authorization": f"Bearer {access_token}"}, headers={"Authorization": f"Bearer {access_token}"},
@ -845,17 +856,17 @@ async def test_alan_change_pubkey_auth_threshold_expired(
}, },
) )
assert response.status_code == 403, "Treshold expired." assert response.status_code == 400, "Treshold expired."
assert ( assert (
response.json().get("detail") == "You can only update your credentials" response.json().get("detail") == "You can only update your credentials"
" in the first 1 seconds after login." " in the first 1 seconds."
" Please login again!" " Please login again or ask a new reset key!"
) )
################################ RESET PASSWORD ################################ ################################ RESET PASSWORD ################################
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_request_reset_key_ok(http_client: AsyncClient): async def test_request_reset_key_ok(http_client: AsyncClient, settings: Settings):
tiny_id = shortuuid.uuid()[:8] tiny_id = shortuuid.uuid()[:8]
response = await http_client.post( response = await http_client.post(
"/api/v1/auth/register", "/api/v1/auth/register",
@ -922,12 +933,14 @@ async def test_request_reset_key_user_not_found(http_client: AsyncClient):
}, },
) )
assert response.status_code == 403, "User does not exist." assert response.status_code == 404, "User does not exist."
assert response.json().get("detail") == "User not found." assert response.json().get("detail") == "User not found."
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_reset_username_password_not_allowed(http_client: AsyncClient): async def test_reset_username_password_not_allowed(
http_client: AsyncClient, settings: Settings
):
# exclude 'username_password' # exclude 'username_password'
settings.auth_allowed_methods = [AuthMethods.user_id_only.value] settings.auth_allowed_methods = [AuthMethods.user_id_only.value]
@ -968,7 +981,7 @@ async def test_reset_username_passwords_do_not_matcj(
}, },
) )
assert response.status_code == 403, "Passwords do not match." assert response.status_code == 400, "Passwords do not match."
assert response.json().get("detail") == "Passwords do not match." assert response.json().get("detail") == "Passwords do not match."
@ -983,13 +996,13 @@ async def test_reset_username_password_bad_key(http_client: AsyncClient):
"password_repeat": "secret0000", "password_repeat": "secret0000",
}, },
) )
assert response.status_code == 500, "Bad reset key." assert response.status_code == 400, "Bad reset key."
assert response.json().get("detail") == "Cannot reset user password." assert response.json().get("detail") == "Invalid reset key."
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_reset_password_auth_threshold_expired( async def test_reset_password_auth_threshold_expired(
user_alan: User, http_client: AsyncClient user_alan: User, http_client: AsyncClient, settings: Settings
): ):
reset_key = await api_users_reset_password(user_alan.id) reset_key = await api_users_reset_password(user_alan.id)
@ -1006,7 +1019,7 @@ async def test_reset_password_auth_threshold_expired(
}, },
) )
assert response.status_code == 403, "Treshold expired." assert response.status_code == 400, "Treshold expired."
assert ( assert (
response.json().get("detail") == "You can only update your credentials" response.json().get("detail") == "You can only update your credentials"
" in the first 1 seconds." " in the first 1 seconds."

View file

@ -1,5 +1,7 @@
import pytest import pytest
from lnbits.core.models import Payment
# check if the client is working # check if the client is working
@pytest.mark.asyncio @pytest.mark.asyncio
@ -10,17 +12,15 @@ async def test_core_views_generic(client):
# check GET /public/v1/payment/{payment_hash}: correct hash [should pass] # check GET /public/v1/payment/{payment_hash}: correct hash [should pass]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_api_public_payment_longpolling(client, invoice): async def test_api_public_payment_longpolling(client, invoice: Payment):
response = await client.get(f"/public/v1/payment/{invoice['payment_hash']}") response = await client.get(f"/public/v1/payment/{invoice.payment_hash}")
assert response.status_code < 300 assert response.status_code < 300
assert response.json()["status"] == "paid" assert response.json()["status"] == "paid"
# check GET /public/v1/payment/{payment_hash}: wrong hash [should fail] # check GET /public/v1/payment/{payment_hash}: wrong hash [should fail]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_api_public_payment_longpolling_wrong_hash(client, invoice): async def test_api_public_payment_longpolling_wrong_hash(client, invoice: Payment):
response = await client.get( response = await client.get(f"/public/v1/payment/{invoice.payment_hash + '0'*64}")
f"/public/v1/payment/{invoice['payment_hash'] + '0'*64}"
)
assert response.status_code == 404 assert response.status_code == 404
assert response.json()["detail"] == "Payment does not exist." assert response.json()["detail"] == "Payment does not exist."

View file

@ -1,60 +1,60 @@
# ruff: noqa: E402 # ruff: noqa: E402
import asyncio import asyncio
from time import time from datetime import datetime, timezone
import uvloop import uvloop
from asgi_lifespan import LifespanManager
from lnbits.core.views.payment_api import _api_payments_create_invoice
from lnbits.wallets.fake import FakeWallet from lnbits.wallets.fake import FakeWallet
uvloop.install() asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
from uuid import uuid4
import pytest import pytest
import pytest_asyncio import pytest_asyncio
from asgi_lifespan import LifespanManager
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from httpx import ASGITransport, AsyncClient from httpx import ASGITransport, AsyncClient
from lnbits.app import create_app from lnbits.app import create_app
from lnbits.core.crud import ( from lnbits.core.crud import (
create_account,
create_wallet, create_wallet,
delete_account,
get_account,
get_account_by_username, get_account_by_username,
get_user, get_payment,
update_payment_status, get_user_from_account,
update_payment,
) )
from lnbits.core.models import CreateInvoice, PaymentState from lnbits.core.models import Account, CreateInvoice, PaymentState, User
from lnbits.core.services import create_user_account, update_wallet_balance from lnbits.core.services import create_user_account, update_wallet_balance
from lnbits.core.views.payment_api import api_payments_create_invoice
from lnbits.db import DB_TYPE, SQLITE, Database from lnbits.db import DB_TYPE, SQLITE, Database
from lnbits.settings import AuthMethods, settings from lnbits.settings import AuthMethods, Settings
from lnbits.settings import settings as lnbits_settings
from tests.helpers import ( from tests.helpers import (
get_random_invoice_data, get_random_invoice_data,
) )
# override settings for tests
settings.lnbits_data_folder = "./tests/data" @pytest_asyncio.fixture(scope="session")
settings.lnbits_admin_ui = True def settings():
settings.lnbits_extensions_default_install = [] # override settings for tests
settings.lnbits_extensions_deactivate_all = True lnbits_settings.lnbits_admin_extensions = []
lnbits_settings.lnbits_data_folder = "./tests/data"
lnbits_settings.lnbits_admin_ui = True
lnbits_settings.lnbits_extensions_default_install = []
lnbits_settings.lnbits_extensions_deactivate_all = True
yield lnbits_settings
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def run_before_and_after_tests(): def run_before_and_after_tests(settings: Settings):
"""Fixture to execute asserts before and after a test is run""" """Fixture to execute asserts before and after a test is run"""
##### BEFORE TEST RUN ##### _settings_cleanup(settings)
settings.lnbits_allow_new_accounts = True
settings.auth_allowed_methods = AuthMethods.all()
settings.auth_credetials_update_threshold = 120
settings.lnbits_reserve_fee_percent = 1
settings.lnbits_reserve_fee_min = 2000
settings.lnbits_service_fee = 0
settings.lnbits_wallet_limit_daily_max_withdraw = 0
settings.lnbits_admin_extensions = []
yield # this is where the testing happens yield # this is where the testing happens
_settings_cleanup(settings)
##### AFTER TEST RUN #####
@pytest_asyncio.fixture(scope="session") @pytest_asyncio.fixture(scope="session")
@ -66,7 +66,7 @@ def event_loop():
# use session scope to run once before and once after all tests # use session scope to run once before and once after all tests
@pytest_asyncio.fixture(scope="session") @pytest_asyncio.fixture(scope="session")
async def app(): async def app(settings: Settings):
app = create_app() app = create_app()
async with LifespanManager(app) as manager: async with LifespanManager(app) as manager:
settings.first_install = False settings.first_install = False
@ -74,7 +74,7 @@ async def app():
@pytest_asyncio.fixture(scope="session") @pytest_asyncio.fixture(scope="session")
async def client(app): async def client(app, settings: Settings):
url = f"http://{settings.host}:{settings.port}" url = f"http://{settings.host}:{settings.port}"
async with AsyncClient(transport=ASGITransport(app=app), base_url=url) as client: async with AsyncClient(transport=ASGITransport(app=app), base_url=url) as client:
yield client yield client
@ -82,7 +82,7 @@ async def client(app):
# function scope # function scope
@pytest_asyncio.fixture(scope="function") @pytest_asyncio.fixture(scope="function")
async def http_client(app): async def http_client(app, settings: Settings):
url = f"http://{settings.host}:{settings.port}" url = f"http://{settings.host}:{settings.port}"
async with AsyncClient(transport=ASGITransport(app=app), base_url=url) as client: async with AsyncClient(transport=ASGITransport(app=app), base_url=url) as client:
@ -99,25 +99,33 @@ async def db():
yield Database("database") yield Database("database")
@pytest_asyncio.fixture(scope="package") @pytest_asyncio.fixture(scope="session")
async def user_alan(): async def user_alan():
user = await get_account_by_username("alan") account = await get_account_by_username("alan")
if not user: if account:
user = await create_user_account( await delete_account(account.id)
email="alan@lnbits.com", username="alan", password="secret1234"
) account = Account(
id=uuid4().hex,
email="alan@lnbits.com",
username="alan",
)
account.hash_password("secret1234")
user = await create_user_account(account)
yield user yield user
@pytest_asyncio.fixture(scope="session") @pytest_asyncio.fixture(scope="session")
async def from_user(): async def from_user():
user = await create_account() user = await create_user_account()
yield user yield user
@pytest_asyncio.fixture(scope="session") @pytest_asyncio.fixture(scope="session")
async def from_wallet(from_user): async def from_wallet(from_user):
user = from_user user = from_user
wallet = await create_wallet(user_id=user.id, wallet_name="test_wallet_from") wallet = await create_wallet(user_id=user.id, wallet_name="test_wallet_from")
await update_wallet_balance( await update_wallet_balance(
wallet_id=wallet.id, wallet_id=wallet.id,
@ -126,6 +134,15 @@ async def from_wallet(from_user):
yield wallet yield wallet
@pytest_asyncio.fixture(scope="session")
async def to_wallet_pagination_tests(to_user):
user = to_user
wallet = await create_wallet(
user_id=user.id, wallet_name="test_wallet_to_pagination_tests"
)
yield wallet
@pytest_asyncio.fixture @pytest_asyncio.fixture
async def from_wallet_ws(from_wallet, test_client): async def from_wallet_ws(from_wallet, test_client):
# wait a bit in order to avoid receiving topup notification # wait a bit in order to avoid receiving topup notification
@ -136,12 +153,12 @@ async def from_wallet_ws(from_wallet, test_client):
@pytest_asyncio.fixture(scope="session") @pytest_asyncio.fixture(scope="session")
async def to_user(): async def to_user():
user = await create_account() user = await create_user_account()
yield user yield user
@pytest.fixture() @pytest.fixture()
def from_super_user(from_user): def from_super_user(from_user: User, settings: Settings):
prev = settings.super_user prev = settings.super_user
settings.super_user = from_user.id settings.super_user = from_user.id
yield from_user yield from_user
@ -149,8 +166,10 @@ def from_super_user(from_user):
@pytest_asyncio.fixture(scope="session") @pytest_asyncio.fixture(scope="session")
async def superuser(): async def superuser(settings: Settings):
user = await get_user(settings.super_user) account = await get_account(settings.super_user)
assert account, "Superuser not found"
user = await get_user_from_account(account)
yield user yield user
@ -165,6 +184,13 @@ async def to_wallet(to_user):
yield wallet yield wallet
@pytest_asyncio.fixture(scope="session")
async def to_fresh_wallet(to_user):
user = to_user
wallet = await create_wallet(user_id=user.id, wallet_name="test_wallet_to_fresh")
yield wallet
@pytest_asyncio.fixture @pytest_asyncio.fixture
async def to_wallet_ws(to_wallet, test_client): async def to_wallet_ws(to_wallet, test_client):
# wait a bit in order to avoid receiving topup notification # wait a bit in order to avoid receiving topup notification
@ -173,6 +199,15 @@ async def to_wallet_ws(to_wallet, test_client):
yield ws yield ws
@pytest_asyncio.fixture(scope="session")
async def inkey_fresh_headers_to(to_fresh_wallet):
wallet = to_fresh_wallet
yield {
"X-Api-Key": wallet.inkey,
"Content-type": "application/json",
}
@pytest_asyncio.fixture(scope="session") @pytest_asyncio.fixture(scope="session")
async def inkey_headers_from(from_wallet): async def inkey_headers_from(from_wallet):
wallet = from_wallet wallet = from_wallet
@ -213,7 +248,7 @@ async def adminkey_headers_to(to_wallet):
async def invoice(to_wallet): async def invoice(to_wallet):
data = await get_random_invoice_data() data = await get_random_invoice_data()
invoice_data = CreateInvoice(**data) invoice_data = CreateInvoice(**data)
invoice = await api_payments_create_invoice(invoice_data, to_wallet) invoice = await _api_payments_create_invoice(invoice_data, to_wallet)
yield invoice yield invoice
del invoice del invoice
@ -224,12 +259,14 @@ async def external_funding_source():
@pytest_asyncio.fixture(scope="session") @pytest_asyncio.fixture(scope="session")
async def fake_payments(client, adminkey_headers_from): async def fake_payments(client, inkey_fresh_headers_to):
ts = datetime.now(timezone.utc).timestamp()
# Because sqlite only stores timestamps with milliseconds # Because sqlite only stores timestamps with milliseconds
# we have to wait a second to ensure a different timestamp than previous invoices # we have to wait a second to ensure a different timestamp than previous invoices
if DB_TYPE == SQLITE: if DB_TYPE == SQLITE:
await asyncio.sleep(1) await asyncio.sleep(1)
ts = time()
fake_data = [ fake_data = [
CreateInvoice(amount=10, memo="aaaa", out=False), CreateInvoice(amount=10, memo="aaaa", out=False),
@ -239,12 +276,29 @@ async def fake_payments(client, adminkey_headers_from):
for invoice in fake_data: for invoice in fake_data:
response = await client.post( response = await client.post(
"/api/v1/payments", headers=adminkey_headers_from, json=invoice.dict() "/api/v1/payments", headers=inkey_fresh_headers_to, json=invoice.dict()
) )
assert response.is_success assert response.is_success
data = response.json() data = response.json()
assert data["checking_id"] assert data["checking_id"]
await update_payment_status(data["checking_id"], status=PaymentState.SUCCESS) payment = await get_payment(data["checking_id"])
payment.status = PaymentState.SUCCESS
await update_payment(payment)
params = {"time[ge]": ts, "time[le]": time()} params = {
"created_at[ge]": ts,
"created_at[le]": datetime.now(timezone.utc).timestamp(),
}
return fake_data, params return fake_data, params
def _settings_cleanup(settings: Settings):
settings.lnbits_allow_new_accounts = True
settings.lnbits_allowed_users = []
settings.auth_allowed_methods = AuthMethods.all()
settings.auth_credetials_update_threshold = 120
settings.lnbits_reserve_fee_percent = 1
settings.lnbits_reserve_fee_min = 2000
settings.lnbits_service_fee = 0
settings.lnbits_wallet_limit_daily_max_withdraw = 0
settings.lnbits_admin_extensions = []

View file

@ -2,7 +2,8 @@ import random
import string import string
from typing import Optional from typing import Optional
from lnbits.db import FromRowModel from pydantic import BaseModel
from lnbits.wallets import get_funding_source, set_funding_source from lnbits.wallets import get_funding_source, set_funding_source
@ -10,12 +11,26 @@ class FakeError(Exception):
pass pass
class DbTestModel(FromRowModel): class DbTestModel(BaseModel):
id: int id: int
name: str name: str
value: Optional[str] = None value: Optional[str] = None
class DbTestModel2(BaseModel):
id: int
label: str
description: Optional[str] = None
child: DbTestModel
class DbTestModel3(BaseModel):
id: int
user: str
child: DbTestModel2
active: bool = False
def get_random_string(iterations: int = 10): def get_random_string(iterations: int = 10):
return "".join( return "".join(
random.SystemRandom().choice(string.ascii_uppercase + string.digits) random.SystemRandom().choice(string.ascii_uppercase + string.digits)

View file

@ -39,9 +39,11 @@ docker_lightning_unconnected_cli = [
def run_cmd(cmd: list) -> str: def run_cmd(cmd: list) -> str:
timeout = 20 timeout = 10
process = Popen(cmd, stdout=PIPE, stderr=PIPE) process = Popen(cmd, stdout=PIPE, stderr=PIPE)
logger.debug(f"running command: {cmd}")
def process_communication(comm): def process_communication(comm):
stdout, stderr = comm stdout, stderr = comm
output = stdout.decode("utf-8").strip() output = stdout.decode("utf-8").strip()

View file

@ -4,7 +4,7 @@ import hashlib
import pytest import pytest
from lnbits import bolt11 from lnbits import bolt11
from lnbits.core.crud import get_standalone_payment, update_payment_details from lnbits.core.crud import get_standalone_payment, update_payment
from lnbits.core.models import CreateInvoice, Payment, PaymentState from lnbits.core.models import CreateInvoice, Payment, PaymentState
from lnbits.core.services import fee_reserve_total, get_balance_delta from lnbits.core.services import fee_reserve_total, get_balance_delta
from lnbits.tasks import create_task, wait_for_paid_invoices from lnbits.tasks import create_task, wait_for_paid_invoices
@ -99,7 +99,7 @@ async def test_create_real_invoice(client, adminkey_headers_from, inkey_headers_
raise FakeError() raise FakeError()
task = create_task(wait_for_paid_invoices("test_create_invoice", on_paid)()) task = create_task(wait_for_paid_invoices("test_create_invoice", on_paid)())
pay_real_invoice(invoice["payment_request"]) pay_real_invoice(invoice["bolt11"])
# wait for the task to exit # wait for the task to exit
with pytest.raises(FakeError): with pytest.raises(FakeError):
@ -143,7 +143,6 @@ async def test_pay_real_invoice_set_pending_and_check_state(
payment = await get_standalone_payment(invoice["payment_hash"]) payment = await get_standalone_payment(invoice["payment_hash"])
assert payment assert payment
assert payment.success assert payment.success
assert payment.pending is False
@pytest.mark.asyncio @pytest.mark.asyncio
@ -160,28 +159,19 @@ async def test_pay_hold_invoice_check_pending(
) )
) )
await asyncio.sleep(1) await asyncio.sleep(1)
# get payment hash from the invoice # get payment hash from the invoice
invoice_obj = bolt11.decode(invoice["payment_request"]) invoice_obj = bolt11.decode(invoice["payment_request"])
payment_db = await get_standalone_payment(invoice_obj.payment_hash)
assert payment_db
assert payment_db.pending is True
settle_invoice(preimage) settle_invoice(preimage)
payment_db = await get_standalone_payment(invoice_obj.payment_hash)
assert payment_db
response = await task response = await task
assert response.status_code < 300 assert response.status_code < 300
# check if paid # check if paid
await asyncio.sleep(1) await asyncio.sleep(1)
payment_db_after_settlement = await get_standalone_payment(invoice_obj.payment_hash) payment_db_after_settlement = await get_standalone_payment(invoice_obj.payment_hash)
assert payment_db_after_settlement assert payment_db_after_settlement
assert payment_db_after_settlement.pending is False
@pytest.mark.asyncio @pytest.mark.asyncio
@ -202,11 +192,6 @@ async def test_pay_hold_invoice_check_pending_and_fail(
# get payment hash from the invoice # get payment hash from the invoice
invoice_obj = bolt11.decode(invoice["payment_request"]) invoice_obj = bolt11.decode(invoice["payment_request"])
payment_db = await get_standalone_payment(invoice_obj.payment_hash)
assert payment_db
assert payment_db.pending is True
preimage_hash = hashlib.sha256(bytes.fromhex(preimage)).hexdigest() preimage_hash = hashlib.sha256(bytes.fromhex(preimage)).hexdigest()
# cancel the hodl invoice # cancel the hodl invoice
@ -221,7 +206,6 @@ async def test_pay_hold_invoice_check_pending_and_fail(
# payment should be in database as failed # payment should be in database as failed
payment_db_after_settlement = await get_standalone_payment(invoice_obj.payment_hash) payment_db_after_settlement = await get_standalone_payment(invoice_obj.payment_hash)
assert payment_db_after_settlement assert payment_db_after_settlement
assert payment_db_after_settlement.pending is False
assert payment_db_after_settlement.failed is True assert payment_db_after_settlement.failed is True
@ -243,11 +227,6 @@ async def test_pay_hold_invoice_check_pending_and_fail_cancel_payment_task_in_me
# get payment hash from the invoice # get payment hash from the invoice
invoice_obj = bolt11.decode(invoice["payment_request"]) invoice_obj = bolt11.decode(invoice["payment_request"])
payment_db = await get_standalone_payment(invoice_obj.payment_hash)
assert payment_db
assert payment_db.pending is True
# cancel payment task, this simulates the client dropping the connection # cancel payment task, this simulates the client dropping the connection
task.cancel() task.cancel()
@ -264,7 +243,7 @@ async def test_pay_hold_invoice_check_pending_and_fail_cancel_payment_task_in_me
assert payment_db_after_settlement is not None assert payment_db_after_settlement is not None
# payment is failed # payment is failed
status = await payment_db.check_status() status = await payment_db_after_settlement.check_status()
assert not status.paid assert not status.paid
assert status.failed assert status.failed
@ -307,16 +286,15 @@ async def test_receive_real_invoice_set_pending_and_check_state(
assert payment_status["paid"] assert payment_status["paid"]
assert payment assert payment
assert payment.pending is False
# set the incoming invoice to pending # set the incoming invoice to pending
await update_payment_details(payment.checking_id, status=PaymentState.PENDING) payment.status = PaymentState.PENDING
await update_payment(payment)
payment_pending = await get_standalone_payment( payment_pending = await get_standalone_payment(
invoice["payment_hash"], incoming=True invoice["payment_hash"], incoming=True
) )
assert payment_pending assert payment_pending
assert payment_pending.pending is True
assert payment_pending.success is False assert payment_pending.success is False
assert payment_pending.failed is False assert payment_pending.failed is False
@ -324,7 +302,7 @@ async def test_receive_real_invoice_set_pending_and_check_state(
raise FakeError() raise FakeError()
task = create_task(wait_for_paid_invoices("test_create_invoice", on_paid)()) task = create_task(wait_for_paid_invoices("test_create_invoice", on_paid)())
pay_real_invoice(invoice["payment_request"]) pay_real_invoice(invoice["bolt11"])
with pytest.raises(FakeError): with pytest.raises(FakeError):
await task await task
@ -349,7 +327,7 @@ async def test_check_fee_reserve(client, adminkey_headers_from):
) )
assert response.status_code < 300 assert response.status_code < 300
invoice = response.json() invoice = response.json()
payment_request = invoice["payment_request"] payment_request = invoice["bolt11"]
response = await client.get( response = await client.get(
f"/api/v1/payments/fee-reserve?invoice={payment_request}", f"/api/v1/payments/fee-reserve?invoice={payment_request}",

View file

@ -2,45 +2,45 @@ import pytest
from bolt11 import decode from bolt11 import decode
from lnbits.core.services import ( from lnbits.core.services import (
PaymentStatus,
create_invoice, create_invoice,
) )
from lnbits.wallets import get_funding_source from lnbits.wallets import get_funding_source
from lnbits.wallets.base import PaymentStatus
description = "test create invoice" description = "test create invoice"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_invoice(from_wallet): async def test_create_invoice(from_wallet):
payment_hash, pr = await create_invoice( payment = await create_invoice(
wallet_id=from_wallet.id, wallet_id=from_wallet.id,
amount=1000, amount=1000,
memo=description, memo=description,
) )
invoice = decode(pr) invoice = decode(payment.bolt11)
assert invoice.payment_hash == payment_hash assert invoice.payment_hash == payment.payment_hash
assert invoice.amount_msat == 1000000 assert invoice.amount_msat == 1000000
assert invoice.description == description assert invoice.description == description
funding_source = get_funding_source() funding_source = get_funding_source()
status = await funding_source.get_invoice_status(payment_hash) status = await funding_source.get_invoice_status(payment.payment_hash)
assert isinstance(status, PaymentStatus) assert isinstance(status, PaymentStatus)
assert status.pending assert status.pending
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_create_internal_invoice(from_wallet): async def test_create_internal_invoice(from_wallet):
payment_hash, pr = await create_invoice( payment = await create_invoice(
wallet_id=from_wallet.id, amount=1000, memo=description, internal=True wallet_id=from_wallet.id, amount=1000, memo=description, internal=True
) )
invoice = decode(pr) invoice = decode(payment.bolt11)
assert invoice.payment_hash == payment_hash assert invoice.payment_hash == payment.payment_hash
assert invoice.amount_msat == 1000000 assert invoice.amount_msat == 1000000
assert invoice.description == description assert invoice.description == description
# Internal invoices are not on fundingsource. so we should get some kind of error # Internal invoices are not on fundingsource. so we should get some kind of error
# that the invoice is not found, but we get status pending # that the invoice is not found, but we get status pending
funding_source = get_funding_source() funding_source = get_funding_source()
status = await funding_source.get_invoice_status(payment_hash) status = await funding_source.get_invoice_status(payment.payment_hash)
assert isinstance(status, PaymentStatus) assert isinstance(status, PaymentStatus)
assert status.pending assert status.pending

View file

@ -1,27 +1,23 @@
import pytest import pytest
from lnbits.core.crud import ( from lnbits.core.models import PaymentState
get_standalone_payment,
)
from lnbits.core.services import ( from lnbits.core.services import (
PaymentError,
pay_invoice, pay_invoice,
) )
from lnbits.exceptions import PaymentError
description = "test pay invoice" description = "test pay invoice"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_services_pay_invoice(to_wallet, real_invoice): async def test_services_pay_invoice(to_wallet, real_invoice):
payment_hash = await pay_invoice( payment = await pay_invoice(
wallet_id=to_wallet.id, wallet_id=to_wallet.id,
payment_request=real_invoice.get("bolt11"), payment_request=real_invoice.get("bolt11"),
description=description, description=description,
) )
assert payment_hash
payment = await get_standalone_payment(payment_hash)
assert payment assert payment
assert not payment.pending assert payment.status == PaymentState.SUCCESS
assert payment.memo == description assert payment.memo == description

View file

@ -7,7 +7,7 @@ from pydantic import parse_obj_as
from lnbits import bolt11 from lnbits import bolt11
from lnbits.nodes.base import ChannelPoint, ChannelState, NodeChannel from lnbits.nodes.base import ChannelPoint, ChannelState, NodeChannel
from tests.conftest import pytest_asyncio, settings from tests.conftest import pytest_asyncio
from ..helpers import ( from ..helpers import (
funding_source, funding_source,
@ -25,7 +25,7 @@ pytestmark = pytest.mark.skipif(
@pytest_asyncio.fixture() @pytest_asyncio.fixture()
async def node_client(client, from_super_user): async def node_client(client, from_super_user, settings):
settings.lnbits_node_ui = True settings.lnbits_node_ui = True
settings.lnbits_public_node_ui = False settings.lnbits_public_node_ui = False
settings.lnbits_node_ui_transactions = True settings.lnbits_node_ui_transactions = True
@ -37,14 +37,14 @@ async def node_client(client, from_super_user):
@pytest_asyncio.fixture() @pytest_asyncio.fixture()
async def public_node_client(node_client): async def public_node_client(node_client, settings):
settings.lnbits_public_node_ui = True settings.lnbits_public_node_ui = True
yield node_client yield node_client
settings.lnbits_public_node_ui = False settings.lnbits_public_node_ui = False
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_node_info_not_found(client, from_super_user): async def test_node_info_not_found(client, from_super_user, settings):
settings.lnbits_node_ui = False settings.lnbits_node_ui = False
response = await client.get("/node/api/v1/info", params={"usr": from_super_user.id}) response = await client.get("/node/api/v1/info", params={"usr": from_super_user.id})
assert response.status_code == HTTPStatus.SERVICE_UNAVAILABLE assert response.status_code == HTTPStatus.SERVICE_UNAVAILABLE

View file

@ -1,28 +1,72 @@
import json
import pytest import pytest
from lnbits.helpers import ( from lnbits.db import (
dict_to_model,
insert_query, insert_query,
model_to_dict,
update_query, update_query,
) )
from tests.helpers import DbTestModel from tests.helpers import DbTestModel, DbTestModel2, DbTestModel3
test = DbTestModel(id=1, name="test", value="yes") test_data = DbTestModel3(
id=1,
user="userid",
child=DbTestModel2(
id=2,
label="test",
description="mydesc",
child=DbTestModel(id=3, name="myname", value="myvalue"),
),
active=True,
)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_helpers_insert_query(): async def test_helpers_insert_query():
q = insert_query("test_helpers_query", test) q = insert_query("test_helpers_query", test_data)
assert ( assert q == (
q == "INSERT INTO test_helpers_query (id, name, value) " """INSERT INTO test_helpers_query ("id", "user", "child", "active") """
"VALUES (:id, :name, :value)" "VALUES (:id, :user, :child, :active)"
) )
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_helpers_update_query(): async def test_helpers_update_query():
q = update_query("test_helpers_query", test) q = update_query("test_helpers_query", test_data)
assert ( assert q == (
q == "UPDATE test_helpers_query " """UPDATE test_helpers_query SET "id" = :id, "user" = """
"SET id = :id, name = :name, value = :value " """:user, "child" = :child, "active" = :active WHERE id = :id"""
"WHERE id = :id"
) )
child_json = json.dumps(
{
"id": 2,
"label": "test",
"description": "mydesc",
"child": {"id": 3, "name": "myname", "value": "myvalue"},
}
)
test_dict = {"id": 1, "user": "userid", "child": child_json, "active": True}
@pytest.mark.asyncio
async def test_helpers_model_to_dict():
d = model_to_dict(test_data)
assert d.get("id") == test_data.id
assert d.get("active") == test_data.active
assert d.get("child") == child_json
assert d.get("user") == test_data.user
assert d == test_dict
@pytest.mark.asyncio
async def test_helpers_dict_to_model():
m = dict_to_model(test_dict, DbTestModel3)
assert m == test_data
assert type(m) is DbTestModel3
assert m.active is True
assert type(m.child) is DbTestModel2
assert type(m.child.child) is DbTestModel

View file

@ -12,7 +12,7 @@ from lnbits.core.crud import get_standalone_payment, get_wallet
from lnbits.core.models import Payment, PaymentState, Wallet from lnbits.core.models import Payment, PaymentState, Wallet
from lnbits.core.services import create_invoice, pay_invoice from lnbits.core.services import create_invoice, pay_invoice
from lnbits.exceptions import PaymentError from lnbits.exceptions import PaymentError
from lnbits.settings import settings from lnbits.settings import Settings
from lnbits.tasks import ( from lnbits.tasks import (
create_permanent_task, create_permanent_task,
internal_invoice_listener, internal_invoice_listener,
@ -49,43 +49,40 @@ async def test_amountless_invoice(to_wallet: Wallet):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_bad_wallet_id(to_wallet: Wallet): async def test_bad_wallet_id(to_wallet: Wallet):
_, payment_request = await create_invoice( payment = await create_invoice(wallet_id=to_wallet.id, amount=31, memo="Bad Wallet")
wallet_id=to_wallet.id, amount=31, memo="Bad Wallet" bad_wallet_id = to_wallet.id[::-1]
) with pytest.raises(
with pytest.raises(AssertionError, match="invalid wallet_id"): PaymentError, match=f"Could not fetch wallet '{bad_wallet_id}'."
):
await pay_invoice( await pay_invoice(
wallet_id=to_wallet.id[::-1], wallet_id=bad_wallet_id,
payment_request=payment_request, payment_request=payment.bolt11,
) )
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_payment_limit(to_wallet: Wallet): async def test_payment_limit(to_wallet: Wallet):
_, payment_request = await create_invoice( payment = await create_invoice(wallet_id=to_wallet.id, amount=101, memo="")
wallet_id=to_wallet.id, amount=101, memo=""
)
with pytest.raises(PaymentError, match="Amount in invoice is too high."): with pytest.raises(PaymentError, match="Amount in invoice is too high."):
await pay_invoice( await pay_invoice(
wallet_id=to_wallet.id, wallet_id=to_wallet.id,
max_sat=100, max_sat=100,
payment_request=payment_request, payment_request=payment.bolt11,
) )
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_pay_twice(to_wallet: Wallet): async def test_pay_twice(to_wallet: Wallet):
_, payment_request = await create_invoice( payment = await create_invoice(wallet_id=to_wallet.id, amount=3, memo="Twice")
wallet_id=to_wallet.id, amount=3, memo="Twice"
)
await pay_invoice( await pay_invoice(
wallet_id=to_wallet.id, wallet_id=to_wallet.id,
payment_request=payment_request, payment_request=payment.bolt11,
) )
with pytest.raises(PaymentError, match="Internal invoice already paid."): with pytest.raises(PaymentError, match="Internal invoice already paid."):
await pay_invoice( await pay_invoice(
wallet_id=to_wallet.id, wallet_id=to_wallet.id,
payment_request=payment_request, payment_request=payment.bolt11,
) )
@ -106,15 +103,13 @@ async def test_fake_wallet_pay_external(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_invoice_changed(to_wallet: Wallet): async def test_invoice_changed(to_wallet: Wallet):
_, payment_request = await create_invoice( payment = await create_invoice(wallet_id=to_wallet.id, amount=21, memo="original")
wallet_id=to_wallet.id, amount=21, memo="original"
)
invoice = bolt11_decode(payment_request) invoice = bolt11_decode(payment.bolt11)
invoice.amount_msat = MilliSatoshi(12000) invoice.amount_msat = MilliSatoshi(12000)
payment_request = bolt11_encode(invoice) payment_request = bolt11_encode(invoice)
with pytest.raises(PaymentError, match="Invalid invoice."): with pytest.raises(PaymentError, match="Invalid invoice. Bolt11 changed."):
await pay_invoice( await pay_invoice(
wallet_id=to_wallet.id, wallet_id=to_wallet.id,
payment_request=payment_request, payment_request=payment_request,
@ -132,24 +127,20 @@ async def test_invoice_changed(to_wallet: Wallet):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_pay_for_extension(to_wallet: Wallet): async def test_pay_for_extension(to_wallet: Wallet, settings: Settings):
_, payment_request = await create_invoice( payment = await create_invoice(wallet_id=to_wallet.id, amount=3, memo="Allowed")
wallet_id=to_wallet.id, amount=3, memo="Allowed"
)
await pay_invoice( await pay_invoice(
wallet_id=to_wallet.id, payment_request=payment_request, extra={"tag": "lnurlp"} wallet_id=to_wallet.id, payment_request=payment.bolt11, tag="lnurlp"
)
_, payment_request = await create_invoice(
wallet_id=to_wallet.id, amount=3, memo="Not Allowed"
) )
payment = await create_invoice(wallet_id=to_wallet.id, amount=3, memo="Not Allowed")
settings.lnbits_admin_extensions = ["lnurlp"] settings.lnbits_admin_extensions = ["lnurlp"]
with pytest.raises( with pytest.raises(
PaymentError, match="User not authorized for extension 'lnurlp'." PaymentError, match="User not authorized for extension 'lnurlp'."
): ):
await pay_invoice( await pay_invoice(
wallet_id=to_wallet.id, wallet_id=to_wallet.id,
payment_request=payment_request, payment_request=payment.bolt11,
extra={"tag": "lnurlp"}, tag="lnurlp",
) )
@ -161,21 +152,19 @@ async def test_notification_for_internal_payment(to_wallet: Wallet):
invoice_queue: asyncio.Queue = asyncio.Queue() invoice_queue: asyncio.Queue = asyncio.Queue()
register_invoice_listener(invoice_queue, test_name) register_invoice_listener(invoice_queue, test_name)
_, payment_request = await create_invoice( payment = await create_invoice(wallet_id=to_wallet.id, amount=123, memo=test_name)
wallet_id=to_wallet.id, amount=123, memo=test_name
)
await pay_invoice( await pay_invoice(
wallet_id=to_wallet.id, payment_request=payment_request, extra={"tag": "lnurlp"} wallet_id=to_wallet.id, payment_request=payment.bolt11, extra={"tag": "lnurlp"}
) )
await asyncio.sleep(1) await asyncio.sleep(1)
while True: while True:
payment: Payment = invoice_queue.get_nowait() # raises if queue empty _payment: Payment = invoice_queue.get_nowait() # raises if queue empty
assert payment assert _payment
if payment.memo == test_name: if _payment.memo == test_name:
assert payment.status == PaymentState.SUCCESS.value assert _payment.status == PaymentState.SUCCESS.value
assert payment.bolt11 == payment_request assert _payment.bolt11 == payment.bolt11
assert payment.amount == 123_000 assert _payment.amount == 123_000
break # we found our payment, success break # we found our payment, success
@ -216,7 +205,7 @@ async def test_retry_failed_invoice(
assert external_invoice.payment_request assert external_invoice.payment_request
ws_notification = mocker.patch( ws_notification = mocker.patch(
"lnbits.core.services.send_payment_notification", "lnbits.core.services.payments.send_payment_notification",
AsyncMock(return_value=None), AsyncMock(return_value=None),
) )
@ -293,24 +282,24 @@ async def test_pay_external_invoice_pending(
AsyncMock(return_value=payment_reponse_pending), AsyncMock(return_value=payment_reponse_pending),
) )
ws_notification = mocker.patch( ws_notification = mocker.patch(
"lnbits.core.services.send_payment_notification", "lnbits.core.services.payments.send_payment_notification",
AsyncMock(return_value=None), AsyncMock(return_value=None),
) )
wallet = await get_wallet(from_wallet.id) wallet = await get_wallet(from_wallet.id)
assert wallet assert wallet
balance_before = wallet.balance balance_before = wallet.balance
payment_hash = await pay_invoice( payment = await pay_invoice(
wallet_id=from_wallet.id, wallet_id=from_wallet.id,
payment_request=external_invoice.payment_request, payment_request=external_invoice.payment_request,
) )
payment = await get_standalone_payment(payment_hash) _payment = await get_standalone_payment(payment.payment_hash)
assert payment assert _payment
assert payment.status == PaymentState.PENDING.value assert _payment.status == PaymentState.PENDING.value
assert payment.checking_id == payment_hash assert _payment.checking_id == payment.payment_hash
assert payment.amount == -2103_000 assert _payment.amount == -2103_000
assert payment.bolt11 == external_invoice.payment_request assert _payment.bolt11 == external_invoice.payment_request
assert payment.preimage == preimage assert _payment.preimage == preimage
wallet = await get_wallet(from_wallet.id) wallet = await get_wallet(from_wallet.id)
assert wallet assert wallet
@ -339,7 +328,7 @@ async def test_retry_pay_external_invoice_pending(
AsyncMock(return_value=payment_reponse_pending), AsyncMock(return_value=payment_reponse_pending),
) )
ws_notification = mocker.patch( ws_notification = mocker.patch(
"lnbits.core.services.send_payment_notification", "lnbits.core.services.payments.send_payment_notification",
AsyncMock(return_value=None), AsyncMock(return_value=None),
) )
wallet = await get_wallet(from_wallet.id) wallet = await get_wallet(from_wallet.id)
@ -384,24 +373,24 @@ async def test_pay_external_invoice_success(
AsyncMock(return_value=payment_reponse_pending), AsyncMock(return_value=payment_reponse_pending),
) )
ws_notification = mocker.patch( ws_notification = mocker.patch(
"lnbits.core.services.send_payment_notification", "lnbits.core.services.payments.send_payment_notification",
AsyncMock(return_value=None), AsyncMock(return_value=None),
) )
wallet = await get_wallet(from_wallet.id) wallet = await get_wallet(from_wallet.id)
assert wallet assert wallet
balance_before = wallet.balance balance_before = wallet.balance
payment_hash = await pay_invoice( payment = await pay_invoice(
wallet_id=from_wallet.id, wallet_id=from_wallet.id,
payment_request=external_invoice.payment_request, payment_request=external_invoice.payment_request,
) )
payment = await get_standalone_payment(payment_hash) _payment = await get_standalone_payment(payment.payment_hash)
assert payment assert _payment
assert payment.status == PaymentState.SUCCESS.value assert _payment.status == PaymentState.SUCCESS.value
assert payment.checking_id == payment_hash assert _payment.checking_id == payment.payment_hash
assert payment.amount == -2104_000 assert _payment.amount == -2104_000
assert payment.bolt11 == external_invoice.payment_request assert _payment.bolt11 == external_invoice.payment_request
assert payment.preimage == preimage assert _payment.preimage == preimage
wallet = await get_wallet(from_wallet.id) wallet = await get_wallet(from_wallet.id)
assert wallet assert wallet
@ -430,7 +419,7 @@ async def test_retry_pay_success(
AsyncMock(return_value=payment_reponse_pending), AsyncMock(return_value=payment_reponse_pending),
) )
ws_notification = mocker.patch( ws_notification = mocker.patch(
"lnbits.core.services.send_payment_notification", "lnbits.core.services.payments.send_payment_notification",
AsyncMock(return_value=None), AsyncMock(return_value=None),
) )
wallet = await get_wallet(from_wallet.id) wallet = await get_wallet(from_wallet.id)
@ -465,15 +454,15 @@ async def test_pay_external_invoice_success_bad_checking_id(
external_invoice = await external_funding_source.create_invoice(invoice_amount) external_invoice = await external_funding_source.create_invoice(invoice_amount)
assert external_invoice.payment_request assert external_invoice.payment_request
assert external_invoice.checking_id assert external_invoice.checking_id
bad_checking_id = external_invoice.checking_id[::-1] bad_checking_id = f"bad_{external_invoice.checking_id}"
preimage = "0000000000000000000000000000000000000000000000000000000000002108" preimage = "0000000000000000000000000000000000000000000000000000000000002108"
payment_reponse_pending = PaymentResponse( payment_reponse_success = PaymentResponse(
ok=True, checking_id=bad_checking_id, preimage=preimage ok=True, checking_id=bad_checking_id, preimage=preimage
) )
mocker.patch( mocker.patch(
"lnbits.wallets.FakeWallet.pay_invoice", "lnbits.wallets.FakeWallet.pay_invoice",
AsyncMock(return_value=payment_reponse_pending), AsyncMock(return_value=payment_reponse_success),
) )
await pay_invoice( await pay_invoice(
@ -519,10 +508,7 @@ async def test_no_checking_id(
assert payment.checking_id == external_invoice.checking_id assert payment.checking_id == external_invoice.checking_id
assert payment.payment_hash == external_invoice.checking_id assert payment.payment_hash == external_invoice.checking_id
assert payment.amount == -2110_000 assert payment.amount == -2110_000
assert ( assert payment.preimage is None
payment.preimage
== "0000000000000000000000000000000000000000000000000000000000000000"
)
assert payment.status == PaymentState.PENDING.value assert payment.status == PaymentState.PENDING.value
@ -532,6 +518,7 @@ async def test_service_fee(
to_wallet: Wallet, to_wallet: Wallet,
mocker: MockerFixture, mocker: MockerFixture,
external_funding_source: FakeWallet, external_funding_source: FakeWallet,
settings: Settings,
): ):
invoice_amount = 2112 invoice_amount = 2112
external_invoice = await external_funding_source.create_invoice(invoice_amount) external_invoice = await external_funding_source.create_invoice(invoice_amount)
@ -550,27 +537,26 @@ async def test_service_fee(
settings.lnbits_service_fee_wallet = to_wallet.id settings.lnbits_service_fee_wallet = to_wallet.id
settings.lnbits_service_fee = 20 settings.lnbits_service_fee = 20
payment_hash = await pay_invoice( payment = await pay_invoice(
wallet_id=from_wallet.id, wallet_id=from_wallet.id,
payment_request=external_invoice.payment_request, payment_request=external_invoice.payment_request,
) )
payment = await get_standalone_payment(payment_hash) _payment = await get_standalone_payment(payment.payment_hash)
assert payment assert _payment
assert payment.status == PaymentState.SUCCESS.value assert _payment.status == PaymentState.SUCCESS.value
assert payment.checking_id == payment_hash assert _payment.checking_id == payment.payment_hash
assert payment.amount == -2112_000 assert _payment.amount == -2112_000
assert payment.fee == -422_400 assert _payment.fee == -422_400
assert payment.bolt11 == external_invoice.payment_request assert _payment.bolt11 == external_invoice.payment_request
assert payment.preimage == preimage assert _payment.preimage == preimage
service_fee_payment = await get_standalone_payment(f"service_fee_{payment_hash}") service_fee_payment = await get_standalone_payment(
f"service_fee_{payment.payment_hash}"
)
assert service_fee_payment assert service_fee_payment
assert service_fee_payment.status == PaymentState.SUCCESS.value assert service_fee_payment.status == PaymentState.SUCCESS.value
assert service_fee_payment.checking_id == f"service_fee_{payment_hash}" assert service_fee_payment.checking_id == f"service_fee_{payment.payment_hash}"
assert service_fee_payment.amount == 422_400 assert service_fee_payment.amount == 422_400
assert service_fee_payment.bolt11 == external_invoice.payment_request assert service_fee_payment.bolt11 == external_invoice.payment_request
assert ( assert service_fee_payment.preimage is None
service_fee_payment.preimage
== "0000000000000000000000000000000000000000000000000000000000000000"
)

View file

@ -5,7 +5,7 @@ from lnbits.core.services import (
fee_reserve_total, fee_reserve_total,
service_fee, service_fee,
) )
from lnbits.settings import settings from lnbits.settings import Settings
@pytest.mark.asyncio @pytest.mark.asyncio
@ -15,7 +15,7 @@ async def test_fee_reserve_internal():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_fee_reserve_min(): async def test_fee_reserve_min(settings: Settings):
settings.lnbits_reserve_fee_percent = 2 settings.lnbits_reserve_fee_percent = 2
settings.lnbits_reserve_fee_min = 500 settings.lnbits_reserve_fee_min = 500
fee = fee_reserve(10000) fee = fee_reserve(10000)
@ -23,7 +23,7 @@ async def test_fee_reserve_min():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_fee_reserve_percent(): async def test_fee_reserve_percent(settings: Settings):
settings.lnbits_reserve_fee_percent = 1 settings.lnbits_reserve_fee_percent = 1
settings.lnbits_reserve_fee_min = 100 settings.lnbits_reserve_fee_min = 100
fee = fee_reserve(100000) fee = fee_reserve(100000)
@ -31,14 +31,14 @@ async def test_fee_reserve_percent():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_service_fee_no_wallet(): async def test_service_fee_no_wallet(settings: Settings):
settings.lnbits_service_fee_wallet = "" settings.lnbits_service_fee_wallet = ""
fee = service_fee(10000) fee = service_fee(10000)
assert fee == 0 assert fee == 0
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_service_fee_internal(): async def test_service_fee_internal(settings: Settings):
settings.lnbits_service_fee_wallet = "wallet_id" settings.lnbits_service_fee_wallet = "wallet_id"
settings.lnbits_service_fee_ignore_internal = True settings.lnbits_service_fee_ignore_internal = True
fee = service_fee(10000, internal=True) fee = service_fee(10000, internal=True)
@ -46,7 +46,7 @@ async def test_service_fee_internal():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_service_fee(): async def test_service_fee(settings: Settings):
settings.lnbits_service_fee_wallet = "wallet_id" settings.lnbits_service_fee_wallet = "wallet_id"
settings.lnbits_service_fee = 2 settings.lnbits_service_fee = 2
fee = service_fee(10000) fee = service_fee(10000)
@ -54,7 +54,7 @@ async def test_service_fee():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_service_fee_max(): async def test_service_fee_max(settings: Settings):
settings.lnbits_service_fee_wallet = "wallet_id" settings.lnbits_service_fee_wallet = "wallet_id"
settings.lnbits_service_fee = 2 settings.lnbits_service_fee = 2
settings.lnbits_service_fee_max = 199 settings.lnbits_service_fee_max = 199
@ -63,7 +63,7 @@ async def test_service_fee_max():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_fee_reserve_total(): async def test_fee_reserve_total(settings: Settings):
settings.lnbits_reserve_fee_percent = 1 settings.lnbits_reserve_fee_percent = 1
settings.lnbits_reserve_fee_min = 100 settings.lnbits_reserve_fee_min = 100
settings.lnbits_service_fee = 2 settings.lnbits_service_fee = 2

View file

@ -1,11 +1,11 @@
import pytest import pytest
from lnbits.core.services import check_wallet_daily_withdraw_limit from lnbits.core.services.payments import check_wallet_daily_withdraw_limit
from lnbits.settings import settings from lnbits.settings import Settings
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_no_wallet_limit(): async def test_no_wallet_limit(settings: Settings):
settings.lnbits_wallet_limit_daily_max_withdraw = 0 settings.lnbits_wallet_limit_daily_max_withdraw = 0
result = await check_wallet_daily_withdraw_limit( result = await check_wallet_daily_withdraw_limit(
conn=None, wallet_id="333333", amount_msat=0 conn=None, wallet_id="333333", amount_msat=0
@ -15,7 +15,7 @@ async def test_no_wallet_limit():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_wallet_limit_but_no_payments(): async def test_wallet_limit_but_no_payments(settings: Settings):
settings.lnbits_wallet_limit_daily_max_withdraw = 5 settings.lnbits_wallet_limit_daily_max_withdraw = 5
result = await check_wallet_daily_withdraw_limit( result = await check_wallet_daily_withdraw_limit(
conn=None, wallet_id="333333", amount_msat=0 conn=None, wallet_id="333333", amount_msat=0
@ -25,7 +25,7 @@ async def test_wallet_limit_but_no_payments():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_no_wallet_spend_allowed(): async def test_no_wallet_spend_allowed(settings: Settings):
settings.lnbits_wallet_limit_daily_max_withdraw = -1 settings.lnbits_wallet_limit_daily_max_withdraw = -1
with pytest.raises( with pytest.raises(

View file

@ -617,7 +617,7 @@
"response_type": "json", "response_type": "json",
"response": { "response": {
"checking_id": "e35526a43d04e985594c0dfab848814f524b1c786598ec9a63beddb2d726ac96", "checking_id": "e35526a43d04e985594c0dfab848814f524b1c786598ec9a63beddb2d726ac96",
"payment_request": "lnbc5550n1pnq9jg3sp52rvwstvjcypjsaenzdh0h30jazvzsf8aaye0julprtth9kysxtuspp5e5s3z7felv4t9zrcc6wpn7ehvjl5yzewanzl5crljdl3jgeffyhqdq2f38xy6t5wvxqzjccqpjrzjq0yzeq76ney45hmjlnlpvu0nakzy2g35hqh0dujq8ujdpr2e42pf2rrs6vqpgcsqqqqqqqqqqqqqqeqqyg9qxpqysgqwftcx89k5pp28435pgxfl2vx3ksemzxccppw2j9yjn0ngr6ed7wj8ztc0d5kmt2mvzdlcgrludhz7jncd5l5l9w820hc4clpwhtqj3gq62g66n" "bolt11": "lnbc5550n1pnq9jg3sp52rvwstvjcypjsaenzdh0h30jazvzsf8aaye0julprtth9kysxtuspp5e5s3z7felv4t9zrcc6wpn7ehvjl5yzewanzl5crljdl3jgeffyhqdq2f38xy6t5wvxqzjccqpjrzjq0yzeq76ney45hmjlnlpvu0nakzy2g35hqh0dujq8ujdpr2e42pf2rrs6vqpgcsqqqqqqqqqqqqqqeqqyg9qxpqysgqwftcx89k5pp28435pgxfl2vx3ksemzxccppw2j9yjn0ngr6ed7wj8ztc0d5kmt2mvzdlcgrludhz7jncd5l5l9w820hc4clpwhtqj3gq62g66n"
} }
} }
] ]
@ -825,7 +825,7 @@
}, },
"response_type": "json", "response_type": "json",
"response": { "response": {
"payment_request": "lnbc5550n1pnq9jg3sp52rvwstvjcypjsaenzdh0h30jazvzsf8aaye0julprtth9kysxtuspp5e5s3z7felv4t9zrcc6wpn7ehvjl5yzewanzl5crljdl3jgeffyhqdq2f38xy6t5wvxqzjccqpjrzjq0yzeq76ney45hmjlnlpvu0nakzy2g35hqh0dujq8ujdpr2e42pf2rrs6vqpgcsqqqqqqqqqqqqqqeqqyg9qxpqysgqwftcx89k5pp28435pgxfl2vx3ksemzxccppw2j9yjn0ngr6ed7wj8ztc0d5kmt2mvzdlcgrludhz7jncd5l5l9w820hc4clpwhtqj3gq62g66n" "bolt11": "lnbc5550n1pnq9jg3sp52rvwstvjcypjsaenzdh0h30jazvzsf8aaye0julprtth9kysxtuspp5e5s3z7felv4t9zrcc6wpn7ehvjl5yzewanzl5crljdl3jgeffyhqdq2f38xy6t5wvxqzjccqpjrzjq0yzeq76ney45hmjlnlpvu0nakzy2g35hqh0dujq8ujdpr2e42pf2rrs6vqpgcsqqqqqqqqqqqqqqeqqyg9qxpqysgqwftcx89k5pp28435pgxfl2vx3ksemzxccppw2j9yjn0ngr6ed7wj8ztc0d5kmt2mvzdlcgrludhz7jncd5l5l9w820hc4clpwhtqj3gq62g66n"
} }
} }
] ]

View file

@ -55,10 +55,10 @@ internal_id = f"internal_{payment_hash}"
cursor.execute( cursor.execute(
""" """
INSERT INTO apipayments INSERT INTO apipayments
(wallet, checking_id, hash, amount, status, memo, fee, expiry, pending) (wallet_id, checking_id, payment_hash, amount, status, memo, fee, expiry)
VALUES VALUES
(:wallet_id, :checking_id, :payment_hash, :amount, (:wallet_id, :checking_id, :payment_hash, :amount,
:status, :memo, :fee, :expiry, :pending) :status, :memo, :fee, :expiry)
""", """,
{ {
"wallet_id": wallet_id, "wallet_id": wallet_id,