[fix] check user extension access (#2519)
* feat: check user extension access * fix: handle upgraded extensions
This commit is contained in:
parent
d4da96597e
commit
44b458ebb8
8 changed files with 66 additions and 56 deletions
|
|
@ -261,10 +261,10 @@ async def build_all_installed_extensions_list(
|
||||||
MUST be installed by default (see LNBITS_EXTENSIONS_DEFAULT_INSTALL).
|
MUST be installed by default (see LNBITS_EXTENSIONS_DEFAULT_INSTALL).
|
||||||
"""
|
"""
|
||||||
installed_extensions = await get_installed_extensions()
|
installed_extensions = await get_installed_extensions()
|
||||||
|
settings.lnbits_all_extensions_ids = {e.id for e in installed_extensions}
|
||||||
|
|
||||||
installed_extensions_ids = [e.id for e in installed_extensions]
|
|
||||||
for ext_id in settings.lnbits_extensions_default_install:
|
for ext_id in settings.lnbits_extensions_default_install:
|
||||||
if ext_id in installed_extensions_ids:
|
if ext_id in settings.lnbits_all_extensions_ids:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
ext_releases = await InstallableExtension.get_extension_releases(ext_id)
|
ext_releases = await InstallableExtension.get_extension_releases(ext_id)
|
||||||
|
|
@ -318,8 +318,7 @@ async def restore_installed_extension(app: FastAPI, ext: InstallableExtension):
|
||||||
|
|
||||||
# mount routes for the new version
|
# mount routes for the new version
|
||||||
core_app_extra.register_new_ext_routes(extension)
|
core_app_extra.register_new_ext_routes(extension)
|
||||||
if extension.upgrade_hash:
|
ext.notify_upgrade(extension.upgrade_hash)
|
||||||
ext.notify_upgrade()
|
|
||||||
|
|
||||||
|
|
||||||
def register_custom_extensions_path():
|
def register_custom_extensions_path():
|
||||||
|
|
|
||||||
|
|
@ -316,7 +316,7 @@ async def check_invalid_payments(
|
||||||
async def load_disabled_extension_list() -> None:
|
async def load_disabled_extension_list() -> None:
|
||||||
"""Update list of extensions that have been explicitly disabled"""
|
"""Update list of extensions that have been explicitly disabled"""
|
||||||
inactive_extensions = await get_inactive_extensions()
|
inactive_extensions = await get_inactive_extensions()
|
||||||
settings.lnbits_deactivated_extensions += inactive_extensions
|
settings.lnbits_deactivated_extensions.update(inactive_extensions)
|
||||||
|
|
||||||
|
|
||||||
@extensions.command("list")
|
@extensions.command("list")
|
||||||
|
|
|
||||||
|
|
@ -322,10 +322,7 @@ async def get_user(user_id: str, conn: Optional[Connection] = None) -> Optional[
|
||||||
)
|
)
|
||||||
|
|
||||||
if user:
|
if user:
|
||||||
extensions = await (conn or db).fetchall(
|
extensions = await get_user_active_extensions_ids(user_id, conn)
|
||||||
"""SELECT extension FROM extensions WHERE "user" = ? AND active""",
|
|
||||||
(user_id,),
|
|
||||||
)
|
|
||||||
wallets = await (conn or db).fetchall(
|
wallets = await (conn or db).fetchall(
|
||||||
"""
|
"""
|
||||||
SELECT *, COALESCE((
|
SELECT *, COALESCE((
|
||||||
|
|
@ -344,7 +341,7 @@ async def get_user(user_id: str, conn: Optional[Connection] = None) -> Optional[
|
||||||
email=user["email"],
|
email=user["email"],
|
||||||
username=user["username"],
|
username=user["username"],
|
||||||
extensions=[
|
extensions=[
|
||||||
e[0] for e in extensions if User.is_extension_for_user(e[0], user["id"])
|
e for e in extensions if User.is_extension_for_user(e[0], user["id"])
|
||||||
],
|
],
|
||||||
wallets=[Wallet(**w) for w in wallets],
|
wallets=[Wallet(**w) for w in wallets],
|
||||||
admin=user["id"] == settings.super_user
|
admin=user["id"] == settings.super_user
|
||||||
|
|
@ -482,6 +479,16 @@ async def update_user_extension(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_user_active_extensions_ids(
|
||||||
|
user_id: str, conn: Optional[Connection] = None
|
||||||
|
) -> List[str]:
|
||||||
|
rows = await (conn or db).fetchall(
|
||||||
|
"""SELECT extension FROM extensions WHERE "user" = ? AND active""",
|
||||||
|
(user_id,),
|
||||||
|
)
|
||||||
|
return [e[0] for e in rows]
|
||||||
|
|
||||||
|
|
||||||
# wallets
|
# wallets
|
||||||
# -------
|
# -------
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -94,14 +94,12 @@ async def api_install_extension(
|
||||||
# call stop while the old routes are still active
|
# call stop while the old routes are still active
|
||||||
await stop_extension_background_work(data.ext_id, user.id, access_token)
|
await stop_extension_background_work(data.ext_id, user.id, access_token)
|
||||||
|
|
||||||
if data.ext_id not in settings.lnbits_deactivated_extensions:
|
settings.lnbits_deactivated_extensions.add(data.ext_id)
|
||||||
settings.lnbits_deactivated_extensions += [data.ext_id]
|
|
||||||
|
|
||||||
# mount routes for the new version
|
# mount routes for the new version
|
||||||
core_app_extra.register_new_ext_routes(extension)
|
core_app_extra.register_new_ext_routes(extension)
|
||||||
|
|
||||||
if extension.upgrade_hash:
|
ext_info.notify_upgrade(extension.upgrade_hash)
|
||||||
ext_info.notify_upgrade()
|
|
||||||
|
|
||||||
return extension
|
return extension
|
||||||
except AssertionError as exc:
|
except AssertionError as exc:
|
||||||
|
|
@ -151,8 +149,7 @@ async def api_uninstall_extension(
|
||||||
# call stop while the old routes are still active
|
# call stop while the old routes are still active
|
||||||
await stop_extension_background_work(ext_id, user.id, access_token)
|
await stop_extension_background_work(ext_id, user.id, access_token)
|
||||||
|
|
||||||
if ext_id not in settings.lnbits_deactivated_extensions:
|
settings.lnbits_deactivated_extensions.add(ext_id)
|
||||||
settings.lnbits_deactivated_extensions += [ext_id]
|
|
||||||
|
|
||||||
for ext_info in extensions:
|
for ext_info in extensions:
|
||||||
ext_info.clean_extension_files()
|
ext_info.clean_extension_files()
|
||||||
|
|
|
||||||
|
|
@ -115,19 +115,15 @@ async def extensions_install(
|
||||||
all_extensions = get_valid_extensions()
|
all_extensions = get_valid_extensions()
|
||||||
ext = next((e for e in all_extensions if e.code == ext_id), None)
|
ext = next((e for e in all_extensions if e.code == ext_id), None)
|
||||||
if ext_id and user.admin:
|
if ext_id and user.admin:
|
||||||
if deactivate and deactivate not in settings.lnbits_deactivated_extensions:
|
if deactivate:
|
||||||
settings.lnbits_deactivated_extensions += [deactivate]
|
settings.lnbits_deactivated_extensions.add(deactivate)
|
||||||
elif activate:
|
elif activate:
|
||||||
# if extension never loaded (was deactivated on server startup)
|
# if extension never loaded (was deactivated on server startup)
|
||||||
if ext_id not in sys.modules.keys():
|
if ext_id not in sys.modules.keys():
|
||||||
# run extension start-up routine
|
# run extension start-up routine
|
||||||
core_app_extra.register_new_ext_routes(ext)
|
core_app_extra.register_new_ext_routes(ext)
|
||||||
|
|
||||||
settings.lnbits_deactivated_extensions = list(
|
settings.lnbits_deactivated_extensions.remove(activate)
|
||||||
filter(
|
|
||||||
lambda e: e != activate, settings.lnbits_deactivated_extensions
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
await update_installed_extension_state(
|
await update_installed_extension_state(
|
||||||
ext_id=ext_id, active=activate is not None
|
ext_id=ext_id, active=activate is not None
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,7 @@ from lnbits.core.crud import (
|
||||||
get_account_by_email,
|
get_account_by_email,
|
||||||
get_account_by_username,
|
get_account_by_username,
|
||||||
get_user,
|
get_user,
|
||||||
|
get_user_active_extensions_ids,
|
||||||
get_wallet_for_key,
|
get_wallet_for_key,
|
||||||
)
|
)
|
||||||
from lnbits.core.models import KeyType, User, WalletTypeInfo
|
from lnbits.core.models import KeyType, User, WalletTypeInfo
|
||||||
|
|
@ -88,16 +89,7 @@ class KeyChecker(SecurityBase):
|
||||||
detail="Invalid adminkey.",
|
detail="Invalid adminkey.",
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
await _check_user_extension_access(wallet.user, request["path"])
|
||||||
wallet.user != settings.super_user
|
|
||||||
and wallet.user not in settings.lnbits_admin_users
|
|
||||||
and settings.lnbits_admin_extensions
|
|
||||||
and request["path"].split("/")[1] in settings.lnbits_admin_extensions
|
|
||||||
):
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=HTTPStatus.FORBIDDEN,
|
|
||||||
detail="User not authorized for this extension.",
|
|
||||||
)
|
|
||||||
|
|
||||||
key_type = KeyType.admin if wallet.adminkey == key_value else KeyType.invoice
|
key_type = KeyType.admin if wallet.adminkey == key_value else KeyType.invoice
|
||||||
return WalletTypeInfo(key_type, wallet)
|
return WalletTypeInfo(key_type, wallet)
|
||||||
|
|
@ -161,15 +153,7 @@ async def check_user_exists(
|
||||||
user = await get_user(account.id)
|
user = await get_user(account.id)
|
||||||
assert user, "User not found for account."
|
assert user, "User not found for account."
|
||||||
|
|
||||||
if (
|
await _check_user_extension_access(user.id, r["path"])
|
||||||
user.id != settings.super_user
|
|
||||||
and user.id not in settings.lnbits_admin_users
|
|
||||||
and settings.lnbits_admin_extensions
|
|
||||||
and r["path"].split("/")[1] in settings.lnbits_admin_extensions
|
|
||||||
):
|
|
||||||
raise HTTPException(
|
|
||||||
HTTPStatus.UNAUTHORIZED, "User not authorized for extension."
|
|
||||||
)
|
|
||||||
|
|
||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
@ -226,6 +210,28 @@ def parse_filters(model: Type[TFilterModel]):
|
||||||
return dependency
|
return dependency
|
||||||
|
|
||||||
|
|
||||||
|
async def _check_user_extension_access(user_id: str, current_path: str):
|
||||||
|
"""
|
||||||
|
Check if the user has access to a particular extension.
|
||||||
|
Raises HTTP Forbidden if the user is not allowed.
|
||||||
|
"""
|
||||||
|
path = current_path.split("/")
|
||||||
|
ext_id = path[3] if path[1] == "upgrades" else path[1]
|
||||||
|
if settings.is_admin_extension(ext_id) and not settings.is_admin_user(user_id):
|
||||||
|
raise HTTPException(
|
||||||
|
HTTPStatus.FORBIDDEN,
|
||||||
|
f"User not authorized for extension '{ext_id}'.",
|
||||||
|
)
|
||||||
|
|
||||||
|
if settings.is_extension_id(ext_id):
|
||||||
|
ext_ids = await get_user_active_extensions_ids(user_id)
|
||||||
|
if ext_id not in ext_ids:
|
||||||
|
raise HTTPException(
|
||||||
|
HTTPStatus.FORBIDDEN,
|
||||||
|
f"User extension '{ext_id}' not enabled.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def _get_account_from_token(access_token):
|
async def _get_account_from_token(access_token):
|
||||||
try:
|
try:
|
||||||
payload = jwt.decode(access_token, settings.auth_secret_key, "HS256")
|
payload = jwt.decode(access_token, settings.auth_secret_key, "HS256")
|
||||||
|
|
|
||||||
|
|
@ -479,22 +479,15 @@ class InstallableExtension(BaseModel):
|
||||||
shutil.copytree(Path(self.ext_upgrade_dir), Path(self.ext_dir))
|
shutil.copytree(Path(self.ext_upgrade_dir), Path(self.ext_dir))
|
||||||
logger.success(f"Extension {self.name} ({self.installed_version}) installed.")
|
logger.success(f"Extension {self.name} ({self.installed_version}) installed.")
|
||||||
|
|
||||||
def notify_upgrade(self) -> None:
|
def notify_upgrade(self, upgrade_hash: Optional[str]) -> None:
|
||||||
"""
|
"""
|
||||||
Update the list of upgraded extensions. The middleware will perform
|
Update the list of upgraded extensions. The middleware will perform
|
||||||
redirects based on this
|
redirects based on this
|
||||||
"""
|
"""
|
||||||
|
if upgrade_hash:
|
||||||
|
settings.lnbits_upgraded_extensions.add(f"{self.hash}/{self.id}")
|
||||||
|
|
||||||
clean_upgraded_exts = list(
|
settings.lnbits_all_extensions_ids.add(self.id)
|
||||||
filter(
|
|
||||||
lambda old_ext: not old_ext.endswith(f"/{self.id}"),
|
|
||||||
settings.lnbits_upgraded_extensions,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
settings.lnbits_upgraded_extensions = [
|
|
||||||
*clean_upgraded_exts,
|
|
||||||
f"{self.hash}/{self.id}",
|
|
||||||
]
|
|
||||||
|
|
||||||
def clean_extension_files(self):
|
def clean_extension_files(self):
|
||||||
# remove downloaded archive
|
# remove downloaded archive
|
||||||
|
|
|
||||||
|
|
@ -63,12 +63,15 @@ class ExtensionsInstallSettings(LNbitsSettings):
|
||||||
|
|
||||||
class InstalledExtensionsSettings(LNbitsSettings):
|
class InstalledExtensionsSettings(LNbitsSettings):
|
||||||
# installed extensions that have been deactivated
|
# installed extensions that have been deactivated
|
||||||
lnbits_deactivated_extensions: list[str] = Field(default=[])
|
lnbits_deactivated_extensions: set[str] = Field(default=[])
|
||||||
# upgraded extensions that require API redirects
|
# upgraded extensions that require API redirects
|
||||||
lnbits_upgraded_extensions: list[str] = Field(default=[])
|
lnbits_upgraded_extensions: set[str] = Field(default=[])
|
||||||
# list of redirects that extensions want to perform
|
# list of redirects that extensions want to perform
|
||||||
lnbits_extensions_redirects: list[Any] = Field(default=[])
|
lnbits_extensions_redirects: list[Any] = Field(default=[])
|
||||||
|
|
||||||
|
# list of all extension ids
|
||||||
|
lnbits_all_extensions_ids: set[Any] = Field(default=[])
|
||||||
|
|
||||||
def extension_upgrade_path(self, ext_id: str) -> Optional[str]:
|
def extension_upgrade_path(self, ext_id: str) -> Optional[str]:
|
||||||
return next(
|
return next(
|
||||||
(e for e in self.lnbits_upgraded_extensions if e.endswith(f"/{ext_id}")),
|
(e for e in self.lnbits_upgraded_extensions if e.endswith(f"/{ext_id}")),
|
||||||
|
|
@ -481,7 +484,7 @@ class Settings(EditableSettings, ReadOnlySettings, TransientSettings, BaseSettin
|
||||||
case_sensitive = False
|
case_sensitive = False
|
||||||
json_loads = list_parse_fallback
|
json_loads = list_parse_fallback
|
||||||
|
|
||||||
def is_user_allowed(self, user_id: str):
|
def is_user_allowed(self, user_id: str) -> bool:
|
||||||
return (
|
return (
|
||||||
len(self.lnbits_allowed_users) == 0
|
len(self.lnbits_allowed_users) == 0
|
||||||
or user_id in self.lnbits_allowed_users
|
or user_id in self.lnbits_allowed_users
|
||||||
|
|
@ -489,6 +492,15 @@ class Settings(EditableSettings, ReadOnlySettings, TransientSettings, BaseSettin
|
||||||
or user_id == self.super_user
|
or user_id == self.super_user
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def is_admin_user(self, user_id: str) -> bool:
|
||||||
|
return user_id in self.lnbits_admin_users or user_id == self.super_user
|
||||||
|
|
||||||
|
def is_admin_extension(self, ext_id: str) -> bool:
|
||||||
|
return ext_id in self.lnbits_admin_extensions
|
||||||
|
|
||||||
|
def is_extension_id(self, ext_id: str) -> bool:
|
||||||
|
return ext_id in self.lnbits_all_extensions_ids
|
||||||
|
|
||||||
|
|
||||||
class SuperSettings(EditableSettings):
|
class SuperSettings(EditableSettings):
|
||||||
super_user: str
|
super_user: str
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue