[fix] user sorting performance (#3561)

This commit is contained in:
Vlad Stan 2025-11-25 09:16:35 +02:00 committed by GitHub
parent 148ba9d275
commit d55e2a0e1f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 164 additions and 39 deletions

View file

@ -237,7 +237,7 @@ class AccountOverview(Account):
class AccountFilters(FilterModel): class AccountFilters(FilterModel):
__search_fields__ = [ __search_fields__ = [
"user", "id",
"email", "email",
"username", "username",
"pubkey", "pubkey",
@ -245,17 +245,18 @@ class AccountFilters(FilterModel):
"wallet_id", "wallet_id",
] ]
__sort_fields__ = [ __sort_fields__ = [
"balance_msat", "id",
"email", "email",
"username", "username",
"transaction_count", "pubkey",
"wallet_count", "external_id",
"last_payment", "created_at",
"updated_at",
] ]
email: str | None = None id: str | None = None
user: str | None = None
username: str | None = None username: str | None = None
email: str | None = None
pubkey: str | None = None pubkey: str | None = None
external_id: str | None = None external_id: str | None = None
wallet_id: str | None = None wallet_id: str | None = None

View file

@ -235,11 +235,14 @@ class Connection(Compat):
table_name: if provided some optimisations can be applied. table_name: if provided some optimisations can be applied.
""" """
if table_name and not _valid_sql_name(table_name):
raise ValueError(f"Invalid table name: '{table_name}'.")
if not filters: if not filters:
filters = Filters() filters = Filters()
if table_name:
if not _valid_sql_name(table_name):
raise ValueError(f"Invalid table name: '{table_name}'.")
filters.set_table_name(table_name)
clause = filters.where(where) clause = filters.where(where)
parsed_values = filters.values(values) parsed_values = filters.values(values)
@ -491,6 +494,7 @@ class Page(BaseModel, Generic[T]):
class Filter(BaseModel, Generic[TFilterModel]): class Filter(BaseModel, Generic[TFilterModel]):
table_name: str | None = None
field: str field: str
op: Operator = Operator.EQ op: Operator = Operator.EQ
model: type[TFilterModel] | None model: type[TFilterModel] | None
@ -533,20 +537,20 @@ class Filter(BaseModel, Generic[TFilterModel]):
@property @property
def statement(self) -> str: def statement(self) -> str:
prefix = f"{self.table_name}." if self.table_name else ""
stmt = [] stmt = []
for key in self.values.keys() if self.values else []: for key in self.values.keys() if self.values else []:
clean_key = key.split("__")[0] clean_key = key.split("__")[0]
if self.model and self.model.__fields__[clean_key].type_ == datetime: if self.model and self.model.__fields__[clean_key].type_ == datetime:
placeholder = compat_timestamp_placeholder(key) placeholder = compat_timestamp_placeholder(key)
stmt.append(f"{clean_key} {self.op.as_sql} {placeholder}") stmt.append(f"{prefix}{clean_key} {self.op.as_sql} {placeholder}")
else: else:
stmt.append(f"{clean_key} {self.op.as_sql} :{key}") stmt.append(f"{prefix}{clean_key} {self.op.as_sql} :{key}")
if self.op == Operator.EVERY: if self.op == Operator.EVERY:
statement = " AND ".join(stmt) statement = " AND ".join(stmt)
else: else:
statement = " OR ".join(stmt) statement = " OR ".join(stmt)
return f"({statement})" return f"({statement})"
@ -563,13 +567,14 @@ class Filters(BaseModel, Generic[TFilterModel]):
search: str | None = None search: str | None = None
offset: int | None = None offset: int | None = None
limit: int | None = None limit: int | None = 10
sortby: str | None = None sortby: str | None = None
direction: Literal["asc", "desc"] | None = None direction: Literal["asc", "desc"] | None = None
model: type[TFilterModel] | None = None model: type[TFilterModel] | None = None
table_name: str | None = None
@root_validator(pre=True) @root_validator(pre=True)
def validate_sortby(cls, values): def validate_sortby(cls, values):
sortby = values.get("sortby") sortby = values.get("sortby")
@ -584,8 +589,8 @@ class Filters(BaseModel, Generic[TFilterModel]):
def pagination(self) -> str: def pagination(self) -> str:
stmt = "" stmt = ""
if self.limit: self.limit = self.limit or 10
stmt += f"LIMIT {self.limit} " stmt += f"LIMIT {min(1000, self.limit)} "
if self.offset: if self.offset:
stmt += f"OFFSET {self.offset}" stmt += f"OFFSET {self.offset}"
return stmt return stmt
@ -611,7 +616,8 @@ class Filters(BaseModel, Generic[TFilterModel]):
def order_by(self) -> str: def order_by(self) -> str:
if self.sortby: if self.sortby:
return f"ORDER BY {self.sortby} {self.direction or 'asc'}" prefix = f"{self.table_name}." if self.table_name else ""
return f"ORDER BY {prefix}{self.sortby} {self.direction or 'asc'}"
return "" return ""
def values(self, values: dict | None = None) -> dict: def values(self, values: dict | None = None) -> dict:
@ -631,6 +637,11 @@ class Filters(BaseModel, Generic[TFilterModel]):
values["search"] = f"%{self.search.lower()}%" values["search"] = f"%{self.search.lower()}%"
return values return values
def set_table_name(self, table_name: str) -> None:
self.table_name = table_name
for page_filter in self.filters:
page_filter.table_name = table_name
class DbJsonEncoder(json.JSONEncoder): class DbJsonEncoder(json.JSONEncoder):
def default(self, o): def default(self, o):

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -716,11 +716,14 @@ window.localisation.en = {
add_label: 'Add Label', add_label: 'Add Label',
label: 'Label', label: 'Label',
labels: 'Labels', labels: 'Labels',
label_filter: 'Label Filter',
no_labels_defined: 'No labels defined yet', no_labels_defined: 'No labels defined yet',
manage_labels: 'Manage Labels', manage_labels: 'Manage Labels',
update_label: 'Update Label', update_label: 'Update Label',
delete_label: 'Delete Label', delete_label: 'Delete Label',
add_remove_labels: 'Add or Remove Labels', add_remove_labels: 'Add or Remove Labels',
payment_labels_updated: 'Payment labels updated', payment_labels_updated: 'Payment labels updated',
color: 'Color' color: 'Color',
sort: 'Sort',
sort_by: 'Sort by'
} }

View file

@ -85,10 +85,10 @@ window.PageUsers = {
sortable: false sortable: false
}, },
{ {
name: 'user', name: 'id',
align: 'left', align: 'left',
label: 'User Id', label: 'User Id',
field: 'user', field: 'id',
sortable: false sortable: false
}, },
@ -119,7 +119,7 @@ window.PageUsers = {
align: 'left', align: 'left',
label: 'Balance', label: 'Balance',
field: 'balance_msat', field: 'balance_msat',
sortable: true sortable: false
}, },
{ {
@ -127,7 +127,7 @@ window.PageUsers = {
align: 'left', align: 'left',
label: 'Payments', label: 'Payments',
field: 'transaction_count', field: 'transaction_count',
sortable: true sortable: false
}, },
{ {
@ -135,16 +135,24 @@ window.PageUsers = {
align: 'left', align: 'left',
label: 'Last Payment', label: 'Last Payment',
field: 'last_payment', field: 'last_payment',
sortable: true sortable: false
} }
], ],
pagination: { pagination: {
sortBy: 'balance_msat', sortBy: 'created_at',
rowsPerPage: 10, rowsPerPage: 10,
page: 1, page: 1,
descending: true, descending: true,
rowsNumber: 10 rowsNumber: 10
}, },
sortFields: [
{name: 'id', label: 'User ID'},
{name: 'username', label: 'Username'},
{name: 'email', label: 'Email'},
{name: 'pubkey', label: 'Public Key'},
{name: 'created_at', label: 'Creation Date'},
{name: 'updated_at', label: 'Last Updated'}
],
search: null, search: null,
hideEmpty: true, hideEmpty: true,
loading: false loading: false
@ -198,6 +206,16 @@ window.PageUsers = {
}) })
.catch(LNbits.utils.notifyApiError) .catch(LNbits.utils.notifyApiError)
}, },
sortByColumn(columnName) {
if (this.usersTable.pagination.sortBy === columnName) {
this.usersTable.pagination.descending =
!this.usersTable.pagination.descending
} else {
this.usersTable.pagination.sortBy = columnName
this.usersTable.pagination.descending = false
}
this.fetchUsers()
},
createUser() { createUser() {
LNbits.api LNbits.api
.request('POST', '/users/api/v1/user', null, this.activeUser.data) .request('POST', '/users/api/v1/user', null, this.activeUser.data)

View file

@ -515,7 +515,44 @@
> >
<template v-slot:header="props"> <template v-slot:header="props">
<q-tr :props="props"> <q-tr :props="props">
<q-th auto-width></q-th> <q-th auto-width>
<q-btn-dropdown color="primary" icon="sort" flat dense>
<q-list>
<template
class="full-width"
v-for="column in usersTable.sortFields"
:key="column.name"
>
<q-item
@click="sortByColumn(column.name)"
clickable
v-ripple
v-close-popup
dense
>
<q-item-section>
<q-item-label lines="1" class="full-width"
><span v-text="column.label"></span
></q-item-label>
</q-item-section>
<q-item-section side>
<template
v-if="
usersTable.pagination.sortBy === column.name
"
>
<q-icon
v-if="usersTable.pagination.descending"
name="arrow_downward"
></q-icon>
<q-icon v-else name="arrow_upward"></q-icon>
</template>
</q-item-section>
</q-item>
</template>
</q-list>
</q-btn-dropdown>
</q-th>
<q-th v-for="col in props.cols" :key="col.name" :props="props"> <q-th v-for="col in props.cols" :key="col.name" :props="props">
<q-input <q-input
v-if=" v-if="

View file

@ -1,10 +1,12 @@
from typing import Any from typing import Any
from uuid import uuid4
import pytest import pytest
import shortuuid import shortuuid
from httpx import AsyncClient from httpx import AsyncClient
from lnbits.core.models.users import User from lnbits.core.models.users import Account, User
from lnbits.core.services.users import create_user_account
from lnbits.settings import Settings from lnbits.settings import Settings
from lnbits.utils.nostr import generate_keypair, hex_to_npub from lnbits.utils.nostr import generate_keypair, hex_to_npub
@ -465,3 +467,46 @@ async def test_create_user_invalid_npub(
headers={"Authorization": f"Bearer {superuser_token}"}, headers={"Authorization": f"Bearer {superuser_token}"},
) )
assert create_resp.status_code == 400 assert create_resp.status_code == 400
@pytest.mark.anyio
async def test_search_users(http_client: AsyncClient, superuser_token):
namespace_id = shortuuid.uuid()[:8]
users = []
user_count = 15
for index in range(user_count):
username = f"u_{namespace_id}_{index:03d}"
user = await create_user_account(
Account(
id=uuid4().hex,
username=username,
email=f"{username}@lnbits.com",
pubkey="",
external_id=None,
)
)
users.append(user)
create_resp = await http_client.get(
"/users/api/v1/user?sortby=id&direction=desc",
headers={"Authorization": f"Bearer {superuser_token}"},
)
assert create_resp.status_code == 200
create_resp = await http_client.get(
"/users/api/v1/user"
f"?sortby=username&direction=desc&username[like]=u_{namespace_id}",
headers={"Authorization": f"Bearer {superuser_token}"},
)
assert create_resp.status_code == 200
data = create_resp.json()
assert data["total"] == user_count
assert data["data"][0]["username"] == users[user_count - 1].username
create_resp = await http_client.get(
"/users/api/v1/user" f"?sortby=username&direction=desc&id={users[0].id}",
headers={"Authorization": f"Bearer {superuser_token}"},
)
assert create_resp.status_code == 200
data = create_resp.json()
assert data["total"] == 1
assert data["data"][0]["username"] == users[0].username

View file

@ -36,37 +36,42 @@ async def test_crud_get_payments(app):
await update_wallet_balance(wallet, -10) await update_wallet_balance(wallet, -10)
wallet.balance_msat += -10 * 1000 wallet.balance_msat += -10 * 1000
payments = await get_payments(wallet_id=wallet.id) filters = Filters(limit=100)
payments = await get_payments(wallet_id=wallet.id, filters=filters)
assert len(payments) == 22, "should return 22 successful payments" assert len(payments) == 22, "should return 22 successful payments"
payments = await get_payments(wallet_id=wallet.id, incoming=True) payments = await get_payments(wallet_id=wallet.id, incoming=True, filters=filters)
assert len(payments) == 11, "should return 11 successful incoming payments" assert len(payments) == 11, "should return 11 successful incoming payments"
await update_payments(payments) await update_payments(payments)
payments = await get_payments(wallet_id=wallet.id, outgoing=True) payments = await get_payments(wallet_id=wallet.id, outgoing=True, filters=filters)
assert len(payments) == 11, "should return 11 successful outgoing payments" assert len(payments) == 11, "should return 11 successful outgoing payments"
await update_payments(payments) await update_payments(payments)
payments = await get_payments(wallet_id=wallet.id, pending=True) payments = await get_payments(wallet_id=wallet.id, pending=True, filters=filters)
assert len(payments) == 4, "should return 4 pending payments" assert len(payments) == 4, "should return 4 pending payments"
# function signature should have Optional[bool] for complete and pending to make # function signature should have Optional[bool] for complete and pending to make
# this distinction possible # this distinction possible
payments = await get_payments(wallet_id=wallet.id, pending=False) payments = await get_payments(wallet_id=wallet.id, pending=False, filters=filters)
assert len(payments) == 22, "should return all payments" assert len(payments) == 22, "should return all payments"
payments = await get_payments(wallet_id=wallet.id, complete=True, pending=True) payments = await get_payments(
wallet_id=wallet.id, complete=True, pending=True, filters=filters
)
assert len(payments) == 20, "should return 4 pending and 16 complete payments" assert len(payments) == 20, "should return 4 pending and 16 complete payments"
payments = await get_payments(wallet_id=wallet.id, complete=True, outgoing=True) payments = await get_payments(
wallet_id=wallet.id, complete=True, outgoing=True, filters=filters
)
assert ( assert (
len(payments) == 10 len(payments) == 10
), "should return 8 complete outgoing payments and 2 pending outgoing payments" ), "should return 8 complete outgoing payments and 2 pending outgoing payments"
payments = await get_payments(wallet_id=wallet.id) payments = await get_payments(wallet_id=wallet.id, filters=filters)
assert len(payments) == 22, "should return all payments" assert len(payments) == 22, "should return all payments"
payments = await get_payments(wallet_id=wallet.id, complete=True) payments = await get_payments(wallet_id=wallet.id, complete=True, filters=filters)
assert ( assert (
len(payments) == 18 len(payments) == 18
), "should return 14 successful payment and 4 pending payments" ), "should return 14 successful payment and 4 pending payments"

View file

@ -9,6 +9,7 @@ from lnbits.core.crud.wallets import (
get_wallets, get_wallets,
update_wallet, update_wallet,
) )
from lnbits.core.models.payments import PaymentFilters
from lnbits.core.models.users import User from lnbits.core.models.users import User
from lnbits.core.models.wallets import ( from lnbits.core.models.wallets import (
Wallet, Wallet,
@ -29,6 +30,7 @@ from lnbits.core.services.wallets import (
reject_wallet_invitation, reject_wallet_invitation,
update_wallet_share_permissions, update_wallet_share_permissions,
) )
from lnbits.db import Filters
from lnbits.exceptions import InvoiceError, PaymentError from lnbits.exceptions import InvoiceError, PaymentError
from tests.conftest import new_user from tests.conftest import new_user
@ -563,7 +565,10 @@ async def test_shared_wallet_view_permissions(from_wallet: Wallet):
await pay_invoice(wallet_id=from_wallet.id, payment_request=payment.bolt11) await pay_invoice(wallet_id=from_wallet.id, payment_request=payment.bolt11)
wallet_balance += payment.sat wallet_balance += payment.sat
shared_wallet_payments = await get_payments(wallet_id=mirror_wallet.id) filters = Filters(limit=100, model=PaymentFilters)
shared_wallet_payments = await get_payments(
wallet_id=mirror_wallet.id, filters=filters
)
assert len(shared_wallet_payments) == payment_count assert len(shared_wallet_payments) == payment_count
mirror_wallet = await get_wallet(mirror_wallet.id) mirror_wallet = await get_wallet(mirror_wallet.id)
assert mirror_wallet is not None assert mirror_wallet is not None