Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix code at head, pytype errors, and small Event refactoring #50

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
3 changes: 2 additions & 1 deletion nostr/delegation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import time
from typing import Optional
from dataclasses import dataclass


Expand All @@ -8,7 +9,7 @@ class Delegation:
delegatee_pubkey: str
event_kind: int
duration_secs: int = 30*24*60 # default to 30 days
signature: str = None # set in PrivateKey.sign_delegation
signature: Optional[str] = None # set in PrivateKey.sign_delegation

@property
def expires(self) -> int:
Expand Down
55 changes: 33 additions & 22 deletions nostr/event.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import time
from typing import Optional
import json
from dataclasses import dataclass, field
from enum import IntEnum
Expand All @@ -23,12 +24,12 @@ class EventKind(IntEnum):

@dataclass
class Event:
content: str = None
public_key: str = None
created_at: int = None
kind: int = EventKind.TEXT_NOTE
content: Optional[str] = None
public_key: Optional[str] = None
created_at: Optional[int] = None
kind: Optional[int] = EventKind.TEXT_NOTE
tags: List[List[str]] = field(default_factory=list) # Dataclasses require special handling when the default value is a mutable type
signature: str = None
signature: Optional[str] = None


def __post_init__(self):
Expand Down Expand Up @@ -79,29 +80,39 @@ def verify(self) -> bool:
return pub_key.schnorr_verify(bytes.fromhex(self.id), bytes.fromhex(self.signature), None, raw=True)


def to_message(self) -> str:
return json.dumps(
[
ClientMessageType.EVENT,
{
"id": self.id,
"pubkey": self.public_key,
"created_at": self.created_at,
"kind": self.kind,
"tags": self.tags,
"content": self.content,
"sig": self.signature
}
]
@classmethod
def from_dict(cls, msg: dict) -> 'Event':
# "id" is ignore, as it will be computed from the contents
return Event(
content=msg['content'],
public_key=msg['pubkey'],
created_at=msg['created_at'],
kind=msg['kind'],
tags=msg['tags'],
signature=msg['sig'],
)

def to_dict(self) -> dict:
return {
"id": self.id,
"pubkey": self.public_key,
"created_at": self.created_at,
"kind": self.kind,
"tags": self.tags,
"content": self.content,
"sig": self.signature
}

def to_message(self) -> str:
return json.dumps([ClientMessageType.EVENT, self.to_dict()])



@dataclass
class EncryptedDirectMessage(Event):
recipient_pubkey: str = None
cleartext_content: str = None
reference_event_id: str = None
recipient_pubkey: Optional[str] = None
cleartext_content: Optional[str] = None
reference_event_id: Optional[str] = None


def __post_init__(self):
Expand Down
25 changes: 13 additions & 12 deletions nostr/filter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections import UserList
from typing import List
from typing import List, Optional

from .event import Event, EventKind

Expand All @@ -16,20 +16,21 @@ class Filter:
added. For example:
# arbitrary tag
filter.add_arbitrary_tag('t', [hashtags])

# promoted to explicit support
Filter(hashtag_refs=[hashtags])
"""
def __init__(
self,
event_ids: List[str] = None,
kinds: List[EventKind] = None,
authors: List[str] = None,
since: int = None,
until: int = None,
event_refs: List[str] = None, # the "#e" attr; list of event ids referenced in an "e" tag
pubkey_refs: List[str] = None, # The "#p" attr; list of pubkeys referenced in a "p" tag
limit: int = None) -> None:
self,
event_ids: Optional[List[str]] = None,
kinds: Optional[List[EventKind]] = None,
authors: Optional[List[str]] = None,
since: Optional[int] = None,
until: Optional[int] = None,
event_refs: Optional[List[str]] = None, # the "#e" attr; list of event ids referenced in an "e" tag
pubkey_refs: Optional[List[str]] = None, # The "#p" attr; list of pubkeys referenced in a "p" tag
limit: Optional[int] = None,
) -> None:
self.event_ids = event_ids
self.kinds = kinds
self.authors = authors
Expand Down Expand Up @@ -128,4 +129,4 @@ def match(self, event: Event):
return False

def to_json_array(self) -> list:
return [filter.to_json_object() for filter in self.data]
return [filter.to_json_object() for filter in self.data]
12 changes: 7 additions & 5 deletions nostr/key.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import secrets
import base64
import secp256k1
import secp256k1 # type: ignore
from cffi import FFI
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.primitives import padding
from hashlib import sha256
from typing import cast, Optional

from .delegation import Delegation
from .event import EncryptedDirectMessage, Event, EventKind
Expand Down Expand Up @@ -35,7 +36,7 @@ def from_npub(cls, npub: str):


class PrivateKey:
def __init__(self, raw_secret: bytes=None) -> None:
def __init__(self, raw_secret: Optional[bytes]=None) -> None:
if not raw_secret is None:
self.raw_secret = raw_secret
else:
Expand Down Expand Up @@ -77,7 +78,7 @@ def encrypt_message(self, message: str, public_key_hex: str) -> str:
encrypted_message = encryptor.update(padded_data) + encryptor.finalize()

return f"{base64.b64encode(encrypted_message).decode()}?iv={base64.b64encode(iv).decode()}"

def encrypt_dm(self, dm: EncryptedDirectMessage) -> None:
dm.content = self.encrypt_message(message=dm.cleartext_content, public_key_hex=dm.recipient_pubkey)

Expand All @@ -104,7 +105,8 @@ def sign_message_hash(self, hash: bytes) -> str:

def sign_event(self, event: Event) -> None:
if event.kind == EventKind.ENCRYPTED_DIRECT_MESSAGE and event.content is None:
self.encrypt_dm(event)
edm = cast(EncryptedDirectMessage, event)
self.encrypt_dm(edm)
if event.public_key is None:
event.public_key = self.public_key.hex()
event.signature = self.sign_message_hash(bytes.fromhex(event.id))
Expand All @@ -116,7 +118,7 @@ def __eq__(self, other):
return self.raw_secret == other.raw_secret


def mine_vanity_key(prefix: str = None, suffix: str = None) -> PrivateKey:
def mine_vanity_key(prefix: Optional[str] = None, suffix: Optional[str] = None) -> PrivateKey:
if prefix is None and suffix is None:
raise ValueError("Expected at least one of 'prefix' or 'suffix' arguments")

Expand Down
10 changes: 1 addition & 9 deletions nostr/message_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,7 @@ def _process_message(self, message: str, url: str):
message_type = message_json[0]
if message_type == RelayMessageType.EVENT:
subscription_id = message_json[1]
e = message_json[2]
event = Event(
e["content"],
e["pubkey"],
e["created_at"],
e["kind"],
e["tags"],
e["sig"],
)
event = Event.from_dict(message_json[2])
with self.lock:
if not event.id in self._unique_events:
self.events.put(EventMessage(event, subscription_id, url))
Expand Down
43 changes: 26 additions & 17 deletions nostr/relay.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
import json
import logging

from dataclasses import dataclass
from threading import Lock
from typing import Optional
from websocket import WebSocketApp
from websocket import WebSocketApp, WebSocketConnectionClosedException
from .event import Event
from .filter import Filters
from .message_pool import MessagePool
from .message_type import RelayMessageType
from .subscription import Subscription


logger = logging.getLogger('nostr')


@dataclass
class RelayPolicy:
should_read: bool = True
Expand All @@ -35,7 +41,7 @@ class Relay:
url: str
message_pool: MessagePool
policy: RelayPolicy = RelayPolicy()
proxy_config: RelayProxyConnectionConfig = RelayProxyConnectionConfig()
proxy_config: Optional[RelayProxyConnectionConfig] = None
ssl_options: Optional[dict] = None

def __post_init__(self):
Expand All @@ -52,16 +58,23 @@ def __post_init__(self):
def connect(self):
self.ws.run_forever(
sslopt=self.ssl_options,
http_proxy_host=self.proxy_config.host,
http_proxy_port=self.proxy_config.port,
proxy_type=self.proxy_config.type
http_proxy_host=self.proxy_config.host if self.proxy_config is not
None else None,
http_proxy_port=self.proxy_config.port if self.proxy_config is not
None else None,
proxy_type=self.proxy_config.type if self.proxy_config is not None
else None,
)

def close(self):
self.ws.close()

def publish(self, message: str):
self.ws.send(message)
try:
self.ws.send(message)
except WebSocketConnectionClosedException as e:
logger.error("Connection closed on publish attempt. url=%s",
self.url)

def add_subscription(self, id, filters: Filters):
with self.lock:
Expand All @@ -84,15 +97,19 @@ def to_json_object(self) -> dict:
}

def _on_open(self, class_obj):
logger.debug("Relay._on_open: url=%s", self.url)
pass

def _on_close(self, class_obj, status_code, message):
logger.debug("Relay._on_close: url=%s, code=%s, message=%s", self.url,
status_code, message)
pass

def _on_message(self, class_obj, message: str):
self.message_pool.add_message(message, self.url)

def _on_error(self, class_obj, error):
logger.debug("Relay._on_error: url=%s, error=%s", self.url, error)
pass

def _is_valid_message(self, message: str) -> bool:
Expand All @@ -107,21 +124,13 @@ def _is_valid_message(self, message: str) -> bool:
if message_type == RelayMessageType.EVENT:
if not len(message_json) == 3:
return False

subscription_id = message_json[1]
with self.lock:
if subscription_id not in self.subscriptions:
return False

e = message_json[2]
event = Event(
e["content"],
e["pubkey"],
e["created_at"],
e["kind"],
e["tags"],
e["sig"],
)
event = Event.from_dict(message_json[2])
if not event.verify():
return False

Expand Down
20 changes: 11 additions & 9 deletions nostr/relay_manager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import time
import threading
from typing import Optional
from dataclasses import dataclass
from threading import Lock

Expand All @@ -26,20 +27,21 @@ def __post_init__(self):
self.lock: Lock = Lock()

def add_relay(
self,
url: str,
self,
url: str,
policy: RelayPolicy = RelayPolicy(),
ssl_options: dict = None,
proxy_config: RelayProxyConnectionConfig = None):
ssl_options: Optional[dict] = None,
proxy_config: Optional[RelayProxyConnectionConfig] = None):

relay = Relay(url, self.message_pool, policy, ssl_options, proxy_config)
relay = Relay(url, self.message_pool, policy, proxy_config, ssl_options)

with self.lock:
self.relays[url] = relay

threading.Thread(
target=relay.connect,
name=f"{relay.url}-thread"
name=f"{relay.url}-thread",
daemon=True,
).start()

time.sleep(1)
Expand All @@ -55,12 +57,12 @@ def add_subscription_on_relay(self, url: str, id: str, filters: Filters):
if url in self.relays:
relay = self.relays[url]
if not relay.policy.should_read:
raise RelayException(f"Could not send request: {relay_url} is not configured to read from")
raise RelayException(f"Could not send request: {url} is not configured to read from")
relay.add_subscription(id, filters)
request = Request(id, filters)
relay.publish(request.to_message())
else:
raise RelayException(f"Invalid relay url: no connection to {relay_url}")
raise RelayException(f"Invalid relay url: no connection to {url}")

def add_subscription_on_all_relays(self, id: str, filters: Filters):
with self.lock:
Expand All @@ -77,7 +79,7 @@ def close_subscription_on_relay(self, url: str, id: str):
relay.close_subscription(id)
relay.publish(json.dumps(["CLOSE", id]))
else:
raise RelayException(f"Invalid relay url: no connection to {relay_url}")
raise RelayException(f"Invalid relay url: no connection to {url}")

def close_subscription_on_all_relays(self, id: str):
with self.lock:
Expand Down
3 changes: 2 additions & 1 deletion nostr/subscription.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import Optional
from .filter import Filters

class Subscription:
def __init__(self, id: str, filters: Filters=None) -> None:
def __init__(self, id: str, filters: Optional[Filters]=None) -> None:
self.id = id
self.filters = filters

Expand Down
Loading