diff --git a/nostr/delegation.py b/nostr/delegation.py index 94801f5..47e35f5 100644 --- a/nostr/delegation.py +++ b/nostr/delegation.py @@ -1,4 +1,5 @@ import time +from typing import Optional from dataclasses import dataclass @@ -8,7 +9,7 @@ class Delegation: delegatee_pubkey: str event_kind: int duration_secs: int = 30*24*60 # default to 30 days - signature: str = None # set in PrivateKey.sign_delegation + signature: Optional[str] = None # set in PrivateKey.sign_delegation @property def expires(self) -> int: diff --git a/nostr/event.py b/nostr/event.py index b6b8ccf..3eb3962 100644 --- a/nostr/event.py +++ b/nostr/event.py @@ -1,4 +1,5 @@ import time +from typing import Optional import json from dataclasses import dataclass, field from enum import IntEnum @@ -23,12 +24,12 @@ class EventKind(IntEnum): @dataclass class Event: - content: str = None - public_key: str = None - created_at: int = None - kind: int = EventKind.TEXT_NOTE + content: Optional[str] = None + public_key: Optional[str] = None + created_at: Optional[int] = None + kind: Optional[int] = EventKind.TEXT_NOTE tags: List[List[str]] = field(default_factory=list) # Dataclasses require special handling when the default value is a mutable type - signature: str = None + signature: Optional[str] = None def __post_init__(self): @@ -79,29 +80,39 @@ def verify(self) -> bool: return pub_key.schnorr_verify(bytes.fromhex(self.id), bytes.fromhex(self.signature), None, raw=True) - def to_message(self) -> str: - return json.dumps( - [ - ClientMessageType.EVENT, - { - "id": self.id, - "pubkey": self.public_key, - "created_at": self.created_at, - "kind": self.kind, - "tags": self.tags, - "content": self.content, - "sig": self.signature - } - ] + @classmethod + def from_dict(cls, msg: dict) -> 'Event': + # "id" is ignore, as it will be computed from the contents + return Event( + content=msg['content'], + public_key=msg['pubkey'], + created_at=msg['created_at'], + kind=msg['kind'], + tags=msg['tags'], + signature=msg['sig'], ) + def to_dict(self) -> dict: + return { + "id": self.id, + "pubkey": self.public_key, + "created_at": self.created_at, + "kind": self.kind, + "tags": self.tags, + "content": self.content, + "sig": self.signature + } + + def to_message(self) -> str: + return json.dumps([ClientMessageType.EVENT, self.to_dict()]) + @dataclass class EncryptedDirectMessage(Event): - recipient_pubkey: str = None - cleartext_content: str = None - reference_event_id: str = None + recipient_pubkey: Optional[str] = None + cleartext_content: Optional[str] = None + reference_event_id: Optional[str] = None def __post_init__(self): diff --git a/nostr/filter.py b/nostr/filter.py index f4cb0a5..a531dcb 100644 --- a/nostr/filter.py +++ b/nostr/filter.py @@ -1,5 +1,5 @@ from collections import UserList -from typing import List +from typing import List, Optional from .event import Event, EventKind @@ -16,20 +16,21 @@ class Filter: added. For example: # arbitrary tag filter.add_arbitrary_tag('t', [hashtags]) - + # promoted to explicit support Filter(hashtag_refs=[hashtags]) """ def __init__( - self, - event_ids: List[str] = None, - kinds: List[EventKind] = None, - authors: List[str] = None, - since: int = None, - until: int = None, - event_refs: List[str] = None, # the "#e" attr; list of event ids referenced in an "e" tag - pubkey_refs: List[str] = None, # The "#p" attr; list of pubkeys referenced in a "p" tag - limit: int = None) -> None: + self, + event_ids: Optional[List[str]] = None, + kinds: Optional[List[EventKind]] = None, + authors: Optional[List[str]] = None, + since: Optional[int] = None, + until: Optional[int] = None, + event_refs: Optional[List[str]] = None, # the "#e" attr; list of event ids referenced in an "e" tag + pubkey_refs: Optional[List[str]] = None, # The "#p" attr; list of pubkeys referenced in a "p" tag + limit: Optional[int] = None, + ) -> None: self.event_ids = event_ids self.kinds = kinds self.authors = authors @@ -128,4 +129,4 @@ def match(self, event: Event): return False def to_json_array(self) -> list: - return [filter.to_json_object() for filter in self.data] \ No newline at end of file + return [filter.to_json_object() for filter in self.data] diff --git a/nostr/key.py b/nostr/key.py index 511b198..13ce533 100644 --- a/nostr/key.py +++ b/nostr/key.py @@ -1,10 +1,11 @@ import secrets import base64 -import secp256k1 +import secp256k1 # type: ignore from cffi import FFI from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes from cryptography.hazmat.primitives import padding from hashlib import sha256 +from typing import cast, Optional from .delegation import Delegation from .event import EncryptedDirectMessage, Event, EventKind @@ -35,7 +36,7 @@ def from_npub(cls, npub: str): class PrivateKey: - def __init__(self, raw_secret: bytes=None) -> None: + def __init__(self, raw_secret: Optional[bytes]=None) -> None: if not raw_secret is None: self.raw_secret = raw_secret else: @@ -77,7 +78,7 @@ def encrypt_message(self, message: str, public_key_hex: str) -> str: encrypted_message = encryptor.update(padded_data) + encryptor.finalize() return f"{base64.b64encode(encrypted_message).decode()}?iv={base64.b64encode(iv).decode()}" - + def encrypt_dm(self, dm: EncryptedDirectMessage) -> None: dm.content = self.encrypt_message(message=dm.cleartext_content, public_key_hex=dm.recipient_pubkey) @@ -104,7 +105,8 @@ def sign_message_hash(self, hash: bytes) -> str: def sign_event(self, event: Event) -> None: if event.kind == EventKind.ENCRYPTED_DIRECT_MESSAGE and event.content is None: - self.encrypt_dm(event) + edm = cast(EncryptedDirectMessage, event) + self.encrypt_dm(edm) if event.public_key is None: event.public_key = self.public_key.hex() event.signature = self.sign_message_hash(bytes.fromhex(event.id)) @@ -116,7 +118,7 @@ def __eq__(self, other): return self.raw_secret == other.raw_secret -def mine_vanity_key(prefix: str = None, suffix: str = None) -> PrivateKey: +def mine_vanity_key(prefix: Optional[str] = None, suffix: Optional[str] = None) -> PrivateKey: if prefix is None and suffix is None: raise ValueError("Expected at least one of 'prefix' or 'suffix' arguments") diff --git a/nostr/message_pool.py b/nostr/message_pool.py index 8abcf60..418da8b 100644 --- a/nostr/message_pool.py +++ b/nostr/message_pool.py @@ -54,15 +54,7 @@ def _process_message(self, message: str, url: str): message_type = message_json[0] if message_type == RelayMessageType.EVENT: subscription_id = message_json[1] - e = message_json[2] - event = Event( - e["content"], - e["pubkey"], - e["created_at"], - e["kind"], - e["tags"], - e["sig"], - ) + event = Event.from_dict(message_json[2]) with self.lock: if not event.id in self._unique_events: self.events.put(EventMessage(event, subscription_id, url)) diff --git a/nostr/relay.py b/nostr/relay.py index 103e497..f56c17d 100644 --- a/nostr/relay.py +++ b/nostr/relay.py @@ -1,14 +1,20 @@ import json +import logging + from dataclasses import dataclass from threading import Lock from typing import Optional -from websocket import WebSocketApp +from websocket import WebSocketApp, WebSocketConnectionClosedException from .event import Event from .filter import Filters from .message_pool import MessagePool from .message_type import RelayMessageType from .subscription import Subscription + +logger = logging.getLogger('nostr') + + @dataclass class RelayPolicy: should_read: bool = True @@ -35,7 +41,7 @@ class Relay: url: str message_pool: MessagePool policy: RelayPolicy = RelayPolicy() - proxy_config: RelayProxyConnectionConfig = RelayProxyConnectionConfig() + proxy_config: Optional[RelayProxyConnectionConfig] = None ssl_options: Optional[dict] = None def __post_init__(self): @@ -52,16 +58,23 @@ def __post_init__(self): def connect(self): self.ws.run_forever( sslopt=self.ssl_options, - http_proxy_host=self.proxy_config.host, - http_proxy_port=self.proxy_config.port, - proxy_type=self.proxy_config.type + http_proxy_host=self.proxy_config.host if self.proxy_config is not + None else None, + http_proxy_port=self.proxy_config.port if self.proxy_config is not + None else None, + proxy_type=self.proxy_config.type if self.proxy_config is not None + else None, ) def close(self): self.ws.close() def publish(self, message: str): - self.ws.send(message) + try: + self.ws.send(message) + except WebSocketConnectionClosedException as e: + logger.error("Connection closed on publish attempt. url=%s", + self.url) def add_subscription(self, id, filters: Filters): with self.lock: @@ -84,15 +97,19 @@ def to_json_object(self) -> dict: } def _on_open(self, class_obj): + logger.debug("Relay._on_open: url=%s", self.url) pass def _on_close(self, class_obj, status_code, message): + logger.debug("Relay._on_close: url=%s, code=%s, message=%s", self.url, + status_code, message) pass def _on_message(self, class_obj, message: str): self.message_pool.add_message(message, self.url) - + def _on_error(self, class_obj, error): + logger.debug("Relay._on_error: url=%s, error=%s", self.url, error) pass def _is_valid_message(self, message: str) -> bool: @@ -107,21 +124,13 @@ def _is_valid_message(self, message: str) -> bool: if message_type == RelayMessageType.EVENT: if not len(message_json) == 3: return False - + subscription_id = message_json[1] with self.lock: if subscription_id not in self.subscriptions: return False - e = message_json[2] - event = Event( - e["content"], - e["pubkey"], - e["created_at"], - e["kind"], - e["tags"], - e["sig"], - ) + event = Event.from_dict(message_json[2]) if not event.verify(): return False diff --git a/nostr/relay_manager.py b/nostr/relay_manager.py index bfaf254..e69ed96 100644 --- a/nostr/relay_manager.py +++ b/nostr/relay_manager.py @@ -1,6 +1,7 @@ import json import time import threading +from typing import Optional from dataclasses import dataclass from threading import Lock @@ -26,20 +27,21 @@ def __post_init__(self): self.lock: Lock = Lock() def add_relay( - self, - url: str, + self, + url: str, policy: RelayPolicy = RelayPolicy(), - ssl_options: dict = None, - proxy_config: RelayProxyConnectionConfig = None): + ssl_options: Optional[dict] = None, + proxy_config: Optional[RelayProxyConnectionConfig] = None): - relay = Relay(url, self.message_pool, policy, ssl_options, proxy_config) + relay = Relay(url, self.message_pool, policy, proxy_config, ssl_options) with self.lock: self.relays[url] = relay threading.Thread( target=relay.connect, - name=f"{relay.url}-thread" + name=f"{relay.url}-thread", + daemon=True, ).start() time.sleep(1) @@ -55,12 +57,12 @@ def add_subscription_on_relay(self, url: str, id: str, filters: Filters): if url in self.relays: relay = self.relays[url] if not relay.policy.should_read: - raise RelayException(f"Could not send request: {relay_url} is not configured to read from") + raise RelayException(f"Could not send request: {url} is not configured to read from") relay.add_subscription(id, filters) request = Request(id, filters) relay.publish(request.to_message()) else: - raise RelayException(f"Invalid relay url: no connection to {relay_url}") + raise RelayException(f"Invalid relay url: no connection to {url}") def add_subscription_on_all_relays(self, id: str, filters: Filters): with self.lock: @@ -77,7 +79,7 @@ def close_subscription_on_relay(self, url: str, id: str): relay.close_subscription(id) relay.publish(json.dumps(["CLOSE", id])) else: - raise RelayException(f"Invalid relay url: no connection to {relay_url}") + raise RelayException(f"Invalid relay url: no connection to {url}") def close_subscription_on_all_relays(self, id: str): with self.lock: diff --git a/nostr/subscription.py b/nostr/subscription.py index 7afba20..87fbc05 100644 --- a/nostr/subscription.py +++ b/nostr/subscription.py @@ -1,7 +1,8 @@ +from typing import Optional from .filter import Filters class Subscription: - def __init__(self, id: str, filters: Filters=None) -> None: + def __init__(self, id: str, filters: Optional[Filters]=None) -> None: self.id = id self.filters = filters diff --git a/test/test_event.py b/test/test_event.py index 10aa41a..8a0e18d 100644 --- a/test/test_event.py +++ b/test/test_event.py @@ -5,9 +5,11 @@ from nostr.event import Event, EncryptedDirectMessage from nostr.key import PrivateKey +import unittest -class TestEvent: + +class TestEvent(unittest.TestCase): def test_event_default_time(self): """ ensure created_at default value reflects the time at Event object instantiation @@ -17,7 +19,7 @@ def test_event_default_time(self): time.sleep(1.5) event2 = Event(content='test event') assert event1.created_at < event2.created_at - + def test_content_only_instantiation(self): """ should be able to create an Event by only specifying content without kwarg """ @@ -36,7 +38,7 @@ def test_event_id_recomputes(self): # Recomputed id should now be different assert event.id != event_id - + def test_note_id_bech32_conversion(self): """ should convert the event id to its `note`-prepended bech32 form """ @@ -91,6 +93,16 @@ def test_add_pubkey_ref(self): event.add_pubkey_ref(some_pubkey) assert ['p', some_pubkey] in event.tags + def test_dict_roundtrip(self): + """ conversion to dict and back result in same object """ + event = Event(content='test event', + created_at=12345678, + kind=1, + ) + event.add_pubkey_ref("some_pubkey") + + got = Event.from_dict(event.to_dict()) + self.assertEqual(got, event) class TestEncryptedDirectMessage: @@ -106,19 +118,19 @@ def test_content_field_moved_to_cleartext_content(self): dm = EncryptedDirectMessage(content="My message!", recipient_pubkey=self.recipient_pubkey) assert dm.content is None assert dm.cleartext_content is not None - + def test_nokwarg_content_allowed(self): """ Should allow creating a new DM w/no `content` nor `cleartext_content` kwarg """ dm = EncryptedDirectMessage("My message!", recipient_pubkey=self.recipient_pubkey) assert dm.cleartext_content is not None - + def test_recipient_p_tag(self): """ Should generate recipient 'p' tag """ dm = EncryptedDirectMessage(cleartext_content="Secret message!", recipient_pubkey=self.recipient_pubkey) assert ['p', self.recipient_pubkey] in dm.tags - + def test_unencrypted_dm_has_undefined_id(self): """ Should raise Exception if `id` is requested before DM is encrypted """