diff --git a/lnbits/extensions/watchonly/crud.py b/lnbits/extensions/watchonly/crud.py index 61e47cfe..1d9abcec 100644 --- a/lnbits/extensions/watchonly/crud.py +++ b/lnbits/extensions/watchonly/crud.py @@ -41,8 +41,9 @@ async def create_watch_wallet(user: str, w: WalletAccount) -> WalletAccount: w.meta, ), ) - - return await get_watch_wallet(wallet_id) + wallet = await get_watch_wallet(wallet_id) + assert wallet + return wallet async def get_watch_wallet(wallet_id: str) -> Optional[WalletAccount]: @@ -121,11 +122,11 @@ async def create_fresh_addresses( change_address=False, ) -> List[Address]: if start_address_index > end_address_index: - return None + return [] wallet = await get_watch_wallet(wallet_id) if not wallet: - return None + return [] branch_index = 1 if change_address else 0 @@ -150,7 +151,7 @@ async def create_fresh_addresses( # return fresh addresses rows = await db.fetchall( """ - SELECT * FROM watchonly.addresses + SELECT * FROM watchonly.addresses WHERE wallet = ? AND branch_index = ? AND address_index >= ? AND address_index < ? ORDER BY branch_index, address_index """, @@ -172,7 +173,7 @@ async def get_address_at_index( ) -> Optional[Address]: row = await db.fetchone( """ - SELECT * FROM watchonly.addresses + SELECT * FROM watchonly.addresses WHERE wallet = ? AND branch_index = ? AND address_index = ? """, ( diff --git a/lnbits/extensions/watchonly/models.py b/lnbits/extensions/watchonly/models.py index c6265d6c..24d63bfd 100644 --- a/lnbits/extensions/watchonly/models.py +++ b/lnbits/extensions/watchonly/models.py @@ -1,7 +1,7 @@ from sqlite3 import Row from typing import List, Optional -from fastapi.param_functions import Query +from fastapi import Query from pydantic import BaseModel @@ -35,7 +35,7 @@ class Address(BaseModel): amount: int = 0 branch_index: int = 0 address_index: int - note: str = None + note: Optional[str] = None has_activity: bool = False @classmethod @@ -57,9 +57,9 @@ class TransactionInput(BaseModel): class TransactionOutput(BaseModel): amount: int address: str - branch_index: int = None - address_index: int = None - wallet: str = None + branch_index: Optional[int] = None + address_index: Optional[int] = None + wallet: Optional[str] = None class MasterPublicKey(BaseModel): diff --git a/lnbits/extensions/watchonly/views.py b/lnbits/extensions/watchonly/views.py index 819d1248..8cebc6cc 100644 --- a/lnbits/extensions/watchonly/views.py +++ b/lnbits/extensions/watchonly/views.py @@ -1,6 +1,5 @@ -from fastapi.params import Depends +from fastapi import Depends, Request from fastapi.templating import Jinja2Templates -from starlette.requests import Request from starlette.responses import HTMLResponse from lnbits.core.models import User diff --git a/lnbits/extensions/watchonly/views_api.py b/lnbits/extensions/watchonly/views_api.py index c6e15ea6..5bb43661 100644 --- a/lnbits/extensions/watchonly/views_api.py +++ b/lnbits/extensions/watchonly/views_api.py @@ -1,5 +1,6 @@ import json from http import HTTPStatus +from typing import List import httpx from embit import finalizer, script @@ -7,9 +8,7 @@ from embit.ec import PublicKey from embit.networks import NETWORKS from embit.psbt import PSBT, DerivationPath from embit.transaction import Transaction, TransactionInput, TransactionOutput -from fastapi import Query, Request -from fastapi.params import Depends -from starlette.exceptions import HTTPException +from fastapi import Depends, HTTPException, Query, Request from lnbits.decorators import WalletTypeInfo, get_key_type, require_admin_key from lnbits.extensions.watchonly import watchonly_ext @@ -57,10 +56,8 @@ async def api_wallets_retrieve( return [] -@watchonly_ext.get("/api/v1/wallet/{wallet_id}") -async def api_wallet_retrieve( - wallet_id, wallet: WalletTypeInfo = Depends(get_key_type) -): +@watchonly_ext.get("/api/v1/wallet/{wallet_id}", dependencies=[Depends(get_key_type)]) +async def api_wallet_retrieve(wallet_id: str): w_wallet = await get_watch_wallet(wallet_id) if not w_wallet: @@ -76,7 +73,8 @@ async def api_wallet_create_or_update( data: CreateWallet, w: WalletTypeInfo = Depends(require_admin_key) ): try: - (descriptor, network) = parse_key(data.masterpub) + # TODO: talk to motorina about this + (descriptor, network) = parse_key(data.masterpub) # type: ignore if data.network != network["name"]: raise ValueError( "Account network error. This account is for '{}'".format( @@ -126,8 +124,10 @@ async def api_wallet_create_or_update( return wallet.dict() -@watchonly_ext.delete("/api/v1/wallet/{wallet_id}") -async def api_wallet_delete(wallet_id, w: WalletTypeInfo = Depends(require_admin_key)): +@watchonly_ext.delete( + "/api/v1/wallet/{wallet_id}", dependencies=[Depends(require_admin_key)] +) +async def api_wallet_delete(wallet_id: str): wallet = await get_watch_wallet(wallet_id) if not wallet: @@ -144,16 +144,15 @@ async def api_wallet_delete(wallet_id, w: WalletTypeInfo = Depends(require_admin #############################ADDRESSES########################## -@watchonly_ext.get("/api/v1/address/{wallet_id}") -async def api_fresh_address(wallet_id, w: WalletTypeInfo = Depends(get_key_type)): +@watchonly_ext.get("/api/v1/address/{wallet_id}", dependencies=[Depends(get_key_type)]) +async def api_fresh_address(wallet_id: str): address = await get_fresh_address(wallet_id) + assert address return address.dict() -@watchonly_ext.put("/api/v1/address/{id}") -async def api_update_address( - id: str, req: Request, w: WalletTypeInfo = Depends(require_admin_key) -): +@watchonly_ext.put("/api/v1/address/{id}", dependencies=[Depends(require_admin_key)]) +async def api_update_address(id: str, req: Request): body = await req.json() params = {} # amout is only updated if the address has history @@ -162,9 +161,10 @@ async def api_update_address( params["has_activity"] = True if "note" in body: - params["note"] = str(body["note"]) + params["note"] = body["note"] address = await update_address(**params, id=id) + assert address wallet = ( await get_watch_wallet(address.wallet) @@ -189,6 +189,7 @@ async def api_get_addresses(wallet_id, w: WalletTypeInfo = Depends(get_key_type) addresses = await get_addresses(wallet_id) config = await get_config(w.wallet.user) + assert config if not addresses: await create_fresh_addresses(wallet_id, 0, config.receive_gap_limit) @@ -229,10 +230,8 @@ async def api_get_addresses(wallet_id, w: WalletTypeInfo = Depends(get_key_type) #############################PSBT########################## -@watchonly_ext.post("/api/v1/psbt") -async def api_psbt_create( - data: CreatePsbt, w: WalletTypeInfo = Depends(require_admin_key) -): +@watchonly_ext.post("/api/v1/psbt", dependencies=[Depends(require_admin_key)]) +async def api_psbt_create(data: CreatePsbt): try: vin = [ TransactionInput(bytes.fromhex(inp.tx_id), inp.vout) for inp in data.inputs @@ -246,7 +245,7 @@ async def api_psbt_create( for _, masterpub in enumerate(data.masterpubs): descriptors[masterpub.id] = parse_key(masterpub.public_key) - inputs_extra = [] + inputs_extra: List[dict] = [] for i, inp in enumerate(data.inputs): bip32_derivations = {} @@ -266,14 +265,15 @@ async def api_psbt_create( tx = Transaction(vin=vin, vout=vout) psbt = PSBT(tx) - for i, inp in enumerate(inputs_extra): - psbt.inputs[i].bip32_derivations = inp["bip32_derivations"] - psbt.inputs[i].non_witness_utxo = inp.get("non_witness_utxo", None) + for i, inp_extra in enumerate(inputs_extra): + psbt.inputs[i].bip32_derivations = inp_extra["bip32_derivations"] + psbt.inputs[i].non_witness_utxo = inp_extra.get("non_witness_utxo", None) outputs_extra = [] bip32_derivations = {} for i, out in enumerate(data.outputs): if out.branch_index == 1: + assert out.wallet descriptor = descriptors[out.wallet][0] d = descriptor.derive(out.address_index, out.branch_index) for k in d.keys: @@ -282,8 +282,8 @@ async def api_psbt_create( ) outputs_extra.append({"bip32_derivations": bip32_derivations}) - for i, out in enumerate(outputs_extra): - psbt.outputs[i].bip32_derivations = out["bip32_derivations"] + for i, out_extra in enumerate(outputs_extra): + psbt.outputs[i].bip32_derivations = out_extra["bip32_derivations"] return psbt.to_string() @@ -360,7 +360,8 @@ async def api_tx_broadcast( else config.mempool_endpoint + "/testnet" ) async with httpx.AsyncClient() as client: - r = await client.post(endpoint + "/api/tx", data=data.tx_hex) + r = await client.post(endpoint + "/api/tx", content=data.tx_hex) + r.raise_for_status() tx_id = r.text return tx_id except Exception as e: @@ -375,6 +376,7 @@ async def api_update_config( data: Config, w: WalletTypeInfo = Depends(require_admin_key) ): config = await update_config(data, user=w.wallet.user) + assert config return config.dict() diff --git a/pyproject.toml b/pyproject.toml index 03dbbc8d..b41cff10 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -93,7 +93,6 @@ exclude = """(?x)( | ^lnbits/extensions/boltz. | ^lnbits/extensions/livestream. | ^lnbits/extensions/lnurldevice. - | ^lnbits/extensions/watchonly. | ^lnbits/wallets/lnd_grpc_files. )"""