diff --git a/.coveragerc b/.coveragerc index 51f4bf59..3cb20809 100644 --- a/.coveragerc +++ b/.coveragerc @@ -1,6 +1,7 @@ [report] exclude_lines = pragma: no cover + @overload ((t|typing)\.)?TYPE_CHECKING ^\s\.\.\.\s$ def __repr__ diff --git a/mcstatus/_protocol/connection.py b/mcstatus/_protocol/connection.py deleted file mode 100644 index 72c36057..00000000 --- a/mcstatus/_protocol/connection.py +++ /dev/null @@ -1,705 +0,0 @@ -from __future__ import annotations - -import asyncio -import errno -import socket -import struct -from abc import ABC, abstractmethod -from ctypes import c_int32 as signed_int32, c_int64 as signed_int64, c_uint32 as unsigned_int32, c_uint64 as unsigned_int64 -from ipaddress import ip_address -from typing import TYPE_CHECKING, TypeAlias, cast - -import asyncio_dgram - -if TYPE_CHECKING: - from collections.abc import Iterable - - from typing_extensions import Self, SupportsIndex - - from mcstatus._net.address import Address - -__all__ = [ - "BaseAsyncConnection", - "BaseAsyncReadSyncWriteConnection", - "BaseConnection", - "BaseReadAsync", - "BaseReadSync", - "BaseSyncConnection", - "BaseWriteAsync", - "BaseWriteSync", - "Connection", - "SocketConnection", - "TCPAsyncSocketConnection", - "TCPSocketConnection", - "UDPAsyncSocketConnection", - "UDPSocketConnection", -] - -BytesConvertable: TypeAlias = "SupportsIndex | Iterable[SupportsIndex]" - - -def _ip_type(address: int | str) -> int | None: - """Determine the IP version (IPv4 or IPv6). - - :param address: - A string or integer, the IP address. Either IPv4 or IPv6 addresses may be supplied. - Integers less than 2**32 will be considered to be IPv4 by default. - :return: ``4`` or ``6`` if the IP is IPv4 or IPv6, respectively. :obj:`None` if the IP is invalid. - """ - try: - return ip_address(address).version - except ValueError: - return None - - -class BaseWriteSync(ABC): - """Base synchronous write class.""" - - __slots__ = () - - @abstractmethod - def write(self, data: Connection | str | bytearray | bytes) -> None: - """Write data to ``self``.""" - - def __repr__(self) -> str: - return f"<{self.__class__.__name__} Object>" - - @staticmethod - def _pack(format_: str, data: int) -> bytes: - """Pack data in with format in big-endian mode.""" - return struct.pack(">" + format_, data) - - def write_varint(self, value: int) -> None: - """Write varint with value ``value`` to ``self``. - - :param value: Maximum is ``2 ** 31 - 1``, minimum is ``-(2 ** 31)``. - :raises ValueError: If value is out of range. - """ - remaining = unsigned_int32(value).value - for _ in range(5): - if not remaining & -0x80: # remaining & ~0x7F == 0: - self.write(struct.pack("!B", remaining)) - if value > 2**31 - 1 or value < -(2**31): - break - return - self.write(struct.pack("!B", remaining & 0x7F | 0x80)) - remaining >>= 7 - raise ValueError(f'The value "{value}" is too big to send in a varint') - - def write_varlong(self, value: int) -> None: - """Write varlong with value ``value`` to ``self``. - - :param value: Maximum is ``2 ** 63 - 1``, minimum is ``-(2 ** 63)``. - :raises ValueError: If value is out of range. - """ - remaining = unsigned_int64(value).value - for _ in range(10): - if not remaining & -0x80: # remaining & ~0x7F == 0: - self.write(struct.pack("!B", remaining)) - if value > 2**63 - 1 or value < -(2**31): - break - return - self.write(struct.pack("!B", remaining & 0x7F | 0x80)) - remaining >>= 7 - raise ValueError(f'The value "{value}" is too big to send in a varlong') - - def write_utf(self, value: str) -> None: - """Write varint of length of ``value`` up to 32767 bytes, then write ``value`` encoded with ``UTF-8``.""" - self.write_varint(len(value)) - self.write(bytearray(value, "utf8")) - - def write_ascii(self, value: str) -> None: - """Write value encoded with ``ISO-8859-1``, then write an additional ``0x00`` at the end.""" - self.write(bytearray(value, "ISO-8859-1")) - self.write(bytearray.fromhex("00")) - - def write_short(self, value: int) -> None: - """Write 2 bytes for value ``-32768 - 32767``.""" - self.write(self._pack("h", value)) - - def write_ushort(self, value: int) -> None: - """Write 2 bytes for value ``0 - 65535 (2 ** 16 - 1)``.""" - self.write(self._pack("H", value)) - - def write_int(self, value: int) -> None: - """Write 4 bytes for value ``-2147483648 - 2147483647``.""" - self.write(self._pack("i", value)) - - def write_uint(self, value: int) -> None: - """Write 4 bytes for value ``0 - 4294967295 (2 ** 32 - 1)``.""" - self.write(self._pack("I", value)) - - def write_long(self, value: int) -> None: - """Write 8 bytes for value ``-9223372036854775808 - 9223372036854775807``.""" - self.write(self._pack("q", value)) - - def write_ulong(self, value: int) -> None: - """Write 8 bytes for value ``0 - 18446744073709551613 (2 ** 64 - 1)``.""" - self.write(self._pack("Q", value)) - - def write_bool(self, value: bool) -> None: # noqa: FBT001 # Boolean positional argument - """Write 1 byte for boolean `True` or `False`.""" - self.write(self._pack("?", value)) - - def write_buffer(self, buffer: Connection) -> None: - """Flush buffer, then write a varint of the length of the buffer's data, then write buffer data.""" - data = buffer.flush() - self.write_varint(len(data)) - self.write(data) - - -class BaseWriteAsync(ABC): - """Base synchronous write class.""" - - __slots__ = () - - @abstractmethod - async def write(self, data: Connection | str | bytearray | bytes) -> None: - """Write data to ``self``.""" - - def __repr__(self) -> str: - return f"<{self.__class__.__name__} Object>" - - @staticmethod - def _pack(format_: str, data: int) -> bytes: - """Pack data in with format in big-endian mode.""" - return struct.pack(">" + format_, data) - - async def write_varint(self, value: int) -> None: - """Write varint with value ``value`` to ``self``. - - :param value: Maximum is ``2 ** 31 - 1``, minimum is ``-(2 ** 31)``. - :raises ValueError: If value is out of range. - """ - remaining = unsigned_int32(value).value - for _ in range(5): - if not remaining & -0x80: # remaining & ~0x7F == 0: - await self.write(struct.pack("!B", remaining)) - if value > 2**31 - 1 or value < -(2**31): - break - return - await self.write(struct.pack("!B", remaining & 0x7F | 0x80)) - remaining >>= 7 - raise ValueError(f'The value "{value}" is too big to send in a varint') - - async def write_varlong(self, value: int) -> None: - """Write varlong with value ``value`` to ``self``. - - :param value: Maximum is ``2 ** 63 - 1``, minimum is ``-(2 ** 63)``. - :raises ValueError: If value is out of range. - """ - remaining = unsigned_int64(value).value - for _ in range(10): - if not remaining & -0x80: # remaining & ~0x7F == 0: - await self.write(struct.pack("!B", remaining)) - if value > 2**63 - 1 or value < -(2**31): - break - return - await self.write(struct.pack("!B", remaining & 0x7F | 0x80)) - remaining >>= 7 - raise ValueError(f'The value "{value}" is too big to send in a varlong') - - async def write_utf(self, value: str) -> None: - """Write varint of length of ``value`` up to 32767 bytes, then write ``value`` encoded with ``UTF-8``.""" - await self.write_varint(len(value)) - await self.write(bytearray(value, "utf8")) - - async def write_ascii(self, value: str) -> None: - """Write value encoded with ``ISO-8859-1``, then write an additional ``0x00`` at the end.""" - await self.write(bytearray(value, "ISO-8859-1")) - await self.write(bytearray.fromhex("00")) - - async def write_short(self, value: int) -> None: - """Write 2 bytes for value ``-32768 - 32767``.""" - await self.write(self._pack("h", value)) - - async def write_ushort(self, value: int) -> None: - """Write 2 bytes for value ``0 - 65535 (2 ** 16 - 1)``.""" - await self.write(self._pack("H", value)) - - async def write_int(self, value: int) -> None: - """Write 4 bytes for value ``-2147483648 - 2147483647``.""" - await self.write(self._pack("i", value)) - - async def write_uint(self, value: int) -> None: - """Write 4 bytes for value ``0 - 4294967295 (2 ** 32 - 1)``.""" - await self.write(self._pack("I", value)) - - async def write_long(self, value: int) -> None: - """Write 8 bytes for value ``-9223372036854775808 - 9223372036854775807``.""" - await self.write(self._pack("q", value)) - - async def write_ulong(self, value: int) -> None: - """Write 8 bytes for value ``0 - 18446744073709551613 (2 ** 64 - 1)``.""" - await self.write(self._pack("Q", value)) - - async def write_bool(self, value: bool) -> None: # noqa: FBT001 # Boolean positional argument - """Write 1 byte for boolean `True` or `False`.""" - await self.write(self._pack("?", value)) - - async def write_buffer(self, buffer: Connection) -> None: - """Flush buffer, then write a varint of the length of the buffer's data, then write buffer data.""" - data = buffer.flush() - await self.write_varint(len(data)) - await self.write(data) - - -class BaseReadSync(ABC): - """Base synchronous read class.""" - - __slots__ = () - - @abstractmethod - def read(self, length: int, /) -> bytearray: - """Read length bytes from ``self``, and return a byte array.""" - - def __repr__(self) -> str: - return f"<{self.__class__.__name__} Object>" - - @staticmethod - def _unpack(format_: str, data: bytes) -> int: - """Unpack data as bytes with format in big-endian.""" - return struct.unpack(">" + format_, bytes(data))[0] - - def read_varint(self) -> int: - """Read varint from ``self`` and return it. - - :param value: Maximum is ``2 ** 31 - 1``, minimum is ``-(2 ** 31)``. - :raises IOError: If varint received is out of range. - """ - result = 0 - for i in range(5): - part = self.read(1)[0] - result |= (part & 0x7F) << (7 * i) - if not part & 0x80: - return signed_int32(result).value - raise OSError("Received varint is too big!") - - def read_varlong(self) -> int: - """Read varlong from ``self`` and return it. - - :param value: Maximum is ``2 ** 63 - 1``, minimum is ``-(2 ** 63)``. - :raises IOError: If varint received is out of range. - """ - result = 0 - for i in range(10): - part = self.read(1)[0] - result |= (part & 0x7F) << (7 * i) - if not part & 0x80: - return signed_int64(result).value - raise OSError("Received varlong is too big!") - - def read_utf(self) -> str: - """Read up to 32767 bytes by reading a varint, then decode bytes as ``UTF-8``.""" - length = self.read_varint() - return self.read(length).decode("utf8") - - def read_ascii(self) -> str: - """Read ``self`` until last value is not zero, then return that decoded with ``ISO-8859-1``.""" - result = bytearray() - while len(result) == 0 or result[-1] != 0: - result.extend(self.read(1)) - return result[:-1].decode("ISO-8859-1") - - def read_short(self) -> int: - """Return ``-32768 - 32767``. Read 2 bytes.""" - return self._unpack("h", self.read(2)) - - def read_ushort(self) -> int: - """Return ``0 - 65535 (2 ** 16 - 1)``. Read 2 bytes.""" - return self._unpack("H", self.read(2)) - - def read_int(self) -> int: - """Return ``-2147483648 - 2147483647``. Read 4 bytes.""" - return self._unpack("i", self.read(4)) - - def read_uint(self) -> int: - """Return ``0 - 4294967295 (2 ** 32 - 1)``. 4 bytes read.""" - return self._unpack("I", self.read(4)) - - def read_long(self) -> int: - """Return ``-9223372036854775808 - 9223372036854775807``. Read 8 bytes.""" - return self._unpack("q", self.read(8)) - - def read_ulong(self) -> int: - """Return ``0 - 18446744073709551613 (2 ** 64 - 1)``. Read 8 bytes.""" - return self._unpack("Q", self.read(8)) - - def read_bool(self) -> bool: - """Return `True` or `False`. Read 1 byte.""" - return cast("bool", self._unpack("?", self.read(1))) - - def read_buffer(self) -> Connection: - """Read a varint for length, then return a new connection from length read bytes.""" - length = self.read_varint() - result = Connection() - result.receive(self.read(length)) - return result - - -class BaseReadAsync(ABC): - """Asynchronous Read connection base class.""" - - __slots__ = () - - @abstractmethod - async def read(self, length: int, /) -> bytearray: - """Read length bytes from ``self``, return a byte array.""" - - def __repr__(self) -> str: - return f"<{self.__class__.__name__} Object>" - - @staticmethod - def _unpack(format_: str, data: bytes) -> int: - """Unpack data as bytes with format in big-endian.""" - return struct.unpack(">" + format_, bytes(data))[0] - - async def read_varint(self) -> int: - """Read varint from ``self`` and return it. - - :param value: Maximum is ``2 ** 31 - 1``, minimum is ``-(2 ** 31)``. - :raises IOError: If varint received is out of range. - """ - result = 0 - for i in range(5): - part = (await self.read(1))[0] - result |= (part & 0x7F) << 7 * i - if not part & 0x80: - return signed_int32(result).value - raise OSError("Received a varint that was too big!") - - async def read_varlong(self) -> int: - """Read varlong from ``self`` and return it. - - :param value: Maximum is ``2 ** 63 - 1``, minimum is ``-(2 ** 63)``. - :raises IOError: If varint received is out of range. - """ - result = 0 - for i in range(10): - part = (await self.read(1))[0] - result |= (part & 0x7F) << (7 * i) - if not part & 0x80: - return signed_int64(result).value - raise OSError("Received varlong is too big!") - - async def read_utf(self) -> str: - """Read up to 32767 bytes by reading a varint, then decode bytes as ``UTF-8``.""" - length = await self.read_varint() - return (await self.read(length)).decode("utf8") - - async def read_ascii(self) -> str: - """Read ``self`` until last value is not zero, then return that decoded with ``ISO-8859-1``.""" - result = bytearray() - while len(result) == 0 or result[-1] != 0: - result.extend(await self.read(1)) - return result[:-1].decode("ISO-8859-1") - - async def read_short(self) -> int: - """Return ``-32768 - 32767``. Read 2 bytes.""" - return self._unpack("h", await self.read(2)) - - async def read_ushort(self) -> int: - """Return ``0 - 65535 (2 ** 16 - 1)``. Read 2 bytes.""" - return self._unpack("H", await self.read(2)) - - async def read_int(self) -> int: - """Return ``-2147483648 - 2147483647``. Read 4 bytes.""" - return self._unpack("i", await self.read(4)) - - async def read_uint(self) -> int: - """Return ``0 - 4294967295 (2 ** 32 - 1)``. 4 bytes read.""" - return self._unpack("I", await self.read(4)) - - async def read_long(self) -> int: - """Return ``-9223372036854775808 - 9223372036854775807``. Read 8 bytes.""" - return self._unpack("q", await self.read(8)) - - async def read_ulong(self) -> int: - """Return ``0 - 18446744073709551613 (2 ** 64 - 1)``. Read 8 bytes.""" - return self._unpack("Q", await self.read(8)) - - async def read_bool(self) -> bool: - """Return `True` or `False`. Read 1 byte.""" - return cast("bool", self._unpack("?", await self.read(1))) - - async def read_buffer(self) -> Connection: - """Read a varint for length, then return a new connection from length read bytes.""" - length = await self.read_varint() - result = Connection() - result.receive(await self.read(length)) - return result - - -class BaseConnection: - """Base Connection class. Implements flush, receive, and remaining.""" - - __slots__ = () - - def __repr__(self) -> str: - return f"<{self.__class__.__name__} Object>" - - def flush(self) -> bytearray: - """Raise :exc:`TypeError`, unsupported.""" - raise TypeError(f"{self.__class__.__name__} does not support flush()") - - def receive(self, _data: BytesConvertable | bytearray) -> None: - """Raise :exc:`TypeError`, unsupported.""" - raise TypeError(f"{self.__class__.__name__} does not support receive()") - - def remaining(self) -> int: - """Raise :exc:`TypeError`, unsupported.""" - raise TypeError(f"{self.__class__.__name__} does not support remaining()") - - -class BaseSyncConnection(BaseConnection, BaseReadSync, BaseWriteSync): - """Base synchronous read and write class.""" - - __slots__ = () - - -class BaseAsyncReadSyncWriteConnection(BaseConnection, BaseReadAsync, BaseWriteSync): - """Base asynchronous read and synchronous write class.""" - - __slots__ = () - - -class BaseAsyncConnection(BaseConnection, BaseReadAsync, BaseWriteAsync): - """Base asynchronous read and write class.""" - - __slots__ = () - - -class Connection(BaseSyncConnection): - """Base connection class.""" - - __slots__ = ("received", "sent") - - def __init__(self) -> None: - self.sent = bytearray() - self.received = bytearray() - - def read(self, length: int, /) -> bytearray: - """Return :attr:`.received` up to length bytes, then cut received up to that point.""" - if len(self.received) < length: - raise OSError(f"Not enough data to read! {len(self.received)} < {length}") - - result = self.received[:length] - self.received = self.received[length:] - return result - - def write(self, data: Connection | str | bytearray | bytes) -> None: - """Extend :attr:`.sent` from ``data``.""" - if isinstance(data, Connection): - data = data.flush() - if isinstance(data, str): - data = bytearray(data, "utf-8") - self.sent.extend(data) - - def receive(self, data: BytesConvertable | bytearray) -> None: - """Extend :attr:`.received` with ``data``.""" - if not isinstance(data, bytearray): - data = bytearray(data) - self.received.extend(data) - - def remaining(self) -> int: - """Return length of :attr:`.received`.""" - return len(self.received) - - def flush(self) -> bytearray: - """Return :attr:`.sent`, also clears :attr:`.sent`.""" - result, self.sent = self.sent, bytearray() - return result - - def copy(self) -> Connection: - """Return a copy of ``self``.""" - new = self.__class__() - new.receive(self.received) - new.write(self.sent) - return new - - -class SocketConnection(BaseSyncConnection): - """Socket connection.""" - - __slots__ = ("socket",) - - def __init__(self) -> None: - # These will only be None until connect is called, ignore the None type assignment - self.socket: socket.socket = None # pyright: ignore[reportAttributeAccessIssue] - - def close(self) -> None: - """Close :attr:`.socket`.""" - if self.socket is not None: # If initialized - try: - self.socket.shutdown(socket.SHUT_RDWR) - except OSError as exception: # Socket wasn't connected (nothing to shut down) - if exception.errno != errno.ENOTCONN: - raise - - self.socket.close() - - def __enter__(self) -> Self: - return self - - def __exit__(self, *_: object) -> None: - self.close() - - -class TCPSocketConnection(SocketConnection): - """TCP Connection to address. Timeout defaults to 3 seconds.""" - - __slots__ = () - - def __init__(self, addr: tuple[str | None, int], timeout: float = 3) -> None: - super().__init__() - self.socket = socket.create_connection(addr, timeout=timeout) - self.socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) - - def read(self, length: int, /) -> bytearray: - """Return length bytes read from :attr:`.socket`. Raises :exc:`IOError` when server doesn't respond.""" - result = bytearray() - while len(result) < length: - new = self.socket.recv(length - len(result)) - if len(new) == 0: - raise OSError("Server did not respond with any information!") - result.extend(new) - return result - - def write(self, data: Connection | str | bytes | bytearray) -> None: - """Send data on :attr:`.socket`.""" - if isinstance(data, Connection): - data = bytearray(data.flush()) - elif isinstance(data, str): - data = bytearray(data, "utf-8") - self.socket.send(data) - - -class UDPSocketConnection(SocketConnection): - """UDP Connection class.""" - - __slots__ = ("addr",) - - def __init__(self, addr: Address, timeout: float = 3) -> None: - super().__init__() - self.addr = addr - self.socket = socket.socket( - socket.AF_INET if _ip_type(addr[0]) == 4 else socket.AF_INET6, - socket.SOCK_DGRAM, - ) - self.socket.settimeout(timeout) - - def remaining(self) -> int: - """Always return ``65535`` (``2 ** 16 - 1``).""" # noqa: D401 # imperative mood - return 65535 - - def read(self, _length: int, /) -> bytearray: - """Return up to :meth:`.remaining` bytes. Length does nothing here.""" - result = bytearray() - while len(result) == 0: - result.extend(self.socket.recvfrom(self.remaining())[0]) - return result - - def write(self, data: Connection | str | bytes | bytearray) -> None: - """Use :attr:`.socket` to send data to :attr:`.addr`.""" - if isinstance(data, Connection): - data = bytearray(data.flush()) - elif isinstance(data, str): - data = bytearray(data, "utf-8") - self.socket.sendto(data, self.addr) - - -class TCPAsyncSocketConnection(BaseAsyncReadSyncWriteConnection): - """Asynchronous TCP Connection class.""" - - __slots__ = ("_addr", "reader", "timeout", "writer") - - def __init__(self, addr: Address, timeout: float = 3) -> None: - # These will only be None until connect is called, ignore the None type assignment - self.reader: asyncio.StreamReader = None # pyright: ignore[reportAttributeAccessIssue] - self.writer: asyncio.StreamWriter = None # pyright: ignore[reportAttributeAccessIssue] - self.timeout: float = timeout - self._addr = addr - - async def connect(self) -> None: - """Use :mod:`asyncio` to open a connection to address. Timeout is in seconds.""" - conn = asyncio.open_connection(*self._addr) - self.reader, self.writer = await asyncio.wait_for(conn, timeout=self.timeout) - if self.writer is not None: # it might be None in unittest - sock: socket.socket = self.writer.transport.get_extra_info("socket") - sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) - - async def read(self, length: int, /) -> bytearray: - """Read up to ``length`` bytes from :attr:`.reader`.""" - result = bytearray() - while len(result) < length: - new = await asyncio.wait_for(self.reader.read(length - len(result)), timeout=self.timeout) - if len(new) == 0: - raise OSError("Socket did not respond with any information!") - result.extend(new) - return result - - def write(self, data: Connection | str | bytes | bytearray) -> None: - """Write data to :attr:`.writer`.""" - if isinstance(data, Connection): - data = bytearray(data.flush()) - elif isinstance(data, str): - data = bytearray(data, "utf-8") - self.writer.write(data) - - def close(self) -> None: - """Close :attr:`.writer`.""" - if self.writer is not None: # If initialized - self.writer.close() - - async def __aenter__(self) -> Self: - await self.connect() - return self - - async def __aexit__(self, *_: object) -> None: - self.close() - - -class UDPAsyncSocketConnection(BaseAsyncConnection): - """Asynchronous UDP Connection class.""" - - __slots__ = ("_addr", "stream", "timeout") - - def __init__(self, addr: Address, timeout: float = 3) -> None: - # This will only be None until connect is called, ignore the None type assignment - self.stream: asyncio_dgram.aio.DatagramClient = None # pyright: ignore[reportAttributeAccessIssue] - self.timeout: float = timeout - self._addr = addr - - async def connect(self) -> None: - """Connect to address. Timeout is in seconds.""" - conn = asyncio_dgram.connect(self._addr) - self.stream = await asyncio.wait_for(conn, timeout=self.timeout) - - def remaining(self) -> int: - """Always return ``65535`` (``2 ** 16 - 1``).""" # noqa: D401 # imperative mood - return 65535 - - async def read(self, _length: int, /) -> bytearray: - """Read from :attr:`.stream`. Length does nothing here.""" - data, _remote_addr = await asyncio.wait_for(self.stream.recv(), timeout=self.timeout) - return bytearray(data) - - async def write(self, data: Connection | str | bytes | bytearray) -> None: - """Send data with :attr:`.stream`.""" - if isinstance(data, Connection): - data = bytearray(data.flush()) - elif isinstance(data, str): - data = bytearray(data, "utf-8") - await self.stream.send(data) - - def close(self) -> None: - """Close :attr:`.stream`.""" - if self.stream is not None: # If initialized - self.stream.close() - - async def __aenter__(self) -> Self: - await self.connect() - return self - - async def __aexit__(self, *_: object) -> None: - self.close() diff --git a/mcstatus/_protocol/io/__init__.py b/mcstatus/_protocol/io/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/mcstatus/_protocol/io/base_io.py b/mcstatus/_protocol/io/base_io.py new file mode 100644 index 00000000..c7f99d17 --- /dev/null +++ b/mcstatus/_protocol/io/base_io.py @@ -0,0 +1,777 @@ +from __future__ import annotations + +import math +import struct +from abc import ABC, abstractmethod +from enum import Enum +from typing import Literal, TYPE_CHECKING, TypeAlias, TypeVar, overload + +if TYPE_CHECKING: + from collections.abc import Awaitable, Callable + +__all__ = [ + "FLOAT_FORMATS_TYPE", + "INT_FORMATS_TYPE", + "BaseAsyncReader", + "BaseAsyncWriter", + "BaseSyncReader", + "BaseSyncWriter", + "StructFormat", +] + +T = TypeVar("T") +R = TypeVar("R") + +# region: Utils + + +def to_twos_complement(number: int, bits: int) -> int: + """Convert a given ``number`` into two's complement format with the specified number of ``bits``. + + :param number: The signed integer to convert. + :param bits: Number of bits used for the two's complement representation. + :return: The integer encoded in two's complement form within the given bit width. + :raises ValueError: + If given ``number`` is out of range, and can't be converted into twos complement format, + since it wouldn't fit into the given amount of ``bits``. + """ + value_max = 1 << (bits - 1) + value_min = value_max * -1 + # With two's complement, we have one more negative number than positive + # this means we can't be exactly at value_max, but we can be at exactly value_min + if number >= value_max or number < value_min: + raise ValueError(f"Can't convert number {number} into {bits}-bit twos complement format - out of range") + + return number + (1 << bits) if number < 0 else number + + +def from_twos_complement(number: int, bits: int) -> int: + """Convert a ``number`` from a two's complement representation with the specified number of ``bits``. + + :param number: The integer encoded in two's complement form. + :param bits: Number of bits used for the two's complement representation. + :return: The decoded signed integer. + :raises ValueError: + If given ``number`` doesn't fit into given amount of ``bits``. This likely means that you're + using the wrong number, or that the number was converted into twos complement with higher + amount of `bits`. + """ + value_max = (1 << bits) - 1 + if number < 0 or number > value_max: + raise ValueError(f"Can't convert number {number} from {bits}-bit twos complement format - out of range") + + if number & (1 << (bits - 1)) != 0: + number -= 1 << bits + + return number + + +# endregion + + +# region: Format types + + +class StructFormat(str, Enum): + """All possible struct format types used for reading and writing binary data. + + These values correspond directly to format characters accepted by the + :mod:`struct` module. + """ + + BOOL = "?" + CHAR = "c" + BYTE = "b" + UBYTE = "B" + SHORT = "h" + USHORT = "H" + INT = "i" + UINT = "I" + LONG = "l" + ULONG = "L" + FLOAT = "f" + DOUBLE = "d" + HALFFLOAT = "e" + LONGLONG = "q" + ULONGLONG = "Q" + + +INT_FORMATS_TYPE: TypeAlias = Literal[ + StructFormat.BYTE, + StructFormat.UBYTE, + StructFormat.SHORT, + StructFormat.USHORT, + StructFormat.INT, + StructFormat.UINT, + StructFormat.LONG, + StructFormat.ULONG, + StructFormat.LONGLONG, + StructFormat.ULONGLONG, +] + +FLOAT_FORMATS_TYPE: TypeAlias = Literal[ + StructFormat.FLOAT, + StructFormat.DOUBLE, + StructFormat.HALFFLOAT, +] + +# endregion + +# region: Writer classes + + +class BaseAsyncWriter(ABC): + """Base class holding asynchronous write buffer/connection interactions.""" + + __slots__ = () + + @abstractmethod + async def write(self, data: bytes | bytearray, /) -> None: + """Underlying write method, sending/storing the data. + + All of the writer functions will eventually call this method. + """ + + @overload + async def write_value(self, fmt: INT_FORMATS_TYPE, value: int, /) -> None: ... + + @overload + async def write_value(self, fmt: FLOAT_FORMATS_TYPE, value: float, /) -> None: ... + + @overload + async def write_value(self, fmt: Literal[StructFormat.BOOL], value: bool, /) -> None: ... # noqa: FBT001 + + @overload + async def write_value(self, fmt: Literal[StructFormat.CHAR], value: bytes, /) -> None: ... + + async def write_value(self, fmt: StructFormat, value: object, /) -> None: + """Write a given ``value`` as given struct format (``fmt``) in big-endian mode.""" + await self.write(struct.pack(">" + fmt.value, value)) + + async def _write_varuint(self, value: int, /, *, max_bits: int | None = None) -> None: + """Write an arbitrarily large unsigned integer using a variable-length encoding. + + This is a standard way of transmitting integers with variable length over the network, allowing + smaller numbers take up fewer bytes. + + Writing is limited to integers representable within ``max_bits`` bits. Attempting to write a larger + value will raise :class:`ValueError`. Note that limiting the value to e.g. 32 bits does not mean + that at most 4 bytes will be written. Due to the encoding overhead, values within 32 bits may require + up to 5 bytes. + + Varints encode integers using groups of 7 bits. The 7 least significant bits of each byte store data, + while the most significant bit acts as a continuation flag. If this bit is set (``1``), another byte + follows. + + The least significant group is written first, followed by increasingly significant groups, making the + encoding little-endian in 7-bit groups. + + :param value: Unsigned integer to encode. + :param max_bits: Maximum allowed bit width for the encoded value. + :raises ValueError: If ``value`` is negative or exceeds the allowed bit width. + """ + value_max = (1 << (max_bits)) - 1 if max_bits is not None else float("inf") + if value < 0 or value > value_max: + raise ValueError(f"Tried to write varint outside of the range of {max_bits}-bit int.") + + remaining = value + while True: + if remaining & ~0x7F == 0: # final byte + await self.write_value(StructFormat.UBYTE, remaining) + return + # Write only 7 least significant bits with the first bit being 1, marking there will be another byte + await self.write_value(StructFormat.UBYTE, remaining & 0x7F | 0x80) + # Subtract the value we've already sent (7 least significant bits) + remaining >>= 7 + + async def write_varint(self, value: int, /) -> None: + """Write a 32-bit signed integer using a variable-length encoding. + + The value is first converted to a 32-bit two's complement representation + and then encoded using the varint format. + + See :meth:`_write_varuint` for details about the encoding. + + :param value: Signed integer to encode. + """ + val = to_twos_complement(value, bits=32) + await self._write_varuint(val, max_bits=32) + + async def write_varlong(self, value: int, /) -> None: + """Write a 64-bit signed integer using a variable-length encoding. + + The value is first converted to a 64-bit two's complement representation + and then encoded using the varint format. + + See :meth:`_write_varuint` for details about the encoding. + + :param value: Signed integer to encode. + """ + val = to_twos_complement(value, bits=64) + await self._write_varuint(val, max_bits=64) + + async def write_bytearray(self, data: bytes | bytearray, /) -> None: + """Write an arbitrary sequence of bytes, prefixed with a varint of its size.""" + await self.write_varint(len(data)) + await self.write(data) + + async def write_ascii(self, value: str, /) -> None: + """Write ISO-8859-1 encoded string, with NULL (0x00) at the end to indicate string end.""" + data = bytes(value, "ISO-8859-1") + await self.write(data) + await self.write(bytes([0])) + + async def write_utf(self, value: str, /) -> None: + """Write a UTF-8 encoded string prefixed with its byte length as a varint. + + The maximum number of characters allowed is ``32767``. + + Individual UTF-8 characters can take up to 4 bytes, however most of the common ones take up less. Assuming the + worst case of 4 bytes per every character, at most 131068 data bytes will be written + 3 additional bytes from + the varint encoding overhead. + + :param value: String to encode. + :raises ValueError: If the string exceeds the maximum allowed length. + """ + if len(value) > 32767: + raise ValueError("Maximum character limit for writing strings is 32767 characters.") + + data = bytes(value, "utf-8") + await self.write_varint(len(data)) + await self.write(data) + + async def write_optional(self, value: T | None, /, writer: Callable[[T], Awaitable[R]]) -> R | None: + """Write a bool indicating whether ``value`` is present and, if so, serialize the value using ``writer``. + + * If ``value`` is ``None``, ``False`` is written and ``None`` is returned. + * If ``value`` is not ``None``, ``True`` is written and ``writer`` is called + with the value. The return value of ``writer`` is then forwarded. + + :param value: Optional value to serialize. + :param writer: Callable used to serialize the value when it is present. + :return: ``None`` if the value is absent, otherwise the result of ``writer``. + """ + if value is None: + await self.write_value(StructFormat.BOOL, False) # noqa: FBT003 + return None + + await self.write_value(StructFormat.BOOL, True) # noqa: FBT003 + return await writer(value) + + +class BaseSyncWriter(ABC): + """Base class holding synchronous write buffer/connection interactions.""" + + __slots__ = () + + @abstractmethod + def write(self, data: bytes | bytearray, /) -> None: + """Underlying write method, sending/storing the data. + + All of the writer functions will eventually call this method. + """ + + @overload + def write_value(self, fmt: INT_FORMATS_TYPE, value: int, /) -> None: ... + + @overload + def write_value(self, fmt: FLOAT_FORMATS_TYPE, value: float, /) -> None: ... + + @overload + def write_value(self, fmt: Literal[StructFormat.BOOL], value: bool, /) -> None: ... # noqa: FBT001 + + @overload + def write_value(self, fmt: Literal[StructFormat.CHAR], value: bytes, /) -> None: ... + + def write_value(self, fmt: StructFormat, value: object, /) -> None: + """Write a given ``value`` as given struct format (``fmt``) in big-endian mode.""" + self.write(struct.pack(">" + fmt.value, value)) + + def _write_varuint(self, value: int, /, *, max_bits: int | None = None) -> None: + """Write an arbitrarily large unsigned integer using a variable-length encoding. + + This is a standard way of transmitting integers with variable length over the network, allowing + smaller numbers take up fewer bytes. + + Writing is limited to integers representable within ``max_bits`` bits. Attempting to write a larger + value will raise :class:`ValueError`. Note that limiting the value to e.g. 32 bits does not mean + that at most 4 bytes will be written. Due to the encoding overhead, values within 32 bits may require + up to 5 bytes. + + Varints encode integers using groups of 7 bits. The 7 least significant bits of each byte store data, + while the most significant bit acts as a continuation flag. If this bit is set (``1``), another byte + follows. + + The least significant group is written first, followed by increasingly significant groups, making the + encoding little-endian in 7-bit groups. + + :param value: Unsigned integer to encode. + :param max_bits: Maximum allowed bit width for the encoded value. + :raises ValueError: If ``value`` is negative or exceeds the allowed bit width. + """ + value_max = (1 << (max_bits)) - 1 if max_bits is not None else float("inf") + if value < 0 or value > value_max: + raise ValueError(f"Tried to write varint outside of the range of {max_bits}-bit int.") + + remaining = value + while True: + if remaining & ~0x7F == 0: # final byte + self.write_value(StructFormat.UBYTE, remaining) + return + # Write only 7 least significant bits with the first bit being 1, marking there will be another byte + self.write_value(StructFormat.UBYTE, remaining & 0x7F | 0x80) + # Subtract the value we've already sent (7 least significant bits) + remaining >>= 7 + + def write_varint(self, value: int, /) -> None: + """Write a 32-bit signed integer using a variable-length encoding. + + The value is first converted to a 32-bit two's complement representation + and then encoded using the varint format. + + See :meth:`_write_varuint` for details about the encoding. + + :param value: Signed integer to encode. + """ + val = to_twos_complement(value, bits=32) + self._write_varuint(val, max_bits=32) + + def write_varlong(self, value: int, /) -> None: + """Write a 64-bit signed integer using a variable-length encoding. + + The value is first converted to a 64-bit two's complement representation + and then encoded using the varint format. + + See :meth:`_write_varuint` for details about the encoding. + + :param value: Signed integer to encode. + """ + val = to_twos_complement(value, bits=64) + self._write_varuint(val, max_bits=64) + + def write_bytearray(self, data: bytes | bytearray, /) -> None: + """Write an arbitrary sequence of bytes, prefixed with a varint of its size.""" + self.write_varint(len(data)) + self.write(data) + + def write_ascii(self, value: str, /) -> None: + """Write ISO-8859-1 encoded string, with NULL (0x00) at the end to indicate string end.""" + data = bytes(value, "ISO-8859-1") + self.write(data) + self.write(bytes([0])) + + def write_utf(self, value: str, /) -> None: + """Write a UTF-8 encoded string prefixed with its byte length as a varint. + + The maximum number of characters allowed is ``32767``. + + Individual UTF-8 characters can take up to 4 bytes, however most of the common ones take up less. Assuming the + worst case of 4 bytes per every character, at most 131068 data bytes will be written + 3 additional bytes from + the varint encoding overhead. + + :param value: String to encode. + :raises ValueError: If the string exceeds the maximum allowed length. + """ + if len(value) > 32767: + raise ValueError("Maximum character limit for writing strings is 32767 characters.") + + data = bytes(value, "utf-8") + self.write_varint(len(data)) + self.write(data) + + def write_optional(self, value: T | None, /, writer: Callable[[T], R]) -> R | None: + """Write a bool indicating whether ``value`` is present and, if so, serialize the value using ``writer``. + + * If ``value`` is ``None``, ``False`` is written and ``None`` is returned. + * If ``value`` is not ``None``, ``True`` is written and ``writer`` is called + with the value. The return value of ``writer`` is then forwarded. + + :param value: Optional value to serialize. + :param writer: Callable used to serialize the value when it is present. + :return: ``None`` if the value is absent, otherwise the result of ``writer``. + """ + if value is None: + self.write_value(StructFormat.BOOL, False) # noqa: FBT003 + return None + + self.write_value(StructFormat.BOOL, True) # noqa: FBT003 + return writer(value) + + +# endregion +# region: Reader classes + + +class BaseAsyncReader(ABC): + """Base class holding asynchronous read buffer/connection interactions.""" + + __slots__ = () + + @abstractmethod + async def read(self, length: int, /) -> bytes: + """Underlying read method, obtaining the raw data. + + All of the reader functions will eventually call this method. + """ + + @overload + async def read_value(self, fmt: INT_FORMATS_TYPE, /) -> int: ... + + @overload + async def read_value(self, fmt: FLOAT_FORMATS_TYPE, /) -> float: ... + + @overload + async def read_value(self, fmt: Literal[StructFormat.BOOL], /) -> bool: ... + + @overload + async def read_value(self, fmt: Literal[StructFormat.CHAR], /) -> bytes: ... + + async def read_value(self, fmt: StructFormat, /) -> object: + """Read a value as given struct format (``fmt``) in big-endian mode. + + The amount of bytes to read will be determined based on the struct format automatically. + """ + length = struct.calcsize(fmt.value) + data = await self.read(length) + unpacked = struct.unpack(">" + fmt.value, data) + return unpacked[0] + + async def _read_varuint(self, *, max_bits: int | None = None) -> int: + """Read an arbitrarily large unsigned integer using a variable-length encoding. + + This is a standard way of transmitting integers with variable length over the network, + allowing smaller numbers to take fewer bytes. + + Reading is limited to integers representable within ``max_bits`` bits. Attempting to read + a larger value will raise :class:`OSError`. Note that limiting the value to e.g. 32 bits + does not mean that at most 4 bytes will be read. Due to the encoding overhead, values + within 32 bits may require up to 5 bytes. + + Varints encode integers using groups of 7 bits. The 7 least significant bits of each byte + store data, while the most significant bit acts as a continuation flag. If this bit is set + (``1``), another byte follows. + + The least significant group is read first, followed by increasingly significant groups, + making the encoding little-endian in 7-bit groups. + + :param max_bits: Maximum allowed bit width for the decoded value. + :raises OSError: + * If the decoded value exceeds the allowed bit width. + * If more bytes are received than can possibly encode a value constrained by ``max_bits``. + :return: The decoded unsigned integer. + """ + value_max = (1 << (max_bits)) - 1 if max_bits is not None else float("inf") + # Varints carry 7 value bits per byte, so this is the exact maximum encoded length. + byte_limit = math.ceil(max_bits / 7) if max_bits is not None else None + + result = 0 + i = 0 + while byte_limit is None or i < byte_limit: + byte = await self.read_value(StructFormat.UBYTE) + # Read 7 least significant value bits in this byte, and shift them appropriately to be in the right place + # then simply add them (OR) as additional 7 most significant bits in our result + result |= (byte & 0x7F) << (7 * i) + + # Ensure that we stop reading and raise an error if the size gets over the maximum + # (if the current amount of bits is higher than allowed size in bits) + if result > value_max: + raise OSError(f"Received varint was outside the range of {max_bits}-bit int.") + + # If the most significant bit is 0, we should stop reading + if not byte & 0x80: + break + + i += 1 + else: + raise OSError( + f"Received varint had too many bytes for {max_bits}-bit int (continuation bit set on byte {byte_limit})." + ) + + return result + + async def read_varint(self) -> int: + """Read a 32-bit signed integer using a variable-length encoding. + + The value is read as an unsigned varint and then converted from a + 32-bit two's complement representation. + + See :meth:`_read_varuint` for details about the encoding. + + :return: Decoded signed integer. + """ + unsigned_num = await self._read_varuint(max_bits=32) + return from_twos_complement(unsigned_num, bits=32) + + async def read_varlong(self) -> int: + """Read a 64-bit signed integer using a variable-length encoding. + + The value is read as an unsigned varint and then converted from a + 64-bit two's complement representation. + + See :meth:`_read_varuint` for details about the encoding. + + :return: Decoded signed integer. + """ + unsigned_num = await self._read_varuint(max_bits=64) + return from_twos_complement(unsigned_num, bits=64) + + async def read_bytearray(self, /) -> bytes: + """Read a sequence of bytes prefixed with its length encoded as a varint. + + :raises OSError: If the encoded length prefix is negative. + :return: The decoded byte sequence. + """ + length = await self.read_varint() + if length < 0: + raise OSError(f"Length prefix for byte arrays must be non-negative, got {length}.") + return await self.read(length) + + async def read_ascii(self) -> str: + """Read ISO-8859-1 encoded string, until we encounter NULL (0x00) at the end indicating string end. + + Bytes are read until a NULL terminator is encountered. The terminator itself is not included in the + returned string. There is no limit that can be set for how long this string can end up being. + + :return: Decoded string. + """ + # Keep reading bytes until we find NULL + result = bytearray() + while len(result) == 0 or result[-1] != 0: + byte = await self.read(1) + result.extend(byte) + return result[:-1].decode("ISO-8859-1") + + async def read_utf(self) -> str: + """Read a UTF-8 encoded string prefixed with its byte length as a varint. + + The maximum number of characters allowed is ``32767``. + + Individual UTF-8 characters can take up to 4 bytes, however most of the common ones take up less. Assuming the + worst case of 4 bytes per every character, at most 131068 data bytes will be read + 3 additional bytes from + the varint encoding overhead. + + :raises OSError: + * If the length prefix is negative. + * If the length prefix exceeds the maximum of ``131068``, the string will not be read at all, + and the error will be raised immediately after reading the prefix. + * If the decoded string contains more than ``32767`` characters. In this case the data is still + fully read because it fits within the byte limit. This behavior mirrors Minecraft's implementation. + + :return: Decoded UTF-8 string. + """ + length = await self.read_varint() + if length < 0: + raise OSError(f"Length prefix for utf strings must be non-negative, got {length}.") + if length > 131068: + raise OSError(f"Maximum read limit for utf strings is 131068 bytes, got {length}.") + + data = await self.read(length) + chars = data.decode("utf-8") + + if len(chars) > 32767: + raise OSError(f"Maximum read limit for utf strings is 32767 characters, got {len(chars)}.") + + return chars + + async def read_optional(self, reader: Callable[[], Awaitable[R]]) -> R | None: + """Read a boolean indicating whether a value is present and, if so, deserialize it using ``reader``. + + * If ``False`` is read, nothing further is read and ``None`` is returned. + * If ``True`` is read, ``reader`` is called and its return value is forwarded. + + :param reader: Callable used to read the value when it is present. + :return: ``None`` if the value is absent, otherwise the result of ``reader``. + """ + if not await self.read_value(StructFormat.BOOL): + return None + + return await reader() + + +class BaseSyncReader(ABC): + """Base class holding synchronous read buffer/connection interactions.""" + + __slots__ = () + + @abstractmethod + def read(self, length: int, /) -> bytes: + """Underlying read method, obtaining the raw data. + + All of the reader functions will eventually call this method. + """ + + @overload + def read_value(self, fmt: INT_FORMATS_TYPE, /) -> int: ... + + @overload + def read_value(self, fmt: FLOAT_FORMATS_TYPE, /) -> float: ... + + @overload + def read_value(self, fmt: Literal[StructFormat.BOOL], /) -> bool: ... + + @overload + def read_value(self, fmt: Literal[StructFormat.CHAR], /) -> bytes: ... + + def read_value(self, fmt: StructFormat, /) -> object: + """Read a value as given struct format (``fmt``) in big-endian mode. + + The amount of bytes to read will be determined based on the struct format automatically. + """ + length = struct.calcsize(fmt.value) + data = self.read(length) + unpacked = struct.unpack(">" + fmt.value, data) + return unpacked[0] + + def _read_varuint(self, *, max_bits: int | None = None) -> int: + """Read an arbitrarily large unsigned integer using a variable-length encoding. + + This is a standard way of transmitting integers with variable length over the network, + allowing smaller numbers to take fewer bytes. + + Reading is limited to integers representable within ``max_bits`` bits. Attempting to read + a larger value will raise :class:`OSError`. Note that limiting the value to e.g. 32 bits + does not mean that at most 4 bytes will be read. Due to the encoding overhead, values + within 32 bits may require up to 5 bytes. + + Varints encode integers using groups of 7 bits. The 7 least significant bits of each byte + store data, while the most significant bit acts as a continuation flag. If this bit is set + (``1``), another byte follows. + + The least significant group is read first, followed by increasingly significant groups, + making the encoding little-endian in 7-bit groups. + + :param max_bits: Maximum allowed bit width for the decoded value. + :raises OSError: + * If the decoded value exceeds the allowed bit width. + * If more bytes are received than can possibly encode a value constrained by ``max_bits``. + :return: The decoded unsigned integer. + """ + value_max = (1 << (max_bits)) - 1 if max_bits is not None else float("inf") + # Varints carry 7 value bits per byte, so this is the exact maximum encoded length. + byte_limit = math.ceil(max_bits / 7) if max_bits is not None else None + + result = 0 + i = 0 + while byte_limit is None or i < byte_limit: + byte = self.read_value(StructFormat.UBYTE) + # Read 7 least significant value bits in this byte, and shift them appropriately to be in the right place + # then simply add them (OR) as additional 7 most significant bits in our result + result |= (byte & 0x7F) << (7 * i) + + # Ensure that we stop reading and raise an error if the size gets over the maximum + # (if the current amount of bits is higher than allowed size in bits) + if result > value_max: + raise OSError(f"Received varint was outside the range of {max_bits}-bit int.") + + # If the most significant bit is 0, we should stop reading + if not byte & 0x80: + break + + i += 1 + else: + raise OSError( + f"Received varint had too many bytes for {max_bits}-bit int (continuation bit set on byte {byte_limit})." + ) + + return result + + def read_varint(self) -> int: + """Read a 32-bit signed integer using a variable-length encoding. + + The value is read as an unsigned varint and then converted from a + 32-bit two's complement representation. + + See :meth:`_read_varuint` for details about the encoding. + + :return: Decoded signed integer. + """ + unsigned_num = self._read_varuint(max_bits=32) + return from_twos_complement(unsigned_num, bits=32) + + def read_varlong(self) -> int: + """Read a 64-bit signed integer using a variable-length encoding. + + The value is read as an unsigned varint and then converted from a + 64-bit two's complement representation. + + See :meth:`_read_varuint` for details about the encoding. + + :return: Decoded signed integer. + """ + unsigned_num = self._read_varuint(max_bits=64) + return from_twos_complement(unsigned_num, bits=64) + + def read_bytearray(self) -> bytes: + """Read a sequence of bytes prefixed with its length encoded as a varint. + + :raises OSError: If the encoded length prefix is negative. + :return: The decoded byte sequence. + """ + length = self.read_varint() + if length < 0: + raise OSError(f"Length prefix for byte arrays must be non-negative, got {length}.") + return self.read(length) + + def read_ascii(self) -> str: + """Read ISO-8859-1 encoded string, until we encounter NULL (0x00) at the end indicating string end. + + Bytes are read until a NULL terminator is encountered. The terminator itself is not included in the + returned string. There is no limit that can be set for how long this string can end up being. + + :return: Decoded string. + """ + # Keep reading bytes until we find NULL + result = bytearray() + while len(result) == 0 or result[-1] != 0: + byte = self.read(1) + result.extend(byte) + return result[:-1].decode("ISO-8859-1") + + def read_utf(self) -> str: + """Read a UTF-8 encoded string prefixed with its byte length as a varint. + + The maximum number of characters allowed is ``32767``. + + Individual UTF-8 characters can take up to 4 bytes, however most of the common ones take up less. Assuming the + worst case of 4 bytes per every character, at most 131068 data bytes will be read + 3 additional bytes from + the varint encoding overhead. + + :raises OSError: + * If the length prefix is negative. + * If the length prefix exceeds the maximum of ``131068``, the string will not be read at all, + and the error will be raised immediately after reading the prefix. + * If the decoded string contains more than ``32767`` characters. In this case the data is still + fully read because it fits within the byte limit. This behavior mirrors Minecraft's implementation. + + :return: Decoded UTF-8 string. + """ + length = self.read_varint() + if length < 0: + raise OSError(f"Length prefix for utf strings must be non-negative, got {length}.") + if length > 131068: + raise OSError(f"Maximum read limit for utf strings is 131068 bytes, got {length}.") + + data = self.read(length) + chars = data.decode("utf-8") + + if len(chars) > 32767: + raise OSError(f"Maximum read limit for utf strings is 32767 characters, got {len(chars)}.") + + return chars + + def read_optional(self, reader: Callable[[], R]) -> R | None: + """Read a boolean indicating whether a value is present and, if so, deserialize it using ``reader``. + + * If ``False`` is read, nothing further is read and ``None`` is returned. + * If ``True`` is read, ``reader`` is called and its return value is forwarded. + + :param reader: Callable used to read the value when it is present. + :return: ``None`` if the value is absent, otherwise the result of ``reader``. + """ + if not self.read_value(StructFormat.BOOL): + return None + + return reader() + + +# endregion diff --git a/mcstatus/_protocol/io/buffer.py b/mcstatus/_protocol/io/buffer.py new file mode 100644 index 00000000..e0d594a2 --- /dev/null +++ b/mcstatus/_protocol/io/buffer.py @@ -0,0 +1,120 @@ +from __future__ import annotations + +from typing import SupportsIndex, TYPE_CHECKING, final, overload + +from mcstatus._protocol.io.base_io import BaseSyncReader, BaseSyncWriter + +if TYPE_CHECKING: + from collections.abc import Iterable + + from _typeshed import ReadableBuffer + from typing_extensions import override +else: + override = lambda f: f # noqa: E731 + +__all__ = ["Buffer"] + + +@final +class Buffer(BaseSyncWriter, BaseSyncReader, bytearray): + """In-memory bytearray-like buffer supporting the common read/write operations.""" + + __slots__ = ("pos",) + + @overload + def __init__(self) -> None: ... + @overload + def __init__(self, ints: Iterable[SupportsIndex] | SupportsIndex | ReadableBuffer, /) -> None: ... + @overload + def __init__(self, string: str, /, encoding: str, errors: str = "strict") -> None: ... + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.pos = 0 + + @override + def write(self, data: bytes | bytearray, /) -> None: + """Write/Store given ``data`` into the buffer.""" + self.extend(data) + + @override + def read(self, length: int, /) -> bytes: + """Read data stored in the buffer. + + Reading data doesn't remove that data, rather that data is treated as already read, and + next read will start from the first unread byte. If freeing the data is necessary, check + the :meth:`clear` method. + + :param length: + Amount of bytes to be read. + + If the requested amount can't be read (buffer doesn't contain that much data/buffer + doesn't contain any data), an :exc:`OSError` will be raised. + + If there were some data in the buffer, but it was less than requested, this remaining + data will still be depleted and the partial data that was read will be a part of the + error message in the :exc:`OSError`. This behavior is here to mimic reading from a real + socket connection. + + If ``length`` is negative, an :exc:`OSError` will be raised. + """ + if length < 0: + raise OSError(f"Requested to read a negative amount of data: {length}.") + + end = self.pos + length + + if end > len(self): + data = bytearray(self.unread_view()) + bytes_read = len(self) - self.pos + self.pos = len(self) + raise OSError( + "Requested to read more data than available." + f" Read {bytes_read} bytes: {data}, out of {length} requested bytes." + ) + + try: + return bytes(self.unread_view()[:length]) + finally: + self.pos = end + + @override + def clear(self, *, only_already_read: bool = False) -> None: + """Clear out the stored data and reset position. + + :param only_already_read: + When set to ``True``, only the data that was already marked as read will be cleared, + and the position will be reset (to start at the remaining data). This can be useful + for avoiding needlessly storing large amounts of data in memory, if this data is no + longer useful. + + Otherwise, if set to ``False``, all of the data is cleared, and the position is reset, + essentially resulting in a blank buffer. + """ + if only_already_read: + del self[: self.pos] + else: + super().clear() + self.pos = 0 + + def reset(self) -> None: + """Reset the position in the buffer. + + Since the buffer doesn't automatically clear the already read data, it is possible to simply + reset the position and read the data it contains again. + """ + self.pos = 0 + + def unread_view(self) -> memoryview: + """Return a zero-copy view of unread data without modifying buffer state.""" + return memoryview(self)[self.pos :] + + def flush(self) -> bytes: + """Read all of the remaining data in the buffer and clear it out.""" + data = bytes(self.unread_view()) + self.clear() + return data + + @property + def remaining(self) -> int: + """Get the amount of bytes that's still remaining in the buffer to be read.""" + return len(self) - self.pos diff --git a/mcstatus/_protocol/io/connection.py b/mcstatus/_protocol/io/connection.py new file mode 100644 index 00000000..a84e9874 --- /dev/null +++ b/mcstatus/_protocol/io/connection.py @@ -0,0 +1,236 @@ +from __future__ import annotations + +import asyncio +import errno +import socket +from ipaddress import ip_address +from typing import TYPE_CHECKING, TypeAlias, final + +import asyncio_dgram + +from mcstatus._protocol.io.base_io import BaseAsyncReader, BaseAsyncWriter, BaseSyncReader, BaseSyncWriter + +if TYPE_CHECKING: + from collections.abc import Iterable + + from typing_extensions import Self, SupportsIndex, override + + from mcstatus._net.address import Address +else: + override = lambda f: f # noqa: E731 + +__all__ = [ + "BaseAsyncConnection", + "BaseSyncConnection", + "TCPAsyncSocketConnection", + "TCPSocketConnection", + "UDPAsyncSocketConnection", + "UDPSocketConnection", +] + +BytesConvertable: TypeAlias = "SupportsIndex | Iterable[SupportsIndex]" + + +class BaseSyncConnection(BaseSyncReader, BaseSyncWriter): + """Base synchronous read and write class.""" + + __slots__ = () + + +class BaseAsyncConnection(BaseAsyncReader, BaseAsyncWriter): + """Base asynchronous read and write class.""" + + __slots__ = () + + +class _SocketConnection(BaseSyncConnection): + """Socket connection.""" + + __slots__ = ("socket",) + + def __init__(self) -> None: + # These will only be None until connect is called, ignore the None type assignment + self.socket: socket.socket = None # pyright: ignore[reportAttributeAccessIssue] + + def close(self) -> None: + """Close :attr:`.socket`.""" + if self.socket is not None: # If initialized + try: + self.socket.shutdown(socket.SHUT_RDWR) + except OSError as exception: # Socket wasn't connected (nothing to shut down) + if exception.errno != errno.ENOTCONN: + raise + + self.socket.close() + + def __enter__(self) -> Self: + return self + + def __exit__(self, *_: object) -> None: + self.close() + + +@final +class TCPSocketConnection(_SocketConnection): + """TCP Connection to address. Timeout defaults to 3 seconds.""" + + __slots__ = () + + def __init__(self, addr: tuple[str | None, int], timeout: float = 3) -> None: + super().__init__() + self.socket = socket.create_connection(addr, timeout=timeout) + self.socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + + @override + def read(self, length: int, /) -> bytes: + """Return length bytes read from :attr:`.socket`. Raises :exc:`OSError` when server doesn't respond.""" + result = bytearray() + while len(result) < length: + new = self.socket.recv(length - len(result)) + if len(new) == 0: + raise OSError("Server did not respond with any information!") + result.extend(new) + return bytes(result) + + def write(self, data: bytes | bytearray, /) -> None: + """Send data on :attr:`.socket`.""" + self.socket.sendall(data) + + +@final +class UDPSocketConnection(_SocketConnection): + """UDP Connection class.""" + + __slots__ = ("addr",) + + def __init__(self, addr: Address, timeout: float = 3) -> None: + super().__init__() + self.addr = addr + self.socket = socket.socket( + socket.AF_INET if ip_address(addr[0]).version == 4 else socket.AF_INET6, + socket.SOCK_DGRAM, + ) + self.socket.settimeout(timeout) + + @property + def remaining(self) -> int: + """Always return ``65535`` (``2 ** 16 - 1``).""" + return 65535 + + @override + def read(self, _length: int, /) -> bytes: + """Return up to :meth:`.remaining` bytes. Length does nothing here.""" + result = b"" + while len(result) == 0: + result = self.socket.recvfrom(self.remaining)[0] + return result + + @override + def write(self, data: bytes | bytearray, /) -> None: + """Use :attr:`.socket` to send data to :attr:`.addr`.""" + self.socket.sendto(data, self.addr) + + +@final +class TCPAsyncSocketConnection(BaseAsyncConnection): + """Asynchronous TCP Connection class.""" + + __slots__ = ("_addr", "reader", "timeout", "writer") + + def __init__(self, addr: Address, timeout: float = 3) -> None: + # These will only be None until connect is called, ignore the None type assignment + self.reader: asyncio.StreamReader = None # pyright: ignore[reportAttributeAccessIssue] + self.writer: asyncio.StreamWriter = None # pyright: ignore[reportAttributeAccessIssue] + self.timeout: float = timeout + self._addr = addr + + async def connect(self) -> None: + """Use :mod:`asyncio` to open a connection to address. Timeout is in seconds.""" + conn = asyncio.open_connection(*self._addr) + self.reader, self.writer = await asyncio.wait_for(conn, timeout=self.timeout) + if self.writer is not None: # it might be None in unittest + sock: socket.socket = self.writer.transport.get_extra_info("socket") + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + + @override + async def read(self, length: int, /) -> bytes: + """Read up to ``length`` bytes from :attr:`.reader`.""" + result = bytearray() + while len(result) < length: + new = await asyncio.wait_for(self.reader.read(length - len(result)), timeout=self.timeout) + if len(new) == 0: + # No information at all + if len(result) == 0: + raise OSError("Server did not respond with any information!") + # We did get a few bytes, but we requested more + raise OSError( + f"Server stopped responding (got {len(result)} bytes, but expected {length} bytes)." + f" Partial obtained data: {result!r}" + ) + result.extend(new) + return bytes(result) + + @override + async def write(self, data: bytes | bytearray, /) -> None: + """Write data to :attr:`.writer`.""" + self.writer.write(data) + await self.writer.drain() + + async def close(self) -> None: + """Close :attr:`.writer`.""" + if self.writer is not None: # If initialized + self.writer.close() + await self.writer.wait_closed() + + async def __aenter__(self) -> Self: + await self.connect() + return self + + async def __aexit__(self, *_: object) -> None: + await self.close() + + +@final +class UDPAsyncSocketConnection(BaseAsyncConnection): + """Asynchronous UDP Connection class.""" + + __slots__ = ("_addr", "stream", "timeout") + + def __init__(self, addr: Address, timeout: float = 3) -> None: + # This will only be None until connect is called, ignore the None type assignment + self.stream: asyncio_dgram.aio.DatagramClient = None # pyright: ignore[reportAttributeAccessIssue] + self.timeout: float = timeout + self._addr = addr + + async def connect(self) -> None: + """Connect to address. Timeout is in seconds.""" + conn = asyncio_dgram.connect(self._addr) + self.stream = await asyncio.wait_for(conn, timeout=self.timeout) + + @property + def remaining(self) -> int: + """Always return ``65535`` (``2 ** 16 - 1``).""" + return 65535 + + @override + async def read(self, _length: int, /) -> bytes: + """Read from :attr:`.stream`. Length does nothing here.""" + data, _remote_addr = await asyncio.wait_for(self.stream.recv(), timeout=self.timeout) + return data + + @override + async def write(self, data: bytes | bytearray, /) -> None: + """Send data with :attr:`.stream`.""" + await self.stream.send(bytes(data)) + + def close(self) -> None: + """Close :attr:`.stream`.""" + if self.stream is not None: # If initialized + self.stream.close() + + async def __aenter__(self) -> Self: + await self.connect() + return self + + async def __aexit__(self, *_: object) -> None: + self.close() diff --git a/mcstatus/_protocol/java_client.py b/mcstatus/_protocol/java_client.py index e27b3d95..05785e6b 100644 --- a/mcstatus/_protocol/java_client.py +++ b/mcstatus/_protocol/java_client.py @@ -7,13 +7,15 @@ from time import perf_counter from typing import TYPE_CHECKING, final -from mcstatus._protocol.connection import Connection, TCPAsyncSocketConnection, TCPSocketConnection +from mcstatus._protocol.io.base_io import StructFormat +from mcstatus._protocol.io.buffer import Buffer from mcstatus.responses import JavaStatusResponse if TYPE_CHECKING: from collections.abc import Awaitable from mcstatus._net.address import Address + from mcstatus._protocol.io.connection import TCPAsyncSocketConnection, TCPSocketConnection from mcstatus.responses._raw import RawJavaResponse __all__ = ["AsyncJavaClient", "JavaClient"] @@ -32,16 +34,15 @@ def __post_init__(self) -> None: if self.ping_token is None: self.ping_token = random.randint(0, (1 << 63) - 1) - def handshake(self) -> None: - """Write the initial handshake packet to the connection.""" - packet = Connection() + def _build_handshake_packet(self) -> Buffer: + """Build the initial handshake packet.""" + packet = Buffer() packet.write_varint(0) packet.write_varint(self.version) packet.write_utf(self.address.host) - packet.write_ushort(self.address.port) + packet.write_value(StructFormat.USHORT, self.address.port) packet.write_varint(1) # Intention to query status - - self.connection.write_buffer(packet) + return packet @abstractmethod def read_status(self) -> JavaStatusResponse | Awaitable[JavaStatusResponse]: @@ -53,7 +54,7 @@ def test_ping(self) -> float | Awaitable[float]: """Send a ping token and measure the latency.""" raise NotImplementedError - def _handle_status_response(self, response: Connection, start: float, end: float) -> JavaStatusResponse: + def _handle_status_response(self, response: Buffer, start: float, end: float) -> JavaStatusResponse: """Given a response buffer (already read from connection), parse and build the JavaStatusResponse.""" if response.read_varint() != 0: raise OSError("Received invalid status response packet.") @@ -68,11 +69,11 @@ def _handle_status_response(self, response: Connection, start: float, end: float except KeyError as e: raise OSError("Received invalid status response") from e - def _handle_ping_response(self, response: Connection, start: float, end: float) -> float: + def _handle_ping_response(self, response: Buffer, start: float, end: float) -> float: """Given a ping response buffer, validate token and compute latency.""" if response.read_varint() != 1: raise OSError("Received invalid ping response packet.") - received_token = response.read_long() + received_token = response.read_value(StructFormat.LONGLONG) if received_token != self.ping_token: raise OSError(f"Received mangled ping response (expected token {self.ping_token}, got {received_token})") return (end - start) * 1000 @@ -83,26 +84,30 @@ def _handle_ping_response(self, response: Connection, start: float, end: float) class JavaClient(_BaseJavaClient): connection: TCPSocketConnection # pyright: ignore[reportIncompatibleVariableOverride] + def handshake(self) -> None: + """Write the initial handshake packet to the connection.""" + self.connection.write_bytearray(self._build_handshake_packet()) + def read_status(self) -> JavaStatusResponse: """Send the status request and read the response.""" - request = Connection() + request = Buffer() request.write_varint(0) # Request status - self.connection.write_buffer(request) + self.connection.write_bytearray(request) start = perf_counter() - response = self.connection.read_buffer() + response = Buffer(self.connection.read_bytearray()) end = perf_counter() return self._handle_status_response(response, start, end) def test_ping(self) -> float: """Send a ping token and measure the latency.""" - request = Connection() + request = Buffer() request.write_varint(1) # Test ping - request.write_long(self.ping_token) + request.write_value(StructFormat.LONGLONG, self.ping_token) start = perf_counter() - self.connection.write_buffer(request) + self.connection.write_bytearray(request) - response = self.connection.read_buffer() + response = Buffer(self.connection.read_bytearray()) end = perf_counter() return self._handle_ping_response(response, start, end) @@ -112,25 +117,29 @@ def test_ping(self) -> float: class AsyncJavaClient(_BaseJavaClient): connection: TCPAsyncSocketConnection # pyright: ignore[reportIncompatibleVariableOverride] + async def handshake(self) -> None: + """Write the initial handshake packet to the connection.""" + await self.connection.write_bytearray(self._build_handshake_packet()) + async def read_status(self) -> JavaStatusResponse: """Send the status request and read the response.""" - request = Connection() + request = Buffer() request.write_varint(0) # Request status - self.connection.write_buffer(request) + await self.connection.write_bytearray(request) start = perf_counter() - response = await self.connection.read_buffer() + response = Buffer(await self.connection.read_bytearray()) end = perf_counter() return self._handle_status_response(response, start, end) async def test_ping(self) -> float: """Send a ping token and measure the latency.""" - request = Connection() + request = Buffer() request.write_varint(1) # Test ping - request.write_long(self.ping_token) + request.write_value(StructFormat.LONGLONG, self.ping_token) start = perf_counter() - self.connection.write_buffer(request) + await self.connection.write_bytearray(request) - response = await self.connection.read_buffer() + response = Buffer(await self.connection.read_bytearray()) end = perf_counter() return self._handle_ping_response(response, start, end) diff --git a/mcstatus/_protocol/legacy_client.py b/mcstatus/_protocol/legacy_client.py index 2e2270de..a90274bd 100644 --- a/mcstatus/_protocol/legacy_client.py +++ b/mcstatus/_protocol/legacy_client.py @@ -1,6 +1,7 @@ from time import perf_counter -from mcstatus._protocol.connection import BaseAsyncReadSyncWriteConnection, BaseSyncConnection +from mcstatus._protocol.io.base_io import StructFormat +from mcstatus._protocol.io.connection import BaseAsyncConnection, BaseSyncConnection from mcstatus.responses import LegacyStatusResponse __all__ = ["AsyncLegacyClient", "LegacyClient"] @@ -35,24 +36,24 @@ def read_status(self) -> LegacyStatusResponse: id = self.connection.read(1) if id != b"\xff": raise OSError("Received invalid packet ID") - length = self.connection.read_ushort() + length = self.connection.read_value(StructFormat.USHORT) data = self.connection.read(length * 2) end = perf_counter() return self.parse_response(data, (end - start) * 1000) class AsyncLegacyClient(_BaseLegacyClient): - def __init__(self, connection: BaseAsyncReadSyncWriteConnection) -> None: + def __init__(self, connection: BaseAsyncConnection) -> None: self.connection = connection async def read_status(self) -> LegacyStatusResponse: """Send the status request and read the response.""" start = perf_counter() - self.connection.write(self.request_status_data) + await self.connection.write(self.request_status_data) id = await self.connection.read(1) if id != b"\xff": raise OSError("Received invalid packet ID") - length = await self.connection.read_ushort() + length = await self.connection.read_value(StructFormat.USHORT) data = await self.connection.read(length * 2) end = perf_counter() return self.parse_response(data, (end - start) * 1000) diff --git a/mcstatus/_protocol/query_client.py b/mcstatus/_protocol/query_client.py index 7ede695a..4092a7f5 100644 --- a/mcstatus/_protocol/query_client.py +++ b/mcstatus/_protocol/query_client.py @@ -2,12 +2,12 @@ import random import re -import struct from abc import abstractmethod from dataclasses import dataclass, field from typing import ClassVar, TYPE_CHECKING, final -from mcstatus._protocol.connection import Connection, UDPAsyncSocketConnection, UDPSocketConnection +from mcstatus._protocol.io.base_io import StructFormat +from mcstatus._protocol.io.buffer import Buffer from mcstatus.responses import QueryResponse from mcstatus.responses._raw import RawQueryResponse @@ -16,6 +16,8 @@ if TYPE_CHECKING: from collections.abc import Awaitable + from mcstatus._protocol.io.connection import UDPAsyncSocketConnection, UDPSocketConnection + @dataclass class _BaseQueryClient: @@ -32,24 +34,24 @@ def _generate_session_id() -> int: # minecraft only supports lower 4 bits return random.randint(0, 2**31) & 0x0F0F0F0F - def _create_packet(self) -> Connection: - packet = Connection() + def _create_packet(self) -> Buffer: + packet = Buffer() packet.write(self.MAGIC_PREFIX) - packet.write(struct.pack("!B", self.PACKET_TYPE_QUERY)) - packet.write_uint(self._generate_session_id()) - packet.write_int(self.challenge) + packet.write_value(StructFormat.UBYTE, self.PACKET_TYPE_QUERY) + packet.write_value(StructFormat.UINT, self._generate_session_id()) + packet.write_value(StructFormat.INT, self.challenge) packet.write(self.PADDING) return packet - def _create_handshake_packet(self) -> Connection: - packet = Connection() + def _create_handshake_packet(self) -> Buffer: + packet = Buffer() packet.write(self.MAGIC_PREFIX) - packet.write(struct.pack("!B", self.PACKET_TYPE_CHALLENGE)) - packet.write_uint(self._generate_session_id()) + packet.write_value(StructFormat.UBYTE, self.PACKET_TYPE_CHALLENGE) + packet.write_value(StructFormat.UINT, self._generate_session_id()) return packet @abstractmethod - def _read_packet(self) -> Connection | Awaitable[Connection]: + def _read_packet(self) -> Buffer | Awaitable[Buffer]: raise NotImplementedError @abstractmethod @@ -60,7 +62,7 @@ def handshake(self) -> None | Awaitable[None]: def read_query(self) -> QueryResponse | Awaitable[QueryResponse]: raise NotImplementedError - def _parse_response(self, response: Connection) -> tuple[RawQueryResponse, list[str]]: + def _parse_response(self, response: Buffer) -> tuple[RawQueryResponse, list[str]]: """Transform the connection object (the result) into dict which is passed to the QueryResponse constructor. :return: A tuple with two elements. First is `raw` answer and second is list of players. @@ -73,7 +75,7 @@ def _parse_response(self, response: Connection) -> tuple[RawQueryResponse, list[ if key == "hostname": # hostname is actually motd in the query protocol match = re.search( b"(.*?)\x00(hostip|hostport|game_id|gametype|map|maxplayers|numplayers|plugins|version)", - response.received, + bytes(response.unread_view()), flags=re.DOTALL, ) motd = match.group(1) if match else "" @@ -105,9 +107,8 @@ def _parse_response(self, response: Connection) -> tuple[RawQueryResponse, list[ class QueryClient(_BaseQueryClient): connection: UDPSocketConnection # pyright: ignore[reportIncompatibleVariableOverride] - def _read_packet(self) -> Connection: - packet = Connection() - packet.receive(self.connection.read(self.connection.remaining())) + def _read_packet(self) -> Buffer: + packet = Buffer(self.connection.read(self.connection.remaining)) packet.read(1 + 4) return packet @@ -130,9 +131,8 @@ def read_query(self) -> QueryResponse: class AsyncQueryClient(_BaseQueryClient): connection: UDPAsyncSocketConnection # pyright: ignore[reportIncompatibleVariableOverride] - async def _read_packet(self) -> Connection: - packet = Connection() - packet.receive(await self.connection.read(self.connection.remaining())) + async def _read_packet(self) -> Buffer: + packet = Buffer(await self.connection.read(self.connection.remaining)) packet.read(1 + 4) return packet diff --git a/mcstatus/responses/forge.py b/mcstatus/responses/forge.py index 0b53ed3c..03cf4b51 100644 --- a/mcstatus/responses/forge.py +++ b/mcstatus/responses/forge.py @@ -18,7 +18,8 @@ from io import StringIO from typing import Final, TYPE_CHECKING -from mcstatus._protocol.connection import BaseConnection, BaseReadSync, Connection +from mcstatus._protocol.io.base_io import BaseSyncReader, StructFormat +from mcstatus._protocol.io.buffer import Buffer from mcstatus._utils import or_none if TYPE_CHECKING: @@ -57,10 +58,10 @@ def build(cls, raw: RawForgeDataChannel) -> Self: return cls(name=raw["res"], version=raw["version"], required=raw["required"]) @classmethod - def decode(cls, buffer: Connection, mod_id: str | None = None) -> Self: + def decode(cls, buffer: Buffer, mod_id: str | None = None) -> Self: """Decode an object about Forge channel from decoded optimized buffer. - :param buffer: :class:`Connection` object from UTF-16 encoded binary data. + :param buffer: :class:`Buffer` object from UTF-16 encoded binary data. :param mod_id: Optional mod id prefix :class:`str`. :return: :class:`ForgeDataChannel` object. """ @@ -68,7 +69,7 @@ def decode(cls, buffer: Connection, mod_id: str | None = None) -> Self: if mod_id is not None: channel_identifier = f"{mod_id}:{channel_identifier}" version = buffer.read_utf() - client_required = buffer.read_bool() + client_required = buffer.read_value(StructFormat.BOOL) return cls( name=channel_identifier, @@ -106,10 +107,10 @@ def build(cls, raw: RawForgeDataMod) -> Self: return cls(name=mod_id, marker=mod_version) @classmethod - def decode(cls, buffer: Connection) -> tuple[Self, list[ForgeDataChannel]]: + def decode(cls, buffer: Buffer) -> tuple[Self, list[ForgeDataChannel]]: """Decode data about a Forge mod from decoded optimized buffer. - :param buffer: :class:`Connection` object from UTF-16 encoded binary data. + :param buffer: :class:`Buffer` object from UTF-16 encoded binary data. :return: :class:`tuple` object of :class:`ForgeDataMod` object and :class:`list` of :class:`ForgeDataChannel` objects. """ channel_version_flags = buffer.read_varint() @@ -127,7 +128,7 @@ def decode(cls, buffer: Connection) -> tuple[Self, list[ForgeDataChannel]]: return cls(name=mod_id, marker=mod_version), channels -class _StringBuffer(BaseReadSync, BaseConnection): +class _StringBuffer(BaseSyncReader): """String Buffer for reading utf-16 encoded binary data.""" __slots__ = ("received", "stringio") @@ -136,7 +137,7 @@ def __init__(self, stringio: StringIO) -> None: self.stringio = stringio self.received = bytearray() - def read(self, length: int) -> bytearray: + def read(self, length: int, /) -> bytes: """Read length bytes from ``self``, and return a byte array.""" data = bytearray() while self.received and len(data) < length: @@ -148,7 +149,7 @@ def read(self, length: int) -> bytearray: data.extend(result.encode("utf-16be")) while len(data) > length: self.received.append(data.pop()) - return data + return bytes(data) def remaining(self) -> int: """Return number of reads remaining.""" @@ -156,20 +157,20 @@ def remaining(self) -> int: def read_optimized_size(self) -> int: """Read encoded data length.""" - return self.read_short() | (self.read_short() << 15) + return self.read_value(StructFormat.SHORT) | (self.read_value(StructFormat.SHORT) << 15) - def read_optimized_buffer(self) -> Connection: + def read_optimized_buffer(self) -> Buffer: """Read encoded buffer.""" size = self.read_optimized_size() - buffer = Connection() + buffer = Buffer() value, bits = 0, 0 - while buffer.remaining() < size: + while buffer.remaining < size: if bits < 8 and self.remaining(): # Ignoring sign bit - value |= (self.read_short() & 0x7FFF) << bits + value |= (self.read_value(StructFormat.SHORT) & 0x7FFF) << bits bits += 15 - buffer.receive((value & 0xFF).to_bytes(1, "big")) + buffer.write((value & 0xFF).to_bytes(1, "big")) value >>= 8 bits -= 8 @@ -190,7 +191,7 @@ class ForgeData: """Is the mods list and or channel list incomplete?""" @staticmethod - def _decode_optimized(string: str) -> Connection: + def _decode_optimized(string: str) -> Buffer: """Decode buffer from UTF-16 optimized binary data ``string``.""" with StringIO(string) as text: str_buffer = _StringBuffer(text) @@ -223,8 +224,8 @@ def build(cls, raw: RawForgeData) -> Self: channels: list[ForgeDataChannel] = [] mods: list[ForgeDataMod] = [] - truncated = buffer.read_bool() - mod_count = buffer.read_ushort() + truncated = buffer.read_value(StructFormat.BOOL) + mod_count = buffer.read_value(StructFormat.USHORT) try: for _ in range(mod_count): mod, mod_channels = ForgeDataMod.decode(buffer) diff --git a/mcstatus/server.py b/mcstatus/server.py index fc9340e5..0a78c13c 100644 --- a/mcstatus/server.py +++ b/mcstatus/server.py @@ -5,7 +5,7 @@ from mcstatus._net.address import Address, async_minecraft_srv_address_lookup, minecraft_srv_address_lookup from mcstatus._protocol.bedrock_client import BedrockClient -from mcstatus._protocol.connection import ( +from mcstatus._protocol.io.connection import ( TCPAsyncSocketConnection, TCPSocketConnection, UDPAsyncSocketConnection, @@ -170,7 +170,7 @@ async def _retry_async_ping( version=version, ping_token=ping_token, # pyright: ignore[reportArgumentType] # None is not assignable to int ) - java_client.handshake() + await java_client.handshake() ping = await java_client.test_ping() return ping @@ -230,7 +230,7 @@ async def _retry_async_status( version=version, ping_token=ping_token, # pyright: ignore[reportArgumentType] # None is not assignable to int ) - java_client.handshake() + await java_client.handshake() result = await java_client.read_status() return result diff --git a/pyproject.toml b/pyproject.toml index 67418446..c923e6c0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -112,7 +112,7 @@ pythonPlatform = "All" pythonVersion = "3.10" typeCheckingMode = "standard" -disableBytesTypePromotions = false +disableBytesTypePromotions = true enableTypeIgnoreComments = false reportUnnecessaryTypeIgnoreComment = true diff --git a/tests/protocol/helpers.py b/tests/protocol/helpers.py new file mode 100644 index 00000000..0853f209 --- /dev/null +++ b/tests/protocol/helpers.py @@ -0,0 +1,157 @@ +from __future__ import annotations + +import asyncio +from typing import Any, ParamSpec, TYPE_CHECKING, TypeVar + +from mcstatus._protocol.io.buffer import Buffer +from mcstatus._protocol.io.connection import BaseAsyncConnection, BaseSyncConnection + +if TYPE_CHECKING: + from collections.abc import Callable, Coroutine + + from typing_extensions import override +else: + override = lambda f: f # noqa: E731 + +P = ParamSpec("P") +T = TypeVar("T") + + +def async_decorator(f: Callable[P, Coroutine[Any, Any, T]]) -> Callable[P, T]: + """Wrap an async callable so it can be invoked from synchronous tests.""" + + def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: + """Execute the wrapped coroutine function using ``asyncio.run``.""" + return asyncio.run(f(*args, **kwargs)) + + return wrapper + + +class SyncBufferConnection(BaseSyncConnection): + """In-memory synchronous stream-style transport used by protocol tests.""" + + def __init__(self, incoming: bytes | bytearray = b"") -> None: + """Initialize in-memory read/write buffers with optional incoming bytes.""" + self.sent = Buffer() + self.received = Buffer(incoming) + + @override + def read(self, length: int, /) -> bytes: + """Read ``length`` bytes from the queued incoming data.""" + return self.received.read(length) + + @override + def write(self, data: bytes | bytearray, /) -> None: + """Append outgoing payload data to the send buffer.""" + self.sent.write(data) + + def receive(self, data: bytes | bytearray) -> None: + """Queue incoming payload data for subsequent reads.""" + self.received.write(data) + + def remaining(self) -> int: + """Return unread byte count from the incoming buffer.""" + return self.received.remaining + + def flush(self) -> bytes: + """Return and clear all bytes written to the send buffer.""" + return self.sent.flush() + + +class AsyncBufferConnection(BaseAsyncConnection): + """In-memory asynchronous stream-style transport used by protocol tests.""" + + def __init__(self, incoming: bytes | bytearray = b"") -> None: + """Initialize in-memory read/write buffers with optional incoming bytes.""" + self.sent = Buffer() + self.received = Buffer(incoming) + + @override + async def read(self, length: int, /) -> bytes: + """Read ``length`` bytes from the queued incoming data.""" + return self.received.read(length) + + @override + async def write(self, data: bytes | bytearray, /) -> None: + """Append outgoing payload data to the send buffer.""" + self.sent.write(data) + + def receive(self, data: bytes | bytearray) -> None: + """Queue incoming payload data for subsequent reads.""" + self.received.write(data) + + def remaining(self) -> int: + """Return unread byte count from the incoming buffer.""" + return self.received.remaining + + def flush(self) -> bytes: + """Return and clear all bytes written to the send buffer.""" + return self.sent.flush() + + +class SyncDatagramConnection(BaseSyncConnection): + """Datagram-like synchronous transport returning one queued packet per read.""" + + def __init__(self) -> None: + self.sent = Buffer() + self.received: list[bytes] = [] + + @override + def read(self, _length: int, /) -> bytes: + """Pop and return the next queued datagram payload.""" + if not self.received: + raise OSError("No datagram data to read.") + return self.received.pop(0) + + @override + def write(self, data: bytes | bytearray, /) -> None: + """Append outgoing datagram payload to the send buffer.""" + self.sent.write(data) + + def receive(self, data: bytes | bytearray) -> None: + """Queue one incoming datagram payload for future reads.""" + self.received.append(bytes(data)) + + def remaining(self) -> int: + """Return size of the next queued datagram, or ``0`` when empty.""" + if not self.received: + return 0 + return len(self.received[0]) + + def flush(self) -> bytes: + """Return and clear all bytes written to the send buffer.""" + return self.sent.flush() + + +class AsyncDatagramConnection(BaseAsyncConnection): + """Datagram-like asynchronous transport returning one queued packet per read.""" + + def __init__(self) -> None: + self.sent = Buffer() + self.received: list[bytes] = [] + + @override + async def read(self, _length: int, /) -> bytes: + """Pop and return the next queued datagram payload.""" + if not self.received: + raise OSError("No datagram data to read.") + return self.received.pop(0) + + @override + async def write(self, data: bytes | bytearray, /) -> None: + """Append outgoing datagram payload to the send buffer.""" + self.sent.write(data) + + def receive(self, data: bytes | bytearray) -> None: + """Queue one incoming datagram payload for future reads.""" + self.received.append(bytes(data)) + + def remaining(self) -> int: + """Return size of the next queued datagram, or ``0`` when empty.""" + if not self.received: + return 0 + return len(self.received[0]) + + def flush(self) -> bytes: + """Return and clear all bytes written to the send buffer.""" + return self.sent.flush() diff --git a/tests/protocol/test_async_support.py b/tests/protocol/test_async_support.py deleted file mode 100644 index badeb3cc..00000000 --- a/tests/protocol/test_async_support.py +++ /dev/null @@ -1,23 +0,0 @@ -from inspect import iscoroutinefunction - -from mcstatus._protocol.connection import TCPAsyncSocketConnection, UDPAsyncSocketConnection - - -def test_is_completely_asynchronous(): - conn = TCPAsyncSocketConnection - assertions = 0 - for attribute in dir(conn): - if attribute.startswith("read_"): - assert iscoroutinefunction(getattr(conn, attribute)) - assertions += 1 - assert assertions > 0, "None of the read_* attributes were async" - - -def test_query_is_completely_asynchronous(): - conn = UDPAsyncSocketConnection - assertions = 0 - for attribute in dir(conn): - if attribute.startswith("read_"): - assert iscoroutinefunction(getattr(conn, attribute)) - assertions += 1 - assert assertions > 0, "None of the read_* attributes were async" diff --git a/tests/protocol/test_base_io.py b/tests/protocol/test_base_io.py new file mode 100644 index 00000000..9e0465cf --- /dev/null +++ b/tests/protocol/test_base_io.py @@ -0,0 +1,508 @@ +from __future__ import annotations + +import struct +from inspect import isawaitable +from typing import TYPE_CHECKING, TypeVar, overload +from unittest.mock import AsyncMock, Mock + +import pytest + +from mcstatus._protocol.io.base_io import INT_FORMATS_TYPE, StructFormat +from mcstatus._protocol.io.buffer import Buffer +from tests.protocol.helpers import AsyncBufferConnection, SyncBufferConnection + +if TYPE_CHECKING: + from collections.abc import Awaitable + +ConnectionClass = type[SyncBufferConnection] | type[AsyncBufferConnection] +T = TypeVar("T") + + +@pytest.fixture( + params=[ + pytest.param(SyncBufferConnection, id="sync"), + pytest.param(AsyncBufferConnection, id="async"), + ] +) +def conn_cls(request: pytest.FixtureRequest) -> ConnectionClass: + """Provide a parametrized sync/async connection class for each test.""" + return request.param + + +@overload +async def maybe_await(value: Awaitable[T], /) -> T: ... + + +@overload +async def maybe_await(value: T, /) -> T: ... + + +async def maybe_await(value: Awaitable[T] | T, /) -> T: + """Return a value directly or await it when needed. + + This keeps test calls explicit and type-checkable for both sync and async IO APIs. + """ + if isawaitable(value): + return await value + return value + + +@pytest.mark.parametrize( + ("fmt", "value", "expected"), + [ + pytest.param(StructFormat.UBYTE, 0, b"\x00", id="ubyte-0"), + pytest.param(StructFormat.UBYTE, 15, b"\x0f", id="ubyte-15"), + pytest.param(StructFormat.UBYTE, 255, b"\xff", id="ubyte-max"), + pytest.param(StructFormat.BYTE, 0, b"\x00", id="byte-0"), + pytest.param(StructFormat.BYTE, 15, b"\x0f", id="byte-15"), + pytest.param(StructFormat.BYTE, 127, b"\x7f", id="byte-max"), + pytest.param(StructFormat.BYTE, -20, b"\xec", id="byte-neg-20"), + pytest.param(StructFormat.BYTE, -128, b"\x80", id="byte-min"), + ], +) +@pytest.mark.asyncio +async def test_write_value_matches_reference( + conn_cls: ConnectionClass, + fmt: INT_FORMATS_TYPE, + value: int, + expected: bytes, +): + conn = conn_cls() + await maybe_await(conn.write_value(fmt, value)) + assert conn.flush() == expected + + +@pytest.mark.asyncio +async def test_write_value_char_uses_single_byte(conn_cls: ConnectionClass): + conn = conn_cls() + await maybe_await(conn.write_value(StructFormat.CHAR, b"a")) + assert conn.flush() == b"a" + + +@pytest.mark.asyncio +async def test_write_value_char_rejects_non_single_byte(conn_cls: ConnectionClass): + conn = conn_cls() + with pytest.raises(struct.error): + await maybe_await(conn.write_value(StructFormat.CHAR, b"ab")) + + +@pytest.mark.parametrize( + ("fmt", "value"), + [ + pytest.param(StructFormat.UBYTE, -1, id="ubyte-neg"), + pytest.param(StructFormat.UBYTE, 256, id="ubyte-overflow"), + pytest.param(StructFormat.BYTE, -129, id="byte-underflow"), + pytest.param(StructFormat.BYTE, 128, id="byte-overflow"), + ], +) +@pytest.mark.asyncio +async def test_write_value_rejects_out_of_range(conn_cls: ConnectionClass, fmt: INT_FORMATS_TYPE, value: int): + conn = conn_cls() + with pytest.raises(struct.error): + await maybe_await(conn.write_value(fmt, value)) + + +@pytest.mark.parametrize( + ("encoded", "fmt", "expected"), + [ + pytest.param(b"\x00", StructFormat.UBYTE, 0, id="ubyte-0"), + pytest.param(b"\x0f", StructFormat.UBYTE, 15, id="ubyte-15"), + pytest.param(b"\xff", StructFormat.UBYTE, 255, id="ubyte-max"), + pytest.param(b"\x00", StructFormat.BYTE, 0, id="byte-0"), + pytest.param(b"\x0f", StructFormat.BYTE, 15, id="byte-15"), + pytest.param(b"\x7f", StructFormat.BYTE, 127, id="byte-max"), + pytest.param(b"\xec", StructFormat.BYTE, -20, id="byte-neg-20"), + pytest.param(b"\x80", StructFormat.BYTE, -128, id="byte-min"), + ], +) +@pytest.mark.asyncio +async def test_read_value_matches_reference( + conn_cls: ConnectionClass, + encoded: bytes, + fmt: INT_FORMATS_TYPE, + expected: int, +): + conn = conn_cls(encoded) + assert await maybe_await(conn.read_value(fmt)) == expected + + +@pytest.mark.asyncio +async def test_read_value_char_returns_bytes(conn_cls: ConnectionClass): + conn = conn_cls(b"a") + value = await maybe_await(conn.read_value(StructFormat.CHAR)) + assert value == b"a" + assert isinstance(value, bytes) + + +@pytest.mark.parametrize( + ("number", "expected"), + [ + pytest.param(0, b"\x00", id="0"), + pytest.param(127, b"\x7f", id="127"), + pytest.param(128, b"\x80\x01", id="128"), + pytest.param(255, b"\xff\x01", id="255"), + pytest.param(1_000_000, b"\xc0\x84\x3d", id="1m"), + pytest.param((2**31) - 1, b"\xff\xff\xff\xff\x07", id="max-32"), + ], +) +@pytest.mark.asyncio +async def test_write_varuint_matches_reference(conn_cls: ConnectionClass, number: int, expected: bytes): + conn = conn_cls() + await maybe_await(conn._write_varuint(number)) + assert conn.flush() == expected + + +@pytest.mark.parametrize( + ("encoded", "expected"), + [ + pytest.param(b"\x00", 0, id="0"), + pytest.param(b"\x7f", 127, id="127"), + pytest.param(b"\x80\x01", 128, id="128"), + pytest.param(b"\xff\x01", 255, id="255"), + pytest.param(b"\xc0\x84\x3d", 1_000_000, id="1m"), + pytest.param(b"\xff\xff\xff\xff\x07", (2**31) - 1, id="max-32"), + ], +) +@pytest.mark.asyncio +async def test_read_varuint_matches_reference(conn_cls: ConnectionClass, encoded: bytes, expected: int): + conn = conn_cls(encoded) + assert await maybe_await(conn._read_varuint()) == expected + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("number", "max_bits"), + [ + (-1, 128), + (-1, 1), + (2**16, 16), + (2**32, 32), + ], +) +async def test_write_varuint_rejects_out_of_range(conn_cls: ConnectionClass, number: int, max_bits: int): + conn = conn_cls() + with pytest.raises(ValueError, match=r"outside of the range of"): + await maybe_await(conn._write_varuint(number, max_bits=max_bits)) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("encoded", "max_bits"), + [ + (b"\x80\x80\x04", 16), + (b"\x80\x80\x80\x80\x10", 32), + ], +) +async def test_read_varuint_rejects_out_of_range(conn_cls: ConnectionClass, encoded: bytes, max_bits: int): + conn = conn_cls(encoded) + with pytest.raises(OSError, match=r"outside the range of"): + await maybe_await(conn._read_varuint(max_bits=max_bits)) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("encoded", "max_bits", "max_bytes"), + [ + (b"\x80\x80\x80\x00", 16, 3), + (b"\x80\x80\x80\x80\x80\x00", 32, 5), + ], +) +async def test_read_varuint_rejects_too_many_bytes( + conn_cls: ConnectionClass, + encoded: bytes, + max_bits: int, + max_bytes: int, +): + conn = conn_cls(encoded) + with pytest.raises( + OSError, + match=rf"^Received varint had too many bytes for {max_bits}-bit int \(continuation bit set on byte {max_bytes}\)\.$", + ): + await maybe_await(conn._read_varuint(max_bits=max_bits)) + + +@pytest.mark.parametrize( + ("number", "expected"), + [ + pytest.param(127, b"\x7f", id="127"), + pytest.param(16_384, b"\x80\x80\x01", id="16384"), + pytest.param(-128, b"\x80\xff\xff\xff\x0f", id="-128"), + pytest.param(-16_383, b"\x81\x80\xff\xff\x0f", id="-16383"), + ], +) +@pytest.mark.asyncio +async def test_write_varint_matches_reference(conn_cls: ConnectionClass, number: int, expected: bytes): + conn = conn_cls() + await maybe_await(conn.write_varint(number)) + assert conn.flush() == expected + + +@pytest.mark.parametrize( + ("encoded", "expected"), + [ + pytest.param(b"\x7f", 127, id="127"), + pytest.param(b"\x80\x80\x01", 16_384, id="16384"), + pytest.param(b"\x80\xff\xff\xff\x0f", -128, id="-128"), + pytest.param(b"\x81\x80\xff\xff\x0f", -16_383, id="-16383"), + ], +) +@pytest.mark.asyncio +async def test_read_varint_matches_reference(conn_cls: ConnectionClass, encoded: bytes, expected: int): + conn = conn_cls(encoded) + assert await maybe_await(conn.read_varint()) == expected + + +@pytest.mark.parametrize( + ("number", "expected"), + [ + pytest.param(127, b"\x7f", id="127"), + pytest.param(16_384, b"\x80\x80\x01", id="16384"), + pytest.param(-128, b"\x80\xff\xff\xff\xff\xff\xff\xff\xff\x01", id="-128"), + pytest.param(-16_383, b"\x81\x80\xff\xff\xff\xff\xff\xff\xff\x01", id="-16383"), + ], +) +@pytest.mark.asyncio +async def test_write_varlong_matches_reference(conn_cls: ConnectionClass, number: int, expected: bytes): + conn = conn_cls() + await maybe_await(conn.write_varlong(number)) + assert conn.flush() == expected + + +@pytest.mark.parametrize( + ("encoded", "expected"), + [ + pytest.param(b"\x7f", 127, id="127"), + pytest.param(b"\x80\x80\x01", 16_384, id="16384"), + pytest.param(b"\x80\xff\xff\xff\xff\xff\xff\xff\xff\x01", -128, id="-128"), + pytest.param(b"\x81\x80\xff\xff\xff\xff\xff\xff\xff\x01", -16_383, id="-16383"), + ], +) +@pytest.mark.asyncio +async def test_read_varlong_matches_reference(conn_cls: ConnectionClass, encoded: bytes, expected: int): + conn = conn_cls(encoded) + assert await maybe_await(conn.read_varlong()) == expected + + +@pytest.mark.asyncio +@pytest.mark.parametrize("number", [0, 1, 127, 16_384, -1, -(2**31), (2**31) - 1]) +async def test_varint_roundtrip(conn_cls: ConnectionClass, number: int): + conn = conn_cls() + await maybe_await(conn.write_varint(number)) + conn.receive(conn.flush()) + assert await maybe_await(conn.read_varint()) == number + + +@pytest.mark.asyncio +@pytest.mark.parametrize("number", [127, 16_384, -128, -16_383, -(2**63), (2**63) - 1]) +async def test_varlong_roundtrip(conn_cls: ConnectionClass, number: int): + conn = conn_cls() + await maybe_await(conn.write_varlong(number)) + conn.receive(conn.flush()) + assert await maybe_await(conn.read_varlong()) == number + + +@pytest.mark.asyncio +@pytest.mark.parametrize("number", [-(2**63) - 1, 2**63]) +async def test_write_varlong_rejects_out_of_range(conn_cls: ConnectionClass, number: int): + conn = conn_cls() + with pytest.raises(ValueError, match=r"out of range"): + await maybe_await(conn.write_varlong(number)) + + +@pytest.mark.asyncio +async def test_optional_helpers(conn_cls: ConnectionClass): + conn = conn_cls() + + if isinstance(conn, AsyncBufferConnection): + writer = AsyncMock(return_value="written") + + assert await conn.write_optional(None, writer) is None + writer.assert_not_awaited() + assert conn.flush() == b"\x00" + + assert await conn.write_optional("value", writer) == "written" + writer.assert_awaited_once_with("value") + assert conn.flush() == b"\x01" + + reader = AsyncMock(return_value="parsed") + conn.receive(b"\x00") + assert await conn.read_optional(reader) is None + reader.assert_not_awaited() + + conn.receive(b"\x01") + assert await conn.read_optional(reader) == "parsed" + reader.assert_awaited_once_with() + return + + writer = Mock(return_value="written") + + assert conn.write_optional(None, writer) is None + writer.assert_not_called() + assert conn.flush() == b"\x00" + + assert conn.write_optional("value", writer) == "written" + writer.assert_called_once_with("value") + assert conn.flush() == b"\x01" + + reader = Mock(return_value="parsed") + conn.receive(b"\x00") + assert conn.read_optional(reader) is None + reader.assert_not_called() + + conn.receive(b"\x01") + assert conn.read_optional(reader) == "parsed" + reader.assert_called_once_with() + + +@pytest.mark.asyncio +async def test_write_and_read_ascii(conn_cls: ConnectionClass): + conn = conn_cls() + await maybe_await(conn.write_ascii("hello")) + + conn.receive(conn.flush()) + assert await maybe_await(conn.read_ascii()) == "hello" + + +@pytest.mark.asyncio +async def test_write_and_read_bytearray(conn_cls: ConnectionClass): + conn = conn_cls() + data = b"\x00\x01hello\xff" + + await maybe_await(conn.write_bytearray(data)) + conn.receive(conn.flush()) + + assert await maybe_await(conn.read_bytearray()) == data + + +@pytest.mark.parametrize( + ("data", "expected"), + [ + pytest.param(b"", b"\x00", id="empty"), + pytest.param(b"\x01", b"\x01\x01", id="single"), + pytest.param(b"hello\x00world", b"\x0bhello\x00world", id="with-null"), + pytest.param(b"\x01\x02\x03four\x05", b"\x08\x01\x02\x03four\x05", id="mixed"), + ], +) +@pytest.mark.asyncio +async def test_write_bytearray_matches_reference(conn_cls: ConnectionClass, data: bytes, expected: bytes): + conn = conn_cls() + await maybe_await(conn.write_bytearray(data)) + assert conn.flush() == expected + + +@pytest.mark.parametrize( + ("encoded", "expected"), + [ + pytest.param(b"\x00", b"", id="empty"), + pytest.param(b"\x01\x01", b"\x01", id="single"), + pytest.param(b"\x0bhello\x00world", b"hello\x00world", id="with-null"), + pytest.param(b"\x08\x01\x02\x03four\x05", b"\x01\x02\x03four\x05", id="mixed"), + ], +) +@pytest.mark.asyncio +async def test_read_bytearray_matches_reference(conn_cls: ConnectionClass, encoded: bytes, expected: bytes): + conn = conn_cls(encoded) + assert await maybe_await(conn.read_bytearray()) == expected + + +@pytest.mark.asyncio +async def test_read_bytearray_rejects_negative_length(conn_cls: ConnectionClass): + conn = conn_cls(b"\xff\xff\xff\xff\x0f") + with pytest.raises(OSError, match=r"^Length prefix for byte arrays must be non-negative, got -1\.$"): + await maybe_await(conn.read_bytearray()) + + +@pytest.mark.parametrize( + ("value", "expected"), + [ + pytest.param("test", b"test\x00", id="test"), + pytest.param("a" * 100, b"a" * 100 + b"\x00", id="100-as"), + pytest.param("", b"\x00", id="empty"), + ], +) +@pytest.mark.asyncio +async def test_write_ascii_matches_reference(conn_cls: ConnectionClass, value: str, expected: bytes): + conn = conn_cls() + await maybe_await(conn.write_ascii(value)) + assert conn.flush() == expected + + +@pytest.mark.parametrize( + ("encoded", "expected"), + [ + pytest.param(b"test\x00", "test", id="test"), + pytest.param(b"a" * 100 + b"\x00", "a" * 100, id="100-as"), + pytest.param(b"\x00", "", id="empty"), + ], +) +@pytest.mark.asyncio +async def test_read_ascii_matches_reference(conn_cls: ConnectionClass, encoded: bytes, expected: str): + conn = conn_cls(encoded) + assert await maybe_await(conn.read_ascii()) == expected + + +@pytest.mark.parametrize( + ("value", "expected"), + [ + pytest.param("test", b"\x04test", id="test"), + pytest.param("a" * 100, b"\x64" + b"a" * 100, id="100-as"), + pytest.param("", b"\x00", id="empty"), + pytest.param("नमस्ते", b"\x12" + "नमस्ते".encode(), id="hindi"), + ], +) +@pytest.mark.asyncio +async def test_write_utf_matches_reference(conn_cls: ConnectionClass, value: str, expected: bytes): + conn = conn_cls() + await maybe_await(conn.write_utf(value)) + assert conn.flush() == expected + + +@pytest.mark.parametrize( + ("encoded", "expected"), + [ + pytest.param(b"\x04test", "test", id="test"), + pytest.param(b"\x64" + b"a" * 100, "a" * 100, id="100-as"), + pytest.param(b"\x00", "", id="empty"), + pytest.param(b"\x12" + "नमस्ते".encode(), "नमस्ते", id="hindi"), + ], +) +@pytest.mark.asyncio +async def test_read_utf_matches_reference(conn_cls: ConnectionClass, encoded: bytes, expected: str): + conn = conn_cls(encoded) + assert await maybe_await(conn.read_utf()) == expected + + +@pytest.mark.asyncio +async def test_write_utf_rejects_too_many_characters(conn_cls: ConnectionClass): + conn = conn_cls() + with pytest.raises(ValueError, match=r"Maximum character limit for writing strings is 32767 characters"): + await maybe_await(conn.write_utf("a" * 32768)) + + +@pytest.mark.asyncio +async def test_read_utf_rejects_too_many_bytes(conn_cls: ConnectionClass): + payload = Buffer() + payload.write_varint(131069) + + conn = conn_cls(payload) + with pytest.raises(OSError, match=r"Maximum read limit for utf strings is 131068 bytes, got 131069"): + await maybe_await(conn.read_utf()) + + +@pytest.mark.asyncio +async def test_read_utf_rejects_negative_length(conn_cls: ConnectionClass): + conn = conn_cls(b"\xff\xff\xff\xff\x0f") + with pytest.raises(OSError, match=r"^Length prefix for utf strings must be non-negative, got -1\.$"): + await maybe_await(conn.read_utf()) + + +@pytest.mark.asyncio +async def test_read_utf_rejects_too_many_characters(conn_cls: ConnectionClass): + text = "a" * 32768 + payload = Buffer() + payload.write_varint(len(text.encode("utf-8"))) + payload.write(text.encode("utf-8")) + + conn = conn_cls(payload) + with pytest.raises(OSError, match=r"Maximum read limit for utf strings is 32767 characters, got 32768"): + await maybe_await(conn.read_utf()) diff --git a/tests/protocol/test_base_io_twos_complement.py b/tests/protocol/test_base_io_twos_complement.py new file mode 100644 index 00000000..08e3eb25 --- /dev/null +++ b/tests/protocol/test_base_io_twos_complement.py @@ -0,0 +1,72 @@ +from __future__ import annotations + +import pytest + +from mcstatus._protocol.io.base_io import from_twos_complement, to_twos_complement + +TWOS_COMPLEMENT_CASES = [ + (-128, 8, 0x80), + (-2, 8, 0xFE), + (-1, 8, 0xFF), + (0, 8, 0x00), + (1, 8, 0x01), + (127, 8, 0x7F), + (-(2**15), 16, 0x8000), + (-1, 16, 0xFFFF), + ((2**15) - 1, 16, 0x7FFF), + (-(2**31), 32, 0x80000000), + (-1, 32, 0xFFFFFFFF), + ((2**31) - 1, 32, 0x7FFFFFFF), + (-(2**63), 64, 0x8000000000000000), + (-9_876_543_210_123_456, 64, 0xFFDCE956165A0F40), + (-1, 64, 0xFFFFFFFFFFFFFFFF), + (0, 64, 0x0000000000000000), + (9_876_543_210_123_456, 64, 0x002316A9E9A5F0C0), + ((2**63) - 1, 64, 0x7FFFFFFFFFFFFFFF), +] + + +@pytest.mark.parametrize( + ("number", "bits", "expected_twos"), + TWOS_COMPLEMENT_CASES, +) +def test_to_twos_complement_matches_expected_values(number: int, bits: int, expected_twos: int): + assert to_twos_complement(number, bits=bits) == expected_twos + + +@pytest.mark.parametrize( + ("twos_value", "bits", "expected_number"), + [(twos_value, bits, number) for number, bits, twos_value in TWOS_COMPLEMENT_CASES], +) +def test_from_twos_complement_matches_expected_values(twos_value: int, bits: int, expected_number: int): + assert from_twos_complement(twos_value, bits=bits) == expected_number + + +@pytest.mark.parametrize( + ("number", "bits"), + [ + (-129, 8), + (128, 8), + (-(2**31) - 1, 32), + (2**31, 32), + (-(2**63) - 1, 64), + (2**63, 64), + ], +) +def test_to_twos_complement_rejects_out_of_range(number: int, bits: int): + with pytest.raises(ValueError, match=r"out of range"): + to_twos_complement(number, bits=bits) + + +@pytest.mark.parametrize( + ("number", "bits"), + [ + (-1, 8), + (256, 8), + (2**32, 32), + (2**64, 64), + ], +) +def test_from_twos_complement_rejects_out_of_range(number: int, bits: int): + with pytest.raises(ValueError, match=r"out of range"): + from_twos_complement(number, bits=bits) diff --git a/tests/protocol/test_connection.py b/tests/protocol/test_connection.py index 18062d30..7819a3cd 100644 --- a/tests/protocol/test_connection.py +++ b/tests/protocol/test_connection.py @@ -3,32 +3,54 @@ import pytest from mcstatus._net.address import Address -from mcstatus._protocol.connection import Connection, TCPSocketConnection, UDPSocketConnection +from mcstatus._protocol.io.base_io import StructFormat +from mcstatus._protocol.io.buffer import Buffer +from mcstatus._protocol.io.connection import TCPSocketConnection, UDPSocketConnection -class TestConnection: - connection: Connection - +class TestBuffer: def setup_method(self): - self.connection = Connection() + self.connection = Buffer() def test_flush(self): - self.connection.sent = bytearray.fromhex("7FAABB") + self.connection.write(bytearray.fromhex("7FAABB")) assert self.connection.flush() == bytearray.fromhex("7FAABB") - assert self.connection.sent == bytearray() + assert self.connection == bytearray() - def test_receive(self): - self.connection.receive(bytearray.fromhex("7F")) - self.connection.receive(bytearray.fromhex("AABB")) + def test_remaining(self): + self.connection.write(bytearray.fromhex("7FAABB")) - assert self.connection.received == bytearray.fromhex("7FAABB") + assert self.connection.remaining == 3 - def test_remaining(self): - self.connection.receive(bytearray.fromhex("7F")) - self.connection.receive(bytearray.fromhex("AABB")) + def test_reset(self): + self.connection.write(b"abcdef") + + assert self.connection.read(3) == b"abc" + self.connection.reset() + assert self.connection.read(6) == b"abcdef" + + def test_clear_only_already_read(self): + self.connection.write(b"abcdef") + assert self.connection.read(2) == b"ab" + + self.connection.clear(only_already_read=True) + + assert self.connection == bytearray(b"cdef") + assert self.connection.remaining == 4 + + def test_unread_view(self): + self.connection.write(b"abcdef") + assert self.connection.read(2) == b"ab" + + assert bytes(self.connection.unread_view()) == b"cdef" + + def test_flush_only_returns_unread_data(self): + self.connection.write(b"abcdef") + assert self.connection.read(2) == b"ab" - assert self.connection.remaining() == 3 + assert self.connection.flush() == b"cdef" + assert self.connection == bytearray() def test_send(self): self.connection.write(bytearray.fromhex("7F")) @@ -37,13 +59,13 @@ def test_send(self): assert self.connection.flush() == bytearray.fromhex("7FAABB") def test_read(self): - self.connection.receive(bytearray.fromhex("7FAABB")) + self.connection.write(bytearray.fromhex("7FAABB")) assert self.connection.read(2) == bytearray.fromhex("7FAA") assert self.connection.read(1) == bytearray.fromhex("BB") def _assert_varint_read_write(self, hexstr, value) -> None: - self.connection.receive(bytearray.fromhex(hexstr)) + self.connection.write(bytearray.fromhex(hexstr)) assert self.connection.read_varint() == value self.connection.write_varint(value) @@ -59,19 +81,25 @@ def test_varint_cases(self): self._assert_varint_read_write("8080808008", -2147483648) def test_read_invalid_varint(self): - self.connection.receive(bytearray.fromhex("FFFFFFFF80")) + self.connection.write(bytearray.fromhex("FFFFFFFF10")) - with pytest.raises(IOError, match=r"^Received varint is too big!$"): + with pytest.raises(IOError, match=r"^Received varint was outside the range of 32-bit int.$"): self.connection.read_varint() def test_write_invalid_varint(self): - with pytest.raises(ValueError, match=r'^The value "2147483648" is too big to send in a varint$'): + with pytest.raises( + ValueError, + match=r"^Can't convert number 2147483648 into 32-bit twos complement format - out of range$", + ): self.connection.write_varint(2147483648) - with pytest.raises(ValueError, match=r'^The value "-2147483649" is too big to send in a varint$'): + with pytest.raises( + ValueError, + match=r"^Can't convert number -2147483649 into 32-bit twos complement format - out of range$", + ): self.connection.write_varint(-2147483649) def test_read_utf(self): - self.connection.receive(bytearray.fromhex("0D48656C6C6F2C20776F726C6421")) + self.connection.write(bytearray.fromhex("0D48656C6C6F2C20776F726C6421")) assert self.connection.read_utf() == "Hello, world!" @@ -86,7 +114,7 @@ def test_read_empty_utf(self): assert self.connection.flush() == bytearray.fromhex("00") def test_read_ascii(self): - self.connection.receive(bytearray.fromhex("48656C6C6F2C20776F726C642100")) + self.connection.write(bytearray.fromhex("48656C6C6F2C20776F726C642100")) assert self.connection.read_ascii() == "Hello, world!" @@ -101,133 +129,137 @@ def test_read_empty_ascii(self): assert self.connection.flush() == bytearray.fromhex("00") def test_read_short_negative(self): - self.connection.receive(bytearray.fromhex("8000")) + self.connection.write(bytearray.fromhex("8000")) - assert self.connection.read_short() == -32768 + assert self.connection.read_value(StructFormat.SHORT) == -32768 def test_write_short_negative(self): - self.connection.write_short(-32768) + self.connection.write_value(StructFormat.SHORT, -32768) assert self.connection.flush() == bytearray.fromhex("8000") def test_read_short_positive(self): - self.connection.receive(bytearray.fromhex("7FFF")) + self.connection.write(bytearray.fromhex("7FFF")) - assert self.connection.read_short() == 32767 + assert self.connection.read_value(StructFormat.SHORT) == 32767 def test_write_short_positive(self): - self.connection.write_short(32767) + self.connection.write_value(StructFormat.SHORT, 32767) assert self.connection.flush() == bytearray.fromhex("7FFF") def test_read_ushort_positive(self): - self.connection.receive(bytearray.fromhex("8000")) + self.connection.write(bytearray.fromhex("8000")) - assert self.connection.read_ushort() == 32768 + assert self.connection.read_value(StructFormat.USHORT) == 32768 def test_write_ushort_positive(self): - self.connection.write_ushort(32768) + self.connection.write_value(StructFormat.USHORT, 32768) assert self.connection.flush() == bytearray.fromhex("8000") def test_read_int_negative(self): - self.connection.receive(bytearray.fromhex("80000000")) + self.connection.write(bytearray.fromhex("80000000")) - assert self.connection.read_int() == -2147483648 + assert self.connection.read_value(StructFormat.INT) == -2147483648 def test_write_int_negative(self): - self.connection.write_int(-2147483648) + self.connection.write_value(StructFormat.INT, -2147483648) assert self.connection.flush() == bytearray.fromhex("80000000") def test_read_int_positive(self): - self.connection.receive(bytearray.fromhex("7FFFFFFF")) + self.connection.write(bytearray.fromhex("7FFFFFFF")) - assert self.connection.read_int() == 2147483647 + assert self.connection.read_value(StructFormat.INT) == 2147483647 def test_write_int_positive(self): - self.connection.write_int(2147483647) + self.connection.write_value(StructFormat.INT, 2147483647) assert self.connection.flush() == bytearray.fromhex("7FFFFFFF") def test_read_uint_positive(self): - self.connection.receive(bytearray.fromhex("80000000")) + self.connection.write(bytearray.fromhex("80000000")) - assert self.connection.read_uint() == 2147483648 + assert self.connection.read_value(StructFormat.UINT) == 2147483648 def test_write_uint_positive(self): - self.connection.write_uint(2147483648) + self.connection.write_value(StructFormat.UINT, 2147483648) assert self.connection.flush() == bytearray.fromhex("80000000") def test_read_long_negative(self): - self.connection.receive(bytearray.fromhex("8000000000000000")) + self.connection.write(bytearray.fromhex("8000000000000000")) - assert self.connection.read_long() == -9223372036854775808 + assert self.connection.read_value(StructFormat.LONGLONG) == -9223372036854775808 def test_write_long_negative(self): - self.connection.write_long(-9223372036854775808) + self.connection.write_value(StructFormat.LONGLONG, -9223372036854775808) assert self.connection.flush() == bytearray.fromhex("8000000000000000") def test_read_long_positive(self): - self.connection.receive(bytearray.fromhex("7FFFFFFFFFFFFFFF")) + self.connection.write(bytearray.fromhex("7FFFFFFFFFFFFFFF")) - assert self.connection.read_long() == 9223372036854775807 + assert self.connection.read_value(StructFormat.LONGLONG) == 9223372036854775807 def test_write_long_positive(self): - self.connection.write_long(9223372036854775807) + self.connection.write_value(StructFormat.LONGLONG, 9223372036854775807) assert self.connection.flush() == bytearray.fromhex("7FFFFFFFFFFFFFFF") def test_read_ulong_positive(self): - self.connection.receive(bytearray.fromhex("8000000000000000")) + self.connection.write(bytearray.fromhex("8000000000000000")) - assert self.connection.read_ulong() == 9223372036854775808 + assert self.connection.read_value(StructFormat.ULONGLONG) == 9223372036854775808 def test_write_ulong_positive(self): - self.connection.write_ulong(9223372036854775808) + self.connection.write_value(StructFormat.ULONGLONG, 9223372036854775808) assert self.connection.flush() == bytearray.fromhex("8000000000000000") @pytest.mark.parametrize(("as_bytes", "as_bool"), [("01", True), ("00", False)]) def test_read_bool(self, as_bytes: str, as_bool: bool) -> None: - self.connection.receive(bytearray.fromhex(as_bytes)) + self.connection.write(bytearray.fromhex(as_bytes)) - assert self.connection.read_bool() is as_bool + assert self.connection.read_value(StructFormat.BOOL) is as_bool @pytest.mark.parametrize(("as_bytes", "as_bool"), [("01", True), ("00", False)]) def test_write_bool(self, as_bytes: str, as_bool: bool) -> None: - self.connection.write_bool(as_bool) + self.connection.write_value(StructFormat.BOOL, as_bool) assert self.connection.flush() == bytearray.fromhex(as_bytes) - def test_read_buffer(self): - self.connection.receive(bytearray.fromhex("027FAA")) - buffer = self.connection.read_buffer() + def test_read_bytearray(self): + self.connection.write(bytearray.fromhex("027FAA")) - assert buffer.received == bytearray.fromhex("7FAA") - assert self.connection.flush() == bytearray() + assert self.connection.read_bytearray() == bytearray.fromhex("7FAA") - def test_write_buffer(self): - buffer = Connection() - buffer.write(bytearray.fromhex("7FAA")) - self.connection.write_buffer(buffer) + def test_write_bytearray(self): + self.connection.write_bytearray(bytearray.fromhex("7FAA")) assert self.connection.flush() == bytearray.fromhex("027FAA") def test_read_empty(self): - self.connection.received = bytearray() - - with pytest.raises(IOError, match=r"^Not enough data to read! 0 < 1$"): + with pytest.raises( + IOError, + match=r"^Requested to read more data than available. Read 0 bytes: bytearray\(b''\), out of 1 requested bytes.$", + ): self.connection.read(1) def test_read_not_enough(self): - self.connection.received = bytearray(b"a") + self.connection.write(bytearray(b"a")) - with pytest.raises(IOError, match=r"^Not enough data to read! 1 < 2$"): + with pytest.raises( + IOError, + match=r"^Requested to read more data than available. Read 1 bytes: bytearray\(b'a'\), out of 2 requested bytes.$", + ): self.connection.read(2) + def test_read_negative_length(self): + with pytest.raises(IOError, match=r"^Requested to read a negative amount of data: -1\.$"): + self.connection.read(-1) + class TestTCPSocketConnection: @pytest.fixture(scope="class") @@ -236,24 +268,12 @@ def connection(self): socket = Mock() socket.recv = Mock() - socket.send = Mock() + socket.sendall = Mock() with patch("socket.create_connection") as create_connection: create_connection.return_value = socket with TCPSocketConnection(test_addr) as connection: yield connection - def test_flush(self, connection): - with pytest.raises(TypeError, match=r"^TCPSocketConnection does not support flush\(\)$"): - connection.flush() - - def test_receive(self, connection): - with pytest.raises(TypeError, match=r"^TCPSocketConnection does not support receive\(\)$"): - connection.receive("") - - def test_remaining(self, connection): - with pytest.raises(TypeError, match=r"^TCPSocketConnection does not support remaining\(\)$"): - connection.remaining() - def test_read(self, connection): connection.socket.recv.return_value = bytearray.fromhex("7FAA") @@ -274,13 +294,13 @@ def test_read_not_enough(self, connection): def test_write(self, connection): connection.write(bytearray.fromhex("7FAA")) - connection.socket.send.assert_called_once_with(bytearray.fromhex("7FAA")) + connection.socket.sendall.assert_called_once_with(bytearray.fromhex("7FAA")) class TestUDPSocketConnection: @pytest.fixture(scope="class") def connection(self): - test_addr = Address("localhost", 1234) + test_addr = Address("127.0.0.1", 1234) socket = Mock() socket.recvfrom = Mock() @@ -290,16 +310,8 @@ def connection(self): with UDPSocketConnection(test_addr) as connection: yield connection - def test_flush(self, connection): - with pytest.raises(TypeError, match=r"^UDPSocketConnection does not support flush\(\)$"): - connection.flush() - - def test_receive(self, connection): - with pytest.raises(TypeError, match=r"^UDPSocketConnection does not support receive\(\)$"): - connection.receive("") - def test_remaining(self, connection): - assert connection.remaining() == 65535 + assert connection.remaining == 65535 def test_read(self, connection): connection.socket.recvfrom.return_value = [bytearray.fromhex("7FAA")] @@ -317,5 +329,5 @@ def test_write(self, connection): connection.socket.sendto.assert_called_once_with( bytearray.fromhex("7FAA"), - Address("localhost", 1234), + Address("127.0.0.1", 1234), ) diff --git a/tests/protocol/test_java_client.py b/tests/protocol/test_java_client.py index 39d5ba56..6acd54e2 100644 --- a/tests/protocol/test_java_client.py +++ b/tests/protocol/test_java_client.py @@ -5,14 +5,15 @@ import pytest from mcstatus._net.address import Address -from mcstatus._protocol.connection import Connection from mcstatus._protocol.java_client import JavaClient +from tests.protocol.helpers import SyncBufferConnection class TestJavaClient: def setup_method(self): + self.connection = SyncBufferConnection() self.java_client = JavaClient( - Connection(), # pyright: ignore[reportArgumentType] + self.connection, # pyright: ignore[reportArgumentType] address=Address("localhost", 25565), version=44, ) @@ -20,10 +21,10 @@ def setup_method(self): def test_handshake(self): self.java_client.handshake() - assert self.java_client.connection.flush() == bytearray.fromhex("0F002C096C6F63616C686F737463DD01") + assert self.connection.flush() == bytearray.fromhex("0F002C096C6F63616C686F737463DD01") def test_read_status(self): - self.java_client.connection.receive( + self.connection.receive( bytearray.fromhex( "7200707B226465736372697074696F6E223A2241204D696E65637261667420536572766572222C22706C6179657273223A7B2" "26D6178223A32302C226F6E6C696E65223A307D2C2276657273696F6E223A7B226E616D65223A22312E382D70726531222C22" @@ -37,15 +38,15 @@ def test_read_status(self): "players": {"max": 20, "online": 0}, "version": {"name": "1.8-pre1", "protocol": 44}, } - assert self.java_client.connection.flush() == bytearray.fromhex("0100") + assert self.connection.flush() == bytearray.fromhex("0100") def test_read_status_invalid_json(self): - self.java_client.connection.receive(bytearray.fromhex("0300017B")) + self.connection.receive(bytearray.fromhex("0300017B")) with pytest.raises(IOError, match=r"^Received invalid JSON$"): self.java_client.read_status() def test_read_status_invalid_reply(self): - self.java_client.connection.receive( + self.connection.receive( # no motd, see also #922 bytearray.fromhex( "4F004D7B22706C6179657273223A7B226D6178223A32302C226F6E6C696E65223A307D2C2276657273696F6E223A7B226E616" @@ -56,91 +57,75 @@ def test_read_status_invalid_reply(self): self.java_client.read_status() def test_read_status_invalid_status(self): - self.java_client.connection.receive(bytearray.fromhex("0105")) + self.connection.receive(bytearray.fromhex("0105")) with pytest.raises(IOError, match=r"^Received invalid status response packet.$"): self.java_client.read_status() def test_test_ping(self): - self.java_client.connection.receive(bytearray.fromhex("09010000000000DD7D1C")) + self.connection.receive(bytearray.fromhex("09010000000000DD7D1C")) self.java_client.ping_token = 14515484 assert self.java_client.test_ping() >= 0 - assert self.java_client.connection.flush() == bytearray.fromhex("09010000000000DD7D1C") + assert self.connection.flush() == bytearray.fromhex("09010000000000DD7D1C") def test_test_ping_invalid(self): - self.java_client.connection.receive(bytearray.fromhex("011F")) + self.connection.receive(bytearray.fromhex("011F")) self.java_client.ping_token = 14515484 with pytest.raises(IOError, match=r"^Received invalid ping response packet.$"): self.java_client.test_ping() def test_test_ping_wrong_token(self): - self.java_client.connection.receive(bytearray.fromhex("09010000000000DD7D1C")) + self.connection.receive(bytearray.fromhex("09010000000000DD7D1C")) self.java_client.ping_token = 12345 with pytest.raises(IOError, match=r"^Received mangled ping response \(expected token 12345, got 14515484\)$"): self.java_client.test_ping() + # Windows CI can occasionally measure <1ms despite a 1ms sleep; + # see https://github.com/py-mine/mcstatus/issues/442. @pytest.mark.flaky(reruns=5, condition=sys.platform.startswith("win32")) def test_latency_is_real_number(self): - """``time.perf_counter`` returns fractional seconds, we must convert it to milliseconds.""" + self.connection.receive( + bytearray.fromhex( + "7200707B226465736372697074696F6E223A2241204D696E65637261667420536572766572222C22706C6179657273223A7B2" + "26D6178223A32302C226F6E6C696E65223A307D2C2276657273696F6E223A7B226E616D65223A22312E382D70726531222C22" + "70726F746F636F6C223A34347D7D" + ) + ) - def mocked_read_buffer(): + def mocked_read_bytearray() -> bytes: time.sleep(0.001) - return mock.DEFAULT - - with mock.patch.object(Connection, "read_buffer") as mocked: - mocked.side_effect = mocked_read_buffer - mocked.return_value.read_varint.return_value = 0 - mocked.return_value.read_utf.return_value = """ - { - "description": "A Minecraft Server", - "players": {"max": 20, "online": 0}, - "version": {"name": "1.8-pre1", "protocol": 44} - } - """ - java_client = JavaClient( - Connection(), # pyright: ignore[reportArgumentType] - address=Address("localhost", 25565), - version=44, - ) + return original_read_bytearray() - java_client.connection.receive( - bytearray.fromhex( - "7200707B226465736372697074696F6E223A2241204D696E65637261667420536572766572222C22706C6179657273223A" - "7B226D6178223A32302C226F6E6C696E65223A307D2C2276657273696F6E223A7B226E616D65223A22312E382D70726531" - "222C2270726F746F636F6C223A34347D7D" - ) - ) - # we slept 1ms, so this should be always ~1. - assert java_client.read_status().latency >= 1 + original_read_bytearray = self.connection.read_bytearray + + with mock.patch.object(self.connection, "read_bytearray", side_effect=mocked_read_bytearray): + # Latency should be in milliseconds, so somewhere just above 1 + # + # We give it a pretty big leeway with the max here, as the MacOS CI runs can + # sometimes take quite long (upwards of 10s). + latency = self.java_client.read_status().latency + assert 1 <= latency <= 20 + # Windows CI can occasionally measure <1ms despite a 1ms sleep; + # see https://github.com/py-mine/mcstatus/issues/442. @pytest.mark.flaky(reruns=5, condition=sys.platform.startswith("win32")) def test_test_ping_is_in_milliseconds(self): - """``time.perf_counter`` returns fractional seconds, we must convert it to milliseconds.""" + self.java_client.ping_token = 14515484 + self.connection.receive(bytearray.fromhex("09010000000000DD7D1C")) - def mocked_read_buffer(): + def mocked_read_bytearray() -> bytes: time.sleep(0.001) - return mock.DEFAULT - - with mock.patch.object(Connection, "read_buffer") as mocked: - mocked.side_effect = mocked_read_buffer - mocked.return_value.read_varint.return_value = 1 - mocked.return_value.read_long.return_value = 123456789 - java_client = JavaClient( - Connection(), # pyright: ignore[reportArgumentType] - address=Address("localhost", 25565), - version=44, - ping_token=123456789, - ) + return original_read_bytearray() - java_client.connection.receive( - bytearray.fromhex( - "7200707B226465736372697074696F6E223A2241204D696E65637261667420536572766572222C22706C6179657273223A" - "7B226D6178223A32302C226F6E6C696E65223A307D2C2276657273696F6E223A7B226E616D65223A22312E382D70726531" - "222C2270726F746F636F6C223A34347D7D" - ) - ) - # we slept 1ms, so this should be always ~1. - assert java_client.test_ping() >= 1 + original_read_bytearray = self.connection.read_bytearray + + with mock.patch.object(self.connection, "read_bytearray", side_effect=mocked_read_bytearray): + # Latency should be in milliseconds, so somewhere just above 1 + # + # We give it a pretty big leeway with the max here, as the MacOS CI runs can + # sometimes take quite long (upwards of 10s). + latency = self.java_client.test_ping() + assert 1 <= latency <= 20 diff --git a/tests/protocol/test_java_client_async.py b/tests/protocol/test_java_client_async.py index d126de58..5c7e29b4 100644 --- a/tests/protocol/test_java_client_async.py +++ b/tests/protocol/test_java_client_async.py @@ -1,42 +1,30 @@ import asyncio import sys -import time from unittest import mock import pytest from mcstatus._net.address import Address -from mcstatus._protocol.connection import Connection from mcstatus._protocol.java_client import AsyncJavaClient - - -def async_decorator(f): - def wrapper(*args, **kwargs): - return asyncio.run(f(*args, **kwargs)) - - return wrapper - - -class FakeAsyncConnection(Connection): - async def read_buffer(self): # pyright: ignore[reportIncompatibleMethodOverride] - return super().read_buffer() +from tests.protocol.helpers import AsyncBufferConnection, async_decorator class TestAsyncJavaClient: def setup_method(self): + self.connection = AsyncBufferConnection() self.java_client = AsyncJavaClient( - FakeAsyncConnection(), # pyright: ignore[reportArgumentType] + self.connection, # pyright: ignore[reportArgumentType] address=Address("localhost", 25565), version=44, ) def test_handshake(self): - self.java_client.handshake() + async_decorator(self.java_client.handshake)() - assert self.java_client.connection.flush() == bytearray.fromhex("0F002C096C6F63616C686F737463DD01") + assert self.connection.flush() == bytearray.fromhex("0F002C096C6F63616C686F737463DD01") def test_read_status(self): - self.java_client.connection.receive( + self.connection.receive( bytearray.fromhex( "7200707B226465736372697074696F6E223A2241204D696E65637261667420536572766572222C22706C6179657273223A7B2" "26D6178223A32302C226F6E6C696E65223A307D2C2276657273696F6E223A7B226E616D65223A22312E382D70726531222C22" @@ -50,15 +38,15 @@ def test_read_status(self): "players": {"max": 20, "online": 0}, "version": {"name": "1.8-pre1", "protocol": 44}, } - assert self.java_client.connection.flush() == bytearray.fromhex("0100") + assert self.connection.flush() == bytearray.fromhex("0100") def test_read_status_invalid_json(self): - self.java_client.connection.receive(bytearray.fromhex("0300017B")) + self.connection.receive(bytearray.fromhex("0300017B")) with pytest.raises(IOError, match=r"^Received invalid JSON$"): async_decorator(self.java_client.read_status)() def test_read_status_invalid_reply(self): - self.java_client.connection.receive( + self.connection.receive( bytearray.fromhex( "4F004D7B22706C6179657273223A7B226D6178223A32302C226F6E6C696E65223A307D2C2276657273696F6E223A7B226E616" "D65223A22312E382D70726531222C2270726F746F636F6C223A34347D7D" @@ -68,88 +56,79 @@ def test_read_status_invalid_reply(self): async_decorator(self.java_client.read_status)() def test_read_status_invalid_status(self): - self.java_client.connection.receive(bytearray.fromhex("0105")) + self.connection.receive(bytearray.fromhex("0105")) with pytest.raises(IOError, match=r"^Received invalid status response packet.$"): async_decorator(self.java_client.read_status)() def test_test_ping(self): - self.java_client.connection.receive(bytearray.fromhex("09010000000000DD7D1C")) + self.connection.receive(bytearray.fromhex("09010000000000DD7D1C")) self.java_client.ping_token = 14515484 assert async_decorator(self.java_client.test_ping)() >= 0 - assert self.java_client.connection.flush() == bytearray.fromhex("09010000000000DD7D1C") + assert self.connection.flush() == bytearray.fromhex("09010000000000DD7D1C") def test_test_ping_invalid(self): - self.java_client.connection.receive(bytearray.fromhex("011F")) + self.connection.receive(bytearray.fromhex("011F")) self.java_client.ping_token = 14515484 with pytest.raises(IOError, match=r"^Received invalid ping response packet.$"): async_decorator(self.java_client.test_ping)() def test_test_ping_wrong_token(self): - self.java_client.connection.receive(bytearray.fromhex("09010000000000DD7D1C")) + self.connection.receive(bytearray.fromhex("09010000000000DD7D1C")) self.java_client.ping_token = 12345 with pytest.raises(IOError, match=r"^Received mangled ping response \(expected token 12345, got 14515484\)$"): async_decorator(self.java_client.test_ping)() @pytest.mark.asyncio + # Windows CI can occasionally measure <1ms despite a 1ms sleep; + # see https://github.com/py-mine/mcstatus/issues/442. @pytest.mark.flaky(reruns=5, condition=sys.platform.startswith("win32")) async def test_latency_is_real_number(self): - """``time.perf_counter`` returns fractional seconds, we must convert it to milliseconds.""" - - def mocked_read_buffer(): - time.sleep(0.001) - return mock.DEFAULT - - with mock.patch.object(FakeAsyncConnection, "read_buffer") as mocked: - mocked.side_effect = mocked_read_buffer - # overwrite `async` here - mocked.return_value.read_varint = lambda: 0 - mocked.return_value.read_utf = lambda: ( - """ - { - "description": "A Minecraft Server", - "players": {"max": 20, "online": 0}, - "version": {"name": "1.8-pre1", "protocol": 44} - } - """ - ) - java_client = AsyncJavaClient( - FakeAsyncConnection(), # pyright: ignore[reportArgumentType] - address=Address("localhost", 25565), - version=44, + self.connection.receive( + bytearray.fromhex( + "7200707B226465736372697074696F6E223A2241204D696E65637261667420536572766572222C22706C6179657273223A7B2" + "26D6178223A32302C226F6E6C696E65223A307D2C2276657273696F6E223A7B226E616D65223A22312E382D70726531222C22" + "70726F746F636F6C223A34347D7D" ) + ) - java_client.connection.receive( - bytearray.fromhex( - "7200707B226465736372697074696F6E223A2241204D696E65637261667420536572766572222C22706C6179657273223A" - "7B226D6178223A32302C226F6E6C696E65223A307D2C2276657273696F6E223A7B226E616D65223A22312E382D70726531" - "222C2270726F746F636F6C223A34347D7D" - ) - ) - # we slept 1ms, so this should be always ~1. - assert (await java_client.read_status()).latency >= 1 + async def mocked_read_bytearray() -> bytes: + """Delay reads while still delegating to the real async connection implementation.""" + await asyncio.sleep(0.001) + return await original_read_bytearray() + + original_read_bytearray = self.connection.read_bytearray + + with mock.patch.object(self.connection, "read_bytearray", side_effect=mocked_read_bytearray): + # Latency should be in milliseconds, so somewhere just above 1 + # + # We give it a pretty big leeway with the max here, as the MacOS CI runs can + # sometimes take quite long (upwards of 10s). + latency = (await self.java_client.read_status()).latency + assert 1 <= latency <= 20 @pytest.mark.asyncio + # Windows CI can occasionally measure <1ms despite a 1ms sleep; + # see https://github.com/py-mine/mcstatus/issues/442. @pytest.mark.flaky(reruns=5, condition=sys.platform.startswith("win32")) async def test_test_ping_is_in_milliseconds(self): - """``time.perf_counter`` returns fractional seconds, we must convert it to milliseconds.""" - - def mocked_read_buffer(): - time.sleep(0.001) - return mock.DEFAULT - - with mock.patch.object(FakeAsyncConnection, "read_buffer") as mocked: - mocked.side_effect = mocked_read_buffer - mocked.return_value.read_varint = lambda: 1 # overwrite `async` here - mocked.return_value.read_long = lambda: 123456789 # overwrite `async` here - java_client = AsyncJavaClient( - FakeAsyncConnection(), # pyright: ignore[reportArgumentType] - address=Address("localhost", 25565), - version=44, - ping_token=123456789, - ) - # we slept 1ms, so this should be always ~1. - assert await java_client.test_ping() >= 1 + self.java_client.ping_token = 14515484 + self.connection.receive(bytearray.fromhex("09010000000000DD7D1C")) + + async def mocked_read_bytearray() -> bytes: + """Delay reads while still delegating to the real async connection implementation.""" + await asyncio.sleep(0.001) + return await original_read_bytearray() + + original_read_bytearray = self.connection.read_bytearray + + with mock.patch.object(self.connection, "read_bytearray", side_effect=mocked_read_bytearray): + # Latency should be in milliseconds, so somewhere just above 1 + # + # We give it a pretty big leeway with the max here, as the MacOS CI runs can + # sometimes take quite long (upwards of 10s). + latency = await self.java_client.test_ping() + assert 1 <= latency <= 20 diff --git a/tests/protocol/test_legacy_client.py b/tests/protocol/test_legacy_client.py index cf55abe2..3e6dfbf1 100644 --- a/tests/protocol/test_legacy_client.py +++ b/tests/protocol/test_legacy_client.py @@ -1,9 +1,9 @@ import pytest -from mcstatus._protocol.connection import Connection from mcstatus._protocol.legacy_client import LegacyClient from mcstatus.motd import Motd from mcstatus.responses.legacy import LegacyStatusPlayers, LegacyStatusResponse, LegacyStatusVersion +from tests.protocol.helpers import SyncBufferConnection def test_invalid_kick_reason(): @@ -40,7 +40,7 @@ def test_parse_response(response: bytes, expected: LegacyStatusResponse): def test_invalid_packet_id(): - socket = Connection() + socket = SyncBufferConnection() socket.receive(bytearray.fromhex("00")) server = LegacyClient(socket) with pytest.raises(IOError, match=r"^Received invalid packet ID$"): diff --git a/tests/protocol/test_query_client.py b/tests/protocol/test_query_client.py index b4f36abb..b46e8a87 100644 --- a/tests/protocol/test_query_client.py +++ b/tests/protocol/test_query_client.py @@ -1,24 +1,25 @@ from unittest.mock import Mock -from mcstatus._protocol.connection import Connection from mcstatus._protocol.query_client import QueryClient from mcstatus.motd import Motd +from tests.protocol.helpers import SyncDatagramConnection class TestQueryClient: def setup_method(self): - self.query_client = QueryClient(Connection()) # pyright: ignore[reportArgumentType] + self.connection = SyncDatagramConnection() + self.query_client = QueryClient(self.connection) # pyright: ignore[reportArgumentType] def test_handshake(self): - self.query_client.connection.receive(bytearray.fromhex("090000000035373033353037373800")) + self.connection.receive(bytearray.fromhex("090000000035373033353037373800")) self.query_client.handshake() - conn_bytes = self.query_client.connection.flush() + conn_bytes = self.connection.flush() assert conn_bytes[:3] == bytearray.fromhex("FEFD09") assert self.query_client.challenge == 570350778 def test_query(self): - self.query_client.connection.receive( + self.connection.receive( bytearray.fromhex( "00000000000000000000000000000000686f73746e616d650041204d696e656372616674205365727665720067616d6574797" "06500534d500067616d655f6964004d494e4543524146540076657273696f6e00312e3800706c7567696e7300006d61700077" @@ -28,7 +29,7 @@ def test_query(self): ) ) response = self.query_client.read_query() - conn_bytes = self.query_client.connection.flush() + conn_bytes = self.connection.flush() assert conn_bytes[:3] == bytearray.fromhex("FEFD00") assert conn_bytes[7:] == bytearray.fromhex("0000000000000000") assert response.raw == { @@ -46,7 +47,7 @@ def test_query(self): assert response.players.list == ["Dinnerbone", "Djinnibone", "Steve"] def test_query_handles_unorderd_map_response(self): - self.query_client.connection.receive( + self.connection.receive( bytearray( b"\x00\x00\x00\x00\x00GeyserMC\x00\x80\x00hostname\x00Geyser\x00hostip\x001.1.1.1\x00plugins\x00\x00numplayers" b"\x001\x00gametype\x00SMP\x00maxplayers\x00100\x00hostport\x0019132\x00version\x00Geyser" @@ -54,14 +55,14 @@ def test_query_handles_unorderd_map_response(self): ) ) response = self.query_client.read_query() - self.query_client.connection.flush() + self.connection.flush() assert response.raw["game_id"] == "MINECRAFT" assert response.motd == Motd.parse("Geyser") assert response.software.version == "Geyser (git-master-0fd903e) 1.18.10" def test_query_handles_unicode_motd_with_nulls(self): - self.query_client.connection.receive( + self.connection.receive( bytearray( b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00hostname\x00\x00*K\xd5\x00gametype\x00SMP" b"\x00game_id\x00MINECRAFT\x00version\x001.16.5\x00plugins\x00Paper on 1.16.5-R0.1-SNAPSHOT\x00map\x00world" @@ -70,13 +71,13 @@ def test_query_handles_unicode_motd_with_nulls(self): ) ) response = self.query_client.read_query() - self.query_client.connection.flush() + self.connection.flush() assert response.raw["game_id"] == "MINECRAFT" assert response.motd == Motd.parse("\x00*KÕ") def test_query_handles_unicode_motd_with_2a00_at_the_start(self): - self.query_client.connection.receive( + self.connection.receive( bytearray.fromhex( "00000000000000000000000000000000686f73746e616d6500006f746865720067616d657479706500534d500067616d655f6964004d" "494e4543524146540076657273696f6e00312e31382e3100706c7567696e7300006d617000776f726c64006e756d706c617965727300" @@ -85,7 +86,7 @@ def test_query_handles_unicode_motd_with_2a00_at_the_start(self): ) ) response = self.query_client.read_query() - self.query_client.connection.flush() + self.connection.flush() assert response.raw["game_id"] == "MINECRAFT" assert response.motd == Motd.parse("\x00other") # "\u2a00other" is actually what is expected, @@ -96,12 +97,12 @@ def test_session_id(self): def session_id(): return 0x01010101 - self.query_client.connection.receive(bytearray.fromhex("090000000035373033353037373800")) + self.connection.receive(bytearray.fromhex("090000000035373033353037373800")) self.query_client._generate_session_id = Mock() self.query_client._generate_session_id = session_id self.query_client.handshake() - conn_bytes = self.query_client.connection.flush() + conn_bytes = self.connection.flush() assert conn_bytes[:3] == bytearray.fromhex("FEFD09") assert conn_bytes[3:] == session_id().to_bytes(4, byteorder="big") assert self.query_client.challenge == 570350778 diff --git a/tests/protocol/test_query_client_async.py b/tests/protocol/test_query_client_async.py index dfc9f58f..3c091bb4 100644 --- a/tests/protocol/test_query_client_async.py +++ b/tests/protocol/test_query_client_async.py @@ -1,29 +1,21 @@ -from mcstatus._protocol.connection import Connection from mcstatus._protocol.query_client import AsyncQueryClient -from tests.protocol.test_java_client_async import async_decorator - - -class FakeUDPAsyncConnection(Connection): - async def read(self, length): # pyright: ignore[reportIncompatibleMethodOverride] - return super().read(length) - - async def write(self, data): # pyright: ignore[reportIncompatibleMethodOverride] - return super().write(data) +from tests.protocol.helpers import AsyncDatagramConnection, async_decorator class TestAsyncQueryClient: def setup_method(self): - self.query_client = AsyncQueryClient(FakeUDPAsyncConnection()) # pyright: ignore[reportArgumentType] + self.connection = AsyncDatagramConnection() + self.query_client = AsyncQueryClient(self.connection) # pyright: ignore[reportArgumentType] def test_handshake(self): - self.query_client.connection.receive(bytearray.fromhex("090000000035373033353037373800")) + self.connection.receive(bytearray.fromhex("090000000035373033353037373800")) async_decorator(self.query_client.handshake)() - conn_bytes = self.query_client.connection.flush() + conn_bytes = self.connection.flush() assert conn_bytes[:3] == bytearray.fromhex("FEFD09") assert self.query_client.challenge == 570350778 def test_query(self): - self.query_client.connection.receive( + self.connection.receive( bytearray.fromhex( "00000000000000000000000000000000686f73746e616d650041204d696e656372616674205365727665720067616d6574797" "06500534d500067616d655f6964004d494e4543524146540076657273696f6e00312e3800706c7567696e7300006d61700077" @@ -33,7 +25,7 @@ def test_query(self): ) ) response = async_decorator(self.query_client.read_query)() - conn_bytes = self.query_client.connection.flush() + conn_bytes = self.connection.flush() assert conn_bytes[:3] == bytearray.fromhex("FEFD00") assert conn_bytes[7:] == bytearray.fromhex("0000000000000000") assert response.raw == { diff --git a/tests/protocol/test_timeout.py b/tests/protocol/test_timeout.py index d6c0b2ec..32a8bcd9 100644 --- a/tests/protocol/test_timeout.py +++ b/tests/protocol/test_timeout.py @@ -1,12 +1,12 @@ import asyncio import typing from asyncio.exceptions import TimeoutError as AsyncioTimeoutError -from unittest.mock import patch +from unittest.mock import AsyncMock, Mock, patch import pytest from mcstatus._net.address import Address -from mcstatus._protocol.connection import TCPAsyncSocketConnection +from mcstatus._protocol.io.connection import TCPAsyncSocketConnection class FakeAsyncStream(asyncio.StreamReader): @@ -26,3 +26,57 @@ async def test_tcp_socket_read(self): async with TCPAsyncSocketConnection(Address("dummy_address", 1234), timeout=0.01) as tcp_async_socket: with pytest.raises(AsyncioTimeoutError): await tcp_async_socket.read(10) + + @pytest.mark.asyncio + async def test_tcp_socket_read_partial_data_then_eof(self): + """Raise when the stream ends after only partial payload delivery.""" + tcp_async_socket = TCPAsyncSocketConnection(Address("dummy_address", 1234), timeout=0.01) + tcp_async_socket.reader = Mock(read=AsyncMock(side_effect=[b"a", b""])) + + with pytest.raises( + OSError, + match=( + r"^Server stopped responding \(got 1 bytes, but expected 2 bytes\). " + r"Partial obtained data: bytearray\(b'a'\)$" + ), + ): + await tcp_async_socket.read(2) + + @pytest.mark.asyncio + async def test_tcp_socket_read_eof_without_any_data(self): + """Raise when the stream immediately ends without returning any data.""" + tcp_async_socket = TCPAsyncSocketConnection(Address("dummy_address", 1234), timeout=0.01) + tcp_async_socket.reader = Mock(read=AsyncMock(return_value=b"")) + + with pytest.raises(OSError, match=r"^Server did not respond with any information!$"): + await tcp_async_socket.read(2) + + @pytest.mark.asyncio + async def test_tcp_socket_write_awaits_drain(self): + """Ensure writes await ``drain`` so buffered data is flushed.""" + writer = Mock() + writer.write = Mock() + writer.drain = AsyncMock() + + tcp_async_socket = TCPAsyncSocketConnection(Address("dummy_address", 1234), timeout=0.01) + tcp_async_socket.writer = writer + + await tcp_async_socket.write(b"hello") + + writer.write.assert_called_once_with(b"hello") + writer.drain.assert_awaited_once_with() + + @pytest.mark.asyncio + async def test_tcp_socket_close_waits_for_writer(self): + """Ensure close awaits ``wait_closed`` on the stream writer.""" + writer = Mock() + writer.close = Mock() + writer.wait_closed = AsyncMock() + + tcp_async_socket = TCPAsyncSocketConnection(Address("dummy_address", 1234), timeout=0.01) + tcp_async_socket.writer = typing.cast("asyncio.StreamWriter", writer) + + await tcp_async_socket.close() + + writer.close.assert_called_once_with() + writer.wait_closed.assert_awaited_once_with() diff --git a/tests/test_server.py b/tests/test_server.py index 57c1d741..9b3d05de 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -1,58 +1,14 @@ from __future__ import annotations import asyncio -from typing import SupportsIndex, TYPE_CHECKING, TypeAlias from unittest.mock import call, patch import pytest import pytest_asyncio from mcstatus._net.address import Address -from mcstatus._protocol.connection import BaseAsyncReadSyncWriteConnection, Connection from mcstatus.server import BedrockServer, JavaServer, LegacyServer - -if TYPE_CHECKING: - from collections.abc import Iterable - -BytesConvertable: TypeAlias = "SupportsIndex | Iterable[SupportsIndex]" - - -class AsyncConnection(BaseAsyncReadSyncWriteConnection): - def __init__(self) -> None: - self.sent = bytearray() - self.received = bytearray() - - async def read(self, length: int) -> bytearray: - """Return :attr:`.received` up to length bytes, then cut received up to that point.""" - if len(self.received) < length: - raise OSError(f"Not enough data to read! {len(self.received)} < {length}") - - result = self.received[:length] - self.received = self.received[length:] - return result - - def write(self, data: Connection | str | bytearray | bytes) -> None: - """Extend :attr:`.sent` from ``data``.""" - if isinstance(data, Connection): - data = data.flush() - if isinstance(data, str): - data = bytearray(data, "utf-8") - self.sent.extend(data) - - def receive(self, data: BytesConvertable | bytearray) -> None: - """Extend :attr:`.received` with ``data``.""" - if not isinstance(data, bytearray): - data = bytearray(data) - self.received.extend(data) - - def remaining(self) -> int: - """Return length of :attr:`.received`.""" - return len(self.received) - - def flush(self) -> bytearray: - """Return :attr:`.sent`, also clears :attr:`.sent`.""" - result, self.sent = self.sent, bytearray() - return result +from tests.protocol.helpers import AsyncBufferConnection, SyncBufferConnection, SyncDatagramConnection class MockProtocolFactory(asyncio.Protocol): @@ -86,13 +42,15 @@ def resume_writing(self): @pytest_asyncio.fixture() async def create_mock_packet_server(): + """Create a temporary asyncio packet servers used by tests.""" event_loop = asyncio.get_running_loop() servers = [] async def create_server(port, data_expected_to_receive, data_to_respond_with): + """Start a server that validates one request pattern and returns a fixed payload.""" server = await event_loop.create_server( lambda: MockProtocolFactory(data_expected_to_receive, data_to_respond_with), - host="localhost", + host="127.0.0.1", port=port, ) servers.append(server) @@ -107,7 +65,7 @@ async def create_server(port, data_expected_to_receive, data_to_respond_with): class TestBedrockServer: def setup_method(self): - self.server = BedrockServer("localhost") + self.server = BedrockServer("127.0.0.1") def test_default_port(self): assert self.server.address.port == 19132 @@ -126,7 +84,7 @@ async def test_async_ping(self, unused_tcp_port, create_mock_packet_server): data_expected_to_receive=bytearray.fromhex("09010000000001C54246"), data_to_respond_with=bytearray.fromhex("0F002F096C6F63616C686F737463DD0109010000000001C54246"), ) - minecraft_server = JavaServer("localhost", port=unused_tcp_port) + minecraft_server = JavaServer("127.0.0.1", port=unused_tcp_port) latency = await minecraft_server.async_ping(ping_token=29704774, version=47) assert latency >= 0 @@ -140,7 +98,7 @@ async def test_async_lookup_constructor(self): def test_java_server_with_query_port(): with patch("mcstatus.server.JavaServer._retry_query") as patched_query_func: - server = JavaServer("localhost", query_port=12345) + server = JavaServer("127.0.0.1", query_port=12345) server.query() assert server.query_port == 12345 assert patched_query_func.call_args == call(Address("127.0.0.1", port=12345), tries=3) @@ -149,7 +107,7 @@ def test_java_server_with_query_port(): @pytest.mark.asyncio async def test_java_server_with_query_port_async(): with patch("mcstatus.server.JavaServer._retry_async_query") as patched_query_func: - server = JavaServer("localhost", query_port=12345) + server = JavaServer("127.0.0.1", query_port=12345) await server.async_query() assert server.query_port == 12345 assert patched_query_func.call_args == call(Address("127.0.0.1", port=12345), tries=3) @@ -157,8 +115,8 @@ async def test_java_server_with_query_port_async(): class TestJavaServer: def setup_method(self): - self.socket = Connection() - self.server = JavaServer("localhost") + self.socket = SyncBufferConnection() + self.server = JavaServer("127.0.0.1") def test_default_port(self): assert self.server.address.port == 25565 @@ -170,7 +128,7 @@ def test_ping(self): connection.return_value.__enter__.return_value = self.socket latency = self.server.ping(ping_token=29704774, version=47) - assert self.socket.flush() == bytearray.fromhex("0F002F096C6F63616C686F737463DD0109010000000001C54246") + assert self.socket.flush() == bytearray.fromhex("0F002F093132372E302E302E3163DD0109010000000001C54246") assert self.socket.remaining() == 0, "Data is pending to be read, but should be empty" assert latency >= 0 @@ -195,7 +153,7 @@ def test_status(self): connection.return_value.__enter__.return_value = self.socket info = self.server.status(version=47) - assert self.socket.flush() == bytearray.fromhex("0F002F096C6F63616C686F737463DD010100") + assert self.socket.flush() == bytearray.fromhex("0F002F093132372E302E302E3163DD010100") assert self.socket.remaining() == 0, "Data is pending to be read, but should be empty" assert info.raw == { "description": "A Minecraft Server", @@ -213,8 +171,9 @@ def test_status_retry(self): assert java_client.call_count == 3 def test_query(self): - self.socket.receive(bytearray.fromhex("090000000035373033353037373800")) - self.socket.receive( + socket = SyncDatagramConnection() + socket.receive(bytearray.fromhex("090000000035373033353037373800")) + socket.receive( bytearray.fromhex( "00000000000000000000000000000000686f73746e616d650041204d696e656372616674205365727665720067616d6574797" "06500534d500067616d655f6964004d494e4543524146540076657273696f6e00312e3800706c7567696e7300006d61700077" @@ -224,38 +183,30 @@ def test_query(self): ) ) - with patch("mcstatus._protocol.connection.Connection.remaining") as mock_remaining: - mock_remaining.side_effect = [15, 208] - - with ( - patch("mcstatus.server.UDPSocketConnection") as connection, - patch.object(self.server.address, "resolve_ip") as resolve_ip, - ): - connection.return_value.__enter__.return_value = self.socket - resolve_ip.return_value = "127.0.0.1" - info = self.server.query() - - conn_bytes = self.socket.flush() - assert conn_bytes[:3] == bytearray.fromhex("FEFD09") - assert info.raw == { - "hostname": "A Minecraft Server", - "gametype": "SMP", - "game_id": "MINECRAFT", - "version": "1.8", - "plugins": "", - "map": "world", - "numplayers": "3", - "maxplayers": "20", - "hostport": "25565", - "hostip": "192.168.56.1", - } + with patch("mcstatus.server.UDPSocketConnection") as connection: + connection.return_value.__enter__.return_value = socket + info = self.server.query() + + conn_bytes = socket.flush() + assert conn_bytes[:3] == bytearray.fromhex("FEFD09") + assert info.raw == { + "hostname": "A Minecraft Server", + "gametype": "SMP", + "game_id": "MINECRAFT", + "version": "1.8", + "plugins": "", + "map": "world", + "numplayers": "3", + "maxplayers": "20", + "hostport": "25565", + "hostip": "192.168.56.1", + } def test_query_retry(self): # Use a blank mock for the connection, we don't want to actually create any connections with patch("mcstatus.server.UDPSocketConnection"), patch("mcstatus.server.QueryClient") as query_client: query_client.side_effect = [RuntimeError, RuntimeError, RuntimeError] - with pytest.raises(RuntimeError, match=r"^$"), patch.object(self.server.address, "resolve_ip") as resolve_ip: # noqa: PT012 - resolve_ip.return_value = "127.0.0.1" + with pytest.raises(RuntimeError, match=r"^$"): self.server.query() assert query_client.call_count == 3 @@ -267,8 +218,8 @@ def test_lookup_constructor(self): class TestLegacyServer: def setup_method(self): - self.socket = Connection() - self.server = LegacyServer("localhost") + self.socket = SyncBufferConnection() + self.server = LegacyServer("127.0.0.1") def test_default_port(self): assert self.server.address.port == 25565 @@ -306,8 +257,8 @@ def test_status(self): class TestAsyncLegacyServer: def setup_method(self): - self.socket = AsyncConnection() - self.server = LegacyServer("localhost") + self.socket = AsyncBufferConnection() + self.server = LegacyServer("127.0.0.1") @pytest.mark.asyncio async def test_async_lookup_constructor(self): diff --git a/tests/utils/test_retry.py b/tests/utils/test_retry.py index 6be65834..e2861735 100644 --- a/tests/utils/test_retry.py +++ b/tests/utils/test_retry.py @@ -1,7 +1,7 @@ import pytest from mcstatus._utils.retry import retry -from tests.protocol.test_java_client_async import async_decorator +from tests.protocol.helpers import async_decorator def test_sync_success():