diff --git a/common/net/__init__.py b/common/net/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/common/net/bigsize.py b/common/net/bigsize.py new file mode 100644 index 00000000..30303b66 --- /dev/null +++ b/common/net/bigsize.py @@ -0,0 +1,117 @@ +def encode(value): + """ + Encodes a value to BigSize. + + Args: + value (:obj:`int`): the integer value to be encoded. + + Returns: + :obj:`bytes`: the BigSize encoding of the given value. + + Raises: + :obj:`TypeError`: If the provided value is not an integer. + :obj:`ValueError`: If the provided value is negative or bigger than ``pow(2, 64) - 1``. + """ + + if not isinstance(value, int): + raise TypeError(f"value must be integer, {type(value)} received") + + if value < 0: + raise ValueError(f"value must be a positive integer, {value} received") + + if value < 253: + return value.to_bytes(1, "big") + elif value < pow(2, 16): + return b"\xfd" + value.to_bytes(2, "big") + elif value < pow(2, 32): + return b"\xfe" + value.to_bytes(4, "big") + elif value < pow(2, 64): + return b"\xff" + value.to_bytes(8, "big") + else: + raise ValueError("BigSize can only encode up to 8-byte values") + + +def decode(value): + """ + Decodes a value from BigSize. + + Args: + value (:obj:`bytes`): the value to be decoded. + + Returns: + :obj:`int`: the integer decoding of the provided value. + + Raises: + :obj:`TypeError`: If the provided value is not in bytes. + :obj:`ValueError`: If the provided value is bigger than 9-bytes or the value is not properly encoded. + """ + + if not isinstance(value, bytes): + raise TypeError(f"value must be bytes, {type(value)} received") + + if len(value) == 0: + raise ValueError("Unexpected EOF while decoding BigSize") + + if len(value) > 9: + raise ValueError(f"value must be, at most, 9-bytes long, {len(value)} received") + + if len(value) == 1 and value[0] < 253: + return value[0] + + prefix = value[0] + decoded_value = int.from_bytes(value[1:], "big") + + if prefix == 253: + length = 3 + min_v = 253 + max_v = pow(2, 16) + elif prefix == 254: + length = 5 + min_v = pow(2, 16) + max_v = pow(2, 32) + else: + length = 9 + min_v = pow(2, 32) + max_v = pow(2, 64) + + if not len(value) == length: + raise ValueError("Unexpected EOF while decoding BigSize") + elif not min_v <= decoded_value < max_v: + raise ValueError("Encoded BigSize is non-canonical") + else: + return decoded_value + + +def parse(value): + """ + Parses a BigSize from a bytearray. + + Args: + value (:obj:`bytes`): the bytearray from where the BigSize value will be parsed. + + Returns: + :obj:`tuple`: A 2 items tuple containing the parsed BigSize and its encoded length. + + Raises: + :obj:`TypeError`: If the provided value is not in bytes. + :obj:`ValueError`: If the provided value is not, at least, 1-byte long or if the value cannot be parsed. + """ + + if not isinstance(value, bytes): + raise TypeError("value must be bytes") + if len(value) < 1: + raise ValueError("value must be at least 1-byte long") + + prefix = value[0] + + # message length is not explicitly checked here, but wrong length will fail at decode. + if prefix < 253: + # prefix is actually the value to be parsed + return decode(value[0:1]), 1 + else: + if prefix == 253: + return decode(value[0:3]), 3 + elif prefix == 254: + return decode(value[0:5]), 5 + else: + return decode(value[0:9]), 9 diff --git a/common/net/bolt1.py b/common/net/bolt1.py new file mode 100644 index 00000000..b09b7b34 --- /dev/null +++ b/common/net/bolt1.py @@ -0,0 +1,288 @@ +from common.tools import is_256b_hex_str + +from common.net.tlv import TLVRecord, NetworksTLV +from common.net.bolt9 import FeatureVector +from common.net.utils import message_sanity_checks + + +message_types = {"init": b"\x00\x10", "error": b"\x00\x11", "ping": b"\x00\x12", "pong": b"\x00\x13"} + + +class Message: + """ + Message class. Used as a base class for all other messages. + + Args: + mtype (:obj:`bytes`): the message type. + payload (:obj:`bytes`): the message payload. + extension (:obj:`bytes`): the message extension, if any (optional). + + Attributes: + type (:obj:`bytes`): the message type. + payload (:obj:`bytes`): the message payload. + extension (:obj:`bytes`): the message extension, if any (optional). + """ + + def __init__(self, mtype, payload, extension=None): + # Normalize the default extension type to empty list + if extension is None: + extension = [] + if not isinstance(mtype, bytes): + raise TypeError("mtype must be bytes") + elif not isinstance(payload, bytes): + raise TypeError("payload must be bytes") + elif not isinstance(extension, list): + raise TypeError("extension must be a list if set") + elif not all(isinstance(tlv, TLVRecord) for tlv in extension): + raise TypeError("All items in extension must be TLVRecords") + + self.type = mtype + self.payload = payload + self.extension = extension + + @classmethod + def from_bytes(cls, message): + """ + Builds a message from its byte representation. + + Args: + message (:obj:`bytes`): the byte-encoded message. + + Returns: + The Message children class depending on the received message type. Check ``message_types`` for more info. + + Raises: + :obj:`TypeError`: If the given message is not in bytes. + :obj:`ValueError`: If the message can not be parsed. + """ + + if not isinstance(message, bytes): + raise TypeError("message be must a bytearray") + if len(message) < 2: + raise ValueError("message be must at least 2-byte long") + + mtype = message[:2] + + if mtype == message_types["init"]: + return InitMessage.from_bytes(message) + elif mtype == message_types["error"]: + return ErrorMessage.from_bytes(message) + elif mtype == message_types["ping"]: + return PingMessage.from_bytes(message) + elif mtype == message_types["pong"]: + return PongMessage.from_bytes(message) + else: + raise ValueError("Cannot decode unknown message type") + + def serialize(self): + """Serializes the message.""" + if not self.extension: + return self.type + self.payload + else: + tlvs = b"".join([tlv.serialize() for tlv in self.extension]) + return self.type + self.payload + tlvs + + def to_dict(self): + return { + "type": self.type.hex(), + "payload": self.payload.hex(), + "extension": [extension.serialize().hex() for extension in self.extension], + } + + +class InitMessage(Message): + """ + First message exchange by the nodes, it reveals the features supported by each end. + + Args: + global_features (:obj:`FeatureVector `): the global features vector. + local_features (:obj:`FeatureVector `): the local features vector. + networks (:obj:`NetworksTLV `): a networks tlv (optional). + """ + + def __init__(self, global_features, local_features, networks=None): + if not (isinstance(global_features, FeatureVector) and isinstance(local_features, FeatureVector)): + raise TypeError("global_features and local_features must be FeatureVector instances") + if networks: + if not isinstance(networks, NetworksTLV): + raise TypeError("networks must be of type NetworksTLV (if set)") + + gf = global_features.serialize() + lf = local_features.serialize() + gflen = len(gf).to_bytes(2, "big") + flen = len(lf).to_bytes(2, "big") + payload = gflen + gf + flen + lf + + # Add extensions if needed (this follows TLV format) + # FIXME: Only networks for now + if networks: + super().__init__(mtype=message_types["init"], payload=payload, extension=[networks]) + else: + super().__init__(mtype=message_types["init"], payload=payload) + self.global_features = global_features + self.local_features = local_features + self.networks = networks + + @classmethod + def from_bytes(cls, message): + """Builds an InitMessage from its byte representation.""" + # Message should be at least: type (2-byte) + gflen (2-byte) + flen (2 byte) + message_sanity_checks(message, message_types["init"], 6) + + try: + gflen = int.from_bytes(message[2:4], "big") + global_features = FeatureVector.from_bytes(message[4 : gflen + 4]) + flen = int.from_bytes(message[gflen + 4 : gflen + 6], "big") + local_features = FeatureVector.from_bytes(message[gflen + 6 : gflen + flen + 6]) + if gflen + flen + 6 > len(message): + raise ValueError() # Unexpected EOF + + # Check if there are TLVs (optional) + if len(message) > gflen + flen + 6: + # FIXME: Only accepting networks TLV for now + networks = NetworksTLV.from_bytes(message[gflen + flen + 6 :]) + return cls(global_features, local_features, networks) + + return cls(global_features, local_features) + + except ValueError: + raise ValueError("Wrong message format. Unexpected EOF") + + +class ErrorMessage(Message): + """ + Message for error reporting. + + Args: + channel_id (:obj:`str`): a 32-byte long hex str identifying the channel that originated the error, or 0 if it + refers to all channels. + data (:obj:`str`): the error message. + """ + + def __init__(self, channel_id, data=""): + if not is_256b_hex_str(channel_id): + raise ValueError("channel_id must be a 256-bit hex string") + + if not isinstance(data, str): + raise ValueError("data must be string if set") + + payload = bytes.fromhex(channel_id) + encoded_message = data.encode("utf-8") + + if len(encoded_message) >= pow(2, 16): + raise ValueError( + f"Encoded data length cannot be bigger than {pow(2, 16) - 1}, {len(encoded_message)} received" + ) + + payload += len(encoded_message).to_bytes(2, "big") + encoded_message + + super().__init__(message_types["error"], payload) + self.channel_id = channel_id + self.data = data + + @classmethod + def from_bytes(cls, message): + """Builds an ErrorMessage from its byte representation.""" + # Message should be at least: type (2-byte) + channel_id (32-byte) + data_len (2-bytes) + message_sanity_checks(message, message_types["error"], 36) + channel_id = message[2:34].hex() + data_len = int.from_bytes(message[34:36], "big") + + # There's associated data + if data_len: + data = message[36 : 36 + data_len] + + if len(data) < data_len: + raise ValueError("Wrong message format. Unexpected EOF") + if len(message) > 36 + data_len: + raise ValueError("Wrong data format. message has additional trailing data") + return cls(channel_id, data.decode("utf-8")) + + return cls(channel_id) + + +class PingMessage(Message): + """ + Message to test the reachability of the other side of the channel. Useful to allow long lived communications. + + Args: + num_pong_bytes (:obj:`int`): the number of bytes to be responded by the peer. + ignored_bytes (:obj:`bytes`): filling bytes added to the message by the sender. + """ + + def __init__(self, num_pong_bytes, ignored_bytes=b""): + if not 0 <= num_pong_bytes < pow(2, 16): + raise ValueError(f"num_pong_bytes must be between 0 and {pow(2, 16) - 1}") + + payload = num_pong_bytes.to_bytes(2, "big") + + if ignored_bytes: + if not isinstance(ignored_bytes, bytes): + raise TypeError("ignored_bytes must be bytes if set") + if len(ignored_bytes) > pow(2, 16) - 4: + raise ValueError(f"ignored_bytes cannot be higher than {pow(2, 16) - 4}") + payload += len(ignored_bytes).to_bytes(2, "big") + ignored_bytes + + super().__init__(message_types["ping"], payload) + self.num_pong_bytes = num_pong_bytes + self.ignored_bytes = ignored_bytes + + @classmethod + def from_bytes(cls, message): + """Builds a PingMessage from its byte representation.""" + # Message should be at least: type (2-byte) + num_pong_bytes (2-bytes) + byteslen (2-bytes) + message_sanity_checks(message, message_types["ping"], 6) + num_pong_bytes = int.from_bytes(message[2:4], "big") + byteslen = int.from_bytes(message[4:6], "big") + + if byteslen: + ignored = message[6 : 6 + byteslen] + if len(ignored) < byteslen: + raise ValueError("Wrong message format. Unexpected EOF") + if len(message) > 6 + byteslen: + raise ValueError("Wrong data format. message has additional trailing data") + return cls(num_pong_bytes, ignored) + + return cls(num_pong_bytes) + + +class PongMessage(Message): + """ + Message to be sent in response to a ``PingMessage``. + + Args: + ignored_bytes (:obj:`bytes`): filling bytes added to the message by the sender. Should match the ones requested + by the sender of the ``PingMessage``. + """ + + def __init__(self, ignored_bytes=b""): + if ignored_bytes: + if not isinstance(ignored_bytes, bytes): + raise TypeError("ignored_bytes must be bytes if set") + if len(ignored_bytes) > pow(2, 16) - 4: + raise ValueError(f"ignored_bytes cannot be higher than {pow(2, 16) - 4}") + + payload = len(ignored_bytes).to_bytes(2, "big") + ignored_bytes + + else: + payload = int.to_bytes(0, 2, "big") + + super().__init__(message_types["pong"], payload) + self.ignored_bytes = ignored_bytes + + @classmethod + def from_bytes(cls, message): + """Builds a PongMessage from its byte representation.""" + # Message should be at least: type (2-byte) + byteslen (2-bytes) + message_sanity_checks(message, message_types["pong"], 4) + byteslen = int.from_bytes(message[2:4], "big") + + if byteslen: + ignored_bytes = message[4 : 4 + byteslen] + if len(ignored_bytes) < byteslen: + raise ValueError("Wrong message format. Unexpected EOF") + if len(message) != 4 + byteslen: + raise ValueError("Wrong data format. message has additional trailing data") + return cls(ignored_bytes) + + return cls() diff --git a/common/net/bolt9.py b/common/net/bolt9.py new file mode 100644 index 00000000..89347f48 --- /dev/null +++ b/common/net/bolt9.py @@ -0,0 +1,189 @@ +from math import ceil + +# feature_name: odd_bit +known_features = { + "option_data_loss_protect": 0, + "initial_routing_sync": 2, + "option_upfront_shutdown_script": 4, + "gossip_queries": 6, + "var_onion_optin": 8, + "gossip_queries_ex": 10, + "option_static_remotekey": 12, + "payment_secret": 14, + "basic_mpp": 16, + "option_support_large_channel": 18, + "option_anchor_outputs": 20, +} + +# Reversed map -> odd_bit : feature_name +known_odd_bits = {v: k for k, v in known_features.items()} + + +def check_feature_name_bit_pair(name, bit): + """ + Checks whether a given name and bit pair match for known features. + + For unknown features, it returns True as long as they are not using a known name nor bit. + + Args: + name (:obj:`str`): the feature name. + bit (:obj:`int`): the bit position. + + Returns: + :obj:`bool`: For known features, returns True if the pair matches. For unknown features, returns True if the bit + is unknown. + """ + + if name in known_features: + # The pair matches + return bit in [known_features[name], known_features[name] + 1] + else: + # The name and bit are unknown + return not (bit in known_features.values() or bit + 1 in known_features.values()) + + +class Feature: + """ + Feature represents a feature bit. + + Args: + bit (:obj:`int`): the index that the feature bit holds in the feature vector. + is_set (:obj:`bool`): whether the feature is set or not. + + Attributes: + is_odd (:obj:`bool`): whether the bit is odd or even. + """ + + def __init__(self, bit, is_set): + if not isinstance(bit, int): + raise TypeError("bit must be int") + if not isinstance(is_set, bool): + raise TypeError("is_set must be bool") + + self.bit = bit + self.is_set = is_set + self.is_odd = bool(self.bit % 2) + + +class FeatureVector: + """The FeatureVector encapsulates all the features.""" + + def __init__(self, **kwargs): + self._features = {} + for key, value in kwargs.items(): + if not isinstance(value, Feature): + raise TypeError(f"Features must be of type Feature, {type(value)} received") + elif key == "initial_routing_sync" and value.is_set and not value.is_odd: + raise ValueError("initial_routing_sync has no even bit") + elif not check_feature_name_bit_pair(key, value.bit): + raise ValueError("Feature name and bit do not match") + + vars(self)[key] = value + self._features[key] = value + + for name in set(known_features.keys()).difference(kwargs.keys()): + vars(self)[name] = Feature(known_features[name], is_set=False) + self._features[name] = vars(self)[name] + + @classmethod + def from_bytes(cls, features): + """ + Builds the FeatureVector from its byte representation. + + Unknown features are parsed as unknown_i where i is the odd_byte of the encoded feature. + + Args: + features (:obj:`bytes`): the byte-encoded feature vector. + + Returns: + :obj:`FeatureVector`: The FeatureVector created from the given bytes. + + Raises: + :obj:`TypeError`: If the provided features are not in bytes. + :obj:`ValueError`: If two bits from the same pair are set. Or if there is a mismatch between name and bit + for known features. + """ + + if not isinstance(features, bytes): + raise TypeError(f"Features must be bytes, {type(features)} received") + + int_features = int.from_bytes(features, "big") + padding = max(2 * len(known_features), int_features.bit_length()) + padding += padding % 2 # round up to the nearest even number + + bit_features = f"{int_features:b}".zfill(padding) + bit_pairs = [bit_features[i : i + 2] for i in range(0, len(bit_features), 2)] + features_dict = {} + + for i, pair in enumerate(reversed(bit_pairs)): + if pair == "11": + raise ValueError("Both odd and even bits cannot be set in a pair") + + # Known features are stored no matter if they are set or not + odd_bit = 2 * i + feature_name = known_odd_bits.get(odd_bit) + if feature_name: + if pair == "00": + features_dict[feature_name] = Feature(odd_bit, is_set=False) + elif pair == "01": + features_dict[feature_name] = Feature(odd_bit, is_set=True) + elif pair == "10": + features_dict[feature_name] = Feature(odd_bit + 1, is_set=True) + + # For unknown features, we only store the ones that are set + else: + feature_name = f"unknown_{odd_bit}" + if pair == "01": + features_dict[feature_name] = Feature(odd_bit, is_set=True) + elif pair == "10": + features_dict[feature_name] = Feature(odd_bit + 1, is_set=True) + + return cls(**features_dict) + + def set_feature(self, name, bit): + """ + Sets a feature from the FeatureVector identified by its name and bit. + + Args: + name (:obj:`str`): the name of the feature. + bit (:obj:`int`): the index that the feature bit holds in the feature vector. + + Raises: + :obj:`TypeError`: If name is not str or bit is not integer. + :obj:`ValueError`: If the given name and bit do not match (for known features). + """ + + if not isinstance(name, str): + raise TypeError("name must be str") + if not isinstance(bit, int): + raise TypeError("bit must be integer") + + # Features we know about or features we don't know about and that do not collide with the ones we know about + if check_feature_name_bit_pair(name, bit): + vars(self)[name] = Feature(bit, is_set=True) + self._features[name] = vars(self)[name] + else: + raise ValueError("Feature name and bit do not match") + + def serialize(self): + """Computes the serialization of the FeatureVector.""" + serialized_features = 0 + for feature in self._features.values(): + if feature.is_set: + serialized_features += pow(2, feature.bit) + + return serialized_features.to_bytes(ceil(serialized_features.bit_length() / 8), "big") + + def to_dict(self): + """Creates the dictionary representation of the Feature.""" + features = {} + for name in self._features: + feature = vars(self)[name] + if feature.is_set: + if feature.is_odd: + features[name] = "odd" + else: + features[name] = "even" + else: + features[name] = 0 + return features diff --git a/common/net/tlv.py b/common/net/tlv.py new file mode 100644 index 00000000..1870373b --- /dev/null +++ b/common/net/tlv.py @@ -0,0 +1,150 @@ +from common.tools import is_256b_hex_str + +import common.net.bigsize as bigsize +from common.net.utils import message_sanity_checks + +tlv_types = { + "networks": b"\x01", + "amt_to_forward": b"\x02", + "outgoing_cltv_value": b"\x04", + "short_channel_id": b"\x06", + "payment_data": b"\x08", +} + + +class TLVRecord: + """ + Base class for TLV records. + + Args: + t (:obj:`bytes`): the message type. + l (:obj:`bytes`): the value length. + v (:obj:`bytes`): the message value. + """ + + def __init__(self, t=b"", l=b"", v=b""): + if not isinstance(t, bytes): + raise TypeError("t must be bytes") + if not isinstance(l, bytes): + raise TypeError("l must be bytes") + if not isinstance(v, bytes): + raise TypeError("v must be bytes") + + self.type = t + self.length = l + self.value = v + + def __len__(self): + """Returns the length of the serialized TLV record""" + return len(self.serialize()) + + def __eq__(self, other): + return isinstance(other, TLVRecord) and self.serialize() == other.serialize() + + @classmethod + def from_bytes(cls, message): + """ + Builds a TLV record from bytes. + + Args: + message (:obj:`bytes`): the byte representation of the TLV record. + + Returns: + :obj:`TLVRecord`: The TLVRecord built from the provided bytes. + + Raises: + :obj:`TypeError`: If the provided message is not in bytes. + :obj:`ValueError`: If the provided message is not properly encoded. + """ + + if not isinstance(message, bytes): + raise TypeError("message must be bytes") + + try: + t, t_length = bigsize.parse(message) + if t.to_bytes(t_length, "big") == tlv_types["networks"]: + return NetworksTLV.from_bytes(message) + else: + l, l_length = bigsize.parse(message[t_length:]) + v = message[t_length + l_length :] + if l > len(v): + # Value is not long enough + raise ValueError() # This message gets overwritten so it does not matter + + if len(message) != t_length + l_length + len(v): + # There is additional trailing data + raise ValueError() # This message gets overwritten so it does not matter + + return cls(t.to_bytes(t_length, "big"), l.to_bytes(l_length, "big"), v) + except ValueError as e: + raise ValueError("Wrong tlv message format. Unexpected EOF") + + def serialize(self): + """Returns the serialized representation of the TLV record.""" + return self.type + self.length + self.value + + +class NetworksTLV(TLVRecord): + """ + TLV record for networks in the init message. Contains the genesis block hash of the networks the node is interested + in. + + Args: + networks (:obj:`list`): a list of genesis block hashes (hex str). This parameter is optional. + + Raises: + :obj:`TypeError`: If networks is set and it is not a list. + :obj:`ValueError`: If networks is set and all its elements are not 32-byte hex strings. + """ + + def __init__(self, networks=None): + if not networks: + super().__init__(tlv_types["networks"], bigsize.encode(0)) + self.networks = [] + elif isinstance(networks, list): + chains = b"" + for chain in networks: + if not is_256b_hex_str(chain): + raise ValueError("All networks must be 32-byte hex str") + chains += bytes.fromhex(chain) + super().__init__(tlv_types["networks"], bigsize.encode(32 * len(networks)), chains) + self.networks = networks + else: + raise TypeError("networks must be a list if set") + + @classmethod + def from_bytes(cls, message): + """ + Builds a NetworksTLV record from bytes. + + Args: + message (:obj:`bytes`): the byte representation of the TLV record. + + Returns: + :obj:`NetworksTLV`: The NetworksTLV built from the provided bytes. + + Raises: + :obj:`TypeError`: If the provided message is not in bytes or networks is not a list. + :obj:`ValueError`: If the provided message is not properly encoded or the items in networks are not 32-byte + hex strings. + """ + + message_sanity_checks(message, tlv_types["networks"], 2, tlv=True) + + try: + clen, length_offset = bigsize.parse(message[1:]) + except ValueError: + # TLV can be defined with no data. + return cls() + + # Chains is an array of genesis block hashes (32-byte each) + if clen % 32: + raise ValueError(f"chains must be multiple of 32, {clen} received") + + networks = [] + offset = 1 + length_offset # type + length fields + for i in range(clen // 32): + networks.append(message[offset : offset + 32].hex()) + offset += 32 + + return cls(networks) diff --git a/common/net/utils.py b/common/net/utils.py new file mode 100644 index 00000000..cab41737 --- /dev/null +++ b/common/net/utils.py @@ -0,0 +1,41 @@ +import common.net.bigsize as bigsize + + +def message_sanity_checks(message, expected_type, min_len, tlv=False): + """ + Runs sanity checks to a received byte-encoded message, such as checking its minimum length or message type. + + Args: + message (:obj:`bytes`): the bytes-encoded message. + expected_type (:obj:`str`): the expected type of the message. + min_len (:obj:`int`): the minimum expected length of the message. + tlv (:obj:`bool`): whether the message is a tlv record or not. + + Raises: + :obj:`TypeError`: If the provided message is not in bytes. + :obj:`ValueError`: If the provided message is not long enough or not of the expected type. + """ + + if not isinstance(message, bytes): + raise TypeError("message be must a bytearray") + if not isinstance(expected_type, bytes): + raise TypeError("expected_type be must bytes") + if not isinstance(min_len, int): + raise TypeError("min_len be must int") + if not isinstance(tlv, bool): + raise TypeError("tlv be must bool if set") + if len(message) < min_len: + raise ValueError(f"message be must at least {min_len}-byte long") + + if tlv: + tlv_type, type_length = bigsize.parse(message) + tlv_type_byte = tlv_type.to_bytes(type_length, "big") + if tlv_type_byte != expected_type: + raise ValueError( + f"Wrong message format. types do not match (expected: {expected_type}, received: {tlv_type_byte}" + ) + else: + if message[:2] != expected_type: + raise ValueError( + f"Wrong message format. types do not match (expected: {expected_type}, received: {message[:2]}" + ) diff --git a/test/common/unit/net/test_bigsize.py b/test/common/unit/net/test_bigsize.py new file mode 100644 index 00000000..a065640d --- /dev/null +++ b/test/common/unit/net/test_bigsize.py @@ -0,0 +1,95 @@ +import pytest +import common.net.bigsize as bigsize + +# Test cases are copied from +# https://github.com/lightningnetwork/lightning-rfc/blob/bdd42711014643d5b2d4cbe179677451b940a9de/01-messaging.md + +value_encoding_pair = { + 0: b"\x00", + 252: b"\xfc", + 253: b"\xfd\x00\xfd", + 65535: b"\xfd\xff\xff", + 65536: b"\xfe\x00\x01\x00\x00", + 4294967295: b"\xfe\xff\xff\xff\xff", + 4294967296: b"\xff\x00\x00\x00\x01\x00\x00\x00\x00", + 18446744073709551615: b"\xff\xff\xff\xff\xff\xff\xff\xff\xff", +} + +non_canonical = [b"\xfd\x00\xfc", b"\xfe\x00\x00\xff\xff", b"\xff\x00\x00\x00\x00\xff\xff\xff\xff"] +unexpected_eof = [b"\xfd\x00", b"\xfe\xff\xff", b"\xff\xff\xff\xff\xff", b"", b"\xfd", b"\xfe", b"\xff"] + +no_int = ["", 1.1, object(), b"\x00"] +no_bytes = ["", 1.1, object(), 0] + + +def test_encode(): + for k, v in value_encoding_pair.items(): + assert bigsize.encode(k) == v + + +def test_encode_wrong(): + # Wrong type + for v in no_int: + with pytest.raises(TypeError): + bigsize.encode(v) + + # Negative value + for i in range(-1, -100): + with pytest.raises(ValueError, match="value must be a positive integer"): + bigsize.encode(i) + + # Value bigger than 8-bytes + with pytest.raises(ValueError, match="BigSize can only encode up to 8-byte values"): + bigsize.encode(pow(2, 64) + 1) + + +def test_decode(): + for k, v in value_encoding_pair.items(): + assert bigsize.decode(v) == k + + +def test_decode_wrong(): + # Wrong type + for v in no_bytes: + with pytest.raises(TypeError): + bigsize.decode(v) + + # Value too big (> 9-bytes) + with pytest.raises(ValueError, match="value must be, at most, 9-bytes long"): + bigsize.decode(bytes(10)) + + # Wrong encoding + for v in non_canonical: + with pytest.raises(ValueError, match="Encoded BigSize is non-canonical"): + bigsize.decode(v) + + for v in unexpected_eof: + with pytest.raises(ValueError, match="Unexpected EOF while decoding BigSize"): + bigsize.decode(v) + + +def test_parse(): + # Parsing should work for the properly encoded ones + for k, v in value_encoding_pair.items(): + int_value, offset = bigsize.parse(v) + assert int_value == k and offset == len(v) + + # Wrong encoding (behaves exactly like decode_wrong) + for v in non_canonical: + with pytest.raises(ValueError, match="Encoded BigSize is non-canonical"): + bigsize.parse(v) + + +def test_parse_wrong(): + # Wrong type + for v in no_bytes: + with pytest.raises(TypeError): + bigsize.parse(v) + + # Empty bytearray + with pytest.raises(ValueError, match="value must be at least 1-byte long"): + bigsize.parse(b"") + + # Value too big (> 9-bytes) + with pytest.raises(ValueError, match="value must be, at most, 9-bytes long"): + bigsize.decode(bytes(10)) diff --git a/test/common/unit/net/test_bolt1.py b/test/common/unit/net/test_bolt1.py new file mode 100644 index 00000000..5146574d --- /dev/null +++ b/test/common/unit/net/test_bolt1.py @@ -0,0 +1,408 @@ +import pytest + +from common.net.bolt1 import Message, InitMessage, ErrorMessage, PingMessage, PongMessage +from common.net.bolt9 import FeatureVector +from common.net.tlv import TLVRecord, NetworksTLV + +from test.common.unit.conftest import get_random_value_hex + + +def test_message(): + # Messages are built from a message_type (bytes) a payload (bytes), and an optional list of TLVRecords + mtype = b"\x00" + payload = b"\x00\x01\x02" + extension = [] + m = Message(mtype, payload, extension) + assert isinstance(m, Message) + assert m.type == mtype and m.payload == payload and m.extension == [] + + # Same with some tlvs + extension = [TLVRecord(), NetworksTLV()] + m2 = Message(mtype, payload, extension) + assert isinstance(m2, Message) + assert m2.type == mtype and m2.payload == payload and m2.extension == extension + + +def test_message_wrong_types(): + # Wrong mtype + with pytest.raises(TypeError, match="mtype must be bytes"): + Message("", b"") + + # Wrong payload + with pytest.raises(TypeError, match="payload must be bytes"): + Message(b"", "") + + # Wrong extension type + with pytest.raises(TypeError, match="extension must be a list if set"): + Message(b"", b"", "") + + # Wrong extension content + with pytest.raises(TypeError, match="All items in extension must be TLVRecords"): + Message(b"", b"", [TLVRecord(), 1]) + + +def test_message_from_bytes(): + # From bytes builds an instance of a children class as long as the type is known, raises ValueError otherwise + # Not testing particular cases for the child classes since they will be covered in their own tests + + # Init + m = b"\x00\x10\x00\x00\x00\x00" + assert isinstance(Message.from_bytes(m), InitMessage) + + # Error + m = b"\x00\x11" + bytes.fromhex(get_random_value_hex(32)) + b"\x00\x00" + assert isinstance(Message.from_bytes(m), ErrorMessage) + + # Ping + m = b"\x00\x12\x00\x00\x00\x00" + assert isinstance(Message.from_bytes(m), PingMessage) + + # Pong + m = b"\x00\x13\x00\x00" + assert isinstance(Message.from_bytes(m), PongMessage) + + # Unknown + with pytest.raises(ValueError, match="Cannot decode unknown message type"): + Message.from_bytes(b"\x00\xff") + + +def test_message_from_bytes_wrong(): + # Message must be bytes + with pytest.raises(TypeError, match="message be must a bytearray"): + Message.from_bytes("random_message") + + # Message must be at least 2-byte long to account for the type + with pytest.raises(ValueError, match="message be must at least 2-byte long"): + Message.from_bytes(b"\x00") + + +def test_message_serialize(): + # Serialize returns the concatenation of the byte representation of each field: + # type + payload + [extension] + + # No extension + mtype = b"\x00\x001" + payload = b"\x00\x00\x00\x00" + assert Message(mtype, payload).serialize() == mtype + payload + + # With extensions + extension = [TLVRecord(t=b"\x00", l=b"\x01", v=b"\x02")] + assert Message(mtype, payload, extension).serialize() == mtype + payload + b"\x00" + b"\x01" + b"\x02" + + +def test_message_to_dict(): + mtype = b"\x00" + payload = b"\x00\x01\x02" + extension = [] + m = Message(mtype, payload, extension) + m_dict = m.to_dict() + + assert isinstance(m_dict, dict) + assert m_dict.get("type") == mtype.hex() + assert m_dict.get("payload") == payload.hex() + assert m_dict.get("extension") == extension + + # Add extension + extension1 = TLVRecord() + random_hash = get_random_value_hex(32) + extension2 = NetworksTLV([random_hash]) + m2 = Message(mtype, payload, [extension1, extension2]) + m2_dict = m2.to_dict() + assert m2_dict.get("type") == mtype.hex() + assert m2_dict.get("payload") == payload.hex() + assert m2_dict.get("extension") == ["", "0120" + random_hash] + + +def test_init_message(): + # Init message requires global_features(FeatureVector), local_features (FeatureVector) and optionally a NetworksTLV + gf = FeatureVector.from_bytes(b"\x02") + lf = FeatureVector() + im = InitMessage(gf, lf) + assert isinstance(im, InitMessage) + assert im.global_features == gf and im.local_features == lf and im.networks is None + + # Same with networks + networks = NetworksTLV([get_random_value_hex(32) for _ in range(5)]) + im2 = InitMessage(gf, lf, networks) + assert isinstance(im2, InitMessage) + assert im2.global_features == gf and im2.local_features == lf and im2.networks is networks + + +def test_init_message_wrong(): + # No FeatureVectors + with pytest.raises(TypeError, match="global_features and local_features must be FeatureVector instances"): + InitMessage("features", FeatureVector()) + with pytest.raises(TypeError, match="global_features and local_features must be FeatureVector instances"): + InitMessage(FeatureVector(), "features") + + # TLV must be NetworksTLV is fet (for now) + with pytest.raises(TypeError, match="networks must be of type NetworksTLV"): + InitMessage(FeatureVector(), FeatureVector(), "TLV") + + +def test_init_message_from_bytes(): + # Message type must be init and size at least 6 (type + gflen + flen) + mtype = b"\x00\x10" + gflen = b"\x00\x00" + flen = gflen + im = InitMessage.from_bytes(mtype + gflen + flen) + assert ( + isinstance(im, InitMessage) + and im.type == mtype + and im.global_features.serialize() == im.local_features.serialize() == b"" + ) + + # A more meaningful init (with some features) + mtype = b"\x00\x10" + global_features = b"\x2a\xaa\xaa" # All odd + gflen = b"\x00\x03" + local_features = b"\x01" # Feature 1 even + flen = b"\x00\x01" + im2 = InitMessage.from_bytes(mtype + gflen + global_features + flen + local_features) + assert ( + isinstance(im2, InitMessage) + and im2.type == mtype + and im2.global_features.serialize() == global_features + and im2.local_features.serialize() == local_features + ) + + # With some networks + networks = NetworksTLV([get_random_value_hex(32) for _ in range(5)]) + im3 = InitMessage.from_bytes(mtype + gflen + global_features + flen + local_features + networks.serialize()) + assert ( + isinstance(im3, InitMessage) + and im3.type == mtype + and im3.global_features.serialize() == global_features + and im3.local_features.serialize() == local_features + and im3.networks == networks + ) + + +def test_init_message_from_bytes_wrong(): + # Message is not bytes + with pytest.raises(TypeError, match="message be must a bytearray"): + InitMessage.from_bytes("message") + + # Message is not long enough < 6 + with pytest.raises(ValueError, match="message be must at least 6-byte long"): + InitMessage.from_bytes(b"\x00\x10\x00\x01\x02") + + # Type is not init + with pytest.raises(ValueError, match="Wrong message format. types do not match"): + InitMessage.from_bytes(b"\x00\x00\x00\x00\x00\x00") + + # Encoded lengths are wrong causing an unexpected EOF + with pytest.raises(ValueError, match="Wrong message format. Unexpected EOF"): + InitMessage.from_bytes(b"\x00\x10\x00\x01\x00\x00") + + +def test_error_message(): + # Error message expects a channel_id (32-hex encoded str) and an optional data field + cid = get_random_value_hex(32) + em = ErrorMessage(cid) + assert isinstance(em, ErrorMessage) + assert em.channel_id == cid + assert not em.data + + # Same with associated data + data = "error message data" + em2 = ErrorMessage(cid, data) + assert isinstance(em, ErrorMessage) + assert em2.channel_id == cid + assert em2.data == data + + +def test_error_message_wrong(): + # Channel id must be a 32-byte hex str + # Data must be string if set and no longer than the message cap size when encoded pow(2, 16) - 1 + + # Wrong channel id + with pytest.raises(ValueError, match="channel_id must be a 256-bit hex string"): + ErrorMessage(get_random_value_hex(31)) + + with pytest.raises(ValueError, match="channel_id must be a 256-bit hex string"): + ErrorMessage(dict()) + + # Wrong data type + with pytest.raises(ValueError, match="data must be string if set"): + ErrorMessage(get_random_value_hex(32), b"message") + + # Data too long + with pytest.raises(ValueError, match=f"Encoded data length cannot be bigger than {pow(2, 16) - 1}"): + ErrorMessage(get_random_value_hex(32), "A" * (pow(2, 16))) + + +def test_error_from_bytes(): + # Message must be, at least, 36-bytes long + mtype = b"\x00\x11" + cid = bytes.fromhex(get_random_value_hex(32)) + data_len = b"\x00\x00" + + em = ErrorMessage.from_bytes(mtype + cid + data_len) + assert isinstance(em, ErrorMessage) + assert em.channel_id == cid.hex() + assert not em.data + + # Same with associated data + data = "message" + data_len = len(data).to_bytes(2, "big") + em2 = ErrorMessage.from_bytes(mtype + cid + data_len + data.encode("utf-8")) + assert isinstance(em2, ErrorMessage) + assert em2.channel_id == cid.hex() + assert em2.data == data + + +def test_error_from_bytes_wrong(): + # Message is not bytes + with pytest.raises(TypeError, match="message be must a bytearray"): + ErrorMessage.from_bytes("message") + + # Message is not long enough < 36 + with pytest.raises(ValueError, match="message be must at least 36-byte long"): + ErrorMessage.from_bytes(b"\x00\x11" + bytes(33)) + + # Type is not error + with pytest.raises(ValueError, match="Wrong message format. types do not match"): + ErrorMessage.from_bytes(b"\x00\x10" + bytes.fromhex(get_random_value_hex(32)) + b"\x00\x00") + + # Encoded lengths are wrong causing an unexpected EOF + with pytest.raises(ValueError, match="Wrong message format. Unexpected EOF"): + ErrorMessage.from_bytes(b"\x00\x11" + bytes.fromhex(get_random_value_hex(32)) + b"\x00\x02\x00") + + # Encoded lengths are wrong leaving additional data at the end + with pytest.raises(ValueError, match="Wrong data format. message has additional trailing data"): + ErrorMessage.from_bytes(b"\x00\x11" + bytes.fromhex(get_random_value_hex(32)) + b"\x00\x01\x00\x00") + + +def test_ping_message(): + # Ping expects a number of pong bytes and optionally some ignored data (bytes) + num_pong_bytes = 10 + pm = PingMessage(num_pong_bytes) + assert isinstance(pm, PingMessage) + assert pm.num_pong_bytes == num_pong_bytes + + # Same with ignore_bytes + ignored_bytes = b"\x01\x04\xff\x00" + pm2 = PingMessage(num_pong_bytes, ignored_bytes) + assert isinstance(pm2, PingMessage) + assert pm2.num_pong_bytes == num_pong_bytes + assert pm2.ignored_bytes == ignored_bytes + + +def test_ping_message_wrong(): + # num_pong_bytes must be an integer between 0 and pow(2, 16) + with pytest.raises(ValueError, match=f"num_pong_bytes must be between 0 and {pow(2, 16) -1}"): + PingMessage(-1) + with pytest.raises(ValueError, match=f"num_pong_bytes must be between 0 and {pow(2, 16) - 1}"): + PingMessage(pow(2, 16)) + + # ignore_bytes must be bytes if set + with pytest.raises(TypeError, match="ignored_bytes must be bytes if set"): + PingMessage(pow(2, 16) - 1, "ignored_bytes") + + # ignore_bytes length cannot be bigger than pow(2, 16) - 4 + with pytest.raises(ValueError, match=f"ignored_bytes cannot be higher than {pow(2, 16) - 4}"): + PingMessage(10, bytes(pow(2, 16) - 3)) + + +def test_ping_message_from_bytes(): + # message must be at least 6 bytes long (type + num_pong_bytes + byteslen) + mtype = b"\x00\x12" + num_pong_bytes = b"\x00\x01" + bytes_len = b"\x00\x00" + + pm = PingMessage.from_bytes(mtype + num_pong_bytes + bytes_len) + assert isinstance(pm, PingMessage) + assert pm.num_pong_bytes == int.from_bytes(num_pong_bytes, "big") + assert not pm.ignored_bytes + + # Same with some ignored data + ignored_data = b"\x00\x01\x02\x03" + bytes_len = b"\x00\x04" + pm2 = PingMessage.from_bytes(mtype + num_pong_bytes + bytes_len + ignored_data) + assert isinstance(pm2, PingMessage) + assert pm2.num_pong_bytes == int.from_bytes(num_pong_bytes, "big") + assert pm2.ignored_bytes == ignored_data + + +def test_ping_message_from_bytes_wrong(): + # Message is not bytes + with pytest.raises(TypeError, match="message be must a bytearray"): + PingMessage.from_bytes("message") + + # Message is not long enough < 6 + with pytest.raises(ValueError, match="message be must at least 6-byte long"): + PingMessage.from_bytes(b"\x00\x12\x00\x01\x02") + + # Type is not ping + with pytest.raises(ValueError, match="Wrong message format. types do not match"): + PingMessage.from_bytes(b"\x00\x10\x00\x01\x00\x00") + + # Encoded lengths are wrong causing an unexpected EOF + with pytest.raises(ValueError, match="Wrong message format. Unexpected EOF"): + PingMessage.from_bytes(b"\x00\x12\x00\x00\x00\x01") + + # Encoded lengths are wrong leaving additional data at the end + with pytest.raises(ValueError, match="Wrong data format. message has additional trailing data"): + PingMessage.from_bytes(b"\x00\x12\x00\x00\x00\x01\x00\x00") + + +def test_pong_message(): + # Pong can be empty, and optionally can receive some ignored bytes + pm = PongMessage() + assert isinstance(pm, PongMessage) + assert not pm.ignored_bytes + + # With some ignored_bytes + ignored_bytes = b"\x00\x02\x06" + pm2 = PongMessage(ignored_bytes) + assert isinstance(pm2, PongMessage) + assert pm2.ignored_bytes is ignored_bytes + + +def test_pong_message_wrong(): + # ignored_bytes must be bytes if set + with pytest.raises(TypeError, match="ignored_bytes must be bytes if set"): + PongMessage("ignored_bytes") + + # ignore_bytes length cannot be bigger than pow(2, 16) - 4 + with pytest.raises(ValueError, match=f"ignored_bytes cannot be higher than {pow(2, 16) - 4}"): + PongMessage(bytes(pow(2, 16) - 3)) + + +def test_pong_message_from_bytes(): + # message must be bytes and length at least 4 (mtype + byteslen) + mtype = b"\x00\x13" + bytes_len = b"\x00\x00" + pm = PongMessage.from_bytes(mtype + bytes_len) + assert isinstance(pm, PongMessage) + assert not pm.ignored_bytes + + # Add some ignored data + ignored_data = b"\x03\xfd\xef" + data_len = b"\x00\x03" + pm2 = PongMessage.from_bytes(mtype + data_len + ignored_data) + assert isinstance(pm2, PongMessage) + assert pm2.ignored_bytes == ignored_data + + +def test_pong_message_from_bytes_wrong(): + # Message is not bytes + with pytest.raises(TypeError, match="message be must a bytearray"): + PongMessage.from_bytes("message") + + # Message is not long enough < 4 + with pytest.raises(ValueError, match="message be must at least 4-byte long"): + PongMessage.from_bytes(b"\x00\x13\x00") + + # Type is not pong + with pytest.raises(ValueError, match="Wrong message format. types do not match"): + PongMessage.from_bytes(b"\x00\x10\x00\x00") + + # Encoded lengths are wrong causing an unexpected EOF + with pytest.raises(ValueError, match="Wrong message format. Unexpected EOF"): + PongMessage.from_bytes(b"\x00\x13\x00\x01") + + # Encoded lengths are wrong leaving additional data at the end + with pytest.raises(ValueError, match="Wrong data format. message has additional trailing data"): + PongMessage.from_bytes(b"\x00\x13\x00\x01\x00\x01") diff --git a/test/common/unit/net/test_bolt9.py b/test/common/unit/net/test_bolt9.py new file mode 100644 index 00000000..007607c3 --- /dev/null +++ b/test/common/unit/net/test_bolt9.py @@ -0,0 +1,219 @@ +import pytest +from common.net.bolt9 import Feature, FeatureVector, known_features + + +def test_feature(): + # Features expect two params, an integer representing the bit of the feature and a boolean with whether that bit + # is set or not. + odd_feature_set = Feature(1, True) + assert odd_feature_set.bit == 1 + assert odd_feature_set.is_set is True + assert odd_feature_set.is_odd is True + + odd_feature_unset = Feature(1, False) + assert odd_feature_unset.bit == 1 + assert odd_feature_unset.is_set is False + assert odd_feature_unset.is_odd is True + + even_feature_set = Feature(0, True) + assert even_feature_set.bit == 0 + assert even_feature_set.is_set is True + assert even_feature_set.is_odd is False + + even_feature_unset = Feature(0, False) + assert even_feature_unset.bit == 0 + assert even_feature_unset.is_set is False + assert even_feature_unset.is_odd is False + + +def test_feature_vector(): + # FeatureVector expects kwarg features (name=Feature) + f1 = Feature(1, True) + f2 = Feature(3, False) + fv = FeatureVector(option_data_loss_protect=f1, initial_routing_sync=f2) + assert fv.option_data_loss_protect == f1 and fv.initial_routing_sync == f2 + + # initial_routing_sync is a special feature with no even bit + with pytest.raises(ValueError, match="initial_routing_sync has no even bit"): + FeatureVector(initial_routing_sync=Feature(2, True)) + + # known features have known bits, mismatches in known pairs are not allowed + with pytest.raises(ValueError, match="Feature name and bit do not match"): + FeatureVector(option_data_loss_protect=Feature(10, True)) + + # Unknown features can have whatever name and bit they like, as long as they do not collide with known features + with pytest.raises(ValueError, match="Feature name and bit do not match"): + FeatureVector(option_data_loss_protect=Feature(42, True)) # known name, unknown bit + with pytest.raises(ValueError, match="Feature name and bit do not match"): + FeatureVector(another_unknown_name=Feature(1, True)) # unknown name, know bit + + # unknown name and bits are allowed + FeatureVector(unknown_feature_name=Feature(42, True)) + + # Finally, all kwargs must have a Feature value + no_feature_dicts = [0, 1.1, True, object, {}, dict()] + for value in no_feature_dicts: + with pytest.raises(TypeError): + FeatureVector(random_name=value) + + +# Encoded features are correct as long as two bits are set from the same pair +# Known features are parsed with its name, whereas unknown are given unknown_i where i is the feature odd bit +def test_feature_vector_from_bytes(): + # The easiest way of testing this is to create the FeatureVector and serialize it + no_features = b"" + assert no_features == FeatureVector.from_bytes(no_features).serialize() + + f0 = b"\x02" + fv0 = FeatureVector.from_bytes(f0) + assert fv0.option_data_loss_protect.is_set and fv0.option_data_loss_protect.is_odd + assert f0 == fv0.serialize() + + f0_2 = b"\x0a" + fv0_2 = FeatureVector.from_bytes(f0_2) + assert fv0_2.option_data_loss_protect.is_set and fv0_2.option_data_loss_protect.is_odd + assert fv0_2.initial_routing_sync.is_set and fv0_2.initial_routing_sync.is_odd + assert f0_2 == fv0_2.serialize() + + # Unknown feature (set bit 22) + f22 = b"\x40\x00\x00" + fv22 = FeatureVector.from_bytes(f22) + assert fv22.unknown_22.is_set + assert fv22.serialize() == f22 + + # All odd features + f_all_odd = b"\x2a\xaa\xaa" + fv_all_odd = FeatureVector.from_bytes(f_all_odd) + assert fv_all_odd.option_data_loss_protect.is_set and fv_all_odd.option_data_loss_protect.is_odd + assert fv_all_odd.initial_routing_sync.is_set and fv_all_odd.initial_routing_sync.is_odd + assert fv_all_odd.option_upfront_shutdown_script.is_set and fv_all_odd.option_upfront_shutdown_script.is_odd + assert fv_all_odd.gossip_queries.is_set and fv_all_odd.gossip_queries.is_odd + assert fv_all_odd.var_onion_optin.is_set and fv_all_odd.var_onion_optin.is_odd + assert fv_all_odd.gossip_queries_ex.is_set and fv_all_odd.gossip_queries_ex.is_odd + assert fv_all_odd.option_static_remotekey.is_set and fv_all_odd.option_static_remotekey.is_odd + assert fv_all_odd.payment_secret.is_set and fv_all_odd.payment_secret.is_odd + assert fv_all_odd.basic_mpp.is_set and fv_all_odd.basic_mpp.is_odd + assert fv_all_odd.option_support_large_channel.is_set and fv_all_odd.option_support_large_channel.is_odd + assert fv_all_odd.option_anchor_outputs.is_set and fv_all_odd.option_anchor_outputs.is_odd + assert fv_all_odd.serialize() == f_all_odd + + # All even features (but initial_routing_sync) + f_all_even = b"\x15\x55\x59" + fv_all_even = FeatureVector.from_bytes(f_all_even) + assert fv_all_even.option_data_loss_protect.is_set and not fv_all_even.option_data_loss_protect.is_odd + assert fv_all_even.initial_routing_sync.is_set and fv_all_even.initial_routing_sync.is_odd + assert fv_all_even.option_upfront_shutdown_script.is_set and not fv_all_even.option_upfront_shutdown_script.is_odd + assert fv_all_even.gossip_queries.is_set and not fv_all_even.gossip_queries.is_odd + assert fv_all_even.var_onion_optin.is_set and not fv_all_even.var_onion_optin.is_odd + assert fv_all_even.gossip_queries_ex.is_set and not fv_all_even.gossip_queries_ex.is_odd + assert fv_all_even.option_static_remotekey.is_set and not fv_all_even.option_static_remotekey.is_odd + assert fv_all_even.payment_secret.is_set and not fv_all_even.payment_secret.is_odd + assert fv_all_even.basic_mpp.is_set and not fv_all_even.basic_mpp.is_odd + assert fv_all_even.option_support_large_channel.is_set and not fv_all_even.option_support_large_channel.is_odd + assert fv_all_even.option_anchor_outputs.is_set and not fv_all_even.option_anchor_outputs.is_odd + assert fv_all_even.serialize() == f_all_even + + +def test_feature_vector_from_bytes_both_set(): + # The same feature cannot be set with both bits set + f0_1 = b"\x03" + with pytest.raises(ValueError, match="Both odd and even bits cannot be set in a pair"): + FeatureVector.from_bytes(f0_1) + + +def test_feature_vector_from_bytes_wrong_type(): + # Features must be bytes + with pytest.raises(TypeError, match="Features must be bytes"): + FeatureVector.from_bytes("random string") + + +def test_feature_vector_set_feature(): + # A feature can be set as long as the name and bit match, or a wrong pair (known name, unknown bit or vice versa) is + # not set. + fv = FeatureVector.from_bytes(b"\x00") + + # Set option_upfront_shutdown_script + fv.set_feature("option_upfront_shutdown_script", 4) + assert fv.option_upfront_shutdown_script.is_set + assert not fv.option_upfront_shutdown_script.is_odd + + # We can set it to odd too + fv.set_feature("option_upfront_shutdown_script", 5) + assert fv.option_upfront_shutdown_script.is_set + assert fv.option_upfront_shutdown_script.is_odd + + # Unknown features work too as long as they don't mismatch + fv.set_feature("random_feature", 24) + assert fv.random_feature.is_set + assert not fv.random_feature.is_odd + + +def test_feature_vector_set_feature_mismatch(): + # If the feature name and the bit do not match, set_feature will fail + # Set option_upfront_shutdown_script + fv = FeatureVector.from_bytes(b"\x00") + with pytest.raises(ValueError, match="Feature name and bit do not match"): + fv.set_feature("option_upfront_shutdown_script", 3) + + # Unknown features that mismatch are not accepted either + with pytest.raises(ValueError, match="Feature name and bit do not match"): + fv.set_feature("random_feature", 3) + + +def test_feature_vector_set_wrong_types(): + fv = FeatureVector.from_bytes(b"\x00") + # Name must be str and bit must be int + with pytest.raises(TypeError): + fv.set_feature(int(), int()) + + with pytest.raises(TypeError): + fv.set_feature(str(), str()) + + +def test_feature_vector_serialize(): + # This has been covered in test_feature_vector_from_bytes + pass + + +def test_feature_vector_to_dict(): + # Converts feature names to dict + fv = FeatureVector.from_bytes(b"\x00") + + # There is no feature set + for k, v in fv.to_dict().items(): + assert v is 0 + + # The dict contains only known features, as long as an unknown is not set + assert fv.to_dict().keys() == known_features.keys() + + fv.set_feature("option_data_loss_protect", 0) + # Only option_data_loss_protect is set (and it's even) + for k, v in fv.to_dict().items(): + if k == "option_data_loss_protect": + assert v == "even" + else: + assert v == 0 + + fv.set_feature("option_upfront_shutdown_script", 5) + # option_data_loss_protect is set (and it's even) and option_upfront_shutdown_script is set (and it's odd) + for k, v in fv.to_dict().items(): + if k == "option_data_loss_protect": + assert v == "even" + elif k == "option_upfront_shutdown_script": + assert v == "odd" + else: + assert v == 0 + + # It works with unknown features too (name is unknown_i) + fv.set_feature("unknown_24", 24) + for k, v in fv.to_dict().items(): + if k == "option_data_loss_protect": + assert v == "even" + elif k == "option_upfront_shutdown_script": + assert v == "odd" + elif k == "unknown_24": + assert v == "even" + else: + assert v == 0 + + assert set(fv.to_dict().keys()).difference(known_features.keys()) == {"unknown_24"} diff --git a/test/common/unit/net/test_tlv.py b/test/common/unit/net/test_tlv.py new file mode 100644 index 00000000..a6a60fd5 --- /dev/null +++ b/test/common/unit/net/test_tlv.py @@ -0,0 +1,149 @@ +import pytest + +import common.net.bigsize as bigsize +from common.net.tlv import TLVRecord, NetworksTLV + +from test.common.unit.conftest import get_random_value_hex + + +def test_tlv_record(): + # The TLV record only enforces the fields to be bytes + tlv = TLVRecord(b"\x01", b"\x02", b"\x03") + assert tlv.type == b"\x01" and tlv.length == b"\x02" and tlv.value == b"\x03" + + # If any of the fields is not byte it'll fail + with pytest.raises(TypeError, match="t must be bytes"): + TLVRecord("", b"\x02", b"\x03") + with pytest.raises(TypeError, match="l must be bytes"): + TLVRecord(b"\x01", "", b"\x03") + with pytest.raises(TypeError, match="v must be bytes"): + TLVRecord(b"\x01", b"\x02", "") + + +def test_tlv_record_len(): + # The TLV length is defined as the length of its serialized fields + t = b"\x01" + l = b"\x02" + v = b"\x03" + tlv = TLVRecord(t, l, v) + assert len(tlv) == len(t) + len(l) + len(v) + + +def test_tlv_record_from_bytes(): + # from_bytes builds an instance of a child class depending on the data type. Currently it only supports Networks. + + # NetworksTLV + t = bigsize.encode(1) + l = bigsize.encode(32) + v = bytes.fromhex(get_random_value_hex(32)) + ntlv = TLVRecord.from_bytes(t + l + v) + + assert isinstance(ntlv, NetworksTLV) + assert ntlv.type == t and ntlv.length == l and ntlv.value == v + + # Any other (unknown types) will return TLVRecord + t = bigsize.encode(0) + tlv = TLVRecord.from_bytes(t + l + v) + assert isinstance(ntlv, TLVRecord) + assert tlv.type == t and tlv.length == l and tlv.value == v + + +# Test cases are copied from +# https://github.com/lightningnetwork/lightning-rfc/blob/bdd42711014643d5b2d4cbe179677451b940a9de/01-messaging.md +def test_tlv_record_from_bytes_failures(): + # We do not count unknown even types since we are only decoding here + unexpected_eof = [ + b"\xfd", + b"\xfd\x01", + b"\xfd\x00\x01\x00", + b"\xfd\x01\x01", + b"\x0f\xfd", + b"\x0f\xfd\x26", + b"\x0f\xfd\x26\x02", + b"\x0f\xfd\x00\x01\x00", + b"\x0f\xfd\x02\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", + ] + + for v in unexpected_eof: + with pytest.raises(ValueError, match="Wrong tlv message format. Unexpected EOF"): + TLVRecord.from_bytes(v) + + +def test_tlv_record_from_bytes_wrong_types(): + # If the provided message is not in bytes, from_bytes will fail + with pytest.raises(TypeError, match="message must be bytes"): + TLVRecord.from_bytes("random_message") + + +def test_networks_tlv(): + # Networks TLV expects a list of genesis block hashes (32-byte hex str elements) or an empty list, if no network + # is supported + + # Empty can be achieved with an empty list or no networks at all + empty_ntlv = NetworksTLV() + empty_ntlv2 = NetworksTLV(networks=[]) + assert empty_ntlv.networks == empty_ntlv2.networks == [] + assert empty_ntlv.type == empty_ntlv2.type == b"\x01" + assert empty_ntlv.length == empty_ntlv2.length == b"\x00" + assert empty_ntlv.value == empty_ntlv2.value == b"" + + random_networks = [get_random_value_hex(32) for _ in range(10)] + random_network_bytearray = b"".join(bytes.fromhex(network) for network in random_networks) + ntlv_random = NetworksTLV(random_networks) + assert ntlv_random.type == b"\x01" + assert ntlv_random.length == bigsize.encode(32 * 10) + assert ntlv_random.value == random_network_bytearray + assert ntlv_random.networks == random_networks + + +def test_networks_tlv_wrong_data(): + # If networks is not a list we'll get an error + with pytest.raises(TypeError, match="networks must be a list if set"): + NetworksTLV(1) + + # If the list does not contain only 32-byte hex encoded values, it will fail + wrong_lists = [[1, 2, 3], [""], [get_random_value_hex(32), get_random_value_hex(31)]] + for networks in wrong_lists: + with pytest.raises(ValueError, match="All networks must be 32-byte hex str"): + NetworksTLV(networks) + + +def test_networks_tlv_from_bytes(): + # from_bytes from NetworksTLV expects the type to match (01) and the data a collection of 32-byte hashes (if set) + t = bigsize.encode(1) + l = bigsize.encode(128) + v = b"".join(bytes.fromhex(get_random_value_hex(32)) for _ in range(4)) + ntlv = NetworksTLV.from_bytes(t + l + v) + assert ntlv.type == t and ntlv.length == l and ntlv.value == v + + # Works for empty too + empty_l = b"\x00" + empty_ntlv = NetworksTLV.from_bytes(t + empty_l) + assert empty_ntlv.type == t and empty_ntlv.length == empty_l and empty_ntlv.value == b"" + + +def test_networks_tlv_from_bytes_wrong(): + # message must be bytes + with pytest.raises(TypeError, match="message be must a bytearray"): + NetworksTLV.from_bytes("random_message") + + # If the type is not networks, it will fail + with pytest.raises(ValueError, match="Wrong message format. types do not match"): + l = bigsize.encode(128) + v = bytes.fromhex(get_random_value_hex(32)) + NetworksTLV.from_bytes(b"\x00" + l + v) + + # Data must be multiple of 32 + # Encoding 128, data_len = 127 + with pytest.raises(ValueError, match="All networks must be 32-byte hex str"): + t = b"\x01" + l = bigsize.encode(128) + v = bytes.fromhex(get_random_value_hex(32)) + NetworksTLV.from_bytes(t + l + v[:-1]) + + # Encoding 127, data_len = 127 + with pytest.raises(ValueError, match="chains must be multiple of 32"): + t = b"\x01" + l = bigsize.encode(127) + v = bytes.fromhex(get_random_value_hex(32)) + NetworksTLV.from_bytes(t + l + v[:-1]) diff --git a/test/common/unit/net/test_utils.py b/test/common/unit/net/test_utils.py new file mode 100644 index 00000000..d74810be --- /dev/null +++ b/test/common/unit/net/test_utils.py @@ -0,0 +1,58 @@ +import pytest +import common.net.bigsize as bigsize +from common.net.utils import message_sanity_checks + + +def test_message_sanity_checks(): + # message_sanity_checks checks that: + # - A message is of the proper data type (bytes) + # - A message is at least ``min_len`` long + # - The message encoded type matches ``expected_type``: + # - If the message is a TLV the expected type is the first bigsize value of the message + # - Otherwise it is a u16. + + # Normal message + min_len = 4 + expected_type = b"\x00\x01" + message = 2 * expected_type + assert message_sanity_checks(message, expected_type, min_len) is None + + # TLV (bigsize encoded) + min_len = 3 + expected_type = bigsize.encode(1) + message = expected_type + b"\x00\x01" + assert message_sanity_checks(message, expected_type, min_len, tlv=True) is None + + +def test_message_sanity_checks_wrong_types(): + with pytest.raises(TypeError, match="message be must a bytearray"): + message_sanity_checks("random_message", None, None) + with pytest.raises(TypeError, match="expected_type be must bytes"): + message_sanity_checks(b"", "random_type", None) + with pytest.raises(TypeError, match="min_len be must int"): + message_sanity_checks(b"", b"", 1.1) + with pytest.raises(TypeError, match="tlv be must bool if set"): + message_sanity_checks(b"", b"", 1, tlv=1.1) + + +def test_message_sanity_checks_wrong_data(): + # minimum size not met + min_len = 3 + expected_type = b"\x01" + message = expected_type + b"\x00" + with pytest.raises(ValueError, match=f"message be must at least {min_len}-byte long"): + message_sanity_checks(message, expected_type, min_len) + + # Wrong type (no TLV) + min_len = 3 + expected_type = b"\x01" + message = 3 * b"\x00" + with pytest.raises(ValueError, match="Wrong message format. types do not match"): + message_sanity_checks(message, expected_type, min_len) + + # Wrong type (TLV) + min_len = 3 + expected_type = bigsize.encode(1) + message = 3 * b"\x00" + with pytest.raises(ValueError, match="Wrong message format. types do not match"): + message_sanity_checks(message, expected_type, min_len, tlv=True)