diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml deleted file mode 100644 index 5bcdee7..0000000 --- a/.github/workflows/ci.yml +++ /dev/null @@ -1,29 +0,0 @@ -name: CI -on: - push: - branches: - - main - pull_request: - -jobs: - lint: - uses: lnbits/lnbits/.github/workflows/lint.yml@dev - tests: - runs-on: ubuntu-latest - needs: [lint] - steps: - - uses: actions/checkout@v4 - - uses: lnbits/lnbits/.github/actions/prepare@dev - - name: Run pytest - uses: pavelzw/pytest-action@v2 - env: - LNBITS_BACKEND_WALLET_CLASS: FakeWallet - PYTHONUNBUFFERED: 1 - DEBUG: true - with: - verbose: true - job-summary: true - emoji: false - click-to-expand: true - custom-pytest: uv run pytest - report-title: 'test' diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 27c8a60..7ec9b48 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -1,9 +1,10 @@ on: push: tags: - - 'v[0-9]+.[0-9]+.[0-9]+' + - "v[0-9]+.[0-9]+.[0-9]+" jobs: + release: runs-on: ubuntu-latest steps: @@ -33,12 +34,12 @@ jobs: - name: Create pull request in extensions repo env: GH_TOKEN: ${{ secrets.EXT_GITHUB }} - repo_name: '${{ github.event.repository.name }}' - tag: '${{ github.ref_name }}' - branch: 'update-${{ github.event.repository.name }}-${{ github.ref_name }}' - title: '[UPDATE] ${{ github.event.repository.name }} to ${{ github.ref_name }}' - body: 'https://github.com/lnbits/${{ github.event.repository.name }}/releases/${{ github.ref_name }}' - archive: 'https://github.com/lnbits/${{ github.event.repository.name }}/archive/refs/tags/${{ github.ref_name }}.zip' + repo_name: "${{ github.event.repository.name }}" + tag: "${{ github.ref_name }}" + branch: "update-${{ github.event.repository.name }}-${{ github.ref_name }}" + title: "[UPDATE] ${{ github.event.repository.name }} to ${{ github.ref_name }}" + body: "https://github.com/lnbits/${{ github.event.repository.name }}/releases/${{ github.ref_name }}" + archive: "https://github.com/lnbits/${{ github.event.repository.name }}/archive/refs/tags/${{ github.ref_name }}.zip" run: | cd lnbits-extensions git checkout -b $branch diff --git a/.gitignore b/.gitignore index 0152b6e..10a11d5 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,24 @@ +.DS_Store +._* + __pycache__ -node_modules +*.py[cod] +*$py.class .mypy_cache -.venv +.vscode +*-lock.json + +*.egg +*.egg-info +.coverage +.pytest_cache +.webassets-cache +htmlcov +test-reports +tests/data/*.sqlite3 + +*.swo +*.swp +*.pyo +*.pyc +*.env \ No newline at end of file diff --git a/.prettierrc b/.prettierrc deleted file mode 100644 index 725c398..0000000 --- a/.prettierrc +++ /dev/null @@ -1,12 +0,0 @@ -{ - "semi": false, - "arrowParens": "avoid", - "insertPragma": false, - "printWidth": 80, - "proseWrap": "preserve", - "singleQuote": true, - "trailingComma": "none", - "useTabs": false, - "bracketSameLine": false, - "bracketSpacing": false -} diff --git a/Makefile b/Makefile deleted file mode 100644 index 0fac253..0000000 --- a/Makefile +++ /dev/null @@ -1,47 +0,0 @@ -all: format check - -format: prettier black ruff - -check: mypy pyright checkblack checkruff checkprettier - -prettier: - uv run ./node_modules/.bin/prettier --write . -pyright: - uv run ./node_modules/.bin/pyright - -mypy: - uv run mypy . - -black: - uv run black . - -ruff: - uv run ruff check . --fix - -checkruff: - uv run ruff check . - -checkprettier: - uv run ./node_modules/.bin/prettier --check . - -checkblack: - uv run black --check . - -checkeditorconfig: - editorconfig-checker - -test: - PYTHONUNBUFFERED=1 \ - DEBUG=true \ - uv run pytest -install-pre-commit-hook: - @echo "Installing pre-commit hook to git" - @echo "Uninstall the hook with uv run pre-commit uninstall" - uv run pre-commit install - -pre-commit: - uv run pre-commit run --all-files - - -checkbundle: - @echo "skipping checkbundle" diff --git a/README.md b/README.md index 70593a8..02d12b7 100644 --- a/README.md +++ b/README.md @@ -2,123 +2,14 @@ For more about LNBits extension check [this tutorial](https://github.com/lnbits/lnbits/wiki/LNbits-Extensions) -## Overview +`nostrclient` is an always-on extension that can open multiple connections to nostr relays and act as a multiplexer for other clients: You open a single websocket to `nostrclient` which then sends the data to multiple relays. The responses from these relays are then sent back to the client. -`nostrclient` is an always-on Nostr relay multiplexer that simplifies connecting to multiple Nostr relays. Instead of your Nostr client managing connections to dozens of relays, you connect to a single WebSocket endpoint provided by `nostrclient`, which then fans out your requests to all configured relays and aggregates the responses back to you. +![2023-03-08 18 11 07](https://user-images.githubusercontent.com/93376500/225265727-369f0f8a-196e-41df-a0d1-98b50a0228be.jpg) -### Why Use This? +### Troubleshoot -- **Simplified Client Configuration** - Connect to one endpoint instead of managing multiple relay connections -- **Always-On Connectivity** - Your LNbits instance maintains persistent connections to relays -- **Resource Efficient** - Share relay connections across multiple clients -- **Subscription Management** - Automatic subscription ID rewriting prevents conflicts between clients +The `Test Endpoint` functionality heps the user to check that the `nostrclient` web-socket endpoint works as expected. -## Architecture - -```mermaid -flowchart LR - A[Client A] -->|WebSocket| N - B[Client B] -->|WebSocket| N - C[Client C] -->|WebSocket| N - - N[nostrclient
Router] -->|Fan Out| R1[Relay A] - N -->|Fan Out| R2[Relay B] - N -->|Fan Out| R3[Relay C] - N -->|Fan Out| R4[Relay D] - - R1 -.->|Aggregate| N - R2 -.->|Aggregate| N - R3 -.->|Aggregate| N - R4 -.->|Aggregate| N -``` - -**Key Feature:** The router rewrites subscription IDs to prevent conflicts when multiple clients use the same IDs. - -## Features - -- **Multi-Relay Multiplexing** - Connect to multiple Nostr relays through a single WebSocket -- **Public & Private Endpoints** - Configurable public and private WebSocket access -- **Automatic Reconnection** - Failed relays are automatically retried with exponential backoff -- **Subscription Deduplication** - Events are deduplicated before being sent to clients -- **Health Monitoring** - Track relay connection status, latency, and error rates -- **Test Endpoint** - Send test messages to verify your setup is working - -## How It Works - -1. **Client Connection** - Your Nostr client connects to the nostrclient WebSocket endpoint -2. **Subscription Rewriting** - Each subscription ID is rewritten to prevent conflicts between multiple clients -3. **Fan-Out** - Subscription requests are sent to all configured relays -4. **Aggregation** - Events from all relays are collected and deduplicated -5. **Response** - Events are sent back to the client with the original subscription ID - -## Configuration - -### WebSocket Endpoints - -- **Public Endpoint**: `/api/v1/relay` - Available to anyone (if enabled) -- **Private Endpoint**: `/api/v1/{encrypted_id}` - Requires valid encrypted endpoint ID - -Configure endpoint access in the extension settings: - -- `private_ws` - Enable/disable private WebSocket access -- `public_ws` - Enable/disable public WebSocket access - -### Adding Relays - -Use the nostrclient UI to add/remove Nostr relays. The extension will automatically: - -- Connect to new relays -- Publish existing subscriptions to new relays -- Monitor relay health and reconnect as needed - -## Testing - -### Test Endpoint Functionality - -The `Test Endpoint` feature helps verify that your nostrclient WebSocket endpoint works correctly. - -**How to test:** - -1. Navigate to the nostrclient extension in LNbits -2. Use the Test Endpoint feature -3. Send a DM to yourself (or a temporary account) -4. Verify that messages are sent and received correctly +The LNbits user can DM itself (or a temp account) from `nostrclient` and verify that the messages are sent and received correctly. https://user-images.githubusercontent.com/2951406/236780745-929c33c2-2502-49be-84a3-db02a7aabc0e.mp4 - -## Troubleshooting - -### Connection Issues - -- **Check relay status** - View relay health in the nostrclient UI -- **Verify endpoint configuration** - Ensure public_ws or private_ws is enabled -- **Check logs** - Review LNbits logs for connection errors - -### Subscription Not Receiving Events - -- **Verify relays are connected** - Check the relay status in the UI -- **Test with known event** - Use the Test Endpoint to verify connectivity -- **Check relay compatibility** - Some relays may not support all Nostr features - -## Development - -This extension uses `uv` for dependency management. - -### Quick Start - -```bash -# Format code -make format - -# Run type checks and linting -make check - -# Run tests -make test -``` - -For more development commands, see the [Makefile](./Makefile). - -## License - -MIT License - see [LICENSE](./LICENSE) diff --git a/__init__.py b/__init__.py index d7eb435..7f573e7 100644 --- a/__init__.py +++ b/__init__.py @@ -1,13 +1,15 @@ import asyncio +from typing import List from fastapi import APIRouter -from loguru import logger -from .crud import db -from .router import all_routers, nostr_client -from .tasks import check_relays, init_relays, subscribe_events -from .views import nostrclient_generic_router -from .views_api import nostrclient_api_router +from lnbits.db import Database +from lnbits.helpers import template_renderer +from lnbits.tasks import catch_everything_and_restart + +from .nostr.client.client import NostrClient as NostrClientLib + +db = Database("ext_nostrclient") nostrclient_static_files = [ { @@ -17,43 +19,31 @@ nostrclient_static_files = [ ] nostrclient_ext: APIRouter = APIRouter(prefix="/nostrclient", tags=["nostrclient"]) -nostrclient_ext.include_router(nostrclient_generic_router) -nostrclient_ext.include_router(nostrclient_api_router) -scheduled_tasks: list[asyncio.Task] = [] + +scheduled_tasks: List[asyncio.Task] = [] + +class NostrClient: + def __init__(self): + self.client: NostrClientLib = NostrClientLib(connect=False) -async def nostrclient_stop(): - for task in scheduled_tasks: - try: - task.cancel() - except Exception as ex: - logger.warning(ex) +nostr = NostrClient() - for router in all_routers: - try: - await router.stop() - all_routers.remove(router) - except Exception as e: - logger.error(e) - nostr_client.close() +def nostr_renderer(): + return template_renderer(["nostrclient/templates"]) + + +from .tasks import check_relays, init_relays, subscribe_events +from .views import * # noqa +from .views_api import * # noqa def nostrclient_start(): - from lnbits.tasks import create_permanent_unique_task - - task1 = create_permanent_unique_task("ext_nostrclient_init_relays", init_relays) - task2 = create_permanent_unique_task( - "ext_nostrclient_subscrive_events", subscribe_events - ) - task3 = create_permanent_unique_task("ext_nostrclient_check_relays", check_relays) - scheduled_tasks.extend([task1, task2, task3]) - - -__all__ = [ - "db", - "nostrclient_ext", - "nostrclient_start", - "nostrclient_static_files", - "nostrclient_stop", -] + loop = asyncio.get_event_loop() + task1 = loop.create_task(catch_everything_and_restart(init_relays)) + scheduled_tasks.append(task1) + task2 = loop.create_task(catch_everything_and_restart(subscribe_events)) + scheduled_tasks.append(task2) + task3 = loop.create_task(catch_everything_and_restart(check_relays)) + scheduled_tasks.append(task3) diff --git a/cbc.py b/cbc.py new file mode 100644 index 0000000..0d9e04f --- /dev/null +++ b/cbc.py @@ -0,0 +1,26 @@ +from Cryptodome.Cipher import AES + +BLOCK_SIZE = 16 + + +class AESCipher(object): + """This class is compatible with crypto.createCipheriv('aes-256-cbc')""" + + def __init__(self, key=None): + self.key = key + + def pad(self, data): + length = BLOCK_SIZE - (len(data) % BLOCK_SIZE) + return data + (chr(length) * length).encode() + + def unpad(self, data): + return data[: -(data[-1] if type(data[-1]) == int else ord(data[-1]))] + + def encrypt(self, plain_text): + cipher = AES.new(self.key, AES.MODE_CBC) + b = plain_text.encode("UTF-8") + return cipher.iv, cipher.encrypt(self.pad(b)) + + def decrypt(self, iv, enc_text): + cipher = AES.new(self.key, AES.MODE_CBC, iv=iv) + return self.unpad(cipher.decrypt(enc_text).decode("UTF-8")) diff --git a/config.json b/config.json index 1f58e7b..d8b886b 100644 --- a/config.json +++ b/config.json @@ -1,17 +1,7 @@ { "name": "Nostr Client", - "short_description": "Nostr relay multiplexer", - "version": "1.1.0", + "short_description": "Nostr client for extensions", "tile": "/nostrclient/static/images/nostr-bitcoin.png", - "contributors": ["calle", "motorina0", "dni"], - "min_lnbits_version": "1.4.0", - "images": [ - { - "uri": "https://raw.githubusercontent.com/lnbits/nostrclient/add-extension-metadata/static/images/1.jpeg" - }, - { - "uri": "https://raw.githubusercontent.com/lnbits/nostrclient/add-extension-metadata/static/images/2.jpeg" - } - ], - "description_md": "https://raw.githubusercontent.com/lnbits/nostrclient/add-extension-metadata/description.md" + "contributors": ["calle"], + "min_lnbits_version": "0.11.0" } diff --git a/crud.py b/crud.py index d311c72..780642d 100644 --- a/crud.py +++ b/crud.py @@ -1,52 +1,31 @@ -from lnbits.db import Database +from typing import List, Optional, Union -from .models import Config, Relay, UserConfig +import shortuuid -db = Database("ext_nostrclient") +from lnbits.helpers import urlsafe_short_hash + +from . import db +from .models import Relay, RelayList -async def get_relays() -> list[Relay]: - return await db.fetchall( - "SELECT * FROM nostrclient.relays", - model=Relay, +async def get_relays() -> RelayList: + row = await db.fetchall("SELECT * FROM nostrclient.relays") + return RelayList(__root__=row) + + +async def add_relay(relay: Relay) -> None: + await db.execute( + f""" + INSERT INTO nostrclient.relays ( + id, + url, + active + ) + VALUES (?, ?, ?) + """, + (relay.id, relay.url, relay.active), ) -async def add_relay(relay: Relay) -> Relay: - await db.insert("nostrclient.relays", relay) - return relay - - async def delete_relay(relay: Relay) -> None: - if not relay.url: - return - await db.execute( - "DELETE FROM nostrclient.relays WHERE url = :url", {"url": relay.url} - ) - - -######################CONFIG####################### -async def create_config(owner_id: str) -> Config: - admin_config = UserConfig(owner_id=owner_id) - await db.insert("nostrclient.config", admin_config) - return admin_config.extra - - -async def update_config(owner_id: str, config: Config) -> Config: - user_config = UserConfig(owner_id=owner_id, extra=config) - await db.update("nostrclient.config", user_config, "WHERE owner_id = :owner_id") - return user_config.extra - - -async def get_config(owner_id: str) -> Config | None: - user_config: UserConfig = await db.fetchone( - """ - SELECT * FROM nostrclient.config - WHERE owner_id = :owner_id - """, - {"owner_id": owner_id}, - model=UserConfig, - ) - if user_config: - return user_config.extra - return None + await db.execute("DELETE FROM nostrclient.relays WHERE url = ?", (relay.url,)) diff --git a/description.md b/description.md deleted file mode 100644 index 5293087..0000000 --- a/description.md +++ /dev/null @@ -1,8 +0,0 @@ -An always-on relay multiplexer that simplifies connecting to multiple Nostr relays. - -Instead of your Nostr client managing connections to dozens of relays, you connect to a single WebSocket endpoint provided by `nostrclient`, which then fans out your requests to all configured relays and aggregates the responses back to you. - -- **Simplified Client Configuration** - Connect to one endpoint instead of managing multiple relay connections -- **Always-On Connectivity** - Your LNbits instance maintains persistent connections to relays -- **Resource Efficient** - Share relay connections across multiple clients -- **Automatic Subscription Management** - Subscription ID rewriting prevents conflicts between clients diff --git a/migrations.py b/migrations.py index b16db58..5a30e45 100644 --- a/migrations.py +++ b/migrations.py @@ -3,7 +3,7 @@ async def m001_initial(db): Initial nostrclient table. """ await db.execute( - """ + f""" CREATE TABLE nostrclient.relays ( id TEXT NOT NULL PRIMARY KEY, url TEXT NOT NULL, @@ -11,22 +11,3 @@ async def m001_initial(db): ); """ ) - - -async def m002_create_config_table(db): - """ - Allow the extension to persist and retrieve any number of config values. - """ - - await db.execute( - """CREATE TABLE nostrclient.config ( - json_data TEXT NOT NULL - );""" - ) - - -async def m003_update_config_table(db): - await db.execute("ALTER TABLE nostrclient.config RENAME COLUMN json_data TO extra") - await db.execute( - "ALTER TABLE nostrclient.config ADD COLUMN owner_id TEXT DEFAULT 'admin'" - ) diff --git a/models.py b/models.py index 937c6c5..88651fc 100644 --- a/models.py +++ b/models.py @@ -1,54 +1,68 @@ -from lnbits.helpers import urlsafe_short_hash +from dataclasses import dataclass +from typing import Dict, List, Optional + +from fastapi import Request +from fastapi.param_functions import Query from pydantic import BaseModel, Field +from lnbits.helpers import urlsafe_short_hash + class RelayStatus(BaseModel): - num_sent_events: int | None = 0 - num_received_events: int | None = 0 - error_counter: int | None = 0 - error_list: list | None = [] - notice_list: list | None = [] - - + num_sent_events: Optional[int] = 0 + num_received_events: Optional[int] = 0 + error_counter: Optional[int] = 0 + error_list: Optional[List] = [] + notice_list: Optional[List] = [] + class Relay(BaseModel): - id: str | None = None - url: str | None = None - active: bool | None = None - - connected: bool | None = Field(default=None, no_database=True) - connected_string: str | None = Field(default=None, no_database=True) - status: RelayStatus | None = Field(default=None, no_database=True) - - ping: int | None = Field(default=None, no_database=True) + id: Optional[str] = None + url: Optional[str] = None + connected: Optional[bool] = None + connected_string: Optional[str] = None + status: Optional[RelayStatus] = None + active: Optional[bool] = None + ping: Optional[int] = None def _init__(self): if not self.id: self.id = urlsafe_short_hash() -class RelayDb(BaseModel): - id: str - url: str - active: bool | None = True +class RelayList(BaseModel): + __root__: List[Relay] + + +class Event(BaseModel): + content: str + pubkey: str + created_at: Optional[int] + kind: int + tags: Optional[List[List[str]]] + sig: str + + +class Filter(BaseModel): + ids: Optional[List[str]] + kinds: Optional[List[int]] + authors: Optional[List[str]] + since: Optional[int] + until: Optional[int] + e: Optional[List[str]] = Field(alias="#e") + p: Optional[List[str]] = Field(alias="#p") + limit: Optional[int] + + +class Filters(BaseModel): + __root__: List[Filter] class TestMessage(BaseModel): - sender_private_key: str | None + sender_private_key: Optional[str] reciever_public_key: str message: str - class TestMessageResponse(BaseModel): private_key: str public_key: str event_json: str - - -class Config(BaseModel): - private_ws: bool = True - public_ws: bool = False - - -class UserConfig(BaseModel): - owner_id: str - extra: Config = Config() diff --git a/tests/__init__.py b/nostr/__init__.py similarity index 100% rename from tests/__init__.py rename to nostr/__init__.py diff --git a/nostr/bech32.py b/nostr/bech32.py index ba2ddd1..61a92c4 100644 --- a/nostr/bech32.py +++ b/nostr/bech32.py @@ -26,22 +26,19 @@ from enum import Enum class Encoding(Enum): """Enumeration type to list the various supported encodings.""" - BECH32 = 1 BECH32M = 2 - CHARSET = "qpzry9x8gf2tvdw0s3jn54khce6mua7l" -BECH32M_CONST = 0x2BC830A3 - +BECH32M_CONST = 0x2bc830a3 def bech32_polymod(values): """Internal function that computes the Bech32 checksum.""" - generator = [0x3B6A57B2, 0x26508E6D, 0x1EA119FA, 0x3D4233DD, 0x2A1462B3] + generator = [0x3b6a57b2, 0x26508e6d, 0x1ea119fa, 0x3d4233dd, 0x2a1462b3] chk = 1 for value in values: top = chk >> 25 - chk = (chk & 0x1FFFFFF) << 5 ^ value + chk = (chk & 0x1ffffff) << 5 ^ value for i in range(5): chk ^= generator[i] if ((top >> i) & 1) else 0 return chk @@ -61,7 +58,6 @@ def bech32_verify_checksum(hrp, data): return Encoding.BECH32M return None - def bech32_create_checksum(hrp, data, spec): """Compute the checksum values given HRP and data.""" values = bech32_hrp_expand(hrp) + data @@ -73,29 +69,26 @@ def bech32_create_checksum(hrp, data, spec): def bech32_encode(hrp, data, spec): """Compute a Bech32 string given HRP and data values.""" combined = data + bech32_create_checksum(hrp, data, spec) - return hrp + "1" + "".join([CHARSET[d] for d in combined]) - + return hrp + '1' + ''.join([CHARSET[d] for d in combined]) def bech32_decode(bech): """Validate a Bech32/Bech32m string, and determine HRP and data.""" - if (any(ord(x) < 33 or ord(x) > 126 for x in bech)) or ( - bech.lower() != bech and bech.upper() != bech - ): + if ((any(ord(x) < 33 or ord(x) > 126 for x in bech)) or + (bech.lower() != bech and bech.upper() != bech)): return (None, None, None) bech = bech.lower() - pos = bech.rfind("1") + pos = bech.rfind('1') if pos < 1 or pos + 7 > len(bech) or len(bech) > 90: return (None, None, None) - if not all(x in CHARSET for x in bech[pos + 1 :]): + if not all(x in CHARSET for x in bech[pos+1:]): return (None, None, None) hrp = bech[:pos] - data = [CHARSET.find(x) for x in bech[pos + 1 :]] + data = [CHARSET.find(x) for x in bech[pos+1:]] spec = bech32_verify_checksum(hrp, data) if spec is None: return (None, None, None) return (hrp, data[:-6], spec) - def convertbits(data, frombits, tobits, pad=True): """General power-of-2 base conversion.""" acc = 0 @@ -124,29 +117,22 @@ def decode(hrp, addr): hrpgot, data, spec = bech32_decode(addr) if hrpgot != hrp: return (None, None) - decoded = convertbits(data[1:], 5, 8, False) # type: ignore + decoded = convertbits(data[1:], 5, 8, False) if decoded is None or len(decoded) < 2 or len(decoded) > 40: return (None, None) - if data[0] > 16: # type: ignore + if data[0] > 16: return (None, None) - if data[0] == 0 and len(decoded) != 20 and len(decoded) != 32: # type: ignore + if data[0] == 0 and len(decoded) != 20 and len(decoded) != 32: return (None, None) - if ( - data[0] == 0 # type: ignore - and spec != Encoding.BECH32 - or data[0] != 0 # type: ignore - and spec != Encoding.BECH32M - ): + if data[0] == 0 and spec != Encoding.BECH32 or data[0] != 0 and spec != Encoding.BECH32M: return (None, None) - return (data[0], decoded) # type: ignore + return (data[0], decoded) def encode(hrp, witver, witprog): """Encode a segwit address.""" spec = Encoding.BECH32 if witver == 0 else Encoding.BECH32M - wit_prog = convertbits(witprog, 8, 5) - assert wit_prog - ret = bech32_encode(hrp, [witver, *wit_prog], spec) + ret = bech32_encode(hrp, [witver] + convertbits(witprog, 8, 5), spec) if decode(hrp, ret) == (None, None): return None return ret diff --git a/nostr/client/client.py b/nostr/client/client.py index d6fb5c8..db07a06 100644 --- a/nostr/client/client.py +++ b/nostr/client/client.py @@ -1,38 +1,25 @@ import asyncio - -from loguru import logger +from typing import List from ..relay_manager import RelayManager class NostrClient: - relay_manager: RelayManager - running: bool + relays = [ ] + relay_manager = RelayManager() - def __init__(self): - self.running = True - self.relay_manager = RelayManager() + def __init__(self, relays: List[str] = [], connect=True): + if len(relays): + self.relays = relays + if connect: + self.connect() - def connect(self, relays): - for relay in relays: - try: - self.relay_manager.add_relay(relay) - except Exception as e: - logger.debug(e) - self.running = True - - def reconnect(self, relays): - self.relay_manager.remove_relays() - self.connect(relays) + async def connect(self): + for relay in self.relays: + self.relay_manager.add_relay(relay) def close(self): - try: - self.relay_manager.close_all_subscriptions() - self.relay_manager.close_connections() - - self.running = False - except Exception as e: - logger.error(e) + self.relay_manager.close_connections() async def subscribe( self, @@ -40,36 +27,18 @@ class NostrClient: callback_notices_func=None, callback_eosenotices_func=None, ): - while self.running: - self._check_events(callback_events_func) - self._check_notices(callback_notices_func) - self._check_eos_notices(callback_eosenotices_func) - - await asyncio.sleep(0.2) - - def _check_events(self, callback_events_func=None): - try: + while True: while self.relay_manager.message_pool.has_events(): event_msg = self.relay_manager.message_pool.get_event() if callback_events_func: callback_events_func(event_msg) - except Exception as e: - logger.debug(e) - - def _check_notices(self, callback_notices_func=None): - try: while self.relay_manager.message_pool.has_notices(): event_msg = self.relay_manager.message_pool.get_notice() if callback_notices_func: callback_notices_func(event_msg) - except Exception as e: - logger.debug(e) - - def _check_eos_notices(self, callback_eosenotices_func=None): - try: while self.relay_manager.message_pool.has_eose_notices(): event_msg = self.relay_manager.message_pool.get_eose_notice() if callback_eosenotices_func: callback_eosenotices_func(event_msg) - except Exception as e: - logger.debug(e) + + await asyncio.sleep(0.5) diff --git a/nostr/delegation.py b/nostr/delegation.py new file mode 100644 index 0000000..94801f5 --- /dev/null +++ b/nostr/delegation.py @@ -0,0 +1,32 @@ +import time +from dataclasses import dataclass + + +@dataclass +class Delegation: + delegator_pubkey: str + delegatee_pubkey: str + event_kind: int + duration_secs: int = 30*24*60 # default to 30 days + signature: str = None # set in PrivateKey.sign_delegation + + @property + def expires(self) -> int: + return int(time.time()) + self.duration_secs + + @property + def conditions(self) -> str: + return f"kind={self.event_kind}&created_at<{self.expires}" + + @property + def delegation_token(self) -> str: + return f"nostr:delegation:{self.delegatee_pubkey}:{self.conditions}" + + def get_tag(self) -> list[str]: + """ Called by Event """ + return [ + "delegation", + self.delegator_pubkey, + self.conditions, + self.signature, + ] diff --git a/nostr/event.py b/nostr/event.py index 994c0f4..65b187d 100644 --- a/nostr/event.py +++ b/nostr/event.py @@ -3,9 +3,9 @@ import time from dataclasses import dataclass, field from enum import IntEnum from hashlib import sha256 -from typing import Optional +from typing import List -import coincurve +from secp256k1 import PublicKey from .message_type import ClientMessageType @@ -21,14 +21,14 @@ class EventKind(IntEnum): @dataclass class Event: - content: Optional[str] = None - public_key: Optional[str] = None - created_at: Optional[int] = None + content: str = None + public_key: str = None + created_at: int = None kind: int = EventKind.TEXT_NOTE - tags: list[list[str]] = field( + tags: List[List[str]] = field( default_factory=list ) # Dataclasses require special handling when the default value is a mutable type - signature: Optional[str] = None + signature: str = None def __post_init__(self): if self.content is not None and not isinstance(self.content, str): @@ -40,7 +40,7 @@ class Event: @staticmethod def serialize( - public_key: str, created_at: int, kind: int, tags: list[list[str]], content: str + public_key: str, created_at: int, kind: int, tags: List[List[str]], content: str ) -> bytes: data = [0, public_key, created_at, kind, tags, content] data_str = json.dumps(data, separators=(",", ":"), ensure_ascii=False) @@ -48,7 +48,7 @@ class Event: @staticmethod def compute_id( - public_key: str, created_at: int, kind: int, tags: list[list[str]], content: str + public_key: str, created_at: int, kind: int, tags: List[List[str]], content: str ): return sha256( Event.serialize(public_key, created_at, kind, tags, content) @@ -57,9 +57,6 @@ class Event: @property def id(self) -> str: # Always recompute the id to reflect the up-to-date state of the Event - assert self.public_key - assert self.created_at - assert self.content return Event.compute_id( self.public_key, self.created_at, self.kind, self.tags, self.content ) @@ -73,10 +70,12 @@ class Event: self.tags.append(["e", event_id]) def verify(self) -> bool: - assert self.public_key - assert self.signature - pub_key = coincurve.PublicKeyXOnly(bytes.fromhex(self.public_key)) - return pub_key.verify(bytes.fromhex(self.signature), bytes.fromhex(self.id)) + pub_key = PublicKey( + bytes.fromhex("02" + self.public_key), True + ) # add 02 for schnorr (bip340) + return pub_key.schnorr_verify( + bytes.fromhex(self.id), bytes.fromhex(self.signature), None, raw=True + ) def to_message(self) -> str: return json.dumps( @@ -97,9 +96,9 @@ class Event: @dataclass class EncryptedDirectMessage(Event): - recipient_pubkey: Optional[str] = None - cleartext_content: Optional[str] = None - reference_event_id: Optional[str] = None + recipient_pubkey: str = None + cleartext_content: str = None + reference_event_id: str = None def __post_init__(self): if self.content is not None: @@ -123,7 +122,6 @@ class EncryptedDirectMessage(Event): def id(self) -> str: if self.content is None: raise Exception( - "EncryptedDirectMessage `id` is undefined until its" - + " message is encrypted and stored in the `content` field" + "EncryptedDirectMessage `id` is undefined until its message is encrypted and stored in the `content` field" ) return super().id diff --git a/nostr/filter.py b/nostr/filter.py new file mode 100644 index 0000000..f119079 --- /dev/null +++ b/nostr/filter.py @@ -0,0 +1,134 @@ +from collections import UserList +from typing import List + +from .event import Event, EventKind + + +class Filter: + """ + NIP-01 filtering. + + Explicitly supports "#e" and "#p" tag filters via `event_refs` and `pubkey_refs`. + + Arbitrary NIP-12 single-letter tag filters are also supported via `add_arbitrary_tag`. + If a particular single-letter tag gains prominence, explicit support should be + added. For example: + # arbitrary tag + filter.add_arbitrary_tag('t', [hashtags]) + + # promoted to explicit support + Filter(hashtag_refs=[hashtags]) + """ + + def __init__( + self, + event_ids: List[str] = None, + kinds: List[EventKind] = None, + authors: List[str] = None, + since: int = None, + until: int = None, + event_refs: List[ + str + ] = None, # the "#e" attr; list of event ids referenced in an "e" tag + pubkey_refs: List[ + str + ] = None, # The "#p" attr; list of pubkeys referenced in a "p" tag + limit: int = None, + ) -> None: + self.event_ids = event_ids + self.kinds = kinds + self.authors = authors + self.since = since + self.until = until + self.event_refs = event_refs + self.pubkey_refs = pubkey_refs + self.limit = limit + + self.tags = {} + if self.event_refs: + self.add_arbitrary_tag("e", self.event_refs) + if self.pubkey_refs: + self.add_arbitrary_tag("p", self.pubkey_refs) + + def add_arbitrary_tag(self, tag: str, values: list): + """ + Filter on any arbitrary tag with explicit handling for NIP-01 and NIP-12 + single-letter tags. + """ + # NIP-01 'e' and 'p' tags and any NIP-12 single-letter tags must be prefixed with "#" + tag_key = tag if len(tag) > 1 else f"#{tag}" + self.tags[tag_key] = values + + def matches(self, event: Event) -> bool: + if self.event_ids is not None and event.id not in self.event_ids: + return False + if self.kinds is not None and event.kind not in self.kinds: + return False + if self.authors is not None and event.public_key not in self.authors: + return False + if self.since is not None and event.created_at < self.since: + return False + if self.until is not None and event.created_at > self.until: + return False + if (self.event_refs is not None or self.pubkey_refs is not None) and len( + event.tags + ) == 0: + return False + + if self.tags: + e_tag_identifiers = set([e_tag[0] for e_tag in event.tags]) + for f_tag, f_tag_values in self.tags.items(): + # Omit any NIP-01 or NIP-12 "#" chars on single-letter tags + f_tag = f_tag.replace("#", "") + + if f_tag not in e_tag_identifiers: + # Event is missing a tag type that we're looking for + return False + + # Multiple values within f_tag_values are treated as OR search; an Event + # needs to match only one. + # Note: an Event could have multiple entries of the same tag type + # (e.g. a reply to multiple people) so we have to check all of them. + match_found = False + for e_tag in event.tags: + if e_tag[0] == f_tag and e_tag[1] in f_tag_values: + match_found = True + break + if not match_found: + return False + + return True + + def to_json_object(self) -> dict: + res = {} + if self.event_ids is not None: + res["ids"] = self.event_ids + if self.kinds is not None: + res["kinds"] = self.kinds + if self.authors is not None: + res["authors"] = self.authors + if self.since is not None: + res["since"] = self.since + if self.until is not None: + res["until"] = self.until + if self.limit is not None: + res["limit"] = self.limit + if self.tags: + res.update(self.tags) + + return res + + +class Filters(UserList): + def __init__(self, initlist: "list[Filter]" = []) -> None: + super().__init__(initlist) + self.data: "list[Filter]" + + def match(self, event: Event): + for filter in self.data: + if filter.matches(event): + return True + return False + + def to_json_array(self) -> list: + return [filter.to_json_object() for filter in self.data] diff --git a/nostr/key.py b/nostr/key.py index f7b4e81..8089e11 100644 --- a/nostr/key.py +++ b/nostr/key.py @@ -1,11 +1,14 @@ import base64 import secrets +from hashlib import sha256 -import coincurve +import secp256k1 +from cffi import FFI from cryptography.hazmat.primitives import padding from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes -from .bech32 import Encoding, bech32_decode, bech32_encode, convertbits +from . import bech32 +from .delegation import Delegation from .event import EncryptedDirectMessage, Event, EventKind @@ -14,61 +17,55 @@ class PublicKey: self.raw_bytes = raw_bytes def bech32(self) -> str: - converted_bits = convertbits(self.raw_bytes, 8, 5) - return bech32_encode("npub", converted_bits, Encoding.BECH32) + converted_bits = bech32.convertbits(self.raw_bytes, 8, 5) + return bech32.bech32_encode("npub", converted_bits, bech32.Encoding.BECH32) def hex(self) -> str: return self.raw_bytes.hex() - def verify_signed_message_hash(self, message_hash: str, sig: str) -> bool: - pk = coincurve.PublicKeyXOnly(self.raw_bytes) - return pk.verify(bytes.fromhex(sig), bytes.fromhex(message_hash)) + def verify_signed_message_hash(self, hash: str, sig: str) -> bool: + pk = secp256k1.PublicKey(b"\x02" + self.raw_bytes, True) + return pk.schnorr_verify(bytes.fromhex(hash), bytes.fromhex(sig), None, True) @classmethod def from_npub(cls, npub: str): """Load a PublicKey from its bech32/npub form""" - _, data, _ = bech32_decode(npub) - raw_data = convertbits(data, 5, 8) - assert raw_data - raw_public_key = raw_data[:-1] + hrp, data, spec = bech32.bech32_decode(npub) + raw_public_key = bech32.convertbits(data, 5, 8)[:-1] return cls(bytes(raw_public_key)) class PrivateKey: - def __init__(self, raw_secret: bytes | None = None) -> None: - if raw_secret is not None: + def __init__(self, raw_secret: bytes = None) -> None: + if not raw_secret is None: self.raw_secret = raw_secret else: self.raw_secret = secrets.token_bytes(32) - sk = coincurve.PrivateKey(self.raw_secret) - assert sk.public_key - self.public_key = PublicKey(sk.public_key.format()[1:]) + sk = secp256k1.PrivateKey(self.raw_secret) + self.public_key = PublicKey(sk.pubkey.serialize()[1:]) @classmethod def from_nsec(cls, nsec: str): """Load a PrivateKey from its bech32/nsec form""" - _, data, _ = bech32_decode(nsec) - raw_data = convertbits(data, 5, 8) - assert raw_data - raw_secret = raw_data[:-1] + hrp, data, spec = bech32.bech32_decode(nsec) + raw_secret = bech32.convertbits(data, 5, 8)[:-1] return cls(bytes(raw_secret)) def bech32(self) -> str: - converted_bits = convertbits(self.raw_secret, 8, 5) - return bech32_encode("nsec", converted_bits, Encoding.BECH32) + converted_bits = bech32.convertbits(self.raw_secret, 8, 5) + return bech32.bech32_encode("nsec", converted_bits, bech32.Encoding.BECH32) def hex(self) -> str: return self.raw_secret.hex() def tweak_add(self, scalar: bytes) -> bytes: - sk = coincurve.PrivateKey(self.raw_secret) - return sk.add(scalar).to_der() + sk = secp256k1.PrivateKey(self.raw_secret) + return sk.tweak_add(scalar) def compute_shared_secret(self, public_key_hex: str) -> bytes: - pk = coincurve.PublicKey(bytes.fromhex("02" + public_key_hex)) - sk = coincurve.PrivateKey(self.raw_secret) - return sk.ecdh(pk.format()) + pk = secp256k1.PublicKey(bytes.fromhex("02" + public_key_hex), True) + return pk.ecdh(self.raw_secret, hashfn=copy_x) def encrypt_message(self, message: str, public_key_hex: str) -> str: padder = padding.PKCS7(128).padder() @@ -82,14 +79,9 @@ class PrivateKey: encryptor = cipher.encryptor() encrypted_message = encryptor.update(padded_data) + encryptor.finalize() - return ( - f"{base64.b64encode(encrypted_message).decode()}" - + f"?iv={base64.b64encode(iv).decode()}" - ) + return f"{base64.b64encode(encrypted_message).decode()}?iv={base64.b64encode(iv).decode()}" def encrypt_dm(self, dm: EncryptedDirectMessage) -> None: - assert dm.cleartext_content - assert dm.recipient_pubkey dm.content = self.encrypt_message( message=dm.cleartext_content, public_key_hex=dm.recipient_pubkey ) @@ -112,23 +104,28 @@ class PrivateKey: return unpadded_data.decode() - def sign_message_hash(self, message_hash: bytes) -> str: - sk = coincurve.PrivateKey(self.raw_secret) - sig = sk.sign_schnorr(message_hash) + def sign_message_hash(self, hash: bytes) -> str: + sk = secp256k1.PrivateKey(self.raw_secret) + sig = sk.schnorr_sign(hash, None, raw=True) return sig.hex() def sign_event(self, event: Event) -> None: if event.kind == EventKind.ENCRYPTED_DIRECT_MESSAGE and event.content is None: - self.encrypt_dm(event) # type: ignore + self.encrypt_dm(event) if event.public_key is None: event.public_key = self.public_key.hex() event.signature = self.sign_message_hash(bytes.fromhex(event.id)) + def sign_delegation(self, delegation: Delegation) -> None: + delegation.signature = self.sign_message_hash( + sha256(delegation.delegation_token.encode()).digest() + ) + def __eq__(self, other): return self.raw_secret == other.raw_secret -def mine_vanity_key(prefix: str | None = None, suffix: str | None = None) -> PrivateKey: +def mine_vanity_key(prefix: str = None, suffix: str = None) -> PrivateKey: if prefix is None and suffix is None: raise ValueError("Expected at least one of 'prefix' or 'suffix' arguments") @@ -144,3 +141,14 @@ def mine_vanity_key(prefix: str | None = None, suffix: str | None = None) -> Pri break return sk + + +ffi = FFI() + + +@ffi.callback( + "int (unsigned char *, const unsigned char *, const unsigned char *, void *)" +) +def copy_x(output, x32, y32, data): + ffi.memmove(output, x32, 32) + return 1 diff --git a/nostr/message_pool.py b/nostr/message_pool.py index a3e6c5f..02f7fd4 100644 --- a/nostr/message_pool.py +++ b/nostr/message_pool.py @@ -2,15 +2,13 @@ import json from queue import Queue from threading import Lock +from .event import Event from .message_type import RelayMessageType class EventMessage: - def __init__( - self, event: str, event_id: str, subscription_id: str, url: str - ) -> None: + def __init__(self, event: Event, subscription_id: str, url: str) -> None: self.event = event - self.event_id = event_id self.subscription_id = subscription_id self.url = url @@ -61,16 +59,18 @@ class MessagePool: message_type = message_json[0] if message_type == RelayMessageType.EVENT: subscription_id = message_json[1] - event = message_json[2] - if "id" not in event: - return - event_id = event["id"] - + e = message_json[2] + event = Event( + e["content"], + e["pubkey"], + e["created_at"], + e["kind"], + e["tags"], + e["sig"], + ) with self.lock: - if f"{subscription_id}_{event_id}" not in self._unique_events: - self._accept_event( - EventMessage(json.dumps(event), event_id, subscription_id, url) - ) + if not f"{subscription_id}_{event.id}" in self._unique_events: + self._accept_event(EventMessage(event, subscription_id, url)) elif message_type == RelayMessageType.NOTICE: self.notices.put(NoticeMessage(message_json[1], url)) elif message_type == RelayMessageType.END_OF_STORED_EVENTS: @@ -78,12 +78,10 @@ class MessagePool: def _accept_event(self, event_message: EventMessage): """ - Event uniqueness is considered per `subscription_id`. The `subscription_id` is - rewritten to be unique and it is the same accross relays. The same event can - come from different subscriptions (from the same client or from different ones). - Clients that have joined later should receive older events. + Event uniqueness is considered per `subscription_id`. + The `subscription_id` is rewritten to be unique and it is the same accross relays. + The same event can come from different subscriptions (from the same client or from different ones). + Clients that have joined later should receive older events. """ self.events.put(event_message) - self._unique_events.add( - f"{event_message.subscription_id}_{event_message.event_id}" - ) + self._unique_events.add(f"{event_message.subscription_id}_{event_message.event.id}") \ No newline at end of file diff --git a/nostr/relay.py b/nostr/relay.py index d762963..caacba0 100644 --- a/nostr/relay.py +++ b/nostr/relay.py @@ -2,34 +2,57 @@ import asyncio import json import time from queue import Queue +from threading import Lock +from typing import List from loguru import logger from websocket import WebSocketApp +from .event import Event +from .filter import Filters from .message_pool import MessagePool +from .message_type import RelayMessageType from .subscription import Subscription +class RelayPolicy: + def __init__(self, should_read: bool = True, should_write: bool = True) -> None: + self.should_read = should_read + self.should_write = should_write + + def to_json_object(self) -> dict[str, bool]: + return {"read": self.should_read, "write": self.should_write} + + class Relay: - def __init__(self, url: str, message_pool: MessagePool) -> None: + def __init__( + self, + url: str, + policy: RelayPolicy, + message_pool: MessagePool, + subscriptions: dict[str, Subscription] = {}, + ) -> None: self.url = url + self.policy = policy self.message_pool = message_pool + self.subscriptions = subscriptions self.connected: bool = False self.reconnect: bool = True self.shutdown: bool = False - self.error_counter: int = 0 self.error_threshold: int = 100 - self.error_list: list[str] = [] - self.notice_list: list[str] = [] + self.error_list: List[str] = [] + self.notice_list: List[str] = [] self.last_error_date: int = 0 self.num_received_events: int = 0 self.num_sent_events: int = 0 self.num_subscriptions: int = 0 + self.ssl_options: dict = {} + self.proxy: dict = {} + self.lock = Lock() + self.queue = Queue() - self.queue: Queue = Queue() - - def connect(self): + def connect(self, ssl_options: dict = None, proxy: dict = None): self.ws = WebSocketApp( self.url, on_open=self._on_open, @@ -39,14 +62,19 @@ class Relay: on_ping=self._on_ping, on_pong=self._on_pong, ) + self.ssl_options = ssl_options + self.proxy = proxy if not self.connected: - self.ws.run_forever(ping_interval=10) + self.ws.run_forever( + sslopt=ssl_options, + http_proxy_host=None if proxy is None else proxy.get("host"), + http_proxy_port=None if proxy is None else proxy.get("port"), + proxy_type=None if proxy is None else proxy.get("type"), + ping_interval=5, + ) def close(self): - try: - self.ws.close() - except Exception as e: - logger.warning(f"[Relay: {self.url}] Failed to close websocket: {e}") + self.ws.close() self.connected = False self.shutdown = True @@ -62,10 +90,10 @@ class Relay: def publish(self, message: str): self.queue.put(message) - def publish_subscriptions(self, subscriptions: list[Subscription]): - for s in subscriptions: - assert s.filters - json_str = json.dumps(["REQ", s.id, *s.filters]) + def publish_subscriptions(self): + for _, subscription in self.subscriptions.items(): + s = subscription.to_json_object() + json_str = json.dumps(["REQ", s["id"], s["filters"][0]]) self.publish(json_str) async def queue_worker(self): @@ -75,44 +103,55 @@ class Relay: message = self.queue.get(timeout=1) self.num_sent_events += 1 self.ws.send(message) - except Exception as _: + except: pass else: await asyncio.sleep(1) - + if self.shutdown: - logger.warning(f"[Relay: {self.url}] Closing queue worker.") - return + logger.warning(f"Closing queue worker for '{self.url}'.") + break - def close_subscription(self, sub_id: str) -> None: - try: - self.publish(json.dumps(["CLOSE", sub_id])) - except Exception as e: - logger.debug(f"[Relay: {self.url}] Failed to close subscription: {e}") + def add_subscription(self, id, filters: Filters): + with self.lock: + self.subscriptions[id] = Subscription(id, filters) + + def close_subscription(self, id: str) -> None: + with self.lock: + self.subscriptions.pop(id) + self.publish(json.dumps(["CLOSE", id])) + + def to_json_object(self) -> dict: + return { + "url": self.url, + "policy": self.policy.to_json_object(), + "subscriptions": [ + subscription.to_json_object() + for subscription in self.subscriptions.values() + ], + } def add_notice(self, notice: str): - self.notice_list = [notice, *self.notice_list] + self.notice_list = ([notice] + self.notice_list)[:20] def _on_open(self, _): - logger.info(f"[Relay: {self.url}] Connected.") + logger.info(f"Connected to relay: '{self.url}'.") self.connected = True - self.shutdown = False - + def _on_close(self, _, status_code, message): - logger.warning( - f"[Relay: {self.url}] Connection closed." - + f" Status: '{status_code}'. Message: '{message}'." - ) + logger.warning(f"Connection to relay {self.url} closed. Status: '{status_code}'. Message: '{message}'.") self.close() def _on_message(self, _, message: str): - self.num_received_events += 1 - self.message_pool.add_message(message, self.url) + if self._is_valid_message(message): + self.num_received_events += 1 + self.message_pool.add_message(message, self.url) def _on_error(self, _, error): - logger.warning(f"[Relay: {self.url}] Error: '{error!s}'") + logger.warning(f"Relay error: '{str(error)}'") self._append_error_message(str(error)) - self.close() + self.connected = False + self.error_counter += 1 def _on_ping(self, *_): return @@ -120,7 +159,65 @@ class Relay: def _on_pong(self, *_): return + def _is_valid_message(self, message: str) -> bool: + message = message.strip("\n") + if not message or message[0] != "[" or message[-1] != "]": + return False + + message_json = json.loads(message) + message_type = message_json[0] + + if not RelayMessageType.is_valid(message_type): + return False + + if message_type == RelayMessageType.EVENT: + return self._is_valid_event_message(message_json) + + if message_type == RelayMessageType.COMMAND_RESULT: + return self._is_valid_command_result_message(message, message_json) + + return True + + def _is_valid_event_message(self, message_json): + if not len(message_json) == 3: + return False + + subscription_id = message_json[1] + with self.lock: + if subscription_id not in self.subscriptions: + return False + + e = message_json[2] + event = Event( + e["content"], + e["pubkey"], + e["created_at"], + e["kind"], + e["tags"], + e["sig"], + ) + if not event.verify(): + return False + + with self.lock: + subscription = self.subscriptions[subscription_id] + + if subscription.filters and not subscription.filters.match(event): + return False + + return True + + def _is_valid_command_result_message(self, message, message_json): + if not len(message_json) < 3: + return False + + if message_json[2] != True: + logger.warning(f"Relay '{self.url}' negative command result: '{message}'") + self._append_error_message(message) + return False + + return True + def _append_error_message(self, message): - self.error_counter += 1 - self.error_list = [message, *self.error_list] - self.last_error_date = int(time.time()) + self.error_list = ([message] + self.error_list)[:20] + self.last_error_date = int(time.time()) \ No newline at end of file diff --git a/nostr/relay_manager.py b/nostr/relay_manager.py index 2aa27c5..f639fb0 100644 --- a/nostr/relay_manager.py +++ b/nostr/relay_manager.py @@ -1,15 +1,21 @@ + import asyncio +import ssl import threading import time -from typing import List from loguru import logger +from .filter import Filters from .message_pool import MessagePool, NoticeMessage -from .relay import Relay +from .relay import Relay, RelayPolicy from .subscription import Subscription +class RelayException(Exception): + pass + + class RelayManager: def __init__(self) -> None: self.relays: dict[str, Relay] = {} @@ -19,98 +25,72 @@ class RelayManager: self._cached_subscriptions: dict[str, Subscription] = {} self._subscriptions_lock = threading.Lock() - def add_relay(self, url: str) -> Relay: + def add_relay(self, url: str, read: bool = True, write: bool = True) -> Relay: if url in list(self.relays.keys()): - logger.debug(f"Relay '{url}' already present.") - return self.relays[url] + return + + with self._subscriptions_lock: + subscriptions = self._cached_subscriptions.copy() - relay = Relay(url, self.message_pool) + policy = RelayPolicy(read, write) + relay = Relay(url, policy, self.message_pool, subscriptions) self.relays[url] = relay - self._open_connection(relay) + self._open_connection( + relay, + {"cert_reqs": ssl.CERT_NONE} + ) # NOTE: This disables ssl certificate verification - relay.publish_subscriptions(list(self._cached_subscriptions.values())) + relay.publish_subscriptions() return relay def remove_relay(self, url: str): - try: - self.relays[url].close() - except Exception as e: - logger.debug(e) + self.relays[url].close() + self.relays.pop(url) + self.threads[url].join(timeout=5) + self.threads.pop(url) + self.queue_threads[url].join(timeout=5) + self.queue_threads.pop(url) + - if url in self.relays: - self.relays.pop(url) - - try: - self.threads[url].join(timeout=5) - except Exception as e: - logger.debug(e) - - if url in self.threads: - self.threads.pop(url) - - try: - self.queue_threads[url].join(timeout=5) - except Exception as e: - logger.debug(e) - - if url in self.queue_threads: - self.queue_threads.pop(url) - - def remove_relays(self): - relay_urls = list(self.relays.keys()) - for url in relay_urls: - self.remove_relay(url) - - def add_subscription(self, id: str, filters: List[str]): - s = Subscription(id, filters) + def add_subscription(self, id: str, filters: Filters): with self._subscriptions_lock: - self._cached_subscriptions[id] = s + self._cached_subscriptions[id] = Subscription(id, filters) for relay in self.relays.values(): - relay.publish_subscriptions([s]) + relay.add_subscription(id, filters) def close_subscription(self, id: str): - try: - logger.info(f"Closing subscription: '{id}'.") - with self._subscriptions_lock: - if id in self._cached_subscriptions: - self._cached_subscriptions.pop(id) + with self._subscriptions_lock: + self._cached_subscriptions.pop(id) - for relay in self.relays.values(): - relay.close_subscription(id) - except Exception as e: - logger.debug(e) - - def close_subscriptions(self, subscriptions: List[str]): - for id in subscriptions: - self.close_subscription(id) - - def close_all_subscriptions(self): - all_subscriptions = list(self._cached_subscriptions.keys()) - self.close_subscriptions(all_subscriptions) + for relay in self.relays.values(): + relay.close_subscription(id) def check_and_restart_relays(self): stopped_relays = [r for r in self.relays.values() if r.shutdown] for relay in stopped_relays: self._restart_relay(relay) + def close_connections(self): for relay in self.relays.values(): relay.close() def publish_message(self, message: str): for relay in self.relays.values(): - relay.publish(message) + if relay.policy.should_write: + relay.publish(message) def handle_notice(self, notice: NoticeMessage): relay = next((r for r in self.relays.values() if r.url == notice.url)) if relay: relay.add_notice(notice.content) - def _open_connection(self, relay: Relay): + def _open_connection(self, relay: Relay, ssl_options: dict = None, proxy: dict = None): self.threads[relay.url] = threading.Thread( target=relay.connect, + args=(ssl_options, proxy), name=f"{relay.url}-thread", daemon=True, ) @@ -118,7 +98,7 @@ class RelayManager: def wrap_async_queue_worker(): asyncio.run(relay.queue_worker()) - + self.queue_threads[relay.url] = threading.Thread( target=wrap_async_queue_worker, name=f"{relay.url}-queue", @@ -128,16 +108,14 @@ class RelayManager: def _restart_relay(self, relay: Relay): time_since_last_error = time.time() - relay.last_error_date - - min_wait_time = min( - 60 * relay.error_counter, 60 * 60 - ) # try at least once an hour + + min_wait_time = min(60 * relay.error_counter, 60 * 60 * 24) # try at least once a day if time_since_last_error < min_wait_time: return - + logger.info(f"Restarting connection to relay '{relay.url}'") self.remove_relay(relay.url) new_relay = self.add_relay(relay.url) new_relay.error_counter = relay.error_counter - new_relay.error_list = relay.error_list + new_relay.error_list = relay.error_list \ No newline at end of file diff --git a/nostr/subscription.py b/nostr/subscription.py index ed60f7e..76da0af 100644 --- a/nostr/subscription.py +++ b/nostr/subscription.py @@ -1,7 +1,13 @@ -from typing import Optional +from .filter import Filters class Subscription: - def __init__(self, id: str, filters: Optional[list[str]] = None) -> None: + def __init__(self, id: str, filters: Filters=None) -> None: self.id = id self.filters = filters + + def to_json_object(self): + return { + "id": self.id, + "filters": self.filters.to_json_array() + } diff --git a/package-lock.json b/package-lock.json deleted file mode 100644 index 1180ffb..0000000 --- a/package-lock.json +++ /dev/null @@ -1,59 +0,0 @@ -{ - "name": "nostrclient", - "version": "1.0.0", - "lockfileVersion": 3, - "requires": true, - "packages": { - "": { - "name": "nostrclient", - "version": "1.0.0", - "license": "ISC", - "dependencies": { - "prettier": "^3.2.5", - "pyright": "^1.1.358" - } - }, - "node_modules/fsevents": { - "version": "2.3.3", - "resolved": "https://registry.npmjs.org/fsevents/-/fsevents-2.3.3.tgz", - "integrity": "sha512-5xoDfX+fL7faATnagmWPpbFtwh/R77WmMMqqHGS65C3vvB0YHrgF+B1YmZ3441tMj5n63k0212XNoJwzlhffQw==", - "hasInstallScript": true, - "optional": true, - "os": [ - "darwin" - ], - "engines": { - "node": "^8.16.0 || ^10.6.0 || >=11.0.0" - } - }, - "node_modules/prettier": { - "version": "3.3.3", - "resolved": "https://registry.npmjs.org/prettier/-/prettier-3.3.3.tgz", - "integrity": "sha512-i2tDNA0O5IrMO757lfrdQZCc2jPNDVntV0m/+4whiDfWaTKfMNgR7Qz0NAeGz/nRqF4m5/6CLzbP4/liHt12Ew==", - "bin": { - "prettier": "bin/prettier.cjs" - }, - "engines": { - "node": ">=14" - }, - "funding": { - "url": "https://github.com/prettier/prettier?sponsor=1" - } - }, - "node_modules/pyright": { - "version": "1.1.374", - "resolved": "https://registry.npmjs.org/pyright/-/pyright-1.1.374.tgz", - "integrity": "sha512-ISbC1YnYDYrEatoKKjfaA5uFIp0ddC/xw9aSlN/EkmwupXUMVn41Jl+G6wHEjRhC+n4abHZeGpEvxCUus/K9dA==", - "bin": { - "pyright": "index.js", - "pyright-langserver": "langserver.index.js" - }, - "engines": { - "node": ">=14.0.0" - }, - "optionalDependencies": { - "fsevents": "~2.3.3" - } - } - } -} diff --git a/package.json b/package.json deleted file mode 100644 index 7b84315..0000000 --- a/package.json +++ /dev/null @@ -1,15 +0,0 @@ -{ - "name": "nostrclient", - "version": "1.0.0", - "description": "", - "main": "index.js", - "scripts": { - "test": "echo \"Error: no test specified\" && exit 1" - }, - "author": "", - "license": "ISC", - "dependencies": { - "prettier": "^3.2.5", - "pyright": "^1.1.358" - } -} diff --git a/pyproject.toml b/pyproject.toml deleted file mode 100644 index c7c0e56..0000000 --- a/pyproject.toml +++ /dev/null @@ -1,98 +0,0 @@ -[project] -name = "lnbits-nostrclient" -version = "1.1.0" -requires-python = ">=3.10,<3.13" -description = "LNbits, free and open-source Lightning wallet and accounts system." -authors = [{ name = "Alan Bits", email = "alan@lnbits.com" }] -urls = { Homepage = "https://lnbits.com", Repository = "https://github.com/lnbits/nostrclient" } -dependencies = [ "lnbits>1" ] - -[tool.poetry] -package-mode = false - -[tool.uv] -dev-dependencies = [ - "black", - "pytest-asyncio", - "pytest", - "mypy", - "pre-commit", - "ruff", - "pytest-md", - "types-cffi", -] - -[tool.mypy] -exclude = "(nostr/*)" -plugins = ["pydantic.mypy"] - -[[tool.mypy.overrides]] -module = [ - "nostr.*", -] -follow_imports = "skip" -ignore_missing_imports = "True" - -[tool.pydantic-mypy] -init_forbid_extra = true -init_typed = true -warn_required_dynamic_aliases = true -warn_untyped_fields = true - -[tool.pytest.ini_options] -log_cli = false -testpaths = [ - "tests" -] - -[tool.black] -line-length = 88 - -[tool.ruff] -# Same as Black. + 10% rule of black -line-length = 88 -exclude = [ - "nostr", -] - -[tool.ruff.lint] -# Enable: -# F - pyflakes -# E - pycodestyle errors -# W - pycodestyle warnings -# I - isort -# A - flake8-builtins -# C - mccabe -# N - naming -# UP - pyupgrade -# RUF - ruff -# B - bugbear -select = ["F", "E", "W", "I", "A", "C", "N", "UP", "RUF", "B"] -ignore = ["C901"] - -# Allow autofix for all enabled rules (when `--fix`) is provided. -fixable = ["ALL"] -unfixable = [] - -# Allow unused variables when underscore-prefixed. -dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" - -# needed for pydantic -[tool.ruff.lint.pep8-naming] -classmethod-decorators = [ - "root_validator", -] - -# Ignore unused imports in __init__.py files. -# [tool.ruff.lint.extend-per-file-ignores] -# "__init__.py" = ["F401", "F403"] - -# [tool.ruff.lint.mccabe] -# max-complexity = 10 - -[tool.ruff.lint.flake8-bugbear] -# Allow default arguments like, e.g., `data: List[str] = fastapi.Query(None)`. -extend-immutable-calls = [ - "fastapi.Depends", - "fastapi.Query", -] diff --git a/router.py b/router.py index a7054e9..cc0a380 100644 --- a/router.py +++ b/router.py @@ -1,85 +1,84 @@ import asyncio import json -from typing import ClassVar +from typing import List, Union -from fastapi import WebSocket, WebSocketDisconnect -from lnbits.helpers import urlsafe_short_hash +from fastapi import WebSocketDisconnect from loguru import logger -from .nostr.client.client import NostrClient +from lnbits.helpers import urlsafe_short_hash -# from . import nostr_client -from .nostr.message_pool import EndOfStoredEventsMessage, EventMessage, NoticeMessage - -nostr_client: NostrClient = NostrClient() -all_routers: list["NostrRouter"] = [] +from . import nostr +from .models import Event, Filter +from .nostr.filter import Filter as NostrFilter +from .nostr.filter import Filters as NostrFilters +from .nostr.message_pool import EndOfStoredEventsMessage, NoticeMessage class NostrRouter: - received_subscription_events: ClassVar[dict[str, list[EventMessage]]] = {} - received_subscription_notices: ClassVar[list[NoticeMessage]] = [] - received_subscription_eosenotices: ClassVar[dict[str, EndOfStoredEventsMessage]] = ( - {} - ) - def __init__(self, websocket: WebSocket): + received_subscription_events: dict[str, list[Event]] = {} + received_subscription_notices: list[NoticeMessage] = [] + received_subscription_eosenotices: dict[str, EndOfStoredEventsMessage] = {} + + def __init__(self, websocket): + self.subscriptions: List[str] = [] self.connected: bool = True - self.websocket: WebSocket = websocket - self.tasks: list[asyncio.Task] = [] - self.original_subscription_ids: dict[str, str] = {} + self.websocket = websocket + self.tasks: List[asyncio.Task] = [] + self.original_subscription_ids = {} - @property - def subscriptions(self) -> list[str]: - return list(self.original_subscription_ids.keys()) - - def start(self): - self.connected = True - self.tasks.append(asyncio.create_task(self._client_to_nostr())) - self.tasks.append(asyncio.create_task(self._nostr_to_client())) - - async def stop(self): - nostr_client.relay_manager.close_subscriptions(self.subscriptions) - self.connected = False - - for t in self.tasks: - try: - t.cancel() - except Exception as _: - pass - - try: - await self.websocket.close(reason="Websocket connection closed") - except Exception as _: - pass - - async def _client_to_nostr(self): - """ - Receives requests / data from the client and forwards it to relays. - """ - while self.connected: + async def client_to_nostr(self): + """Receives requests / data from the client and forwards it to relays. If the + request was a subscription/filter, registers it with the nostr client lib. + Remembers the subscription id so we can send back responses from the relay to this + client in `nostr_to_client`""" + while True: try: json_str = await self.websocket.receive_text() - except WebSocketDisconnect as e: - logger.debug(e) - await self.stop() + except WebSocketDisconnect: + self.connected = False break try: await self._handle_client_to_nostr(json_str) except Exception as e: - logger.debug(f"Failed to handle client message: '{e!s}'.") + logger.debug(f"Failed to handle client message: '{str(e)}'.") - async def _nostr_to_client(self): - """Sends responses from relays back to the client.""" - while self.connected: + + async def nostr_to_client(self): + """Sends responses from relays back to the client. Polls the subscriptions of this client + stored in `my_subscriptions`. Then gets all responses for this subscription id from `received_subscription_events` which + is filled in tasks.py. Takes one response after the other and relays it back to the client. Reconstructs + the reponse manually because the nostr client lib we're using can't do it. Reconstructs the original subscription id + that we had previously rewritten in order to avoid collisions when multiple clients use the same id. + """ + while True and self.connected: try: await self._handle_subscriptions() self._handle_notices() except Exception as e: - logger.debug(f"Failed to handle response for client: '{e!s}'.") - await asyncio.sleep(1) + logger.debug(f"Failed to handle response for client: '{str(e)}'.") await asyncio.sleep(0.1) + + async def start(self): + self.tasks.append(asyncio.create_task(self.client_to_nostr())) + self.tasks.append(asyncio.create_task(self.nostr_to_client())) + + async def stop(self): + for t in self.tasks: + try: + t.cancel() + except: + pass + + for s in self.subscriptions: + try: + nostr.client.relay_manager.close_subscription(s) + except: + pass + self.connected = False + async def _handle_subscriptions(self): for s in self.subscriptions: if s in NostrRouter.received_subscription_events: @@ -87,6 +86,8 @@ class NostrRouter: if s in NostrRouter.received_subscription_eosenotices: await self._handle_received_subscription_eosenotices(s) + + async def _handle_received_subscription_eosenotices(self, s): try: if s not in self.original_subscription_ids: @@ -94,7 +95,7 @@ class NostrRouter: s_original = self.original_subscription_ids[s] event_to_forward = ["EOSE", s_original] del NostrRouter.received_subscription_eosenotices[s] - + await self.websocket.send_text(json.dumps(event_to_forward)) except Exception as e: logger.debug(e) @@ -103,73 +104,97 @@ class NostrRouter: try: if s not in NostrRouter.received_subscription_events: return - while len(NostrRouter.received_subscription_events[s]): - event_message = NostrRouter.received_subscription_events[s].pop(0) - event_json = event_message.event + my_event = NostrRouter.received_subscription_events[s].pop(0) + # event.to_message() does not include the subscription ID, we have to add it manually + event_json = { + "id": my_event.id, + "pubkey": my_event.public_key, + "created_at": my_event.created_at, + "kind": my_event.kind, + "tags": my_event.tags, + "content": my_event.content, + "sig": my_event.signature, + } # this reconstructs the original response from the relay # reconstruct original subscription id s_original = self.original_subscription_ids[s] - event_to_forward = json.dumps( - ["EVENT", s_original, json.loads(event_json)] - ) - await self.websocket.send_text(event_to_forward) + event_to_forward = ["EVENT", s_original, event_json] + await self.websocket.send_text(json.dumps(event_to_forward)) except Exception as e: - logger.warning( - f"[NOSTRCLIENT] Error in _handle_received_subscription_events: {e}" - ) + logger.debug(e) def _handle_notices(self): while len(NostrRouter.received_subscription_notices): my_event = NostrRouter.received_subscription_notices.pop(0) - logger.debug(f"[Relay '{my_event.url}'] Notice: '{my_event.content}']") - # Note: we don't send it to the user because - # we don't know who should receive it - nostr_client.relay_manager.handle_notice(my_event) + # note: we don't send it to the user because we don't know who should receive it + logger.info(f"Relay ('{my_event.url}') notice: '{my_event.content}']") + nostr.client.relay_manager.handle_notice(my_event) + + + + def _marshall_nostr_filters(self, data: Union[dict, list]): + filters = data if isinstance(data, list) else [data] + filters = [Filter.parse_obj(f) for f in filters] + filter_list: list[NostrFilter] = [] + for filter in filters: + filter_list.append( + NostrFilter( + event_ids=filter.ids, # type: ignore + kinds=filter.kinds, # type: ignore + authors=filter.authors, # type: ignore + since=filter.since, # type: ignore + until=filter.until, # type: ignore + event_refs=filter.e, # type: ignore + pubkey_refs=filter.p, # type: ignore + limit=filter.limit, # type: ignore + ) + ) + return NostrFilters(filter_list) async def _handle_client_to_nostr(self, json_str): + """Parses a (string) request from a client. If it is a subscription (REQ) or a CLOSE, it will + register the subscription in the nostr client library that we're using so we can + receive the callbacks on it later. Will rewrite the subscription id since we expect + multiple clients to use the router and want to avoid subscription id collisions + """ + json_data = json.loads(json_str) - assert len(json_data), "Bad JSON array" + assert len(json_data) + if json_data[0] == "REQ": self._handle_client_req(json_data) return - + if json_data[0] == "CLOSE": self._handle_client_close(json_data[1]) return if json_data[0] == "EVENT": - nostr_client.relay_manager.publish_message(json_str) + nostr.client.relay_manager.publish_message(json_str) return def _handle_client_req(self, json_data): subscription_id = json_data[1] - logger.info(f"New subscription: '{subscription_id}'") subscription_id_rewritten = urlsafe_short_hash() self.original_subscription_ids[subscription_id_rewritten] = subscription_id - filters = json_data[2:] + fltr = json_data[2:] + filters = self._marshall_nostr_filters(fltr) - nostr_client.relay_manager.add_subscription(subscription_id_rewritten, filters) + nostr.client.relay_manager.add_subscription( + subscription_id_rewritten, filters + ) + request_rewritten = json.dumps([json_data[0], subscription_id_rewritten] + fltr) + + self.subscriptions.append(subscription_id_rewritten) + nostr.client.relay_manager.publish_message(request_rewritten) def _handle_client_close(self, subscription_id): - subscription_id_rewritten = next( - ( - k - for k, v in self.original_subscription_ids.items() - if v == subscription_id - ), - None, - ) + subscription_id_rewritten = next((k for k, v in self.original_subscription_ids.items() if v == subscription_id), None) if subscription_id_rewritten: self.original_subscription_ids.pop(subscription_id_rewritten) - nostr_client.relay_manager.close_subscription(subscription_id_rewritten) - logger.info( - f""" - Unsubscribe from '{subscription_id_rewritten}'. - Original id: '{subscription_id}.' - """ - ) + nostr.client.relay_manager.close_subscription(subscription_id_rewritten) else: - logger.info(f"Failed to unsubscribe from '{subscription_id}.'") + logger.debug(f"Failed to unsubscribe from '{subscription_id}.'") diff --git a/static/images/1.jpeg b/static/images/1.jpeg deleted file mode 100644 index 1f00661..0000000 Binary files a/static/images/1.jpeg and /dev/null differ diff --git a/static/images/2.jpeg b/static/images/2.jpeg deleted file mode 100644 index e7301fd..0000000 Binary files a/static/images/2.jpeg and /dev/null differ diff --git a/tasks.py b/tasks.py index 2a76765..4c316bc 100644 --- a/tasks.py +++ b/tasks.py @@ -3,66 +3,75 @@ import threading from loguru import logger +from . import nostr from .crud import get_relays from .nostr.message_pool import EndOfStoredEventsMessage, EventMessage, NoticeMessage -from .router import NostrRouter, nostr_client +from .router import NostrRouter, nostr async def init_relays(): + # reinitialize the entire client + nostr.__init__() # get relays from db relays = await get_relays() # set relays and connect to them - valid_relays = [r.url for r in relays if r.url] - - nostr_client.reconnect(valid_relays) + nostr.client.relays = list(set([r.url for r in relays.__root__ if r.url])) + await nostr.client.connect() async def check_relays(): - """Check relays that have been disconnected""" + """ Check relays that have been disconnected """ while True: try: await asyncio.sleep(20) - nostr_client.relay_manager.check_and_restart_relays() + nostr.client.relay_manager.check_and_restart_relays() except Exception as e: - logger.warning(f"Cannot restart relays: '{e!s}'.") - + logger.warning(f"Cannot restart relays: '{str(e)}'.") + async def subscribe_events(): - while not [r.connected for r in nostr_client.relay_manager.relays.values()]: + while not any([r.connected for r in nostr.client.relay_manager.relays.values()]): await asyncio.sleep(2) - def callback_events(event_message: EventMessage): - sub_id = event_message.subscription_id - if sub_id not in NostrRouter.received_subscription_events: - NostrRouter.received_subscription_events[sub_id] = [event_message] - return + def callback_events(eventMessage: EventMessage): + if eventMessage.subscription_id in NostrRouter.received_subscription_events: + # do not add duplicate events (by event id) + if eventMessage.event.id in set( + [ + e.id + for e in NostrRouter.received_subscription_events[eventMessage.subscription_id] + ] + ): + return - # do not add duplicate events (by event id) - ids = [e.event_id for e in NostrRouter.received_subscription_events[sub_id]] - if event_message.event_id in ids: - return + NostrRouter.received_subscription_events[eventMessage.subscription_id].append( + eventMessage.event + ) + else: + NostrRouter.received_subscription_events[eventMessage.subscription_id] = [ + eventMessage.event + ] + return - NostrRouter.received_subscription_events[sub_id].append(event_message) + def callback_notices(noticeMessage: NoticeMessage): + if noticeMessage not in NostrRouter.received_subscription_notices: + NostrRouter.received_subscription_notices.append(noticeMessage) + return - def callback_notices(notice_message: NoticeMessage): - if notice_message not in NostrRouter.received_subscription_notices: - NostrRouter.received_subscription_notices.append(notice_message) + def callback_eose_notices(eventMessage: EndOfStoredEventsMessage): + if eventMessage.subscription_id not in NostrRouter.received_subscription_eosenotices: + NostrRouter.received_subscription_eosenotices[ + eventMessage.subscription_id + ] = eventMessage - def callback_eose_notices(event_message: EndOfStoredEventsMessage): - sub_id = event_message.subscription_id - if sub_id in NostrRouter.received_subscription_eosenotices: - return - - NostrRouter.received_subscription_eosenotices[sub_id] = event_message + return def wrap_async_subscribe(): - asyncio.run( - nostr_client.subscribe( - callback_events, - callback_notices, - callback_eose_notices, - ) - ) + asyncio.run(nostr.client.subscribe( + callback_events, + callback_notices, + callback_eose_notices, + )) t = threading.Thread( target=wrap_async_subscribe, diff --git a/templates/nostrclient/index.html b/templates/nostrclient/index.html index ca44f1b..a0c5999 100644 --- a/templates/nostrclient/index.html +++ b/templates/nostrclient/index.html @@ -4,47 +4,21 @@
-
-
- +
+
+
-
- - +
+ + +
-
- -
@@ -55,36 +29,18 @@
Nostrclient
- +
- +