diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
deleted file mode 100644
index 5bcdee7..0000000
--- a/.github/workflows/ci.yml
+++ /dev/null
@@ -1,29 +0,0 @@
-name: CI
-on:
- push:
- branches:
- - main
- pull_request:
-
-jobs:
- lint:
- uses: lnbits/lnbits/.github/workflows/lint.yml@dev
- tests:
- runs-on: ubuntu-latest
- needs: [lint]
- steps:
- - uses: actions/checkout@v4
- - uses: lnbits/lnbits/.github/actions/prepare@dev
- - name: Run pytest
- uses: pavelzw/pytest-action@v2
- env:
- LNBITS_BACKEND_WALLET_CLASS: FakeWallet
- PYTHONUNBUFFERED: 1
- DEBUG: true
- with:
- verbose: true
- job-summary: true
- emoji: false
- click-to-expand: true
- custom-pytest: uv run pytest
- report-title: 'test'
diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml
deleted file mode 100644
index 27c8a60..0000000
--- a/.github/workflows/release.yml
+++ /dev/null
@@ -1,58 +0,0 @@
-on:
- push:
- tags:
- - 'v[0-9]+.[0-9]+.[0-9]+'
-
-jobs:
- release:
- runs-on: ubuntu-latest
- steps:
- - uses: actions/checkout@v3
- - name: Create github release
- env:
- GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- tag: ${{ github.ref_name }}
- run: |
- gh release create "$tag" --generate-notes
-
- pullrequest:
- needs: [release]
- runs-on: ubuntu-latest
- steps:
- - uses: actions/checkout@v3
- with:
- token: ${{ secrets.EXT_GITHUB }}
- repository: lnbits/lnbits-extensions
- path: './lnbits-extensions'
-
- - name: setup git user
- run: |
- git config --global user.name "alan"
- git config --global user.email "alan@lnbits.com"
-
- - name: Create pull request in extensions repo
- env:
- GH_TOKEN: ${{ secrets.EXT_GITHUB }}
- repo_name: '${{ github.event.repository.name }}'
- tag: '${{ github.ref_name }}'
- branch: 'update-${{ github.event.repository.name }}-${{ github.ref_name }}'
- title: '[UPDATE] ${{ github.event.repository.name }} to ${{ github.ref_name }}'
- body: 'https://github.com/lnbits/${{ github.event.repository.name }}/releases/${{ github.ref_name }}'
- archive: 'https://github.com/lnbits/${{ github.event.repository.name }}/archive/refs/tags/${{ github.ref_name }}.zip'
- run: |
- cd lnbits-extensions
- git checkout -b $branch
-
- # if there is another open PR
- git pull origin $branch || echo "branch does not exist"
-
- sh util.sh update_extension $repo_name $tag
-
- git add -A
- git commit -am "$title"
- git push origin $branch
-
- # check if pr exists before creating it
- gh config set pager cat
- check=$(gh pr list -H $branch | wc -l)
- test $check -ne 0 || gh pr create --title "$title" --body "$body" --repo lnbits/lnbits-extensions
diff --git a/.gitignore b/.gitignore
index 0152b6e..10a11d5 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,4 +1,24 @@
+.DS_Store
+._*
+
__pycache__
-node_modules
+*.py[cod]
+*$py.class
.mypy_cache
-.venv
+.vscode
+*-lock.json
+
+*.egg
+*.egg-info
+.coverage
+.pytest_cache
+.webassets-cache
+htmlcov
+test-reports
+tests/data/*.sqlite3
+
+*.swo
+*.swp
+*.pyo
+*.pyc
+*.env
\ No newline at end of file
diff --git a/.prettierrc b/.prettierrc
deleted file mode 100644
index 725c398..0000000
--- a/.prettierrc
+++ /dev/null
@@ -1,12 +0,0 @@
-{
- "semi": false,
- "arrowParens": "avoid",
- "insertPragma": false,
- "printWidth": 80,
- "proseWrap": "preserve",
- "singleQuote": true,
- "trailingComma": "none",
- "useTabs": false,
- "bracketSameLine": false,
- "bracketSpacing": false
-}
diff --git a/Makefile b/Makefile
deleted file mode 100644
index 0fac253..0000000
--- a/Makefile
+++ /dev/null
@@ -1,47 +0,0 @@
-all: format check
-
-format: prettier black ruff
-
-check: mypy pyright checkblack checkruff checkprettier
-
-prettier:
- uv run ./node_modules/.bin/prettier --write .
-pyright:
- uv run ./node_modules/.bin/pyright
-
-mypy:
- uv run mypy .
-
-black:
- uv run black .
-
-ruff:
- uv run ruff check . --fix
-
-checkruff:
- uv run ruff check .
-
-checkprettier:
- uv run ./node_modules/.bin/prettier --check .
-
-checkblack:
- uv run black --check .
-
-checkeditorconfig:
- editorconfig-checker
-
-test:
- PYTHONUNBUFFERED=1 \
- DEBUG=true \
- uv run pytest
-install-pre-commit-hook:
- @echo "Installing pre-commit hook to git"
- @echo "Uninstall the hook with uv run pre-commit uninstall"
- uv run pre-commit install
-
-pre-commit:
- uv run pre-commit run --all-files
-
-
-checkbundle:
- @echo "skipping checkbundle"
diff --git a/README.md b/README.md
index 70593a8..5f9bfbc 100644
--- a/README.md
+++ b/README.md
@@ -1,124 +1,5 @@
-# Nostrclient - [LNbits](https://github.com/lnbits/lnbits) extension
+# nostrclient
-For more about LNBits extension check [this tutorial](https://github.com/lnbits/lnbits/wiki/LNbits-Extensions)
+`nostrclient` is an always-on extension that can open multiple connections to nostr relays and act as a multiplexer for other clients: You open a single websocket to `nostrclient` which then sends the data to multiple relays. The responses from these relays are then sent back to the client.
-## Overview
-
-`nostrclient` is an always-on Nostr relay multiplexer that simplifies connecting to multiple Nostr relays. Instead of your Nostr client managing connections to dozens of relays, you connect to a single WebSocket endpoint provided by `nostrclient`, which then fans out your requests to all configured relays and aggregates the responses back to you.
-
-### Why Use This?
-
-- **Simplified Client Configuration** - Connect to one endpoint instead of managing multiple relay connections
-- **Always-On Connectivity** - Your LNbits instance maintains persistent connections to relays
-- **Resource Efficient** - Share relay connections across multiple clients
-- **Subscription Management** - Automatic subscription ID rewriting prevents conflicts between clients
-
-## Architecture
-
-```mermaid
-flowchart LR
- A[Client A] -->|WebSocket| N
- B[Client B] -->|WebSocket| N
- C[Client C] -->|WebSocket| N
-
- N[nostrclient
Router] -->|Fan Out| R1[Relay A]
- N -->|Fan Out| R2[Relay B]
- N -->|Fan Out| R3[Relay C]
- N -->|Fan Out| R4[Relay D]
-
- R1 -.->|Aggregate| N
- R2 -.->|Aggregate| N
- R3 -.->|Aggregate| N
- R4 -.->|Aggregate| N
-```
-
-**Key Feature:** The router rewrites subscription IDs to prevent conflicts when multiple clients use the same IDs.
-
-## Features
-
-- **Multi-Relay Multiplexing** - Connect to multiple Nostr relays through a single WebSocket
-- **Public & Private Endpoints** - Configurable public and private WebSocket access
-- **Automatic Reconnection** - Failed relays are automatically retried with exponential backoff
-- **Subscription Deduplication** - Events are deduplicated before being sent to clients
-- **Health Monitoring** - Track relay connection status, latency, and error rates
-- **Test Endpoint** - Send test messages to verify your setup is working
-
-## How It Works
-
-1. **Client Connection** - Your Nostr client connects to the nostrclient WebSocket endpoint
-2. **Subscription Rewriting** - Each subscription ID is rewritten to prevent conflicts between multiple clients
-3. **Fan-Out** - Subscription requests are sent to all configured relays
-4. **Aggregation** - Events from all relays are collected and deduplicated
-5. **Response** - Events are sent back to the client with the original subscription ID
-
-## Configuration
-
-### WebSocket Endpoints
-
-- **Public Endpoint**: `/api/v1/relay` - Available to anyone (if enabled)
-- **Private Endpoint**: `/api/v1/{encrypted_id}` - Requires valid encrypted endpoint ID
-
-Configure endpoint access in the extension settings:
-
-- `private_ws` - Enable/disable private WebSocket access
-- `public_ws` - Enable/disable public WebSocket access
-
-### Adding Relays
-
-Use the nostrclient UI to add/remove Nostr relays. The extension will automatically:
-
-- Connect to new relays
-- Publish existing subscriptions to new relays
-- Monitor relay health and reconnect as needed
-
-## Testing
-
-### Test Endpoint Functionality
-
-The `Test Endpoint` feature helps verify that your nostrclient WebSocket endpoint works correctly.
-
-**How to test:**
-
-1. Navigate to the nostrclient extension in LNbits
-2. Use the Test Endpoint feature
-3. Send a DM to yourself (or a temporary account)
-4. Verify that messages are sent and received correctly
-
-https://user-images.githubusercontent.com/2951406/236780745-929c33c2-2502-49be-84a3-db02a7aabc0e.mp4
-
-## Troubleshooting
-
-### Connection Issues
-
-- **Check relay status** - View relay health in the nostrclient UI
-- **Verify endpoint configuration** - Ensure public_ws or private_ws is enabled
-- **Check logs** - Review LNbits logs for connection errors
-
-### Subscription Not Receiving Events
-
-- **Verify relays are connected** - Check the relay status in the UI
-- **Test with known event** - Use the Test Endpoint to verify connectivity
-- **Check relay compatibility** - Some relays may not support all Nostr features
-
-## Development
-
-This extension uses `uv` for dependency management.
-
-### Quick Start
-
-```bash
-# Format code
-make format
-
-# Run type checks and linting
-make check
-
-# Run tests
-make test
-```
-
-For more development commands, see the [Makefile](./Makefile).
-
-## License
-
-MIT License - see [LICENSE](./LICENSE)
+
diff --git a/__init__.py b/__init__.py
index d7eb435..60d8e23 100644
--- a/__init__.py
+++ b/__init__.py
@@ -1,59 +1,33 @@
-import asyncio
-
from fastapi import APIRouter
-from loguru import logger
+from starlette.staticfiles import StaticFiles
-from .crud import db
-from .router import all_routers, nostr_client
-from .tasks import check_relays, init_relays, subscribe_events
-from .views import nostrclient_generic_router
-from .views_api import nostrclient_api_router
+from lnbits.db import Database
+from lnbits.helpers import template_renderer
+from lnbits.tasks import catch_everything_and_restart
+
+db = Database("ext_nostrclient")
nostrclient_static_files = [
{
"path": "/nostrclient/static",
+ "app": StaticFiles(directory="lnbits/extensions/nostrclient/static"),
"name": "nostrclient_static",
}
]
nostrclient_ext: APIRouter = APIRouter(prefix="/nostrclient", tags=["nostrclient"])
-nostrclient_ext.include_router(nostrclient_generic_router)
-nostrclient_ext.include_router(nostrclient_api_router)
-scheduled_tasks: list[asyncio.Task] = []
-async def nostrclient_stop():
- for task in scheduled_tasks:
- try:
- task.cancel()
- except Exception as ex:
- logger.warning(ex)
+def nostr_renderer():
+ return template_renderer(["lnbits/extensions/nostrclient/templates"])
- for router in all_routers:
- try:
- await router.stop()
- all_routers.remove(router)
- except Exception as e:
- logger.error(e)
- nostr_client.close()
+from .tasks import init_relays, subscribe_events
+from .views import * # noqa
+from .views_api import * # noqa
def nostrclient_start():
- from lnbits.tasks import create_permanent_unique_task
-
- task1 = create_permanent_unique_task("ext_nostrclient_init_relays", init_relays)
- task2 = create_permanent_unique_task(
- "ext_nostrclient_subscrive_events", subscribe_events
- )
- task3 = create_permanent_unique_task("ext_nostrclient_check_relays", check_relays)
- scheduled_tasks.extend([task1, task2, task3])
-
-
-__all__ = [
- "db",
- "nostrclient_ext",
- "nostrclient_start",
- "nostrclient_static_files",
- "nostrclient_stop",
-]
+ loop = asyncio.get_event_loop()
+ loop.create_task(catch_everything_and_restart(init_relays))
+ loop.create_task(catch_everything_and_restart(subscribe_events))
diff --git a/cbc.py b/cbc.py
new file mode 100644
index 0000000..0d9e04f
--- /dev/null
+++ b/cbc.py
@@ -0,0 +1,26 @@
+from Cryptodome.Cipher import AES
+
+BLOCK_SIZE = 16
+
+
+class AESCipher(object):
+ """This class is compatible with crypto.createCipheriv('aes-256-cbc')"""
+
+ def __init__(self, key=None):
+ self.key = key
+
+ def pad(self, data):
+ length = BLOCK_SIZE - (len(data) % BLOCK_SIZE)
+ return data + (chr(length) * length).encode()
+
+ def unpad(self, data):
+ return data[: -(data[-1] if type(data[-1]) == int else ord(data[-1]))]
+
+ def encrypt(self, plain_text):
+ cipher = AES.new(self.key, AES.MODE_CBC)
+ b = plain_text.encode("UTF-8")
+ return cipher.iv, cipher.encrypt(self.pad(b))
+
+ def decrypt(self, iv, enc_text):
+ cipher = AES.new(self.key, AES.MODE_CBC, iv=iv)
+ return self.unpad(cipher.decrypt(enc_text).decode("UTF-8"))
diff --git a/config.json b/config.json
index 1f58e7b..ce8ae18 100644
--- a/config.json
+++ b/config.json
@@ -1,17 +1,6 @@
{
"name": "Nostr Client",
- "short_description": "Nostr relay multiplexer",
- "version": "1.1.0",
+ "short_description": "Nostr client for extensions",
"tile": "/nostrclient/static/images/nostr-bitcoin.png",
- "contributors": ["calle", "motorina0", "dni"],
- "min_lnbits_version": "1.4.0",
- "images": [
- {
- "uri": "https://raw.githubusercontent.com/lnbits/nostrclient/add-extension-metadata/static/images/1.jpeg"
- },
- {
- "uri": "https://raw.githubusercontent.com/lnbits/nostrclient/add-extension-metadata/static/images/2.jpeg"
- }
- ],
- "description_md": "https://raw.githubusercontent.com/lnbits/nostrclient/add-extension-metadata/description.md"
+ "contributors": ["calle"]
}
diff --git a/crud.py b/crud.py
index d311c72..780642d 100644
--- a/crud.py
+++ b/crud.py
@@ -1,52 +1,31 @@
-from lnbits.db import Database
+from typing import List, Optional, Union
-from .models import Config, Relay, UserConfig
+import shortuuid
-db = Database("ext_nostrclient")
+from lnbits.helpers import urlsafe_short_hash
+
+from . import db
+from .models import Relay, RelayList
-async def get_relays() -> list[Relay]:
- return await db.fetchall(
- "SELECT * FROM nostrclient.relays",
- model=Relay,
+async def get_relays() -> RelayList:
+ row = await db.fetchall("SELECT * FROM nostrclient.relays")
+ return RelayList(__root__=row)
+
+
+async def add_relay(relay: Relay) -> None:
+ await db.execute(
+ f"""
+ INSERT INTO nostrclient.relays (
+ id,
+ url,
+ active
+ )
+ VALUES (?, ?, ?)
+ """,
+ (relay.id, relay.url, relay.active),
)
-async def add_relay(relay: Relay) -> Relay:
- await db.insert("nostrclient.relays", relay)
- return relay
-
-
async def delete_relay(relay: Relay) -> None:
- if not relay.url:
- return
- await db.execute(
- "DELETE FROM nostrclient.relays WHERE url = :url", {"url": relay.url}
- )
-
-
-######################CONFIG#######################
-async def create_config(owner_id: str) -> Config:
- admin_config = UserConfig(owner_id=owner_id)
- await db.insert("nostrclient.config", admin_config)
- return admin_config.extra
-
-
-async def update_config(owner_id: str, config: Config) -> Config:
- user_config = UserConfig(owner_id=owner_id, extra=config)
- await db.update("nostrclient.config", user_config, "WHERE owner_id = :owner_id")
- return user_config.extra
-
-
-async def get_config(owner_id: str) -> Config | None:
- user_config: UserConfig = await db.fetchone(
- """
- SELECT * FROM nostrclient.config
- WHERE owner_id = :owner_id
- """,
- {"owner_id": owner_id},
- model=UserConfig,
- )
- if user_config:
- return user_config.extra
- return None
+ await db.execute("DELETE FROM nostrclient.relays WHERE url = ?", (relay.url,))
diff --git a/description.md b/description.md
deleted file mode 100644
index 5293087..0000000
--- a/description.md
+++ /dev/null
@@ -1,8 +0,0 @@
-An always-on relay multiplexer that simplifies connecting to multiple Nostr relays.
-
-Instead of your Nostr client managing connections to dozens of relays, you connect to a single WebSocket endpoint provided by `nostrclient`, which then fans out your requests to all configured relays and aggregates the responses back to you.
-
-- **Simplified Client Configuration** - Connect to one endpoint instead of managing multiple relay connections
-- **Always-On Connectivity** - Your LNbits instance maintains persistent connections to relays
-- **Resource Efficient** - Share relay connections across multiple clients
-- **Automatic Subscription Management** - Subscription ID rewriting prevents conflicts between clients
diff --git a/helpers.py b/helpers.py
deleted file mode 100644
index bcf5c02..0000000
--- a/helpers.py
+++ /dev/null
@@ -1,19 +0,0 @@
-from bech32 import bech32_decode, convertbits
-
-
-def normalize_public_key(pubkey: str) -> str:
- if pubkey.startswith("npub1"):
- _, decoded_data = bech32_decode(pubkey)
- if not decoded_data:
- raise ValueError("Public Key is not valid npub")
-
- decoded_data_bits = convertbits(decoded_data, 5, 8, False)
- if not decoded_data_bits:
- raise ValueError("Public Key is not valid npub")
- return bytes(decoded_data_bits).hex()
-
- # check if valid hex
- if len(pubkey) != 64:
- raise ValueError("Public Key is not valid hex")
- int(pubkey, 16)
- return pubkey
diff --git a/migrations.py b/migrations.py
index b16db58..5a30e45 100644
--- a/migrations.py
+++ b/migrations.py
@@ -3,7 +3,7 @@ async def m001_initial(db):
Initial nostrclient table.
"""
await db.execute(
- """
+ f"""
CREATE TABLE nostrclient.relays (
id TEXT NOT NULL PRIMARY KEY,
url TEXT NOT NULL,
@@ -11,22 +11,3 @@ async def m001_initial(db):
);
"""
)
-
-
-async def m002_create_config_table(db):
- """
- Allow the extension to persist and retrieve any number of config values.
- """
-
- await db.execute(
- """CREATE TABLE nostrclient.config (
- json_data TEXT NOT NULL
- );"""
- )
-
-
-async def m003_update_config_table(db):
- await db.execute("ALTER TABLE nostrclient.config RENAME COLUMN json_data TO extra")
- await db.execute(
- "ALTER TABLE nostrclient.config ADD COLUMN owner_id TEXT DEFAULT 'admin'"
- )
diff --git a/models.py b/models.py
index 937c6c5..4ed1e30 100644
--- a/models.py
+++ b/models.py
@@ -1,54 +1,91 @@
-from lnbits.helpers import urlsafe_short_hash
+from dataclasses import dataclass
+from typing import Dict, List, Optional
+
+from fastapi import Request
+from fastapi.param_functions import Query
from pydantic import BaseModel, Field
-
-class RelayStatus(BaseModel):
- num_sent_events: int | None = 0
- num_received_events: int | None = 0
- error_counter: int | None = 0
- error_list: list | None = []
- notice_list: list | None = []
+from lnbits.helpers import urlsafe_short_hash
class Relay(BaseModel):
- id: str | None = None
- url: str | None = None
- active: bool | None = None
-
- connected: bool | None = Field(default=None, no_database=True)
- connected_string: str | None = Field(default=None, no_database=True)
- status: RelayStatus | None = Field(default=None, no_database=True)
-
- ping: int | None = Field(default=None, no_database=True)
+ id: Optional[str] = None
+ url: Optional[str] = None
+ connected: Optional[bool] = None
+ connected_string: Optional[str] = None
+ status: Optional[str] = None
+ active: Optional[bool] = None
+ ping: Optional[int] = None
def _init__(self):
if not self.id:
self.id = urlsafe_short_hash()
-class RelayDb(BaseModel):
- id: str
- url: str
- active: bool | None = True
+class RelayList(BaseModel):
+ __root__: List[Relay]
-class TestMessage(BaseModel):
- sender_private_key: str | None
- reciever_public_key: str
- message: str
+class Event(BaseModel):
+ content: str
+ pubkey: str
+ created_at: Optional[int]
+ kind: int
+ tags: Optional[List[List[str]]]
+ sig: str
-class TestMessageResponse(BaseModel):
- private_key: str
- public_key: str
- event_json: str
+class Filter(BaseModel):
+ ids: Optional[List[str]]
+ kinds: Optional[List[int]]
+ authors: Optional[List[str]]
+ since: Optional[int]
+ until: Optional[int]
+ e: Optional[List[str]] = Field(alias="#e")
+ p: Optional[List[str]] = Field(alias="#p")
+ limit: Optional[int]
-class Config(BaseModel):
- private_ws: bool = True
- public_ws: bool = False
+class Filters(BaseModel):
+ __root__: List[Filter]
-class UserConfig(BaseModel):
- owner_id: str
- extra: Config = Config()
+# class nostrKeys(BaseModel):
+# pubkey: str
+# privkey: str
+
+# class nostrNotes(BaseModel):
+# id: str
+# pubkey: str
+# created_at: str
+# kind: int
+# tags: str
+# content: str
+# sig: str
+
+# class nostrCreateRelays(BaseModel):
+# relay: str = Query(None)
+
+# class nostrCreateConnections(BaseModel):
+# pubkey: str = Query(None)
+# relayid: str = Query(None)
+
+# class nostrRelays(BaseModel):
+# id: Optional[str]
+# relay: Optional[str]
+# status: Optional[bool] = False
+
+
+# class nostrRelaySetList(BaseModel):
+# allowlist: Optional[str]
+# denylist: Optional[str]
+
+# class nostrConnections(BaseModel):
+# id: str
+# pubkey: Optional[str]
+# relayid: Optional[str]
+
+# class nostrSubscriptions(BaseModel):
+# id: str
+# userPubkey: Optional[str]
+# subscribedPubkey: Optional[str]
diff --git a/tests/__init__.py b/nostr/__init__.py
similarity index 100%
rename from tests/__init__.py
rename to nostr/__init__.py
diff --git a/nostr/bech32.py b/nostr/bech32.py
index ba2ddd1..b068de7 100644
--- a/nostr/bech32.py
+++ b/nostr/bech32.py
@@ -23,25 +23,21 @@
from enum import Enum
-
class Encoding(Enum):
"""Enumeration type to list the various supported encodings."""
-
BECH32 = 1
BECH32M = 2
-
CHARSET = "qpzry9x8gf2tvdw0s3jn54khce6mua7l"
-BECH32M_CONST = 0x2BC830A3
-
+BECH32M_CONST = 0x2bc830a3
def bech32_polymod(values):
"""Internal function that computes the Bech32 checksum."""
- generator = [0x3B6A57B2, 0x26508E6D, 0x1EA119FA, 0x3D4233DD, 0x2A1462B3]
+ generator = [0x3b6a57b2, 0x26508e6d, 0x1ea119fa, 0x3d4233dd, 0x2a1462b3]
chk = 1
for value in values:
top = chk >> 25
- chk = (chk & 0x1FFFFFF) << 5 ^ value
+ chk = (chk & 0x1ffffff) << 5 ^ value
for i in range(5):
chk ^= generator[i] if ((top >> i) & 1) else 0
return chk
@@ -61,7 +57,6 @@ def bech32_verify_checksum(hrp, data):
return Encoding.BECH32M
return None
-
def bech32_create_checksum(hrp, data, spec):
"""Compute the checksum values given HRP and data."""
values = bech32_hrp_expand(hrp) + data
@@ -73,29 +68,26 @@ def bech32_create_checksum(hrp, data, spec):
def bech32_encode(hrp, data, spec):
"""Compute a Bech32 string given HRP and data values."""
combined = data + bech32_create_checksum(hrp, data, spec)
- return hrp + "1" + "".join([CHARSET[d] for d in combined])
-
+ return hrp + '1' + ''.join([CHARSET[d] for d in combined])
def bech32_decode(bech):
"""Validate a Bech32/Bech32m string, and determine HRP and data."""
- if (any(ord(x) < 33 or ord(x) > 126 for x in bech)) or (
- bech.lower() != bech and bech.upper() != bech
- ):
+ if ((any(ord(x) < 33 or ord(x) > 126 for x in bech)) or
+ (bech.lower() != bech and bech.upper() != bech)):
return (None, None, None)
bech = bech.lower()
- pos = bech.rfind("1")
+ pos = bech.rfind('1')
if pos < 1 or pos + 7 > len(bech) or len(bech) > 90:
return (None, None, None)
- if not all(x in CHARSET for x in bech[pos + 1 :]):
+ if not all(x in CHARSET for x in bech[pos+1:]):
return (None, None, None)
hrp = bech[:pos]
- data = [CHARSET.find(x) for x in bech[pos + 1 :]]
+ data = [CHARSET.find(x) for x in bech[pos+1:]]
spec = bech32_verify_checksum(hrp, data)
if spec is None:
return (None, None, None)
return (hrp, data[:-6], spec)
-
def convertbits(data, frombits, tobits, pad=True):
"""General power-of-2 base conversion."""
acc = 0
@@ -124,29 +116,22 @@ def decode(hrp, addr):
hrpgot, data, spec = bech32_decode(addr)
if hrpgot != hrp:
return (None, None)
- decoded = convertbits(data[1:], 5, 8, False) # type: ignore
+ decoded = convertbits(data[1:], 5, 8, False)
if decoded is None or len(decoded) < 2 or len(decoded) > 40:
return (None, None)
- if data[0] > 16: # type: ignore
+ if data[0] > 16:
return (None, None)
- if data[0] == 0 and len(decoded) != 20 and len(decoded) != 32: # type: ignore
+ if data[0] == 0 and len(decoded) != 20 and len(decoded) != 32:
return (None, None)
- if (
- data[0] == 0 # type: ignore
- and spec != Encoding.BECH32
- or data[0] != 0 # type: ignore
- and spec != Encoding.BECH32M
- ):
+ if data[0] == 0 and spec != Encoding.BECH32 or data[0] != 0 and spec != Encoding.BECH32M:
return (None, None)
- return (data[0], decoded) # type: ignore
+ return (data[0], decoded)
def encode(hrp, witver, witprog):
"""Encode a segwit address."""
spec = Encoding.BECH32 if witver == 0 else Encoding.BECH32M
- wit_prog = convertbits(witprog, 8, 5)
- assert wit_prog
- ret = bech32_encode(hrp, [witver, *wit_prog], spec)
+ ret = bech32_encode(hrp, [witver] + convertbits(witprog, 8, 5), spec)
if decode(hrp, ret) == (None, None):
return None
return ret
diff --git a/nostr/client/cbc.py b/nostr/client/cbc.py
new file mode 100644
index 0000000..a41dbc0
--- /dev/null
+++ b/nostr/client/cbc.py
@@ -0,0 +1,41 @@
+
+from Cryptodome import Random
+from Cryptodome.Cipher import AES
+
+plain_text = "This is the text to encrypts"
+
+# encrypted = "7mH9jq3K9xNfWqIyu9gNpUz8qBvGwsrDJ+ACExdV1DvGgY8q39dkxVKeXD7LWCDrPnoD/ZFHJMRMis8v9lwHfNgJut8EVTMuJJi8oTgJevOBXl+E+bJPwej9hY3k20rgCQistNRtGHUzdWyOv7S1tg==".encode()
+# iv = "GzDzqOVShWu3Pl2313FBpQ==".encode()
+
+key = bytes.fromhex("3aa925cb69eb613e2928f8a18279c78b1dca04541dfd064df2eda66b59880795")
+
+BLOCK_SIZE = 16
+
+class AESCipher(object):
+ """This class is compatible with crypto.createCipheriv('aes-256-cbc')
+
+ """
+ def __init__(self, key=None):
+ self.key = key
+
+ def pad(self, data):
+ length = BLOCK_SIZE - (len(data) % BLOCK_SIZE)
+ return data + (chr(length) * length).encode()
+
+ def unpad(self, data):
+ return data[: -(data[-1] if type(data[-1]) == int else ord(data[-1]))]
+
+ def encrypt(self, plain_text):
+ cipher = AES.new(self.key, AES.MODE_CBC)
+ b = plain_text.encode("UTF-8")
+ return cipher.iv, cipher.encrypt(self.pad(b))
+
+ def decrypt(self, iv, enc_text):
+ cipher = AES.new(self.key, AES.MODE_CBC, iv=iv)
+ return self.unpad(cipher.decrypt(enc_text).decode("UTF-8"))
+
+if __name__ == "__main__":
+ aes = AESCipher(key=key)
+ iv, enc_text = aes.encrypt(plain_text)
+ dec_text = aes.decrypt(iv, enc_text)
+ print(dec_text)
\ No newline at end of file
diff --git a/nostr/client/client.py b/nostr/client/client.py
index d6fb5c8..6fb885f 100644
--- a/nostr/client/client.py
+++ b/nostr/client/client.py
@@ -1,75 +1,152 @@
-import asyncio
-
-from loguru import logger
+from typing import *
+import ssl
+import time
+import json
+import os
+import base64
+from ..event import Event
from ..relay_manager import RelayManager
+from ..message_type import ClientMessageType
+from ..key import PrivateKey, PublicKey
+
+from ..filter import Filter, Filters
+from ..event import Event, EventKind, EncryptedDirectMessage
+from ..relay_manager import RelayManager
+from ..message_type import ClientMessageType
+
+# from aes import AESCipher
+from . import cbc
class NostrClient:
- relay_manager: RelayManager
- running: bool
+ relays = [
+ "wss://nostr-pub.wellorder.net",
+ "wss://nostr.zebedee.cloud",
+ "wss://nodestr.fmt.wiz.biz",
+ "wss://nostr.oxtr.dev",
+ ] # ["wss://nostr.oxtr.dev"] # ["wss://relay.nostr.info"] "wss://nostr-pub.wellorder.net" "ws://91.237.88.218:2700", "wss://nostrrr.bublina.eu.org", ""wss://nostr-relay.freeberty.net"", , "wss://nostr.oxtr.dev", "wss://relay.nostr.info", "wss://nostr-pub.wellorder.net" , "wss://relayer.fiatjaf.com", "wss://nodestr.fmt.wiz.biz/", "wss://no.str.cr"
+ relay_manager = RelayManager()
+ private_key: PrivateKey
+ public_key: PublicKey
- def __init__(self):
- self.running = True
- self.relay_manager = RelayManager()
+ def __init__(self, privatekey_hex: str = "", relays: List[str] = [], connect=True):
+ self.generate_keys(privatekey_hex)
- def connect(self, relays):
- for relay in relays:
- try:
- self.relay_manager.add_relay(relay)
- except Exception as e:
- logger.debug(e)
- self.running = True
+ if len(relays):
+ self.relays = relays
+ if connect:
+ self.connect()
- def reconnect(self, relays):
- self.relay_manager.remove_relays()
- self.connect(relays)
+ def connect(self):
+ for relay in self.relays:
+ self.relay_manager.add_relay(relay)
+ self.relay_manager.open_connections(
+ {"cert_reqs": ssl.CERT_NONE}
+ ) # NOTE: This disables ssl certificate verification
def close(self):
- try:
- self.relay_manager.close_all_subscriptions()
- self.relay_manager.close_connections()
+ self.relay_manager.close_connections()
- self.running = False
- except Exception as e:
- logger.error(e)
+ def generate_keys(self, privatekey_hex: str = None):
+ pk = bytes.fromhex(privatekey_hex) if privatekey_hex else None
+ self.private_key = PrivateKey(pk)
+ self.public_key = self.private_key.public_key
- async def subscribe(
+ def post(self, message: str):
+ event = Event(message, self.public_key.hex(), kind=EventKind.TEXT_NOTE)
+ self.private_key.sign_event(event)
+ event_json = event.to_message()
+ # print("Publishing message:")
+ # print(event_json)
+ self.relay_manager.publish_message(event_json)
+
+ def get_post(
+ self, sender_publickey: PublicKey = None, callback_func=None, filter_kwargs={}
+ ):
+ filter = Filter(
+ authors=[sender_publickey.hex()] if sender_publickey else None,
+ kinds=[EventKind.TEXT_NOTE],
+ **filter_kwargs,
+ )
+ filters = Filters([filter])
+ subscription_id = os.urandom(4).hex()
+ self.relay_manager.add_subscription(subscription_id, filters)
+
+ request = [ClientMessageType.REQUEST, subscription_id]
+ request.extend(filters.to_json_array())
+ message = json.dumps(request)
+ self.relay_manager.publish_message(message)
+
+ while True:
+ while self.relay_manager.message_pool.has_events():
+ event_msg = self.relay_manager.message_pool.get_event()
+ if callback_func:
+ callback_func(event_msg.event)
+ time.sleep(0.1)
+
+ def dm(self, message: str, to_pubkey: PublicKey):
+ dm = EncryptedDirectMessage(
+ recipient_pubkey=to_pubkey.hex(), cleartext_content=message
+ )
+ self.private_key.sign_event(dm)
+ self.relay_manager.publish_event(dm)
+
+ def get_dm(self, sender_publickey: PublicKey, callback_func=None):
+ filters = Filters(
+ [
+ Filter(
+ kinds=[EventKind.ENCRYPTED_DIRECT_MESSAGE],
+ pubkey_refs=[sender_publickey.hex()],
+ )
+ ]
+ )
+ subscription_id = os.urandom(4).hex()
+ self.relay_manager.add_subscription(subscription_id, filters)
+
+ request = [ClientMessageType.REQUEST, subscription_id]
+ request.extend(filters.to_json_array())
+ message = json.dumps(request)
+ self.relay_manager.publish_message(message)
+
+ while True:
+ while self.relay_manager.message_pool.has_events():
+ event_msg = self.relay_manager.message_pool.get_event()
+ if "?iv=" in event_msg.event.content:
+ try:
+ shared_secret = self.private_key.compute_shared_secret(
+ event_msg.event.public_key
+ )
+ aes = cbc.AESCipher(key=shared_secret)
+ enc_text_b64, iv_b64 = event_msg.event.content.split("?iv=")
+ iv = base64.decodebytes(iv_b64.encode("utf-8"))
+ enc_text = base64.decodebytes(enc_text_b64.encode("utf-8"))
+ dec_text = aes.decrypt(iv, enc_text)
+ if callback_func:
+ callback_func(event_msg.event, dec_text)
+ except:
+ pass
+ break
+ time.sleep(0.1)
+
+ def subscribe(
self,
callback_events_func=None,
callback_notices_func=None,
callback_eosenotices_func=None,
):
- while self.running:
- self._check_events(callback_events_func)
- self._check_notices(callback_notices_func)
- self._check_eos_notices(callback_eosenotices_func)
-
- await asyncio.sleep(0.2)
-
- def _check_events(self, callback_events_func=None):
- try:
+ while True:
while self.relay_manager.message_pool.has_events():
event_msg = self.relay_manager.message_pool.get_event()
if callback_events_func:
callback_events_func(event_msg)
- except Exception as e:
- logger.debug(e)
-
- def _check_notices(self, callback_notices_func=None):
- try:
while self.relay_manager.message_pool.has_notices():
- event_msg = self.relay_manager.message_pool.get_notice()
+ event_msg = self.relay_manager.message_pool.has_notices()
if callback_notices_func:
callback_notices_func(event_msg)
- except Exception as e:
- logger.debug(e)
-
- def _check_eos_notices(self, callback_eosenotices_func=None):
- try:
while self.relay_manager.message_pool.has_eose_notices():
event_msg = self.relay_manager.message_pool.get_eose_notice()
if callback_eosenotices_func:
callback_eosenotices_func(event_msg)
- except Exception as e:
- logger.debug(e)
+
+ time.sleep(0.1)
diff --git a/nostr/delegation.py b/nostr/delegation.py
new file mode 100644
index 0000000..94801f5
--- /dev/null
+++ b/nostr/delegation.py
@@ -0,0 +1,32 @@
+import time
+from dataclasses import dataclass
+
+
+@dataclass
+class Delegation:
+ delegator_pubkey: str
+ delegatee_pubkey: str
+ event_kind: int
+ duration_secs: int = 30*24*60 # default to 30 days
+ signature: str = None # set in PrivateKey.sign_delegation
+
+ @property
+ def expires(self) -> int:
+ return int(time.time()) + self.duration_secs
+
+ @property
+ def conditions(self) -> str:
+ return f"kind={self.event_kind}&created_at<{self.expires}"
+
+ @property
+ def delegation_token(self) -> str:
+ return f"nostr:delegation:{self.delegatee_pubkey}:{self.conditions}"
+
+ def get_tag(self) -> list[str]:
+ """ Called by Event """
+ return [
+ "delegation",
+ self.delegator_pubkey,
+ self.conditions,
+ self.signature,
+ ]
diff --git a/nostr/event.py b/nostr/event.py
index 994c0f4..b903e0e 100644
--- a/nostr/event.py
+++ b/nostr/event.py
@@ -1,11 +1,10 @@
-import json
import time
+import json
from dataclasses import dataclass, field
from enum import IntEnum
+from typing import List
+from secp256k1 import PublicKey
from hashlib import sha256
-from typing import Optional
-
-import coincurve
from .message_type import ClientMessageType
@@ -21,14 +20,14 @@ class EventKind(IntEnum):
@dataclass
class Event:
- content: Optional[str] = None
- public_key: Optional[str] = None
- created_at: Optional[int] = None
+ content: str = None
+ public_key: str = None
+ created_at: int = None
kind: int = EventKind.TEXT_NOTE
- tags: list[list[str]] = field(
+ tags: List[List[str]] = field(
default_factory=list
) # Dataclasses require special handling when the default value is a mutable type
- signature: Optional[str] = None
+ signature: str = None
def __post_init__(self):
if self.content is not None and not isinstance(self.content, str):
@@ -40,7 +39,7 @@ class Event:
@staticmethod
def serialize(
- public_key: str, created_at: int, kind: int, tags: list[list[str]], content: str
+ public_key: str, created_at: int, kind: int, tags: List[List[str]], content: str
) -> bytes:
data = [0, public_key, created_at, kind, tags, content]
data_str = json.dumps(data, separators=(",", ":"), ensure_ascii=False)
@@ -48,7 +47,7 @@ class Event:
@staticmethod
def compute_id(
- public_key: str, created_at: int, kind: int, tags: list[list[str]], content: str
+ public_key: str, created_at: int, kind: int, tags: List[List[str]], content: str
):
return sha256(
Event.serialize(public_key, created_at, kind, tags, content)
@@ -57,9 +56,6 @@ class Event:
@property
def id(self) -> str:
# Always recompute the id to reflect the up-to-date state of the Event
- assert self.public_key
- assert self.created_at
- assert self.content
return Event.compute_id(
self.public_key, self.created_at, self.kind, self.tags, self.content
)
@@ -73,10 +69,12 @@ class Event:
self.tags.append(["e", event_id])
def verify(self) -> bool:
- assert self.public_key
- assert self.signature
- pub_key = coincurve.PublicKeyXOnly(bytes.fromhex(self.public_key))
- return pub_key.verify(bytes.fromhex(self.signature), bytes.fromhex(self.id))
+ pub_key = PublicKey(
+ bytes.fromhex("02" + self.public_key), True
+ ) # add 02 for schnorr (bip340)
+ return pub_key.schnorr_verify(
+ bytes.fromhex(self.id), bytes.fromhex(self.signature), None, raw=True
+ )
def to_message(self) -> str:
return json.dumps(
@@ -97,9 +95,9 @@ class Event:
@dataclass
class EncryptedDirectMessage(Event):
- recipient_pubkey: Optional[str] = None
- cleartext_content: Optional[str] = None
- reference_event_id: Optional[str] = None
+ recipient_pubkey: str = None
+ cleartext_content: str = None
+ reference_event_id: str = None
def __post_init__(self):
if self.content is not None:
@@ -123,7 +121,6 @@ class EncryptedDirectMessage(Event):
def id(self) -> str:
if self.content is None:
raise Exception(
- "EncryptedDirectMessage `id` is undefined until its"
- + " message is encrypted and stored in the `content` field"
+ "EncryptedDirectMessage `id` is undefined until its message is encrypted and stored in the `content` field"
)
return super().id
diff --git a/nostr/filter.py b/nostr/filter.py
new file mode 100644
index 0000000..f119079
--- /dev/null
+++ b/nostr/filter.py
@@ -0,0 +1,134 @@
+from collections import UserList
+from typing import List
+
+from .event import Event, EventKind
+
+
+class Filter:
+ """
+ NIP-01 filtering.
+
+ Explicitly supports "#e" and "#p" tag filters via `event_refs` and `pubkey_refs`.
+
+ Arbitrary NIP-12 single-letter tag filters are also supported via `add_arbitrary_tag`.
+ If a particular single-letter tag gains prominence, explicit support should be
+ added. For example:
+ # arbitrary tag
+ filter.add_arbitrary_tag('t', [hashtags])
+
+ # promoted to explicit support
+ Filter(hashtag_refs=[hashtags])
+ """
+
+ def __init__(
+ self,
+ event_ids: List[str] = None,
+ kinds: List[EventKind] = None,
+ authors: List[str] = None,
+ since: int = None,
+ until: int = None,
+ event_refs: List[
+ str
+ ] = None, # the "#e" attr; list of event ids referenced in an "e" tag
+ pubkey_refs: List[
+ str
+ ] = None, # The "#p" attr; list of pubkeys referenced in a "p" tag
+ limit: int = None,
+ ) -> None:
+ self.event_ids = event_ids
+ self.kinds = kinds
+ self.authors = authors
+ self.since = since
+ self.until = until
+ self.event_refs = event_refs
+ self.pubkey_refs = pubkey_refs
+ self.limit = limit
+
+ self.tags = {}
+ if self.event_refs:
+ self.add_arbitrary_tag("e", self.event_refs)
+ if self.pubkey_refs:
+ self.add_arbitrary_tag("p", self.pubkey_refs)
+
+ def add_arbitrary_tag(self, tag: str, values: list):
+ """
+ Filter on any arbitrary tag with explicit handling for NIP-01 and NIP-12
+ single-letter tags.
+ """
+ # NIP-01 'e' and 'p' tags and any NIP-12 single-letter tags must be prefixed with "#"
+ tag_key = tag if len(tag) > 1 else f"#{tag}"
+ self.tags[tag_key] = values
+
+ def matches(self, event: Event) -> bool:
+ if self.event_ids is not None and event.id not in self.event_ids:
+ return False
+ if self.kinds is not None and event.kind not in self.kinds:
+ return False
+ if self.authors is not None and event.public_key not in self.authors:
+ return False
+ if self.since is not None and event.created_at < self.since:
+ return False
+ if self.until is not None and event.created_at > self.until:
+ return False
+ if (self.event_refs is not None or self.pubkey_refs is not None) and len(
+ event.tags
+ ) == 0:
+ return False
+
+ if self.tags:
+ e_tag_identifiers = set([e_tag[0] for e_tag in event.tags])
+ for f_tag, f_tag_values in self.tags.items():
+ # Omit any NIP-01 or NIP-12 "#" chars on single-letter tags
+ f_tag = f_tag.replace("#", "")
+
+ if f_tag not in e_tag_identifiers:
+ # Event is missing a tag type that we're looking for
+ return False
+
+ # Multiple values within f_tag_values are treated as OR search; an Event
+ # needs to match only one.
+ # Note: an Event could have multiple entries of the same tag type
+ # (e.g. a reply to multiple people) so we have to check all of them.
+ match_found = False
+ for e_tag in event.tags:
+ if e_tag[0] == f_tag and e_tag[1] in f_tag_values:
+ match_found = True
+ break
+ if not match_found:
+ return False
+
+ return True
+
+ def to_json_object(self) -> dict:
+ res = {}
+ if self.event_ids is not None:
+ res["ids"] = self.event_ids
+ if self.kinds is not None:
+ res["kinds"] = self.kinds
+ if self.authors is not None:
+ res["authors"] = self.authors
+ if self.since is not None:
+ res["since"] = self.since
+ if self.until is not None:
+ res["until"] = self.until
+ if self.limit is not None:
+ res["limit"] = self.limit
+ if self.tags:
+ res.update(self.tags)
+
+ return res
+
+
+class Filters(UserList):
+ def __init__(self, initlist: "list[Filter]" = []) -> None:
+ super().__init__(initlist)
+ self.data: "list[Filter]"
+
+ def match(self, event: Event):
+ for filter in self.data:
+ if filter.matches(event):
+ return True
+ return False
+
+ def to_json_array(self) -> list:
+ return [filter.to_json_object() for filter in self.data]
diff --git a/nostr/key.py b/nostr/key.py
index f7b4e81..d34697f 100644
--- a/nostr/key.py
+++ b/nostr/key.py
@@ -1,12 +1,14 @@
-import base64
import secrets
-
-import coincurve
-from cryptography.hazmat.primitives import padding
+import base64
+import secp256k1
+from cffi import FFI
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
+from cryptography.hazmat.primitives import padding
+from hashlib import sha256
-from .bech32 import Encoding, bech32_decode, bech32_encode, convertbits
+from .delegation import Delegation
from .event import EncryptedDirectMessage, Event, EventKind
+from . import bech32
class PublicKey:
@@ -14,61 +16,55 @@ class PublicKey:
self.raw_bytes = raw_bytes
def bech32(self) -> str:
- converted_bits = convertbits(self.raw_bytes, 8, 5)
- return bech32_encode("npub", converted_bits, Encoding.BECH32)
+ converted_bits = bech32.convertbits(self.raw_bytes, 8, 5)
+ return bech32.bech32_encode("npub", converted_bits, bech32.Encoding.BECH32)
def hex(self) -> str:
return self.raw_bytes.hex()
- def verify_signed_message_hash(self, message_hash: str, sig: str) -> bool:
- pk = coincurve.PublicKeyXOnly(self.raw_bytes)
- return pk.verify(bytes.fromhex(sig), bytes.fromhex(message_hash))
+ def verify_signed_message_hash(self, hash: str, sig: str) -> bool:
+ pk = secp256k1.PublicKey(b"\x02" + self.raw_bytes, True)
+ return pk.schnorr_verify(bytes.fromhex(hash), bytes.fromhex(sig), None, True)
@classmethod
def from_npub(cls, npub: str):
"""Load a PublicKey from its bech32/npub form"""
- _, data, _ = bech32_decode(npub)
- raw_data = convertbits(data, 5, 8)
- assert raw_data
- raw_public_key = raw_data[:-1]
+ hrp, data, spec = bech32.bech32_decode(npub)
+ raw_public_key = bech32.convertbits(data, 5, 8)[:-1]
return cls(bytes(raw_public_key))
class PrivateKey:
- def __init__(self, raw_secret: bytes | None = None) -> None:
- if raw_secret is not None:
+ def __init__(self, raw_secret: bytes = None) -> None:
+ if not raw_secret is None:
self.raw_secret = raw_secret
else:
self.raw_secret = secrets.token_bytes(32)
- sk = coincurve.PrivateKey(self.raw_secret)
- assert sk.public_key
- self.public_key = PublicKey(sk.public_key.format()[1:])
+ sk = secp256k1.PrivateKey(self.raw_secret)
+ self.public_key = PublicKey(sk.pubkey.serialize()[1:])
@classmethod
def from_nsec(cls, nsec: str):
"""Load a PrivateKey from its bech32/nsec form"""
- _, data, _ = bech32_decode(nsec)
- raw_data = convertbits(data, 5, 8)
- assert raw_data
- raw_secret = raw_data[:-1]
+ hrp, data, spec = bech32.bech32_decode(nsec)
+ raw_secret = bech32.convertbits(data, 5, 8)[:-1]
return cls(bytes(raw_secret))
def bech32(self) -> str:
- converted_bits = convertbits(self.raw_secret, 8, 5)
- return bech32_encode("nsec", converted_bits, Encoding.BECH32)
+ converted_bits = bech32.convertbits(self.raw_secret, 8, 5)
+ return bech32.bech32_encode("nsec", converted_bits, bech32.Encoding.BECH32)
def hex(self) -> str:
return self.raw_secret.hex()
def tweak_add(self, scalar: bytes) -> bytes:
- sk = coincurve.PrivateKey(self.raw_secret)
- return sk.add(scalar).to_der()
+ sk = secp256k1.PrivateKey(self.raw_secret)
+ return sk.tweak_add(scalar)
def compute_shared_secret(self, public_key_hex: str) -> bytes:
- pk = coincurve.PublicKey(bytes.fromhex("02" + public_key_hex))
- sk = coincurve.PrivateKey(self.raw_secret)
- return sk.ecdh(pk.format())
+ pk = secp256k1.PublicKey(bytes.fromhex("02" + public_key_hex), True)
+ return pk.ecdh(self.raw_secret, hashfn=copy_x)
def encrypt_message(self, message: str, public_key_hex: str) -> str:
padder = padding.PKCS7(128).padder()
@@ -82,14 +78,9 @@ class PrivateKey:
encryptor = cipher.encryptor()
encrypted_message = encryptor.update(padded_data) + encryptor.finalize()
- return (
- f"{base64.b64encode(encrypted_message).decode()}"
- + f"?iv={base64.b64encode(iv).decode()}"
- )
+ return f"{base64.b64encode(encrypted_message).decode()}?iv={base64.b64encode(iv).decode()}"
def encrypt_dm(self, dm: EncryptedDirectMessage) -> None:
- assert dm.cleartext_content
- assert dm.recipient_pubkey
dm.content = self.encrypt_message(
message=dm.cleartext_content, public_key_hex=dm.recipient_pubkey
)
@@ -112,23 +103,28 @@ class PrivateKey:
return unpadded_data.decode()
- def sign_message_hash(self, message_hash: bytes) -> str:
- sk = coincurve.PrivateKey(self.raw_secret)
- sig = sk.sign_schnorr(message_hash)
+ def sign_message_hash(self, hash: bytes) -> str:
+ sk = secp256k1.PrivateKey(self.raw_secret)
+ sig = sk.schnorr_sign(hash, None, raw=True)
return sig.hex()
def sign_event(self, event: Event) -> None:
if event.kind == EventKind.ENCRYPTED_DIRECT_MESSAGE and event.content is None:
- self.encrypt_dm(event) # type: ignore
+ self.encrypt_dm(event)
if event.public_key is None:
event.public_key = self.public_key.hex()
event.signature = self.sign_message_hash(bytes.fromhex(event.id))
+ def sign_delegation(self, delegation: Delegation) -> None:
+ delegation.signature = self.sign_message_hash(
+ sha256(delegation.delegation_token.encode()).digest()
+ )
+
def __eq__(self, other):
return self.raw_secret == other.raw_secret
-def mine_vanity_key(prefix: str | None = None, suffix: str | None = None) -> PrivateKey:
+def mine_vanity_key(prefix: str = None, suffix: str = None) -> PrivateKey:
if prefix is None and suffix is None:
raise ValueError("Expected at least one of 'prefix' or 'suffix' arguments")
@@ -144,3 +140,14 @@ def mine_vanity_key(prefix: str | None = None, suffix: str | None = None) -> Pri
break
return sk
+
+
+ffi = FFI()
+
+
+@ffi.callback(
+ "int (unsigned char *, const unsigned char *, const unsigned char *, void *)"
+)
+def copy_x(output, x32, y32, data):
+ ffi.memmove(output, x32, 32)
+ return 1
diff --git a/nostr/message_pool.py b/nostr/message_pool.py
index a3e6c5f..d364cf2 100644
--- a/nostr/message_pool.py
+++ b/nostr/message_pool.py
@@ -1,16 +1,13 @@
import json
from queue import Queue
from threading import Lock
-
from .message_type import RelayMessageType
+from .event import Event
class EventMessage:
- def __init__(
- self, event: str, event_id: str, subscription_id: str, url: str
- ) -> None:
+ def __init__(self, event: Event, subscription_id: str, url: str) -> None:
self.event = event
- self.event_id = event_id
self.subscription_id = subscription_id
self.url = url
@@ -61,29 +58,20 @@ class MessagePool:
message_type = message_json[0]
if message_type == RelayMessageType.EVENT:
subscription_id = message_json[1]
- event = message_json[2]
- if "id" not in event:
- return
- event_id = event["id"]
-
+ e = message_json[2]
+ event = Event(
+ e["content"],
+ e["pubkey"],
+ e["created_at"],
+ e["kind"],
+ e["tags"],
+ e["sig"],
+ )
with self.lock:
- if f"{subscription_id}_{event_id}" not in self._unique_events:
- self._accept_event(
- EventMessage(json.dumps(event), event_id, subscription_id, url)
- )
+ if not event.id in self._unique_events:
+ self.events.put(EventMessage(event, subscription_id, url))
+ self._unique_events.add(event.id)
elif message_type == RelayMessageType.NOTICE:
self.notices.put(NoticeMessage(message_json[1], url))
elif message_type == RelayMessageType.END_OF_STORED_EVENTS:
self.eose_notices.put(EndOfStoredEventsMessage(message_json[1], url))
-
- def _accept_event(self, event_message: EventMessage):
- """
- Event uniqueness is considered per `subscription_id`. The `subscription_id` is
- rewritten to be unique and it is the same accross relays. The same event can
- come from different subscriptions (from the same client or from different ones).
- Clients that have joined later should receive older events.
- """
- self.events.put(event_message)
- self._unique_events.add(
- f"{event_message.subscription_id}_{event_message.event_id}"
- )
diff --git a/nostr/message_type.py b/nostr/message_type.py
index d37cdfd..3f5206b 100644
--- a/nostr/message_type.py
+++ b/nostr/message_type.py
@@ -3,20 +3,13 @@ class ClientMessageType:
REQUEST = "REQ"
CLOSE = "CLOSE"
-
class RelayMessageType:
EVENT = "EVENT"
NOTICE = "NOTICE"
END_OF_STORED_EVENTS = "EOSE"
- COMMAND_RESULT = "OK"
@staticmethod
def is_valid(type: str) -> bool:
- if (
- type == RelayMessageType.EVENT
- or type == RelayMessageType.NOTICE
- or type == RelayMessageType.END_OF_STORED_EVENTS
- or type == RelayMessageType.COMMAND_RESULT
- ):
+ if type == RelayMessageType.EVENT or type == RelayMessageType.NOTICE or type == RelayMessageType.END_OF_STORED_EVENTS:
return True
- return False
+ return False
\ No newline at end of file
diff --git a/nostr/pow.py b/nostr/pow.py
new file mode 100644
index 0000000..e006288
--- /dev/null
+++ b/nostr/pow.py
@@ -0,0 +1,54 @@
+import time
+from .event import Event
+from .key import PrivateKey
+
+def zero_bits(b: int) -> int:
+ n = 0
+
+ if b == 0:
+ return 8
+
+ while b >> 1:
+ b = b >> 1
+ n += 1
+
+ return 7 - n
+
+def count_leading_zero_bits(hex_str: str) -> int:
+ total = 0
+ for i in range(0, len(hex_str) - 2, 2):
+ bits = zero_bits(int(hex_str[i:i+2], 16))
+ total += bits
+
+ if bits != 8:
+ break
+
+ return total
+
+def mine_event(content: str, difficulty: int, public_key: str, kind: int, tags: list=[]) -> Event:
+ all_tags = [["nonce", "1", str(difficulty)]]
+ all_tags.extend(tags)
+
+ created_at = int(time.time())
+ event_id = Event.compute_id(public_key, created_at, kind, all_tags, content)
+ num_leading_zero_bits = count_leading_zero_bits(event_id)
+
+ attempts = 1
+ while num_leading_zero_bits < difficulty:
+ attempts += 1
+ all_tags[0][1] = str(attempts)
+ created_at = int(time.time())
+ event_id = Event.compute_id(public_key, created_at, kind, all_tags, content)
+ num_leading_zero_bits = count_leading_zero_bits(event_id)
+
+ return Event(public_key, content, created_at, kind, all_tags, event_id)
+
+def mine_key(difficulty: int) -> PrivateKey:
+ sk = PrivateKey()
+ num_leading_zero_bits = count_leading_zero_bits(sk.public_key.hex())
+
+ while num_leading_zero_bits < difficulty:
+ sk = PrivateKey()
+ num_leading_zero_bits = count_leading_zero_bits(sk.public_key.hex())
+
+ return sk
diff --git a/nostr/relay.py b/nostr/relay.py
index d762963..ee78baa 100644
--- a/nostr/relay.py
+++ b/nostr/relay.py
@@ -1,37 +1,49 @@
-import asyncio
import json
import time
from queue import Queue
-
-from loguru import logger
+from threading import Lock
from websocket import WebSocketApp
-
+from .event import Event
+from .filter import Filters
from .message_pool import MessagePool
+from .message_type import RelayMessageType
from .subscription import Subscription
+class RelayPolicy:
+ def __init__(self, should_read: bool = True, should_write: bool = True) -> None:
+ self.should_read = should_read
+ self.should_write = should_write
+
+ def to_json_object(self) -> dict[str, bool]:
+ return {"read": self.should_read, "write": self.should_write}
+
+
class Relay:
- def __init__(self, url: str, message_pool: MessagePool) -> None:
+ def __init__(
+ self,
+ url: str,
+ policy: RelayPolicy,
+ message_pool: MessagePool,
+ subscriptions: dict[str, Subscription] = {},
+ ) -> None:
self.url = url
+ self.policy = policy
self.message_pool = message_pool
+ self.subscriptions = subscriptions
self.connected: bool = False
self.reconnect: bool = True
- self.shutdown: bool = False
-
self.error_counter: int = 0
- self.error_threshold: int = 100
- self.error_list: list[str] = []
- self.notice_list: list[str] = []
- self.last_error_date: int = 0
+ self.error_threshold: int = 0
self.num_received_events: int = 0
self.num_sent_events: int = 0
self.num_subscriptions: int = 0
-
- self.queue: Queue = Queue()
-
- def connect(self):
+ self.ssl_options: dict = {}
+ self.proxy: dict = {}
+ self.lock = Lock()
+ self.queue = Queue()
self.ws = WebSocketApp(
- self.url,
+ url,
on_open=self._on_open,
on_message=self._on_message,
on_error=self._on_error,
@@ -39,20 +51,31 @@ class Relay:
on_ping=self._on_ping,
on_pong=self._on_pong,
)
+
+ def connect(self, ssl_options: dict = None, proxy: dict = None):
+ self.ssl_options = ssl_options
+ self.proxy = proxy
if not self.connected:
- self.ws.run_forever(ping_interval=10)
+ self.ws.run_forever(
+ sslopt=ssl_options,
+ http_proxy_host=None if proxy is None else proxy.get("host"),
+ http_proxy_port=None if proxy is None else proxy.get("port"),
+ proxy_type=None if proxy is None else proxy.get("type"),
+ ping_interval=5,
+ )
def close(self):
- try:
- self.ws.close()
- except Exception as e:
- logger.warning(f"[Relay: {self.url}] Failed to close websocket: {e}")
- self.connected = False
- self.shutdown = True
+ self.ws.close()
- @property
- def error_threshold_reached(self):
- return self.error_threshold and self.error_counter >= self.error_threshold
+ def check_reconnect(self):
+ try:
+ self.close()
+ except:
+ pass
+ self.connected = False
+ if self.reconnect:
+ time.sleep(1)
+ self.connect(self.ssl_options, self.proxy)
@property
def ping(self):
@@ -62,65 +85,99 @@ class Relay:
def publish(self, message: str):
self.queue.put(message)
- def publish_subscriptions(self, subscriptions: list[Subscription]):
- for s in subscriptions:
- assert s.filters
- json_str = json.dumps(["REQ", s.id, *s.filters])
- self.publish(json_str)
-
- async def queue_worker(self):
+ def queue_worker(self):
while True:
if self.connected:
- try:
- message = self.queue.get(timeout=1)
- self.num_sent_events += 1
- self.ws.send(message)
- except Exception as _:
- pass
+ message = self.queue.get()
+ self.num_sent_events += 1
+ self.ws.send(message)
else:
- await asyncio.sleep(1)
+ time.sleep(0.1)
- if self.shutdown:
- logger.warning(f"[Relay: {self.url}] Closing queue worker.")
- return
+ def add_subscription(self, id, filters: Filters):
+ with self.lock:
+ self.subscriptions[id] = Subscription(id, filters)
- def close_subscription(self, sub_id: str) -> None:
- try:
- self.publish(json.dumps(["CLOSE", sub_id]))
- except Exception as e:
- logger.debug(f"[Relay: {self.url}] Failed to close subscription: {e}")
+ def close_subscription(self, id: str) -> None:
+ with self.lock:
+ self.subscriptions.pop(id)
- def add_notice(self, notice: str):
- self.notice_list = [notice, *self.notice_list]
+ def update_subscription(self, id: str, filters: Filters) -> None:
+ with self.lock:
+ subscription = self.subscriptions[id]
+ subscription.filters = filters
- def _on_open(self, _):
- logger.info(f"[Relay: {self.url}] Connected.")
+ def to_json_object(self) -> dict:
+ return {
+ "url": self.url,
+ "policy": self.policy.to_json_object(),
+ "subscriptions": [
+ subscription.to_json_object()
+ for subscription in self.subscriptions.values()
+ ],
+ }
+
+ def _on_open(self, class_obj):
self.connected = True
- self.shutdown = False
+ pass
- def _on_close(self, _, status_code, message):
- logger.warning(
- f"[Relay: {self.url}] Connection closed."
- + f" Status: '{status_code}'. Message: '{message}'."
- )
- self.close()
+ def _on_close(self, class_obj, status_code, message):
+ self.connected = False
+ pass
- def _on_message(self, _, message: str):
- self.num_received_events += 1
- self.message_pool.add_message(message, self.url)
+ def _on_message(self, class_obj, message: str):
+ if self._is_valid_message(message):
+ self.num_received_events += 1
+ self.message_pool.add_message(message, self.url)
- def _on_error(self, _, error):
- logger.warning(f"[Relay: {self.url}] Error: '{error!s}'")
- self._append_error_message(str(error))
- self.close()
-
- def _on_ping(self, *_):
- return
-
- def _on_pong(self, *_):
- return
-
- def _append_error_message(self, message):
+ def _on_error(self, class_obj, error):
+ self.connected = False
self.error_counter += 1
- self.error_list = [message, *self.error_list]
- self.last_error_date = int(time.time())
+ if self.error_threshold and self.error_counter > self.error_threshold:
+ pass
+ else:
+ self.check_reconnect()
+
+ def _on_ping(self, class_obj, message):
+ return
+
+ def _on_pong(self, class_obj, message):
+ return
+
+ def _is_valid_message(self, message: str) -> bool:
+ message = message.strip("\n")
+ if not message or message[0] != "[" or message[-1] != "]":
+ return False
+
+ message_json = json.loads(message)
+ message_type = message_json[0]
+ if not RelayMessageType.is_valid(message_type):
+ return False
+ if message_type == RelayMessageType.EVENT:
+ if not len(message_json) == 3:
+ return False
+
+ subscription_id = message_json[1]
+ with self.lock:
+ if subscription_id not in self.subscriptions:
+ return False
+
+ e = message_json[2]
+ event = Event(
+ e["content"],
+ e["pubkey"],
+ e["created_at"],
+ e["kind"],
+ e["tags"],
+ e["sig"],
+ )
+ if not event.verify():
+ return False
+
+ with self.lock:
+ subscription = self.subscriptions[subscription_id]
+
+ if subscription.filters and not subscription.filters.match(event):
+ return False
+
+ return True
diff --git a/nostr/relay_manager.py b/nostr/relay_manager.py
index 2aa27c5..5b92d8d 100644
--- a/nostr/relay_manager.py
+++ b/nostr/relay_manager.py
@@ -1,99 +1,52 @@
-import asyncio
+import json
import threading
-import time
-from typing import List
-from loguru import logger
+from .event import Event
+from .filter import Filters
+from .message_pool import MessagePool
+from .message_type import ClientMessageType
+from .relay import Relay, RelayPolicy
-from .message_pool import MessagePool, NoticeMessage
-from .relay import Relay
-from .subscription import Subscription
+
+class RelayException(Exception):
+ pass
class RelayManager:
def __init__(self) -> None:
self.relays: dict[str, Relay] = {}
- self.threads: dict[str, threading.Thread] = {}
- self.queue_threads: dict[str, threading.Thread] = {}
self.message_pool = MessagePool()
- self._cached_subscriptions: dict[str, Subscription] = {}
- self._subscriptions_lock = threading.Lock()
- def add_relay(self, url: str) -> Relay:
- if url in list(self.relays.keys()):
- logger.debug(f"Relay '{url}' already present.")
- return self.relays[url]
-
- relay = Relay(url, self.message_pool)
+ def add_relay(
+ self, url: str, read: bool = True, write: bool = True, subscriptions={}
+ ):
+ policy = RelayPolicy(read, write)
+ relay = Relay(url, policy, self.message_pool, subscriptions)
self.relays[url] = relay
- self._open_connection(relay)
-
- relay.publish_subscriptions(list(self._cached_subscriptions.values()))
- return relay
-
def remove_relay(self, url: str):
- try:
- self.relays[url].close()
- except Exception as e:
- logger.debug(e)
-
- if url in self.relays:
- self.relays.pop(url)
-
- try:
- self.threads[url].join(timeout=5)
- except Exception as e:
- logger.debug(e)
-
- if url in self.threads:
- self.threads.pop(url)
-
- try:
- self.queue_threads[url].join(timeout=5)
- except Exception as e:
- logger.debug(e)
-
- if url in self.queue_threads:
- self.queue_threads.pop(url)
-
- def remove_relays(self):
- relay_urls = list(self.relays.keys())
- for url in relay_urls:
- self.remove_relay(url)
-
- def add_subscription(self, id: str, filters: List[str]):
- s = Subscription(id, filters)
- with self._subscriptions_lock:
- self._cached_subscriptions[id] = s
+ self.relays.pop(url)
+ def add_subscription(self, id: str, filters: Filters):
for relay in self.relays.values():
- relay.publish_subscriptions([s])
+ relay.add_subscription(id, filters)
def close_subscription(self, id: str):
- try:
- logger.info(f"Closing subscription: '{id}'.")
- with self._subscriptions_lock:
- if id in self._cached_subscriptions:
- self._cached_subscriptions.pop(id)
+ for relay in self.relays.values():
+ relay.close_subscription(id)
- for relay in self.relays.values():
- relay.close_subscription(id)
- except Exception as e:
- logger.debug(e)
+ def open_connections(self, ssl_options: dict = None, proxy: dict = None):
+ for relay in self.relays.values():
+ threading.Thread(
+ target=relay.connect,
+ args=(ssl_options, proxy),
+ name=f"{relay.url}-thread",
+ daemon=True,
+ ).start()
- def close_subscriptions(self, subscriptions: List[str]):
- for id in subscriptions:
- self.close_subscription(id)
-
- def close_all_subscriptions(self):
- all_subscriptions = list(self._cached_subscriptions.keys())
- self.close_subscriptions(all_subscriptions)
-
- def check_and_restart_relays(self):
- stopped_relays = [r for r in self.relays.values() if r.shutdown]
- for relay in stopped_relays:
- self._restart_relay(relay)
+ threading.Thread(
+ target=relay.queue_worker, name=f"{relay.url}-queue", daemon=True
+ ).start()
def close_connections(self):
for relay in self.relays.values():
@@ -101,43 +54,16 @@ class RelayManager:
def publish_message(self, message: str):
for relay in self.relays.values():
- relay.publish(message)
+ if relay.policy.should_write:
+ relay.publish(message)
- def handle_notice(self, notice: NoticeMessage):
- relay = next((r for r in self.relays.values() if r.url == notice.url))
- if relay:
- relay.add_notice(notice.content)
+ def publish_event(self, event: Event):
+ """Verifies that the Event is publishable before submitting it to relays"""
+ if event.signature is None:
+ raise RelayException(f"Could not publish {event.id}: must be signed")
- def _open_connection(self, relay: Relay):
- self.threads[relay.url] = threading.Thread(
- target=relay.connect,
- name=f"{relay.url}-thread",
- daemon=True,
- )
- self.threads[relay.url].start()
-
- def wrap_async_queue_worker():
- asyncio.run(relay.queue_worker())
-
- self.queue_threads[relay.url] = threading.Thread(
- target=wrap_async_queue_worker,
- name=f"{relay.url}-queue",
- daemon=True,
- )
- self.queue_threads[relay.url].start()
-
- def _restart_relay(self, relay: Relay):
- time_since_last_error = time.time() - relay.last_error_date
-
- min_wait_time = min(
- 60 * relay.error_counter, 60 * 60
- ) # try at least once an hour
- if time_since_last_error < min_wait_time:
- return
-
- logger.info(f"Restarting connection to relay '{relay.url}'")
-
- self.remove_relay(relay.url)
- new_relay = self.add_relay(relay.url)
- new_relay.error_counter = relay.error_counter
- new_relay.error_list = relay.error_list
+ if not event.verify():
+ raise RelayException(
+ f"Could not publish {event.id}: failed to verify signature {event.signature}"
+ )
+ self.publish_message(event.to_message())
diff --git a/nostr/subscription.py b/nostr/subscription.py
index ed60f7e..7afba20 100644
--- a/nostr/subscription.py
+++ b/nostr/subscription.py
@@ -1,7 +1,12 @@
-from typing import Optional
-
+from .filter import Filters
class Subscription:
- def __init__(self, id: str, filters: Optional[list[str]] = None) -> None:
+ def __init__(self, id: str, filters: Filters=None) -> None:
self.id = id
self.filters = filters
+
+ def to_json_object(self):
+ return {
+ "id": self.id,
+ "filters": self.filters.to_json_array()
+ }
diff --git a/package-lock.json b/package-lock.json
deleted file mode 100644
index 1180ffb..0000000
--- a/package-lock.json
+++ /dev/null
@@ -1,59 +0,0 @@
-{
- "name": "nostrclient",
- "version": "1.0.0",
- "lockfileVersion": 3,
- "requires": true,
- "packages": {
- "": {
- "name": "nostrclient",
- "version": "1.0.0",
- "license": "ISC",
- "dependencies": {
- "prettier": "^3.2.5",
- "pyright": "^1.1.358"
- }
- },
- "node_modules/fsevents": {
- "version": "2.3.3",
- "resolved": "https://registry.npmjs.org/fsevents/-/fsevents-2.3.3.tgz",
- "integrity": "sha512-5xoDfX+fL7faATnagmWPpbFtwh/R77WmMMqqHGS65C3vvB0YHrgF+B1YmZ3441tMj5n63k0212XNoJwzlhffQw==",
- "hasInstallScript": true,
- "optional": true,
- "os": [
- "darwin"
- ],
- "engines": {
- "node": "^8.16.0 || ^10.6.0 || >=11.0.0"
- }
- },
- "node_modules/prettier": {
- "version": "3.3.3",
- "resolved": "https://registry.npmjs.org/prettier/-/prettier-3.3.3.tgz",
- "integrity": "sha512-i2tDNA0O5IrMO757lfrdQZCc2jPNDVntV0m/+4whiDfWaTKfMNgR7Qz0NAeGz/nRqF4m5/6CLzbP4/liHt12Ew==",
- "bin": {
- "prettier": "bin/prettier.cjs"
- },
- "engines": {
- "node": ">=14"
- },
- "funding": {
- "url": "https://github.com/prettier/prettier?sponsor=1"
- }
- },
- "node_modules/pyright": {
- "version": "1.1.374",
- "resolved": "https://registry.npmjs.org/pyright/-/pyright-1.1.374.tgz",
- "integrity": "sha512-ISbC1YnYDYrEatoKKjfaA5uFIp0ddC/xw9aSlN/EkmwupXUMVn41Jl+G6wHEjRhC+n4abHZeGpEvxCUus/K9dA==",
- "bin": {
- "pyright": "index.js",
- "pyright-langserver": "langserver.index.js"
- },
- "engines": {
- "node": ">=14.0.0"
- },
- "optionalDependencies": {
- "fsevents": "~2.3.3"
- }
- }
- }
-}
diff --git a/package.json b/package.json
deleted file mode 100644
index 7b84315..0000000
--- a/package.json
+++ /dev/null
@@ -1,15 +0,0 @@
-{
- "name": "nostrclient",
- "version": "1.0.0",
- "description": "",
- "main": "index.js",
- "scripts": {
- "test": "echo \"Error: no test specified\" && exit 1"
- },
- "author": "",
- "license": "ISC",
- "dependencies": {
- "prettier": "^3.2.5",
- "pyright": "^1.1.358"
- }
-}
diff --git a/pyproject.toml b/pyproject.toml
deleted file mode 100644
index c7c0e56..0000000
--- a/pyproject.toml
+++ /dev/null
@@ -1,98 +0,0 @@
-[project]
-name = "lnbits-nostrclient"
-version = "1.1.0"
-requires-python = ">=3.10,<3.13"
-description = "LNbits, free and open-source Lightning wallet and accounts system."
-authors = [{ name = "Alan Bits", email = "alan@lnbits.com" }]
-urls = { Homepage = "https://lnbits.com", Repository = "https://github.com/lnbits/nostrclient" }
-dependencies = [ "lnbits>1" ]
-
-[tool.poetry]
-package-mode = false
-
-[tool.uv]
-dev-dependencies = [
- "black",
- "pytest-asyncio",
- "pytest",
- "mypy",
- "pre-commit",
- "ruff",
- "pytest-md",
- "types-cffi",
-]
-
-[tool.mypy]
-exclude = "(nostr/*)"
-plugins = ["pydantic.mypy"]
-
-[[tool.mypy.overrides]]
-module = [
- "nostr.*",
-]
-follow_imports = "skip"
-ignore_missing_imports = "True"
-
-[tool.pydantic-mypy]
-init_forbid_extra = true
-init_typed = true
-warn_required_dynamic_aliases = true
-warn_untyped_fields = true
-
-[tool.pytest.ini_options]
-log_cli = false
-testpaths = [
- "tests"
-]
-
-[tool.black]
-line-length = 88
-
-[tool.ruff]
-# Same as Black. + 10% rule of black
-line-length = 88
-exclude = [
- "nostr",
-]
-
-[tool.ruff.lint]
-# Enable:
-# F - pyflakes
-# E - pycodestyle errors
-# W - pycodestyle warnings
-# I - isort
-# A - flake8-builtins
-# C - mccabe
-# N - naming
-# UP - pyupgrade
-# RUF - ruff
-# B - bugbear
-select = ["F", "E", "W", "I", "A", "C", "N", "UP", "RUF", "B"]
-ignore = ["C901"]
-
-# Allow autofix for all enabled rules (when `--fix`) is provided.
-fixable = ["ALL"]
-unfixable = []
-
-# Allow unused variables when underscore-prefixed.
-dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
-
-# needed for pydantic
-[tool.ruff.lint.pep8-naming]
-classmethod-decorators = [
- "root_validator",
-]
-
-# Ignore unused imports in __init__.py files.
-# [tool.ruff.lint.extend-per-file-ignores]
-# "__init__.py" = ["F401", "F403"]
-
-# [tool.ruff.lint.mccabe]
-# max-complexity = 10
-
-[tool.ruff.lint.flake8-bugbear]
-# Allow default arguments like, e.g., `data: List[str] = fastapi.Query(None)`.
-extend-immutable-calls = [
- "fastapi.Depends",
- "fastapi.Query",
-]
diff --git a/router.py b/router.py
deleted file mode 100644
index a7054e9..0000000
--- a/router.py
+++ /dev/null
@@ -1,175 +0,0 @@
-import asyncio
-import json
-from typing import ClassVar
-
-from fastapi import WebSocket, WebSocketDisconnect
-from lnbits.helpers import urlsafe_short_hash
-from loguru import logger
-
-from .nostr.client.client import NostrClient
-
-# from . import nostr_client
-from .nostr.message_pool import EndOfStoredEventsMessage, EventMessage, NoticeMessage
-
-nostr_client: NostrClient = NostrClient()
-all_routers: list["NostrRouter"] = []
-
-
-class NostrRouter:
- received_subscription_events: ClassVar[dict[str, list[EventMessage]]] = {}
- received_subscription_notices: ClassVar[list[NoticeMessage]] = []
- received_subscription_eosenotices: ClassVar[dict[str, EndOfStoredEventsMessage]] = (
- {}
- )
-
- def __init__(self, websocket: WebSocket):
- self.connected: bool = True
- self.websocket: WebSocket = websocket
- self.tasks: list[asyncio.Task] = []
- self.original_subscription_ids: dict[str, str] = {}
-
- @property
- def subscriptions(self) -> list[str]:
- return list(self.original_subscription_ids.keys())
-
- def start(self):
- self.connected = True
- self.tasks.append(asyncio.create_task(self._client_to_nostr()))
- self.tasks.append(asyncio.create_task(self._nostr_to_client()))
-
- async def stop(self):
- nostr_client.relay_manager.close_subscriptions(self.subscriptions)
- self.connected = False
-
- for t in self.tasks:
- try:
- t.cancel()
- except Exception as _:
- pass
-
- try:
- await self.websocket.close(reason="Websocket connection closed")
- except Exception as _:
- pass
-
- async def _client_to_nostr(self):
- """
- Receives requests / data from the client and forwards it to relays.
- """
- while self.connected:
- try:
- json_str = await self.websocket.receive_text()
- except WebSocketDisconnect as e:
- logger.debug(e)
- await self.stop()
- break
-
- try:
- await self._handle_client_to_nostr(json_str)
- except Exception as e:
- logger.debug(f"Failed to handle client message: '{e!s}'.")
-
- async def _nostr_to_client(self):
- """Sends responses from relays back to the client."""
- while self.connected:
- try:
- await self._handle_subscriptions()
- self._handle_notices()
- except Exception as e:
- logger.debug(f"Failed to handle response for client: '{e!s}'.")
- await asyncio.sleep(1)
- await asyncio.sleep(0.1)
-
- async def _handle_subscriptions(self):
- for s in self.subscriptions:
- if s in NostrRouter.received_subscription_events:
- await self._handle_received_subscription_events(s)
- if s in NostrRouter.received_subscription_eosenotices:
- await self._handle_received_subscription_eosenotices(s)
-
- async def _handle_received_subscription_eosenotices(self, s):
- try:
- if s not in self.original_subscription_ids:
- return
- s_original = self.original_subscription_ids[s]
- event_to_forward = ["EOSE", s_original]
- del NostrRouter.received_subscription_eosenotices[s]
-
- await self.websocket.send_text(json.dumps(event_to_forward))
- except Exception as e:
- logger.debug(e)
-
- async def _handle_received_subscription_events(self, s):
- try:
- if s not in NostrRouter.received_subscription_events:
- return
-
- while len(NostrRouter.received_subscription_events[s]):
- event_message = NostrRouter.received_subscription_events[s].pop(0)
- event_json = event_message.event
-
- # this reconstructs the original response from the relay
- # reconstruct original subscription id
- s_original = self.original_subscription_ids[s]
- event_to_forward = json.dumps(
- ["EVENT", s_original, json.loads(event_json)]
- )
- await self.websocket.send_text(event_to_forward)
- except Exception as e:
- logger.warning(
- f"[NOSTRCLIENT] Error in _handle_received_subscription_events: {e}"
- )
-
- def _handle_notices(self):
- while len(NostrRouter.received_subscription_notices):
- my_event = NostrRouter.received_subscription_notices.pop(0)
- logger.debug(f"[Relay '{my_event.url}'] Notice: '{my_event.content}']")
- # Note: we don't send it to the user because
- # we don't know who should receive it
- nostr_client.relay_manager.handle_notice(my_event)
-
- async def _handle_client_to_nostr(self, json_str):
- json_data = json.loads(json_str)
- assert len(json_data), "Bad JSON array"
-
- if json_data[0] == "REQ":
- self._handle_client_req(json_data)
- return
-
- if json_data[0] == "CLOSE":
- self._handle_client_close(json_data[1])
- return
-
- if json_data[0] == "EVENT":
- nostr_client.relay_manager.publish_message(json_str)
- return
-
- def _handle_client_req(self, json_data):
- subscription_id = json_data[1]
- logger.info(f"New subscription: '{subscription_id}'")
- subscription_id_rewritten = urlsafe_short_hash()
- self.original_subscription_ids[subscription_id_rewritten] = subscription_id
- filters = json_data[2:]
-
- nostr_client.relay_manager.add_subscription(subscription_id_rewritten, filters)
-
- def _handle_client_close(self, subscription_id):
- subscription_id_rewritten = next(
- (
- k
- for k, v in self.original_subscription_ids.items()
- if v == subscription_id
- ),
- None,
- )
- if subscription_id_rewritten:
- self.original_subscription_ids.pop(subscription_id_rewritten)
- nostr_client.relay_manager.close_subscription(subscription_id_rewritten)
- logger.info(
- f"""
- Unsubscribe from '{subscription_id_rewritten}'.
- Original id: '{subscription_id}.'
- """
- )
- else:
- logger.info(f"Failed to unsubscribe from '{subscription_id}.'")
diff --git a/services.py b/services.py
new file mode 100644
index 0000000..e03ad1d
--- /dev/null
+++ b/services.py
@@ -0,0 +1,146 @@
+import asyncio
+import json
+from typing import List, Union
+
+from fastapi import WebSocket, WebSocketDisconnect
+
+from lnbits.helpers import urlsafe_short_hash
+
+from .models import Event, Filter, Filters, Relay, RelayList
+from .nostr.client.client import NostrClient as NostrClientLib
+from .nostr.event import Event as NostrEvent
+from .nostr.filter import Filter as NostrFilter
+from .nostr.filter import Filters as NostrFilters
+from .nostr.message_pool import EndOfStoredEventsMessage, EventMessage, NoticeMessage
+
+received_subscription_events: dict[str, list[Event]] = {}
+received_subscription_notices: dict[str, list[NoticeMessage]] = {}
+received_subscription_eosenotices: dict[str, EndOfStoredEventsMessage] = {}
+
+
+class NostrClient:
+ def __init__(self):
+ self.client: NostrClientLib = NostrClientLib(connect=False)
+
+
+nostr = NostrClient()
+
+
+class NostrRouter:
+ def __init__(self, websocket):
+ self.subscriptions: List[str] = []
+ self.connected: bool = True
+ self.websocket = websocket
+ self.tasks: List[asyncio.Task] = []
+ self.subscription_id_rewrite: str = urlsafe_short_hash()
+
+ async def client_to_nostr(self):
+ """Receives requests / data from the client and forwards it to relays. If the
+ request was a subscription/filter, registers it with the nostr client lib.
+ Remembers the subscription id so we can send back responses from the relay to this
+ client in `nostr_to_client`"""
+ while True:
+ try:
+ json_str = await self.websocket.receive_text()
+ except WebSocketDisconnect:
+ self.connected = False
+ break
+
+ # registers a subscription if the input was a REQ request
+ subscription_id, json_str_rewritten = await self._add_nostr_subscription(
+ json_str
+ )
+ if subscription_id and json_str_rewritten:
+ self.subscriptions.append(subscription_id)
+ json_str = json_str_rewritten
+
+ # publish data
+ nostr.client.relay_manager.publish_message(json_str)
+
+ async def nostr_to_client(self):
+ """Sends responses from relays back to the client. Polls the subscriptions of this client
+ stored in `my_subscriptions`. Then gets all responses for this subscription id from `received_subscription_events` which
+ is filled in tasks.py. Takes one response after the other and relays it back to the client. Reconstructs
+ the reponse manually because the nostr client lib we're using can't do it. Reconstructs the original subscription id
+ that we had previously rewritten in order to avoid collisions when multiple clients use the same id."""
+ while True and self.connected:
+ for s in self.subscriptions:
+ if s in received_subscription_events:
+ while len(received_subscription_events[s]):
+ my_event = received_subscription_events[s].pop(0)
+ # event.to_message() does not include the subscription ID, we have to add it manually
+ event_json = {
+ "id": my_event.id,
+ "pubkey": my_event.public_key,
+ "created_at": my_event.created_at,
+ "kind": my_event.kind,
+ "tags": my_event.tags,
+ "content": my_event.content,
+ "sig": my_event.signature,
+ }
+
+ # this reconstructs the original response from the relay
+ # reconstruct original subscription id
+ s_original = s[len(f"{self.subscription_id_rewrite}_") :]
+ event_to_forward = ["EVENT", s_original, event_json]
+ # print(json.dumps(event_to_forward))
+
+ # send data back to client
+ await self.websocket.send_text(json.dumps(event_to_forward))
+ if s in received_subscription_eosenotices:
+ my_event = received_subscription_eosenotices[s]
+ s_original = s[len(f"{self.subscription_id_rewrite}_") :]
+ event_to_forward = ["EOSE", s_original]
+ del received_subscription_eosenotices[s]
+ # send data back to client
+ await self.websocket.send_text(json.dumps(event_to_forward))
+ await asyncio.sleep(0.1)
+
+ async def start(self):
+ self.tasks.append(asyncio.create_task(self.client_to_nostr()))
+ self.tasks.append(asyncio.create_task(self.nostr_to_client()))
+
+ async def stop(self):
+ for t in self.tasks:
+ t.cancel()
+ self.connected = False
+
+ def _marshall_nostr_filters(self, data: Union[dict, list]):
+ filters = data if isinstance(data, list) else [data]
+ filters = [Filter.parse_obj(f) for f in filters]
+ filter_list: list[NostrFilter] = []
+ for filter in filters:
+ filter_list.append(
+ NostrFilter(
+ event_ids=filter.ids, # type: ignore
+ kinds=filter.kinds, # type: ignore
+ authors=filter.authors, # type: ignore
+ since=filter.since, # type: ignore
+ until=filter.until, # type: ignore
+ event_refs=filter.e, # type: ignore
+ pubkey_refs=filter.p, # type: ignore
+ limit=filter.limit, # type: ignore
+ )
+ )
+ return NostrFilters(filter_list)
+
+ async def _add_nostr_subscription(self, json_str):
+ """Parses a (string) request from a client. If it is a subscription (REQ), it will
+ register the subscription in the nostr client library that we're using so we can
+ receive the callbacks on it later. Will rewrite the subscription id since we expect
+ multiple clients to use the router and want to avoid subscription id collisions"""
+ json_data = json.loads(json_str)
+ assert len(json_data)
+ if json_data[0] == "REQ":
+ subscription_id = json_data[1]
+ subscription_id_rewritten = (
+ f"{self.subscription_id_rewrite}_{subscription_id}"
+ )
+ fltr = json_data[2]
+ filters = self._marshall_nostr_filters(fltr)
+ nostr.client.relay_manager.add_subscription(
+ subscription_id_rewritten, filters
+ )
+ request_rewritten = json.dumps(["REQ", subscription_id_rewritten, fltr])
+ return subscription_id_rewritten, request_rewritten
+ return None, None
diff --git a/static/images/1.jpeg b/static/images/1.jpeg
deleted file mode 100644
index 1f00661..0000000
Binary files a/static/images/1.jpeg and /dev/null differ
diff --git a/static/images/2.jpeg b/static/images/2.jpeg
deleted file mode 100644
index e7301fd..0000000
Binary files a/static/images/2.jpeg and /dev/null differ
diff --git a/tasks.py b/tasks.py
index 2a76765..790337c 100644
--- a/tasks.py
+++ b/tasks.py
@@ -1,71 +1,91 @@
import asyncio
+import ssl
+import json
import threading
-from loguru import logger
-
from .crud import get_relays
+from .nostr.event import Event
+from .nostr.key import PublicKey
from .nostr.message_pool import EndOfStoredEventsMessage, EventMessage, NoticeMessage
-from .router import NostrRouter, nostr_client
+from .nostr.relay_manager import RelayManager
+from .services import (
+ nostr,
+ received_subscription_eosenotices,
+ received_subscription_events,
+)
async def init_relays():
+ # we save any subscriptions teporarily to re-add them after reinitializing the client
+ subscriptions = {}
+ for relay in nostr.client.relay_manager.relays.values():
+ # relay.add_subscription(id, filters)
+ for subscription_id, filters in relay.subscriptions.items():
+ subscriptions[subscription_id] = filters
+
+ # reinitialize the entire client
+ nostr.__init__()
# get relays from db
relays = await get_relays()
# set relays and connect to them
- valid_relays = [r.url for r in relays if r.url]
+ nostr.client.relays = list(set([r.url for r in relays.__root__ if r.url]))
+ nostr.client.connect()
- nostr_client.reconnect(valid_relays)
-
-
-async def check_relays():
- """Check relays that have been disconnected"""
- while True:
- try:
- await asyncio.sleep(20)
- nostr_client.relay_manager.check_and_restart_relays()
- except Exception as e:
- logger.warning(f"Cannot restart relays: '{e!s}'.")
+ await asyncio.sleep(2)
+ # re-add subscriptions
+ for subscription_id, subscription in subscriptions.items():
+ nostr.client.relay_manager.add_subscription(
+ subscription_id, subscription.filters
+ )
+ s = subscription.to_json_object()
+ json_str = json.dumps(["REQ", s["id"], s["filters"][0]])
+ nostr.client.relay_manager.publish_message(json_str)
+ return
async def subscribe_events():
- while not [r.connected for r in nostr_client.relay_manager.relays.values()]:
+ while not any([r.connected for r in nostr.client.relay_manager.relays.values()]):
await asyncio.sleep(2)
- def callback_events(event_message: EventMessage):
- sub_id = event_message.subscription_id
- if sub_id not in NostrRouter.received_subscription_events:
- NostrRouter.received_subscription_events[sub_id] = [event_message]
- return
+ def callback_events(eventMessage: EventMessage):
+ # print(f"From {event.public_key[:3]}..{event.public_key[-3:]}: {event.content}")
+ if eventMessage.subscription_id in received_subscription_events:
+ # do not add duplicate events (by event id)
+ if eventMessage.event.id in set(
+ [
+ e.id
+ for e in received_subscription_events[eventMessage.subscription_id]
+ ]
+ ):
+ return
- # do not add duplicate events (by event id)
- ids = [e.event_id for e in NostrRouter.received_subscription_events[sub_id]]
- if event_message.event_id in ids:
- return
-
- NostrRouter.received_subscription_events[sub_id].append(event_message)
-
- def callback_notices(notice_message: NoticeMessage):
- if notice_message not in NostrRouter.received_subscription_notices:
- NostrRouter.received_subscription_notices.append(notice_message)
-
- def callback_eose_notices(event_message: EndOfStoredEventsMessage):
- sub_id = event_message.subscription_id
- if sub_id in NostrRouter.received_subscription_eosenotices:
- return
-
- NostrRouter.received_subscription_eosenotices[sub_id] = event_message
-
- def wrap_async_subscribe():
- asyncio.run(
- nostr_client.subscribe(
- callback_events,
- callback_notices,
- callback_eose_notices,
+ received_subscription_events[eventMessage.subscription_id].append(
+ eventMessage.event
)
- )
+ else:
+ received_subscription_events[eventMessage.subscription_id] = [
+ eventMessage.event
+ ]
+ return
+
+ def callback_notices(eventMessage: NoticeMessage):
+ return
+
+ def callback_eose_notices(eventMessage: EndOfStoredEventsMessage):
+ if eventMessage.subscription_id not in received_subscription_eosenotices:
+ received_subscription_eosenotices[
+ eventMessage.subscription_id
+ ] = eventMessage
+
+ return
t = threading.Thread(
- target=wrap_async_subscribe,
+ target=nostr.client.subscribe,
+ args=(
+ callback_events,
+ callback_notices,
+ callback_eose_notices,
+ ),
name="Nostr-event-subscription",
daemon=True,
)
diff --git a/templates/nostrclient/index.html b/templates/nostrclient/index.html
index ca44f1b..abe5a87 100644
--- a/templates/nostrclient/index.html
+++ b/templates/nostrclient/index.html
@@ -2,52 +2,6 @@
%} {% block page %} {% raw %}
+