diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..5bcdee7 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,29 @@ +name: CI +on: + push: + branches: + - main + pull_request: + +jobs: + lint: + uses: lnbits/lnbits/.github/workflows/lint.yml@dev + tests: + runs-on: ubuntu-latest + needs: [lint] + steps: + - uses: actions/checkout@v4 + - uses: lnbits/lnbits/.github/actions/prepare@dev + - name: Run pytest + uses: pavelzw/pytest-action@v2 + env: + LNBITS_BACKEND_WALLET_CLASS: FakeWallet + PYTHONUNBUFFERED: 1 + DEBUG: true + with: + verbose: true + job-summary: true + emoji: false + click-to-expand: true + custom-pytest: uv run pytest + report-title: 'test' diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml new file mode 100644 index 0000000..27c8a60 --- /dev/null +++ b/.github/workflows/release.yml @@ -0,0 +1,58 @@ +on: + push: + tags: + - 'v[0-9]+.[0-9]+.[0-9]+' + +jobs: + release: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: Create github release + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + tag: ${{ github.ref_name }} + run: | + gh release create "$tag" --generate-notes + + pullrequest: + needs: [release] + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + with: + token: ${{ secrets.EXT_GITHUB }} + repository: lnbits/lnbits-extensions + path: './lnbits-extensions' + + - name: setup git user + run: | + git config --global user.name "alan" + git config --global user.email "alan@lnbits.com" + + - name: Create pull request in extensions repo + env: + GH_TOKEN: ${{ secrets.EXT_GITHUB }} + repo_name: '${{ github.event.repository.name }}' + tag: '${{ github.ref_name }}' + branch: 'update-${{ github.event.repository.name }}-${{ github.ref_name }}' + title: '[UPDATE] ${{ github.event.repository.name }} to ${{ github.ref_name }}' + body: 'https://github.com/lnbits/${{ github.event.repository.name }}/releases/${{ github.ref_name }}' + archive: 'https://github.com/lnbits/${{ github.event.repository.name }}/archive/refs/tags/${{ github.ref_name }}.zip' + run: | + cd lnbits-extensions + git checkout -b $branch + + # if there is another open PR + git pull origin $branch || echo "branch does not exist" + + sh util.sh update_extension $repo_name $tag + + git add -A + git commit -am "$title" + git push origin $branch + + # check if pr exists before creating it + gh config set pager cat + check=$(gh pr list -H $branch | wc -l) + test $check -ne 0 || gh pr create --title "$title" --body "$body" --repo lnbits/lnbits-extensions diff --git a/.gitignore b/.gitignore index 10a11d5..0152b6e 100644 --- a/.gitignore +++ b/.gitignore @@ -1,24 +1,4 @@ -.DS_Store -._* - __pycache__ -*.py[cod] -*$py.class +node_modules .mypy_cache -.vscode -*-lock.json - -*.egg -*.egg-info -.coverage -.pytest_cache -.webassets-cache -htmlcov -test-reports -tests/data/*.sqlite3 - -*.swo -*.swp -*.pyo -*.pyc -*.env \ No newline at end of file +.venv diff --git a/.prettierrc b/.prettierrc new file mode 100644 index 0000000..725c398 --- /dev/null +++ b/.prettierrc @@ -0,0 +1,12 @@ +{ + "semi": false, + "arrowParens": "avoid", + "insertPragma": false, + "printWidth": 80, + "proseWrap": "preserve", + "singleQuote": true, + "trailingComma": "none", + "useTabs": false, + "bracketSameLine": false, + "bracketSpacing": false +} diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..0fac253 --- /dev/null +++ b/Makefile @@ -0,0 +1,47 @@ +all: format check + +format: prettier black ruff + +check: mypy pyright checkblack checkruff checkprettier + +prettier: + uv run ./node_modules/.bin/prettier --write . +pyright: + uv run ./node_modules/.bin/pyright + +mypy: + uv run mypy . + +black: + uv run black . + +ruff: + uv run ruff check . --fix + +checkruff: + uv run ruff check . + +checkprettier: + uv run ./node_modules/.bin/prettier --check . + +checkblack: + uv run black --check . + +checkeditorconfig: + editorconfig-checker + +test: + PYTHONUNBUFFERED=1 \ + DEBUG=true \ + uv run pytest +install-pre-commit-hook: + @echo "Installing pre-commit hook to git" + @echo "Uninstall the hook with uv run pre-commit uninstall" + uv run pre-commit install + +pre-commit: + uv run pre-commit run --all-files + + +checkbundle: + @echo "skipping checkbundle" diff --git a/README.md b/README.md index 7fbb640..70593a8 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,124 @@ # Nostrclient - [LNbits](https://github.com/lnbits/lnbits) extension + For more about LNBits extension check [this tutorial](https://github.com/lnbits/lnbits/wiki/LNbits-Extensions) -`nostrclient` is an always-on extension that can open multiple connections to nostr relays and act as a multiplexer for other clients: You open a single websocket to `nostrclient` which then sends the data to multiple relays. The responses from these relays are then sent back to the client. +## Overview -![2023-03-08 18 11 07](https://user-images.githubusercontent.com/93376500/225265727-369f0f8a-196e-41df-a0d1-98b50a0228be.jpg) +`nostrclient` is an always-on Nostr relay multiplexer that simplifies connecting to multiple Nostr relays. Instead of your Nostr client managing connections to dozens of relays, you connect to a single WebSocket endpoint provided by `nostrclient`, which then fans out your requests to all configured relays and aggregates the responses back to you. + +### Why Use This? + +- **Simplified Client Configuration** - Connect to one endpoint instead of managing multiple relay connections +- **Always-On Connectivity** - Your LNbits instance maintains persistent connections to relays +- **Resource Efficient** - Share relay connections across multiple clients +- **Subscription Management** - Automatic subscription ID rewriting prevents conflicts between clients + +## Architecture + +```mermaid +flowchart LR + A[Client A] -->|WebSocket| N + B[Client B] -->|WebSocket| N + C[Client C] -->|WebSocket| N + + N[nostrclient
Router] -->|Fan Out| R1[Relay A] + N -->|Fan Out| R2[Relay B] + N -->|Fan Out| R3[Relay C] + N -->|Fan Out| R4[Relay D] + + R1 -.->|Aggregate| N + R2 -.->|Aggregate| N + R3 -.->|Aggregate| N + R4 -.->|Aggregate| N +``` + +**Key Feature:** The router rewrites subscription IDs to prevent conflicts when multiple clients use the same IDs. + +## Features + +- **Multi-Relay Multiplexing** - Connect to multiple Nostr relays through a single WebSocket +- **Public & Private Endpoints** - Configurable public and private WebSocket access +- **Automatic Reconnection** - Failed relays are automatically retried with exponential backoff +- **Subscription Deduplication** - Events are deduplicated before being sent to clients +- **Health Monitoring** - Track relay connection status, latency, and error rates +- **Test Endpoint** - Send test messages to verify your setup is working + +## How It Works + +1. **Client Connection** - Your Nostr client connects to the nostrclient WebSocket endpoint +2. **Subscription Rewriting** - Each subscription ID is rewritten to prevent conflicts between multiple clients +3. **Fan-Out** - Subscription requests are sent to all configured relays +4. **Aggregation** - Events from all relays are collected and deduplicated +5. **Response** - Events are sent back to the client with the original subscription ID + +## Configuration + +### WebSocket Endpoints + +- **Public Endpoint**: `/api/v1/relay` - Available to anyone (if enabled) +- **Private Endpoint**: `/api/v1/{encrypted_id}` - Requires valid encrypted endpoint ID + +Configure endpoint access in the extension settings: + +- `private_ws` - Enable/disable private WebSocket access +- `public_ws` - Enable/disable public WebSocket access + +### Adding Relays + +Use the nostrclient UI to add/remove Nostr relays. The extension will automatically: + +- Connect to new relays +- Publish existing subscriptions to new relays +- Monitor relay health and reconnect as needed + +## Testing + +### Test Endpoint Functionality + +The `Test Endpoint` feature helps verify that your nostrclient WebSocket endpoint works correctly. + +**How to test:** + +1. Navigate to the nostrclient extension in LNbits +2. Use the Test Endpoint feature +3. Send a DM to yourself (or a temporary account) +4. Verify that messages are sent and received correctly + +https://user-images.githubusercontent.com/2951406/236780745-929c33c2-2502-49be-84a3-db02a7aabc0e.mp4 + +## Troubleshooting + +### Connection Issues + +- **Check relay status** - View relay health in the nostrclient UI +- **Verify endpoint configuration** - Ensure public_ws or private_ws is enabled +- **Check logs** - Review LNbits logs for connection errors + +### Subscription Not Receiving Events + +- **Verify relays are connected** - Check the relay status in the UI +- **Test with known event** - Use the Test Endpoint to verify connectivity +- **Check relay compatibility** - Some relays may not support all Nostr features + +## Development + +This extension uses `uv` for dependency management. + +### Quick Start + +```bash +# Format code +make format + +# Run type checks and linting +make check + +# Run tests +make test +``` + +For more development commands, see the [Makefile](./Makefile). + +## License + +MIT License - see [LICENSE](./LICENSE) diff --git a/__init__.py b/__init__.py index 019df68..d7eb435 100644 --- a/__init__.py +++ b/__init__.py @@ -1,40 +1,59 @@ import asyncio -from typing import List from fastapi import APIRouter -from starlette.staticfiles import StaticFiles +from loguru import logger -from lnbits.db import Database -from lnbits.helpers import template_renderer -from lnbits.tasks import catch_everything_and_restart - -db = Database("ext_nostrclient") +from .crud import db +from .router import all_routers, nostr_client +from .tasks import check_relays, init_relays, subscribe_events +from .views import nostrclient_generic_router +from .views_api import nostrclient_api_router nostrclient_static_files = [ { "path": "/nostrclient/static", - "app": StaticFiles(directory="lnbits/extensions/nostrclient/static"), "name": "nostrclient_static", } ] nostrclient_ext: APIRouter = APIRouter(prefix="/nostrclient", tags=["nostrclient"]) - -scheduled_tasks: List[asyncio.Task] = [] +nostrclient_ext.include_router(nostrclient_generic_router) +nostrclient_ext.include_router(nostrclient_api_router) +scheduled_tasks: list[asyncio.Task] = [] -def nostr_renderer(): - return template_renderer(["lnbits/extensions/nostrclient/templates"]) +async def nostrclient_stop(): + for task in scheduled_tasks: + try: + task.cancel() + except Exception as ex: + logger.warning(ex) + for router in all_routers: + try: + await router.stop() + all_routers.remove(router) + except Exception as e: + logger.error(e) -from .tasks import init_relays, subscribe_events -from .views import * # noqa -from .views_api import * # noqa + nostr_client.close() def nostrclient_start(): - loop = asyncio.get_event_loop() - task1 = loop.create_task(catch_everything_and_restart(init_relays)) - scheduled_tasks.append(task1) - task2 = loop.create_task(catch_everything_and_restart(subscribe_events)) - scheduled_tasks.append(task2) + from lnbits.tasks import create_permanent_unique_task + + task1 = create_permanent_unique_task("ext_nostrclient_init_relays", init_relays) + task2 = create_permanent_unique_task( + "ext_nostrclient_subscrive_events", subscribe_events + ) + task3 = create_permanent_unique_task("ext_nostrclient_check_relays", check_relays) + scheduled_tasks.extend([task1, task2, task3]) + + +__all__ = [ + "db", + "nostrclient_ext", + "nostrclient_start", + "nostrclient_static_files", + "nostrclient_stop", +] diff --git a/cbc.py b/cbc.py deleted file mode 100644 index 0d9e04f..0000000 --- a/cbc.py +++ /dev/null @@ -1,26 +0,0 @@ -from Cryptodome.Cipher import AES - -BLOCK_SIZE = 16 - - -class AESCipher(object): - """This class is compatible with crypto.createCipheriv('aes-256-cbc')""" - - def __init__(self, key=None): - self.key = key - - def pad(self, data): - length = BLOCK_SIZE - (len(data) % BLOCK_SIZE) - return data + (chr(length) * length).encode() - - def unpad(self, data): - return data[: -(data[-1] if type(data[-1]) == int else ord(data[-1]))] - - def encrypt(self, plain_text): - cipher = AES.new(self.key, AES.MODE_CBC) - b = plain_text.encode("UTF-8") - return cipher.iv, cipher.encrypt(self.pad(b)) - - def decrypt(self, iv, enc_text): - cipher = AES.new(self.key, AES.MODE_CBC, iv=iv) - return self.unpad(cipher.decrypt(enc_text).decode("UTF-8")) diff --git a/config.json b/config.json index ce8ae18..1f58e7b 100644 --- a/config.json +++ b/config.json @@ -1,6 +1,17 @@ { "name": "Nostr Client", - "short_description": "Nostr client for extensions", + "short_description": "Nostr relay multiplexer", + "version": "1.1.0", "tile": "/nostrclient/static/images/nostr-bitcoin.png", - "contributors": ["calle"] + "contributors": ["calle", "motorina0", "dni"], + "min_lnbits_version": "1.4.0", + "images": [ + { + "uri": "https://raw.githubusercontent.com/lnbits/nostrclient/add-extension-metadata/static/images/1.jpeg" + }, + { + "uri": "https://raw.githubusercontent.com/lnbits/nostrclient/add-extension-metadata/static/images/2.jpeg" + } + ], + "description_md": "https://raw.githubusercontent.com/lnbits/nostrclient/add-extension-metadata/description.md" } diff --git a/crud.py b/crud.py index 780642d..d311c72 100644 --- a/crud.py +++ b/crud.py @@ -1,31 +1,52 @@ -from typing import List, Optional, Union +from lnbits.db import Database -import shortuuid +from .models import Config, Relay, UserConfig -from lnbits.helpers import urlsafe_short_hash - -from . import db -from .models import Relay, RelayList +db = Database("ext_nostrclient") -async def get_relays() -> RelayList: - row = await db.fetchall("SELECT * FROM nostrclient.relays") - return RelayList(__root__=row) - - -async def add_relay(relay: Relay) -> None: - await db.execute( - f""" - INSERT INTO nostrclient.relays ( - id, - url, - active - ) - VALUES (?, ?, ?) - """, - (relay.id, relay.url, relay.active), +async def get_relays() -> list[Relay]: + return await db.fetchall( + "SELECT * FROM nostrclient.relays", + model=Relay, ) +async def add_relay(relay: Relay) -> Relay: + await db.insert("nostrclient.relays", relay) + return relay + + async def delete_relay(relay: Relay) -> None: - await db.execute("DELETE FROM nostrclient.relays WHERE url = ?", (relay.url,)) + if not relay.url: + return + await db.execute( + "DELETE FROM nostrclient.relays WHERE url = :url", {"url": relay.url} + ) + + +######################CONFIG####################### +async def create_config(owner_id: str) -> Config: + admin_config = UserConfig(owner_id=owner_id) + await db.insert("nostrclient.config", admin_config) + return admin_config.extra + + +async def update_config(owner_id: str, config: Config) -> Config: + user_config = UserConfig(owner_id=owner_id, extra=config) + await db.update("nostrclient.config", user_config, "WHERE owner_id = :owner_id") + return user_config.extra + + +async def get_config(owner_id: str) -> Config | None: + user_config: UserConfig = await db.fetchone( + """ + SELECT * FROM nostrclient.config + WHERE owner_id = :owner_id + """, + {"owner_id": owner_id}, + model=UserConfig, + ) + if user_config: + return user_config.extra + return None diff --git a/description.md b/description.md new file mode 100644 index 0000000..5293087 --- /dev/null +++ b/description.md @@ -0,0 +1,8 @@ +An always-on relay multiplexer that simplifies connecting to multiple Nostr relays. + +Instead of your Nostr client managing connections to dozens of relays, you connect to a single WebSocket endpoint provided by `nostrclient`, which then fans out your requests to all configured relays and aggregates the responses back to you. + +- **Simplified Client Configuration** - Connect to one endpoint instead of managing multiple relay connections +- **Always-On Connectivity** - Your LNbits instance maintains persistent connections to relays +- **Resource Efficient** - Share relay connections across multiple clients +- **Automatic Subscription Management** - Subscription ID rewriting prevents conflicts between clients diff --git a/migrations.py b/migrations.py index 5a30e45..b16db58 100644 --- a/migrations.py +++ b/migrations.py @@ -3,7 +3,7 @@ async def m001_initial(db): Initial nostrclient table. """ await db.execute( - f""" + """ CREATE TABLE nostrclient.relays ( id TEXT NOT NULL PRIMARY KEY, url TEXT NOT NULL, @@ -11,3 +11,22 @@ async def m001_initial(db): ); """ ) + + +async def m002_create_config_table(db): + """ + Allow the extension to persist and retrieve any number of config values. + """ + + await db.execute( + """CREATE TABLE nostrclient.config ( + json_data TEXT NOT NULL + );""" + ) + + +async def m003_update_config_table(db): + await db.execute("ALTER TABLE nostrclient.config RENAME COLUMN json_data TO extra") + await db.execute( + "ALTER TABLE nostrclient.config ADD COLUMN owner_id TEXT DEFAULT 'admin'" + ) diff --git a/models.py b/models.py index 1456d83..937c6c5 100644 --- a/models.py +++ b/models.py @@ -1,101 +1,54 @@ -from dataclasses import dataclass -from typing import Dict, List, Optional - -from fastapi import Request -from fastapi.param_functions import Query +from lnbits.helpers import urlsafe_short_hash from pydantic import BaseModel, Field -from lnbits.helpers import urlsafe_short_hash + +class RelayStatus(BaseModel): + num_sent_events: int | None = 0 + num_received_events: int | None = 0 + error_counter: int | None = 0 + error_list: list | None = [] + notice_list: list | None = [] class Relay(BaseModel): - id: Optional[str] = None - url: Optional[str] = None - connected: Optional[bool] = None - connected_string: Optional[str] = None - status: Optional[str] = None - active: Optional[bool] = None - ping: Optional[int] = None + id: str | None = None + url: str | None = None + active: bool | None = None + + connected: bool | None = Field(default=None, no_database=True) + connected_string: str | None = Field(default=None, no_database=True) + status: RelayStatus | None = Field(default=None, no_database=True) + + ping: int | None = Field(default=None, no_database=True) def _init__(self): if not self.id: self.id = urlsafe_short_hash() -class RelayList(BaseModel): - __root__: List[Relay] - - -class Event(BaseModel): - content: str - pubkey: str - created_at: Optional[int] - kind: int - tags: Optional[List[List[str]]] - sig: str - - -class Filter(BaseModel): - ids: Optional[List[str]] - kinds: Optional[List[int]] - authors: Optional[List[str]] - since: Optional[int] - until: Optional[int] - e: Optional[List[str]] = Field(alias="#e") - p: Optional[List[str]] = Field(alias="#p") - limit: Optional[int] - - -class Filters(BaseModel): - __root__: List[Filter] +class RelayDb(BaseModel): + id: str + url: str + active: bool | None = True class TestMessage(BaseModel): - sender_private_key: Optional[str] + sender_private_key: str | None reciever_public_key: str message: str + class TestMessageResponse(BaseModel): private_key: str public_key: str event_json: str -# class nostrKeys(BaseModel): -# pubkey: str -# privkey: str -# class nostrNotes(BaseModel): -# id: str -# pubkey: str -# created_at: str -# kind: int -# tags: str -# content: str -# sig: str - -# class nostrCreateRelays(BaseModel): -# relay: str = Query(None) - -# class nostrCreateConnections(BaseModel): -# pubkey: str = Query(None) -# relayid: str = Query(None) - -# class nostrRelays(BaseModel): -# id: Optional[str] -# relay: Optional[str] -# status: Optional[bool] = False +class Config(BaseModel): + private_ws: bool = True + public_ws: bool = False -# class nostrRelaySetList(BaseModel): -# allowlist: Optional[str] -# denylist: Optional[str] - -# class nostrConnections(BaseModel): -# id: str -# pubkey: Optional[str] -# relayid: Optional[str] - -# class nostrSubscriptions(BaseModel): -# id: str -# userPubkey: Optional[str] -# subscribedPubkey: Optional[str] +class UserConfig(BaseModel): + owner_id: str + extra: Config = Config() diff --git a/nostr/bech32.py b/nostr/bech32.py index b068de7..ba2ddd1 100644 --- a/nostr/bech32.py +++ b/nostr/bech32.py @@ -23,21 +23,25 @@ from enum import Enum + class Encoding(Enum): """Enumeration type to list the various supported encodings.""" + BECH32 = 1 BECH32M = 2 + CHARSET = "qpzry9x8gf2tvdw0s3jn54khce6mua7l" -BECH32M_CONST = 0x2bc830a3 +BECH32M_CONST = 0x2BC830A3 + def bech32_polymod(values): """Internal function that computes the Bech32 checksum.""" - generator = [0x3b6a57b2, 0x26508e6d, 0x1ea119fa, 0x3d4233dd, 0x2a1462b3] + generator = [0x3B6A57B2, 0x26508E6D, 0x1EA119FA, 0x3D4233DD, 0x2A1462B3] chk = 1 for value in values: top = chk >> 25 - chk = (chk & 0x1ffffff) << 5 ^ value + chk = (chk & 0x1FFFFFF) << 5 ^ value for i in range(5): chk ^= generator[i] if ((top >> i) & 1) else 0 return chk @@ -57,6 +61,7 @@ def bech32_verify_checksum(hrp, data): return Encoding.BECH32M return None + def bech32_create_checksum(hrp, data, spec): """Compute the checksum values given HRP and data.""" values = bech32_hrp_expand(hrp) + data @@ -68,26 +73,29 @@ def bech32_create_checksum(hrp, data, spec): def bech32_encode(hrp, data, spec): """Compute a Bech32 string given HRP and data values.""" combined = data + bech32_create_checksum(hrp, data, spec) - return hrp + '1' + ''.join([CHARSET[d] for d in combined]) + return hrp + "1" + "".join([CHARSET[d] for d in combined]) + def bech32_decode(bech): """Validate a Bech32/Bech32m string, and determine HRP and data.""" - if ((any(ord(x) < 33 or ord(x) > 126 for x in bech)) or - (bech.lower() != bech and bech.upper() != bech)): + if (any(ord(x) < 33 or ord(x) > 126 for x in bech)) or ( + bech.lower() != bech and bech.upper() != bech + ): return (None, None, None) bech = bech.lower() - pos = bech.rfind('1') + pos = bech.rfind("1") if pos < 1 or pos + 7 > len(bech) or len(bech) > 90: return (None, None, None) - if not all(x in CHARSET for x in bech[pos+1:]): + if not all(x in CHARSET for x in bech[pos + 1 :]): return (None, None, None) hrp = bech[:pos] - data = [CHARSET.find(x) for x in bech[pos+1:]] + data = [CHARSET.find(x) for x in bech[pos + 1 :]] spec = bech32_verify_checksum(hrp, data) if spec is None: return (None, None, None) return (hrp, data[:-6], spec) + def convertbits(data, frombits, tobits, pad=True): """General power-of-2 base conversion.""" acc = 0 @@ -116,22 +124,29 @@ def decode(hrp, addr): hrpgot, data, spec = bech32_decode(addr) if hrpgot != hrp: return (None, None) - decoded = convertbits(data[1:], 5, 8, False) + decoded = convertbits(data[1:], 5, 8, False) # type: ignore if decoded is None or len(decoded) < 2 or len(decoded) > 40: return (None, None) - if data[0] > 16: + if data[0] > 16: # type: ignore return (None, None) - if data[0] == 0 and len(decoded) != 20 and len(decoded) != 32: + if data[0] == 0 and len(decoded) != 20 and len(decoded) != 32: # type: ignore return (None, None) - if data[0] == 0 and spec != Encoding.BECH32 or data[0] != 0 and spec != Encoding.BECH32M: + if ( + data[0] == 0 # type: ignore + and spec != Encoding.BECH32 + or data[0] != 0 # type: ignore + and spec != Encoding.BECH32M + ): return (None, None) - return (data[0], decoded) + return (data[0], decoded) # type: ignore def encode(hrp, witver, witprog): """Encode a segwit address.""" spec = Encoding.BECH32 if witver == 0 else Encoding.BECH32M - ret = bech32_encode(hrp, [witver] + convertbits(witprog, 8, 5), spec) + wit_prog = convertbits(witprog, 8, 5) + assert wit_prog + ret = bech32_encode(hrp, [witver, *wit_prog], spec) if decode(hrp, ret) == (None, None): return None return ret diff --git a/nostr/client/cbc.py b/nostr/client/cbc.py deleted file mode 100644 index a41dbc0..0000000 --- a/nostr/client/cbc.py +++ /dev/null @@ -1,41 +0,0 @@ - -from Cryptodome import Random -from Cryptodome.Cipher import AES - -plain_text = "This is the text to encrypts" - -# encrypted = "7mH9jq3K9xNfWqIyu9gNpUz8qBvGwsrDJ+ACExdV1DvGgY8q39dkxVKeXD7LWCDrPnoD/ZFHJMRMis8v9lwHfNgJut8EVTMuJJi8oTgJevOBXl+E+bJPwej9hY3k20rgCQistNRtGHUzdWyOv7S1tg==".encode() -# iv = "GzDzqOVShWu3Pl2313FBpQ==".encode() - -key = bytes.fromhex("3aa925cb69eb613e2928f8a18279c78b1dca04541dfd064df2eda66b59880795") - -BLOCK_SIZE = 16 - -class AESCipher(object): - """This class is compatible with crypto.createCipheriv('aes-256-cbc') - - """ - def __init__(self, key=None): - self.key = key - - def pad(self, data): - length = BLOCK_SIZE - (len(data) % BLOCK_SIZE) - return data + (chr(length) * length).encode() - - def unpad(self, data): - return data[: -(data[-1] if type(data[-1]) == int else ord(data[-1]))] - - def encrypt(self, plain_text): - cipher = AES.new(self.key, AES.MODE_CBC) - b = plain_text.encode("UTF-8") - return cipher.iv, cipher.encrypt(self.pad(b)) - - def decrypt(self, iv, enc_text): - cipher = AES.new(self.key, AES.MODE_CBC, iv=iv) - return self.unpad(cipher.decrypt(enc_text).decode("UTF-8")) - -if __name__ == "__main__": - aes = AESCipher(key=key) - iv, enc_text = aes.encrypt(plain_text) - dec_text = aes.decrypt(iv, enc_text) - print(dec_text) \ No newline at end of file diff --git a/nostr/client/client.py b/nostr/client/client.py index 6e70f71..d6fb5c8 100644 --- a/nostr/client/client.py +++ b/nostr/client/client.py @@ -1,152 +1,75 @@ -from typing import * -import ssl -import time -import json -import os -import base64 +import asyncio + +from loguru import logger -from ..event import Event from ..relay_manager import RelayManager -from ..message_type import ClientMessageType -from ..key import PrivateKey, PublicKey - -from ..filter import Filter, Filters -from ..event import Event, EventKind, EncryptedDirectMessage -from ..relay_manager import RelayManager -from ..message_type import ClientMessageType - -# from aes import AESCipher -from . import cbc class NostrClient: - relays = [ - "wss://nostr-pub.wellorder.net", - "wss://nostr.zebedee.cloud", - "wss://nodestr.fmt.wiz.biz", - "wss://nostr.oxtr.dev", - ] # ["wss://nostr.oxtr.dev"] # ["wss://relay.nostr.info"] "wss://nostr-pub.wellorder.net" "ws://91.237.88.218:2700", "wss://nostrrr.bublina.eu.org", ""wss://nostr-relay.freeberty.net"", , "wss://nostr.oxtr.dev", "wss://relay.nostr.info", "wss://nostr-pub.wellorder.net" , "wss://relayer.fiatjaf.com", "wss://nodestr.fmt.wiz.biz/", "wss://no.str.cr" - relay_manager = RelayManager() - private_key: PrivateKey - public_key: PublicKey + relay_manager: RelayManager + running: bool - def __init__(self, privatekey_hex: str = "", relays: List[str] = [], connect=True): - self.generate_keys(privatekey_hex) + def __init__(self): + self.running = True + self.relay_manager = RelayManager() - if len(relays): - self.relays = relays - if connect: - self.connect() + def connect(self, relays): + for relay in relays: + try: + self.relay_manager.add_relay(relay) + except Exception as e: + logger.debug(e) + self.running = True - def connect(self): - for relay in self.relays: - self.relay_manager.add_relay(relay) - self.relay_manager.open_connections( - {"cert_reqs": ssl.CERT_NONE} - ) # NOTE: This disables ssl certificate verification + def reconnect(self, relays): + self.relay_manager.remove_relays() + self.connect(relays) def close(self): - self.relay_manager.close_connections() + try: + self.relay_manager.close_all_subscriptions() + self.relay_manager.close_connections() - def generate_keys(self, privatekey_hex: str = None): - pk = bytes.fromhex(privatekey_hex) if privatekey_hex else None - self.private_key = PrivateKey(pk) - self.public_key = self.private_key.public_key + self.running = False + except Exception as e: + logger.error(e) - def post(self, message: str): - event = Event(message, self.public_key.hex(), kind=EventKind.TEXT_NOTE) - self.private_key.sign_event(event) - event_json = event.to_message() - # print("Publishing message:") - # print(event_json) - self.relay_manager.publish_message(event_json) - - def get_post( - self, sender_publickey: PublicKey = None, callback_func=None, filter_kwargs={} - ): - filter = Filter( - authors=[sender_publickey.hex()] if sender_publickey else None, - kinds=[EventKind.TEXT_NOTE], - **filter_kwargs, - ) - filters = Filters([filter]) - subscription_id = os.urandom(4).hex() - self.relay_manager.add_subscription(subscription_id, filters) - - request = [ClientMessageType.REQUEST, subscription_id] - request.extend(filters.to_json_array()) - message = json.dumps(request) - self.relay_manager.publish_message(message) - - while True: - while self.relay_manager.message_pool.has_events(): - event_msg = self.relay_manager.message_pool.get_event() - if callback_func: - callback_func(event_msg.event) - time.sleep(0.1) - - def dm(self, message: str, to_pubkey: PublicKey): - dm = EncryptedDirectMessage( - recipient_pubkey=to_pubkey.hex(), cleartext_content=message - ) - self.private_key.sign_event(dm) - self.relay_manager.publish_event(dm) - - def get_dm(self, sender_publickey: PublicKey, callback_func=None): - filters = Filters( - [ - Filter( - kinds=[EventKind.ENCRYPTED_DIRECT_MESSAGE], - pubkey_refs=[sender_publickey.hex()], - ) - ] - ) - subscription_id = os.urandom(4).hex() - self.relay_manager.add_subscription(subscription_id, filters) - - request = [ClientMessageType.REQUEST, subscription_id] - request.extend(filters.to_json_array()) - message = json.dumps(request) - self.relay_manager.publish_message(message) - - while True: - while self.relay_manager.message_pool.has_events(): - event_msg = self.relay_manager.message_pool.get_event() - if "?iv=" in event_msg.event.content: - try: - shared_secret = self.private_key.compute_shared_secret( - event_msg.event.public_key - ) - aes = cbc.AESCipher(key=shared_secret) - enc_text_b64, iv_b64 = event_msg.event.content.split("?iv=") - iv = base64.decodebytes(iv_b64.encode("utf-8")) - enc_text = base64.decodebytes(enc_text_b64.encode("utf-8")) - dec_text = aes.decrypt(iv, enc_text) - if callback_func: - callback_func(event_msg.event, dec_text) - except: - pass - break - time.sleep(0.1) - - def subscribe( + async def subscribe( self, callback_events_func=None, callback_notices_func=None, callback_eosenotices_func=None, ): - while True: + while self.running: + self._check_events(callback_events_func) + self._check_notices(callback_notices_func) + self._check_eos_notices(callback_eosenotices_func) + + await asyncio.sleep(0.2) + + def _check_events(self, callback_events_func=None): + try: while self.relay_manager.message_pool.has_events(): event_msg = self.relay_manager.message_pool.get_event() if callback_events_func: callback_events_func(event_msg) + except Exception as e: + logger.debug(e) + + def _check_notices(self, callback_notices_func=None): + try: while self.relay_manager.message_pool.has_notices(): event_msg = self.relay_manager.message_pool.get_notice() if callback_notices_func: callback_notices_func(event_msg) + except Exception as e: + logger.debug(e) + + def _check_eos_notices(self, callback_eosenotices_func=None): + try: while self.relay_manager.message_pool.has_eose_notices(): event_msg = self.relay_manager.message_pool.get_eose_notice() if callback_eosenotices_func: callback_eosenotices_func(event_msg) - - time.sleep(0.1) + except Exception as e: + logger.debug(e) diff --git a/nostr/delegation.py b/nostr/delegation.py deleted file mode 100644 index 94801f5..0000000 --- a/nostr/delegation.py +++ /dev/null @@ -1,32 +0,0 @@ -import time -from dataclasses import dataclass - - -@dataclass -class Delegation: - delegator_pubkey: str - delegatee_pubkey: str - event_kind: int - duration_secs: int = 30*24*60 # default to 30 days - signature: str = None # set in PrivateKey.sign_delegation - - @property - def expires(self) -> int: - return int(time.time()) + self.duration_secs - - @property - def conditions(self) -> str: - return f"kind={self.event_kind}&created_at<{self.expires}" - - @property - def delegation_token(self) -> str: - return f"nostr:delegation:{self.delegatee_pubkey}:{self.conditions}" - - def get_tag(self) -> list[str]: - """ Called by Event """ - return [ - "delegation", - self.delegator_pubkey, - self.conditions, - self.signature, - ] diff --git a/nostr/event.py b/nostr/event.py index b903e0e..994c0f4 100644 --- a/nostr/event.py +++ b/nostr/event.py @@ -1,10 +1,11 @@ -import time import json +import time from dataclasses import dataclass, field from enum import IntEnum -from typing import List -from secp256k1 import PublicKey from hashlib import sha256 +from typing import Optional + +import coincurve from .message_type import ClientMessageType @@ -20,14 +21,14 @@ class EventKind(IntEnum): @dataclass class Event: - content: str = None - public_key: str = None - created_at: int = None + content: Optional[str] = None + public_key: Optional[str] = None + created_at: Optional[int] = None kind: int = EventKind.TEXT_NOTE - tags: List[List[str]] = field( + tags: list[list[str]] = field( default_factory=list ) # Dataclasses require special handling when the default value is a mutable type - signature: str = None + signature: Optional[str] = None def __post_init__(self): if self.content is not None and not isinstance(self.content, str): @@ -39,7 +40,7 @@ class Event: @staticmethod def serialize( - public_key: str, created_at: int, kind: int, tags: List[List[str]], content: str + public_key: str, created_at: int, kind: int, tags: list[list[str]], content: str ) -> bytes: data = [0, public_key, created_at, kind, tags, content] data_str = json.dumps(data, separators=(",", ":"), ensure_ascii=False) @@ -47,7 +48,7 @@ class Event: @staticmethod def compute_id( - public_key: str, created_at: int, kind: int, tags: List[List[str]], content: str + public_key: str, created_at: int, kind: int, tags: list[list[str]], content: str ): return sha256( Event.serialize(public_key, created_at, kind, tags, content) @@ -56,6 +57,9 @@ class Event: @property def id(self) -> str: # Always recompute the id to reflect the up-to-date state of the Event + assert self.public_key + assert self.created_at + assert self.content return Event.compute_id( self.public_key, self.created_at, self.kind, self.tags, self.content ) @@ -69,12 +73,10 @@ class Event: self.tags.append(["e", event_id]) def verify(self) -> bool: - pub_key = PublicKey( - bytes.fromhex("02" + self.public_key), True - ) # add 02 for schnorr (bip340) - return pub_key.schnorr_verify( - bytes.fromhex(self.id), bytes.fromhex(self.signature), None, raw=True - ) + assert self.public_key + assert self.signature + pub_key = coincurve.PublicKeyXOnly(bytes.fromhex(self.public_key)) + return pub_key.verify(bytes.fromhex(self.signature), bytes.fromhex(self.id)) def to_message(self) -> str: return json.dumps( @@ -95,9 +97,9 @@ class Event: @dataclass class EncryptedDirectMessage(Event): - recipient_pubkey: str = None - cleartext_content: str = None - reference_event_id: str = None + recipient_pubkey: Optional[str] = None + cleartext_content: Optional[str] = None + reference_event_id: Optional[str] = None def __post_init__(self): if self.content is not None: @@ -121,6 +123,7 @@ class EncryptedDirectMessage(Event): def id(self) -> str: if self.content is None: raise Exception( - "EncryptedDirectMessage `id` is undefined until its message is encrypted and stored in the `content` field" + "EncryptedDirectMessage `id` is undefined until its" + + " message is encrypted and stored in the `content` field" ) return super().id diff --git a/nostr/filter.py b/nostr/filter.py deleted file mode 100644 index f119079..0000000 --- a/nostr/filter.py +++ /dev/null @@ -1,134 +0,0 @@ -from collections import UserList -from typing import List - -from .event import Event, EventKind - - -class Filter: - """ - NIP-01 filtering. - - Explicitly supports "#e" and "#p" tag filters via `event_refs` and `pubkey_refs`. - - Arbitrary NIP-12 single-letter tag filters are also supported via `add_arbitrary_tag`. - If a particular single-letter tag gains prominence, explicit support should be - added. For example: - # arbitrary tag - filter.add_arbitrary_tag('t', [hashtags]) - - # promoted to explicit support - Filter(hashtag_refs=[hashtags]) - """ - - def __init__( - self, - event_ids: List[str] = None, - kinds: List[EventKind] = None, - authors: List[str] = None, - since: int = None, - until: int = None, - event_refs: List[ - str - ] = None, # the "#e" attr; list of event ids referenced in an "e" tag - pubkey_refs: List[ - str - ] = None, # The "#p" attr; list of pubkeys referenced in a "p" tag - limit: int = None, - ) -> None: - self.event_ids = event_ids - self.kinds = kinds - self.authors = authors - self.since = since - self.until = until - self.event_refs = event_refs - self.pubkey_refs = pubkey_refs - self.limit = limit - - self.tags = {} - if self.event_refs: - self.add_arbitrary_tag("e", self.event_refs) - if self.pubkey_refs: - self.add_arbitrary_tag("p", self.pubkey_refs) - - def add_arbitrary_tag(self, tag: str, values: list): - """ - Filter on any arbitrary tag with explicit handling for NIP-01 and NIP-12 - single-letter tags. - """ - # NIP-01 'e' and 'p' tags and any NIP-12 single-letter tags must be prefixed with "#" - tag_key = tag if len(tag) > 1 else f"#{tag}" - self.tags[tag_key] = values - - def matches(self, event: Event) -> bool: - if self.event_ids is not None and event.id not in self.event_ids: - return False - if self.kinds is not None and event.kind not in self.kinds: - return False - if self.authors is not None and event.public_key not in self.authors: - return False - if self.since is not None and event.created_at < self.since: - return False - if self.until is not None and event.created_at > self.until: - return False - if (self.event_refs is not None or self.pubkey_refs is not None) and len( - event.tags - ) == 0: - return False - - if self.tags: - e_tag_identifiers = set([e_tag[0] for e_tag in event.tags]) - for f_tag, f_tag_values in self.tags.items(): - # Omit any NIP-01 or NIP-12 "#" chars on single-letter tags - f_tag = f_tag.replace("#", "") - - if f_tag not in e_tag_identifiers: - # Event is missing a tag type that we're looking for - return False - - # Multiple values within f_tag_values are treated as OR search; an Event - # needs to match only one. - # Note: an Event could have multiple entries of the same tag type - # (e.g. a reply to multiple people) so we have to check all of them. - match_found = False - for e_tag in event.tags: - if e_tag[0] == f_tag and e_tag[1] in f_tag_values: - match_found = True - break - if not match_found: - return False - - return True - - def to_json_object(self) -> dict: - res = {} - if self.event_ids is not None: - res["ids"] = self.event_ids - if self.kinds is not None: - res["kinds"] = self.kinds - if self.authors is not None: - res["authors"] = self.authors - if self.since is not None: - res["since"] = self.since - if self.until is not None: - res["until"] = self.until - if self.limit is not None: - res["limit"] = self.limit - if self.tags: - res.update(self.tags) - - return res - - -class Filters(UserList): - def __init__(self, initlist: "list[Filter]" = []) -> None: - super().__init__(initlist) - self.data: "list[Filter]" - - def match(self, event: Event): - for filter in self.data: - if filter.matches(event): - return True - return False - - def to_json_array(self) -> list: - return [filter.to_json_object() for filter in self.data] diff --git a/nostr/key.py b/nostr/key.py index d34697f..f7b4e81 100644 --- a/nostr/key.py +++ b/nostr/key.py @@ -1,14 +1,12 @@ -import secrets import base64 -import secp256k1 -from cffi import FFI -from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes -from cryptography.hazmat.primitives import padding -from hashlib import sha256 +import secrets -from .delegation import Delegation +import coincurve +from cryptography.hazmat.primitives import padding +from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes + +from .bech32 import Encoding, bech32_decode, bech32_encode, convertbits from .event import EncryptedDirectMessage, Event, EventKind -from . import bech32 class PublicKey: @@ -16,55 +14,61 @@ class PublicKey: self.raw_bytes = raw_bytes def bech32(self) -> str: - converted_bits = bech32.convertbits(self.raw_bytes, 8, 5) - return bech32.bech32_encode("npub", converted_bits, bech32.Encoding.BECH32) + converted_bits = convertbits(self.raw_bytes, 8, 5) + return bech32_encode("npub", converted_bits, Encoding.BECH32) def hex(self) -> str: return self.raw_bytes.hex() - def verify_signed_message_hash(self, hash: str, sig: str) -> bool: - pk = secp256k1.PublicKey(b"\x02" + self.raw_bytes, True) - return pk.schnorr_verify(bytes.fromhex(hash), bytes.fromhex(sig), None, True) + def verify_signed_message_hash(self, message_hash: str, sig: str) -> bool: + pk = coincurve.PublicKeyXOnly(self.raw_bytes) + return pk.verify(bytes.fromhex(sig), bytes.fromhex(message_hash)) @classmethod def from_npub(cls, npub: str): """Load a PublicKey from its bech32/npub form""" - hrp, data, spec = bech32.bech32_decode(npub) - raw_public_key = bech32.convertbits(data, 5, 8)[:-1] + _, data, _ = bech32_decode(npub) + raw_data = convertbits(data, 5, 8) + assert raw_data + raw_public_key = raw_data[:-1] return cls(bytes(raw_public_key)) class PrivateKey: - def __init__(self, raw_secret: bytes = None) -> None: - if not raw_secret is None: + def __init__(self, raw_secret: bytes | None = None) -> None: + if raw_secret is not None: self.raw_secret = raw_secret else: self.raw_secret = secrets.token_bytes(32) - sk = secp256k1.PrivateKey(self.raw_secret) - self.public_key = PublicKey(sk.pubkey.serialize()[1:]) + sk = coincurve.PrivateKey(self.raw_secret) + assert sk.public_key + self.public_key = PublicKey(sk.public_key.format()[1:]) @classmethod def from_nsec(cls, nsec: str): """Load a PrivateKey from its bech32/nsec form""" - hrp, data, spec = bech32.bech32_decode(nsec) - raw_secret = bech32.convertbits(data, 5, 8)[:-1] + _, data, _ = bech32_decode(nsec) + raw_data = convertbits(data, 5, 8) + assert raw_data + raw_secret = raw_data[:-1] return cls(bytes(raw_secret)) def bech32(self) -> str: - converted_bits = bech32.convertbits(self.raw_secret, 8, 5) - return bech32.bech32_encode("nsec", converted_bits, bech32.Encoding.BECH32) + converted_bits = convertbits(self.raw_secret, 8, 5) + return bech32_encode("nsec", converted_bits, Encoding.BECH32) def hex(self) -> str: return self.raw_secret.hex() def tweak_add(self, scalar: bytes) -> bytes: - sk = secp256k1.PrivateKey(self.raw_secret) - return sk.tweak_add(scalar) + sk = coincurve.PrivateKey(self.raw_secret) + return sk.add(scalar).to_der() def compute_shared_secret(self, public_key_hex: str) -> bytes: - pk = secp256k1.PublicKey(bytes.fromhex("02" + public_key_hex), True) - return pk.ecdh(self.raw_secret, hashfn=copy_x) + pk = coincurve.PublicKey(bytes.fromhex("02" + public_key_hex)) + sk = coincurve.PrivateKey(self.raw_secret) + return sk.ecdh(pk.format()) def encrypt_message(self, message: str, public_key_hex: str) -> str: padder = padding.PKCS7(128).padder() @@ -78,9 +82,14 @@ class PrivateKey: encryptor = cipher.encryptor() encrypted_message = encryptor.update(padded_data) + encryptor.finalize() - return f"{base64.b64encode(encrypted_message).decode()}?iv={base64.b64encode(iv).decode()}" + return ( + f"{base64.b64encode(encrypted_message).decode()}" + + f"?iv={base64.b64encode(iv).decode()}" + ) def encrypt_dm(self, dm: EncryptedDirectMessage) -> None: + assert dm.cleartext_content + assert dm.recipient_pubkey dm.content = self.encrypt_message( message=dm.cleartext_content, public_key_hex=dm.recipient_pubkey ) @@ -103,28 +112,23 @@ class PrivateKey: return unpadded_data.decode() - def sign_message_hash(self, hash: bytes) -> str: - sk = secp256k1.PrivateKey(self.raw_secret) - sig = sk.schnorr_sign(hash, None, raw=True) + def sign_message_hash(self, message_hash: bytes) -> str: + sk = coincurve.PrivateKey(self.raw_secret) + sig = sk.sign_schnorr(message_hash) return sig.hex() def sign_event(self, event: Event) -> None: if event.kind == EventKind.ENCRYPTED_DIRECT_MESSAGE and event.content is None: - self.encrypt_dm(event) + self.encrypt_dm(event) # type: ignore if event.public_key is None: event.public_key = self.public_key.hex() event.signature = self.sign_message_hash(bytes.fromhex(event.id)) - def sign_delegation(self, delegation: Delegation) -> None: - delegation.signature = self.sign_message_hash( - sha256(delegation.delegation_token.encode()).digest() - ) - def __eq__(self, other): return self.raw_secret == other.raw_secret -def mine_vanity_key(prefix: str = None, suffix: str = None) -> PrivateKey: +def mine_vanity_key(prefix: str | None = None, suffix: str | None = None) -> PrivateKey: if prefix is None and suffix is None: raise ValueError("Expected at least one of 'prefix' or 'suffix' arguments") @@ -140,14 +144,3 @@ def mine_vanity_key(prefix: str = None, suffix: str = None) -> PrivateKey: break return sk - - -ffi = FFI() - - -@ffi.callback( - "int (unsigned char *, const unsigned char *, const unsigned char *, void *)" -) -def copy_x(output, x32, y32, data): - ffi.memmove(output, x32, 32) - return 1 diff --git a/nostr/message_pool.py b/nostr/message_pool.py index d364cf2..a3e6c5f 100644 --- a/nostr/message_pool.py +++ b/nostr/message_pool.py @@ -1,13 +1,16 @@ import json from queue import Queue from threading import Lock + from .message_type import RelayMessageType -from .event import Event class EventMessage: - def __init__(self, event: Event, subscription_id: str, url: str) -> None: + def __init__( + self, event: str, event_id: str, subscription_id: str, url: str + ) -> None: self.event = event + self.event_id = event_id self.subscription_id = subscription_id self.url = url @@ -58,20 +61,29 @@ class MessagePool: message_type = message_json[0] if message_type == RelayMessageType.EVENT: subscription_id = message_json[1] - e = message_json[2] - event = Event( - e["content"], - e["pubkey"], - e["created_at"], - e["kind"], - e["tags"], - e["sig"], - ) + event = message_json[2] + if "id" not in event: + return + event_id = event["id"] + with self.lock: - if not event.id in self._unique_events: - self.events.put(EventMessage(event, subscription_id, url)) - self._unique_events.add(event.id) + if f"{subscription_id}_{event_id}" not in self._unique_events: + self._accept_event( + EventMessage(json.dumps(event), event_id, subscription_id, url) + ) elif message_type == RelayMessageType.NOTICE: self.notices.put(NoticeMessage(message_json[1], url)) elif message_type == RelayMessageType.END_OF_STORED_EVENTS: self.eose_notices.put(EndOfStoredEventsMessage(message_json[1], url)) + + def _accept_event(self, event_message: EventMessage): + """ + Event uniqueness is considered per `subscription_id`. The `subscription_id` is + rewritten to be unique and it is the same accross relays. The same event can + come from different subscriptions (from the same client or from different ones). + Clients that have joined later should receive older events. + """ + self.events.put(event_message) + self._unique_events.add( + f"{event_message.subscription_id}_{event_message.event_id}" + ) diff --git a/nostr/message_type.py b/nostr/message_type.py index 3f5206b..d37cdfd 100644 --- a/nostr/message_type.py +++ b/nostr/message_type.py @@ -3,13 +3,20 @@ class ClientMessageType: REQUEST = "REQ" CLOSE = "CLOSE" + class RelayMessageType: EVENT = "EVENT" NOTICE = "NOTICE" END_OF_STORED_EVENTS = "EOSE" + COMMAND_RESULT = "OK" @staticmethod def is_valid(type: str) -> bool: - if type == RelayMessageType.EVENT or type == RelayMessageType.NOTICE or type == RelayMessageType.END_OF_STORED_EVENTS: + if ( + type == RelayMessageType.EVENT + or type == RelayMessageType.NOTICE + or type == RelayMessageType.END_OF_STORED_EVENTS + or type == RelayMessageType.COMMAND_RESULT + ): return True - return False \ No newline at end of file + return False diff --git a/nostr/pow.py b/nostr/pow.py deleted file mode 100644 index e006288..0000000 --- a/nostr/pow.py +++ /dev/null @@ -1,54 +0,0 @@ -import time -from .event import Event -from .key import PrivateKey - -def zero_bits(b: int) -> int: - n = 0 - - if b == 0: - return 8 - - while b >> 1: - b = b >> 1 - n += 1 - - return 7 - n - -def count_leading_zero_bits(hex_str: str) -> int: - total = 0 - for i in range(0, len(hex_str) - 2, 2): - bits = zero_bits(int(hex_str[i:i+2], 16)) - total += bits - - if bits != 8: - break - - return total - -def mine_event(content: str, difficulty: int, public_key: str, kind: int, tags: list=[]) -> Event: - all_tags = [["nonce", "1", str(difficulty)]] - all_tags.extend(tags) - - created_at = int(time.time()) - event_id = Event.compute_id(public_key, created_at, kind, all_tags, content) - num_leading_zero_bits = count_leading_zero_bits(event_id) - - attempts = 1 - while num_leading_zero_bits < difficulty: - attempts += 1 - all_tags[0][1] = str(attempts) - created_at = int(time.time()) - event_id = Event.compute_id(public_key, created_at, kind, all_tags, content) - num_leading_zero_bits = count_leading_zero_bits(event_id) - - return Event(public_key, content, created_at, kind, all_tags, event_id) - -def mine_key(difficulty: int) -> PrivateKey: - sk = PrivateKey() - num_leading_zero_bits = count_leading_zero_bits(sk.public_key.hex()) - - while num_leading_zero_bits < difficulty: - sk = PrivateKey() - num_leading_zero_bits = count_leading_zero_bits(sk.public_key.hex()) - - return sk diff --git a/nostr/relay.py b/nostr/relay.py index 7fb4baa..d762963 100644 --- a/nostr/relay.py +++ b/nostr/relay.py @@ -1,50 +1,35 @@ +import asyncio import json import time from queue import Queue -from threading import Lock + +from loguru import logger from websocket import WebSocketApp -from .event import Event -from .filter import Filters + from .message_pool import MessagePool -from .message_type import RelayMessageType from .subscription import Subscription -class RelayPolicy: - def __init__(self, should_read: bool = True, should_write: bool = True) -> None: - self.should_read = should_read - self.should_write = should_write - - def to_json_object(self) -> dict[str, bool]: - return {"read": self.should_read, "write": self.should_write} - - class Relay: - def __init__( - self, - url: str, - policy: RelayPolicy, - message_pool: MessagePool, - subscriptions: dict[str, Subscription] = {}, - ) -> None: + def __init__(self, url: str, message_pool: MessagePool) -> None: self.url = url - self.policy = policy self.message_pool = message_pool - self.subscriptions = subscriptions self.connected: bool = False self.reconnect: bool = True self.shutdown: bool = False + self.error_counter: int = 0 self.error_threshold: int = 100 + self.error_list: list[str] = [] + self.notice_list: list[str] = [] + self.last_error_date: int = 0 self.num_received_events: int = 0 self.num_sent_events: int = 0 self.num_subscriptions: int = 0 - self.ssl_options: dict = {} - self.proxy: dict = {} - self.lock = Lock() - self.queue = Queue() - def connect(self, ssl_options: dict = None, proxy: dict = None): + self.queue: Queue = Queue() + + def connect(self): self.ws = WebSocketApp( self.url, on_open=self._on_open, @@ -54,30 +39,20 @@ class Relay: on_ping=self._on_ping, on_pong=self._on_pong, ) - self.ssl_options = ssl_options - self.proxy = proxy if not self.connected: - self.ws.run_forever( - sslopt=ssl_options, - http_proxy_host=None if proxy is None else proxy.get("host"), - http_proxy_port=None if proxy is None else proxy.get("port"), - proxy_type=None if proxy is None else proxy.get("type"), - ping_interval=5, - ) + self.ws.run_forever(ping_interval=10) def close(self): - self.ws.close() + try: + self.ws.close() + except Exception as e: + logger.warning(f"[Relay: {self.url}] Failed to close websocket: {e}") + self.connected = False self.shutdown = True - def check_reconnect(self): - try: - self.close() - except: - pass - self.connected = False - if self.reconnect: - time.sleep(self.error_counter**2) - self.connect(self.ssl_options, self.proxy) + @property + def error_threshold_reached(self): + return self.error_threshold and self.error_counter >= self.error_threshold @property def ping(self): @@ -87,103 +62,65 @@ class Relay: def publish(self, message: str): self.queue.put(message) - def queue_worker(self, shutdown): + def publish_subscriptions(self, subscriptions: list[Subscription]): + for s in subscriptions: + assert s.filters + json_str = json.dumps(["REQ", s.id, *s.filters]) + self.publish(json_str) + + async def queue_worker(self): while True: if self.connected: try: message = self.queue.get(timeout=1) self.num_sent_events += 1 self.ws.send(message) - except: - if shutdown(): - break + except Exception as _: + pass else: - time.sleep(0.1) + await asyncio.sleep(1) - def add_subscription(self, id, filters: Filters): - with self.lock: - self.subscriptions[id] = Subscription(id, filters) + if self.shutdown: + logger.warning(f"[Relay: {self.url}] Closing queue worker.") + return - def close_subscription(self, id: str) -> None: - with self.lock: - self.subscriptions.pop(id) + def close_subscription(self, sub_id: str) -> None: + try: + self.publish(json.dumps(["CLOSE", sub_id])) + except Exception as e: + logger.debug(f"[Relay: {self.url}] Failed to close subscription: {e}") - def update_subscription(self, id: str, filters: Filters) -> None: - with self.lock: - subscription = self.subscriptions[id] - subscription.filters = filters + def add_notice(self, notice: str): + self.notice_list = [notice, *self.notice_list] - def to_json_object(self) -> dict: - return { - "url": self.url, - "policy": self.policy.to_json_object(), - "subscriptions": [ - subscription.to_json_object() - for subscription in self.subscriptions.values() - ], - } - - def _on_open(self, class_obj): + def _on_open(self, _): + logger.info(f"[Relay: {self.url}] Connected.") self.connected = True - pass + self.shutdown = False - def _on_close(self, class_obj, status_code, message): - self.connected = False - if self.error_threshold and self.error_counter > self.error_threshold: - pass - else: - self.check_reconnect() - pass + def _on_close(self, _, status_code, message): + logger.warning( + f"[Relay: {self.url}] Connection closed." + + f" Status: '{status_code}'. Message: '{message}'." + ) + self.close() - def _on_message(self, class_obj, message: str): - if self._is_valid_message(message): - self.num_received_events += 1 - self.message_pool.add_message(message, self.url) + def _on_message(self, _, message: str): + self.num_received_events += 1 + self.message_pool.add_message(message, self.url) - def _on_error(self, class_obj, error): - self.connected = False + def _on_error(self, _, error): + logger.warning(f"[Relay: {self.url}] Error: '{error!s}'") + self._append_error_message(str(error)) + self.close() + + def _on_ping(self, *_): + return + + def _on_pong(self, *_): + return + + def _append_error_message(self, message): self.error_counter += 1 - - def _on_ping(self, class_obj, message): - return - - def _on_pong(self, class_obj, message): - return - - def _is_valid_message(self, message: str) -> bool: - message = message.strip("\n") - if not message or message[0] != "[" or message[-1] != "]": - return False - - message_json = json.loads(message) - message_type = message_json[0] - if not RelayMessageType.is_valid(message_type): - return False - if message_type == RelayMessageType.EVENT: - if not len(message_json) == 3: - return False - - subscription_id = message_json[1] - with self.lock: - if subscription_id not in self.subscriptions: - return False - - e = message_json[2] - event = Event( - e["content"], - e["pubkey"], - e["created_at"], - e["kind"], - e["tags"], - e["sig"], - ) - if not event.verify(): - return False - - with self.lock: - subscription = self.subscriptions[subscription_id] - - if subscription.filters and not subscription.filters.match(event): - return False - - return True + self.error_list = [message, *self.error_list] + self.last_error_date = int(time.time()) diff --git a/nostr/relay_manager.py b/nostr/relay_manager.py index a698a33..2aa27c5 100644 --- a/nostr/relay_manager.py +++ b/nostr/relay_manager.py @@ -1,15 +1,13 @@ -import json +import asyncio import threading +import time +from typing import List -from .event import Event -from .filter import Filters -from .message_pool import MessagePool -from .message_type import ClientMessageType -from .relay import Relay, RelayPolicy +from loguru import logger - -class RelayException(Exception): - pass +from .message_pool import MessagePool, NoticeMessage +from .relay import Relay +from .subscription import Subscription class RelayManager: @@ -18,45 +16,84 @@ class RelayManager: self.threads: dict[str, threading.Thread] = {} self.queue_threads: dict[str, threading.Thread] = {} self.message_pool = MessagePool() + self._cached_subscriptions: dict[str, Subscription] = {} + self._subscriptions_lock = threading.Lock() - def add_relay( - self, url: str, read: bool = True, write: bool = True, subscriptions={} - ): - policy = RelayPolicy(read, write) - relay = Relay(url, policy, self.message_pool, subscriptions) + def add_relay(self, url: str) -> Relay: + if url in list(self.relays.keys()): + logger.debug(f"Relay '{url}' already present.") + return self.relays[url] + + relay = Relay(url, self.message_pool) self.relays[url] = relay - def remove_relay(self, url: str): - self.relays[url].close() - self.relays.pop(url) - self.threads[url].join(timeout=1) - self.threads.pop(url) + self._open_connection(relay) + + relay.publish_subscriptions(list(self._cached_subscriptions.values())) + return relay + + def remove_relay(self, url: str): + try: + self.relays[url].close() + except Exception as e: + logger.debug(e) + + if url in self.relays: + self.relays.pop(url) + + try: + self.threads[url].join(timeout=5) + except Exception as e: + logger.debug(e) + + if url in self.threads: + self.threads.pop(url) + + try: + self.queue_threads[url].join(timeout=5) + except Exception as e: + logger.debug(e) + + if url in self.queue_threads: + self.queue_threads.pop(url) + + def remove_relays(self): + relay_urls = list(self.relays.keys()) + for url in relay_urls: + self.remove_relay(url) + + def add_subscription(self, id: str, filters: List[str]): + s = Subscription(id, filters) + with self._subscriptions_lock: + self._cached_subscriptions[id] = s - def add_subscription(self, id: str, filters: Filters): for relay in self.relays.values(): - relay.add_subscription(id, filters) + relay.publish_subscriptions([s]) def close_subscription(self, id: str): - for relay in self.relays.values(): - relay.close_subscription(id) + try: + logger.info(f"Closing subscription: '{id}'.") + with self._subscriptions_lock: + if id in self._cached_subscriptions: + self._cached_subscriptions.pop(id) - def open_connections(self, ssl_options: dict = None, proxy: dict = None): - for relay in self.relays.values(): - self.threads[relay.url] = threading.Thread( - target=relay.connect, - args=(ssl_options, proxy), - name=f"{relay.url}-thread", - daemon=True, - ) - self.threads[relay.url].start() + for relay in self.relays.values(): + relay.close_subscription(id) + except Exception as e: + logger.debug(e) - self.queue_threads[relay.url] = threading.Thread( - target=relay.queue_worker, - args=(lambda: relay.shutdown,), - name=f"{relay.url}-queue", - daemon=True, - ) - self.queue_threads[relay.url].start() + def close_subscriptions(self, subscriptions: List[str]): + for id in subscriptions: + self.close_subscription(id) + + def close_all_subscriptions(self): + all_subscriptions = list(self._cached_subscriptions.keys()) + self.close_subscriptions(all_subscriptions) + + def check_and_restart_relays(self): + stopped_relays = [r for r in self.relays.values() if r.shutdown] + for relay in stopped_relays: + self._restart_relay(relay) def close_connections(self): for relay in self.relays.values(): @@ -64,16 +101,43 @@ class RelayManager: def publish_message(self, message: str): for relay in self.relays.values(): - if relay.policy.should_write: - relay.publish(message) + relay.publish(message) - def publish_event(self, event: Event): - """Verifies that the Event is publishable before submitting it to relays""" - if event.signature is None: - raise RelayException(f"Could not publish {event.id}: must be signed") + def handle_notice(self, notice: NoticeMessage): + relay = next((r for r in self.relays.values() if r.url == notice.url)) + if relay: + relay.add_notice(notice.content) - if not event.verify(): - raise RelayException( - f"Could not publish {event.id}: failed to verify signature {event.signature}" - ) - self.publish_message(event.to_message()) + def _open_connection(self, relay: Relay): + self.threads[relay.url] = threading.Thread( + target=relay.connect, + name=f"{relay.url}-thread", + daemon=True, + ) + self.threads[relay.url].start() + + def wrap_async_queue_worker(): + asyncio.run(relay.queue_worker()) + + self.queue_threads[relay.url] = threading.Thread( + target=wrap_async_queue_worker, + name=f"{relay.url}-queue", + daemon=True, + ) + self.queue_threads[relay.url].start() + + def _restart_relay(self, relay: Relay): + time_since_last_error = time.time() - relay.last_error_date + + min_wait_time = min( + 60 * relay.error_counter, 60 * 60 + ) # try at least once an hour + if time_since_last_error < min_wait_time: + return + + logger.info(f"Restarting connection to relay '{relay.url}'") + + self.remove_relay(relay.url) + new_relay = self.add_relay(relay.url) + new_relay.error_counter = relay.error_counter + new_relay.error_list = relay.error_list diff --git a/nostr/subscription.py b/nostr/subscription.py index 7afba20..ed60f7e 100644 --- a/nostr/subscription.py +++ b/nostr/subscription.py @@ -1,12 +1,7 @@ -from .filter import Filters +from typing import Optional + class Subscription: - def __init__(self, id: str, filters: Filters=None) -> None: + def __init__(self, id: str, filters: Optional[list[str]] = None) -> None: self.id = id self.filters = filters - - def to_json_object(self): - return { - "id": self.id, - "filters": self.filters.to_json_array() - } diff --git a/package-lock.json b/package-lock.json new file mode 100644 index 0000000..1180ffb --- /dev/null +++ b/package-lock.json @@ -0,0 +1,59 @@ +{ + "name": "nostrclient", + "version": "1.0.0", + "lockfileVersion": 3, + "requires": true, + "packages": { + "": { + "name": "nostrclient", + "version": "1.0.0", + "license": "ISC", + "dependencies": { + "prettier": "^3.2.5", + "pyright": "^1.1.358" + } + }, + "node_modules/fsevents": { + "version": "2.3.3", + "resolved": "https://registry.npmjs.org/fsevents/-/fsevents-2.3.3.tgz", + "integrity": "sha512-5xoDfX+fL7faATnagmWPpbFtwh/R77WmMMqqHGS65C3vvB0YHrgF+B1YmZ3441tMj5n63k0212XNoJwzlhffQw==", + "hasInstallScript": true, + "optional": true, + "os": [ + "darwin" + ], + "engines": { + "node": "^8.16.0 || ^10.6.0 || >=11.0.0" + } + }, + "node_modules/prettier": { + "version": "3.3.3", + "resolved": "https://registry.npmjs.org/prettier/-/prettier-3.3.3.tgz", + "integrity": "sha512-i2tDNA0O5IrMO757lfrdQZCc2jPNDVntV0m/+4whiDfWaTKfMNgR7Qz0NAeGz/nRqF4m5/6CLzbP4/liHt12Ew==", + "bin": { + "prettier": "bin/prettier.cjs" + }, + "engines": { + "node": ">=14" + }, + "funding": { + "url": "https://github.com/prettier/prettier?sponsor=1" + } + }, + "node_modules/pyright": { + "version": "1.1.374", + "resolved": "https://registry.npmjs.org/pyright/-/pyright-1.1.374.tgz", + "integrity": "sha512-ISbC1YnYDYrEatoKKjfaA5uFIp0ddC/xw9aSlN/EkmwupXUMVn41Jl+G6wHEjRhC+n4abHZeGpEvxCUus/K9dA==", + "bin": { + "pyright": "index.js", + "pyright-langserver": "langserver.index.js" + }, + "engines": { + "node": ">=14.0.0" + }, + "optionalDependencies": { + "fsevents": "~2.3.3" + } + } + } +} diff --git a/package.json b/package.json new file mode 100644 index 0000000..7b84315 --- /dev/null +++ b/package.json @@ -0,0 +1,15 @@ +{ + "name": "nostrclient", + "version": "1.0.0", + "description": "", + "main": "index.js", + "scripts": { + "test": "echo \"Error: no test specified\" && exit 1" + }, + "author": "", + "license": "ISC", + "dependencies": { + "prettier": "^3.2.5", + "pyright": "^1.1.358" + } +} diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..c7c0e56 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,98 @@ +[project] +name = "lnbits-nostrclient" +version = "1.1.0" +requires-python = ">=3.10,<3.13" +description = "LNbits, free and open-source Lightning wallet and accounts system." +authors = [{ name = "Alan Bits", email = "alan@lnbits.com" }] +urls = { Homepage = "https://lnbits.com", Repository = "https://github.com/lnbits/nostrclient" } +dependencies = [ "lnbits>1" ] + +[tool.poetry] +package-mode = false + +[tool.uv] +dev-dependencies = [ + "black", + "pytest-asyncio", + "pytest", + "mypy", + "pre-commit", + "ruff", + "pytest-md", + "types-cffi", +] + +[tool.mypy] +exclude = "(nostr/*)" +plugins = ["pydantic.mypy"] + +[[tool.mypy.overrides]] +module = [ + "nostr.*", +] +follow_imports = "skip" +ignore_missing_imports = "True" + +[tool.pydantic-mypy] +init_forbid_extra = true +init_typed = true +warn_required_dynamic_aliases = true +warn_untyped_fields = true + +[tool.pytest.ini_options] +log_cli = false +testpaths = [ + "tests" +] + +[tool.black] +line-length = 88 + +[tool.ruff] +# Same as Black. + 10% rule of black +line-length = 88 +exclude = [ + "nostr", +] + +[tool.ruff.lint] +# Enable: +# F - pyflakes +# E - pycodestyle errors +# W - pycodestyle warnings +# I - isort +# A - flake8-builtins +# C - mccabe +# N - naming +# UP - pyupgrade +# RUF - ruff +# B - bugbear +select = ["F", "E", "W", "I", "A", "C", "N", "UP", "RUF", "B"] +ignore = ["C901"] + +# Allow autofix for all enabled rules (when `--fix`) is provided. +fixable = ["ALL"] +unfixable = [] + +# Allow unused variables when underscore-prefixed. +dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" + +# needed for pydantic +[tool.ruff.lint.pep8-naming] +classmethod-decorators = [ + "root_validator", +] + +# Ignore unused imports in __init__.py files. +# [tool.ruff.lint.extend-per-file-ignores] +# "__init__.py" = ["F401", "F403"] + +# [tool.ruff.lint.mccabe] +# max-complexity = 10 + +[tool.ruff.lint.flake8-bugbear] +# Allow default arguments like, e.g., `data: List[str] = fastapi.Query(None)`. +extend-immutable-calls = [ + "fastapi.Depends", + "fastapi.Query", +] diff --git a/router.py b/router.py new file mode 100644 index 0000000..a7054e9 --- /dev/null +++ b/router.py @@ -0,0 +1,175 @@ +import asyncio +import json +from typing import ClassVar + +from fastapi import WebSocket, WebSocketDisconnect +from lnbits.helpers import urlsafe_short_hash +from loguru import logger + +from .nostr.client.client import NostrClient + +# from . import nostr_client +from .nostr.message_pool import EndOfStoredEventsMessage, EventMessage, NoticeMessage + +nostr_client: NostrClient = NostrClient() +all_routers: list["NostrRouter"] = [] + + +class NostrRouter: + received_subscription_events: ClassVar[dict[str, list[EventMessage]]] = {} + received_subscription_notices: ClassVar[list[NoticeMessage]] = [] + received_subscription_eosenotices: ClassVar[dict[str, EndOfStoredEventsMessage]] = ( + {} + ) + + def __init__(self, websocket: WebSocket): + self.connected: bool = True + self.websocket: WebSocket = websocket + self.tasks: list[asyncio.Task] = [] + self.original_subscription_ids: dict[str, str] = {} + + @property + def subscriptions(self) -> list[str]: + return list(self.original_subscription_ids.keys()) + + def start(self): + self.connected = True + self.tasks.append(asyncio.create_task(self._client_to_nostr())) + self.tasks.append(asyncio.create_task(self._nostr_to_client())) + + async def stop(self): + nostr_client.relay_manager.close_subscriptions(self.subscriptions) + self.connected = False + + for t in self.tasks: + try: + t.cancel() + except Exception as _: + pass + + try: + await self.websocket.close(reason="Websocket connection closed") + except Exception as _: + pass + + async def _client_to_nostr(self): + """ + Receives requests / data from the client and forwards it to relays. + """ + while self.connected: + try: + json_str = await self.websocket.receive_text() + except WebSocketDisconnect as e: + logger.debug(e) + await self.stop() + break + + try: + await self._handle_client_to_nostr(json_str) + except Exception as e: + logger.debug(f"Failed to handle client message: '{e!s}'.") + + async def _nostr_to_client(self): + """Sends responses from relays back to the client.""" + while self.connected: + try: + await self._handle_subscriptions() + self._handle_notices() + except Exception as e: + logger.debug(f"Failed to handle response for client: '{e!s}'.") + await asyncio.sleep(1) + await asyncio.sleep(0.1) + + async def _handle_subscriptions(self): + for s in self.subscriptions: + if s in NostrRouter.received_subscription_events: + await self._handle_received_subscription_events(s) + if s in NostrRouter.received_subscription_eosenotices: + await self._handle_received_subscription_eosenotices(s) + + async def _handle_received_subscription_eosenotices(self, s): + try: + if s not in self.original_subscription_ids: + return + s_original = self.original_subscription_ids[s] + event_to_forward = ["EOSE", s_original] + del NostrRouter.received_subscription_eosenotices[s] + + await self.websocket.send_text(json.dumps(event_to_forward)) + except Exception as e: + logger.debug(e) + + async def _handle_received_subscription_events(self, s): + try: + if s not in NostrRouter.received_subscription_events: + return + + while len(NostrRouter.received_subscription_events[s]): + event_message = NostrRouter.received_subscription_events[s].pop(0) + event_json = event_message.event + + # this reconstructs the original response from the relay + # reconstruct original subscription id + s_original = self.original_subscription_ids[s] + event_to_forward = json.dumps( + ["EVENT", s_original, json.loads(event_json)] + ) + await self.websocket.send_text(event_to_forward) + except Exception as e: + logger.warning( + f"[NOSTRCLIENT] Error in _handle_received_subscription_events: {e}" + ) + + def _handle_notices(self): + while len(NostrRouter.received_subscription_notices): + my_event = NostrRouter.received_subscription_notices.pop(0) + logger.debug(f"[Relay '{my_event.url}'] Notice: '{my_event.content}']") + # Note: we don't send it to the user because + # we don't know who should receive it + nostr_client.relay_manager.handle_notice(my_event) + + async def _handle_client_to_nostr(self, json_str): + json_data = json.loads(json_str) + assert len(json_data), "Bad JSON array" + + if json_data[0] == "REQ": + self._handle_client_req(json_data) + return + + if json_data[0] == "CLOSE": + self._handle_client_close(json_data[1]) + return + + if json_data[0] == "EVENT": + nostr_client.relay_manager.publish_message(json_str) + return + + def _handle_client_req(self, json_data): + subscription_id = json_data[1] + logger.info(f"New subscription: '{subscription_id}'") + subscription_id_rewritten = urlsafe_short_hash() + self.original_subscription_ids[subscription_id_rewritten] = subscription_id + filters = json_data[2:] + + nostr_client.relay_manager.add_subscription(subscription_id_rewritten, filters) + + def _handle_client_close(self, subscription_id): + subscription_id_rewritten = next( + ( + k + for k, v in self.original_subscription_ids.items() + if v == subscription_id + ), + None, + ) + if subscription_id_rewritten: + self.original_subscription_ids.pop(subscription_id_rewritten) + nostr_client.relay_manager.close_subscription(subscription_id_rewritten) + logger.info( + f""" + Unsubscribe from '{subscription_id_rewritten}'. + Original id: '{subscription_id}.' + """ + ) + else: + logger.info(f"Failed to unsubscribe from '{subscription_id}.'") diff --git a/services.py b/services.py deleted file mode 100644 index 82f6578..0000000 --- a/services.py +++ /dev/null @@ -1,163 +0,0 @@ -import asyncio -import json -from typing import List, Union - -from fastapi import WebSocket, WebSocketDisconnect -from loguru import logger - -from lnbits.helpers import urlsafe_short_hash - -from .models import Event, Filter, Filters, Relay, RelayList -from .nostr.client.client import NostrClient as NostrClientLib -from .nostr.event import Event as NostrEvent -from .nostr.filter import Filter as NostrFilter -from .nostr.filter import Filters as NostrFilters -from .nostr.message_pool import EndOfStoredEventsMessage, EventMessage, NoticeMessage - -received_subscription_events: dict[str, list[Event]] = {} -received_subscription_notices: list[NoticeMessage] = [] -received_subscription_eosenotices: dict[str, EndOfStoredEventsMessage] = {} - - -class NostrClient: - def __init__(self): - self.client: NostrClientLib = NostrClientLib(connect=False) - - -nostr = NostrClient() - - -class NostrRouter: - def __init__(self, websocket): - self.subscriptions: List[str] = [] - self.connected: bool = True - self.websocket = websocket - self.tasks: List[asyncio.Task] = [] - self.oridinal_subscription_ids = {} - - async def client_to_nostr(self): - """Receives requests / data from the client and forwards it to relays. If the - request was a subscription/filter, registers it with the nostr client lib. - Remembers the subscription id so we can send back responses from the relay to this - client in `nostr_to_client`""" - while True: - try: - json_str = await self.websocket.receive_text() - except WebSocketDisconnect: - self.connected = False - break - - # registers a subscription if the input was a REQ request - subscription_id, json_str_rewritten = await self._add_nostr_subscription( - json_str - ) - - if subscription_id and json_str_rewritten: - self.subscriptions.append(subscription_id) - json_str = json_str_rewritten - - # publish data - nostr.client.relay_manager.publish_message(json_str) - - async def nostr_to_client(self): - """Sends responses from relays back to the client. Polls the subscriptions of this client - stored in `my_subscriptions`. Then gets all responses for this subscription id from `received_subscription_events` which - is filled in tasks.py. Takes one response after the other and relays it back to the client. Reconstructs - the reponse manually because the nostr client lib we're using can't do it. Reconstructs the original subscription id - that we had previously rewritten in order to avoid collisions when multiple clients use the same id. - """ - while True and self.connected: - for s in self.subscriptions: - if s in received_subscription_events: - while len(received_subscription_events[s]): - my_event = received_subscription_events[s].pop(0) - # event.to_message() does not include the subscription ID, we have to add it manually - event_json = { - "id": my_event.id, - "pubkey": my_event.public_key, - "created_at": my_event.created_at, - "kind": my_event.kind, - "tags": my_event.tags, - "content": my_event.content, - "sig": my_event.signature, - } - - # this reconstructs the original response from the relay - # reconstruct original subscription id - s_original = self.oridinal_subscription_ids[s] - event_to_forward = ["EVENT", s_original, event_json] - - # print("Event to forward") - # print(json.dumps(event_to_forward)) - - # send data back to client - await self.websocket.send_text(json.dumps(event_to_forward)) - if s in received_subscription_eosenotices: - my_event = received_subscription_eosenotices[s] - s_original = self.oridinal_subscription_ids[s] - event_to_forward = ["EOSE", s_original] - del received_subscription_eosenotices[s] - # send data back to client - # print("Sending EOSE", event_to_forward) - await self.websocket.send_text(json.dumps(event_to_forward)) - - # if s in received_subscription_notices: - while len(received_subscription_notices): - my_event = received_subscription_notices.pop(0) - event_to_forward = ["NOTICE", my_event.content] - # send data back to client - logger.debug("Nostrclient: Received notice", event_to_forward[1]) - # note: we don't send it to the user because we don't know who should receive it - # await self.websocket.send_text(json.dumps(event_to_forward)) - await asyncio.sleep(0.1) - - async def start(self): - self.tasks.append(asyncio.create_task(self.client_to_nostr())) - self.tasks.append(asyncio.create_task(self.nostr_to_client())) - - async def stop(self): - for t in self.tasks: - t.cancel() - self.connected = False - - def _marshall_nostr_filters(self, data: Union[dict, list]): - filters = data if isinstance(data, list) else [data] - filters = [Filter.parse_obj(f) for f in filters] - filter_list: list[NostrFilter] = [] - for filter in filters: - filter_list.append( - NostrFilter( - event_ids=filter.ids, # type: ignore - kinds=filter.kinds, # type: ignore - authors=filter.authors, # type: ignore - since=filter.since, # type: ignore - until=filter.until, # type: ignore - event_refs=filter.e, # type: ignore - pubkey_refs=filter.p, # type: ignore - limit=filter.limit, # type: ignore - ) - ) - return NostrFilters(filter_list) - - async def _add_nostr_subscription(self, json_str): - """Parses a (string) request from a client. If it is a subscription (REQ) or a CLOSE, it will - register the subscription in the nostr client library that we're using so we can - receive the callbacks on it later. Will rewrite the subscription id since we expect - multiple clients to use the router and want to avoid subscription id collisions - """ - json_data = json.loads(json_str) - assert len(json_data) - if json_data[0] in ["REQ", "CLOSE"]: - subscription_id = json_data[1] - subscription_id_rewritten = urlsafe_short_hash() - self.oridinal_subscription_ids[subscription_id_rewritten] = subscription_id - fltr = json_data[2] - filters = self._marshall_nostr_filters(fltr) - nostr.client.relay_manager.add_subscription( - subscription_id_rewritten, filters - ) - request_rewritten = json.dumps( - [json_data[0], subscription_id_rewritten, fltr] - ) - return subscription_id_rewritten, request_rewritten - return None, None diff --git a/static/images/1.jpeg b/static/images/1.jpeg new file mode 100644 index 0000000..1f00661 Binary files /dev/null and b/static/images/1.jpeg differ diff --git a/static/images/2.jpeg b/static/images/2.jpeg new file mode 100644 index 0000000..e7301fd Binary files /dev/null and b/static/images/2.jpeg differ diff --git a/tasks.py b/tasks.py index beff9db..2a76765 100644 --- a/tasks.py +++ b/tasks.py @@ -1,94 +1,71 @@ import asyncio -import json -import ssl import threading +from loguru import logger + from .crud import get_relays -from .nostr.event import Event -from .nostr.key import PublicKey from .nostr.message_pool import EndOfStoredEventsMessage, EventMessage, NoticeMessage -from .nostr.relay_manager import RelayManager -from .services import ( - nostr, - received_subscription_eosenotices, - received_subscription_events, - received_subscription_notices, -) +from .router import NostrRouter, nostr_client async def init_relays(): - # we save any subscriptions teporarily to re-add them after reinitializing the client - subscriptions = {} - for relay in nostr.client.relay_manager.relays.values(): - # relay.add_subscription(id, filters) - for subscription_id, filters in relay.subscriptions.items(): - subscriptions[subscription_id] = filters - - # reinitialize the entire client - nostr.__init__() # get relays from db relays = await get_relays() # set relays and connect to them - nostr.client.relays = list(set([r.url for r in relays.__root__ if r.url])) - nostr.client.connect() + valid_relays = [r.url for r in relays if r.url] - await asyncio.sleep(2) - # re-add subscriptions - for subscription_id, subscription in subscriptions.items(): - nostr.client.relay_manager.add_subscription( - subscription_id, subscription.filters - ) - s = subscription.to_json_object() - json_str = json.dumps(["REQ", s["id"], s["filters"][0]]) - nostr.client.relay_manager.publish_message(json_str) - return + nostr_client.reconnect(valid_relays) + + +async def check_relays(): + """Check relays that have been disconnected""" + while True: + try: + await asyncio.sleep(20) + nostr_client.relay_manager.check_and_restart_relays() + except Exception as e: + logger.warning(f"Cannot restart relays: '{e!s}'.") async def subscribe_events(): - while not any([r.connected for r in nostr.client.relay_manager.relays.values()]): + while not [r.connected for r in nostr_client.relay_manager.relays.values()]: await asyncio.sleep(2) - def callback_events(eventMessage: EventMessage): - # print(f"From {event.public_key[:3]}..{event.public_key[-3:]}: {event.content}") - if eventMessage.subscription_id in received_subscription_events: - # do not add duplicate events (by event id) - if eventMessage.event.id in set( - [ - e.id - for e in received_subscription_events[eventMessage.subscription_id] - ] - ): - return + def callback_events(event_message: EventMessage): + sub_id = event_message.subscription_id + if sub_id not in NostrRouter.received_subscription_events: + NostrRouter.received_subscription_events[sub_id] = [event_message] + return - received_subscription_events[eventMessage.subscription_id].append( - eventMessage.event + # do not add duplicate events (by event id) + ids = [e.event_id for e in NostrRouter.received_subscription_events[sub_id]] + if event_message.event_id in ids: + return + + NostrRouter.received_subscription_events[sub_id].append(event_message) + + def callback_notices(notice_message: NoticeMessage): + if notice_message not in NostrRouter.received_subscription_notices: + NostrRouter.received_subscription_notices.append(notice_message) + + def callback_eose_notices(event_message: EndOfStoredEventsMessage): + sub_id = event_message.subscription_id + if sub_id in NostrRouter.received_subscription_eosenotices: + return + + NostrRouter.received_subscription_eosenotices[sub_id] = event_message + + def wrap_async_subscribe(): + asyncio.run( + nostr_client.subscribe( + callback_events, + callback_notices, + callback_eose_notices, ) - else: - received_subscription_events[eventMessage.subscription_id] = [ - eventMessage.event - ] - return - - def callback_notices(noticeMessage: NoticeMessage): - if noticeMessage not in received_subscription_notices: - received_subscription_notices.append(noticeMessage) - return - - def callback_eose_notices(eventMessage: EndOfStoredEventsMessage): - if eventMessage.subscription_id not in received_subscription_eosenotices: - received_subscription_eosenotices[ - eventMessage.subscription_id - ] = eventMessage - - return + ) t = threading.Thread( - target=nostr.client.subscribe, - args=( - callback_events, - callback_notices, - callback_eose_notices, - ), + target=wrap_async_subscribe, name="Nostr-event-subscription", daemon=True, ) diff --git a/templates/nostrclient/index.html b/templates/nostrclient/index.html index bbd8f23..ca44f1b 100644 --- a/templates/nostrclient/index.html +++ b/templates/nostrclient/index.html @@ -4,8 +4,8 @@
-
-
+
+
-
- Add relay - +
+ + + + + + + +
+
+
@@ -45,7 +71,7 @@ -
-
-
- {{ col.value }} -
-
{{ col.value }}
+
+
🟢
+
🔴
+
+
+
+ ⬆️ ⬇️ + + + ⚠️ + + + ⓘ + +
+
+
+ +
+
+
{{ col.value }}
- - - @@ -99,8 +143,18 @@
-
+
+ Copy address Your endpoint:
-
- -
- - -
-
- Sender Private Key: + + + +
+
+ Sender Private Key: +
+
+ +
-
- +
+
+
+ + + No not use your real private key! Leave empty for a randomly + generated key. + +
-
-
-
-
- - This should be a temp private (throw away). No not user your - own private key! - - +
+
+ Sender Public Key: +
+
+ +
+
+
+
+ Test Message: +
+
+ +
+
+
+
+ Receiver Public Key: +
+
+ +
+
+
+
+
+ + This is the recipient of the message. Field required. + +
+
+
+
+ Send Message +
+
+ - - It is optional. One can be generated for you! - + + +
+
+ Sent Data: +
+
+ +
-
-
-
- Sender Public Key: +
+
+ Received Data: +
+
+ +
-
- -
-
-
-
- Test Message: -
-
- -
-
-
-
- Receiver Public Key: -
-
- -
-
-
-
-
- - This is the recipient of the message. Field required. - -
-
-
-
- Send Message -
-
- - - -
-
- Sent Data: -
-
- -
-
-
-
- Received Data: -
-
- -
-
-
+ +
@@ -271,23 +326,71 @@
+ + + + + +
+ Close +
+
+
+ + + + + +
+ +
+ Update + Cancel +
+
+
+
{% endraw %} {% endblock %} {% block scripts %} {{ window_vars(user) }}