feat: code quality (#59)

* feat: code quality
* fixup!
This commit is contained in:
dni ⚡ 2024-08-05 11:49:50 +02:00 committed by GitHub
commit badc420069
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
29 changed files with 3048 additions and 212 deletions

View file

@ -23,21 +23,25 @@
from enum import Enum
class Encoding(Enum):
"""Enumeration type to list the various supported encodings."""
BECH32 = 1
BECH32M = 2
CHARSET = "qpzry9x8gf2tvdw0s3jn54khce6mua7l"
BECH32M_CONST = 0x2bc830a3
BECH32M_CONST = 0x2BC830A3
def bech32_polymod(values):
"""Internal function that computes the Bech32 checksum."""
generator = [0x3b6a57b2, 0x26508e6d, 0x1ea119fa, 0x3d4233dd, 0x2a1462b3]
generator = [0x3B6A57B2, 0x26508E6D, 0x1EA119FA, 0x3D4233DD, 0x2A1462B3]
chk = 1
for value in values:
top = chk >> 25
chk = (chk & 0x1ffffff) << 5 ^ value
chk = (chk & 0x1FFFFFF) << 5 ^ value
for i in range(5):
chk ^= generator[i] if ((top >> i) & 1) else 0
return chk
@ -57,37 +61,41 @@ def bech32_verify_checksum(hrp, data):
return Encoding.BECH32M
return None
def bech32_create_checksum(hrp, data, spec):
"""Compute the checksum values given HRP and data."""
values = bech32_hrp_expand(hrp) + data
const = BECH32M_CONST if spec == Encoding.BECH32M else 1
polymod = bech32_polymod(values + [0, 0, 0, 0, 0, 0]) ^ const
polymod = bech32_polymod([*values, 0, 0, 0, 0, 0, 0]) ^ const
return [(polymod >> 5 * (5 - i)) & 31 for i in range(6)]
def bech32_encode(hrp, data, spec):
"""Compute a Bech32 string given HRP and data values."""
combined = data + bech32_create_checksum(hrp, data, spec)
return hrp + '1' + ''.join([CHARSET[d] for d in combined])
return hrp + "1" + "".join([CHARSET[d] for d in combined])
def bech32_decode(bech):
"""Validate a Bech32/Bech32m string, and determine HRP and data."""
if ((any(ord(x) < 33 or ord(x) > 126 for x in bech)) or
(bech.lower() != bech and bech.upper() != bech)):
if (any(ord(x) < 33 or ord(x) > 126 for x in bech)) or (
bech.lower() != bech and bech.upper() != bech
):
return (None, None, None)
bech = bech.lower()
pos = bech.rfind('1')
pos = bech.rfind("1")
if pos < 1 or pos + 7 > len(bech) or len(bech) > 90:
return (None, None, None)
if not all(x in CHARSET for x in bech[pos+1:]):
if not all(x in CHARSET for x in bech[pos + 1 :]):
return (None, None, None)
hrp = bech[:pos]
data = [CHARSET.find(x) for x in bech[pos+1:]]
data = [CHARSET.find(x) for x in bech[pos + 1 :]]
spec = bech32_verify_checksum(hrp, data)
if spec is None:
return (None, None, None)
return (hrp, data[:-6], spec)
def convertbits(data, frombits, tobits, pad=True):
"""General power-of-2 base conversion."""
acc = 0
@ -114,6 +122,7 @@ def convertbits(data, frombits, tobits, pad=True):
def decode(hrp, addr):
"""Decode a segwit address."""
hrpgot, data, spec = bech32_decode(addr)
assert data, "Invalid bech32 string"
if hrpgot != hrp:
return (None, None)
decoded = convertbits(data[1:], 5, 8, False)
@ -123,7 +132,12 @@ def decode(hrp, addr):
return (None, None)
if data[0] == 0 and len(decoded) != 20 and len(decoded) != 32:
return (None, None)
if data[0] == 0 and spec != Encoding.BECH32 or data[0] != 0 and spec != Encoding.BECH32M:
if (
data[0] == 0
and spec != Encoding.BECH32
or data[0] != 0
and spec != Encoding.BECH32M
):
return (None, None)
return (data[0], decoded)
@ -131,7 +145,9 @@ def decode(hrp, addr):
def encode(hrp, witver, witprog):
"""Encode a segwit address."""
spec = Encoding.BECH32 if witver == 0 else Encoding.BECH32M
ret = bech32_encode(hrp, [witver] + convertbits(witprog, 8, 5), spec)
bits = convertbits(witprog, 8, 5)
assert bits, "Invalid witness program"
ret = bech32_encode(hrp, [witver, *bits], spec)
if decode(hrp, ret) == (None, None):
return None
return ret

View file

@ -1,10 +1,11 @@
import time
import json
import time
from dataclasses import dataclass, field
from enum import IntEnum
from typing import List
from secp256k1 import PublicKey
from hashlib import sha256
from typing import List, Optional
from secp256k1 import PublicKey
from .message_type import ClientMessageType
@ -20,14 +21,14 @@ class EventKind(IntEnum):
@dataclass
class Event:
content: str = None
public_key: str = None
created_at: int = None
content: Optional[str] = None
public_key: Optional[str] = None
created_at: Optional[int] = None
kind: int = EventKind.TEXT_NOTE
tags: List[List[str]] = field(
default_factory=list
) # Dataclasses require special handling when the default value is a mutable type
signature: str = None
signature: Optional[str] = None
def __post_init__(self):
if self.content is not None and not isinstance(self.content, str):
@ -56,6 +57,9 @@ class Event:
@property
def id(self) -> str:
# Always recompute the id to reflect the up-to-date state of the Event
assert self.public_key, "Event public key is missing"
assert self.created_at, "Event created_at is missing"
assert self.content, "Event content is missing"
return Event.compute_id(
self.public_key, self.created_at, self.kind, self.tags, self.content
)
@ -69,9 +73,11 @@ class Event:
self.tags.append(["e", event_id])
def verify(self) -> bool:
assert self.public_key, "Event public key is missing"
pub_key = PublicKey(
bytes.fromhex("02" + self.public_key), True
) # add 02 for schnorr (bip340)
assert self.signature, "Event signature is missing"
return pub_key.schnorr_verify(
bytes.fromhex(self.id), bytes.fromhex(self.signature), None, raw=True
)
@ -95,9 +101,9 @@ class Event:
@dataclass
class EncryptedDirectMessage(Event):
recipient_pubkey: str = None
cleartext_content: str = None
reference_event_id: str = None
recipient_pubkey: Optional[str] = None
cleartext_content: Optional[str] = None
reference_event_id: Optional[str] = None
def __post_init__(self):
if self.content is not None:
@ -121,6 +127,7 @@ class EncryptedDirectMessage(Event):
def id(self) -> str:
if self.content is None:
raise Exception(
"EncryptedDirectMessage `id` is undefined until its message is encrypted and stored in the `content` field"
"EncryptedDirectMessage `id` is undefined until its "
"message is encrypted and stored in the `content` field"
)
return super().id

View file

@ -1,13 +1,14 @@
import secrets
import base64
import secrets
from typing import Optional
import secp256k1
from cffi import FFI
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.primitives import padding
from hashlib import sha256
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from .event import EncryptedDirectMessage, Event, EventKind
from . import bech32
from .event import EncryptedDirectMessage, EventKind
class PublicKey:
@ -21,33 +22,41 @@ class PublicKey:
def hex(self) -> str:
return self.raw_bytes.hex()
def verify_signed_message_hash(self, hash: str, sig: str) -> bool:
def verify_signed_message_hash(self, message_hash: str, sig: str) -> bool:
pk = secp256k1.PublicKey(b"\x02" + self.raw_bytes, True)
return pk.schnorr_verify(bytes.fromhex(hash), bytes.fromhex(sig), None, True)
return pk.schnorr_verify(
bytes.fromhex(message_hash), bytes.fromhex(sig), None, True
)
@classmethod
def from_npub(cls, npub: str):
"""Load a PublicKey from its bech32/npub form"""
hrp, data, spec = bech32.bech32_decode(npub)
raw_public_key = bech32.convertbits(data, 5, 8)[:-1]
assert data, "Invalid npub"
bits = bech32.convertbits(data, 5, 8)
assert bits, "Invalid npub"
raw_public_key = bits[:-1]
return cls(bytes(raw_public_key))
class PrivateKey:
def __init__(self, raw_secret: bytes = None) -> None:
if not raw_secret is None:
def __init__(self, raw_secret: Optional[bytes] = None) -> None:
if raw_secret is not None:
self.raw_secret = raw_secret
else:
self.raw_secret = secrets.token_bytes(32)
sk = secp256k1.PrivateKey(self.raw_secret)
assert sk.pubkey, "Invalid public"
self.public_key = PublicKey(sk.pubkey.serialize()[1:])
@classmethod
def from_nsec(cls, nsec: str):
"""Load a PrivateKey from its bech32/nsec form"""
hrp, data, spec = bech32.bech32_decode(nsec)
raw_secret = bech32.convertbits(data, 5, 8)[:-1]
bits = bech32.convertbits(data, 5, 8)
assert bits, "Invalid nsec"
raw_secret = bits[:-1]
return cls(bytes(raw_secret))
def bech32(self) -> str:
@ -77,9 +86,12 @@ class PrivateKey:
encryptor = cipher.encryptor()
encrypted_message = encryptor.update(padded_data) + encryptor.finalize()
return f"{base64.b64encode(encrypted_message).decode()}?iv={base64.b64encode(iv).decode()}"
msg = base64.b64encode(encrypted_message).decode()
return f"{msg}?iv={base64.b64encode(iv).decode()}"
def encrypt_dm(self, dm: EncryptedDirectMessage) -> None:
assert dm.recipient_pubkey, "Recipient public key must be set"
assert dm.cleartext_content, "Cleartext content must be set"
dm.content = self.encrypt_message(
message=dm.cleartext_content, public_key_hex=dm.recipient_pubkey
)
@ -102,12 +114,12 @@ class PrivateKey:
return unpadded_data.decode()
def sign_message_hash(self, hash: bytes) -> str:
def sign_message_hash(self, message_hash: bytes) -> str:
sk = secp256k1.PrivateKey(self.raw_secret)
sig = sk.schnorr_sign(hash, None, raw=True)
sig = sk.schnorr_sign(message_hash, None, raw=True)
return sig.hex()
def sign_event(self, event: Event) -> None:
def sign_event(self, event: EncryptedDirectMessage) -> None:
if event.kind == EventKind.ENCRYPTED_DIRECT_MESSAGE and event.content is None:
self.encrypt_dm(event)
if event.public_key is None:
@ -118,7 +130,9 @@ class PrivateKey:
return self.raw_secret == other.raw_secret
def mine_vanity_key(prefix: str = None, suffix: str = None) -> PrivateKey:
def mine_vanity_key(
prefix: Optional[str] = None, suffix: Optional[str] = None
) -> PrivateKey:
if prefix is None and suffix is None:
raise ValueError("Expected at least one of 'prefix' or 'suffix' arguments")

View file

@ -3,13 +3,18 @@ class ClientMessageType:
REQUEST = "REQ"
CLOSE = "CLOSE"
class RelayMessageType:
EVENT = "EVENT"
NOTICE = "NOTICE"
END_OF_STORED_EVENTS = "EOSE"
@staticmethod
def is_valid(type: str) -> bool:
if type == RelayMessageType.EVENT or type == RelayMessageType.NOTICE or type == RelayMessageType.END_OF_STORED_EVENTS:
def is_valid(relay_type: str) -> bool:
if (
relay_type == RelayMessageType.EVENT
or relay_type == RelayMessageType.NOTICE
or relay_type == RelayMessageType.END_OF_STORED_EVENTS
):
return True
return False
return False