[fix] check user extension access (#2519)

* feat: check user extension access
* fix: handle upgraded extensions
This commit is contained in:
Vlad Stan 2024-05-21 13:17:02 +03:00 committed by dni ⚡
parent d4da96597e
commit 44b458ebb8
No known key found for this signature in database
GPG key ID: 886317704CC4E618
8 changed files with 66 additions and 56 deletions

View file

@ -261,10 +261,10 @@ async def build_all_installed_extensions_list(
MUST be installed by default (see LNBITS_EXTENSIONS_DEFAULT_INSTALL). MUST be installed by default (see LNBITS_EXTENSIONS_DEFAULT_INSTALL).
""" """
installed_extensions = await get_installed_extensions() installed_extensions = await get_installed_extensions()
settings.lnbits_all_extensions_ids = {e.id for e in installed_extensions}
installed_extensions_ids = [e.id for e in installed_extensions]
for ext_id in settings.lnbits_extensions_default_install: for ext_id in settings.lnbits_extensions_default_install:
if ext_id in installed_extensions_ids: if ext_id in settings.lnbits_all_extensions_ids:
continue continue
ext_releases = await InstallableExtension.get_extension_releases(ext_id) ext_releases = await InstallableExtension.get_extension_releases(ext_id)
@ -318,8 +318,7 @@ async def restore_installed_extension(app: FastAPI, ext: InstallableExtension):
# mount routes for the new version # mount routes for the new version
core_app_extra.register_new_ext_routes(extension) core_app_extra.register_new_ext_routes(extension)
if extension.upgrade_hash: ext.notify_upgrade(extension.upgrade_hash)
ext.notify_upgrade()
def register_custom_extensions_path(): def register_custom_extensions_path():

View file

@ -316,7 +316,7 @@ async def check_invalid_payments(
async def load_disabled_extension_list() -> None: async def load_disabled_extension_list() -> None:
"""Update list of extensions that have been explicitly disabled""" """Update list of extensions that have been explicitly disabled"""
inactive_extensions = await get_inactive_extensions() inactive_extensions = await get_inactive_extensions()
settings.lnbits_deactivated_extensions += inactive_extensions settings.lnbits_deactivated_extensions.update(inactive_extensions)
@extensions.command("list") @extensions.command("list")

View file

@ -322,10 +322,7 @@ async def get_user(user_id: str, conn: Optional[Connection] = None) -> Optional[
) )
if user: if user:
extensions = await (conn or db).fetchall( extensions = await get_user_active_extensions_ids(user_id, conn)
"""SELECT extension FROM extensions WHERE "user" = ? AND active""",
(user_id,),
)
wallets = await (conn or db).fetchall( wallets = await (conn or db).fetchall(
""" """
SELECT *, COALESCE(( SELECT *, COALESCE((
@ -344,7 +341,7 @@ async def get_user(user_id: str, conn: Optional[Connection] = None) -> Optional[
email=user["email"], email=user["email"],
username=user["username"], username=user["username"],
extensions=[ extensions=[
e[0] for e in extensions if User.is_extension_for_user(e[0], user["id"]) e for e in extensions if User.is_extension_for_user(e[0], user["id"])
], ],
wallets=[Wallet(**w) for w in wallets], wallets=[Wallet(**w) for w in wallets],
admin=user["id"] == settings.super_user admin=user["id"] == settings.super_user
@ -482,6 +479,16 @@ async def update_user_extension(
) )
async def get_user_active_extensions_ids(
user_id: str, conn: Optional[Connection] = None
) -> List[str]:
rows = await (conn or db).fetchall(
"""SELECT extension FROM extensions WHERE "user" = ? AND active""",
(user_id,),
)
return [e[0] for e in rows]
# wallets # wallets
# ------- # -------

View file

@ -94,14 +94,12 @@ async def api_install_extension(
# call stop while the old routes are still active # call stop while the old routes are still active
await stop_extension_background_work(data.ext_id, user.id, access_token) await stop_extension_background_work(data.ext_id, user.id, access_token)
if data.ext_id not in settings.lnbits_deactivated_extensions: settings.lnbits_deactivated_extensions.add(data.ext_id)
settings.lnbits_deactivated_extensions += [data.ext_id]
# mount routes for the new version # mount routes for the new version
core_app_extra.register_new_ext_routes(extension) core_app_extra.register_new_ext_routes(extension)
if extension.upgrade_hash: ext_info.notify_upgrade(extension.upgrade_hash)
ext_info.notify_upgrade()
return extension return extension
except AssertionError as exc: except AssertionError as exc:
@ -151,8 +149,7 @@ async def api_uninstall_extension(
# call stop while the old routes are still active # call stop while the old routes are still active
await stop_extension_background_work(ext_id, user.id, access_token) await stop_extension_background_work(ext_id, user.id, access_token)
if ext_id not in settings.lnbits_deactivated_extensions: settings.lnbits_deactivated_extensions.add(ext_id)
settings.lnbits_deactivated_extensions += [ext_id]
for ext_info in extensions: for ext_info in extensions:
ext_info.clean_extension_files() ext_info.clean_extension_files()

View file

@ -115,19 +115,15 @@ async def extensions_install(
all_extensions = get_valid_extensions() all_extensions = get_valid_extensions()
ext = next((e for e in all_extensions if e.code == ext_id), None) ext = next((e for e in all_extensions if e.code == ext_id), None)
if ext_id and user.admin: if ext_id and user.admin:
if deactivate and deactivate not in settings.lnbits_deactivated_extensions: if deactivate:
settings.lnbits_deactivated_extensions += [deactivate] settings.lnbits_deactivated_extensions.add(deactivate)
elif activate: elif activate:
# if extension never loaded (was deactivated on server startup) # if extension never loaded (was deactivated on server startup)
if ext_id not in sys.modules.keys(): if ext_id not in sys.modules.keys():
# run extension start-up routine # run extension start-up routine
core_app_extra.register_new_ext_routes(ext) core_app_extra.register_new_ext_routes(ext)
settings.lnbits_deactivated_extensions = list( settings.lnbits_deactivated_extensions.remove(activate)
filter(
lambda e: e != activate, settings.lnbits_deactivated_extensions
)
)
await update_installed_extension_state( await update_installed_extension_state(
ext_id=ext_id, active=activate is not None ext_id=ext_id, active=activate is not None

View file

@ -15,6 +15,7 @@ from lnbits.core.crud import (
get_account_by_email, get_account_by_email,
get_account_by_username, get_account_by_username,
get_user, get_user,
get_user_active_extensions_ids,
get_wallet_for_key, get_wallet_for_key,
) )
from lnbits.core.models import KeyType, User, WalletTypeInfo from lnbits.core.models import KeyType, User, WalletTypeInfo
@ -88,16 +89,7 @@ class KeyChecker(SecurityBase):
detail="Invalid adminkey.", detail="Invalid adminkey.",
) )
if ( await _check_user_extension_access(wallet.user, request["path"])
wallet.user != settings.super_user
and wallet.user not in settings.lnbits_admin_users
and settings.lnbits_admin_extensions
and request["path"].split("/")[1] in settings.lnbits_admin_extensions
):
raise HTTPException(
status_code=HTTPStatus.FORBIDDEN,
detail="User not authorized for this extension.",
)
key_type = KeyType.admin if wallet.adminkey == key_value else KeyType.invoice key_type = KeyType.admin if wallet.adminkey == key_value else KeyType.invoice
return WalletTypeInfo(key_type, wallet) return WalletTypeInfo(key_type, wallet)
@ -161,15 +153,7 @@ async def check_user_exists(
user = await get_user(account.id) user = await get_user(account.id)
assert user, "User not found for account." assert user, "User not found for account."
if ( await _check_user_extension_access(user.id, r["path"])
user.id != settings.super_user
and user.id not in settings.lnbits_admin_users
and settings.lnbits_admin_extensions
and r["path"].split("/")[1] in settings.lnbits_admin_extensions
):
raise HTTPException(
HTTPStatus.UNAUTHORIZED, "User not authorized for extension."
)
return user return user
@ -226,6 +210,28 @@ def parse_filters(model: Type[TFilterModel]):
return dependency return dependency
async def _check_user_extension_access(user_id: str, current_path: str):
"""
Check if the user has access to a particular extension.
Raises HTTP Forbidden if the user is not allowed.
"""
path = current_path.split("/")
ext_id = path[3] if path[1] == "upgrades" else path[1]
if settings.is_admin_extension(ext_id) and not settings.is_admin_user(user_id):
raise HTTPException(
HTTPStatus.FORBIDDEN,
f"User not authorized for extension '{ext_id}'.",
)
if settings.is_extension_id(ext_id):
ext_ids = await get_user_active_extensions_ids(user_id)
if ext_id not in ext_ids:
raise HTTPException(
HTTPStatus.FORBIDDEN,
f"User extension '{ext_id}' not enabled.",
)
async def _get_account_from_token(access_token): async def _get_account_from_token(access_token):
try: try:
payload = jwt.decode(access_token, settings.auth_secret_key, "HS256") payload = jwt.decode(access_token, settings.auth_secret_key, "HS256")

View file

@ -479,22 +479,15 @@ class InstallableExtension(BaseModel):
shutil.copytree(Path(self.ext_upgrade_dir), Path(self.ext_dir)) shutil.copytree(Path(self.ext_upgrade_dir), Path(self.ext_dir))
logger.success(f"Extension {self.name} ({self.installed_version}) installed.") logger.success(f"Extension {self.name} ({self.installed_version}) installed.")
def notify_upgrade(self) -> None: def notify_upgrade(self, upgrade_hash: Optional[str]) -> None:
""" """
Update the list of upgraded extensions. The middleware will perform Update the list of upgraded extensions. The middleware will perform
redirects based on this redirects based on this
""" """
if upgrade_hash:
settings.lnbits_upgraded_extensions.add(f"{self.hash}/{self.id}")
clean_upgraded_exts = list( settings.lnbits_all_extensions_ids.add(self.id)
filter(
lambda old_ext: not old_ext.endswith(f"/{self.id}"),
settings.lnbits_upgraded_extensions,
)
)
settings.lnbits_upgraded_extensions = [
*clean_upgraded_exts,
f"{self.hash}/{self.id}",
]
def clean_extension_files(self): def clean_extension_files(self):
# remove downloaded archive # remove downloaded archive

View file

@ -63,12 +63,15 @@ class ExtensionsInstallSettings(LNbitsSettings):
class InstalledExtensionsSettings(LNbitsSettings): class InstalledExtensionsSettings(LNbitsSettings):
# installed extensions that have been deactivated # installed extensions that have been deactivated
lnbits_deactivated_extensions: list[str] = Field(default=[]) lnbits_deactivated_extensions: set[str] = Field(default=[])
# upgraded extensions that require API redirects # upgraded extensions that require API redirects
lnbits_upgraded_extensions: list[str] = Field(default=[]) lnbits_upgraded_extensions: set[str] = Field(default=[])
# list of redirects that extensions want to perform # list of redirects that extensions want to perform
lnbits_extensions_redirects: list[Any] = Field(default=[]) lnbits_extensions_redirects: list[Any] = Field(default=[])
# list of all extension ids
lnbits_all_extensions_ids: set[Any] = Field(default=[])
def extension_upgrade_path(self, ext_id: str) -> Optional[str]: def extension_upgrade_path(self, ext_id: str) -> Optional[str]:
return next( return next(
(e for e in self.lnbits_upgraded_extensions if e.endswith(f"/{ext_id}")), (e for e in self.lnbits_upgraded_extensions if e.endswith(f"/{ext_id}")),
@ -481,7 +484,7 @@ class Settings(EditableSettings, ReadOnlySettings, TransientSettings, BaseSettin
case_sensitive = False case_sensitive = False
json_loads = list_parse_fallback json_loads = list_parse_fallback
def is_user_allowed(self, user_id: str): def is_user_allowed(self, user_id: str) -> bool:
return ( return (
len(self.lnbits_allowed_users) == 0 len(self.lnbits_allowed_users) == 0
or user_id in self.lnbits_allowed_users or user_id in self.lnbits_allowed_users
@ -489,6 +492,15 @@ class Settings(EditableSettings, ReadOnlySettings, TransientSettings, BaseSettin
or user_id == self.super_user or user_id == self.super_user
) )
def is_admin_user(self, user_id: str) -> bool:
return user_id in self.lnbits_admin_users or user_id == self.super_user
def is_admin_extension(self, ext_id: str) -> bool:
return ext_id in self.lnbits_admin_extensions
def is_extension_id(self, ext_id: str) -> bool:
return ext_id in self.lnbits_all_extensions_ids
class SuperSettings(EditableSettings): class SuperSettings(EditableSettings):
super_user: str super_user: str