diff --git a/lnbits/app.py b/lnbits/app.py
index 991291fe..ca98cc58 100644
--- a/lnbits/app.py
+++ b/lnbits/app.py
@@ -52,6 +52,7 @@ from .middleware import (
CustomGZipMiddleware,
ExtensionsRedirectMiddleware,
InstalledExtensionMiddleware,
+ add_first_install_middleware,
add_ip_block_middleware,
add_ratelimit_middleware,
)
@@ -107,6 +108,8 @@ def create_app() -> FastAPI:
register_custom_extensions_path()
+ add_first_install_middleware(app)
+
# adds security middleware
add_ip_block_middleware(app)
add_ratelimit_middleware(app)
diff --git a/lnbits/core/crud.py b/lnbits/core/crud.py
index 7d077ad6..c13db549 100644
--- a/lnbits/core/crud.py
+++ b/lnbits/core/crud.py
@@ -151,11 +151,17 @@ async def get_account(
user_id: str, conn: Optional[Connection] = None
) -> Optional[User]:
row = await (conn or db).fetchone(
- "SELECT id, email, username, created_at, updated_at FROM accounts WHERE id = ?",
+ """
+ SELECT id, email, username, created_at, updated_at, extra
+ FROM accounts WHERE id = ?
+ """,
(user_id,),
)
- return User(**row) if row else None
+ user = User(**row) if row else None
+ if user and row["extra"]:
+ user.config = UserConfig(**json.loads(row["extra"]))
+ return user
async def get_user_password(user_id: str) -> Optional[str]:
diff --git a/lnbits/core/models.py b/lnbits/core/models.py
index 67b1ff5f..9f52a279 100644
--- a/lnbits/core/models.py
+++ b/lnbits/core/models.py
@@ -87,6 +87,10 @@ class UserConfig(BaseModel):
last_name: Optional[str] = None
display_name: Optional[str] = None
picture: Optional[str] = None
+ # Auth provider, possible values:
+ # - "env": the user was created automatically by the system
+ # - "lnbits": the user was created via register form (username/pass or user_id only)
+ # - "google | github | ...": the user was created using an SSO provider
provider: Optional[str] = "lnbits" # auth provider
@@ -141,6 +145,13 @@ class UpdateUserPassword(BaseModel):
password: str = Query(default=..., min_length=8, max_length=50)
password_repeat: str = Query(default=..., min_length=8, max_length=50)
password_old: Optional[str] = Query(default=None, min_length=8, max_length=50)
+ username: Optional[str] = Query(default=..., min_length=2, max_length=20)
+
+
+class UpdateSuperuserPassword(BaseModel):
+ username: str = Query(default=..., min_length=2, max_length=20)
+ password: str = Query(default=..., min_length=8, max_length=50)
+ password_repeat: str = Query(default=..., min_length=8, max_length=50)
class LoginUsr(BaseModel):
diff --git a/lnbits/core/services.py b/lnbits/core/services.py
index 2dcf2fe3..4114e48d 100644
--- a/lnbits/core/services.py
+++ b/lnbits/core/services.py
@@ -52,7 +52,7 @@ from .crud import (
update_super_user,
)
from .helpers import to_valid_user_id
-from .models import Payment, Wallet
+from .models import Payment, UserConfig, Wallet
class PaymentFailure(Exception):
@@ -611,6 +611,10 @@ async def check_admin_settings():
):
send_admin_user_to_saas()
+ account = await get_account(settings.super_user)
+ if account and account.config and account.config.provider == "env":
+ settings.first_install = True
+
logger.success(
"✔️ Admin UI is enabled. run `poetry run lnbits-cli superuser` "
"to get the superuser."
@@ -656,7 +660,9 @@ async def init_admin_settings(super_user: Optional[str] = None) -> SuperSettings
if super_user:
account = await get_account(super_user)
if not account:
- account = await create_account(user_id=super_user)
+ account = await create_account(
+ user_id=super_user, user_config=UserConfig(provider="env")
+ )
if not account.wallets or len(account.wallets) == 0:
await create_wallet(user_id=account.id)
diff --git a/lnbits/core/templates/core/first_install.html b/lnbits/core/templates/core/first_install.html
new file mode 100644
index 00000000..16503bed
--- /dev/null
+++ b/lnbits/core/templates/core/first_install.html
@@ -0,0 +1,139 @@
+{% extends "public.html" %} {% block page %}
+
+
+
+
+
+
+ Welcome to LNbits
+
Set up the Superuser account below.
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+{% endblock %} {% block scripts %}
+
+
+{% endblock %}
diff --git a/lnbits/core/views/auth_api.py b/lnbits/core/views/auth_api.py
index deee0181..2e9ec9a0 100644
--- a/lnbits/core/views/auth_api.py
+++ b/lnbits/core/views/auth_api.py
@@ -38,6 +38,7 @@ from ..models import (
CreateUser,
LoginUsernamePassword,
LoginUsr,
+ UpdateSuperuserPassword,
UpdateUser,
UpdateUserPassword,
User,
@@ -250,6 +251,34 @@ async def update(
raise HTTPException(HTTP_500_INTERNAL_SERVER_ERROR, "Cannot update user.")
+@auth_router.put("/api/v1/auth/first_install")
+async def first_install(data: UpdateSuperuserPassword) -> JSONResponse:
+ if not settings.first_install:
+ raise HTTPException(HTTP_401_UNAUTHORIZED, "This is not your first install")
+ try:
+ await update_account(
+ user_id=settings.super_user,
+ username=data.username,
+ user_config=UserConfig(provider="lnbits"),
+ )
+ super_user = UpdateUserPassword(
+ user_id=settings.super_user,
+ password=data.password,
+ password_repeat=data.password_repeat,
+ username=data.username,
+ )
+ await update_user_password(super_user)
+ settings.first_install = False
+ return _auth_success_response(username=super_user.username)
+ except AssertionError as e:
+ raise HTTPException(HTTP_403_FORBIDDEN, str(e))
+ except Exception as e:
+ logger.debug(e)
+ raise HTTPException(
+ HTTP_500_INTERNAL_SERVER_ERROR, "Cannot update user password."
+ )
+
+
async def _handle_sso_login(userinfo: OpenID, verified_user_id: Optional[str] = None):
email = userinfo.email
if not email or not is_valid_email_address(email):
diff --git a/lnbits/core/views/generic.py b/lnbits/core/views/generic.py
index 72db46e8..c44b3b27 100644
--- a/lnbits/core/views/generic.py
+++ b/lnbits/core/views/generic.py
@@ -53,6 +53,22 @@ async def home(request: Request, lightning: str = ""):
)
+@generic_router.get("/first_install", response_class=HTMLResponse)
+async def first_install(request: Request):
+ if not settings.first_install:
+ return template_renderer().TemplateResponse(
+ "error.html",
+ {
+ "request": request,
+ "err": "Super user account has already been configured.",
+ },
+ )
+ return template_renderer().TemplateResponse(
+ "core/first_install.html",
+ {"request": request},
+ )
+
+
@generic_router.get("/robots.txt", response_class=HTMLResponse)
async def robots():
data = """
diff --git a/lnbits/middleware.py b/lnbits/middleware.py
index ad5e1704..40a5837d 100644
--- a/lnbits/middleware.py
+++ b/lnbits/middleware.py
@@ -2,7 +2,7 @@ from http import HTTPStatus
from typing import Any, List, Tuple, Union
from fastapi import FastAPI, Request
-from fastapi.responses import HTMLResponse, JSONResponse
+from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse
from slowapi import _rate_limit_exceeded_handler
from slowapi.errors import RateLimitExceeded
from slowapi.middleware import SlowAPIMiddleware
@@ -210,3 +210,18 @@ def add_ip_block_middleware(app: FastAPI):
return await call_next(request)
app.middleware("http")(block_allow_ip_middleware)
+
+
+def add_first_install_middleware(app: FastAPI):
+ @app.middleware("http")
+ async def first_install_middleware(request: Request, call_next):
+ if (
+ settings.first_install
+ and request.url.path != "/api/v1/auth/first_install"
+ and request.url.path != "/first_install"
+ and not request.url.path.startswith("/static")
+ ):
+ return RedirectResponse("/first_install")
+ return await call_next(request)
+
+ app.middleware("http")(first_install_middleware)
diff --git a/lnbits/settings.py b/lnbits/settings.py
index e48b9f0d..5c153d7f 100644
--- a/lnbits/settings.py
+++ b/lnbits/settings.py
@@ -390,6 +390,7 @@ class TransientSettings(InstalledExtensionsSettings):
# - are not read from a file or from the `settings` table
# - are not persisted in the `settings` table when the settings are updated
# - are cleared on server restart
+ first_install: bool = Field(default=False)
@classmethod
def readonly_fields(cls):
diff --git a/tests/conftest.py b/tests/conftest.py
index f370b1bd..3987d451 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -50,6 +50,7 @@ async def app():
clean_database(settings)
app = create_app()
await app.router.startup()
+ settings.first_install = False
yield app
await app.router.shutdown()