Skip to content
This repository was archived by the owner on Sep 26, 2022. It is now read-only.
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 57 additions & 42 deletions common/net/bolt1.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from common.tools import is_256b_hex_str

from common.net.tlv import NetworksTLV
from common.net.tlv import TLVRecord, NetworksTLV
from common.net.bolt9 import FeatureVector
from common.net.utils import message_sanity_checks

Expand All @@ -24,6 +24,19 @@ class Message:
"""

def __init__(self, mtype, payload, extension=None):
if not isinstance(mtype, bytes):
raise TypeError("mtype must be bytes")
if not isinstance(payload, bytes):
raise TypeError("payload must be bytes")
if extension is not None and not isinstance(extension, list):
raise TypeError("extension must be a list if set")
else:
# Normalize the default extension type (for empty lists)
if not extension:
extension = None
elif not all(isinstance(tlv, TLVRecord) for tlv in extension):
raise TypeError("All items in extension must be TLVRecords")

self.type = mtype
self.payload = payload
self.extension = extension
Expand All @@ -45,25 +58,29 @@ def from_bytes(cls, message):
"""

if not isinstance(message, bytes):
raise TypeError(f"message be must a bytearray")
raise TypeError("message be must a bytearray")
if len(message) < 2:
raise ValueError(f"message be must at least 2-byte long")
raise ValueError("message be must at least 2-byte long")

mtype = message[:2]

if message[:2] == message_types["init"]:
if mtype == message_types["init"]:
return InitMessage.from_bytes(message)
elif message[:2] == message_types["error"]:
elif mtype == message_types["error"]:
return ErrorMessage.from_bytes(message)
elif message[:2] == message_types["ping"]:
elif mtype == message_types["ping"]:
return PingMessage.from_bytes(message)
elif message[:2] == message_types["pong"]:
elif mtype == message_types["pong"]:
return PongMessage.from_bytes(message)
else:
raise ValueError("Cannot decode unknown message type")

def serialize(self):
"""Serialises the message."""
Comment thread
sr-gi marked this conversation as resolved.
Outdated
if not self.extension:
return self.type + self.payload
else:
tlvs = b"".join([tlv.serialize() for tlv in self.extension()])
tlvs = b"".join([tlv.serialize() for tlv in self.extension])
return self.type + self.payload + tlvs


Expand All @@ -84,16 +101,16 @@ def __init__(self, global_features, local_features, networks=None):
if not isinstance(networks, NetworksTLV):
raise TypeError("networks must be of type NetworksTLV (if set)")

global_features = global_features.serialize()
local_features = local_features.serialize()
gflen = len(global_features).to_bytes(2, "big")
flen = len(local_features).to_bytes(2, "big")
payload = gflen + global_features + flen + local_features
gf = global_features.serialize()
lf = local_features.serialize()
gflen = len(gf).to_bytes(2, "big")
flen = len(lf).to_bytes(2, "big")
payload = gflen + gf + flen + lf

# Add extensions if needed (this follows TLV format)
# FIXME: Only networks for now
if networks:
super().__init__(mtype=message_types["init"], payload=payload, extension=networks.serialize())
super().__init__(mtype=message_types["init"], payload=payload, extension=[networks])
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice, this addresses my earlier TLV vs. TLVStream comment 👍

else:
super().__init__(mtype=message_types["init"], payload=payload)
self.global_features = global_features
Expand All @@ -103,7 +120,6 @@ def __init__(self, global_features, local_features, networks=None):
@classmethod
def from_bytes(cls, message):
"""Builds an InitMessage from its byte representation."""

# Message should be at least: type (2-byte) + gflen (2-byte) + flen (2 byte)
message_sanity_checks(message, message_types["init"], 6)

Expand All @@ -112,6 +128,8 @@ def from_bytes(cls, message):
global_features = FeatureVector.from_bytes(message[4 : gflen + 4])
flen = int.from_bytes(message[gflen + 4 : gflen + 6], "big")
local_features = FeatureVector.from_bytes(message[gflen + 6 : gflen + flen + 6])
Comment on lines +133 to +136
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This bookkeeping would also be easier if you used an io instance 😉

if gflen + flen + 6 > len(message):
raise ValueError() # Unexpected EOF
Comment thread
bigspider marked this conversation as resolved.

# Check if there are TLVs (optional)
if len(message) > gflen + flen + 6:
Expand All @@ -121,7 +139,7 @@ def from_bytes(cls, message):

return cls(global_features, local_features)

except (IndexError, ValueError):
except ValueError:
raise ValueError("Wrong message format. Unexpected EOF")


Expand Down Expand Up @@ -167,14 +185,13 @@ def from_bytes(cls, message):

# There's associated data
if data_len:
try:
data = message[36 : 36 + data_len]
if len(message) != 36 + data_len:
raise ValueError("Wrong data format. message has additional tailing data")
return cls(channel_id, data)
data = message[36 : 36 + data_len]

except IndexError:
if len(data) < data_len:
raise ValueError("Wrong message format. Unexpected EOF")
if len(message) > 36 + data_len:
raise ValueError("Wrong data format. message has additional tailing data")
Comment thread
sr-gi marked this conversation as resolved.
Outdated
return cls(channel_id, data.decode("utf-8"))

return cls(channel_id)

Expand All @@ -189,16 +206,18 @@ class PingMessage(Message):
"""

def __init__(self, num_pong_bytes, ignored_bytes=None):
if num_pong_bytes > pow(2, 16):
raise ValueError(f"num_pong_bytes cannot be higher than {pow(2, 16)}")
if ignored_bytes and not isinstance(ignored_bytes, bytes):
raise TypeError("ignored_bytes must be bytes if set")
if len(ignored_bytes) > pow(2, 16) - 4:
raise ValueError(f"ignored_bytes cannot be higher than {pow(2, 16) -4}")
if not 0 <= num_pong_bytes < pow(2, 16):
raise ValueError(f"num_pong_bytes must be between 0 and {pow(2, 16)}")
Comment thread
bigspider marked this conversation as resolved.
Outdated

payload = num_pong_bytes.to_bytes(2, "big")

if ignored_bytes:
if not isinstance(ignored_bytes, bytes):
raise TypeError("ignored_bytes must be bytes if set")
if len(ignored_bytes) > pow(2, 16) - 4:
raise ValueError(f"ignored_bytes cannot be higher than {pow(2, 16) - 4}")
payload += len(ignored_bytes).to_bytes(2, "big") + ignored_bytes

super().__init__(message_types["ping"], payload)
self.num_pong_bytes = num_pong_bytes
self.ignored_bytes = ignored_bytes
Expand All @@ -212,14 +231,12 @@ def from_bytes(cls, message):
byteslen = int.from_bytes(message[4:6], "big")

if byteslen:
try:
ignored = message[6 : 6 + byteslen]
if len(message) != 6 + byteslen:
raise ValueError("Wrong data format. message has additional tailing data")
return cls(num_pong_bytes, ignored)

except IndexError:
ignored = message[6 : 6 + byteslen]
if len(ignored) < byteslen:
raise ValueError("Wrong message format. Unexpected EOF")
if len(message) > 6 + byteslen:
raise ValueError("Wrong data format. message has additional tailing data")
Comment thread
bigspider marked this conversation as resolved.
Outdated
return cls(num_pong_bytes, ignored)

return cls(num_pong_bytes)

Expand Down Expand Up @@ -256,13 +273,11 @@ def from_bytes(cls, message):
byteslen = int.from_bytes(message[2:4], "big")

if byteslen:
try:
ignored_bytes = message[4 : 4 + byteslen]
if len(message) != 4 + byteslen:
raise ValueError("Wrong data format. message has additional tailing data")
return cls(ignored_bytes)

except IndexError:
ignored_bytes = message[4 : 4 + byteslen]
if len(ignored_bytes) < byteslen:
raise ValueError("Wrong message format. Unexpected EOF")
if len(message) != 4 + byteslen:
raise ValueError("Wrong data format. message has additional tailing data")
Comment thread
sr-gi marked this conversation as resolved.
Outdated
return cls(ignored_bytes)

return cls()