diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
new file mode 100644
index 0000000..5bcdee7
--- /dev/null
+++ b/.github/workflows/ci.yml
@@ -0,0 +1,29 @@
+name: CI
+on:
+ push:
+ branches:
+ - main
+ pull_request:
+
+jobs:
+ lint:
+ uses: lnbits/lnbits/.github/workflows/lint.yml@dev
+ tests:
+ runs-on: ubuntu-latest
+ needs: [lint]
+ steps:
+ - uses: actions/checkout@v4
+ - uses: lnbits/lnbits/.github/actions/prepare@dev
+ - name: Run pytest
+ uses: pavelzw/pytest-action@v2
+ env:
+ LNBITS_BACKEND_WALLET_CLASS: FakeWallet
+ PYTHONUNBUFFERED: 1
+ DEBUG: true
+ with:
+ verbose: true
+ job-summary: true
+ emoji: false
+ click-to-expand: true
+ custom-pytest: uv run pytest
+ report-title: 'test'
diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml
new file mode 100644
index 0000000..27c8a60
--- /dev/null
+++ b/.github/workflows/release.yml
@@ -0,0 +1,58 @@
+on:
+ push:
+ tags:
+ - 'v[0-9]+.[0-9]+.[0-9]+'
+
+jobs:
+ release:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v3
+ - name: Create github release
+ env:
+ GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
+ tag: ${{ github.ref_name }}
+ run: |
+ gh release create "$tag" --generate-notes
+
+ pullrequest:
+ needs: [release]
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v3
+ with:
+ token: ${{ secrets.EXT_GITHUB }}
+ repository: lnbits/lnbits-extensions
+ path: './lnbits-extensions'
+
+ - name: setup git user
+ run: |
+ git config --global user.name "alan"
+ git config --global user.email "alan@lnbits.com"
+
+ - name: Create pull request in extensions repo
+ env:
+ GH_TOKEN: ${{ secrets.EXT_GITHUB }}
+ repo_name: '${{ github.event.repository.name }}'
+ tag: '${{ github.ref_name }}'
+ branch: 'update-${{ github.event.repository.name }}-${{ github.ref_name }}'
+ title: '[UPDATE] ${{ github.event.repository.name }} to ${{ github.ref_name }}'
+ body: 'https://github.com/lnbits/${{ github.event.repository.name }}/releases/${{ github.ref_name }}'
+ archive: 'https://github.com/lnbits/${{ github.event.repository.name }}/archive/refs/tags/${{ github.ref_name }}.zip'
+ run: |
+ cd lnbits-extensions
+ git checkout -b $branch
+
+ # if there is another open PR
+ git pull origin $branch || echo "branch does not exist"
+
+ sh util.sh update_extension $repo_name $tag
+
+ git add -A
+ git commit -am "$title"
+ git push origin $branch
+
+ # check if pr exists before creating it
+ gh config set pager cat
+ check=$(gh pr list -H $branch | wc -l)
+ test $check -ne 0 || gh pr create --title "$title" --body "$body" --repo lnbits/lnbits-extensions
diff --git a/.gitignore b/.gitignore
index 10a11d5..0152b6e 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,24 +1,4 @@
-.DS_Store
-._*
-
__pycache__
-*.py[cod]
-*$py.class
+node_modules
.mypy_cache
-.vscode
-*-lock.json
-
-*.egg
-*.egg-info
-.coverage
-.pytest_cache
-.webassets-cache
-htmlcov
-test-reports
-tests/data/*.sqlite3
-
-*.swo
-*.swp
-*.pyo
-*.pyc
-*.env
\ No newline at end of file
+.venv
diff --git a/.prettierrc b/.prettierrc
new file mode 100644
index 0000000..725c398
--- /dev/null
+++ b/.prettierrc
@@ -0,0 +1,12 @@
+{
+ "semi": false,
+ "arrowParens": "avoid",
+ "insertPragma": false,
+ "printWidth": 80,
+ "proseWrap": "preserve",
+ "singleQuote": true,
+ "trailingComma": "none",
+ "useTabs": false,
+ "bracketSameLine": false,
+ "bracketSpacing": false
+}
diff --git a/Makefile b/Makefile
new file mode 100644
index 0000000..0fac253
--- /dev/null
+++ b/Makefile
@@ -0,0 +1,47 @@
+all: format check
+
+format: prettier black ruff
+
+check: mypy pyright checkblack checkruff checkprettier
+
+prettier:
+ uv run ./node_modules/.bin/prettier --write .
+pyright:
+ uv run ./node_modules/.bin/pyright
+
+mypy:
+ uv run mypy .
+
+black:
+ uv run black .
+
+ruff:
+ uv run ruff check . --fix
+
+checkruff:
+ uv run ruff check .
+
+checkprettier:
+ uv run ./node_modules/.bin/prettier --check .
+
+checkblack:
+ uv run black --check .
+
+checkeditorconfig:
+ editorconfig-checker
+
+test:
+ PYTHONUNBUFFERED=1 \
+ DEBUG=true \
+ uv run pytest
+install-pre-commit-hook:
+ @echo "Installing pre-commit hook to git"
+ @echo "Uninstall the hook with uv run pre-commit uninstall"
+ uv run pre-commit install
+
+pre-commit:
+ uv run pre-commit run --all-files
+
+
+checkbundle:
+ @echo "skipping checkbundle"
diff --git a/README.md b/README.md
index e609b6c..70593a8 100644
--- a/README.md
+++ b/README.md
@@ -1,7 +1,124 @@
-# nostrclient
+# Nostrclient - [LNbits](https://github.com/lnbits/lnbits) extension
-`nostrclient` is an always-on extension that can open multiple connections to nostr relays and act as a multiplexer for other clients: You open a single websocket to `nostrclient` which then sends the data to multiple relays. The responses from these relays are then sent back to the client.
+For more about LNBits extension check [this tutorial](https://github.com/lnbits/lnbits/wiki/LNbits-Extensions)
+## Overview
+`nostrclient` is an always-on Nostr relay multiplexer that simplifies connecting to multiple Nostr relays. Instead of your Nostr client managing connections to dozens of relays, you connect to a single WebSocket endpoint provided by `nostrclient`, which then fans out your requests to all configured relays and aggregates the responses back to you.
-
+### Why Use This?
+
+- **Simplified Client Configuration** - Connect to one endpoint instead of managing multiple relay connections
+- **Always-On Connectivity** - Your LNbits instance maintains persistent connections to relays
+- **Resource Efficient** - Share relay connections across multiple clients
+- **Subscription Management** - Automatic subscription ID rewriting prevents conflicts between clients
+
+## Architecture
+
+```mermaid
+flowchart LR
+ A[Client A] -->|WebSocket| N
+ B[Client B] -->|WebSocket| N
+ C[Client C] -->|WebSocket| N
+
+ N[nostrclient
Router] -->|Fan Out| R1[Relay A]
+ N -->|Fan Out| R2[Relay B]
+ N -->|Fan Out| R3[Relay C]
+ N -->|Fan Out| R4[Relay D]
+
+ R1 -.->|Aggregate| N
+ R2 -.->|Aggregate| N
+ R3 -.->|Aggregate| N
+ R4 -.->|Aggregate| N
+```
+
+**Key Feature:** The router rewrites subscription IDs to prevent conflicts when multiple clients use the same IDs.
+
+## Features
+
+- **Multi-Relay Multiplexing** - Connect to multiple Nostr relays through a single WebSocket
+- **Public & Private Endpoints** - Configurable public and private WebSocket access
+- **Automatic Reconnection** - Failed relays are automatically retried with exponential backoff
+- **Subscription Deduplication** - Events are deduplicated before being sent to clients
+- **Health Monitoring** - Track relay connection status, latency, and error rates
+- **Test Endpoint** - Send test messages to verify your setup is working
+
+## How It Works
+
+1. **Client Connection** - Your Nostr client connects to the nostrclient WebSocket endpoint
+2. **Subscription Rewriting** - Each subscription ID is rewritten to prevent conflicts between multiple clients
+3. **Fan-Out** - Subscription requests are sent to all configured relays
+4. **Aggregation** - Events from all relays are collected and deduplicated
+5. **Response** - Events are sent back to the client with the original subscription ID
+
+## Configuration
+
+### WebSocket Endpoints
+
+- **Public Endpoint**: `/api/v1/relay` - Available to anyone (if enabled)
+- **Private Endpoint**: `/api/v1/{encrypted_id}` - Requires valid encrypted endpoint ID
+
+Configure endpoint access in the extension settings:
+
+- `private_ws` - Enable/disable private WebSocket access
+- `public_ws` - Enable/disable public WebSocket access
+
+### Adding Relays
+
+Use the nostrclient UI to add/remove Nostr relays. The extension will automatically:
+
+- Connect to new relays
+- Publish existing subscriptions to new relays
+- Monitor relay health and reconnect as needed
+
+## Testing
+
+### Test Endpoint Functionality
+
+The `Test Endpoint` feature helps verify that your nostrclient WebSocket endpoint works correctly.
+
+**How to test:**
+
+1. Navigate to the nostrclient extension in LNbits
+2. Use the Test Endpoint feature
+3. Send a DM to yourself (or a temporary account)
+4. Verify that messages are sent and received correctly
+
+https://user-images.githubusercontent.com/2951406/236780745-929c33c2-2502-49be-84a3-db02a7aabc0e.mp4
+
+## Troubleshooting
+
+### Connection Issues
+
+- **Check relay status** - View relay health in the nostrclient UI
+- **Verify endpoint configuration** - Ensure public_ws or private_ws is enabled
+- **Check logs** - Review LNbits logs for connection errors
+
+### Subscription Not Receiving Events
+
+- **Verify relays are connected** - Check the relay status in the UI
+- **Test with known event** - Use the Test Endpoint to verify connectivity
+- **Check relay compatibility** - Some relays may not support all Nostr features
+
+## Development
+
+This extension uses `uv` for dependency management.
+
+### Quick Start
+
+```bash
+# Format code
+make format
+
+# Run type checks and linting
+make check
+
+# Run tests
+make test
+```
+
+For more development commands, see the [Makefile](./Makefile).
+
+## License
+
+MIT License - see [LICENSE](./LICENSE)
diff --git a/__init__.py b/__init__.py
index 1555eb9..d7eb435 100644
--- a/__init__.py
+++ b/__init__.py
@@ -1,36 +1,59 @@
+import asyncio
+
from fastapi import APIRouter
-from starlette.staticfiles import StaticFiles
+from loguru import logger
-from lnbits.db import Database
-from lnbits.helpers import template_renderer
-from lnbits.tasks import catch_everything_and_restart
-
-db = Database("ext_nostrclient")
+from .crud import db
+from .router import all_routers, nostr_client
+from .tasks import check_relays, init_relays, subscribe_events
+from .views import nostrclient_generic_router
+from .views_api import nostrclient_api_router
nostrclient_static_files = [
{
"path": "/nostrclient/static",
- "app": StaticFiles(directory="lnbits/extensions/nostrclient/static"),
"name": "nostrclient_static",
}
]
nostrclient_ext: APIRouter = APIRouter(prefix="/nostrclient", tags=["nostrclient"])
+nostrclient_ext.include_router(nostrclient_generic_router)
+nostrclient_ext.include_router(nostrclient_api_router)
+scheduled_tasks: list[asyncio.Task] = []
-def nostr_renderer():
- return template_renderer(["lnbits/extensions/nostrclient/templates"])
+async def nostrclient_stop():
+ for task in scheduled_tasks:
+ try:
+ task.cancel()
+ except Exception as ex:
+ logger.warning(ex)
+ for router in all_routers:
+ try:
+ await router.stop()
+ all_routers.remove(router)
+ except Exception as e:
+ logger.error(e)
-from .views import * # noqa
-from .views_api import * # noqa
-
-from .tasks import init_relays, subscribe_events
+ nostr_client.close()
def nostrclient_start():
- loop = asyncio.get_event_loop()
- loop.create_task(catch_everything_and_restart(init_relays))
- # loop.create_task(catch_everything_and_restart(send_data))
- # loop.create_task(catch_everything_and_restart(receive_data))
- loop.create_task(catch_everything_and_restart(subscribe_events))
+ from lnbits.tasks import create_permanent_unique_task
+
+ task1 = create_permanent_unique_task("ext_nostrclient_init_relays", init_relays)
+ task2 = create_permanent_unique_task(
+ "ext_nostrclient_subscrive_events", subscribe_events
+ )
+ task3 = create_permanent_unique_task("ext_nostrclient_check_relays", check_relays)
+ scheduled_tasks.extend([task1, task2, task3])
+
+
+__all__ = [
+ "db",
+ "nostrclient_ext",
+ "nostrclient_start",
+ "nostrclient_static_files",
+ "nostrclient_stop",
+]
diff --git a/cbc.py b/cbc.py
deleted file mode 100644
index 0d9e04f..0000000
--- a/cbc.py
+++ /dev/null
@@ -1,26 +0,0 @@
-from Cryptodome.Cipher import AES
-
-BLOCK_SIZE = 16
-
-
-class AESCipher(object):
- """This class is compatible with crypto.createCipheriv('aes-256-cbc')"""
-
- def __init__(self, key=None):
- self.key = key
-
- def pad(self, data):
- length = BLOCK_SIZE - (len(data) % BLOCK_SIZE)
- return data + (chr(length) * length).encode()
-
- def unpad(self, data):
- return data[: -(data[-1] if type(data[-1]) == int else ord(data[-1]))]
-
- def encrypt(self, plain_text):
- cipher = AES.new(self.key, AES.MODE_CBC)
- b = plain_text.encode("UTF-8")
- return cipher.iv, cipher.encrypt(self.pad(b))
-
- def decrypt(self, iv, enc_text):
- cipher = AES.new(self.key, AES.MODE_CBC, iv=iv)
- return self.unpad(cipher.decrypt(enc_text).decode("UTF-8"))
diff --git a/config.json b/config.json
index ce8ae18..1f58e7b 100644
--- a/config.json
+++ b/config.json
@@ -1,6 +1,17 @@
{
"name": "Nostr Client",
- "short_description": "Nostr client for extensions",
+ "short_description": "Nostr relay multiplexer",
+ "version": "1.1.0",
"tile": "/nostrclient/static/images/nostr-bitcoin.png",
- "contributors": ["calle"]
+ "contributors": ["calle", "motorina0", "dni"],
+ "min_lnbits_version": "1.4.0",
+ "images": [
+ {
+ "uri": "https://raw.githubusercontent.com/lnbits/nostrclient/add-extension-metadata/static/images/1.jpeg"
+ },
+ {
+ "uri": "https://raw.githubusercontent.com/lnbits/nostrclient/add-extension-metadata/static/images/2.jpeg"
+ }
+ ],
+ "description_md": "https://raw.githubusercontent.com/lnbits/nostrclient/add-extension-metadata/description.md"
}
diff --git a/crud.py b/crud.py
index b1916b7..d311c72 100644
--- a/crud.py
+++ b/crud.py
@@ -1,29 +1,52 @@
-from typing import List, Optional, Union
+from lnbits.db import Database
-from lnbits.helpers import urlsafe_short_hash
-import shortuuid
-from . import db
-from .models import Relay, RelayList
+from .models import Config, Relay, UserConfig
+
+db = Database("ext_nostrclient")
-async def get_relays() -> RelayList:
- row = await db.fetchall("SELECT * FROM nostrclient.relays")
- return RelayList(__root__=row)
-
-
-async def add_relay(relay: Relay) -> None:
- await db.execute(
- f"""
- INSERT INTO nostrclient.relays (
- id,
- url,
- active
- )
- VALUES (?, ?, ?)
- """,
- (relay.id, relay.url, relay.active),
+async def get_relays() -> list[Relay]:
+ return await db.fetchall(
+ "SELECT * FROM nostrclient.relays",
+ model=Relay,
)
+async def add_relay(relay: Relay) -> Relay:
+ await db.insert("nostrclient.relays", relay)
+ return relay
+
+
async def delete_relay(relay: Relay) -> None:
- await db.execute("DELETE FROM nostrclient.relays WHERE url = ?", (relay.url,))
+ if not relay.url:
+ return
+ await db.execute(
+ "DELETE FROM nostrclient.relays WHERE url = :url", {"url": relay.url}
+ )
+
+
+######################CONFIG#######################
+async def create_config(owner_id: str) -> Config:
+ admin_config = UserConfig(owner_id=owner_id)
+ await db.insert("nostrclient.config", admin_config)
+ return admin_config.extra
+
+
+async def update_config(owner_id: str, config: Config) -> Config:
+ user_config = UserConfig(owner_id=owner_id, extra=config)
+ await db.update("nostrclient.config", user_config, "WHERE owner_id = :owner_id")
+ return user_config.extra
+
+
+async def get_config(owner_id: str) -> Config | None:
+ user_config: UserConfig = await db.fetchone(
+ """
+ SELECT * FROM nostrclient.config
+ WHERE owner_id = :owner_id
+ """,
+ {"owner_id": owner_id},
+ model=UserConfig,
+ )
+ if user_config:
+ return user_config.extra
+ return None
diff --git a/description.md b/description.md
new file mode 100644
index 0000000..5293087
--- /dev/null
+++ b/description.md
@@ -0,0 +1,8 @@
+An always-on relay multiplexer that simplifies connecting to multiple Nostr relays.
+
+Instead of your Nostr client managing connections to dozens of relays, you connect to a single WebSocket endpoint provided by `nostrclient`, which then fans out your requests to all configured relays and aggregates the responses back to you.
+
+- **Simplified Client Configuration** - Connect to one endpoint instead of managing multiple relay connections
+- **Always-On Connectivity** - Your LNbits instance maintains persistent connections to relays
+- **Resource Efficient** - Share relay connections across multiple clients
+- **Automatic Subscription Management** - Subscription ID rewriting prevents conflicts between clients
diff --git a/helpers.py b/helpers.py
new file mode 100644
index 0000000..bcf5c02
--- /dev/null
+++ b/helpers.py
@@ -0,0 +1,19 @@
+from bech32 import bech32_decode, convertbits
+
+
+def normalize_public_key(pubkey: str) -> str:
+ if pubkey.startswith("npub1"):
+ _, decoded_data = bech32_decode(pubkey)
+ if not decoded_data:
+ raise ValueError("Public Key is not valid npub")
+
+ decoded_data_bits = convertbits(decoded_data, 5, 8, False)
+ if not decoded_data_bits:
+ raise ValueError("Public Key is not valid npub")
+ return bytes(decoded_data_bits).hex()
+
+ # check if valid hex
+ if len(pubkey) != 64:
+ raise ValueError("Public Key is not valid hex")
+ int(pubkey, 16)
+ return pubkey
diff --git a/migrations.py b/migrations.py
index 5a30e45..b16db58 100644
--- a/migrations.py
+++ b/migrations.py
@@ -3,7 +3,7 @@ async def m001_initial(db):
Initial nostrclient table.
"""
await db.execute(
- f"""
+ """
CREATE TABLE nostrclient.relays (
id TEXT NOT NULL PRIMARY KEY,
url TEXT NOT NULL,
@@ -11,3 +11,22 @@ async def m001_initial(db):
);
"""
)
+
+
+async def m002_create_config_table(db):
+ """
+ Allow the extension to persist and retrieve any number of config values.
+ """
+
+ await db.execute(
+ """CREATE TABLE nostrclient.config (
+ json_data TEXT NOT NULL
+ );"""
+ )
+
+
+async def m003_update_config_table(db):
+ await db.execute("ALTER TABLE nostrclient.config RENAME COLUMN json_data TO extra")
+ await db.execute(
+ "ALTER TABLE nostrclient.config ADD COLUMN owner_id TEXT DEFAULT 'admin'"
+ )
diff --git a/models.py b/models.py
index bfbc424..937c6c5 100644
--- a/models.py
+++ b/models.py
@@ -1,92 +1,54 @@
-from typing import List, Dict
-from typing import Optional
-
-from fastapi import Request
+from lnbits.helpers import urlsafe_short_hash
from pydantic import BaseModel, Field
-from fastapi.param_functions import Query
-from dataclasses import dataclass
-from lnbits.helpers import urlsafe_short_hash
+
+class RelayStatus(BaseModel):
+ num_sent_events: int | None = 0
+ num_received_events: int | None = 0
+ error_counter: int | None = 0
+ error_list: list | None = []
+ notice_list: list | None = []
class Relay(BaseModel):
- id: Optional[str] = None
- url: Optional[str] = None
- connected: Optional[bool] = None
- connected_string: Optional[str] = None
- status: Optional[str] = None
- active: Optional[bool] = None
- ping: Optional[int] = None
+ id: str | None = None
+ url: str | None = None
+ active: bool | None = None
+
+ connected: bool | None = Field(default=None, no_database=True)
+ connected_string: str | None = Field(default=None, no_database=True)
+ status: RelayStatus | None = Field(default=None, no_database=True)
+
+ ping: int | None = Field(default=None, no_database=True)
def _init__(self):
if not self.id:
self.id = urlsafe_short_hash()
-class RelayList(BaseModel):
- __root__: List[Relay]
+class RelayDb(BaseModel):
+ id: str
+ url: str
+ active: bool | None = True
-class Event(BaseModel):
- content: str
- pubkey: str
- created_at: Optional[int]
- kind: int
- tags: Optional[List[List[str]]]
- sig: str
+class TestMessage(BaseModel):
+ sender_private_key: str | None
+ reciever_public_key: str
+ message: str
-class Filter(BaseModel):
- ids: Optional[List[str]]
- kinds: Optional[List[int]]
- authors: Optional[List[str]]
- since: Optional[int]
- until: Optional[int]
- e: Optional[List[str]] = Field(alias="#e")
- p: Optional[List[str]] = Field(alias="#p")
- limit: Optional[int]
+class TestMessageResponse(BaseModel):
+ private_key: str
+ public_key: str
+ event_json: str
-class Filters(BaseModel):
- __root__: List[Filter]
+class Config(BaseModel):
+ private_ws: bool = True
+ public_ws: bool = False
-# class nostrKeys(BaseModel):
-# pubkey: str
-# privkey: str
-
-# class nostrNotes(BaseModel):
-# id: str
-# pubkey: str
-# created_at: str
-# kind: int
-# tags: str
-# content: str
-# sig: str
-
-# class nostrCreateRelays(BaseModel):
-# relay: str = Query(None)
-
-# class nostrCreateConnections(BaseModel):
-# pubkey: str = Query(None)
-# relayid: str = Query(None)
-
-# class nostrRelays(BaseModel):
-# id: Optional[str]
-# relay: Optional[str]
-# status: Optional[bool] = False
-
-
-# class nostrRelaySetList(BaseModel):
-# allowlist: Optional[str]
-# denylist: Optional[str]
-
-# class nostrConnections(BaseModel):
-# id: str
-# pubkey: Optional[str]
-# relayid: Optional[str]
-
-# class nostrSubscriptions(BaseModel):
-# id: str
-# userPubkey: Optional[str]
-# subscribedPubkey: Optional[str]
+class UserConfig(BaseModel):
+ owner_id: str
+ extra: Config = Config()
diff --git a/nostr/bech32.py b/nostr/bech32.py
index b068de7..ba2ddd1 100644
--- a/nostr/bech32.py
+++ b/nostr/bech32.py
@@ -23,21 +23,25 @@
from enum import Enum
+
class Encoding(Enum):
"""Enumeration type to list the various supported encodings."""
+
BECH32 = 1
BECH32M = 2
+
CHARSET = "qpzry9x8gf2tvdw0s3jn54khce6mua7l"
-BECH32M_CONST = 0x2bc830a3
+BECH32M_CONST = 0x2BC830A3
+
def bech32_polymod(values):
"""Internal function that computes the Bech32 checksum."""
- generator = [0x3b6a57b2, 0x26508e6d, 0x1ea119fa, 0x3d4233dd, 0x2a1462b3]
+ generator = [0x3B6A57B2, 0x26508E6D, 0x1EA119FA, 0x3D4233DD, 0x2A1462B3]
chk = 1
for value in values:
top = chk >> 25
- chk = (chk & 0x1ffffff) << 5 ^ value
+ chk = (chk & 0x1FFFFFF) << 5 ^ value
for i in range(5):
chk ^= generator[i] if ((top >> i) & 1) else 0
return chk
@@ -57,6 +61,7 @@ def bech32_verify_checksum(hrp, data):
return Encoding.BECH32M
return None
+
def bech32_create_checksum(hrp, data, spec):
"""Compute the checksum values given HRP and data."""
values = bech32_hrp_expand(hrp) + data
@@ -68,26 +73,29 @@ def bech32_create_checksum(hrp, data, spec):
def bech32_encode(hrp, data, spec):
"""Compute a Bech32 string given HRP and data values."""
combined = data + bech32_create_checksum(hrp, data, spec)
- return hrp + '1' + ''.join([CHARSET[d] for d in combined])
+ return hrp + "1" + "".join([CHARSET[d] for d in combined])
+
def bech32_decode(bech):
"""Validate a Bech32/Bech32m string, and determine HRP and data."""
- if ((any(ord(x) < 33 or ord(x) > 126 for x in bech)) or
- (bech.lower() != bech and bech.upper() != bech)):
+ if (any(ord(x) < 33 or ord(x) > 126 for x in bech)) or (
+ bech.lower() != bech and bech.upper() != bech
+ ):
return (None, None, None)
bech = bech.lower()
- pos = bech.rfind('1')
+ pos = bech.rfind("1")
if pos < 1 or pos + 7 > len(bech) or len(bech) > 90:
return (None, None, None)
- if not all(x in CHARSET for x in bech[pos+1:]):
+ if not all(x in CHARSET for x in bech[pos + 1 :]):
return (None, None, None)
hrp = bech[:pos]
- data = [CHARSET.find(x) for x in bech[pos+1:]]
+ data = [CHARSET.find(x) for x in bech[pos + 1 :]]
spec = bech32_verify_checksum(hrp, data)
if spec is None:
return (None, None, None)
return (hrp, data[:-6], spec)
+
def convertbits(data, frombits, tobits, pad=True):
"""General power-of-2 base conversion."""
acc = 0
@@ -116,22 +124,29 @@ def decode(hrp, addr):
hrpgot, data, spec = bech32_decode(addr)
if hrpgot != hrp:
return (None, None)
- decoded = convertbits(data[1:], 5, 8, False)
+ decoded = convertbits(data[1:], 5, 8, False) # type: ignore
if decoded is None or len(decoded) < 2 or len(decoded) > 40:
return (None, None)
- if data[0] > 16:
+ if data[0] > 16: # type: ignore
return (None, None)
- if data[0] == 0 and len(decoded) != 20 and len(decoded) != 32:
+ if data[0] == 0 and len(decoded) != 20 and len(decoded) != 32: # type: ignore
return (None, None)
- if data[0] == 0 and spec != Encoding.BECH32 or data[0] != 0 and spec != Encoding.BECH32M:
+ if (
+ data[0] == 0 # type: ignore
+ and spec != Encoding.BECH32
+ or data[0] != 0 # type: ignore
+ and spec != Encoding.BECH32M
+ ):
return (None, None)
- return (data[0], decoded)
+ return (data[0], decoded) # type: ignore
def encode(hrp, witver, witprog):
"""Encode a segwit address."""
spec = Encoding.BECH32 if witver == 0 else Encoding.BECH32M
- ret = bech32_encode(hrp, [witver] + convertbits(witprog, 8, 5), spec)
+ wit_prog = convertbits(witprog, 8, 5)
+ assert wit_prog
+ ret = bech32_encode(hrp, [witver, *wit_prog], spec)
if decode(hrp, ret) == (None, None):
return None
return ret
diff --git a/nostr/client/cbc.py b/nostr/client/cbc.py
deleted file mode 100644
index a41dbc0..0000000
--- a/nostr/client/cbc.py
+++ /dev/null
@@ -1,41 +0,0 @@
-
-from Cryptodome import Random
-from Cryptodome.Cipher import AES
-
-plain_text = "This is the text to encrypts"
-
-# encrypted = "7mH9jq3K9xNfWqIyu9gNpUz8qBvGwsrDJ+ACExdV1DvGgY8q39dkxVKeXD7LWCDrPnoD/ZFHJMRMis8v9lwHfNgJut8EVTMuJJi8oTgJevOBXl+E+bJPwej9hY3k20rgCQistNRtGHUzdWyOv7S1tg==".encode()
-# iv = "GzDzqOVShWu3Pl2313FBpQ==".encode()
-
-key = bytes.fromhex("3aa925cb69eb613e2928f8a18279c78b1dca04541dfd064df2eda66b59880795")
-
-BLOCK_SIZE = 16
-
-class AESCipher(object):
- """This class is compatible with crypto.createCipheriv('aes-256-cbc')
-
- """
- def __init__(self, key=None):
- self.key = key
-
- def pad(self, data):
- length = BLOCK_SIZE - (len(data) % BLOCK_SIZE)
- return data + (chr(length) * length).encode()
-
- def unpad(self, data):
- return data[: -(data[-1] if type(data[-1]) == int else ord(data[-1]))]
-
- def encrypt(self, plain_text):
- cipher = AES.new(self.key, AES.MODE_CBC)
- b = plain_text.encode("UTF-8")
- return cipher.iv, cipher.encrypt(self.pad(b))
-
- def decrypt(self, iv, enc_text):
- cipher = AES.new(self.key, AES.MODE_CBC, iv=iv)
- return self.unpad(cipher.decrypt(enc_text).decode("UTF-8"))
-
-if __name__ == "__main__":
- aes = AESCipher(key=key)
- iv, enc_text = aes.encrypt(plain_text)
- dec_text = aes.decrypt(iv, enc_text)
- print(dec_text)
\ No newline at end of file
diff --git a/nostr/client/client.py b/nostr/client/client.py
index 6fb885f..d6fb5c8 100644
--- a/nostr/client/client.py
+++ b/nostr/client/client.py
@@ -1,152 +1,75 @@
-from typing import *
-import ssl
-import time
-import json
-import os
-import base64
+import asyncio
+
+from loguru import logger
-from ..event import Event
from ..relay_manager import RelayManager
-from ..message_type import ClientMessageType
-from ..key import PrivateKey, PublicKey
-
-from ..filter import Filter, Filters
-from ..event import Event, EventKind, EncryptedDirectMessage
-from ..relay_manager import RelayManager
-from ..message_type import ClientMessageType
-
-# from aes import AESCipher
-from . import cbc
class NostrClient:
- relays = [
- "wss://nostr-pub.wellorder.net",
- "wss://nostr.zebedee.cloud",
- "wss://nodestr.fmt.wiz.biz",
- "wss://nostr.oxtr.dev",
- ] # ["wss://nostr.oxtr.dev"] # ["wss://relay.nostr.info"] "wss://nostr-pub.wellorder.net" "ws://91.237.88.218:2700", "wss://nostrrr.bublina.eu.org", ""wss://nostr-relay.freeberty.net"", , "wss://nostr.oxtr.dev", "wss://relay.nostr.info", "wss://nostr-pub.wellorder.net" , "wss://relayer.fiatjaf.com", "wss://nodestr.fmt.wiz.biz/", "wss://no.str.cr"
- relay_manager = RelayManager()
- private_key: PrivateKey
- public_key: PublicKey
+ relay_manager: RelayManager
+ running: bool
- def __init__(self, privatekey_hex: str = "", relays: List[str] = [], connect=True):
- self.generate_keys(privatekey_hex)
+ def __init__(self):
+ self.running = True
+ self.relay_manager = RelayManager()
- if len(relays):
- self.relays = relays
- if connect:
- self.connect()
+ def connect(self, relays):
+ for relay in relays:
+ try:
+ self.relay_manager.add_relay(relay)
+ except Exception as e:
+ logger.debug(e)
+ self.running = True
- def connect(self):
- for relay in self.relays:
- self.relay_manager.add_relay(relay)
- self.relay_manager.open_connections(
- {"cert_reqs": ssl.CERT_NONE}
- ) # NOTE: This disables ssl certificate verification
+ def reconnect(self, relays):
+ self.relay_manager.remove_relays()
+ self.connect(relays)
def close(self):
- self.relay_manager.close_connections()
+ try:
+ self.relay_manager.close_all_subscriptions()
+ self.relay_manager.close_connections()
- def generate_keys(self, privatekey_hex: str = None):
- pk = bytes.fromhex(privatekey_hex) if privatekey_hex else None
- self.private_key = PrivateKey(pk)
- self.public_key = self.private_key.public_key
+ self.running = False
+ except Exception as e:
+ logger.error(e)
- def post(self, message: str):
- event = Event(message, self.public_key.hex(), kind=EventKind.TEXT_NOTE)
- self.private_key.sign_event(event)
- event_json = event.to_message()
- # print("Publishing message:")
- # print(event_json)
- self.relay_manager.publish_message(event_json)
-
- def get_post(
- self, sender_publickey: PublicKey = None, callback_func=None, filter_kwargs={}
- ):
- filter = Filter(
- authors=[sender_publickey.hex()] if sender_publickey else None,
- kinds=[EventKind.TEXT_NOTE],
- **filter_kwargs,
- )
- filters = Filters([filter])
- subscription_id = os.urandom(4).hex()
- self.relay_manager.add_subscription(subscription_id, filters)
-
- request = [ClientMessageType.REQUEST, subscription_id]
- request.extend(filters.to_json_array())
- message = json.dumps(request)
- self.relay_manager.publish_message(message)
-
- while True:
- while self.relay_manager.message_pool.has_events():
- event_msg = self.relay_manager.message_pool.get_event()
- if callback_func:
- callback_func(event_msg.event)
- time.sleep(0.1)
-
- def dm(self, message: str, to_pubkey: PublicKey):
- dm = EncryptedDirectMessage(
- recipient_pubkey=to_pubkey.hex(), cleartext_content=message
- )
- self.private_key.sign_event(dm)
- self.relay_manager.publish_event(dm)
-
- def get_dm(self, sender_publickey: PublicKey, callback_func=None):
- filters = Filters(
- [
- Filter(
- kinds=[EventKind.ENCRYPTED_DIRECT_MESSAGE],
- pubkey_refs=[sender_publickey.hex()],
- )
- ]
- )
- subscription_id = os.urandom(4).hex()
- self.relay_manager.add_subscription(subscription_id, filters)
-
- request = [ClientMessageType.REQUEST, subscription_id]
- request.extend(filters.to_json_array())
- message = json.dumps(request)
- self.relay_manager.publish_message(message)
-
- while True:
- while self.relay_manager.message_pool.has_events():
- event_msg = self.relay_manager.message_pool.get_event()
- if "?iv=" in event_msg.event.content:
- try:
- shared_secret = self.private_key.compute_shared_secret(
- event_msg.event.public_key
- )
- aes = cbc.AESCipher(key=shared_secret)
- enc_text_b64, iv_b64 = event_msg.event.content.split("?iv=")
- iv = base64.decodebytes(iv_b64.encode("utf-8"))
- enc_text = base64.decodebytes(enc_text_b64.encode("utf-8"))
- dec_text = aes.decrypt(iv, enc_text)
- if callback_func:
- callback_func(event_msg.event, dec_text)
- except:
- pass
- break
- time.sleep(0.1)
-
- def subscribe(
+ async def subscribe(
self,
callback_events_func=None,
callback_notices_func=None,
callback_eosenotices_func=None,
):
- while True:
+ while self.running:
+ self._check_events(callback_events_func)
+ self._check_notices(callback_notices_func)
+ self._check_eos_notices(callback_eosenotices_func)
+
+ await asyncio.sleep(0.2)
+
+ def _check_events(self, callback_events_func=None):
+ try:
while self.relay_manager.message_pool.has_events():
event_msg = self.relay_manager.message_pool.get_event()
if callback_events_func:
callback_events_func(event_msg)
+ except Exception as e:
+ logger.debug(e)
+
+ def _check_notices(self, callback_notices_func=None):
+ try:
while self.relay_manager.message_pool.has_notices():
- event_msg = self.relay_manager.message_pool.has_notices()
+ event_msg = self.relay_manager.message_pool.get_notice()
if callback_notices_func:
callback_notices_func(event_msg)
+ except Exception as e:
+ logger.debug(e)
+
+ def _check_eos_notices(self, callback_eosenotices_func=None):
+ try:
while self.relay_manager.message_pool.has_eose_notices():
event_msg = self.relay_manager.message_pool.get_eose_notice()
if callback_eosenotices_func:
callback_eosenotices_func(event_msg)
-
- time.sleep(0.1)
+ except Exception as e:
+ logger.debug(e)
diff --git a/nostr/delegation.py b/nostr/delegation.py
deleted file mode 100644
index 94801f5..0000000
--- a/nostr/delegation.py
+++ /dev/null
@@ -1,32 +0,0 @@
-import time
-from dataclasses import dataclass
-
-
-@dataclass
-class Delegation:
- delegator_pubkey: str
- delegatee_pubkey: str
- event_kind: int
- duration_secs: int = 30*24*60 # default to 30 days
- signature: str = None # set in PrivateKey.sign_delegation
-
- @property
- def expires(self) -> int:
- return int(time.time()) + self.duration_secs
-
- @property
- def conditions(self) -> str:
- return f"kind={self.event_kind}&created_at<{self.expires}"
-
- @property
- def delegation_token(self) -> str:
- return f"nostr:delegation:{self.delegatee_pubkey}:{self.conditions}"
-
- def get_tag(self) -> list[str]:
- """ Called by Event """
- return [
- "delegation",
- self.delegator_pubkey,
- self.conditions,
- self.signature,
- ]
diff --git a/nostr/event.py b/nostr/event.py
index b903e0e..994c0f4 100644
--- a/nostr/event.py
+++ b/nostr/event.py
@@ -1,10 +1,11 @@
-import time
import json
+import time
from dataclasses import dataclass, field
from enum import IntEnum
-from typing import List
-from secp256k1 import PublicKey
from hashlib import sha256
+from typing import Optional
+
+import coincurve
from .message_type import ClientMessageType
@@ -20,14 +21,14 @@ class EventKind(IntEnum):
@dataclass
class Event:
- content: str = None
- public_key: str = None
- created_at: int = None
+ content: Optional[str] = None
+ public_key: Optional[str] = None
+ created_at: Optional[int] = None
kind: int = EventKind.TEXT_NOTE
- tags: List[List[str]] = field(
+ tags: list[list[str]] = field(
default_factory=list
) # Dataclasses require special handling when the default value is a mutable type
- signature: str = None
+ signature: Optional[str] = None
def __post_init__(self):
if self.content is not None and not isinstance(self.content, str):
@@ -39,7 +40,7 @@ class Event:
@staticmethod
def serialize(
- public_key: str, created_at: int, kind: int, tags: List[List[str]], content: str
+ public_key: str, created_at: int, kind: int, tags: list[list[str]], content: str
) -> bytes:
data = [0, public_key, created_at, kind, tags, content]
data_str = json.dumps(data, separators=(",", ":"), ensure_ascii=False)
@@ -47,7 +48,7 @@ class Event:
@staticmethod
def compute_id(
- public_key: str, created_at: int, kind: int, tags: List[List[str]], content: str
+ public_key: str, created_at: int, kind: int, tags: list[list[str]], content: str
):
return sha256(
Event.serialize(public_key, created_at, kind, tags, content)
@@ -56,6 +57,9 @@ class Event:
@property
def id(self) -> str:
# Always recompute the id to reflect the up-to-date state of the Event
+ assert self.public_key
+ assert self.created_at
+ assert self.content
return Event.compute_id(
self.public_key, self.created_at, self.kind, self.tags, self.content
)
@@ -69,12 +73,10 @@ class Event:
self.tags.append(["e", event_id])
def verify(self) -> bool:
- pub_key = PublicKey(
- bytes.fromhex("02" + self.public_key), True
- ) # add 02 for schnorr (bip340)
- return pub_key.schnorr_verify(
- bytes.fromhex(self.id), bytes.fromhex(self.signature), None, raw=True
- )
+ assert self.public_key
+ assert self.signature
+ pub_key = coincurve.PublicKeyXOnly(bytes.fromhex(self.public_key))
+ return pub_key.verify(bytes.fromhex(self.signature), bytes.fromhex(self.id))
def to_message(self) -> str:
return json.dumps(
@@ -95,9 +97,9 @@ class Event:
@dataclass
class EncryptedDirectMessage(Event):
- recipient_pubkey: str = None
- cleartext_content: str = None
- reference_event_id: str = None
+ recipient_pubkey: Optional[str] = None
+ cleartext_content: Optional[str] = None
+ reference_event_id: Optional[str] = None
def __post_init__(self):
if self.content is not None:
@@ -121,6 +123,7 @@ class EncryptedDirectMessage(Event):
def id(self) -> str:
if self.content is None:
raise Exception(
- "EncryptedDirectMessage `id` is undefined until its message is encrypted and stored in the `content` field"
+ "EncryptedDirectMessage `id` is undefined until its"
+ + " message is encrypted and stored in the `content` field"
)
return super().id
diff --git a/nostr/filter.py b/nostr/filter.py
deleted file mode 100644
index f119079..0000000
--- a/nostr/filter.py
+++ /dev/null
@@ -1,134 +0,0 @@
-from collections import UserList
-from typing import List
-
-from .event import Event, EventKind
-
-
-class Filter:
- """
- NIP-01 filtering.
-
- Explicitly supports "#e" and "#p" tag filters via `event_refs` and `pubkey_refs`.
-
- Arbitrary NIP-12 single-letter tag filters are also supported via `add_arbitrary_tag`.
- If a particular single-letter tag gains prominence, explicit support should be
- added. For example:
- # arbitrary tag
- filter.add_arbitrary_tag('t', [hashtags])
-
- # promoted to explicit support
- Filter(hashtag_refs=[hashtags])
- """
-
- def __init__(
- self,
- event_ids: List[str] = None,
- kinds: List[EventKind] = None,
- authors: List[str] = None,
- since: int = None,
- until: int = None,
- event_refs: List[
- str
- ] = None, # the "#e" attr; list of event ids referenced in an "e" tag
- pubkey_refs: List[
- str
- ] = None, # The "#p" attr; list of pubkeys referenced in a "p" tag
- limit: int = None,
- ) -> None:
- self.event_ids = event_ids
- self.kinds = kinds
- self.authors = authors
- self.since = since
- self.until = until
- self.event_refs = event_refs
- self.pubkey_refs = pubkey_refs
- self.limit = limit
-
- self.tags = {}
- if self.event_refs:
- self.add_arbitrary_tag("e", self.event_refs)
- if self.pubkey_refs:
- self.add_arbitrary_tag("p", self.pubkey_refs)
-
- def add_arbitrary_tag(self, tag: str, values: list):
- """
- Filter on any arbitrary tag with explicit handling for NIP-01 and NIP-12
- single-letter tags.
- """
- # NIP-01 'e' and 'p' tags and any NIP-12 single-letter tags must be prefixed with "#"
- tag_key = tag if len(tag) > 1 else f"#{tag}"
- self.tags[tag_key] = values
-
- def matches(self, event: Event) -> bool:
- if self.event_ids is not None and event.id not in self.event_ids:
- return False
- if self.kinds is not None and event.kind not in self.kinds:
- return False
- if self.authors is not None and event.public_key not in self.authors:
- return False
- if self.since is not None and event.created_at < self.since:
- return False
- if self.until is not None and event.created_at > self.until:
- return False
- if (self.event_refs is not None or self.pubkey_refs is not None) and len(
- event.tags
- ) == 0:
- return False
-
- if self.tags:
- e_tag_identifiers = set([e_tag[0] for e_tag in event.tags])
- for f_tag, f_tag_values in self.tags.items():
- # Omit any NIP-01 or NIP-12 "#" chars on single-letter tags
- f_tag = f_tag.replace("#", "")
-
- if f_tag not in e_tag_identifiers:
- # Event is missing a tag type that we're looking for
- return False
-
- # Multiple values within f_tag_values are treated as OR search; an Event
- # needs to match only one.
- # Note: an Event could have multiple entries of the same tag type
- # (e.g. a reply to multiple people) so we have to check all of them.
- match_found = False
- for e_tag in event.tags:
- if e_tag[0] == f_tag and e_tag[1] in f_tag_values:
- match_found = True
- break
- if not match_found:
- return False
-
- return True
-
- def to_json_object(self) -> dict:
- res = {}
- if self.event_ids is not None:
- res["ids"] = self.event_ids
- if self.kinds is not None:
- res["kinds"] = self.kinds
- if self.authors is not None:
- res["authors"] = self.authors
- if self.since is not None:
- res["since"] = self.since
- if self.until is not None:
- res["until"] = self.until
- if self.limit is not None:
- res["limit"] = self.limit
- if self.tags:
- res.update(self.tags)
-
- return res
-
-
-class Filters(UserList):
- def __init__(self, initlist: "list[Filter]" = []) -> None:
- super().__init__(initlist)
- self.data: "list[Filter]"
-
- def match(self, event: Event):
- for filter in self.data:
- if filter.matches(event):
- return True
- return False
-
- def to_json_array(self) -> list:
- return [filter.to_json_object() for filter in self.data]
diff --git a/nostr/key.py b/nostr/key.py
index d34697f..f7b4e81 100644
--- a/nostr/key.py
+++ b/nostr/key.py
@@ -1,14 +1,12 @@
-import secrets
import base64
-import secp256k1
-from cffi import FFI
-from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
-from cryptography.hazmat.primitives import padding
-from hashlib import sha256
+import secrets
-from .delegation import Delegation
+import coincurve
+from cryptography.hazmat.primitives import padding
+from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
+
+from .bech32 import Encoding, bech32_decode, bech32_encode, convertbits
from .event import EncryptedDirectMessage, Event, EventKind
-from . import bech32
class PublicKey:
@@ -16,55 +14,61 @@ class PublicKey:
self.raw_bytes = raw_bytes
def bech32(self) -> str:
- converted_bits = bech32.convertbits(self.raw_bytes, 8, 5)
- return bech32.bech32_encode("npub", converted_bits, bech32.Encoding.BECH32)
+ converted_bits = convertbits(self.raw_bytes, 8, 5)
+ return bech32_encode("npub", converted_bits, Encoding.BECH32)
def hex(self) -> str:
return self.raw_bytes.hex()
- def verify_signed_message_hash(self, hash: str, sig: str) -> bool:
- pk = secp256k1.PublicKey(b"\x02" + self.raw_bytes, True)
- return pk.schnorr_verify(bytes.fromhex(hash), bytes.fromhex(sig), None, True)
+ def verify_signed_message_hash(self, message_hash: str, sig: str) -> bool:
+ pk = coincurve.PublicKeyXOnly(self.raw_bytes)
+ return pk.verify(bytes.fromhex(sig), bytes.fromhex(message_hash))
@classmethod
def from_npub(cls, npub: str):
"""Load a PublicKey from its bech32/npub form"""
- hrp, data, spec = bech32.bech32_decode(npub)
- raw_public_key = bech32.convertbits(data, 5, 8)[:-1]
+ _, data, _ = bech32_decode(npub)
+ raw_data = convertbits(data, 5, 8)
+ assert raw_data
+ raw_public_key = raw_data[:-1]
return cls(bytes(raw_public_key))
class PrivateKey:
- def __init__(self, raw_secret: bytes = None) -> None:
- if not raw_secret is None:
+ def __init__(self, raw_secret: bytes | None = None) -> None:
+ if raw_secret is not None:
self.raw_secret = raw_secret
else:
self.raw_secret = secrets.token_bytes(32)
- sk = secp256k1.PrivateKey(self.raw_secret)
- self.public_key = PublicKey(sk.pubkey.serialize()[1:])
+ sk = coincurve.PrivateKey(self.raw_secret)
+ assert sk.public_key
+ self.public_key = PublicKey(sk.public_key.format()[1:])
@classmethod
def from_nsec(cls, nsec: str):
"""Load a PrivateKey from its bech32/nsec form"""
- hrp, data, spec = bech32.bech32_decode(nsec)
- raw_secret = bech32.convertbits(data, 5, 8)[:-1]
+ _, data, _ = bech32_decode(nsec)
+ raw_data = convertbits(data, 5, 8)
+ assert raw_data
+ raw_secret = raw_data[:-1]
return cls(bytes(raw_secret))
def bech32(self) -> str:
- converted_bits = bech32.convertbits(self.raw_secret, 8, 5)
- return bech32.bech32_encode("nsec", converted_bits, bech32.Encoding.BECH32)
+ converted_bits = convertbits(self.raw_secret, 8, 5)
+ return bech32_encode("nsec", converted_bits, Encoding.BECH32)
def hex(self) -> str:
return self.raw_secret.hex()
def tweak_add(self, scalar: bytes) -> bytes:
- sk = secp256k1.PrivateKey(self.raw_secret)
- return sk.tweak_add(scalar)
+ sk = coincurve.PrivateKey(self.raw_secret)
+ return sk.add(scalar).to_der()
def compute_shared_secret(self, public_key_hex: str) -> bytes:
- pk = secp256k1.PublicKey(bytes.fromhex("02" + public_key_hex), True)
- return pk.ecdh(self.raw_secret, hashfn=copy_x)
+ pk = coincurve.PublicKey(bytes.fromhex("02" + public_key_hex))
+ sk = coincurve.PrivateKey(self.raw_secret)
+ return sk.ecdh(pk.format())
def encrypt_message(self, message: str, public_key_hex: str) -> str:
padder = padding.PKCS7(128).padder()
@@ -78,9 +82,14 @@ class PrivateKey:
encryptor = cipher.encryptor()
encrypted_message = encryptor.update(padded_data) + encryptor.finalize()
- return f"{base64.b64encode(encrypted_message).decode()}?iv={base64.b64encode(iv).decode()}"
+ return (
+ f"{base64.b64encode(encrypted_message).decode()}"
+ + f"?iv={base64.b64encode(iv).decode()}"
+ )
def encrypt_dm(self, dm: EncryptedDirectMessage) -> None:
+ assert dm.cleartext_content
+ assert dm.recipient_pubkey
dm.content = self.encrypt_message(
message=dm.cleartext_content, public_key_hex=dm.recipient_pubkey
)
@@ -103,28 +112,23 @@ class PrivateKey:
return unpadded_data.decode()
- def sign_message_hash(self, hash: bytes) -> str:
- sk = secp256k1.PrivateKey(self.raw_secret)
- sig = sk.schnorr_sign(hash, None, raw=True)
+ def sign_message_hash(self, message_hash: bytes) -> str:
+ sk = coincurve.PrivateKey(self.raw_secret)
+ sig = sk.sign_schnorr(message_hash)
return sig.hex()
def sign_event(self, event: Event) -> None:
if event.kind == EventKind.ENCRYPTED_DIRECT_MESSAGE and event.content is None:
- self.encrypt_dm(event)
+ self.encrypt_dm(event) # type: ignore
if event.public_key is None:
event.public_key = self.public_key.hex()
event.signature = self.sign_message_hash(bytes.fromhex(event.id))
- def sign_delegation(self, delegation: Delegation) -> None:
- delegation.signature = self.sign_message_hash(
- sha256(delegation.delegation_token.encode()).digest()
- )
-
def __eq__(self, other):
return self.raw_secret == other.raw_secret
-def mine_vanity_key(prefix: str = None, suffix: str = None) -> PrivateKey:
+def mine_vanity_key(prefix: str | None = None, suffix: str | None = None) -> PrivateKey:
if prefix is None and suffix is None:
raise ValueError("Expected at least one of 'prefix' or 'suffix' arguments")
@@ -140,14 +144,3 @@ def mine_vanity_key(prefix: str = None, suffix: str = None) -> PrivateKey:
break
return sk
-
-
-ffi = FFI()
-
-
-@ffi.callback(
- "int (unsigned char *, const unsigned char *, const unsigned char *, void *)"
-)
-def copy_x(output, x32, y32, data):
- ffi.memmove(output, x32, 32)
- return 1
diff --git a/nostr/message_pool.py b/nostr/message_pool.py
index d364cf2..a3e6c5f 100644
--- a/nostr/message_pool.py
+++ b/nostr/message_pool.py
@@ -1,13 +1,16 @@
import json
from queue import Queue
from threading import Lock
+
from .message_type import RelayMessageType
-from .event import Event
class EventMessage:
- def __init__(self, event: Event, subscription_id: str, url: str) -> None:
+ def __init__(
+ self, event: str, event_id: str, subscription_id: str, url: str
+ ) -> None:
self.event = event
+ self.event_id = event_id
self.subscription_id = subscription_id
self.url = url
@@ -58,20 +61,29 @@ class MessagePool:
message_type = message_json[0]
if message_type == RelayMessageType.EVENT:
subscription_id = message_json[1]
- e = message_json[2]
- event = Event(
- e["content"],
- e["pubkey"],
- e["created_at"],
- e["kind"],
- e["tags"],
- e["sig"],
- )
+ event = message_json[2]
+ if "id" not in event:
+ return
+ event_id = event["id"]
+
with self.lock:
- if not event.id in self._unique_events:
- self.events.put(EventMessage(event, subscription_id, url))
- self._unique_events.add(event.id)
+ if f"{subscription_id}_{event_id}" not in self._unique_events:
+ self._accept_event(
+ EventMessage(json.dumps(event), event_id, subscription_id, url)
+ )
elif message_type == RelayMessageType.NOTICE:
self.notices.put(NoticeMessage(message_json[1], url))
elif message_type == RelayMessageType.END_OF_STORED_EVENTS:
self.eose_notices.put(EndOfStoredEventsMessage(message_json[1], url))
+
+ def _accept_event(self, event_message: EventMessage):
+ """
+ Event uniqueness is considered per `subscription_id`. The `subscription_id` is
+ rewritten to be unique and it is the same accross relays. The same event can
+ come from different subscriptions (from the same client or from different ones).
+ Clients that have joined later should receive older events.
+ """
+ self.events.put(event_message)
+ self._unique_events.add(
+ f"{event_message.subscription_id}_{event_message.event_id}"
+ )
diff --git a/nostr/message_type.py b/nostr/message_type.py
index 3f5206b..d37cdfd 100644
--- a/nostr/message_type.py
+++ b/nostr/message_type.py
@@ -3,13 +3,20 @@ class ClientMessageType:
REQUEST = "REQ"
CLOSE = "CLOSE"
+
class RelayMessageType:
EVENT = "EVENT"
NOTICE = "NOTICE"
END_OF_STORED_EVENTS = "EOSE"
+ COMMAND_RESULT = "OK"
@staticmethod
def is_valid(type: str) -> bool:
- if type == RelayMessageType.EVENT or type == RelayMessageType.NOTICE or type == RelayMessageType.END_OF_STORED_EVENTS:
+ if (
+ type == RelayMessageType.EVENT
+ or type == RelayMessageType.NOTICE
+ or type == RelayMessageType.END_OF_STORED_EVENTS
+ or type == RelayMessageType.COMMAND_RESULT
+ ):
return True
- return False
\ No newline at end of file
+ return False
diff --git a/nostr/pow.py b/nostr/pow.py
deleted file mode 100644
index e006288..0000000
--- a/nostr/pow.py
+++ /dev/null
@@ -1,54 +0,0 @@
-import time
-from .event import Event
-from .key import PrivateKey
-
-def zero_bits(b: int) -> int:
- n = 0
-
- if b == 0:
- return 8
-
- while b >> 1:
- b = b >> 1
- n += 1
-
- return 7 - n
-
-def count_leading_zero_bits(hex_str: str) -> int:
- total = 0
- for i in range(0, len(hex_str) - 2, 2):
- bits = zero_bits(int(hex_str[i:i+2], 16))
- total += bits
-
- if bits != 8:
- break
-
- return total
-
-def mine_event(content: str, difficulty: int, public_key: str, kind: int, tags: list=[]) -> Event:
- all_tags = [["nonce", "1", str(difficulty)]]
- all_tags.extend(tags)
-
- created_at = int(time.time())
- event_id = Event.compute_id(public_key, created_at, kind, all_tags, content)
- num_leading_zero_bits = count_leading_zero_bits(event_id)
-
- attempts = 1
- while num_leading_zero_bits < difficulty:
- attempts += 1
- all_tags[0][1] = str(attempts)
- created_at = int(time.time())
- event_id = Event.compute_id(public_key, created_at, kind, all_tags, content)
- num_leading_zero_bits = count_leading_zero_bits(event_id)
-
- return Event(public_key, content, created_at, kind, all_tags, event_id)
-
-def mine_key(difficulty: int) -> PrivateKey:
- sk = PrivateKey()
- num_leading_zero_bits = count_leading_zero_bits(sk.public_key.hex())
-
- while num_leading_zero_bits < difficulty:
- sk = PrivateKey()
- num_leading_zero_bits = count_leading_zero_bits(sk.public_key.hex())
-
- return sk
diff --git a/nostr/relay.py b/nostr/relay.py
index ee78baa..d762963 100644
--- a/nostr/relay.py
+++ b/nostr/relay.py
@@ -1,49 +1,37 @@
+import asyncio
import json
import time
from queue import Queue
-from threading import Lock
+
+from loguru import logger
from websocket import WebSocketApp
-from .event import Event
-from .filter import Filters
+
from .message_pool import MessagePool
-from .message_type import RelayMessageType
from .subscription import Subscription
-class RelayPolicy:
- def __init__(self, should_read: bool = True, should_write: bool = True) -> None:
- self.should_read = should_read
- self.should_write = should_write
-
- def to_json_object(self) -> dict[str, bool]:
- return {"read": self.should_read, "write": self.should_write}
-
-
class Relay:
- def __init__(
- self,
- url: str,
- policy: RelayPolicy,
- message_pool: MessagePool,
- subscriptions: dict[str, Subscription] = {},
- ) -> None:
+ def __init__(self, url: str, message_pool: MessagePool) -> None:
self.url = url
- self.policy = policy
self.message_pool = message_pool
- self.subscriptions = subscriptions
self.connected: bool = False
self.reconnect: bool = True
+ self.shutdown: bool = False
+
self.error_counter: int = 0
- self.error_threshold: int = 0
+ self.error_threshold: int = 100
+ self.error_list: list[str] = []
+ self.notice_list: list[str] = []
+ self.last_error_date: int = 0
self.num_received_events: int = 0
self.num_sent_events: int = 0
self.num_subscriptions: int = 0
- self.ssl_options: dict = {}
- self.proxy: dict = {}
- self.lock = Lock()
- self.queue = Queue()
+
+ self.queue: Queue = Queue()
+
+ def connect(self):
self.ws = WebSocketApp(
- url,
+ self.url,
on_open=self._on_open,
on_message=self._on_message,
on_error=self._on_error,
@@ -51,31 +39,20 @@ class Relay:
on_ping=self._on_ping,
on_pong=self._on_pong,
)
-
- def connect(self, ssl_options: dict = None, proxy: dict = None):
- self.ssl_options = ssl_options
- self.proxy = proxy
if not self.connected:
- self.ws.run_forever(
- sslopt=ssl_options,
- http_proxy_host=None if proxy is None else proxy.get("host"),
- http_proxy_port=None if proxy is None else proxy.get("port"),
- proxy_type=None if proxy is None else proxy.get("type"),
- ping_interval=5,
- )
+ self.ws.run_forever(ping_interval=10)
def close(self):
- self.ws.close()
-
- def check_reconnect(self):
try:
- self.close()
- except:
- pass
+ self.ws.close()
+ except Exception as e:
+ logger.warning(f"[Relay: {self.url}] Failed to close websocket: {e}")
self.connected = False
- if self.reconnect:
- time.sleep(1)
- self.connect(self.ssl_options, self.proxy)
+ self.shutdown = True
+
+ @property
+ def error_threshold_reached(self):
+ return self.error_threshold and self.error_counter >= self.error_threshold
@property
def ping(self):
@@ -85,99 +62,65 @@ class Relay:
def publish(self, message: str):
self.queue.put(message)
- def queue_worker(self):
+ def publish_subscriptions(self, subscriptions: list[Subscription]):
+ for s in subscriptions:
+ assert s.filters
+ json_str = json.dumps(["REQ", s.id, *s.filters])
+ self.publish(json_str)
+
+ async def queue_worker(self):
while True:
if self.connected:
- message = self.queue.get()
- self.num_sent_events += 1
- self.ws.send(message)
+ try:
+ message = self.queue.get(timeout=1)
+ self.num_sent_events += 1
+ self.ws.send(message)
+ except Exception as _:
+ pass
else:
- time.sleep(0.1)
+ await asyncio.sleep(1)
- def add_subscription(self, id, filters: Filters):
- with self.lock:
- self.subscriptions[id] = Subscription(id, filters)
+ if self.shutdown:
+ logger.warning(f"[Relay: {self.url}] Closing queue worker.")
+ return
- def close_subscription(self, id: str) -> None:
- with self.lock:
- self.subscriptions.pop(id)
+ def close_subscription(self, sub_id: str) -> None:
+ try:
+ self.publish(json.dumps(["CLOSE", sub_id]))
+ except Exception as e:
+ logger.debug(f"[Relay: {self.url}] Failed to close subscription: {e}")
- def update_subscription(self, id: str, filters: Filters) -> None:
- with self.lock:
- subscription = self.subscriptions[id]
- subscription.filters = filters
+ def add_notice(self, notice: str):
+ self.notice_list = [notice, *self.notice_list]
- def to_json_object(self) -> dict:
- return {
- "url": self.url,
- "policy": self.policy.to_json_object(),
- "subscriptions": [
- subscription.to_json_object()
- for subscription in self.subscriptions.values()
- ],
- }
-
- def _on_open(self, class_obj):
+ def _on_open(self, _):
+ logger.info(f"[Relay: {self.url}] Connected.")
self.connected = True
- pass
+ self.shutdown = False
- def _on_close(self, class_obj, status_code, message):
- self.connected = False
- pass
+ def _on_close(self, _, status_code, message):
+ logger.warning(
+ f"[Relay: {self.url}] Connection closed."
+ + f" Status: '{status_code}'. Message: '{message}'."
+ )
+ self.close()
- def _on_message(self, class_obj, message: str):
- if self._is_valid_message(message):
- self.num_received_events += 1
- self.message_pool.add_message(message, self.url)
+ def _on_message(self, _, message: str):
+ self.num_received_events += 1
+ self.message_pool.add_message(message, self.url)
- def _on_error(self, class_obj, error):
- self.connected = False
+ def _on_error(self, _, error):
+ logger.warning(f"[Relay: {self.url}] Error: '{error!s}'")
+ self._append_error_message(str(error))
+ self.close()
+
+ def _on_ping(self, *_):
+ return
+
+ def _on_pong(self, *_):
+ return
+
+ def _append_error_message(self, message):
self.error_counter += 1
- if self.error_threshold and self.error_counter > self.error_threshold:
- pass
- else:
- self.check_reconnect()
-
- def _on_ping(self, class_obj, message):
- return
-
- def _on_pong(self, class_obj, message):
- return
-
- def _is_valid_message(self, message: str) -> bool:
- message = message.strip("\n")
- if not message or message[0] != "[" or message[-1] != "]":
- return False
-
- message_json = json.loads(message)
- message_type = message_json[0]
- if not RelayMessageType.is_valid(message_type):
- return False
- if message_type == RelayMessageType.EVENT:
- if not len(message_json) == 3:
- return False
-
- subscription_id = message_json[1]
- with self.lock:
- if subscription_id not in self.subscriptions:
- return False
-
- e = message_json[2]
- event = Event(
- e["content"],
- e["pubkey"],
- e["created_at"],
- e["kind"],
- e["tags"],
- e["sig"],
- )
- if not event.verify():
- return False
-
- with self.lock:
- subscription = self.subscriptions[subscription_id]
-
- if subscription.filters and not subscription.filters.match(event):
- return False
-
- return True
+ self.error_list = [message, *self.error_list]
+ self.last_error_date = int(time.time())
diff --git a/nostr/relay_manager.py b/nostr/relay_manager.py
index 5b92d8d..2aa27c5 100644
--- a/nostr/relay_manager.py
+++ b/nostr/relay_manager.py
@@ -1,52 +1,99 @@
-import json
+import asyncio
import threading
+import time
+from typing import List
-from .event import Event
-from .filter import Filters
-from .message_pool import MessagePool
-from .message_type import ClientMessageType
-from .relay import Relay, RelayPolicy
+from loguru import logger
-
-class RelayException(Exception):
- pass
+from .message_pool import MessagePool, NoticeMessage
+from .relay import Relay
+from .subscription import Subscription
class RelayManager:
def __init__(self) -> None:
self.relays: dict[str, Relay] = {}
+ self.threads: dict[str, threading.Thread] = {}
+ self.queue_threads: dict[str, threading.Thread] = {}
self.message_pool = MessagePool()
+ self._cached_subscriptions: dict[str, Subscription] = {}
+ self._subscriptions_lock = threading.Lock()
- def add_relay(
- self, url: str, read: bool = True, write: bool = True, subscriptions={}
- ):
- policy = RelayPolicy(read, write)
- relay = Relay(url, policy, self.message_pool, subscriptions)
+ def add_relay(self, url: str) -> Relay:
+ if url in list(self.relays.keys()):
+ logger.debug(f"Relay '{url}' already present.")
+ return self.relays[url]
+
+ relay = Relay(url, self.message_pool)
self.relays[url] = relay
- def remove_relay(self, url: str):
- self.relays.pop(url)
+ self._open_connection(relay)
+
+ relay.publish_subscriptions(list(self._cached_subscriptions.values()))
+ return relay
+
+ def remove_relay(self, url: str):
+ try:
+ self.relays[url].close()
+ except Exception as e:
+ logger.debug(e)
+
+ if url in self.relays:
+ self.relays.pop(url)
+
+ try:
+ self.threads[url].join(timeout=5)
+ except Exception as e:
+ logger.debug(e)
+
+ if url in self.threads:
+ self.threads.pop(url)
+
+ try:
+ self.queue_threads[url].join(timeout=5)
+ except Exception as e:
+ logger.debug(e)
+
+ if url in self.queue_threads:
+ self.queue_threads.pop(url)
+
+ def remove_relays(self):
+ relay_urls = list(self.relays.keys())
+ for url in relay_urls:
+ self.remove_relay(url)
+
+ def add_subscription(self, id: str, filters: List[str]):
+ s = Subscription(id, filters)
+ with self._subscriptions_lock:
+ self._cached_subscriptions[id] = s
- def add_subscription(self, id: str, filters: Filters):
for relay in self.relays.values():
- relay.add_subscription(id, filters)
+ relay.publish_subscriptions([s])
def close_subscription(self, id: str):
- for relay in self.relays.values():
- relay.close_subscription(id)
+ try:
+ logger.info(f"Closing subscription: '{id}'.")
+ with self._subscriptions_lock:
+ if id in self._cached_subscriptions:
+ self._cached_subscriptions.pop(id)
- def open_connections(self, ssl_options: dict = None, proxy: dict = None):
- for relay in self.relays.values():
- threading.Thread(
- target=relay.connect,
- args=(ssl_options, proxy),
- name=f"{relay.url}-thread",
- daemon=True,
- ).start()
+ for relay in self.relays.values():
+ relay.close_subscription(id)
+ except Exception as e:
+ logger.debug(e)
- threading.Thread(
- target=relay.queue_worker, name=f"{relay.url}-queue", daemon=True
- ).start()
+ def close_subscriptions(self, subscriptions: List[str]):
+ for id in subscriptions:
+ self.close_subscription(id)
+
+ def close_all_subscriptions(self):
+ all_subscriptions = list(self._cached_subscriptions.keys())
+ self.close_subscriptions(all_subscriptions)
+
+ def check_and_restart_relays(self):
+ stopped_relays = [r for r in self.relays.values() if r.shutdown]
+ for relay in stopped_relays:
+ self._restart_relay(relay)
def close_connections(self):
for relay in self.relays.values():
@@ -54,16 +101,43 @@ class RelayManager:
def publish_message(self, message: str):
for relay in self.relays.values():
- if relay.policy.should_write:
- relay.publish(message)
+ relay.publish(message)
- def publish_event(self, event: Event):
- """Verifies that the Event is publishable before submitting it to relays"""
- if event.signature is None:
- raise RelayException(f"Could not publish {event.id}: must be signed")
+ def handle_notice(self, notice: NoticeMessage):
+ relay = next((r for r in self.relays.values() if r.url == notice.url))
+ if relay:
+ relay.add_notice(notice.content)
- if not event.verify():
- raise RelayException(
- f"Could not publish {event.id}: failed to verify signature {event.signature}"
- )
- self.publish_message(event.to_message())
+ def _open_connection(self, relay: Relay):
+ self.threads[relay.url] = threading.Thread(
+ target=relay.connect,
+ name=f"{relay.url}-thread",
+ daemon=True,
+ )
+ self.threads[relay.url].start()
+
+ def wrap_async_queue_worker():
+ asyncio.run(relay.queue_worker())
+
+ self.queue_threads[relay.url] = threading.Thread(
+ target=wrap_async_queue_worker,
+ name=f"{relay.url}-queue",
+ daemon=True,
+ )
+ self.queue_threads[relay.url].start()
+
+ def _restart_relay(self, relay: Relay):
+ time_since_last_error = time.time() - relay.last_error_date
+
+ min_wait_time = min(
+ 60 * relay.error_counter, 60 * 60
+ ) # try at least once an hour
+ if time_since_last_error < min_wait_time:
+ return
+
+ logger.info(f"Restarting connection to relay '{relay.url}'")
+
+ self.remove_relay(relay.url)
+ new_relay = self.add_relay(relay.url)
+ new_relay.error_counter = relay.error_counter
+ new_relay.error_list = relay.error_list
diff --git a/nostr/subscription.py b/nostr/subscription.py
index 7afba20..ed60f7e 100644
--- a/nostr/subscription.py
+++ b/nostr/subscription.py
@@ -1,12 +1,7 @@
-from .filter import Filters
+from typing import Optional
+
class Subscription:
- def __init__(self, id: str, filters: Filters=None) -> None:
+ def __init__(self, id: str, filters: Optional[list[str]] = None) -> None:
self.id = id
self.filters = filters
-
- def to_json_object(self):
- return {
- "id": self.id,
- "filters": self.filters.to_json_array()
- }
diff --git a/package-lock.json b/package-lock.json
new file mode 100644
index 0000000..1180ffb
--- /dev/null
+++ b/package-lock.json
@@ -0,0 +1,59 @@
+{
+ "name": "nostrclient",
+ "version": "1.0.0",
+ "lockfileVersion": 3,
+ "requires": true,
+ "packages": {
+ "": {
+ "name": "nostrclient",
+ "version": "1.0.0",
+ "license": "ISC",
+ "dependencies": {
+ "prettier": "^3.2.5",
+ "pyright": "^1.1.358"
+ }
+ },
+ "node_modules/fsevents": {
+ "version": "2.3.3",
+ "resolved": "https://registry.npmjs.org/fsevents/-/fsevents-2.3.3.tgz",
+ "integrity": "sha512-5xoDfX+fL7faATnagmWPpbFtwh/R77WmMMqqHGS65C3vvB0YHrgF+B1YmZ3441tMj5n63k0212XNoJwzlhffQw==",
+ "hasInstallScript": true,
+ "optional": true,
+ "os": [
+ "darwin"
+ ],
+ "engines": {
+ "node": "^8.16.0 || ^10.6.0 || >=11.0.0"
+ }
+ },
+ "node_modules/prettier": {
+ "version": "3.3.3",
+ "resolved": "https://registry.npmjs.org/prettier/-/prettier-3.3.3.tgz",
+ "integrity": "sha512-i2tDNA0O5IrMO757lfrdQZCc2jPNDVntV0m/+4whiDfWaTKfMNgR7Qz0NAeGz/nRqF4m5/6CLzbP4/liHt12Ew==",
+ "bin": {
+ "prettier": "bin/prettier.cjs"
+ },
+ "engines": {
+ "node": ">=14"
+ },
+ "funding": {
+ "url": "https://github.com/prettier/prettier?sponsor=1"
+ }
+ },
+ "node_modules/pyright": {
+ "version": "1.1.374",
+ "resolved": "https://registry.npmjs.org/pyright/-/pyright-1.1.374.tgz",
+ "integrity": "sha512-ISbC1YnYDYrEatoKKjfaA5uFIp0ddC/xw9aSlN/EkmwupXUMVn41Jl+G6wHEjRhC+n4abHZeGpEvxCUus/K9dA==",
+ "bin": {
+ "pyright": "index.js",
+ "pyright-langserver": "langserver.index.js"
+ },
+ "engines": {
+ "node": ">=14.0.0"
+ },
+ "optionalDependencies": {
+ "fsevents": "~2.3.3"
+ }
+ }
+ }
+}
diff --git a/package.json b/package.json
new file mode 100644
index 0000000..7b84315
--- /dev/null
+++ b/package.json
@@ -0,0 +1,15 @@
+{
+ "name": "nostrclient",
+ "version": "1.0.0",
+ "description": "",
+ "main": "index.js",
+ "scripts": {
+ "test": "echo \"Error: no test specified\" && exit 1"
+ },
+ "author": "",
+ "license": "ISC",
+ "dependencies": {
+ "prettier": "^3.2.5",
+ "pyright": "^1.1.358"
+ }
+}
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000..c7c0e56
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,98 @@
+[project]
+name = "lnbits-nostrclient"
+version = "1.1.0"
+requires-python = ">=3.10,<3.13"
+description = "LNbits, free and open-source Lightning wallet and accounts system."
+authors = [{ name = "Alan Bits", email = "alan@lnbits.com" }]
+urls = { Homepage = "https://lnbits.com", Repository = "https://github.com/lnbits/nostrclient" }
+dependencies = [ "lnbits>1" ]
+
+[tool.poetry]
+package-mode = false
+
+[tool.uv]
+dev-dependencies = [
+ "black",
+ "pytest-asyncio",
+ "pytest",
+ "mypy",
+ "pre-commit",
+ "ruff",
+ "pytest-md",
+ "types-cffi",
+]
+
+[tool.mypy]
+exclude = "(nostr/*)"
+plugins = ["pydantic.mypy"]
+
+[[tool.mypy.overrides]]
+module = [
+ "nostr.*",
+]
+follow_imports = "skip"
+ignore_missing_imports = "True"
+
+[tool.pydantic-mypy]
+init_forbid_extra = true
+init_typed = true
+warn_required_dynamic_aliases = true
+warn_untyped_fields = true
+
+[tool.pytest.ini_options]
+log_cli = false
+testpaths = [
+ "tests"
+]
+
+[tool.black]
+line-length = 88
+
+[tool.ruff]
+# Same as Black. + 10% rule of black
+line-length = 88
+exclude = [
+ "nostr",
+]
+
+[tool.ruff.lint]
+# Enable:
+# F - pyflakes
+# E - pycodestyle errors
+# W - pycodestyle warnings
+# I - isort
+# A - flake8-builtins
+# C - mccabe
+# N - naming
+# UP - pyupgrade
+# RUF - ruff
+# B - bugbear
+select = ["F", "E", "W", "I", "A", "C", "N", "UP", "RUF", "B"]
+ignore = ["C901"]
+
+# Allow autofix for all enabled rules (when `--fix`) is provided.
+fixable = ["ALL"]
+unfixable = []
+
+# Allow unused variables when underscore-prefixed.
+dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
+
+# needed for pydantic
+[tool.ruff.lint.pep8-naming]
+classmethod-decorators = [
+ "root_validator",
+]
+
+# Ignore unused imports in __init__.py files.
+# [tool.ruff.lint.extend-per-file-ignores]
+# "__init__.py" = ["F401", "F403"]
+
+# [tool.ruff.lint.mccabe]
+# max-complexity = 10
+
+[tool.ruff.lint.flake8-bugbear]
+# Allow default arguments like, e.g., `data: List[str] = fastapi.Query(None)`.
+extend-immutable-calls = [
+ "fastapi.Depends",
+ "fastapi.Query",
+]
diff --git a/router.py b/router.py
new file mode 100644
index 0000000..a7054e9
--- /dev/null
+++ b/router.py
@@ -0,0 +1,175 @@
+import asyncio
+import json
+from typing import ClassVar
+
+from fastapi import WebSocket, WebSocketDisconnect
+from lnbits.helpers import urlsafe_short_hash
+from loguru import logger
+
+from .nostr.client.client import NostrClient
+
+# from . import nostr_client
+from .nostr.message_pool import EndOfStoredEventsMessage, EventMessage, NoticeMessage
+
+nostr_client: NostrClient = NostrClient()
+all_routers: list["NostrRouter"] = []
+
+
+class NostrRouter:
+ received_subscription_events: ClassVar[dict[str, list[EventMessage]]] = {}
+ received_subscription_notices: ClassVar[list[NoticeMessage]] = []
+ received_subscription_eosenotices: ClassVar[dict[str, EndOfStoredEventsMessage]] = (
+ {}
+ )
+
+ def __init__(self, websocket: WebSocket):
+ self.connected: bool = True
+ self.websocket: WebSocket = websocket
+ self.tasks: list[asyncio.Task] = []
+ self.original_subscription_ids: dict[str, str] = {}
+
+ @property
+ def subscriptions(self) -> list[str]:
+ return list(self.original_subscription_ids.keys())
+
+ def start(self):
+ self.connected = True
+ self.tasks.append(asyncio.create_task(self._client_to_nostr()))
+ self.tasks.append(asyncio.create_task(self._nostr_to_client()))
+
+ async def stop(self):
+ nostr_client.relay_manager.close_subscriptions(self.subscriptions)
+ self.connected = False
+
+ for t in self.tasks:
+ try:
+ t.cancel()
+ except Exception as _:
+ pass
+
+ try:
+ await self.websocket.close(reason="Websocket connection closed")
+ except Exception as _:
+ pass
+
+ async def _client_to_nostr(self):
+ """
+ Receives requests / data from the client and forwards it to relays.
+ """
+ while self.connected:
+ try:
+ json_str = await self.websocket.receive_text()
+ except WebSocketDisconnect as e:
+ logger.debug(e)
+ await self.stop()
+ break
+
+ try:
+ await self._handle_client_to_nostr(json_str)
+ except Exception as e:
+ logger.debug(f"Failed to handle client message: '{e!s}'.")
+
+ async def _nostr_to_client(self):
+ """Sends responses from relays back to the client."""
+ while self.connected:
+ try:
+ await self._handle_subscriptions()
+ self._handle_notices()
+ except Exception as e:
+ logger.debug(f"Failed to handle response for client: '{e!s}'.")
+ await asyncio.sleep(1)
+ await asyncio.sleep(0.1)
+
+ async def _handle_subscriptions(self):
+ for s in self.subscriptions:
+ if s in NostrRouter.received_subscription_events:
+ await self._handle_received_subscription_events(s)
+ if s in NostrRouter.received_subscription_eosenotices:
+ await self._handle_received_subscription_eosenotices(s)
+
+ async def _handle_received_subscription_eosenotices(self, s):
+ try:
+ if s not in self.original_subscription_ids:
+ return
+ s_original = self.original_subscription_ids[s]
+ event_to_forward = ["EOSE", s_original]
+ del NostrRouter.received_subscription_eosenotices[s]
+
+ await self.websocket.send_text(json.dumps(event_to_forward))
+ except Exception as e:
+ logger.debug(e)
+
+ async def _handle_received_subscription_events(self, s):
+ try:
+ if s not in NostrRouter.received_subscription_events:
+ return
+
+ while len(NostrRouter.received_subscription_events[s]):
+ event_message = NostrRouter.received_subscription_events[s].pop(0)
+ event_json = event_message.event
+
+ # this reconstructs the original response from the relay
+ # reconstruct original subscription id
+ s_original = self.original_subscription_ids[s]
+ event_to_forward = json.dumps(
+ ["EVENT", s_original, json.loads(event_json)]
+ )
+ await self.websocket.send_text(event_to_forward)
+ except Exception as e:
+ logger.warning(
+ f"[NOSTRCLIENT] Error in _handle_received_subscription_events: {e}"
+ )
+
+ def _handle_notices(self):
+ while len(NostrRouter.received_subscription_notices):
+ my_event = NostrRouter.received_subscription_notices.pop(0)
+ logger.debug(f"[Relay '{my_event.url}'] Notice: '{my_event.content}']")
+ # Note: we don't send it to the user because
+ # we don't know who should receive it
+ nostr_client.relay_manager.handle_notice(my_event)
+
+ async def _handle_client_to_nostr(self, json_str):
+ json_data = json.loads(json_str)
+ assert len(json_data), "Bad JSON array"
+
+ if json_data[0] == "REQ":
+ self._handle_client_req(json_data)
+ return
+
+ if json_data[0] == "CLOSE":
+ self._handle_client_close(json_data[1])
+ return
+
+ if json_data[0] == "EVENT":
+ nostr_client.relay_manager.publish_message(json_str)
+ return
+
+ def _handle_client_req(self, json_data):
+ subscription_id = json_data[1]
+ logger.info(f"New subscription: '{subscription_id}'")
+ subscription_id_rewritten = urlsafe_short_hash()
+ self.original_subscription_ids[subscription_id_rewritten] = subscription_id
+ filters = json_data[2:]
+
+ nostr_client.relay_manager.add_subscription(subscription_id_rewritten, filters)
+
+ def _handle_client_close(self, subscription_id):
+ subscription_id_rewritten = next(
+ (
+ k
+ for k, v in self.original_subscription_ids.items()
+ if v == subscription_id
+ ),
+ None,
+ )
+ if subscription_id_rewritten:
+ self.original_subscription_ids.pop(subscription_id_rewritten)
+ nostr_client.relay_manager.close_subscription(subscription_id_rewritten)
+ logger.info(
+ f"""
+ Unsubscribe from '{subscription_id_rewritten}'.
+ Original id: '{subscription_id}.'
+ """
+ )
+ else:
+ logger.info(f"Failed to unsubscribe from '{subscription_id}.'")
diff --git a/services.py b/services.py
deleted file mode 100644
index f584146..0000000
--- a/services.py
+++ /dev/null
@@ -1,134 +0,0 @@
-import asyncio
-import json
-from typing import List, Union
-from .models import RelayList, Relay, Event, Filter, Filters
-
-from .nostr.event import Event as NostrEvent
-from .nostr.filter import Filter as NostrFilter
-from .nostr.filter import Filters as NostrFilters
-from .tasks import (
- client,
- received_event_queue,
- received_subscription_events,
- received_subscription_eosenotices,
-)
-from fastapi import WebSocket, WebSocketDisconnect
-from lnbits.helpers import urlsafe_short_hash
-
-
-class NostrRouter:
- def __init__(self, websocket):
- self.subscriptions: List[str] = []
- self.connected: bool = True
- self.websocket = websocket
- self.tasks: List[asyncio.Task] = []
- self.subscription_id_rewrite: str = urlsafe_short_hash()
-
- async def client_to_nostr(self):
- """Receives requests / data from the client and forwards it to relays. If the
- request was a subscription/filter, registers it with the nostr client lib.
- Remembers the subscription id so we can send back responses from the relay to this
- client in `nostr_to_client`"""
- while True:
- try:
- json_str = await self.websocket.receive_text()
- except WebSocketDisconnect:
- self.connected = False
- break
- # print(json_str)
-
- # registers a subscription if the input was a REQ request
- subscription_id, json_str_rewritten = await self._add_nostr_subscription(
- json_str
- )
- if subscription_id and json_str_rewritten:
- self.subscriptions.append(subscription_id)
- json_str = json_str_rewritten
-
- # publish data
- client.relay_manager.publish_message(json_str)
-
- async def nostr_to_client(self):
- """Sends responses from relays back to the client. Polls the subscriptions of this client
- stored in `my_subscriptions`. Then gets all responses for this subscription id from `received_subscription_events` which
- is filled in tasks.py. Takes one response after the other and relays it back to the client. Reconstructs
- the reponse manually because the nostr client lib we're using can't do it. Reconstructs the original subscription id
- that we had previously rewritten in order to avoid collisions when multiple clients use the same id."""
- while True and self.connected:
- for s in self.subscriptions:
- if s in received_subscription_events:
- while len(received_subscription_events[s]):
- my_event = received_subscription_events[s].pop(0)
- # event.to_message() does not include the subscription ID, we have to add it manually
- event_json = {
- "id": my_event.id,
- "pubkey": my_event.public_key,
- "created_at": my_event.created_at,
- "kind": my_event.kind,
- "tags": my_event.tags,
- "content": my_event.content,
- "sig": my_event.signature,
- }
-
- # this reconstructs the original response from the relay
- # reconstruct oriiginal subscription id
- s_original = s[len(f"{self.subscription_id_rewrite}_") :]
- event_to_forward = ["EVENT", s_original, event_json]
- # print(json.dumps(event_to_forward))
-
- # send data back to client
- await self.websocket.send_text(json.dumps(event_to_forward))
- if s in received_subscription_eosenotices:
- my_event = received_subscription_eosenotices[s]
- s_original = s[len(f"{self.subscription_id_rewrite}_") :]
- event_to_forward = ["EOSE", s_original]
- del received_subscription_eosenotices[s]
- # send data back to client
- await self.websocket.send_text(json.dumps(event_to_forward))
- await asyncio.sleep(0.1)
-
- async def start(self):
- self.tasks.append(asyncio.create_task(self.client_to_nostr()))
- self.tasks.append(asyncio.create_task(self.nostr_to_client()))
-
- async def stop(self):
- for t in self.tasks:
- t.cancel()
-
- def _marshall_nostr_filters(self, data: Union[dict, list]):
- filters = data if isinstance(data, list) else [data]
- filters = [Filter.parse_obj(f) for f in filters]
- filter_list: list[NostrFilter] = []
- for filter in filters:
- filter_list.append(
- NostrFilter(
- event_ids=filter.ids or [],
- kinds=filter.kinds, # type: ignore
- authors=filter.authors or [],
- since=filter.since, # type: ignore
- until=filter.until, # type: ignore
- event_refs=filter.e, # type: ignore
- pubkey_refs=filter.p, # type: ignore
- limit=filter.limit, # type: ignore
- )
- )
- return NostrFilters(filter_list)
-
- async def _add_nostr_subscription(self, json_str):
- """Parses a (string) request from a client. If it is a subscription (REQ), it will
- register the subscription in the nostr client library that we're using so we can
- receive the callbacks on it later. Will rewrite the subscription id since we expect
- multiple clients to use the router and want to avoid subscription id collisions"""
- json_data = json.loads(json_str)
- assert len(json_data)
- if json_data[0] == "REQ":
- subscription_id = json_data[1]
- subscription_id_rewritten = (
- f"{self.subscription_id_rewrite}_{subscription_id}"
- )
- fltr = json_data[2]
- filters = self._marshall_nostr_filters(fltr)
- client.relay_manager.add_subscription(subscription_id_rewritten, filters)
- request_rewritten = json.dumps(["REQ", subscription_id_rewritten, fltr])
- return subscription_id_rewritten, request_rewritten
- return None, None
diff --git a/static/images/1.jpeg b/static/images/1.jpeg
new file mode 100644
index 0000000..1f00661
Binary files /dev/null and b/static/images/1.jpeg differ
diff --git a/static/images/2.jpeg b/static/images/2.jpeg
new file mode 100644
index 0000000..e7301fd
Binary files /dev/null and b/static/images/2.jpeg differ
diff --git a/tasks.py b/tasks.py
index 40d8aec..2a76765 100644
--- a/tasks.py
+++ b/tasks.py
@@ -1,76 +1,71 @@
import asyncio
-import ssl
import threading
-from .nostr.client.client import NostrClient
-from .nostr.event import Event
-from .nostr.message_pool import EventMessage, NoticeMessage, EndOfStoredEventsMessage
-from .nostr.key import PublicKey
-from .nostr.relay_manager import RelayManager
-
-
-client = NostrClient(
- connect=False,
-)
-
-received_event_queue: asyncio.Queue[EventMessage] = asyncio.Queue(0)
-received_subscription_events: dict[str, list[Event]] = {}
-received_subscription_notices: dict[str, list[NoticeMessage]] = {}
-received_subscription_eosenotices: dict[str, EndOfStoredEventsMessage] = {}
+from loguru import logger
from .crud import get_relays
+from .nostr.message_pool import EndOfStoredEventsMessage, EventMessage, NoticeMessage
+from .router import NostrRouter, nostr_client
async def init_relays():
+ # get relays from db
relays = await get_relays()
- client.relays = list(set([r.url for r in relays.__root__ if r.url]))
- client.connect()
- return
+ # set relays and connect to them
+ valid_relays = [r.url for r in relays if r.url]
+
+ nostr_client.reconnect(valid_relays)
+
+
+async def check_relays():
+ """Check relays that have been disconnected"""
+ while True:
+ try:
+ await asyncio.sleep(20)
+ nostr_client.relay_manager.check_and_restart_relays()
+ except Exception as e:
+ logger.warning(f"Cannot restart relays: '{e!s}'.")
async def subscribe_events():
- while not any([r.connected for r in client.relay_manager.relays.values()]):
+ while not [r.connected for r in nostr_client.relay_manager.relays.values()]:
await asyncio.sleep(2)
- def callback_events(eventMessage: EventMessage):
- # print(f"From {event.public_key[:3]}..{event.public_key[-3:]}: {event.content}")
- if eventMessage.subscription_id in received_subscription_events:
- # do not add duplicate events (by event id)
- if eventMessage.event.id in set(
- [
- e.id
- for e in received_subscription_events[eventMessage.subscription_id]
- ]
- ):
- return
+ def callback_events(event_message: EventMessage):
+ sub_id = event_message.subscription_id
+ if sub_id not in NostrRouter.received_subscription_events:
+ NostrRouter.received_subscription_events[sub_id] = [event_message]
+ return
- received_subscription_events[eventMessage.subscription_id].append(
- eventMessage.event
+ # do not add duplicate events (by event id)
+ ids = [e.event_id for e in NostrRouter.received_subscription_events[sub_id]]
+ if event_message.event_id in ids:
+ return
+
+ NostrRouter.received_subscription_events[sub_id].append(event_message)
+
+ def callback_notices(notice_message: NoticeMessage):
+ if notice_message not in NostrRouter.received_subscription_notices:
+ NostrRouter.received_subscription_notices.append(notice_message)
+
+ def callback_eose_notices(event_message: EndOfStoredEventsMessage):
+ sub_id = event_message.subscription_id
+ if sub_id in NostrRouter.received_subscription_eosenotices:
+ return
+
+ NostrRouter.received_subscription_eosenotices[sub_id] = event_message
+
+ def wrap_async_subscribe():
+ asyncio.run(
+ nostr_client.subscribe(
+ callback_events,
+ callback_notices,
+ callback_eose_notices,
)
- else:
- received_subscription_events[eventMessage.subscription_id] = [
- eventMessage.event
- ]
- return
-
- def callback_notices(eventMessage: NoticeMessage):
- return
-
- def callback_eose_notices(eventMessage: EndOfStoredEventsMessage):
- if eventMessage.subscription_id not in received_subscription_eosenotices:
- received_subscription_eosenotices[
- eventMessage.subscription_id
- ] = eventMessage
-
- return
+ )
t = threading.Thread(
- target=client.subscribe,
- args=(
- callback_events,
- callback_notices,
- callback_eose_notices,
- ),
+ target=wrap_async_subscribe,
name="Nostr-event-subscription",
daemon=True,
)
diff --git a/templates/nostrclient/index.html b/templates/nostrclient/index.html
index 8ddd362..ca44f1b 100644
--- a/templates/nostrclient/index.html
+++ b/templates/nostrclient/index.html
@@ -2,6 +2,52 @@
%} {% block page %} {% raw %}
This extension is a always-on nostr client that other extensions can - use to send and receive events on nostr. - - Add multiple nostr relays to connect to. The extension then opens a websocket for you to use - at -
-
-
+