Merge pull request #313 from arcbtc/FastAPI

Most extensions done
This commit is contained in:
Arc 2021-08-22 22:17:55 +01:00 committed by GitHub
commit 0b132ce928
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
20 changed files with 354 additions and 218 deletions

View file

@ -1,4 +1,7 @@
from hypercorn.trio import serve
import trio import trio
import trio_asyncio
from hypercorn.config import Config
from .commands import migrate_databases, transpile_scss, bundle_vendored from .commands import migrate_databases, transpile_scss, bundle_vendored
@ -8,7 +11,7 @@ bundle_vendored()
from .app import create_app from .app import create_app
app = create_app() app = trio.run(create_app)
from .settings import ( from .settings import (
LNBITS_SITE_TITLE, LNBITS_SITE_TITLE,
@ -17,6 +20,8 @@ from .settings import (
LNBITS_DATA_FOLDER, LNBITS_DATA_FOLDER,
WALLET, WALLET,
LNBITS_COMMIT, LNBITS_COMMIT,
HOST,
PORT
) )
print( print(
@ -30,4 +35,6 @@ print(
""" """
) )
app.run(host=app.config["HOST"], port=app.config["PORT"]) config = Config()
config.bind = [f"{HOST}:{PORT}"]
trio_asyncio.run(serve, app, config)

View file

@ -1,12 +1,15 @@
import jinja2
from lnbits.jinja2_templating import Jinja2Templates
import sys import sys
import warnings import warnings
import importlib import importlib
import traceback import traceback
import trio
from quart import g from fastapi import FastAPI
from quart_trio import QuartTrio from fastapi.middleware.cors import CORSMiddleware
from quart_cors import cors # type: ignore from fastapi.middleware.gzip import GZipMiddleware
from quart_compress import Compress # type: ignore from fastapi.staticfiles import StaticFiles
from .commands import db_migrate, handle_assets from .commands import db_migrate, handle_assets
from .core import core_app from .core import core_app
@ -26,32 +29,66 @@ from .tasks import (
catch_everything_and_restart, catch_everything_and_restart,
) )
from .settings import WALLET from .settings import WALLET
from .requestvars import g, request_global
import lnbits.settings
async def create_app(config_object="lnbits.settings") -> FastAPI:
def create_app(config_object="lnbits.settings") -> QuartTrio:
"""Create application factory. """Create application factory.
:param config_object: The configuration object to use. :param config_object: The configuration object to use.
""" """
app = QuartTrio(__name__, static_folder="static") app = FastAPI()
app.config.from_object(config_object) app.mount("/static", StaticFiles(directory="lnbits/static"), name="static")
app.asgi_http_class = ASGIProxyFix
cors(app) origins = [
Compress(app) "http://localhost",
"http://localhost:5000",
]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
g().config = lnbits.settings
g().templates = build_standard_jinja_templates()
app.add_middleware(GZipMiddleware, minimum_size=1000)
# app.add_middleware(ASGIProxyFix)
check_funding_source(app) check_funding_source(app)
register_assets(app) register_assets(app)
register_blueprints(app) register_routes(app)
register_filters(app) # register_commands(app)
register_commands(app)
register_async_tasks(app) register_async_tasks(app)
register_exception_handlers(app) # register_exception_handlers(app)
return app return app
def build_standard_jinja_templates():
t = Jinja2Templates(
loader=jinja2.FileSystemLoader(["lnbits/templates", "lnbits/core/templates"]),
)
t.env.globals["SITE_TITLE"] = lnbits.settings.LNBITS_SITE_TITLE
t.env.globals["SITE_TAGLINE"] = lnbits.settings.LNBITS_SITE_TAGLINE
t.env.globals["SITE_DESCRIPTION"] = lnbits.settings.LNBITS_SITE_DESCRIPTION
t.env.globals["LNBITS_THEME_OPTIONS"] = lnbits.settings.LNBITS_THEME_OPTIONS
t.env.globals["LNBITS_VERSION"] = lnbits.settings.LNBITS_COMMIT
t.env.globals["EXTENSIONS"] = get_valid_extensions()
if g().config.DEBUG:
t.env.globals["VENDORED_JS"] = map(url_for_vendored, get_js_vendored())
t.env.globals["VENDORED_CSS"] = map(url_for_vendored, get_css_vendored())
else:
t.env.globals["VENDORED_JS"] = ["/static/bundle.js"]
t.env.globals["VENDORED_CSS"] = ["/static/bundle.css"]
def check_funding_source(app: QuartTrio) -> None: return t
@app.before_serving
def check_funding_source(app: FastAPI) -> None:
@app.on_event("startup")
async def check_wallet_status(): async def check_wallet_status():
error_message, balance = await WALLET.status() error_message, balance = await WALLET.status()
if error_message: if error_message:
@ -67,64 +104,60 @@ def check_funding_source(app: QuartTrio) -> None:
) )
def register_blueprints(app: QuartTrio) -> None: def register_routes(app: FastAPI) -> None:
"""Register Flask blueprints / LNbits extensions.""" """Register Flask blueprints / LNbits extensions."""
app.register_blueprint(core_app) app.include_router(core_app)
for ext in get_valid_extensions(): for ext in get_valid_extensions():
try: try:
ext_module = importlib.import_module(f"lnbits.extensions.{ext.code}") ext_module = importlib.import_module(f"lnbits.extensions.{ext.code}")
bp = getattr(ext_module, f"{ext.code}_ext") ext_route = getattr(ext_module, f"{ext.code}_ext")
app.register_blueprint(bp, url_prefix=f"/{ext.code}") app.include_router(ext_route)
except Exception: except Exception:
raise ImportError( raise ImportError(
f"Please make sure that the extension `{ext.code}` follows conventions." f"Please make sure that the extension `{ext.code}` follows conventions."
) )
def register_commands(app: QuartTrio): def register_commands(app: FastAPI):
"""Register Click commands.""" """Register Click commands."""
app.cli.add_command(db_migrate) app.cli.add_command(db_migrate)
app.cli.add_command(handle_assets) app.cli.add_command(handle_assets)
def register_assets(app: QuartTrio): def register_assets(app: FastAPI):
"""Serve each vendored asset separately or a bundle.""" """Serve each vendored asset separately or a bundle."""
@app.before_request @app.on_event("startup")
async def vendored_assets_variable(): async def vendored_assets_variable():
if app.config["DEBUG"]: if g().config.DEBUG:
g.VENDORED_JS = map(url_for_vendored, get_js_vendored()) g().VENDORED_JS = map(url_for_vendored, get_js_vendored())
g.VENDORED_CSS = map(url_for_vendored, get_css_vendored()) g().VENDORED_CSS = map(url_for_vendored, get_css_vendored())
else: else:
g.VENDORED_JS = ["/static/bundle.js"] g().VENDORED_JS = ["/static/bundle.js"]
g.VENDORED_CSS = ["/static/bundle.css"] g().VENDORED_CSS = ["/static/bundle.css"]
def register_filters(app: QuartTrio):
"""Jinja filters."""
app.jinja_env.globals["SITE_TITLE"] = app.config["LNBITS_SITE_TITLE"]
app.jinja_env.globals["SITE_TAGLINE"] = app.config["LNBITS_SITE_TAGLINE"]
app.jinja_env.globals["SITE_DESCRIPTION"] = app.config["LNBITS_SITE_DESCRIPTION"]
app.jinja_env.globals["LNBITS_THEME_OPTIONS"] = app.config["LNBITS_THEME_OPTIONS"]
app.jinja_env.globals["LNBITS_VERSION"] = app.config["LNBITS_COMMIT"]
app.jinja_env.globals["EXTENSIONS"] = get_valid_extensions()
def register_async_tasks(app): def register_async_tasks(app):
@app.route("/wallet/webhook", methods=["GET", "POST", "PUT", "PATCH", "DELETE"]) @app.route("/wallet/webhook")
async def webhook_listener(): async def webhook_listener():
return await webhook_handler() return await webhook_handler()
@app.before_serving @app.on_event("startup")
async def listeners(): async def listeners():
run_deferred_async() run_deferred_async()
app.nursery.start_soon(catch_everything_and_restart, check_pending_payments) trio.open_process(check_pending_payments)
app.nursery.start_soon(catch_everything_and_restart, invoice_listener) trio.open_process(invoice_listener)
app.nursery.start_soon(catch_everything_and_restart, internal_invoice_listener) trio.open_process(internal_invoice_listener)
async with trio.open_nursery() as n:
pass
# n.start_soon(catch_everything_and_restart, check_pending_payments)
# n.start_soon(catch_everything_and_restart, invoice_listener)
# n.start_soon(catch_everything_and_restart, internal_invoice_listener)
@app.after_serving @app.on_event("shutdown")
async def stop_listeners(): async def stop_listeners():
pass pass

49
lnbits/auth_bearer.py Normal file
View file

@ -0,0 +1,49 @@
from fastapi import Request, HTTPException
from fastapi.security.api_key import APIKeyQuery, APIKeyCookie, APIKeyHeader, APIKey
# https://medium.com/data-rebels/fastapi-authentication-revisited-enabling-api-key-authentication-122dc5975680
from fastapi import Security, Depends, FastAPI, HTTPException
from fastapi.security.api_key import APIKeyQuery, APIKeyCookie, APIKeyHeader, APIKey
from fastapi.security.base import SecurityBase
API_KEY = "usr"
API_KEY_NAME = "X-API-key"
api_key_query = APIKeyQuery(name=API_KEY_NAME, auto_error=False)
api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False)
class AuthBearer(SecurityBase):
def __init__(self, scheme_name: str = None, auto_error: bool = True):
self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error
async def __call__(self, request: Request):
key = await self.get_api_key()
print(key)
# credentials: HTTPAuthorizationCredentials = await super(AuthBearer, self).__call__(request)
# if credentials:
# if not credentials.scheme == "Bearer":
# raise HTTPException(
# status_code=403, detail="Invalid authentication scheme.")
# if not self.verify_jwt(credentials.credentials):
# raise HTTPException(
# status_code=403, detail="Invalid token or expired token.")
# return credentials.credentials
# else:
# raise HTTPException(
# status_code=403, detail="Invalid authorization code.")
async def get_api_key(self,
api_key_query: str = Security(api_key_query),
api_key_header: str = Security(api_key_header),
):
if api_key_query == API_KEY:
return api_key_query
elif api_key_header == API_KEY:
return api_key_header
else:
raise HTTPException(status_code=403, detail="Could not validate credentials")

View file

@ -1,22 +1,19 @@
from quart import Blueprint from fastapi.routing import APIRouter
from lnbits.db import Database from lnbits.db import Database
db = Database("database") db = Database("database")
core_app: Blueprint = Blueprint( core_app: APIRouter = APIRouter()
"core",
__name__,
template_folder="templates",
static_folder="static",
static_url_path="/core/static",
)
from .views.api import * # noqa
from .views.generic import * # noqa
from .views.public_api import * # noqa
from .tasks import register_listeners
from lnbits.tasks import record_async from lnbits.tasks import record_async
core_app.record(record_async(register_listeners)) from .tasks import register_listeners
from .views.api import * # noqa
from .views.generic import * # noqa
from .views.public_api import * # noqa
@core_app.on_event("startup")
def do_startup():
record_async(register_listeners)

View file

@ -1,10 +1,12 @@
from fastapi.param_functions import Depends
from lnbits.auth_bearer import AuthBearer
from pydantic import BaseModel from pydantic import BaseModel
import trio import trio
import json import json
import httpx import httpx
import hashlib import hashlib
from urllib.parse import urlparse, urlunparse, urlencode, parse_qs, ParseResult from urllib.parse import urlparse, urlunparse, urlencode, parse_qs, ParseResult
from quart import g, current_app, make_response, url_for from quart import current_app, make_response, url_for
from fastapi import Query from fastapi import Query
@ -15,6 +17,7 @@ from typing import Dict, List, Optional, Union
from lnbits import bolt11, lnurl from lnbits import bolt11, lnurl
from lnbits.decorators import api_check_wallet_key, api_validate_post_request from lnbits.decorators import api_check_wallet_key, api_validate_post_request
from lnbits.utils.exchange_rates import currencies, fiat_amount_as_satoshis from lnbits.utils.exchange_rates import currencies, fiat_amount_as_satoshis
from lnbits.requestvars import g
from .. import core_app, db from .. import core_app, db
from ..crud import get_payments, save_balance_check, update_wallet from ..crud import get_payments, save_balance_check, update_wallet
@ -28,11 +31,14 @@ from ..services import (
from ..tasks import api_invoice_listeners from ..tasks import api_invoice_listeners
@core_app.get("/api/v1/wallet") @core_app.get(
@api_check_wallet_key("invoice") "/api/v1/wallet",
# dependencies=[Depends(AuthBearer())]
)
# @api_check_wallet_key("invoice")
async def api_wallet(): async def api_wallet():
return ( return (
{"id": g.wallet.id, "name": g.wallet.name, "balance": g.wallet.balance_msat}, {"id": g().wallet.id, "name": g().wallet.name, "balance": g().wallet.balance_msat},
HTTPStatus.OK, HTTPStatus.OK,
) )
@ -40,12 +46,12 @@ async def api_wallet():
@core_app.put("/api/v1/wallet/<new_name>") @core_app.put("/api/v1/wallet/<new_name>")
@api_check_wallet_key("invoice") @api_check_wallet_key("invoice")
async def api_update_wallet(new_name: str): async def api_update_wallet(new_name: str):
await update_wallet(g.wallet.id, new_name) await update_wallet(g().wallet.id, new_name)
return ( return (
{ {
"id": g.wallet.id, "id": g().wallet.id,
"name": g.wallet.name, "name": g().wallet.name,
"balance": g.wallet.balance_msat, "balance": g().wallet.balance_msat,
}, },
HTTPStatus.OK, HTTPStatus.OK,
) )
@ -55,7 +61,7 @@ async def api_update_wallet(new_name: str):
@api_check_wallet_key("invoice") @api_check_wallet_key("invoice")
async def api_payments(): async def api_payments():
return ( return (
await get_payments(wallet_id=g.wallet.id, pending=True, complete=True), await get_payments(wallet_id=g().wallet.id, pending=True, complete=True),
HTTPStatus.OK, HTTPStatus.OK,
) )
@ -88,7 +94,7 @@ async def api_payments_create_invoice(data: CreateInvoiceData):
async with db.connect() as conn: async with db.connect() as conn:
try: try:
payment_hash, payment_request = await create_invoice( payment_hash, payment_request = await create_invoice(
wallet_id=g.wallet.id, wallet_id=g().wallet.id,
amount=amount, amount=amount,
memo=memo, memo=memo,
description_hash=description_hash, description_hash=description_hash,
@ -105,8 +111,8 @@ async def api_payments_create_invoice(data: CreateInvoiceData):
lnurl_response: Union[None, bool, str] = None lnurl_response: Union[None, bool, str] = None
if data.lnurl_callback: if data.lnurl_callback:
if "lnurl_balance_check" in g.data: if "lnurl_balance_check" in g().data:
save_balance_check(g.wallet.id, data.lnurl_balance_check) save_balance_check(g().wallet.id, data.lnurl_balance_check)
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
try: try:
@ -117,7 +123,7 @@ async def api_payments_create_invoice(data: CreateInvoiceData):
"balanceNotify": url_for( "balanceNotify": url_for(
"core.lnurl_balance_notify", "core.lnurl_balance_notify",
service=urlparse(data.lnurl_callback).netloc, service=urlparse(data.lnurl_callback).netloc,
wal=g.wallet.id, wal=g().wallet.id,
_external=True, _external=True,
), ),
}, },
@ -217,14 +223,14 @@ async def api_payments_pay_lnurl(data: CreateLNURLData):
if invoice.amount_msat != data.amount: if invoice.amount_msat != data.amount:
return ( return (
{ {
"message": f"{domain} returned an invalid invoice. Expected {g.data['amount']} msat, got {invoice.amount_msat}." "message": f"{domain} returned an invalid invoice. Expected {g().data['amount']} msat, got {invoice.amount_msat}."
}, },
HTTPStatus.BAD_REQUEST, HTTPStatus.BAD_REQUEST,
) )
if invoice.description_hash != g.data["description_hash"]: if invoice.description_hash != g().data["description_hash"]:
return ( return (
{ {
"message": f"{domain} returned an invalid invoice. Expected description_hash == {g.data['description_hash']}, got {invoice.description_hash}." "message": f"{domain} returned an invalid invoice. Expected description_hash == {g().data['description_hash']}, got {invoice.description_hash}."
}, },
HTTPStatus.BAD_REQUEST, HTTPStatus.BAD_REQUEST,
) )
@ -237,7 +243,7 @@ async def api_payments_pay_lnurl(data: CreateLNURLData):
extra["comment"] = data.comment extra["comment"] = data.comment
payment_hash = await pay_invoice( payment_hash = await pay_invoice(
wallet_id=g.wallet.id, wallet_id=g().wallet.id,
payment_request=params["pr"], payment_request=params["pr"],
description=data.description, description=data.description,
extra=extra, extra=extra,
@ -257,7 +263,7 @@ async def api_payments_pay_lnurl(data: CreateLNURLData):
@core_app.get("/api/v1/payments/<payment_hash>") @core_app.get("/api/v1/payments/<payment_hash>")
@api_check_wallet_key("invoice") @api_check_wallet_key("invoice")
async def api_payment(payment_hash): async def api_payment(payment_hash):
payment = await g.wallet.get_payment(payment_hash) payment = await g().wallet.get_payment(payment_hash)
if not payment: if not payment:
return {"message": "Payment does not exist."}, HTTPStatus.NOT_FOUND return {"message": "Payment does not exist."}, HTTPStatus.NOT_FOUND
@ -278,7 +284,7 @@ async def api_payment(payment_hash):
@core_app.get("/api/v1/payments/sse") @core_app.get("/api/v1/payments/sse")
@api_check_wallet_key("invoice", accept_querystring=True) @api_check_wallet_key("invoice", accept_querystring=True)
async def api_payments_sse(): async def api_payments_sse():
this_wallet_id = g.wallet.id this_wallet_id = g().wallet.id
send_payment, receive_payment = trio.open_memory_channel(0) send_payment, receive_payment = trio.open_memory_channel(0)
@ -356,7 +362,7 @@ async def api_lnurlscan(code: str):
params.update(kind="auth") params.update(kind="auth")
params.update(callback=url) # with k1 already in it params.update(callback=url) # with k1 already in it
lnurlauth_key = g.wallet.lnurlauth_key(domain) lnurlauth_key = g().wallet.lnurlauth_key(domain)
params.update(pubkey=lnurlauth_key.verifying_key.to_string("compressed").hex()) params.update(pubkey=lnurlauth_key.verifying_key.to_string("compressed").hex())
else: else:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:

View file

@ -1,15 +1,10 @@
from lnbits.requestvars import g
from os import path from os import path
from http import HTTPStatus from http import HTTPStatus
from quart import ( from typing import Optional
g, import jinja2
current_app,
abort, from starlette.responses import HTMLResponse
request,
redirect,
render_template,
send_from_directory,
url_for,
)
from lnbits.core import core_app, db from lnbits.core import core_app, db
from lnbits.decorators import check_user_exists, validate_uuids from lnbits.decorators import check_user_exists, validate_uuids
@ -26,20 +21,18 @@ from ..crud import (
) )
from ..services import redeem_lnurl_withdraw, pay_invoice from ..services import redeem_lnurl_withdraw, pay_invoice
from fastapi import FastAPI, Request from fastapi import FastAPI, Request
from fastapi.templating import Jinja2Templates from fastapi.responses import FileResponse
from lnbits.jinja2_templating import Jinja2Templates
templates = Jinja2Templates(directory="templates")
@core_app.get("/favicon.ico") @core_app.get("/favicon.ico")
async def favicon(): async def favicon():
return await send_from_directory( return FileResponse("lnbits/core/static/favicon.ico")
path.join(core_app.root_path, "static"), "favicon.ico"
)
@core_app.get("/", response_class=HTMLResponse)
@core_app.get("/")
async def home(request: Request, lightning: str = None): async def home(request: Request, lightning: str = None):
return templates.TemplateResponse("core/index.html", {"request": request, "lnurl": lightning}) return g().templates.TemplateResponse("core/index.html", {"request": request, "lnurl": lightning})
@core_app.get("/extensions") @core_app.get("/extensions")

View file

@ -7,7 +7,7 @@ from uuid import UUID
from lnbits.core.crud import get_user, get_wallet_for_key from lnbits.core.crud import get_user, get_wallet_for_key
from lnbits.settings import LNBITS_ALLOWED_USERS from lnbits.settings import LNBITS_ALLOWED_USERS
from lnbits.requestvars import g
def api_check_wallet_key(key_type: str = "invoice", accept_querystring=False): def api_check_wallet_key(key_type: str = "invoice", accept_querystring=False):
def wrap(view): def wrap(view):
@ -15,14 +15,14 @@ def api_check_wallet_key(key_type: str = "invoice", accept_querystring=False):
async def wrapped_view(**kwargs): async def wrapped_view(**kwargs):
try: try:
key_value = request.headers.get("X-Api-Key") or request.args["api-key"] key_value = request.headers.get("X-Api-Key") or request.args["api-key"]
g.wallet = await get_wallet_for_key(key_value, key_type) g().wallet = await get_wallet_for_key(key_value, key_type)
except KeyError: except KeyError:
return ( return (
jsonify({"message": "`X-Api-Key` header missing."}), jsonify({"message": "`X-Api-Key` header missing."}),
HTTPStatus.BAD_REQUEST, HTTPStatus.BAD_REQUEST,
) )
if not g.wallet: if not g().wallet:
return jsonify({"message": "Wrong keys."}), HTTPStatus.UNAUTHORIZED return jsonify({"message": "Wrong keys."}), HTTPStatus.UNAUTHORIZED
return await view(**kwargs) return await view(**kwargs)
@ -44,9 +44,9 @@ def api_validate_post_request(*, schema: dict):
v = Validator(schema) v = Validator(schema)
data = await request.get_json() data = await request.get_json()
g.data = {key: data[key] for key in schema.keys() if key in data} g().data = {key: data[key] for key in schema.keys() if key in data}
if not v.validate(g.data): if not v.validate(g().data):
return ( return (
jsonify({"message": f"Errors in request data: {v.errors}"}), jsonify({"message": f"Errors in request data: {v.errors}"}),
HTTPStatus.BAD_REQUEST, HTTPStatus.BAD_REQUEST,
@ -63,11 +63,11 @@ def check_user_exists(param: str = "usr"):
def wrap(view): def wrap(view):
@wraps(view) @wraps(view)
async def wrapped_view(**kwargs): async def wrapped_view(**kwargs):
g.user = await get_user(request.args.get(param, type=str)) or abort( g().user = await get_user(request.args.get(param, type=str)) or abort(
HTTPStatus.NOT_FOUND, "User does not exist." HTTPStatus.NOT_FOUND, "User does not exist."
) )
if LNBITS_ALLOWED_USERS and g.user.id not in LNBITS_ALLOWED_USERS: if LNBITS_ALLOWED_USERS and g().user.id not in LNBITS_ALLOWED_USERS:
abort(HTTPStatus.UNAUTHORIZED, "User not authorized.") abort(HTTPStatus.UNAUTHORIZED, "User not authorized.")
return await view(**kwargs) return await view(**kwargs)

View file

@ -2,6 +2,11 @@ import time
from base64 import urlsafe_b64encode from base64 import urlsafe_b64encode
from quart import jsonify, g, request from quart import jsonify, g, request
from fastapi import FastAPI, Query
from fastapi.encoders import jsonable_encoder
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from lnbits.core.services import pay_invoice, create_invoice from lnbits.core.services import pay_invoice, create_invoice
from lnbits.core.crud import get_payments, delete_expired_invoices from lnbits.core.crud import get_payments, delete_expired_invoices
from lnbits.decorators import api_validate_post_request from lnbits.decorators import api_validate_post_request
@ -13,62 +18,62 @@ from .decorators import check_wallet
from .utils import to_buffer, decoded_as_lndhub from .utils import to_buffer, decoded_as_lndhub
@lndhub_ext.route("/ext/getinfo", methods=["GET"]) @lndhub_ext.get("/ext/getinfo")
async def lndhub_getinfo(): async def lndhub_getinfo():
return jsonify({"error": True, "code": 1, "message": "bad auth"}) return {"error": True, "code": 1, "message": "bad auth"}
@lndhub_ext.route("/ext/auth", methods=["POST"]) @lndhub_ext.post("/ext/auth")
@api_validate_post_request( # @api_validate_post_request(
schema={ # schema={
"login": {"type": "string", "required": True, "excludes": "refresh_token"}, # "login": {"type": "string", "required": True, "excludes": "refresh_token"},
"password": {"type": "string", "required": True, "excludes": "refresh_token"}, # "password": {"type": "string", "required": True, "excludes": "refresh_token"},
"refresh_token": { # "refresh_token": {
"type": "string", # "type": "string",
"required": True, # "required": True,
"excludes": ["login", "password"], # "excludes": ["login", "password"],
}, # },
} # }
) # )
async def lndhub_auth(): async def lndhub_auth(login: str, password: str, refresh_token: str): #missing the "excludes" thing
token = ( token = (
g.data["refresh_token"] refresh_token
if "refresh_token" in g.data and g.data["refresh_token"] if refresh_token
else urlsafe_b64encode( else urlsafe_b64encode(
(g.data["login"] + ":" + g.data["password"]).encode("utf-8") (login + ":" + password).encode("utf-8")
).decode("ascii") ).decode("ascii")
) )
return jsonify({"refresh_token": token, "access_token": token}) return {"refresh_token": token, "access_token": token}
@lndhub_ext.route("/ext/addinvoice", methods=["POST"]) @lndhub_ext.post("/ext/addinvoice")
@check_wallet() @check_wallet()
@api_validate_post_request( # @api_validate_post_request(
schema={ # schema={
"amt": {"type": "string", "required": True}, # "amt": {"type": "string", "required": True},
"memo": {"type": "string", "required": True}, # "memo": {"type": "string", "required": True},
"preimage": {"type": "string", "required": False}, # "preimage": {"type": "string", "required": False},
} # }
) # )
async def lndhub_addinvoice(): async def lndhub_addinvoice(amt: str, memo: str, preimage: str = ""):
try: try:
_, pr = await create_invoice( _, pr = await create_invoice(
wallet_id=g.wallet.id, wallet_id=g.wallet.id,
amount=int(g.data["amt"]), amount=int(amt),
memo=g.data["memo"], memo=memo,
extra={"tag": "lndhub"}, extra={"tag": "lndhub"},
) )
except Exception as e: except Exception as e:
return jsonify( return
{ {
"error": True, "error": True,
"code": 7, "code": 7,
"message": "Failed to create invoice: " + str(e), "message": "Failed to create invoice: " + str(e),
} }
)
invoice = bolt11.decode(pr) invoice = bolt11.decode(pr)
return jsonify( return
{ {
"pay_req": pr, "pay_req": pr,
"payment_request": pr, "payment_request": pr,
@ -76,30 +81,30 @@ async def lndhub_addinvoice():
"r_hash": to_buffer(invoice.payment_hash), "r_hash": to_buffer(invoice.payment_hash),
"hash": invoice.payment_hash, "hash": invoice.payment_hash,
} }
)
@lndhub_ext.route("/ext/payinvoice", methods=["POST"])
@lndhub_ext.post("/ext/payinvoice")
@check_wallet(requires_admin=True) @check_wallet(requires_admin=True)
@api_validate_post_request(schema={"invoice": {"type": "string", "required": True}}) # @api_validate_post_request(schema={"invoice": {"type": "string", "required": True}})
async def lndhub_payinvoice(): async def lndhub_payinvoice(invoice: str):
try: try:
await pay_invoice( await pay_invoice(
wallet_id=g.wallet.id, wallet_id=g.wallet.id,
payment_request=g.data["invoice"], payment_request=invoice,
extra={"tag": "lndhub"}, extra={"tag": "lndhub"},
) )
except Exception as e: except Exception as e:
return jsonify( return
{ {
"error": True, "error": True,
"code": 10, "code": 10,
"message": "Payment failed: " + str(e), "message": "Payment failed: " + str(e),
} }
)
invoice: bolt11.Invoice = bolt11.decode(g.data["invoice"])
return jsonify( invoice: bolt11.Invoice = bolt11.decode(invoice)
return
{ {
"payment_error": "", "payment_error": "",
"payment_preimage": "0" * 64, "payment_preimage": "0" * 64,
@ -113,16 +118,16 @@ async def lndhub_payinvoice():
"timestamp": int(time.time()), "timestamp": int(time.time()),
"memo": invoice.description, "memo": invoice.description,
} }
)
@lndhub_ext.route("/ext/balance", methods=["GET"])
@lndhub_ext.get("/ext/balance")
@check_wallet() @check_wallet()
async def lndhub_balance(): async def lndhub_balance():
return jsonify({"BTC": {"AvailableBalance": g.wallet.balance}}) return {"BTC": {"AvailableBalance": g.wallet.balance}}
@lndhub_ext.route("/ext/gettxs", methods=["GET"]) @lndhub_ext.get("/ext/gettxs")
@check_wallet() @check_wallet()
async def lndhub_gettxs(): async def lndhub_gettxs():
for payment in await get_payments( for payment in await get_payments(
@ -138,7 +143,7 @@ async def lndhub_gettxs():
) )
limit = int(request.args.get("limit", 200)) limit = int(request.args.get("limit", 200))
return jsonify( return
[ [
{ {
"payment_preimage": payment.preimage, "payment_preimage": payment.preimage,
@ -164,10 +169,10 @@ async def lndhub_gettxs():
)[:limit] )[:limit]
) )
] ]
)
@lndhub_ext.route("/ext/getuserinvoices", methods=["GET"])
@lndhub_ext.get("/ext/getuserinvoices")
@check_wallet() @check_wallet()
async def lndhub_getuserinvoices(): async def lndhub_getuserinvoices():
await delete_expired_invoices() await delete_expired_invoices()
@ -184,7 +189,7 @@ async def lndhub_getuserinvoices():
) )
limit = int(request.args.get("limit", 200)) limit = int(request.args.get("limit", 200))
return jsonify( return
[ [
{ {
"r_hash": to_buffer(invoice.payment_hash), "r_hash": to_buffer(invoice.payment_hash),
@ -210,31 +215,31 @@ async def lndhub_getuserinvoices():
)[:limit] )[:limit]
) )
] ]
)
@lndhub_ext.route("/ext/getbtc", methods=["GET"])
@lndhub_ext.get("/ext/getbtc")
@check_wallet() @check_wallet()
async def lndhub_getbtc(): async def lndhub_getbtc():
"load an address for incoming onchain btc" "load an address for incoming onchain btc"
return jsonify([]) return []
@lndhub_ext.route("/ext/getpending", methods=["GET"]) @lndhub_ext.get("/ext/getpending")
@check_wallet() @check_wallet()
async def lndhub_getpending(): async def lndhub_getpending():
"pending onchain transactions" "pending onchain transactions"
return jsonify([]) return []
@lndhub_ext.route("/ext/decodeinvoice", methods=["GET"]) @lndhub_ext.get("/ext/decodeinvoice")
async def lndhub_decodeinvoice(): async def lndhub_decodeinvoice():
invoice = request.args.get("invoice") invoice = request.args.get("invoice")
inv = bolt11.decode(invoice) inv = bolt11.decode(invoice)
return jsonify(decoded_as_lndhub(inv)) return decoded_as_lndhub(inv)
@lndhub_ext.route("/ext/checkrouteinvoice", methods=["GET"]) @lndhub_ext.get("/ext/checkrouteinvoice")
async def lndhub_checkrouteinvoice(): async def lndhub_checkrouteinvoice():
"not implemented on canonical lndhub" "not implemented on canonical lndhub"
pass pass

View file

@ -48,7 +48,7 @@ async def api_links():
) )
@lnurlp_ext.get("/api/v1/links/<link_id>") @lnurlp_ext.get("/api/v1/links/{link_id}")
@api_check_wallet_key("invoice") @api_check_wallet_key("invoice")
async def api_link_retrieve(link_id): async def api_link_retrieve(link_id):
link = await get_pay_link(link_id) link = await get_pay_link(link_id)
@ -65,14 +65,14 @@ class CreateData(BaseModel):
description: str description: str
min: int = Query(0.01, ge=0.01) min: int = Query(0.01, ge=0.01)
max: int = Query(0.01, ge=0.01) max: int = Query(0.01, ge=0.01)
currency: Optional[str] currency: Optional[str]
comment_chars: int = Query(0, ge=0, lt=800) comment_chars: int = Query(0, ge=0, lt=800)
webhook_url: Optional[str] webhook_url: Optional[str]
success_text: Optional[str] success_text: Optional[str]
success_url: Optional[str] success_url: Optional[str]
@lnurlp_ext.post("/api/v1/links") @lnurlp_ext.post("/api/v1/links")
@lnurlp_ext.put("/api/v1/links/<link_id>") @lnurlp_ext.put("/api/v1/links/{link_id}")
@api_check_wallet_key("invoice") @api_check_wallet_key("invoice")
async def api_link_create_or_update(data: CreateData, link_id=None): async def api_link_create_or_update(data: CreateData, link_id=None):
if data.min > data.max: if data.min > data.max:
@ -111,7 +111,7 @@ async def api_link_create_or_update(data: CreateData, link_id=None):
) )
@lnurlp_ext.delete("/api/v1/links/<link_id>") @lnurlp_ext.delete("/api/v1/links/{link_id}")
@api_check_wallet_key("invoice") @api_check_wallet_key("invoice")
async def api_link_delete(link_id): async def api_link_delete(link_id):
link = await get_pay_link(link_id) link = await get_pay_link(link_id)
@ -127,7 +127,7 @@ async def api_link_delete(link_id):
return "", HTTPStatus.NO_CONTENT return "", HTTPStatus.NO_CONTENT
@lnurlp_ext.get("/api/v1/rate/<currency>") @lnurlp_ext.get("/api/v1/rate/{currency}")
async def api_check_fiat_rate(currency): async def api_check_fiat_rate(currency):
try: try:
rate = await get_fiat_rate_satoshis(currency) rate = await get_fiat_rate_satoshis(currency)

View file

@ -1,11 +1,12 @@
from quart import Blueprint from fastapi import APIRouter
from lnbits.db import Database from lnbits.db import Database
db = Database("ext_offlineshop") db = Database("ext_offlineshop")
offlineshop_ext: Blueprint = Blueprint( offlineshop_ext: APIRouter = APIRouter(
"offlineshop", __name__, static_folder="static", template_folder="templates" prefix="/Extension",
tags=["Offlineshop"]
) )

View file

@ -53,11 +53,11 @@ class CreateItemsData(BaseModel):
name: str name: str
description: str description: str
image: Optional[str] image: Optional[str]
price: int price: int
unit: str unit: str
@offlineshop_ext.post("/api/v1/offlineshop/items") @offlineshop_ext.post("/api/v1/offlineshop/items")
@offlineshop_ext.put("/api/v1/offlineshop/items/<item_id>") @offlineshop_ext.put("/api/v1/offlineshop/items/{item_id}")
@api_check_wallet_key("invoice") @api_check_wallet_key("invoice")
async def api_add_or_update_item(data: CreateItemsData, item_id=None): async def api_add_or_update_item(data: CreateItemsData, item_id=None):
shop = await get_or_create_shop_by_wallet(g.wallet.id) shop = await get_or_create_shop_by_wallet(g.wallet.id)
@ -84,7 +84,7 @@ async def api_add_or_update_item(data: CreateItemsData, item_id=None):
return "", HTTPStatus.OK return "", HTTPStatus.OK
@offlineshop_ext.delete("/api/v1/offlineshop/items/<item_id>") @offlineshop_ext.delete("/api/v1/offlineshop/items/{item_id}")
@api_check_wallet_key("invoice") @api_check_wallet_key("invoice")
async def api_delete_item(item_id): async def api_delete_item(item_id):
shop = await get_or_create_shop_by_wallet(g.wallet.id) shop = await get_or_create_shop_by_wallet(g.wallet.id)

View file

@ -29,8 +29,8 @@ class CreateData(BaseModel):
url: Optional[str] = Query(...) url: Optional[str] = Query(...)
memo: Optional[str] = Query(...) memo: Optional[str] = Query(...)
description: str description: str
amount: int amount: int
remembers: bool remembers: bool
@paywall_ext.post("/api/v1/paywalls") @paywall_ext.post("/api/v1/paywalls")
@api_check_wallet_key("invoice") @api_check_wallet_key("invoice")
@ -39,7 +39,7 @@ async def api_paywall_create(data: CreateData):
return paywall, HTTPStatus.CREATED return paywall, HTTPStatus.CREATED
@paywall_ext.delete("/api/v1/paywalls/<paywall_id>") @paywall_ext.delete("/api/v1/paywalls/{paywall_id}")
@api_check_wallet_key("invoice") @api_check_wallet_key("invoice")
async def api_paywall_delete(paywall_id): async def api_paywall_delete(paywall_id):
paywall = await get_paywall(paywall_id) paywall = await get_paywall(paywall_id)
@ -55,7 +55,7 @@ async def api_paywall_delete(paywall_id):
return "", HTTPStatus.NO_CONTENT return "", HTTPStatus.NO_CONTENT
@paywall_ext.post("/api/v1/paywalls/<paywall_id>/invoice") @paywall_ext.post("/api/v1/paywalls/{paywall_id}/invoice")
async def api_paywall_create_invoice(amount: int = Query(..., ge=1), paywall_id = None): async def api_paywall_create_invoice(amount: int = Query(..., ge=1), paywall_id = None):
paywall = await get_paywall(paywall_id) paywall = await get_paywall(paywall_id)
@ -76,26 +76,26 @@ async def api_paywall_create_invoice(amount: int = Query(..., ge=1), paywall_id
extra={"tag": "paywall"}, extra={"tag": "paywall"},
) )
except Exception as e: except Exception as e:
return jsonable_encoder({"message": str(e)}), HTTPStatus.INTERNAL_SERVER_ERROR return {"message": str(e)}, HTTPStatus.INTERNAL_SERVER_ERROR
return ( return (
jsonable_encoder({"payment_hash": payment_hash, "payment_request": payment_request}), {"payment_hash": payment_hash, "payment_request": payment_request},
HTTPStatus.CREATED, HTTPStatus.CREATED,
) )
@paywall_ext.post("/api/v1/paywalls/<paywall_id>/check_invoice") @paywall_ext.post("/api/v1/paywalls/{paywall_id}/check_invoice")
async def api_paywal_check_invoice(payment_hash: str = Query(...), paywall_id = None): async def api_paywal_check_invoice(payment_hash: str = Query(...), paywall_id = None):
paywall = await get_paywall(paywall_id) paywall = await get_paywall(paywall_id)
if not paywall: if not paywall:
return jsonable_encoder({"message": "Paywall does not exist."}), HTTPStatus.NOT_FOUND return {"message": "Paywall does not exist."}, HTTPStatus.NOT_FOUND
try: try:
status = await check_invoice_status(paywall.wallet, payment_hash) status = await check_invoice_status(paywall.wallet, payment_hash)
is_paid = not status.pending is_paid = not status.pending
except Exception: except Exception:
return jsonable_encoder({"paid": False}), HTTPStatus.OK return {"paid": False}, HTTPStatus.OK
if is_paid: if is_paid:
wallet = await get_wallet(paywall.wallet) wallet = await get_wallet(paywall.wallet)
@ -103,8 +103,8 @@ async def api_paywal_check_invoice(payment_hash: str = Query(...), paywall_id =
await payment.set_pending(False) await payment.set_pending(False)
return ( return (
jsonable_encoder({"paid": True, "url": paywall.url, "remembers": paywall.remembers}), {"paid": True, "url": paywall.url, "remembers": paywall.remembers},
HTTPStatus.OK, HTTPStatus.OK,
) )
return jsonable_encoder({"paid": False}), HTTPStatus.OK return {"paid": False}, HTTPStatus.OK

View file

@ -46,7 +46,7 @@ async def api_create_service(data: CreateServicesData):
return service._asdict(), HTTPStatus.CREATED return service._asdict(), HTTPStatus.CREATED
@streamalerts_ext.get("/api/v1/getaccess/<service_id>") @streamalerts_ext.get("/api/v1/getaccess/{service_id}")
async def api_get_access(service_id): async def api_get_access(service_id):
"""Redirect to Streamlabs' Approve/Decline page for API access for Service """Redirect to Streamlabs' Approve/Decline page for API access for Service
with service_id with service_id
@ -69,7 +69,7 @@ async def api_get_access(service_id):
return ({"message": "Service does not exist!"}, HTTPStatus.BAD_REQUEST) return ({"message": "Service does not exist!"}, HTTPStatus.BAD_REQUEST)
@streamalerts_ext.get("/api/v1/authenticate/<service_id>") @streamalerts_ext.get("/api/v1/authenticate/{service_id}")
async def api_authenticate_service(Code: str, State: str, service_id): async def api_authenticate_service(Code: str, State: str, service_id):
"""Endpoint visited via redirect during third party API authentication """Endpoint visited via redirect during third party API authentication
@ -183,7 +183,7 @@ async def api_get_donations():
) )
@streamalerts_ext.put("/api/v1/donations/<donation_id>") @streamalerts_ext.put("/api/v1/donations/{donation_id}")
@api_check_wallet_key("invoice") @api_check_wallet_key("invoice")
async def api_update_donation(donation_id=None): async def api_update_donation(donation_id=None):
"""Update a donation with the data given in the request""" """Update a donation with the data given in the request"""
@ -208,7 +208,7 @@ async def api_update_donation(donation_id=None):
return donation._asdict(), HTTPStatus.CREATED return donation._asdict(), HTTPStatus.CREATED
@streamalerts_ext.put("/api/v1/services/<service_id>") @streamalerts_ext.put("/api/v1/services/{service_id}")
@api_check_wallet_key("invoice") @api_check_wallet_key("invoice")
async def api_update_service(service_id=None): async def api_update_service(service_id=None):
"""Update a service with the data given in the request""" """Update a service with the data given in the request"""
@ -229,7 +229,7 @@ async def api_update_service(service_id=None):
return service._asdict(), HTTPStatus.CREATED return service._asdict(), HTTPStatus.CREATED
@streamalerts_ext.delete("/api/v1/donations/<donation_id>") @streamalerts_ext.delete("/api/v1/donations/{donation_id}")
@api_check_wallet_key("invoice") @api_check_wallet_key("invoice")
async def api_delete_donation(donation_id): async def api_delete_donation(donation_id):
"""Delete the donation with the given donation_id""" """Delete the donation with the given donation_id"""
@ -245,7 +245,7 @@ async def api_delete_donation(donation_id):
return "", HTTPStatus.NO_CONTENT return "", HTTPStatus.NO_CONTENT
@streamalerts_ext.delete("/api/v1/services/<service_id>") @streamalerts_ext.delete("/api/v1/services/{service_id}")
@api_check_wallet_key("invoice") @api_check_wallet_key("invoice")
async def api_delete_service(service_id): async def api_delete_service(service_id):
"""Delete the service with the given service_id""" """Delete the service with the given service_id"""

View file

@ -51,7 +51,7 @@ class CreateDomainsData(BaseModel):
allowed_record_types: str allowed_record_types: str
@subdomains_ext.post("/api/v1/domains") @subdomains_ext.post("/api/v1/domains")
@subdomains_ext.put("/api/v1/domains/<domain_id>") @subdomains_ext.put("/api/v1/domains/{domain_id}")
@api_check_wallet_key("invoice") @api_check_wallet_key("invoice")
async def api_domain_create(data: CreateDomainsData, domain_id=None): async def api_domain_create(data: CreateDomainsData, domain_id=None):
if domain_id: if domain_id:
@ -66,10 +66,10 @@ async def api_domain_create(data: CreateDomainsData, domain_id=None):
domain = await update_domain(domain_id, **data) domain = await update_domain(domain_id, **data)
else: else:
domain = await create_domain(**data) domain = await create_domain(**data)
return jsonify(domain._asdict()), HTTPStatus.CREATED return domain._asdict(), HTTPStatus.CREATED
@subdomains_ext.delete("/api/v1/domains/<domain_id>") @subdomains_ext.delete("/api/v1/domains/{domain_id}")
@api_check_wallet_key("invoice") @api_check_wallet_key("invoice")
async def api_domain_delete(domain_id): async def api_domain_delete(domain_id):
domain = await get_domain(domain_id) domain = await get_domain(domain_id)
@ -110,14 +110,14 @@ class CreateDomainsData(BaseModel):
duration: int duration: int
record_type: str record_type: str
@subdomains_ext.post("/api/v1/subdomains/<domain_id>") @subdomains_ext.post("/api/v1/subdomains/{domain_id}")
async def api_subdomain_make_subdomain(data: CreateDomainsData, domain_id): async def api_subdomain_make_subdomain(data: CreateDomainsData, domain_id):
domain = await get_domain(domain_id) domain = await get_domain(domain_id)
# If the request is coming for the non-existant domain # If the request is coming for the non-existant domain
if not domain: if not domain:
return jsonify({"message": "LNsubdomain does not exist."}), HTTPStatus.NOT_FOUND return {"message": "LNsubdomain does not exist."}, HTTPStatus.NOT_FOUND
## If record_type is not one of the allowed ones reject the request ## If record_type is not one of the allowed ones reject the request
if data.record_type not in domain.allowed_record_types: if data.record_type not in domain.allowed_record_types:
@ -184,7 +184,7 @@ async def api_subdomain_make_subdomain(data: CreateDomainsData, domain_id):
) )
@subdomains_ext.get("/api/v1/subdomains/<payment_hash>") @subdomains_ext.get("/api/v1/subdomains/{payment_hash}")
async def api_subdomain_send_subdomain(payment_hash): async def api_subdomain_send_subdomain(payment_hash):
subdomain = await get_subdomain(payment_hash) subdomain = await get_subdomain(payment_hash)
try: try:
@ -199,7 +199,7 @@ async def api_subdomain_send_subdomain(payment_hash):
return {"paid": False}, HTTPStatus.OK return {"paid": False}, HTTPStatus.OK
@subdomains_ext.delete("/api/v1/subdomains/<subdomain_id>") @subdomains_ext.delete("/api/v1/subdomains/{subdomain_id}")
@api_check_wallet_key("invoice") @api_check_wallet_key("invoice")
async def api_subdomain_delete(subdomain_id): async def api_subdomain_delete(subdomain_id):
subdomain = await get_subdomain(subdomain_id) subdomain = await get_subdomain(subdomain_id)

View file

@ -37,7 +37,7 @@ async def api_wallets_retrieve():
return "" return ""
@watchonly_ext.get("/api/v1/wallet/<wallet_id>") @watchonly_ext.get("/api/v1/wallet/{wallet_id}")
@api_check_wallet_key("invoice") @api_check_wallet_key("invoice")
async def api_wallet_retrieve(wallet_id): async def api_wallet_retrieve(wallet_id):
wallet = await get_watch_wallet(wallet_id) wallet = await get_watch_wallet(wallet_id)
@ -63,7 +63,7 @@ async def api_wallet_create_or_update(masterPub: str, Title: str, wallet_id=None
return wallet._asdict(), HTTPStatus.CREATED return wallet._asdict(), HTTPStatus.CREATED
@watchonly_ext.delete("/api/v1/wallet/<wallet_id>") @watchonly_ext.delete("/api/v1/wallet/{wallet_id}")
@api_check_wallet_key("admin") @api_check_wallet_key("admin")
async def api_wallet_delete(wallet_id): async def api_wallet_delete(wallet_id):
wallet = await get_watch_wallet(wallet_id) wallet = await get_watch_wallet(wallet_id)
@ -79,7 +79,7 @@ async def api_wallet_delete(wallet_id):
#############################ADDRESSES########################## #############################ADDRESSES##########################
@watchonly_ext.get("/api/v1/address/<wallet_id>") @watchonly_ext.get("/api/v1/address/{wallet_id}")
@api_check_wallet_key("invoice") @api_check_wallet_key("invoice")
async def api_fresh_address(wallet_id): async def api_fresh_address(wallet_id):
await get_fresh_address(wallet_id) await get_fresh_address(wallet_id)
@ -89,7 +89,7 @@ async def api_fresh_address(wallet_id):
return [address._asdict() for address in addresses], HTTPStatus.OK return [address._asdict() for address in addresses], HTTPStatus.OK
@watchonly_ext.get("/api/v1/addresses/<wallet_id>") @watchonly_ext.get("/api/v1/addresses/{wallet_id}")
@api_check_wallet_key("invoice") @api_check_wallet_key("invoice")
async def api_get_addresses(wallet_id): async def api_get_addresses(wallet_id):
wallet = await get_watch_wallet(wallet_id) wallet = await get_watch_wallet(wallet_id)

View file

@ -46,7 +46,7 @@ async def api_links():
) )
@withdraw_ext.get("/api/v1/links/<link_id>") @withdraw_ext.get("/api/v1/links/{link_id}")
@api_check_wallet_key("invoice") @api_check_wallet_key("invoice")
async def api_link_retrieve(link_id): async def api_link_retrieve(link_id):
link = await get_withdraw_link(link_id, 0) link = await get_withdraw_link(link_id, 0)
@ -70,7 +70,7 @@ class CreateData(BaseModel):
is_unique: bool is_unique: bool
@withdraw_ext.post("/api/v1/links") @withdraw_ext.post("/api/v1/links")
@withdraw_ext.put("/api/v1/links/<link_id>") @withdraw_ext.put("/api/v1/links/{link_id}")
@api_check_wallet_key("admin") @api_check_wallet_key("admin")
async def api_link_create_or_update(data: CreateData, link_id: str = None): async def api_link_create_or_update(data: CreateData, link_id: str = None):
if data.max_withdrawable < data.min_withdrawable: if data.max_withdrawable < data.min_withdrawable:
@ -97,7 +97,7 @@ async def api_link_create_or_update(data: CreateData, link_id: str = None):
HTTPStatus.NOT_FOUND, HTTPStatus.NOT_FOUND,
) )
if link.wallet != g.wallet.id: if link.wallet != g.wallet.id:
return jsonify({"message": "Not your withdraw link."}), HTTPStatus.FORBIDDEN return {"message": "Not your withdraw link."}, HTTPStatus.FORBIDDEN
link = await update_withdraw_link(link_id, **data, usescsv=usescsv, used=0) link = await update_withdraw_link(link_id, **data, usescsv=usescsv, used=0)
else: else:
link = await create_withdraw_link( link = await create_withdraw_link(
@ -109,7 +109,7 @@ async def api_link_create_or_update(data: CreateData, link_id: str = None):
) )
@withdraw_ext.delete("/api/v1/links/<link_id>") @withdraw_ext.delete("/api/v1/links/{link_id}")
@api_check_wallet_key("admin") @api_check_wallet_key("admin")
async def api_link_delete(link_id): async def api_link_delete(link_id):
link = await get_withdraw_link(link_id) link = await get_withdraw_link(link_id)
@ -127,7 +127,7 @@ async def api_link_delete(link_id):
return "", HTTPStatus.NO_CONTENT return "", HTTPStatus.NO_CONTENT
@withdraw_ext.get("/api/v1/links/<the_hash>/<lnurl_id>") @withdraw_ext.get("/api/v1/links/{the_hash}/{lnurl_id}")
@api_check_wallet_key("invoice") @api_check_wallet_key("invoice")
async def api_hash_retrieve(the_hash, lnurl_id): async def api_hash_retrieve(the_hash, lnurl_id):
hashCheck = await get_hash_check(the_hash, lnurl_id) hashCheck = await get_hash_check(the_hash, lnurl_id)

View file

@ -0,0 +1,36 @@
# Borrowed from the excellent accent-starlette
# https://github.com/accent-starlette/starlette-core/blob/master/starlette_core/templating.py
import typing
from starlette import templating
from starlette.datastructures import QueryParams
from lnbits.requestvars import g
try:
import jinja2
except ImportError: # pragma: nocover
jinja2 = None # type: ignore
class Jinja2Templates(templating.Jinja2Templates):
def __init__(self, loader: jinja2.BaseLoader) -> None:
assert jinja2 is not None, "jinja2 must be installed to use Jinja2Templates"
self.env = self.get_environment(loader)
def get_environment(self, loader: "jinja2.BaseLoader") -> "jinja2.Environment":
@jinja2.contextfunction
def url_for(context: dict, name: str, **path_params: typing.Any) -> str:
request = context["request"]
return request.url_for(name, **path_params)
def url_params_update(init: QueryParams, **new: typing.Any) -> QueryParams:
values = dict(init)
values.update(new)
return QueryParams(**values)
env = jinja2.Environment(loader=loader, autoescape=True)
env.globals["url_for"] = url_for
env.globals["url_params_update"] = url_params_update
return env

9
lnbits/requestvars.py Normal file
View file

@ -0,0 +1,9 @@
import contextvars
import types
request_global = contextvars.ContextVar("request_global",
default=types.SimpleNamespace())
def g() -> types.SimpleNamespace:
return request_global.get()

View file

@ -2,7 +2,7 @@
<html lang="en"> <html lang="en">
<head> <head>
{% for url in g.VENDORED_CSS %} {% for url in VENDORED_CSS %}
<link rel="stylesheet" type="text/css" href="{{ url }}" /> <link rel="stylesheet" type="text/css" href="{{ url }}" />
{% endfor %} {% endfor %}
<!----> <!---->
@ -184,7 +184,7 @@
{% block vue_templates %}{% endblock %} {% block vue_templates %}{% endblock %}
<!----> <!---->
{% for url in g.VENDORED_JS %} {% for url in VENDORED_JS %}
<script src="{{ url }}"></script> <script src="{{ url }}"></script>
{% endfor %} {% endfor %}
<!----> <!---->

View file

@ -2,7 +2,7 @@
<html lang="en"> <html lang="en">
<head> <head>
{% for url in g.VENDORED_CSS %} {% for url in VENDORED_CSS %}
<link rel="stylesheet" type="text/css" href="{{ url }}" /> <link rel="stylesheet" type="text/css" href="{{ url }}" />
{% endfor %} {% endfor %}
<style> <style>
@ -33,7 +33,7 @@
</q-page-container> </q-page-container>
</q-layout> </q-layout>
{% for url in g.VENDORED_JS %} {% for url in VENDORED_JS %}
<script src="{{ url }}"></script> <script src="{{ url }}"></script>
{% endfor %} {% endfor %}
<!----> <!---->