From 121c3bd6bc3e9153596064c4b93a223eb559fb9f Mon Sep 17 00:00:00 2001 From: ItsDrike Date: Thu, 5 Mar 2026 12:53:41 +0100 Subject: [PATCH 01/10] Rework protocol IO, add Buffer abstraction & fix socket write issues The previous protocol implementation mixed multiple responsibilities in `Connection`. It acted both as a network transport and as an in-memory buffer for packet construction/parsing. This made the code confusing, hard to reason about, and resulted in duplicated serialization logic. This change restructures the protocol IO layer and cleans up the design: - Extract binary read/write primitives into `_protocol.io.base_io` (BaseSync/AsyncReader and BaseSync/AsyncWriter). - Introduce a dedicated `Buffer` type for in-memory packet construction and decoding instead of abusing connection objects as buffers. - Move socket implementations into `_protocol.io.connection`. - Update protocol clients to use the new IO abstractions and the `StructFormat` helpers for typed serialization. Besides the structural cleanup, this also fixes several important issues in the previous implementation: * The synchronous TCP connection used `socket.send(data)`, which does not guarantee that all bytes are sent. The new implementation correctly uses `socket.sendall()`. * The asynchronous TCP write method was synchronous. It was just calling `writer.write(data)` but never awaited `drain()`, meaning the actual network write could happen at an arbitrary later time. In practice this rarely broke because a read always followed the write, which implicitly allowed the event loop to flush the buffer, but the behavior was still incorrect and fragile. The new async connection now uses async write method and awaits `writer.drain()` to ensure proper write semantics. --- mcstatus/_protocol/connection.py | 705 -------------------------- mcstatus/_protocol/io/__init__.py | 0 mcstatus/_protocol/io/base_io.py | 743 ++++++++++++++++++++++++++++ mcstatus/_protocol/io/buffer.py | 115 +++++ mcstatus/_protocol/io/connection.py | 244 +++++++++ mcstatus/_protocol/java_client.py | 57 ++- mcstatus/_protocol/legacy_client.py | 11 +- mcstatus/_protocol/query_client.py | 40 +- mcstatus/responses/forge.py | 35 +- mcstatus/server.py | 6 +- 10 files changed, 1182 insertions(+), 774 deletions(-) delete mode 100644 mcstatus/_protocol/connection.py create mode 100644 mcstatus/_protocol/io/__init__.py create mode 100644 mcstatus/_protocol/io/base_io.py create mode 100644 mcstatus/_protocol/io/buffer.py create mode 100644 mcstatus/_protocol/io/connection.py diff --git a/mcstatus/_protocol/connection.py b/mcstatus/_protocol/connection.py deleted file mode 100644 index 72c36057..00000000 --- a/mcstatus/_protocol/connection.py +++ /dev/null @@ -1,705 +0,0 @@ -from __future__ import annotations - -import asyncio -import errno -import socket -import struct -from abc import ABC, abstractmethod -from ctypes import c_int32 as signed_int32, c_int64 as signed_int64, c_uint32 as unsigned_int32, c_uint64 as unsigned_int64 -from ipaddress import ip_address -from typing import TYPE_CHECKING, TypeAlias, cast - -import asyncio_dgram - -if TYPE_CHECKING: - from collections.abc import Iterable - - from typing_extensions import Self, SupportsIndex - - from mcstatus._net.address import Address - -__all__ = [ - "BaseAsyncConnection", - "BaseAsyncReadSyncWriteConnection", - "BaseConnection", - "BaseReadAsync", - "BaseReadSync", - "BaseSyncConnection", - "BaseWriteAsync", - "BaseWriteSync", - "Connection", - "SocketConnection", - "TCPAsyncSocketConnection", - "TCPSocketConnection", - "UDPAsyncSocketConnection", - "UDPSocketConnection", -] - -BytesConvertable: TypeAlias = "SupportsIndex | Iterable[SupportsIndex]" - - -def _ip_type(address: int | str) -> int | None: - """Determine the IP version (IPv4 or IPv6). - - :param address: - A string or integer, the IP address. Either IPv4 or IPv6 addresses may be supplied. - Integers less than 2**32 will be considered to be IPv4 by default. - :return: ``4`` or ``6`` if the IP is IPv4 or IPv6, respectively. :obj:`None` if the IP is invalid. - """ - try: - return ip_address(address).version - except ValueError: - return None - - -class BaseWriteSync(ABC): - """Base synchronous write class.""" - - __slots__ = () - - @abstractmethod - def write(self, data: Connection | str | bytearray | bytes) -> None: - """Write data to ``self``.""" - - def __repr__(self) -> str: - return f"<{self.__class__.__name__} Object>" - - @staticmethod - def _pack(format_: str, data: int) -> bytes: - """Pack data in with format in big-endian mode.""" - return struct.pack(">" + format_, data) - - def write_varint(self, value: int) -> None: - """Write varint with value ``value`` to ``self``. - - :param value: Maximum is ``2 ** 31 - 1``, minimum is ``-(2 ** 31)``. - :raises ValueError: If value is out of range. - """ - remaining = unsigned_int32(value).value - for _ in range(5): - if not remaining & -0x80: # remaining & ~0x7F == 0: - self.write(struct.pack("!B", remaining)) - if value > 2**31 - 1 or value < -(2**31): - break - return - self.write(struct.pack("!B", remaining & 0x7F | 0x80)) - remaining >>= 7 - raise ValueError(f'The value "{value}" is too big to send in a varint') - - def write_varlong(self, value: int) -> None: - """Write varlong with value ``value`` to ``self``. - - :param value: Maximum is ``2 ** 63 - 1``, minimum is ``-(2 ** 63)``. - :raises ValueError: If value is out of range. - """ - remaining = unsigned_int64(value).value - for _ in range(10): - if not remaining & -0x80: # remaining & ~0x7F == 0: - self.write(struct.pack("!B", remaining)) - if value > 2**63 - 1 or value < -(2**31): - break - return - self.write(struct.pack("!B", remaining & 0x7F | 0x80)) - remaining >>= 7 - raise ValueError(f'The value "{value}" is too big to send in a varlong') - - def write_utf(self, value: str) -> None: - """Write varint of length of ``value`` up to 32767 bytes, then write ``value`` encoded with ``UTF-8``.""" - self.write_varint(len(value)) - self.write(bytearray(value, "utf8")) - - def write_ascii(self, value: str) -> None: - """Write value encoded with ``ISO-8859-1``, then write an additional ``0x00`` at the end.""" - self.write(bytearray(value, "ISO-8859-1")) - self.write(bytearray.fromhex("00")) - - def write_short(self, value: int) -> None: - """Write 2 bytes for value ``-32768 - 32767``.""" - self.write(self._pack("h", value)) - - def write_ushort(self, value: int) -> None: - """Write 2 bytes for value ``0 - 65535 (2 ** 16 - 1)``.""" - self.write(self._pack("H", value)) - - def write_int(self, value: int) -> None: - """Write 4 bytes for value ``-2147483648 - 2147483647``.""" - self.write(self._pack("i", value)) - - def write_uint(self, value: int) -> None: - """Write 4 bytes for value ``0 - 4294967295 (2 ** 32 - 1)``.""" - self.write(self._pack("I", value)) - - def write_long(self, value: int) -> None: - """Write 8 bytes for value ``-9223372036854775808 - 9223372036854775807``.""" - self.write(self._pack("q", value)) - - def write_ulong(self, value: int) -> None: - """Write 8 bytes for value ``0 - 18446744073709551613 (2 ** 64 - 1)``.""" - self.write(self._pack("Q", value)) - - def write_bool(self, value: bool) -> None: # noqa: FBT001 # Boolean positional argument - """Write 1 byte for boolean `True` or `False`.""" - self.write(self._pack("?", value)) - - def write_buffer(self, buffer: Connection) -> None: - """Flush buffer, then write a varint of the length of the buffer's data, then write buffer data.""" - data = buffer.flush() - self.write_varint(len(data)) - self.write(data) - - -class BaseWriteAsync(ABC): - """Base synchronous write class.""" - - __slots__ = () - - @abstractmethod - async def write(self, data: Connection | str | bytearray | bytes) -> None: - """Write data to ``self``.""" - - def __repr__(self) -> str: - return f"<{self.__class__.__name__} Object>" - - @staticmethod - def _pack(format_: str, data: int) -> bytes: - """Pack data in with format in big-endian mode.""" - return struct.pack(">" + format_, data) - - async def write_varint(self, value: int) -> None: - """Write varint with value ``value`` to ``self``. - - :param value: Maximum is ``2 ** 31 - 1``, minimum is ``-(2 ** 31)``. - :raises ValueError: If value is out of range. - """ - remaining = unsigned_int32(value).value - for _ in range(5): - if not remaining & -0x80: # remaining & ~0x7F == 0: - await self.write(struct.pack("!B", remaining)) - if value > 2**31 - 1 or value < -(2**31): - break - return - await self.write(struct.pack("!B", remaining & 0x7F | 0x80)) - remaining >>= 7 - raise ValueError(f'The value "{value}" is too big to send in a varint') - - async def write_varlong(self, value: int) -> None: - """Write varlong with value ``value`` to ``self``. - - :param value: Maximum is ``2 ** 63 - 1``, minimum is ``-(2 ** 63)``. - :raises ValueError: If value is out of range. - """ - remaining = unsigned_int64(value).value - for _ in range(10): - if not remaining & -0x80: # remaining & ~0x7F == 0: - await self.write(struct.pack("!B", remaining)) - if value > 2**63 - 1 or value < -(2**31): - break - return - await self.write(struct.pack("!B", remaining & 0x7F | 0x80)) - remaining >>= 7 - raise ValueError(f'The value "{value}" is too big to send in a varlong') - - async def write_utf(self, value: str) -> None: - """Write varint of length of ``value`` up to 32767 bytes, then write ``value`` encoded with ``UTF-8``.""" - await self.write_varint(len(value)) - await self.write(bytearray(value, "utf8")) - - async def write_ascii(self, value: str) -> None: - """Write value encoded with ``ISO-8859-1``, then write an additional ``0x00`` at the end.""" - await self.write(bytearray(value, "ISO-8859-1")) - await self.write(bytearray.fromhex("00")) - - async def write_short(self, value: int) -> None: - """Write 2 bytes for value ``-32768 - 32767``.""" - await self.write(self._pack("h", value)) - - async def write_ushort(self, value: int) -> None: - """Write 2 bytes for value ``0 - 65535 (2 ** 16 - 1)``.""" - await self.write(self._pack("H", value)) - - async def write_int(self, value: int) -> None: - """Write 4 bytes for value ``-2147483648 - 2147483647``.""" - await self.write(self._pack("i", value)) - - async def write_uint(self, value: int) -> None: - """Write 4 bytes for value ``0 - 4294967295 (2 ** 32 - 1)``.""" - await self.write(self._pack("I", value)) - - async def write_long(self, value: int) -> None: - """Write 8 bytes for value ``-9223372036854775808 - 9223372036854775807``.""" - await self.write(self._pack("q", value)) - - async def write_ulong(self, value: int) -> None: - """Write 8 bytes for value ``0 - 18446744073709551613 (2 ** 64 - 1)``.""" - await self.write(self._pack("Q", value)) - - async def write_bool(self, value: bool) -> None: # noqa: FBT001 # Boolean positional argument - """Write 1 byte for boolean `True` or `False`.""" - await self.write(self._pack("?", value)) - - async def write_buffer(self, buffer: Connection) -> None: - """Flush buffer, then write a varint of the length of the buffer's data, then write buffer data.""" - data = buffer.flush() - await self.write_varint(len(data)) - await self.write(data) - - -class BaseReadSync(ABC): - """Base synchronous read class.""" - - __slots__ = () - - @abstractmethod - def read(self, length: int, /) -> bytearray: - """Read length bytes from ``self``, and return a byte array.""" - - def __repr__(self) -> str: - return f"<{self.__class__.__name__} Object>" - - @staticmethod - def _unpack(format_: str, data: bytes) -> int: - """Unpack data as bytes with format in big-endian.""" - return struct.unpack(">" + format_, bytes(data))[0] - - def read_varint(self) -> int: - """Read varint from ``self`` and return it. - - :param value: Maximum is ``2 ** 31 - 1``, minimum is ``-(2 ** 31)``. - :raises IOError: If varint received is out of range. - """ - result = 0 - for i in range(5): - part = self.read(1)[0] - result |= (part & 0x7F) << (7 * i) - if not part & 0x80: - return signed_int32(result).value - raise OSError("Received varint is too big!") - - def read_varlong(self) -> int: - """Read varlong from ``self`` and return it. - - :param value: Maximum is ``2 ** 63 - 1``, minimum is ``-(2 ** 63)``. - :raises IOError: If varint received is out of range. - """ - result = 0 - for i in range(10): - part = self.read(1)[0] - result |= (part & 0x7F) << (7 * i) - if not part & 0x80: - return signed_int64(result).value - raise OSError("Received varlong is too big!") - - def read_utf(self) -> str: - """Read up to 32767 bytes by reading a varint, then decode bytes as ``UTF-8``.""" - length = self.read_varint() - return self.read(length).decode("utf8") - - def read_ascii(self) -> str: - """Read ``self`` until last value is not zero, then return that decoded with ``ISO-8859-1``.""" - result = bytearray() - while len(result) == 0 or result[-1] != 0: - result.extend(self.read(1)) - return result[:-1].decode("ISO-8859-1") - - def read_short(self) -> int: - """Return ``-32768 - 32767``. Read 2 bytes.""" - return self._unpack("h", self.read(2)) - - def read_ushort(self) -> int: - """Return ``0 - 65535 (2 ** 16 - 1)``. Read 2 bytes.""" - return self._unpack("H", self.read(2)) - - def read_int(self) -> int: - """Return ``-2147483648 - 2147483647``. Read 4 bytes.""" - return self._unpack("i", self.read(4)) - - def read_uint(self) -> int: - """Return ``0 - 4294967295 (2 ** 32 - 1)``. 4 bytes read.""" - return self._unpack("I", self.read(4)) - - def read_long(self) -> int: - """Return ``-9223372036854775808 - 9223372036854775807``. Read 8 bytes.""" - return self._unpack("q", self.read(8)) - - def read_ulong(self) -> int: - """Return ``0 - 18446744073709551613 (2 ** 64 - 1)``. Read 8 bytes.""" - return self._unpack("Q", self.read(8)) - - def read_bool(self) -> bool: - """Return `True` or `False`. Read 1 byte.""" - return cast("bool", self._unpack("?", self.read(1))) - - def read_buffer(self) -> Connection: - """Read a varint for length, then return a new connection from length read bytes.""" - length = self.read_varint() - result = Connection() - result.receive(self.read(length)) - return result - - -class BaseReadAsync(ABC): - """Asynchronous Read connection base class.""" - - __slots__ = () - - @abstractmethod - async def read(self, length: int, /) -> bytearray: - """Read length bytes from ``self``, return a byte array.""" - - def __repr__(self) -> str: - return f"<{self.__class__.__name__} Object>" - - @staticmethod - def _unpack(format_: str, data: bytes) -> int: - """Unpack data as bytes with format in big-endian.""" - return struct.unpack(">" + format_, bytes(data))[0] - - async def read_varint(self) -> int: - """Read varint from ``self`` and return it. - - :param value: Maximum is ``2 ** 31 - 1``, minimum is ``-(2 ** 31)``. - :raises IOError: If varint received is out of range. - """ - result = 0 - for i in range(5): - part = (await self.read(1))[0] - result |= (part & 0x7F) << 7 * i - if not part & 0x80: - return signed_int32(result).value - raise OSError("Received a varint that was too big!") - - async def read_varlong(self) -> int: - """Read varlong from ``self`` and return it. - - :param value: Maximum is ``2 ** 63 - 1``, minimum is ``-(2 ** 63)``. - :raises IOError: If varint received is out of range. - """ - result = 0 - for i in range(10): - part = (await self.read(1))[0] - result |= (part & 0x7F) << (7 * i) - if not part & 0x80: - return signed_int64(result).value - raise OSError("Received varlong is too big!") - - async def read_utf(self) -> str: - """Read up to 32767 bytes by reading a varint, then decode bytes as ``UTF-8``.""" - length = await self.read_varint() - return (await self.read(length)).decode("utf8") - - async def read_ascii(self) -> str: - """Read ``self`` until last value is not zero, then return that decoded with ``ISO-8859-1``.""" - result = bytearray() - while len(result) == 0 or result[-1] != 0: - result.extend(await self.read(1)) - return result[:-1].decode("ISO-8859-1") - - async def read_short(self) -> int: - """Return ``-32768 - 32767``. Read 2 bytes.""" - return self._unpack("h", await self.read(2)) - - async def read_ushort(self) -> int: - """Return ``0 - 65535 (2 ** 16 - 1)``. Read 2 bytes.""" - return self._unpack("H", await self.read(2)) - - async def read_int(self) -> int: - """Return ``-2147483648 - 2147483647``. Read 4 bytes.""" - return self._unpack("i", await self.read(4)) - - async def read_uint(self) -> int: - """Return ``0 - 4294967295 (2 ** 32 - 1)``. 4 bytes read.""" - return self._unpack("I", await self.read(4)) - - async def read_long(self) -> int: - """Return ``-9223372036854775808 - 9223372036854775807``. Read 8 bytes.""" - return self._unpack("q", await self.read(8)) - - async def read_ulong(self) -> int: - """Return ``0 - 18446744073709551613 (2 ** 64 - 1)``. Read 8 bytes.""" - return self._unpack("Q", await self.read(8)) - - async def read_bool(self) -> bool: - """Return `True` or `False`. Read 1 byte.""" - return cast("bool", self._unpack("?", await self.read(1))) - - async def read_buffer(self) -> Connection: - """Read a varint for length, then return a new connection from length read bytes.""" - length = await self.read_varint() - result = Connection() - result.receive(await self.read(length)) - return result - - -class BaseConnection: - """Base Connection class. Implements flush, receive, and remaining.""" - - __slots__ = () - - def __repr__(self) -> str: - return f"<{self.__class__.__name__} Object>" - - def flush(self) -> bytearray: - """Raise :exc:`TypeError`, unsupported.""" - raise TypeError(f"{self.__class__.__name__} does not support flush()") - - def receive(self, _data: BytesConvertable | bytearray) -> None: - """Raise :exc:`TypeError`, unsupported.""" - raise TypeError(f"{self.__class__.__name__} does not support receive()") - - def remaining(self) -> int: - """Raise :exc:`TypeError`, unsupported.""" - raise TypeError(f"{self.__class__.__name__} does not support remaining()") - - -class BaseSyncConnection(BaseConnection, BaseReadSync, BaseWriteSync): - """Base synchronous read and write class.""" - - __slots__ = () - - -class BaseAsyncReadSyncWriteConnection(BaseConnection, BaseReadAsync, BaseWriteSync): - """Base asynchronous read and synchronous write class.""" - - __slots__ = () - - -class BaseAsyncConnection(BaseConnection, BaseReadAsync, BaseWriteAsync): - """Base asynchronous read and write class.""" - - __slots__ = () - - -class Connection(BaseSyncConnection): - """Base connection class.""" - - __slots__ = ("received", "sent") - - def __init__(self) -> None: - self.sent = bytearray() - self.received = bytearray() - - def read(self, length: int, /) -> bytearray: - """Return :attr:`.received` up to length bytes, then cut received up to that point.""" - if len(self.received) < length: - raise OSError(f"Not enough data to read! {len(self.received)} < {length}") - - result = self.received[:length] - self.received = self.received[length:] - return result - - def write(self, data: Connection | str | bytearray | bytes) -> None: - """Extend :attr:`.sent` from ``data``.""" - if isinstance(data, Connection): - data = data.flush() - if isinstance(data, str): - data = bytearray(data, "utf-8") - self.sent.extend(data) - - def receive(self, data: BytesConvertable | bytearray) -> None: - """Extend :attr:`.received` with ``data``.""" - if not isinstance(data, bytearray): - data = bytearray(data) - self.received.extend(data) - - def remaining(self) -> int: - """Return length of :attr:`.received`.""" - return len(self.received) - - def flush(self) -> bytearray: - """Return :attr:`.sent`, also clears :attr:`.sent`.""" - result, self.sent = self.sent, bytearray() - return result - - def copy(self) -> Connection: - """Return a copy of ``self``.""" - new = self.__class__() - new.receive(self.received) - new.write(self.sent) - return new - - -class SocketConnection(BaseSyncConnection): - """Socket connection.""" - - __slots__ = ("socket",) - - def __init__(self) -> None: - # These will only be None until connect is called, ignore the None type assignment - self.socket: socket.socket = None # pyright: ignore[reportAttributeAccessIssue] - - def close(self) -> None: - """Close :attr:`.socket`.""" - if self.socket is not None: # If initialized - try: - self.socket.shutdown(socket.SHUT_RDWR) - except OSError as exception: # Socket wasn't connected (nothing to shut down) - if exception.errno != errno.ENOTCONN: - raise - - self.socket.close() - - def __enter__(self) -> Self: - return self - - def __exit__(self, *_: object) -> None: - self.close() - - -class TCPSocketConnection(SocketConnection): - """TCP Connection to address. Timeout defaults to 3 seconds.""" - - __slots__ = () - - def __init__(self, addr: tuple[str | None, int], timeout: float = 3) -> None: - super().__init__() - self.socket = socket.create_connection(addr, timeout=timeout) - self.socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) - - def read(self, length: int, /) -> bytearray: - """Return length bytes read from :attr:`.socket`. Raises :exc:`IOError` when server doesn't respond.""" - result = bytearray() - while len(result) < length: - new = self.socket.recv(length - len(result)) - if len(new) == 0: - raise OSError("Server did not respond with any information!") - result.extend(new) - return result - - def write(self, data: Connection | str | bytes | bytearray) -> None: - """Send data on :attr:`.socket`.""" - if isinstance(data, Connection): - data = bytearray(data.flush()) - elif isinstance(data, str): - data = bytearray(data, "utf-8") - self.socket.send(data) - - -class UDPSocketConnection(SocketConnection): - """UDP Connection class.""" - - __slots__ = ("addr",) - - def __init__(self, addr: Address, timeout: float = 3) -> None: - super().__init__() - self.addr = addr - self.socket = socket.socket( - socket.AF_INET if _ip_type(addr[0]) == 4 else socket.AF_INET6, - socket.SOCK_DGRAM, - ) - self.socket.settimeout(timeout) - - def remaining(self) -> int: - """Always return ``65535`` (``2 ** 16 - 1``).""" # noqa: D401 # imperative mood - return 65535 - - def read(self, _length: int, /) -> bytearray: - """Return up to :meth:`.remaining` bytes. Length does nothing here.""" - result = bytearray() - while len(result) == 0: - result.extend(self.socket.recvfrom(self.remaining())[0]) - return result - - def write(self, data: Connection | str | bytes | bytearray) -> None: - """Use :attr:`.socket` to send data to :attr:`.addr`.""" - if isinstance(data, Connection): - data = bytearray(data.flush()) - elif isinstance(data, str): - data = bytearray(data, "utf-8") - self.socket.sendto(data, self.addr) - - -class TCPAsyncSocketConnection(BaseAsyncReadSyncWriteConnection): - """Asynchronous TCP Connection class.""" - - __slots__ = ("_addr", "reader", "timeout", "writer") - - def __init__(self, addr: Address, timeout: float = 3) -> None: - # These will only be None until connect is called, ignore the None type assignment - self.reader: asyncio.StreamReader = None # pyright: ignore[reportAttributeAccessIssue] - self.writer: asyncio.StreamWriter = None # pyright: ignore[reportAttributeAccessIssue] - self.timeout: float = timeout - self._addr = addr - - async def connect(self) -> None: - """Use :mod:`asyncio` to open a connection to address. Timeout is in seconds.""" - conn = asyncio.open_connection(*self._addr) - self.reader, self.writer = await asyncio.wait_for(conn, timeout=self.timeout) - if self.writer is not None: # it might be None in unittest - sock: socket.socket = self.writer.transport.get_extra_info("socket") - sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) - - async def read(self, length: int, /) -> bytearray: - """Read up to ``length`` bytes from :attr:`.reader`.""" - result = bytearray() - while len(result) < length: - new = await asyncio.wait_for(self.reader.read(length - len(result)), timeout=self.timeout) - if len(new) == 0: - raise OSError("Socket did not respond with any information!") - result.extend(new) - return result - - def write(self, data: Connection | str | bytes | bytearray) -> None: - """Write data to :attr:`.writer`.""" - if isinstance(data, Connection): - data = bytearray(data.flush()) - elif isinstance(data, str): - data = bytearray(data, "utf-8") - self.writer.write(data) - - def close(self) -> None: - """Close :attr:`.writer`.""" - if self.writer is not None: # If initialized - self.writer.close() - - async def __aenter__(self) -> Self: - await self.connect() - return self - - async def __aexit__(self, *_: object) -> None: - self.close() - - -class UDPAsyncSocketConnection(BaseAsyncConnection): - """Asynchronous UDP Connection class.""" - - __slots__ = ("_addr", "stream", "timeout") - - def __init__(self, addr: Address, timeout: float = 3) -> None: - # This will only be None until connect is called, ignore the None type assignment - self.stream: asyncio_dgram.aio.DatagramClient = None # pyright: ignore[reportAttributeAccessIssue] - self.timeout: float = timeout - self._addr = addr - - async def connect(self) -> None: - """Connect to address. Timeout is in seconds.""" - conn = asyncio_dgram.connect(self._addr) - self.stream = await asyncio.wait_for(conn, timeout=self.timeout) - - def remaining(self) -> int: - """Always return ``65535`` (``2 ** 16 - 1``).""" # noqa: D401 # imperative mood - return 65535 - - async def read(self, _length: int, /) -> bytearray: - """Read from :attr:`.stream`. Length does nothing here.""" - data, _remote_addr = await asyncio.wait_for(self.stream.recv(), timeout=self.timeout) - return bytearray(data) - - async def write(self, data: Connection | str | bytes | bytearray) -> None: - """Send data with :attr:`.stream`.""" - if isinstance(data, Connection): - data = bytearray(data.flush()) - elif isinstance(data, str): - data = bytearray(data, "utf-8") - await self.stream.send(data) - - def close(self) -> None: - """Close :attr:`.stream`.""" - if self.stream is not None: # If initialized - self.stream.close() - - async def __aenter__(self) -> Self: - await self.connect() - return self - - async def __aexit__(self, *_: object) -> None: - self.close() diff --git a/mcstatus/_protocol/io/__init__.py b/mcstatus/_protocol/io/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/mcstatus/_protocol/io/base_io.py b/mcstatus/_protocol/io/base_io.py new file mode 100644 index 00000000..99746253 --- /dev/null +++ b/mcstatus/_protocol/io/base_io.py @@ -0,0 +1,743 @@ +from __future__ import annotations + +import struct +from abc import ABC, abstractmethod +from enum import Enum +from itertools import count +from typing import Literal, TYPE_CHECKING, TypeAlias, TypeVar, overload + +if TYPE_CHECKING: + from collections.abc import Awaitable, Callable + +__all__ = [ + "FLOAT_FORMATS_TYPE", + "INT_FORMATS_TYPE", + "BaseAsyncReader", + "BaseAsyncWriter", + "BaseSyncReader", + "BaseSyncWriter", + "StructFormat", +] + +T = TypeVar("T") +R = TypeVar("R") + +# region: Utils + + +def to_twos_complement(number: int, bits: int) -> int: + """Convert a given ``number`` into two's complement format with the specified number of ``bits``. + + :param number: The signed integer to convert. + :param bits: Number of bits used for the two's complement representation. + :return: The integer encoded in two's complement form within the given bit width. + :raises ValueError: + If given ``number`` is out of range, and can't be converted into twos complement format, + since it wouldn't fit into the given amount of ``bits``. + """ + value_max = 1 << (bits - 1) + value_min = value_max * -1 + # With two's complement, we have one more negative number than positive + # this means we can't be exactly at value_max, but we can be at exactly value_min + if number >= value_max or number < value_min: + raise ValueError(f"Can't convert number {number} into {bits}-bit twos complement format - out of range") + + return number + (1 << bits) if number < 0 else number + + +def from_twos_complement(number: int, bits: int) -> int: + """Convert a ``number`` from a two's complement representation with the specified number of ``bits``. + + :param number: The integer encoded in two's complement form. + :param bits: Number of bits used for the two's complement representation. + :return: The decoded signed integer. + :raises ValueError: + If given ``number`` doesn't fit into given amount of ``bits``. This likely means that you're + using the wrong number, or that the number was converted into twos complement with higher + amount of `bits`. + """ + value_max = (1 << bits) - 1 + if number < 0 or number > value_max: + raise ValueError(f"Can't convert number {number} from {bits}-bit twos complement format - out of range") + + if number & (1 << (bits - 1)) != 0: + number -= 1 << bits + + return number + + +# endregion + + +# region: Format types + + +class StructFormat(str, Enum): + """All possible struct format types used for reading and writing binary data. + + These values correspond directly to format characters accepted by the + :mod:`struct` module. + """ + + BOOL = "?" + CHAR = "c" + BYTE = "b" + UBYTE = "B" + SHORT = "h" + USHORT = "H" + INT = "i" + UINT = "I" + LONG = "l" + ULONG = "L" + FLOAT = "f" + DOUBLE = "d" + HALFFLOAT = "e" + LONGLONG = "q" + ULONGLONG = "Q" + + +INT_FORMATS_TYPE: TypeAlias = Literal[ + StructFormat.BYTE, + StructFormat.UBYTE, + StructFormat.SHORT, + StructFormat.USHORT, + StructFormat.INT, + StructFormat.UINT, + StructFormat.LONG, + StructFormat.ULONG, + StructFormat.LONGLONG, + StructFormat.ULONGLONG, +] + +FLOAT_FORMATS_TYPE: TypeAlias = Literal[ + StructFormat.FLOAT, + StructFormat.DOUBLE, + StructFormat.HALFFLOAT, +] + +# endregion + +# region: Writer classes + + +class BaseAsyncWriter(ABC): + """Base class holding asynchronous write buffer/connection interactions.""" + + __slots__ = () + + @abstractmethod + async def write(self, data: bytes | bytearray, /) -> None: + """Underlying write method, sending/storing the data. + + All of the writer functions will eventually call this method. + """ + + @overload + async def write_value(self, fmt: INT_FORMATS_TYPE, value: int, /) -> None: ... + + @overload + async def write_value(self, fmt: FLOAT_FORMATS_TYPE, value: float, /) -> None: ... + + @overload + async def write_value(self, fmt: Literal[StructFormat.BOOL], value: bool, /) -> None: ... # noqa: FBT001 + + @overload + async def write_value(self, fmt: Literal[StructFormat.CHAR], value: str, /) -> None: ... + + async def write_value(self, fmt: StructFormat, value: object, /) -> None: + """Write a given ``value`` as given struct format (``fmt``) in big-endian mode.""" + await self.write(struct.pack(">" + fmt.value, value)) + + async def _write_varuint(self, value: int, /, *, max_bits: int | None = None) -> None: + """Write an arbitrarily large unsigned integer using a variable-length encoding. + + This is a standard way of transmitting integers with variable length over the network, allowing + smaller numbers take up fewer bytes. + + Writing is limited to integers representable within ``max_bits`` bits. Attempting to write a larger + value will raise :class:`ValueError`. Note that limiting the value to e.g. 32 bits does not mean + that at most 4 bytes will be written. Due to the encoding overhead, values within 32 bits may require + up to 5 bytes. + + Varints encode integers using groups of 7 bits. The 7 least significant bits of each byte store data, + while the most significant bit acts as a continuation flag. If this bit is set (``1``), another byte + follows. + + The least significant group is written first, followed by increasingly significant groups, making the + encoding little-endian in 7-bit groups. + + :param value: Unsigned integer to encode. + :param max_bits: Maximum allowed bit width for the encoded value. + :raises ValueError: If ``value`` is negative or exceeds the allowed bit width. + """ + value_max = (1 << (max_bits)) - 1 if max_bits is not None else float("inf") + if value < 0 or value > value_max: + raise ValueError(f"Tried to write varint outside of the range of {max_bits}-bit int.") + + remaining = value + while True: + if remaining & ~0x7F == 0: # final byte + await self.write_value(StructFormat.UBYTE, remaining) + return + # Write only 7 least significant bits with the first bit being 1, marking there will be another byte + await self.write_value(StructFormat.UBYTE, remaining & 0x7F | 0x80) + # Subtract the value we've already sent (7 least significant bits) + remaining >>= 7 + + async def write_varint(self, value: int, /) -> None: + """Write a 32-bit signed integer using a variable-length encoding. + + The value is first converted to a 32-bit two's complement representation + and then encoded using the varint format. + + See :meth:`_write_varuint` for details about the encoding. + + :param value: Signed integer to encode. + """ + val = to_twos_complement(value, bits=32) + await self._write_varuint(val, max_bits=32) + + async def write_varlong(self, value: int, /) -> None: + """Write a 64-bit signed integer using a variable-length encoding. + + The value is first converted to a 64-bit two's complement representation + and then encoded using the varint format. + + See :meth:`_write_varuint` for details about the encoding. + + :param value: Signed integer to encode. + """ + val = to_twos_complement(value, bits=64) + await self._write_varuint(val, max_bits=64) + + async def write_bytearray(self, data: bytes | bytearray, /) -> None: + """Write an arbitrary sequence of bytes, prefixed with a varint of its size.""" + await self.write_varint(len(data)) + await self.write(data) + + async def write_ascii(self, value: str, /) -> None: + """Write ISO-8859-1 encoded string, with NULL (0x00) at the end to indicate string end.""" + data = bytes(value, "ISO-8859-1") + await self.write(data) + await self.write(bytes([0])) + + async def write_utf(self, value: str, /) -> None: + """Write a UTF-8 encoded string prefixed with its byte length as a varint. + + The maximum number of characters allowed is ``32767``. + + Individual UTF-8 characters can take up to 4 bytes, however most of the common ones take up less. Assuming the + worst case of 4 bytes per every character, at most 131068 data bytes will be written + 3 additional bytes from + the varint encoding overhead. + + :param value: String to encode. + :raises ValueError: If the string exceeds the maximum allowed length. + """ + if len(value) > 32767: + raise ValueError("Maximum character limit for writing strings is 32767 characters.") + + data = bytes(value, "utf-8") + await self.write_varint(len(data)) + await self.write(data) + + async def write_optional(self, value: T | None, /, writer: Callable[[T], Awaitable[R]]) -> R | None: + """Write a bool indicating whether ``value`` is present and, if so, serialize the value using ``writer``. + + * If ``value`` is ``None``, ``False`` is written and ``None`` is returned. + * If ``value`` is not ``None``, ``True`` is written and ``writer`` is called + with the value. The return value of ``writer`` is then forwarded. + + :param value: Optional value to serialize. + :param writer: Callable used to serialize the value when it is present. + :return: ``None`` if the value is absent, otherwise the result of ``writer``. + """ + if value is None: + await self.write_value(StructFormat.BOOL, False) # noqa: FBT003 + return None + + await self.write_value(StructFormat.BOOL, True) # noqa: FBT003 + return await writer(value) + + +class BaseSyncWriter(ABC): + """Base class holding synchronous write buffer/connection interactions.""" + + __slots__ = () + + @abstractmethod + def write(self, data: bytes | bytearray, /) -> None: + """Underlying write method, sending/storing the data. + + All of the writer functions will eventually call this method. + """ + + @overload + def write_value(self, fmt: INT_FORMATS_TYPE, value: int, /) -> None: ... + + @overload + def write_value(self, fmt: FLOAT_FORMATS_TYPE, value: float, /) -> None: ... + + @overload + def write_value(self, fmt: Literal[StructFormat.BOOL], value: bool, /) -> None: ... # noqa: FBT001 + + @overload + def write_value(self, fmt: Literal[StructFormat.CHAR], value: str, /) -> None: ... + + def write_value(self, fmt: StructFormat, value: object, /) -> None: + """Write a given ``value`` as given struct format (``fmt``) in big-endian mode.""" + self.write(struct.pack(">" + fmt.value, value)) + + def _write_varuint(self, value: int, /, *, max_bits: int | None = None) -> None: + """Write an arbitrarily large unsigned integer using a variable-length encoding. + + This is a standard way of transmitting integers with variable length over the network, allowing + smaller numbers take up fewer bytes. + + Writing is limited to integers representable within ``max_bits`` bits. Attempting to write a larger + value will raise :class:`ValueError`. Note that limiting the value to e.g. 32 bits does not mean + that at most 4 bytes will be written. Due to the encoding overhead, values within 32 bits may require + up to 5 bytes. + + Varints encode integers using groups of 7 bits. The 7 least significant bits of each byte store data, + while the most significant bit acts as a continuation flag. If this bit is set (``1``), another byte + follows. + + The least significant group is written first, followed by increasingly significant groups, making the + encoding little-endian in 7-bit groups. + + :param value: Unsigned integer to encode. + :param max_bits: Maximum allowed bit width for the encoded value. + :raises ValueError: If ``value`` is negative or exceeds the allowed bit width. + """ + value_max = (1 << (max_bits)) - 1 if max_bits is not None else float("inf") + if value < 0 or value > value_max: + raise ValueError(f"Tried to write varint outside of the range of {max_bits}-bit int.") + + remaining = value + while True: + if remaining & ~0x7F == 0: # final byte + self.write_value(StructFormat.UBYTE, remaining) + return + # Write only 7 least significant bits with the first bit being 1, marking there will be another byte + self.write_value(StructFormat.UBYTE, remaining & 0x7F | 0x80) + # Subtract the value we've already sent (7 least significant bits) + remaining >>= 7 + + def write_varint(self, value: int, /) -> None: + """Write a 32-bit signed integer using a variable-length encoding. + + The value is first converted to a 32-bit two's complement representation + and then encoded using the varint format. + + See :meth:`_write_varuint` for details about the encoding. + + :param value: Signed integer to encode. + """ + val = to_twos_complement(value, bits=32) + self._write_varuint(val, max_bits=32) + + def write_varlong(self, value: int, /) -> None: + """Write a 64-bit signed integer using a variable-length encoding. + + The value is first converted to a 64-bit two's complement representation + and then encoded using the varint format. + + See :meth:`_write_varuint` for details about the encoding. + + :param value: Signed integer to encode. + """ + val = to_twos_complement(value, bits=64) + self._write_varuint(val, max_bits=64) + + def write_bytearray(self, data: bytes | bytearray, /) -> None: + """Write an arbitrary sequence of bytes, prefixed with a varint of its size.""" + self.write_varint(len(data)) + self.write(data) + + def write_ascii(self, value: str, /) -> None: + """Write ISO-8859-1 encoded string, with NULL (0x00) at the end to indicate string end.""" + data = bytes(value, "ISO-8859-1") + self.write(data) + self.write(bytes([0])) + + def write_utf(self, value: str, /) -> None: + """Write a UTF-8 encoded string prefixed with its byte length as a varint. + + The maximum number of characters allowed is ``32767``. + + Individual UTF-8 characters can take up to 4 bytes, however most of the common ones take up less. Assuming the + worst case of 4 bytes per every character, at most 131068 data bytes will be written + 3 additional bytes from + the varint encoding overhead. + + :param value: String to encode. + :raises ValueError: If the string exceeds the maximum allowed length. + """ + if len(value) > 32767: + raise ValueError("Maximum character limit for writing strings is 32767 characters.") + + data = bytes(value, "utf-8") + self.write_varint(len(data)) + self.write(data) + + def write_optional(self, value: T | None, /, writer: Callable[[T], R]) -> R | None: + """Write a bool indicating whether ``value`` is present and, if so, serialize the value using ``writer``. + + * If ``value`` is ``None``, ``False`` is written and ``None`` is returned. + * If ``value`` is not ``None``, ``True`` is written and ``writer`` is called + with the value. The return value of ``writer`` is then forwarded. + + :param value: Optional value to serialize. + :param writer: Callable used to serialize the value when it is present. + :return: ``None`` if the value is absent, otherwise the result of ``writer``. + """ + if value is None: + self.write_value(StructFormat.BOOL, False) # noqa: FBT003 + return None + + self.write_value(StructFormat.BOOL, True) # noqa: FBT003 + return writer(value) + + +# endregion +# region: Reader classes + + +class BaseAsyncReader(ABC): + """Base class holding asynchronous read buffer/connection interactions.""" + + __slots__ = () + + @abstractmethod + async def read(self, length: int, /) -> bytes: + """Underlying read method, obtaining the raw data. + + All of the reader functions will eventually call this method. + """ + + @overload + async def read_value(self, fmt: INT_FORMATS_TYPE, /) -> int: ... + + @overload + async def read_value(self, fmt: FLOAT_FORMATS_TYPE, /) -> float: ... + + @overload + async def read_value(self, fmt: Literal[StructFormat.BOOL], /) -> bool: ... + + @overload + async def read_value(self, fmt: Literal[StructFormat.CHAR], /) -> str: ... + + async def read_value(self, fmt: StructFormat, /) -> object: + """Read a value as given struct format (``fmt``) in big-endian mode. + + The amount of bytes to read will be determined based on the struct format automatically. + """ + length = struct.calcsize(fmt.value) + data = await self.read(length) + unpacked = struct.unpack(">" + fmt.value, data) + return unpacked[0] + + async def _read_varuint(self, *, max_bits: int | None = None) -> int: + """Read an arbitrarily large unsigned integer using a variable-length encoding. + + This is a standard way of transmitting integers with variable length over the network, + allowing smaller numbers to take fewer bytes. + + Reading is limited to integers representable within ``max_bits`` bits. Attempting to read + a larger value will raise :class:`OSError`. Note that limiting the value to e.g. 32 bits + does not mean that at most 4 bytes will be read. Due to the encoding overhead, values + within 32 bits may require up to 5 bytes. + + Varints encode integers using groups of 7 bits. The 7 least significant bits of each byte + store data, while the most significant bit acts as a continuation flag. If this bit is set + (``1``), another byte follows. + + The least significant group is read first, followed by increasingly significant groups, + making the encoding little-endian in 7-bit groups. + + :param max_bits: Maximum allowed bit width for the decoded value. + :raises OSError: If the decoded value exceeds the allowed bit width. + :return: The decoded unsigned integer. + """ + value_max = (1 << (max_bits)) - 1 if max_bits is not None else float("inf") + + result = 0 + for i in count(): + byte = await self.read_value(StructFormat.UBYTE) + # Read 7 least significant value bits in this byte, and shift them appropriately to be in the right place + # then simply add them (OR) as additional 7 most significant bits in our result + result |= (byte & 0x7F) << (7 * i) + + # Ensure that we stop reading and raise an error if the size gets over the maximum + # (if the current amount of bits is higher than allowed size in bits) + if result > value_max: + raise OSError(f"Received varint was outside the range of {max_bits}-bit int.") + + # If the most significant bit is 0, we should stop reading + if not byte & 0x80: + break + + return result + + async def read_varint(self) -> int: + """Read a 32-bit signed integer using a variable-length encoding. + + The value is read as an unsigned varint and then converted from a + 32-bit two's complement representation. + + See :meth:`_read_varuint` for details about the encoding. + + :return: Decoded signed integer. + """ + unsigned_num = await self._read_varuint(max_bits=32) + return from_twos_complement(unsigned_num, bits=32) + + async def read_varlong(self) -> int: + """Read a 64-bit signed integer using a variable-length encoding. + + The value is read as an unsigned varint and then converted from a + 64-bit two's complement representation. + + See :meth:`_read_varuint` for details about the encoding. + + :return: Decoded signed integer. + """ + unsigned_num = await self._read_varuint(max_bits=64) + return from_twos_complement(unsigned_num, bits=64) + + async def read_bytearray(self, /) -> bytes: + """Read a sequence of bytes prefixed with its length encoded as a varint. + + :return: The decoded byte sequence. + """ + length = await self.read_varint() + return await self.read(length) + + async def read_ascii(self) -> str: + """Read ISO-8859-1 encoded string, until we encounter NULL (0x00) at the end indicating string end. + + Bytes are read until a NULL terminator is encountered. The terminator itself is not included in the + returned string. There is no limit that can be set for how long this string can end up being. + + :return: Decoded string. + """ + # Keep reading bytes until we find NULL + result = bytearray() + while len(result) == 0 or result[-1] != 0: + byte = await self.read(1) + result.extend(byte) + return result[:-1].decode("ISO-8859-1") + + async def read_utf(self) -> str: + """Read a UTF-8 encoded string prefixed with its byte length as a varint. + + The maximum number of characters allowed is ``32767``. + + Individual UTF-8 characters can take up to 4 bytes, however most of the common ones take up less. Assuming the + worst case of 4 bytes per every character, at most 131068 data bytes will be read + 3 additional bytes from + the varint encoding overhead. + + :raises OSError: + * If the length prefix exceeds the maximum of ``131068``, the string will not be read at all, + and the error will be raised immediately after reading the prefix. + * If the decoded string contains more than ``32767`` characters. In this case the data is still + fully read because it fits within the byte limit. This behavior mirrors Minecraft's implementation. + + :return: Decoded UTF-8 string. + """ + length = await self.read_varint() + if length > 131068: + raise OSError(f"Maximum read limit for utf strings is 131068 bytes, got {length}.") + + data = await self.read(length) + chars = data.decode("utf-8") + + if len(chars) > 32767: + raise OSError(f"Maximum read limit for utf strings is 32767 characters, got {len(chars)}.") + + return chars + + async def read_optional(self, reader: Callable[[], Awaitable[R]]) -> R | None: + """Read a boolean indicating whether a value is present and, if so, deserialize it using ``reader``. + + * If ``False`` is read, nothing further is read and ``None`` is returned. + * If ``True`` is read, ``reader`` is called and its return value is forwarded. + + :param reader: Callable used to read the value when it is present. + :return: ``None`` if the value is absent, otherwise the result of ``reader``. + """ + if not await self.read_value(StructFormat.BOOL): + return None + + return await reader() + + +class BaseSyncReader(ABC): + """Base class holding synchronous read buffer/connection interactions.""" + + __slots__ = () + + @abstractmethod + def read(self, length: int, /) -> bytes: + """Underlying read method, obtaining the raw data. + + All of the reader functions will eventually call this method. + """ + + @overload + def read_value(self, fmt: INT_FORMATS_TYPE, /) -> int: ... + + @overload + def read_value(self, fmt: FLOAT_FORMATS_TYPE, /) -> float: ... + + @overload + def read_value(self, fmt: Literal[StructFormat.BOOL], /) -> bool: ... + + @overload + def read_value(self, fmt: Literal[StructFormat.CHAR], /) -> str: ... + + def read_value(self, fmt: StructFormat, /) -> object: + """Read a value as given struct format (``fmt``) in big-endian mode. + + The amount of bytes to read will be determined based on the struct format automatically. + """ + length = struct.calcsize(fmt.value) + data = self.read(length) + unpacked = struct.unpack(">" + fmt.value, data) + return unpacked[0] + + def _read_varuint(self, *, max_bits: int | None = None) -> int: + """Read an arbitrarily large unsigned integer using a variable-length encoding. + + This is a standard way of transmitting integers with variable length over the network, + allowing smaller numbers to take fewer bytes. + + Reading is limited to integers representable within ``max_bits`` bits. Attempting to read + a larger value will raise :class:`OSError`. Note that limiting the value to e.g. 32 bits + does not mean that at most 4 bytes will be read. Due to the encoding overhead, values + within 32 bits may require up to 5 bytes. + + Varints encode integers using groups of 7 bits. The 7 least significant bits of each byte + store data, while the most significant bit acts as a continuation flag. If this bit is set + (``1``), another byte follows. + + The least significant group is read first, followed by increasingly significant groups, + making the encoding little-endian in 7-bit groups. + + :param max_bits: Maximum allowed bit width for the decoded value. + :raises OSError: If the decoded value exceeds the allowed bit width. + :return: The decoded unsigned integer. + """ + value_max = (1 << (max_bits)) - 1 if max_bits is not None else float("inf") + + result = 0 + for i in count(): + byte = self.read_value(StructFormat.UBYTE) + # Read 7 least significant value bits in this byte, and shift them appropriately to be in the right place + # then simply add them (OR) as additional 7 most significant bits in our result + result |= (byte & 0x7F) << (7 * i) + + # Ensure that we stop reading and raise an error if the size gets over the maximum + # (if the current amount of bits is higher than allowed size in bits) + if result > value_max: + raise OSError(f"Received varint was outside the range of {max_bits}-bit int.") + + # If the most significant bit is 0, we should stop reading + if not byte & 0x80: + break + + return result + + def read_varint(self) -> int: + """Read a 32-bit signed integer using a variable-length encoding. + + The value is read as an unsigned varint and then converted from a + 32-bit two's complement representation. + + See :meth:`_read_varuint` for details about the encoding. + + :return: Decoded signed integer. + """ + unsigned_num = self._read_varuint(max_bits=32) + return from_twos_complement(unsigned_num, bits=32) + + def read_varlong(self) -> int: + """Read a 64-bit signed integer using a variable-length encoding. + + The value is read as an unsigned varint and then converted from a + 64-bit two's complement representation. + + See :meth:`_read_varuint` for details about the encoding. + + :return: Decoded signed integer. + """ + unsigned_num = self._read_varuint(max_bits=64) + return from_twos_complement(unsigned_num, bits=64) + + def read_bytearray(self) -> bytes: + """Read a sequence of bytes prefixed with its length encoded as a varint. + + :return: The decoded byte sequence. + """ + length = self.read_varint() + return self.read(length) + + def read_ascii(self) -> str: + """Read ISO-8859-1 encoded string, until we encounter NULL (0x00) at the end indicating string end. + + Bytes are read until a NULL terminator is encountered. The terminator itself is not included in the + returned string. There is no limit that can be set for how long this string can end up being. + + :return: Decoded string. + """ + # Keep reading bytes until we find NULL + result = bytearray() + while len(result) == 0 or result[-1] != 0: + byte = self.read(1) + result.extend(byte) + return result[:-1].decode("ISO-8859-1") + + def read_utf(self) -> str: + """Read a UTF-8 encoded string prefixed with its byte length as a varint. + + The maximum number of characters allowed is ``32767``. + + Individual UTF-8 characters can take up to 4 bytes, however most of the common ones take up less. Assuming the + worst case of 4 bytes per every character, at most 131068 data bytes will be read + 3 additional bytes from + the varint encoding overhead. + + :raises OSError: + * If the length prefix exceeds the maximum of ``131068``, the string will not be read at all, + and the error will be raised immediately after reading the prefix. + * If the decoded string contains more than ``32767`` characters. In this case the data is still + fully read because it fits within the byte limit. This behavior mirrors Minecraft's implementation. + + :return: Decoded UTF-8 string. + """ + length = self.read_varint() + if length > 131068: + raise OSError(f"Maximum read limit for utf strings is 131068 bytes, got {length}.") + + data = self.read(length) + chars = data.decode("utf-8") + + if len(chars) > 32767: + raise OSError(f"Maximum read limit for utf strings is 32767 characters, got {len(chars)}.") + + return chars + + def read_optional(self, reader: Callable[[], R]) -> R | None: + """Read a boolean indicating whether a value is present and, if so, deserialize it using ``reader``. + + * If ``False`` is read, nothing further is read and ``None`` is returned. + * If ``True`` is read, ``reader`` is called and its return value is forwarded. + + :param reader: Callable used to read the value when it is present. + :return: ``None`` if the value is absent, otherwise the result of ``reader``. + """ + if not self.read_value(StructFormat.BOOL): + return None + + return reader() + + +# endregion diff --git a/mcstatus/_protocol/io/buffer.py b/mcstatus/_protocol/io/buffer.py new file mode 100644 index 00000000..772cac8f --- /dev/null +++ b/mcstatus/_protocol/io/buffer.py @@ -0,0 +1,115 @@ +from __future__ import annotations + +from typing import SupportsIndex, TYPE_CHECKING, final, overload + +from mcstatus._protocol.io.base_io import BaseSyncReader, BaseSyncWriter + +if TYPE_CHECKING: + from collections.abc import Iterable + + from _typeshed import ReadableBuffer + from typing_extensions import override +else: + override = lambda f: f # noqa: E731 + +__all__ = ["Buffer"] + + +@final +class Buffer(BaseSyncWriter, BaseSyncReader, bytearray): + """In-memory bytearray-like buffer supporting the common read/write operations.""" + + __slots__ = ("pos",) + + @overload + def __init__(self) -> None: ... + @overload + def __init__(self, ints: Iterable[SupportsIndex] | SupportsIndex | ReadableBuffer, /) -> None: ... + @overload + def __init__(self, string: str, /, encoding: str, errors: str = "strict") -> None: ... + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.pos = 0 + + @override + def write(self, data: bytes | bytearray, /) -> None: + """Write/Store given ``data`` into the buffer.""" + self.extend(data) + + @override + def read(self, length: int, /) -> bytes: + """Read data stored in the buffer. + + Reading data doesn't remove that data, rather that data is treated as already read, and + next read will start from the first unread byte. If freeing the data is necessary, check + the :meth:`clear` method. + + :param length: + Amount of bytes to be read. + + If the requested amount can't be read (buffer doesn't contain that much data/buffer + doesn't contain any data), an :exc:`OSError` will be raised. + + If there were some data in the buffer, but it was less than requested, this remaining + data will still be depleted and the partial data that was read will be a part of the + error message in the :exc:`OSError`. This behavior is here to mimic reading from a real + socket connection. + """ + end = self.pos + length + + if end > len(self): + data = bytearray(self.unread_view()) + bytes_read = len(self) - self.pos + self.pos = len(self) + raise OSError( + "Requested to read more data than available." + f" Read {bytes_read} bytes: {data}, out of {length} requested bytes." + ) + + try: + return bytes(self.unread_view()[:length]) + finally: + self.pos = end + + @override + def clear(self, *, only_already_read: bool = False) -> None: + """Clear out the stored data and reset position. + + :param only_already_read: + When set to ``True``, only the data that was already marked as read will be cleared, + and the position will be reset (to start at the remaining data). This can be useful + for avoiding needlessly storing large amounts of data in memory, if this data is no + longer useful. + + Otherwise, if set to ``False``, all of the data is cleared, and the position is reset, + essentially resulting in a blank buffer. + """ + if only_already_read: + del self[: self.pos] + else: + super().clear() + self.pos = 0 + + def reset(self) -> None: + """Reset the position in the buffer. + + Since the buffer doesn't automatically clear the already read data, it is possible to simply + reset the position and read the data it contains again. + """ + self.pos = 0 + + def unread_view(self) -> memoryview: + """Return a zero-copy view of unread data without modifying buffer state.""" + return memoryview(self)[self.pos :] + + def flush(self) -> bytes: + """Read all of the remaining data in the buffer and clear it out.""" + data = bytes(self.unread_view()) + self.clear() + return data + + @property + def remaining(self) -> int: + """Get the amount of bytes that's still remaining in the buffer to be read.""" + return len(self) - self.pos diff --git a/mcstatus/_protocol/io/connection.py b/mcstatus/_protocol/io/connection.py new file mode 100644 index 00000000..09aedf69 --- /dev/null +++ b/mcstatus/_protocol/io/connection.py @@ -0,0 +1,244 @@ +from __future__ import annotations + +import asyncio +import errno +import socket +from ipaddress import ip_address +from typing import TYPE_CHECKING, TypeAlias, final + +import asyncio_dgram + +from mcstatus._protocol.io.base_io import BaseAsyncReader, BaseAsyncWriter, BaseSyncReader, BaseSyncWriter + +if TYPE_CHECKING: + from collections.abc import Iterable + + from typing_extensions import Self, SupportsIndex, override + + from mcstatus._net.address import Address +else: + override = lambda f: f # noqa: E731 + +__all__ = [ + "BaseAsyncConnection", + "BaseSyncConnection", + "TCPAsyncSocketConnection", + "TCPSocketConnection", + "UDPAsyncSocketConnection", + "UDPSocketConnection", +] + +BytesConvertable: TypeAlias = "SupportsIndex | Iterable[SupportsIndex]" + + +class BaseSyncConnection(BaseSyncReader, BaseSyncWriter): + """Base synchronous read and write class.""" + + __slots__ = () + + +class BaseAsyncConnection(BaseAsyncReader, BaseAsyncWriter): + """Base asynchronous read and write class.""" + + __slots__ = () + + +class _SocketConnection(BaseSyncConnection): + """Socket connection.""" + + __slots__ = ("socket",) + + def __init__(self) -> None: + # These will only be None until connect is called, ignore the None type assignment + self.socket: socket.socket = None # pyright: ignore[reportAttributeAccessIssue] + + def close(self) -> None: + """Close :attr:`.socket`.""" + if self.socket is not None: # If initialized + try: + self.socket.shutdown(socket.SHUT_RDWR) + except OSError as exception: # Socket wasn't connected (nothing to shut down) + if exception.errno != errno.ENOTCONN: + raise + + self.socket.close() + + def __enter__(self) -> Self: + return self + + def __exit__(self, *_: object) -> None: + self.close() + + +@final +class TCPSocketConnection(_SocketConnection): + """TCP Connection to address. Timeout defaults to 3 seconds.""" + + __slots__ = () + + def __init__(self, addr: tuple[str | None, int], timeout: float = 3) -> None: + super().__init__() + self.socket = socket.create_connection(addr, timeout=timeout) + self.socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + + @override + def read(self, length: int, /) -> bytearray: + """Return length bytes read from :attr:`.socket`. Raises :exc:`OSError` when server doesn't respond.""" + result = bytearray() + while len(result) < length: + new = self.socket.recv(length - len(result)) + if len(new) == 0: + raise OSError("Server did not respond with any information!") + result.extend(new) + return result + + def write(self, data: str | bytes | bytearray, /) -> None: + """Send data on :attr:`.socket`.""" + if isinstance(data, str): + data = data.encode("utf-8") + self.socket.sendall(data) + + +@final +class UDPSocketConnection(_SocketConnection): + """UDP Connection class.""" + + __slots__ = ("addr",) + + def __init__(self, addr: Address, timeout: float = 3) -> None: + super().__init__() + self.addr = addr + self.socket = socket.socket( + socket.AF_INET if ip_address(addr[0]).version == 4 else socket.AF_INET6, + socket.SOCK_DGRAM, + ) + self.socket.settimeout(timeout) + + @property + def remaining(self) -> int: + """Always return ``65535`` (``2 ** 16 - 1``).""" + return 65535 + + @override + def read(self, _length: int, /) -> bytearray: + """Return up to :meth:`.remaining` bytes. Length does nothing here.""" + result = bytearray() + while len(result) == 0: + result.extend(self.socket.recvfrom(self.remaining)[0]) + return result + + @override + def write(self, data: str | bytes | bytearray, /) -> None: + """Use :attr:`.socket` to send data to :attr:`.addr`.""" + if isinstance(data, str): + data = data.encode("utf-8") + self.socket.sendto(data, self.addr) + + +@final +class TCPAsyncSocketConnection(BaseAsyncConnection): + """Asynchronous TCP Connection class.""" + + __slots__ = ("_addr", "reader", "timeout", "writer") + + def __init__(self, addr: Address, timeout: float = 3) -> None: + # These will only be None until connect is called, ignore the None type assignment + self.reader: asyncio.StreamReader = None # pyright: ignore[reportAttributeAccessIssue] + self.writer: asyncio.StreamWriter = None # pyright: ignore[reportAttributeAccessIssue] + self.timeout: float = timeout + self._addr = addr + + async def connect(self) -> None: + """Use :mod:`asyncio` to open a connection to address. Timeout is in seconds.""" + conn = asyncio.open_connection(*self._addr) + self.reader, self.writer = await asyncio.wait_for(conn, timeout=self.timeout) + if self.writer is not None: # it might be None in unittest + sock: socket.socket = self.writer.transport.get_extra_info("socket") + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + + @override + async def read(self, length: int, /) -> bytearray: + """Read up to ``length`` bytes from :attr:`.reader`.""" + result = bytearray() + while len(result) < length: + new = await asyncio.wait_for(self.reader.read(length - len(result)), timeout=self.timeout) + if len(new) == 0: + # No information at all + if len(result) == 0: + raise OSError("Server did not respond with any information!") + # We did get a few bytes, but we requested more + raise OSError( + f"Server stopped responding (got {len(result)} bytes, but expected {length} bytes)." + f" Partial obtained data: {result!r}" + ) + result.extend(new) + return result + + @override + async def write(self, data: str | bytes | bytearray, /) -> None: + """Write data to :attr:`.writer`.""" + if isinstance(data, str): + data = data.encode("utf-8") + self.writer.write(data) + await self.writer.drain() + + async def close(self) -> None: + """Close :attr:`.writer`.""" + if self.writer is not None: # If initialized + self.writer.close() + await self.writer.wait_closed() + + async def __aenter__(self) -> Self: + await self.connect() + return self + + async def __aexit__(self, *_: object) -> None: + await self.close() + + +@final +class UDPAsyncSocketConnection(BaseAsyncConnection): + """Asynchronous UDP Connection class.""" + + __slots__ = ("_addr", "stream", "timeout") + + def __init__(self, addr: Address, timeout: float = 3) -> None: + # This will only be None until connect is called, ignore the None type assignment + self.stream: asyncio_dgram.aio.DatagramClient = None # pyright: ignore[reportAttributeAccessIssue] + self.timeout: float = timeout + self._addr = addr + + async def connect(self) -> None: + """Connect to address. Timeout is in seconds.""" + conn = asyncio_dgram.connect(self._addr) + self.stream = await asyncio.wait_for(conn, timeout=self.timeout) + + @property + def remaining(self) -> int: + """Always return ``65535`` (``2 ** 16 - 1``).""" + return 65535 + + @override + async def read(self, _length: int, /) -> bytearray: + """Read from :attr:`.stream`. Length does nothing here.""" + data, _remote_addr = await asyncio.wait_for(self.stream.recv(), timeout=self.timeout) + return bytearray(data) + + @override + async def write(self, data: str | bytes | bytearray, /) -> None: + """Send data with :attr:`.stream`.""" + if isinstance(data, str): + data = data.encode("utf-8") + await self.stream.send(data) + + def close(self) -> None: + """Close :attr:`.stream`.""" + if self.stream is not None: # If initialized + self.stream.close() + + async def __aenter__(self) -> Self: + await self.connect() + return self + + async def __aexit__(self, *_: object) -> None: + self.close() diff --git a/mcstatus/_protocol/java_client.py b/mcstatus/_protocol/java_client.py index e27b3d95..05785e6b 100644 --- a/mcstatus/_protocol/java_client.py +++ b/mcstatus/_protocol/java_client.py @@ -7,13 +7,15 @@ from time import perf_counter from typing import TYPE_CHECKING, final -from mcstatus._protocol.connection import Connection, TCPAsyncSocketConnection, TCPSocketConnection +from mcstatus._protocol.io.base_io import StructFormat +from mcstatus._protocol.io.buffer import Buffer from mcstatus.responses import JavaStatusResponse if TYPE_CHECKING: from collections.abc import Awaitable from mcstatus._net.address import Address + from mcstatus._protocol.io.connection import TCPAsyncSocketConnection, TCPSocketConnection from mcstatus.responses._raw import RawJavaResponse __all__ = ["AsyncJavaClient", "JavaClient"] @@ -32,16 +34,15 @@ def __post_init__(self) -> None: if self.ping_token is None: self.ping_token = random.randint(0, (1 << 63) - 1) - def handshake(self) -> None: - """Write the initial handshake packet to the connection.""" - packet = Connection() + def _build_handshake_packet(self) -> Buffer: + """Build the initial handshake packet.""" + packet = Buffer() packet.write_varint(0) packet.write_varint(self.version) packet.write_utf(self.address.host) - packet.write_ushort(self.address.port) + packet.write_value(StructFormat.USHORT, self.address.port) packet.write_varint(1) # Intention to query status - - self.connection.write_buffer(packet) + return packet @abstractmethod def read_status(self) -> JavaStatusResponse | Awaitable[JavaStatusResponse]: @@ -53,7 +54,7 @@ def test_ping(self) -> float | Awaitable[float]: """Send a ping token and measure the latency.""" raise NotImplementedError - def _handle_status_response(self, response: Connection, start: float, end: float) -> JavaStatusResponse: + def _handle_status_response(self, response: Buffer, start: float, end: float) -> JavaStatusResponse: """Given a response buffer (already read from connection), parse and build the JavaStatusResponse.""" if response.read_varint() != 0: raise OSError("Received invalid status response packet.") @@ -68,11 +69,11 @@ def _handle_status_response(self, response: Connection, start: float, end: float except KeyError as e: raise OSError("Received invalid status response") from e - def _handle_ping_response(self, response: Connection, start: float, end: float) -> float: + def _handle_ping_response(self, response: Buffer, start: float, end: float) -> float: """Given a ping response buffer, validate token and compute latency.""" if response.read_varint() != 1: raise OSError("Received invalid ping response packet.") - received_token = response.read_long() + received_token = response.read_value(StructFormat.LONGLONG) if received_token != self.ping_token: raise OSError(f"Received mangled ping response (expected token {self.ping_token}, got {received_token})") return (end - start) * 1000 @@ -83,26 +84,30 @@ def _handle_ping_response(self, response: Connection, start: float, end: float) class JavaClient(_BaseJavaClient): connection: TCPSocketConnection # pyright: ignore[reportIncompatibleVariableOverride] + def handshake(self) -> None: + """Write the initial handshake packet to the connection.""" + self.connection.write_bytearray(self._build_handshake_packet()) + def read_status(self) -> JavaStatusResponse: """Send the status request and read the response.""" - request = Connection() + request = Buffer() request.write_varint(0) # Request status - self.connection.write_buffer(request) + self.connection.write_bytearray(request) start = perf_counter() - response = self.connection.read_buffer() + response = Buffer(self.connection.read_bytearray()) end = perf_counter() return self._handle_status_response(response, start, end) def test_ping(self) -> float: """Send a ping token and measure the latency.""" - request = Connection() + request = Buffer() request.write_varint(1) # Test ping - request.write_long(self.ping_token) + request.write_value(StructFormat.LONGLONG, self.ping_token) start = perf_counter() - self.connection.write_buffer(request) + self.connection.write_bytearray(request) - response = self.connection.read_buffer() + response = Buffer(self.connection.read_bytearray()) end = perf_counter() return self._handle_ping_response(response, start, end) @@ -112,25 +117,29 @@ def test_ping(self) -> float: class AsyncJavaClient(_BaseJavaClient): connection: TCPAsyncSocketConnection # pyright: ignore[reportIncompatibleVariableOverride] + async def handshake(self) -> None: + """Write the initial handshake packet to the connection.""" + await self.connection.write_bytearray(self._build_handshake_packet()) + async def read_status(self) -> JavaStatusResponse: """Send the status request and read the response.""" - request = Connection() + request = Buffer() request.write_varint(0) # Request status - self.connection.write_buffer(request) + await self.connection.write_bytearray(request) start = perf_counter() - response = await self.connection.read_buffer() + response = Buffer(await self.connection.read_bytearray()) end = perf_counter() return self._handle_status_response(response, start, end) async def test_ping(self) -> float: """Send a ping token and measure the latency.""" - request = Connection() + request = Buffer() request.write_varint(1) # Test ping - request.write_long(self.ping_token) + request.write_value(StructFormat.LONGLONG, self.ping_token) start = perf_counter() - self.connection.write_buffer(request) + await self.connection.write_bytearray(request) - response = await self.connection.read_buffer() + response = Buffer(await self.connection.read_bytearray()) end = perf_counter() return self._handle_ping_response(response, start, end) diff --git a/mcstatus/_protocol/legacy_client.py b/mcstatus/_protocol/legacy_client.py index 2e2270de..a90274bd 100644 --- a/mcstatus/_protocol/legacy_client.py +++ b/mcstatus/_protocol/legacy_client.py @@ -1,6 +1,7 @@ from time import perf_counter -from mcstatus._protocol.connection import BaseAsyncReadSyncWriteConnection, BaseSyncConnection +from mcstatus._protocol.io.base_io import StructFormat +from mcstatus._protocol.io.connection import BaseAsyncConnection, BaseSyncConnection from mcstatus.responses import LegacyStatusResponse __all__ = ["AsyncLegacyClient", "LegacyClient"] @@ -35,24 +36,24 @@ def read_status(self) -> LegacyStatusResponse: id = self.connection.read(1) if id != b"\xff": raise OSError("Received invalid packet ID") - length = self.connection.read_ushort() + length = self.connection.read_value(StructFormat.USHORT) data = self.connection.read(length * 2) end = perf_counter() return self.parse_response(data, (end - start) * 1000) class AsyncLegacyClient(_BaseLegacyClient): - def __init__(self, connection: BaseAsyncReadSyncWriteConnection) -> None: + def __init__(self, connection: BaseAsyncConnection) -> None: self.connection = connection async def read_status(self) -> LegacyStatusResponse: """Send the status request and read the response.""" start = perf_counter() - self.connection.write(self.request_status_data) + await self.connection.write(self.request_status_data) id = await self.connection.read(1) if id != b"\xff": raise OSError("Received invalid packet ID") - length = await self.connection.read_ushort() + length = await self.connection.read_value(StructFormat.USHORT) data = await self.connection.read(length * 2) end = perf_counter() return self.parse_response(data, (end - start) * 1000) diff --git a/mcstatus/_protocol/query_client.py b/mcstatus/_protocol/query_client.py index 7ede695a..4092a7f5 100644 --- a/mcstatus/_protocol/query_client.py +++ b/mcstatus/_protocol/query_client.py @@ -2,12 +2,12 @@ import random import re -import struct from abc import abstractmethod from dataclasses import dataclass, field from typing import ClassVar, TYPE_CHECKING, final -from mcstatus._protocol.connection import Connection, UDPAsyncSocketConnection, UDPSocketConnection +from mcstatus._protocol.io.base_io import StructFormat +from mcstatus._protocol.io.buffer import Buffer from mcstatus.responses import QueryResponse from mcstatus.responses._raw import RawQueryResponse @@ -16,6 +16,8 @@ if TYPE_CHECKING: from collections.abc import Awaitable + from mcstatus._protocol.io.connection import UDPAsyncSocketConnection, UDPSocketConnection + @dataclass class _BaseQueryClient: @@ -32,24 +34,24 @@ def _generate_session_id() -> int: # minecraft only supports lower 4 bits return random.randint(0, 2**31) & 0x0F0F0F0F - def _create_packet(self) -> Connection: - packet = Connection() + def _create_packet(self) -> Buffer: + packet = Buffer() packet.write(self.MAGIC_PREFIX) - packet.write(struct.pack("!B", self.PACKET_TYPE_QUERY)) - packet.write_uint(self._generate_session_id()) - packet.write_int(self.challenge) + packet.write_value(StructFormat.UBYTE, self.PACKET_TYPE_QUERY) + packet.write_value(StructFormat.UINT, self._generate_session_id()) + packet.write_value(StructFormat.INT, self.challenge) packet.write(self.PADDING) return packet - def _create_handshake_packet(self) -> Connection: - packet = Connection() + def _create_handshake_packet(self) -> Buffer: + packet = Buffer() packet.write(self.MAGIC_PREFIX) - packet.write(struct.pack("!B", self.PACKET_TYPE_CHALLENGE)) - packet.write_uint(self._generate_session_id()) + packet.write_value(StructFormat.UBYTE, self.PACKET_TYPE_CHALLENGE) + packet.write_value(StructFormat.UINT, self._generate_session_id()) return packet @abstractmethod - def _read_packet(self) -> Connection | Awaitable[Connection]: + def _read_packet(self) -> Buffer | Awaitable[Buffer]: raise NotImplementedError @abstractmethod @@ -60,7 +62,7 @@ def handshake(self) -> None | Awaitable[None]: def read_query(self) -> QueryResponse | Awaitable[QueryResponse]: raise NotImplementedError - def _parse_response(self, response: Connection) -> tuple[RawQueryResponse, list[str]]: + def _parse_response(self, response: Buffer) -> tuple[RawQueryResponse, list[str]]: """Transform the connection object (the result) into dict which is passed to the QueryResponse constructor. :return: A tuple with two elements. First is `raw` answer and second is list of players. @@ -73,7 +75,7 @@ def _parse_response(self, response: Connection) -> tuple[RawQueryResponse, list[ if key == "hostname": # hostname is actually motd in the query protocol match = re.search( b"(.*?)\x00(hostip|hostport|game_id|gametype|map|maxplayers|numplayers|plugins|version)", - response.received, + bytes(response.unread_view()), flags=re.DOTALL, ) motd = match.group(1) if match else "" @@ -105,9 +107,8 @@ def _parse_response(self, response: Connection) -> tuple[RawQueryResponse, list[ class QueryClient(_BaseQueryClient): connection: UDPSocketConnection # pyright: ignore[reportIncompatibleVariableOverride] - def _read_packet(self) -> Connection: - packet = Connection() - packet.receive(self.connection.read(self.connection.remaining())) + def _read_packet(self) -> Buffer: + packet = Buffer(self.connection.read(self.connection.remaining)) packet.read(1 + 4) return packet @@ -130,9 +131,8 @@ def read_query(self) -> QueryResponse: class AsyncQueryClient(_BaseQueryClient): connection: UDPAsyncSocketConnection # pyright: ignore[reportIncompatibleVariableOverride] - async def _read_packet(self) -> Connection: - packet = Connection() - packet.receive(await self.connection.read(self.connection.remaining())) + async def _read_packet(self) -> Buffer: + packet = Buffer(await self.connection.read(self.connection.remaining)) packet.read(1 + 4) return packet diff --git a/mcstatus/responses/forge.py b/mcstatus/responses/forge.py index 0b53ed3c..70480860 100644 --- a/mcstatus/responses/forge.py +++ b/mcstatus/responses/forge.py @@ -18,7 +18,8 @@ from io import StringIO from typing import Final, TYPE_CHECKING -from mcstatus._protocol.connection import BaseConnection, BaseReadSync, Connection +from mcstatus._protocol.io.base_io import BaseSyncReader, StructFormat +from mcstatus._protocol.io.buffer import Buffer from mcstatus._utils import or_none if TYPE_CHECKING: @@ -57,10 +58,10 @@ def build(cls, raw: RawForgeDataChannel) -> Self: return cls(name=raw["res"], version=raw["version"], required=raw["required"]) @classmethod - def decode(cls, buffer: Connection, mod_id: str | None = None) -> Self: + def decode(cls, buffer: Buffer, mod_id: str | None = None) -> Self: """Decode an object about Forge channel from decoded optimized buffer. - :param buffer: :class:`Connection` object from UTF-16 encoded binary data. + :param buffer: :class:`Buffer` object from UTF-16 encoded binary data. :param mod_id: Optional mod id prefix :class:`str`. :return: :class:`ForgeDataChannel` object. """ @@ -68,7 +69,7 @@ def decode(cls, buffer: Connection, mod_id: str | None = None) -> Self: if mod_id is not None: channel_identifier = f"{mod_id}:{channel_identifier}" version = buffer.read_utf() - client_required = buffer.read_bool() + client_required = buffer.read_value(StructFormat.BOOL) return cls( name=channel_identifier, @@ -106,10 +107,10 @@ def build(cls, raw: RawForgeDataMod) -> Self: return cls(name=mod_id, marker=mod_version) @classmethod - def decode(cls, buffer: Connection) -> tuple[Self, list[ForgeDataChannel]]: + def decode(cls, buffer: Buffer) -> tuple[Self, list[ForgeDataChannel]]: """Decode data about a Forge mod from decoded optimized buffer. - :param buffer: :class:`Connection` object from UTF-16 encoded binary data. + :param buffer: :class:`Buffer` object from UTF-16 encoded binary data. :return: :class:`tuple` object of :class:`ForgeDataMod` object and :class:`list` of :class:`ForgeDataChannel` objects. """ channel_version_flags = buffer.read_varint() @@ -127,7 +128,7 @@ def decode(cls, buffer: Connection) -> tuple[Self, list[ForgeDataChannel]]: return cls(name=mod_id, marker=mod_version), channels -class _StringBuffer(BaseReadSync, BaseConnection): +class _StringBuffer(BaseSyncReader): """String Buffer for reading utf-16 encoded binary data.""" __slots__ = ("received", "stringio") @@ -136,7 +137,7 @@ def __init__(self, stringio: StringIO) -> None: self.stringio = stringio self.received = bytearray() - def read(self, length: int) -> bytearray: + def read(self, length: int, /) -> bytearray: """Read length bytes from ``self``, and return a byte array.""" data = bytearray() while self.received and len(data) < length: @@ -156,20 +157,20 @@ def remaining(self) -> int: def read_optimized_size(self) -> int: """Read encoded data length.""" - return self.read_short() | (self.read_short() << 15) + return self.read_value(StructFormat.SHORT) | (self.read_value(StructFormat.SHORT) << 15) - def read_optimized_buffer(self) -> Connection: + def read_optimized_buffer(self) -> Buffer: """Read encoded buffer.""" size = self.read_optimized_size() - buffer = Connection() + buffer = Buffer() value, bits = 0, 0 - while buffer.remaining() < size: + while buffer.remaining < size: if bits < 8 and self.remaining(): # Ignoring sign bit - value |= (self.read_short() & 0x7FFF) << bits + value |= (self.read_value(StructFormat.SHORT) & 0x7FFF) << bits bits += 15 - buffer.receive((value & 0xFF).to_bytes(1, "big")) + buffer.write((value & 0xFF).to_bytes(1, "big")) value >>= 8 bits -= 8 @@ -190,7 +191,7 @@ class ForgeData: """Is the mods list and or channel list incomplete?""" @staticmethod - def _decode_optimized(string: str) -> Connection: + def _decode_optimized(string: str) -> Buffer: """Decode buffer from UTF-16 optimized binary data ``string``.""" with StringIO(string) as text: str_buffer = _StringBuffer(text) @@ -223,8 +224,8 @@ def build(cls, raw: RawForgeData) -> Self: channels: list[ForgeDataChannel] = [] mods: list[ForgeDataMod] = [] - truncated = buffer.read_bool() - mod_count = buffer.read_ushort() + truncated = buffer.read_value(StructFormat.BOOL) + mod_count = buffer.read_value(StructFormat.USHORT) try: for _ in range(mod_count): mod, mod_channels = ForgeDataMod.decode(buffer) diff --git a/mcstatus/server.py b/mcstatus/server.py index fc9340e5..0a78c13c 100644 --- a/mcstatus/server.py +++ b/mcstatus/server.py @@ -5,7 +5,7 @@ from mcstatus._net.address import Address, async_minecraft_srv_address_lookup, minecraft_srv_address_lookup from mcstatus._protocol.bedrock_client import BedrockClient -from mcstatus._protocol.connection import ( +from mcstatus._protocol.io.connection import ( TCPAsyncSocketConnection, TCPSocketConnection, UDPAsyncSocketConnection, @@ -170,7 +170,7 @@ async def _retry_async_ping( version=version, ping_token=ping_token, # pyright: ignore[reportArgumentType] # None is not assignable to int ) - java_client.handshake() + await java_client.handshake() ping = await java_client.test_ping() return ping @@ -230,7 +230,7 @@ async def _retry_async_status( version=version, ping_token=ping_token, # pyright: ignore[reportArgumentType] # None is not assignable to int ) - java_client.handshake() + await java_client.handshake() result = await java_client.read_status() return result From e959b2c11b496315b99f620525af7bad54bdea99 Mon Sep 17 00:00:00 2001 From: ItsDrike Date: Thu, 26 Mar 2026 22:31:30 +0100 Subject: [PATCH 02/10] Add tests for new IO implementation --- tests/protocol/helpers.py | 161 +++++++ tests/protocol/test_async_support.py | 10 +- tests/protocol/test_base_io.py | 440 ++++++++++++++++++ .../protocol/test_base_io_twos_complement.py | 72 +++ tests/protocol/test_connection.py | 192 ++++---- tests/protocol/test_java_client.py | 111 ++--- tests/protocol/test_java_client_async.py | 131 +++--- tests/protocol/test_legacy_client.py | 4 +- tests/protocol/test_query_client.py | 29 +- tests/protocol/test_query_client_async.py | 22 +- tests/protocol/test_timeout.py | 58 ++- tests/test_server.py | 125 ++--- tests/utils/test_retry.py | 2 +- 13 files changed, 1000 insertions(+), 357 deletions(-) create mode 100644 tests/protocol/helpers.py create mode 100644 tests/protocol/test_base_io.py create mode 100644 tests/protocol/test_base_io_twos_complement.py diff --git a/tests/protocol/helpers.py b/tests/protocol/helpers.py new file mode 100644 index 00000000..cdb7e4ce --- /dev/null +++ b/tests/protocol/helpers.py @@ -0,0 +1,161 @@ +from __future__ import annotations + +import asyncio +from typing import Any, ParamSpec, TYPE_CHECKING, TypeVar + +from mcstatus._protocol.io.buffer import Buffer +from mcstatus._protocol.io.connection import BaseAsyncConnection, BaseSyncConnection + +if TYPE_CHECKING: + from collections.abc import Callable, Coroutine + + from typing_extensions import override +else: + override = lambda f: f # noqa: E731 + +P = ParamSpec("P") +T = TypeVar("T") + + +def async_decorator(f: Callable[P, Coroutine[Any, Any, T]]) -> Callable[P, T]: + """Wrap an async callable so it can be invoked from synchronous tests.""" + + def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: + """Execute the wrapped coroutine function using ``asyncio.run``.""" + return asyncio.run(f(*args, **kwargs)) + + return wrapper + + +class SyncBufferConnection(BaseSyncConnection): + """In-memory synchronous stream-style transport used by protocol tests.""" + + def __init__(self, incoming: bytes | bytearray = b"") -> None: + """Initialize in-memory read/write buffers with optional incoming bytes.""" + self.sent = Buffer() + self.received = Buffer(incoming) + + @override + def read(self, length: int, /) -> bytes: + """Read ``length`` bytes from the queued incoming data.""" + return self.received.read(length) + + @override + def write(self, data: bytes | bytearray, /) -> None: + """Append outgoing payload data to the send buffer.""" + self.sent.write(data) + + def receive(self, data: bytes | bytearray) -> None: + """Queue incoming payload data for subsequent reads.""" + self.received.write(data) + + def remaining(self) -> int: + """Return unread byte count from the incoming buffer.""" + return self.received.remaining + + def flush(self) -> bytes: + """Return and clear all bytes written to the send buffer.""" + return self.sent.flush() + + +class AsyncBufferConnection(BaseAsyncConnection): + """In-memory asynchronous stream-style transport used by protocol tests.""" + + def __init__(self, incoming: bytes | bytearray = b"") -> None: + """Initialize in-memory read/write buffers with optional incoming bytes.""" + self.sent = Buffer() + self.received = Buffer(incoming) + + @override + async def read(self, length: int, /) -> bytes: + """Read ``length`` bytes from the queued incoming data.""" + return self.received.read(length) + + @override + async def write(self, data: bytes | bytearray, /) -> None: + """Append outgoing payload data to the send buffer.""" + self.sent.write(data) + + def receive(self, data: bytes | bytearray) -> None: + """Queue incoming payload data for subsequent reads.""" + self.received.write(data) + + def remaining(self) -> int: + """Return unread byte count from the incoming buffer.""" + return self.received.remaining + + def flush(self) -> bytes: + """Return and clear all bytes written to the send buffer.""" + return self.sent.flush() + + +class SyncDatagramConnection(BaseSyncConnection): + """Datagram-like synchronous transport returning one queued packet per read.""" + + def __init__(self) -> None: + self.sent = Buffer() + self.received: list[bytes] = [] + + @override + def read(self, _length: int, /) -> bytes: + """Pop and return the next queued datagram payload.""" + if not self.received: + raise OSError("No datagram data to read.") + return self.received.pop(0) + + @override + def write(self, data: bytes | bytearray | str, /) -> None: + """Append outgoing datagram payload to the send buffer.""" + if isinstance(data, str): + data = data.encode("utf-8") + self.sent.write(data) + + def receive(self, data: bytes | bytearray) -> None: + """Queue one incoming datagram payload for future reads.""" + self.received.append(bytes(data)) + + def remaining(self) -> int: + """Return size of the next queued datagram, or ``0`` when empty.""" + if not self.received: + return 0 + return len(self.received[0]) + + def flush(self) -> bytes: + """Return and clear all bytes written to the send buffer.""" + return self.sent.flush() + + +class AsyncDatagramConnection(BaseAsyncConnection): + """Datagram-like asynchronous transport returning one queued packet per read.""" + + def __init__(self) -> None: + self.sent = Buffer() + self.received: list[bytes] = [] + + @override + async def read(self, _length: int, /) -> bytes: + """Pop and return the next queued datagram payload.""" + if not self.received: + raise OSError("No datagram data to read.") + return self.received.pop(0) + + @override + async def write(self, data: bytes | bytearray | str, /) -> None: + """Append outgoing datagram payload to the send buffer.""" + if isinstance(data, str): + data = data.encode("utf-8") + self.sent.write(data) + + def receive(self, data: bytes | bytearray) -> None: + """Queue one incoming datagram payload for future reads.""" + self.received.append(bytes(data)) + + def remaining(self) -> int: + """Return size of the next queued datagram, or ``0`` when empty.""" + if not self.received: + return 0 + return len(self.received[0]) + + def flush(self) -> bytes: + """Return and clear all bytes written to the send buffer.""" + return self.sent.flush() diff --git a/tests/protocol/test_async_support.py b/tests/protocol/test_async_support.py index badeb3cc..0647346b 100644 --- a/tests/protocol/test_async_support.py +++ b/tests/protocol/test_async_support.py @@ -1,23 +1,23 @@ from inspect import iscoroutinefunction -from mcstatus._protocol.connection import TCPAsyncSocketConnection, UDPAsyncSocketConnection +from mcstatus._protocol.io.connection import TCPAsyncSocketConnection, UDPAsyncSocketConnection def test_is_completely_asynchronous(): conn = TCPAsyncSocketConnection assertions = 0 for attribute in dir(conn): - if attribute.startswith("read_"): + if attribute.startswith(("read_", "write_")): assert iscoroutinefunction(getattr(conn, attribute)) assertions += 1 - assert assertions > 0, "None of the read_* attributes were async" + assert assertions > 0, "No read_*/write_* attributes were found" def test_query_is_completely_asynchronous(): conn = UDPAsyncSocketConnection assertions = 0 for attribute in dir(conn): - if attribute.startswith("read_"): + if attribute.startswith(("read_", "write_")): assert iscoroutinefunction(getattr(conn, attribute)) assertions += 1 - assert assertions > 0, "None of the read_* attributes were async" + assert assertions > 0, "No read_*/write_* attributes were found" diff --git a/tests/protocol/test_base_io.py b/tests/protocol/test_base_io.py new file mode 100644 index 00000000..022715e2 --- /dev/null +++ b/tests/protocol/test_base_io.py @@ -0,0 +1,440 @@ +from __future__ import annotations + +import struct +from inspect import isawaitable +from typing import TYPE_CHECKING, TypeVar, overload +from unittest.mock import AsyncMock, Mock + +import pytest + +from mcstatus._protocol.io.base_io import INT_FORMATS_TYPE, StructFormat +from mcstatus._protocol.io.buffer import Buffer +from tests.protocol.helpers import AsyncBufferConnection, SyncBufferConnection + +if TYPE_CHECKING: + from collections.abc import Awaitable + +IO_TYPE = type[SyncBufferConnection] | type[AsyncBufferConnection] +T = TypeVar("T") + + +@pytest.fixture( + params=[ + pytest.param(SyncBufferConnection, id="sync"), + pytest.param(AsyncBufferConnection, id="async"), + ] +) +def io_type(request: pytest.FixtureRequest) -> IO_TYPE: + """Provide a parametrized sync and async IO connections to each test.""" + return request.param + + +@overload +async def maybe_await(value: Awaitable[T], /) -> T: ... + + +@overload +async def maybe_await(value: T, /) -> T: ... + + +async def maybe_await(value: Awaitable[T] | T, /) -> T: + """Return a value directly or await it when needed. + + This keeps test calls explicit and type-checkable for both sync and async IO APIs. + """ + if isawaitable(value): + return await value + return value + + +@pytest.mark.parametrize( + ("fmt", "value", "expected"), + [ + pytest.param(StructFormat.UBYTE, 0, b"\x00", id="ubyte-0"), + pytest.param(StructFormat.UBYTE, 15, b"\x0f", id="ubyte-15"), + pytest.param(StructFormat.UBYTE, 255, b"\xff", id="ubyte-max"), + pytest.param(StructFormat.BYTE, 0, b"\x00", id="byte-0"), + pytest.param(StructFormat.BYTE, 15, b"\x0f", id="byte-15"), + pytest.param(StructFormat.BYTE, 127, b"\x7f", id="byte-max"), + pytest.param(StructFormat.BYTE, -20, b"\xec", id="byte-neg-20"), + pytest.param(StructFormat.BYTE, -128, b"\x80", id="byte-min"), + ], +) +@pytest.mark.asyncio +async def test_write_value_matches_reference(io_type: IO_TYPE, fmt: INT_FORMATS_TYPE, value: int, expected: bytes): + io = io_type() + await maybe_await(io.write_value(fmt, value)) + assert io.flush() == expected + + +@pytest.mark.parametrize( + ("fmt", "value"), + [ + pytest.param(StructFormat.UBYTE, -1, id="ubyte-neg"), + pytest.param(StructFormat.UBYTE, 256, id="ubyte-overflow"), + pytest.param(StructFormat.BYTE, -129, id="byte-underflow"), + pytest.param(StructFormat.BYTE, 128, id="byte-overflow"), + ], +) +@pytest.mark.asyncio +async def test_write_value_rejects_out_of_range(io_type: IO_TYPE, fmt: INT_FORMATS_TYPE, value: int): + io = io_type() + with pytest.raises(struct.error): + await maybe_await(io.write_value(fmt, value)) + + +@pytest.mark.parametrize( + ("encoded", "fmt", "expected"), + [ + pytest.param(b"\x00", StructFormat.UBYTE, 0, id="ubyte-0"), + pytest.param(b"\x0f", StructFormat.UBYTE, 15, id="ubyte-15"), + pytest.param(b"\xff", StructFormat.UBYTE, 255, id="ubyte-max"), + pytest.param(b"\x00", StructFormat.BYTE, 0, id="byte-0"), + pytest.param(b"\x0f", StructFormat.BYTE, 15, id="byte-15"), + pytest.param(b"\x7f", StructFormat.BYTE, 127, id="byte-max"), + pytest.param(b"\xec", StructFormat.BYTE, -20, id="byte-neg-20"), + pytest.param(b"\x80", StructFormat.BYTE, -128, id="byte-min"), + ], +) +@pytest.mark.asyncio +async def test_read_value_matches_reference(io_type: IO_TYPE, encoded: bytes, fmt: INT_FORMATS_TYPE, expected: int): + io = io_type(encoded) + assert await maybe_await(io.read_value(fmt)) == expected + + +@pytest.mark.parametrize( + ("number", "expected"), + [ + pytest.param(0, b"\x00", id="0"), + pytest.param(127, b"\x7f", id="127"), + pytest.param(128, b"\x80\x01", id="128"), + pytest.param(255, b"\xff\x01", id="255"), + pytest.param(1_000_000, b"\xc0\x84\x3d", id="1m"), + pytest.param((2**31) - 1, b"\xff\xff\xff\xff\x07", id="max-32"), + ], +) +@pytest.mark.asyncio +async def test_write_varuint_matches_reference(io_type: IO_TYPE, number: int, expected: bytes): + io = io_type() + await maybe_await(io._write_varuint(number)) + assert io.flush() == expected + + +@pytest.mark.parametrize( + ("encoded", "expected"), + [ + pytest.param(b"\x00", 0, id="0"), + pytest.param(b"\x7f", 127, id="127"), + pytest.param(b"\x80\x01", 128, id="128"), + pytest.param(b"\xff\x01", 255, id="255"), + pytest.param(b"\xc0\x84\x3d", 1_000_000, id="1m"), + pytest.param(b"\xff\xff\xff\xff\x07", (2**31) - 1, id="max-32"), + ], +) +@pytest.mark.asyncio +async def test_read_varuint_matches_reference(io_type: IO_TYPE, encoded: bytes, expected: int): + io = io_type(encoded) + assert await maybe_await(io._read_varuint()) == expected + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("number", "max_bits"), + [ + (-1, 128), + (-1, 1), + (2**16, 16), + (2**32, 32), + ], +) +async def test_write_varuint_rejects_out_of_range(io_type: IO_TYPE, number: int, max_bits: int): + io = io_type() + with pytest.raises(ValueError, match=r"outside of the range of"): + await maybe_await(io._write_varuint(number, max_bits=max_bits)) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("encoded", "max_bits"), + [ + (b"\x80\x80\x04", 16), + (b"\x80\x80\x80\x80\x10", 32), + ], +) +async def test_read_varuint_rejects_out_of_range(io_type: IO_TYPE, encoded: bytes, max_bits: int): + io = io_type(encoded) + with pytest.raises(OSError, match=r"outside the range of"): + await maybe_await(io._read_varuint(max_bits=max_bits)) + + +@pytest.mark.parametrize( + ("number", "expected"), + [ + pytest.param(127, b"\x7f", id="127"), + pytest.param(16_384, b"\x80\x80\x01", id="16384"), + pytest.param(-128, b"\x80\xff\xff\xff\x0f", id="-128"), + pytest.param(-16_383, b"\x81\x80\xff\xff\x0f", id="-16383"), + ], +) +@pytest.mark.asyncio +async def test_write_varint_matches_reference(io_type: IO_TYPE, number: int, expected: bytes): + io = io_type() + await maybe_await(io.write_varint(number)) + assert io.flush() == expected + + +@pytest.mark.parametrize( + ("encoded", "expected"), + [ + pytest.param(b"\x7f", 127, id="127"), + pytest.param(b"\x80\x80\x01", 16_384, id="16384"), + pytest.param(b"\x80\xff\xff\xff\x0f", -128, id="-128"), + pytest.param(b"\x81\x80\xff\xff\x0f", -16_383, id="-16383"), + ], +) +@pytest.mark.asyncio +async def test_read_varint_matches_reference(io_type: IO_TYPE, encoded: bytes, expected: int): + io = io_type(encoded) + assert await maybe_await(io.read_varint()) == expected + + +@pytest.mark.parametrize( + ("number", "expected"), + [ + pytest.param(127, b"\x7f", id="127"), + pytest.param(16_384, b"\x80\x80\x01", id="16384"), + pytest.param(-128, b"\x80\xff\xff\xff\xff\xff\xff\xff\xff\x01", id="-128"), + pytest.param(-16_383, b"\x81\x80\xff\xff\xff\xff\xff\xff\xff\x01", id="-16383"), + ], +) +@pytest.mark.asyncio +async def test_write_varlong_matches_reference(io_type: IO_TYPE, number: int, expected: bytes): + io = io_type() + await maybe_await(io.write_varlong(number)) + assert io.flush() == expected + + +@pytest.mark.parametrize( + ("encoded", "expected"), + [ + pytest.param(b"\x7f", 127, id="127"), + pytest.param(b"\x80\x80\x01", 16_384, id="16384"), + pytest.param(b"\x80\xff\xff\xff\xff\xff\xff\xff\xff\x01", -128, id="-128"), + pytest.param(b"\x81\x80\xff\xff\xff\xff\xff\xff\xff\x01", -16_383, id="-16383"), + ], +) +@pytest.mark.asyncio +async def test_read_varlong_matches_reference(io_type: IO_TYPE, encoded: bytes, expected: int): + io = io_type(encoded) + assert await maybe_await(io.read_varlong()) == expected + + +@pytest.mark.asyncio +@pytest.mark.parametrize("number", [0, 1, 127, 16_384, -1, -(2**31), (2**31) - 1]) +async def test_varint_roundtrip(io_type: IO_TYPE, number: int): + io = io_type() + await maybe_await(io.write_varint(number)) + io.receive(io.flush()) + assert await maybe_await(io.read_varint()) == number + + +@pytest.mark.asyncio +@pytest.mark.parametrize("number", [127, 16_384, -128, -16_383, -(2**63), (2**63) - 1]) +async def test_varlong_roundtrip(io_type: IO_TYPE, number: int): + io = io_type() + await maybe_await(io.write_varlong(number)) + io.receive(io.flush()) + assert await maybe_await(io.read_varlong()) == number + + +@pytest.mark.asyncio +@pytest.mark.parametrize("number", [-(2**63) - 1, 2**63]) +async def test_write_varlong_rejects_out_of_range(io_type: IO_TYPE, number: int): + io = io_type() + with pytest.raises(ValueError, match=r"out of range"): + await maybe_await(io.write_varlong(number)) + + +@pytest.mark.asyncio +async def test_optional_helpers(io_type: IO_TYPE): + io = io_type() + + if isinstance(io, AsyncBufferConnection): + writer = AsyncMock(return_value="written") + + assert await io.write_optional(None, writer) is None + writer.assert_not_awaited() + assert io.flush() == b"\x00" + + assert await io.write_optional("value", writer) == "written" + writer.assert_awaited_once_with("value") + assert io.flush() == b"\x01" + + reader = AsyncMock(return_value="parsed") + io.receive(b"\x00") + assert await io.read_optional(reader) is None + reader.assert_not_awaited() + + io.receive(b"\x01") + assert await io.read_optional(reader) == "parsed" + reader.assert_awaited_once_with() + return + + writer = Mock(return_value="written") + + assert io.write_optional(None, writer) is None + writer.assert_not_called() + assert io.flush() == b"\x00" + + assert io.write_optional("value", writer) == "written" + writer.assert_called_once_with("value") + assert io.flush() == b"\x01" + + reader = Mock(return_value="parsed") + io.receive(b"\x00") + assert io.read_optional(reader) is None + reader.assert_not_called() + + io.receive(b"\x01") + assert io.read_optional(reader) == "parsed" + reader.assert_called_once_with() + + +@pytest.mark.asyncio +async def test_write_and_read_ascii(io_type: IO_TYPE): + io = io_type() + await maybe_await(io.write_ascii("hello")) + + io.receive(io.flush()) + assert await maybe_await(io.read_ascii()) == "hello" + + +@pytest.mark.asyncio +async def test_write_and_read_bytearray(io_type: IO_TYPE): + io = io_type() + data = b"\x00\x01hello\xff" + + await maybe_await(io.write_bytearray(data)) + io.receive(io.flush()) + + assert await maybe_await(io.read_bytearray()) == data + + +@pytest.mark.parametrize( + ("data", "expected"), + [ + pytest.param(b"", b"\x00", id="empty"), + pytest.param(b"\x01", b"\x01\x01", id="single"), + pytest.param(b"hello\x00world", b"\x0bhello\x00world", id="with-null"), + pytest.param(b"\x01\x02\x03four\x05", b"\x08\x01\x02\x03four\x05", id="mixed"), + ], +) +@pytest.mark.asyncio +async def test_write_bytearray_matches_reference(io_type: IO_TYPE, data: bytes, expected: bytes): + io = io_type() + await maybe_await(io.write_bytearray(data)) + assert io.flush() == expected + + +@pytest.mark.parametrize( + ("encoded", "expected"), + [ + pytest.param(b"\x00", b"", id="empty"), + pytest.param(b"\x01\x01", b"\x01", id="single"), + pytest.param(b"\x0bhello\x00world", b"hello\x00world", id="with-null"), + pytest.param(b"\x08\x01\x02\x03four\x05", b"\x01\x02\x03four\x05", id="mixed"), + ], +) +@pytest.mark.asyncio +async def test_read_bytearray_matches_reference(io_type: IO_TYPE, encoded: bytes, expected: bytes): + io = io_type(encoded) + assert await maybe_await(io.read_bytearray()) == expected + + +@pytest.mark.parametrize( + ("value", "expected"), + [ + pytest.param("test", b"test\x00", id="test"), + pytest.param("a" * 100, b"a" * 100 + b"\x00", id="100-as"), + pytest.param("", b"\x00", id="empty"), + ], +) +@pytest.mark.asyncio +async def test_write_ascii_matches_reference(io_type: IO_TYPE, value: str, expected: bytes): + io = io_type() + await maybe_await(io.write_ascii(value)) + assert io.flush() == expected + + +@pytest.mark.parametrize( + ("encoded", "expected"), + [ + pytest.param(b"test\x00", "test", id="test"), + pytest.param(b"a" * 100 + b"\x00", "a" * 100, id="100-as"), + pytest.param(b"\x00", "", id="empty"), + ], +) +@pytest.mark.asyncio +async def test_read_ascii_matches_reference(io_type: IO_TYPE, encoded: bytes, expected: str): + io = io_type(encoded) + assert await maybe_await(io.read_ascii()) == expected + + +@pytest.mark.parametrize( + ("value", "expected"), + [ + pytest.param("test", b"\x04test", id="test"), + pytest.param("a" * 100, b"\x64" + b"a" * 100, id="100-as"), + pytest.param("", b"\x00", id="empty"), + pytest.param("नमस्ते", b"\x12" + "नमस्ते".encode(), id="hindi"), + ], +) +@pytest.mark.asyncio +async def test_write_utf_matches_reference(io_type: IO_TYPE, value: str, expected: bytes): + io = io_type() + await maybe_await(io.write_utf(value)) + assert io.flush() == expected + + +@pytest.mark.parametrize( + ("encoded", "expected"), + [ + pytest.param(b"\x04test", "test", id="test"), + pytest.param(b"\x64" + b"a" * 100, "a" * 100, id="100-as"), + pytest.param(b"\x00", "", id="empty"), + pytest.param(b"\x12" + "नमस्ते".encode(), "नमस्ते", id="hindi"), + ], +) +@pytest.mark.asyncio +async def test_read_utf_matches_reference(io_type: IO_TYPE, encoded: bytes, expected: str): + io = io_type(encoded) + assert await maybe_await(io.read_utf()) == expected + + +@pytest.mark.asyncio +async def test_write_utf_rejects_too_many_characters(io_type: IO_TYPE): + io = io_type() + with pytest.raises(ValueError, match=r"Maximum character limit for writing strings is 32767 characters"): + await maybe_await(io.write_utf("a" * 32768)) + + +@pytest.mark.asyncio +async def test_read_utf_rejects_too_many_bytes(io_type: IO_TYPE): + payload = Buffer() + payload.write_varint(131069) + + io = io_type(payload) + with pytest.raises(OSError, match=r"Maximum read limit for utf strings is 131068 bytes, got 131069"): + await maybe_await(io.read_utf()) + + +@pytest.mark.asyncio +async def test_read_utf_rejects_too_many_characters(io_type: IO_TYPE): + text = "a" * 32768 + payload = Buffer() + payload.write_varint(len(text.encode("utf-8"))) + payload.write(text.encode("utf-8")) + + io = io_type(payload) + with pytest.raises(OSError, match=r"Maximum read limit for utf strings is 32767 characters, got 32768"): + await maybe_await(io.read_utf()) diff --git a/tests/protocol/test_base_io_twos_complement.py b/tests/protocol/test_base_io_twos_complement.py new file mode 100644 index 00000000..08e3eb25 --- /dev/null +++ b/tests/protocol/test_base_io_twos_complement.py @@ -0,0 +1,72 @@ +from __future__ import annotations + +import pytest + +from mcstatus._protocol.io.base_io import from_twos_complement, to_twos_complement + +TWOS_COMPLEMENT_CASES = [ + (-128, 8, 0x80), + (-2, 8, 0xFE), + (-1, 8, 0xFF), + (0, 8, 0x00), + (1, 8, 0x01), + (127, 8, 0x7F), + (-(2**15), 16, 0x8000), + (-1, 16, 0xFFFF), + ((2**15) - 1, 16, 0x7FFF), + (-(2**31), 32, 0x80000000), + (-1, 32, 0xFFFFFFFF), + ((2**31) - 1, 32, 0x7FFFFFFF), + (-(2**63), 64, 0x8000000000000000), + (-9_876_543_210_123_456, 64, 0xFFDCE956165A0F40), + (-1, 64, 0xFFFFFFFFFFFFFFFF), + (0, 64, 0x0000000000000000), + (9_876_543_210_123_456, 64, 0x002316A9E9A5F0C0), + ((2**63) - 1, 64, 0x7FFFFFFFFFFFFFFF), +] + + +@pytest.mark.parametrize( + ("number", "bits", "expected_twos"), + TWOS_COMPLEMENT_CASES, +) +def test_to_twos_complement_matches_expected_values(number: int, bits: int, expected_twos: int): + assert to_twos_complement(number, bits=bits) == expected_twos + + +@pytest.mark.parametrize( + ("twos_value", "bits", "expected_number"), + [(twos_value, bits, number) for number, bits, twos_value in TWOS_COMPLEMENT_CASES], +) +def test_from_twos_complement_matches_expected_values(twos_value: int, bits: int, expected_number: int): + assert from_twos_complement(twos_value, bits=bits) == expected_number + + +@pytest.mark.parametrize( + ("number", "bits"), + [ + (-129, 8), + (128, 8), + (-(2**31) - 1, 32), + (2**31, 32), + (-(2**63) - 1, 64), + (2**63, 64), + ], +) +def test_to_twos_complement_rejects_out_of_range(number: int, bits: int): + with pytest.raises(ValueError, match=r"out of range"): + to_twos_complement(number, bits=bits) + + +@pytest.mark.parametrize( + ("number", "bits"), + [ + (-1, 8), + (256, 8), + (2**32, 32), + (2**64, 64), + ], +) +def test_from_twos_complement_rejects_out_of_range(number: int, bits: int): + with pytest.raises(ValueError, match=r"out of range"): + from_twos_complement(number, bits=bits) diff --git a/tests/protocol/test_connection.py b/tests/protocol/test_connection.py index 18062d30..a70a5ad4 100644 --- a/tests/protocol/test_connection.py +++ b/tests/protocol/test_connection.py @@ -3,32 +3,54 @@ import pytest from mcstatus._net.address import Address -from mcstatus._protocol.connection import Connection, TCPSocketConnection, UDPSocketConnection +from mcstatus._protocol.io.base_io import StructFormat +from mcstatus._protocol.io.buffer import Buffer +from mcstatus._protocol.io.connection import TCPSocketConnection, UDPSocketConnection -class TestConnection: - connection: Connection - +class TestBuffer: def setup_method(self): - self.connection = Connection() + self.connection = Buffer() def test_flush(self): - self.connection.sent = bytearray.fromhex("7FAABB") + self.connection.write(bytearray.fromhex("7FAABB")) assert self.connection.flush() == bytearray.fromhex("7FAABB") - assert self.connection.sent == bytearray() + assert self.connection == bytearray() - def test_receive(self): - self.connection.receive(bytearray.fromhex("7F")) - self.connection.receive(bytearray.fromhex("AABB")) + def test_remaining(self): + self.connection.write(bytearray.fromhex("7FAABB")) - assert self.connection.received == bytearray.fromhex("7FAABB") + assert self.connection.remaining == 3 - def test_remaining(self): - self.connection.receive(bytearray.fromhex("7F")) - self.connection.receive(bytearray.fromhex("AABB")) + def test_reset(self): + self.connection.write(b"abcdef") + + assert self.connection.read(3) == b"abc" + self.connection.reset() + assert self.connection.read(6) == b"abcdef" + + def test_clear_only_already_read(self): + self.connection.write(b"abcdef") + assert self.connection.read(2) == b"ab" + + self.connection.clear(only_already_read=True) + + assert self.connection == bytearray(b"cdef") + assert self.connection.remaining == 4 + + def test_unread_view(self): + self.connection.write(b"abcdef") + assert self.connection.read(2) == b"ab" + + assert bytes(self.connection.unread_view()) == b"cdef" - assert self.connection.remaining() == 3 + def test_flush_only_returns_unread_data(self): + self.connection.write(b"abcdef") + assert self.connection.read(2) == b"ab" + + assert self.connection.flush() == b"cdef" + assert self.connection == bytearray() def test_send(self): self.connection.write(bytearray.fromhex("7F")) @@ -37,13 +59,13 @@ def test_send(self): assert self.connection.flush() == bytearray.fromhex("7FAABB") def test_read(self): - self.connection.receive(bytearray.fromhex("7FAABB")) + self.connection.write(bytearray.fromhex("7FAABB")) assert self.connection.read(2) == bytearray.fromhex("7FAA") assert self.connection.read(1) == bytearray.fromhex("BB") def _assert_varint_read_write(self, hexstr, value) -> None: - self.connection.receive(bytearray.fromhex(hexstr)) + self.connection.write(bytearray.fromhex(hexstr)) assert self.connection.read_varint() == value self.connection.write_varint(value) @@ -59,19 +81,25 @@ def test_varint_cases(self): self._assert_varint_read_write("8080808008", -2147483648) def test_read_invalid_varint(self): - self.connection.receive(bytearray.fromhex("FFFFFFFF80")) + self.connection.write(bytearray.fromhex("FFFFFFFF10")) - with pytest.raises(IOError, match=r"^Received varint is too big!$"): + with pytest.raises(IOError, match=r"^Received varint was outside the range of 32-bit int.$"): self.connection.read_varint() def test_write_invalid_varint(self): - with pytest.raises(ValueError, match=r'^The value "2147483648" is too big to send in a varint$'): + with pytest.raises( + ValueError, + match=r"^Can't convert number 2147483648 into 32-bit twos complement format - out of range$", + ): self.connection.write_varint(2147483648) - with pytest.raises(ValueError, match=r'^The value "-2147483649" is too big to send in a varint$'): + with pytest.raises( + ValueError, + match=r"^Can't convert number -2147483649 into 32-bit twos complement format - out of range$", + ): self.connection.write_varint(-2147483649) def test_read_utf(self): - self.connection.receive(bytearray.fromhex("0D48656C6C6F2C20776F726C6421")) + self.connection.write(bytearray.fromhex("0D48656C6C6F2C20776F726C6421")) assert self.connection.read_utf() == "Hello, world!" @@ -86,7 +114,7 @@ def test_read_empty_utf(self): assert self.connection.flush() == bytearray.fromhex("00") def test_read_ascii(self): - self.connection.receive(bytearray.fromhex("48656C6C6F2C20776F726C642100")) + self.connection.write(bytearray.fromhex("48656C6C6F2C20776F726C642100")) assert self.connection.read_ascii() == "Hello, world!" @@ -101,131 +129,131 @@ def test_read_empty_ascii(self): assert self.connection.flush() == bytearray.fromhex("00") def test_read_short_negative(self): - self.connection.receive(bytearray.fromhex("8000")) + self.connection.write(bytearray.fromhex("8000")) - assert self.connection.read_short() == -32768 + assert self.connection.read_value(StructFormat.SHORT) == -32768 def test_write_short_negative(self): - self.connection.write_short(-32768) + self.connection.write_value(StructFormat.SHORT, -32768) assert self.connection.flush() == bytearray.fromhex("8000") def test_read_short_positive(self): - self.connection.receive(bytearray.fromhex("7FFF")) + self.connection.write(bytearray.fromhex("7FFF")) - assert self.connection.read_short() == 32767 + assert self.connection.read_value(StructFormat.SHORT) == 32767 def test_write_short_positive(self): - self.connection.write_short(32767) + self.connection.write_value(StructFormat.SHORT, 32767) assert self.connection.flush() == bytearray.fromhex("7FFF") def test_read_ushort_positive(self): - self.connection.receive(bytearray.fromhex("8000")) + self.connection.write(bytearray.fromhex("8000")) - assert self.connection.read_ushort() == 32768 + assert self.connection.read_value(StructFormat.USHORT) == 32768 def test_write_ushort_positive(self): - self.connection.write_ushort(32768) + self.connection.write_value(StructFormat.USHORT, 32768) assert self.connection.flush() == bytearray.fromhex("8000") def test_read_int_negative(self): - self.connection.receive(bytearray.fromhex("80000000")) + self.connection.write(bytearray.fromhex("80000000")) - assert self.connection.read_int() == -2147483648 + assert self.connection.read_value(StructFormat.INT) == -2147483648 def test_write_int_negative(self): - self.connection.write_int(-2147483648) + self.connection.write_value(StructFormat.INT, -2147483648) assert self.connection.flush() == bytearray.fromhex("80000000") def test_read_int_positive(self): - self.connection.receive(bytearray.fromhex("7FFFFFFF")) + self.connection.write(bytearray.fromhex("7FFFFFFF")) - assert self.connection.read_int() == 2147483647 + assert self.connection.read_value(StructFormat.INT) == 2147483647 def test_write_int_positive(self): - self.connection.write_int(2147483647) + self.connection.write_value(StructFormat.INT, 2147483647) assert self.connection.flush() == bytearray.fromhex("7FFFFFFF") def test_read_uint_positive(self): - self.connection.receive(bytearray.fromhex("80000000")) + self.connection.write(bytearray.fromhex("80000000")) - assert self.connection.read_uint() == 2147483648 + assert self.connection.read_value(StructFormat.UINT) == 2147483648 def test_write_uint_positive(self): - self.connection.write_uint(2147483648) + self.connection.write_value(StructFormat.UINT, 2147483648) assert self.connection.flush() == bytearray.fromhex("80000000") def test_read_long_negative(self): - self.connection.receive(bytearray.fromhex("8000000000000000")) + self.connection.write(bytearray.fromhex("8000000000000000")) - assert self.connection.read_long() == -9223372036854775808 + assert self.connection.read_value(StructFormat.LONGLONG) == -9223372036854775808 def test_write_long_negative(self): - self.connection.write_long(-9223372036854775808) + self.connection.write_value(StructFormat.LONGLONG, -9223372036854775808) assert self.connection.flush() == bytearray.fromhex("8000000000000000") def test_read_long_positive(self): - self.connection.receive(bytearray.fromhex("7FFFFFFFFFFFFFFF")) + self.connection.write(bytearray.fromhex("7FFFFFFFFFFFFFFF")) - assert self.connection.read_long() == 9223372036854775807 + assert self.connection.read_value(StructFormat.LONGLONG) == 9223372036854775807 def test_write_long_positive(self): - self.connection.write_long(9223372036854775807) + self.connection.write_value(StructFormat.LONGLONG, 9223372036854775807) assert self.connection.flush() == bytearray.fromhex("7FFFFFFFFFFFFFFF") def test_read_ulong_positive(self): - self.connection.receive(bytearray.fromhex("8000000000000000")) + self.connection.write(bytearray.fromhex("8000000000000000")) - assert self.connection.read_ulong() == 9223372036854775808 + assert self.connection.read_value(StructFormat.ULONGLONG) == 9223372036854775808 def test_write_ulong_positive(self): - self.connection.write_ulong(9223372036854775808) + self.connection.write_value(StructFormat.ULONGLONG, 9223372036854775808) assert self.connection.flush() == bytearray.fromhex("8000000000000000") @pytest.mark.parametrize(("as_bytes", "as_bool"), [("01", True), ("00", False)]) def test_read_bool(self, as_bytes: str, as_bool: bool) -> None: - self.connection.receive(bytearray.fromhex(as_bytes)) + self.connection.write(bytearray.fromhex(as_bytes)) - assert self.connection.read_bool() is as_bool + assert self.connection.read_value(StructFormat.BOOL) is as_bool @pytest.mark.parametrize(("as_bytes", "as_bool"), [("01", True), ("00", False)]) def test_write_bool(self, as_bytes: str, as_bool: bool) -> None: - self.connection.write_bool(as_bool) + self.connection.write_value(StructFormat.BOOL, as_bool) assert self.connection.flush() == bytearray.fromhex(as_bytes) - def test_read_buffer(self): - self.connection.receive(bytearray.fromhex("027FAA")) - buffer = self.connection.read_buffer() + def test_read_bytearray(self): + self.connection.write(bytearray.fromhex("027FAA")) - assert buffer.received == bytearray.fromhex("7FAA") - assert self.connection.flush() == bytearray() + assert self.connection.read_bytearray() == bytearray.fromhex("7FAA") - def test_write_buffer(self): - buffer = Connection() - buffer.write(bytearray.fromhex("7FAA")) - self.connection.write_buffer(buffer) + def test_write_bytearray(self): + self.connection.write_bytearray(bytearray.fromhex("7FAA")) assert self.connection.flush() == bytearray.fromhex("027FAA") def test_read_empty(self): - self.connection.received = bytearray() - - with pytest.raises(IOError, match=r"^Not enough data to read! 0 < 1$"): + with pytest.raises( + IOError, + match=r"^Requested to read more data than available. Read 0 bytes: bytearray\(b''\), out of 1 requested bytes.$", + ): self.connection.read(1) def test_read_not_enough(self): - self.connection.received = bytearray(b"a") + self.connection.write(bytearray(b"a")) - with pytest.raises(IOError, match=r"^Not enough data to read! 1 < 2$"): + with pytest.raises( + IOError, + match=r"^Requested to read more data than available. Read 1 bytes: bytearray\(b'a'\), out of 2 requested bytes.$", + ): self.connection.read(2) @@ -236,24 +264,12 @@ def connection(self): socket = Mock() socket.recv = Mock() - socket.send = Mock() + socket.sendall = Mock() with patch("socket.create_connection") as create_connection: create_connection.return_value = socket with TCPSocketConnection(test_addr) as connection: yield connection - def test_flush(self, connection): - with pytest.raises(TypeError, match=r"^TCPSocketConnection does not support flush\(\)$"): - connection.flush() - - def test_receive(self, connection): - with pytest.raises(TypeError, match=r"^TCPSocketConnection does not support receive\(\)$"): - connection.receive("") - - def test_remaining(self, connection): - with pytest.raises(TypeError, match=r"^TCPSocketConnection does not support remaining\(\)$"): - connection.remaining() - def test_read(self, connection): connection.socket.recv.return_value = bytearray.fromhex("7FAA") @@ -274,13 +290,13 @@ def test_read_not_enough(self, connection): def test_write(self, connection): connection.write(bytearray.fromhex("7FAA")) - connection.socket.send.assert_called_once_with(bytearray.fromhex("7FAA")) + connection.socket.sendall.assert_called_once_with(bytearray.fromhex("7FAA")) class TestUDPSocketConnection: @pytest.fixture(scope="class") def connection(self): - test_addr = Address("localhost", 1234) + test_addr = Address("127.0.0.1", 1234) socket = Mock() socket.recvfrom = Mock() @@ -290,16 +306,8 @@ def connection(self): with UDPSocketConnection(test_addr) as connection: yield connection - def test_flush(self, connection): - with pytest.raises(TypeError, match=r"^UDPSocketConnection does not support flush\(\)$"): - connection.flush() - - def test_receive(self, connection): - with pytest.raises(TypeError, match=r"^UDPSocketConnection does not support receive\(\)$"): - connection.receive("") - def test_remaining(self, connection): - assert connection.remaining() == 65535 + assert connection.remaining == 65535 def test_read(self, connection): connection.socket.recvfrom.return_value = [bytearray.fromhex("7FAA")] @@ -317,5 +325,5 @@ def test_write(self, connection): connection.socket.sendto.assert_called_once_with( bytearray.fromhex("7FAA"), - Address("localhost", 1234), + Address("127.0.0.1", 1234), ) diff --git a/tests/protocol/test_java_client.py b/tests/protocol/test_java_client.py index 39d5ba56..6acd54e2 100644 --- a/tests/protocol/test_java_client.py +++ b/tests/protocol/test_java_client.py @@ -5,14 +5,15 @@ import pytest from mcstatus._net.address import Address -from mcstatus._protocol.connection import Connection from mcstatus._protocol.java_client import JavaClient +from tests.protocol.helpers import SyncBufferConnection class TestJavaClient: def setup_method(self): + self.connection = SyncBufferConnection() self.java_client = JavaClient( - Connection(), # pyright: ignore[reportArgumentType] + self.connection, # pyright: ignore[reportArgumentType] address=Address("localhost", 25565), version=44, ) @@ -20,10 +21,10 @@ def setup_method(self): def test_handshake(self): self.java_client.handshake() - assert self.java_client.connection.flush() == bytearray.fromhex("0F002C096C6F63616C686F737463DD01") + assert self.connection.flush() == bytearray.fromhex("0F002C096C6F63616C686F737463DD01") def test_read_status(self): - self.java_client.connection.receive( + self.connection.receive( bytearray.fromhex( "7200707B226465736372697074696F6E223A2241204D696E65637261667420536572766572222C22706C6179657273223A7B2" "26D6178223A32302C226F6E6C696E65223A307D2C2276657273696F6E223A7B226E616D65223A22312E382D70726531222C22" @@ -37,15 +38,15 @@ def test_read_status(self): "players": {"max": 20, "online": 0}, "version": {"name": "1.8-pre1", "protocol": 44}, } - assert self.java_client.connection.flush() == bytearray.fromhex("0100") + assert self.connection.flush() == bytearray.fromhex("0100") def test_read_status_invalid_json(self): - self.java_client.connection.receive(bytearray.fromhex("0300017B")) + self.connection.receive(bytearray.fromhex("0300017B")) with pytest.raises(IOError, match=r"^Received invalid JSON$"): self.java_client.read_status() def test_read_status_invalid_reply(self): - self.java_client.connection.receive( + self.connection.receive( # no motd, see also #922 bytearray.fromhex( "4F004D7B22706C6179657273223A7B226D6178223A32302C226F6E6C696E65223A307D2C2276657273696F6E223A7B226E616" @@ -56,91 +57,75 @@ def test_read_status_invalid_reply(self): self.java_client.read_status() def test_read_status_invalid_status(self): - self.java_client.connection.receive(bytearray.fromhex("0105")) + self.connection.receive(bytearray.fromhex("0105")) with pytest.raises(IOError, match=r"^Received invalid status response packet.$"): self.java_client.read_status() def test_test_ping(self): - self.java_client.connection.receive(bytearray.fromhex("09010000000000DD7D1C")) + self.connection.receive(bytearray.fromhex("09010000000000DD7D1C")) self.java_client.ping_token = 14515484 assert self.java_client.test_ping() >= 0 - assert self.java_client.connection.flush() == bytearray.fromhex("09010000000000DD7D1C") + assert self.connection.flush() == bytearray.fromhex("09010000000000DD7D1C") def test_test_ping_invalid(self): - self.java_client.connection.receive(bytearray.fromhex("011F")) + self.connection.receive(bytearray.fromhex("011F")) self.java_client.ping_token = 14515484 with pytest.raises(IOError, match=r"^Received invalid ping response packet.$"): self.java_client.test_ping() def test_test_ping_wrong_token(self): - self.java_client.connection.receive(bytearray.fromhex("09010000000000DD7D1C")) + self.connection.receive(bytearray.fromhex("09010000000000DD7D1C")) self.java_client.ping_token = 12345 with pytest.raises(IOError, match=r"^Received mangled ping response \(expected token 12345, got 14515484\)$"): self.java_client.test_ping() + # Windows CI can occasionally measure <1ms despite a 1ms sleep; + # see https://github.com/py-mine/mcstatus/issues/442. @pytest.mark.flaky(reruns=5, condition=sys.platform.startswith("win32")) def test_latency_is_real_number(self): - """``time.perf_counter`` returns fractional seconds, we must convert it to milliseconds.""" + self.connection.receive( + bytearray.fromhex( + "7200707B226465736372697074696F6E223A2241204D696E65637261667420536572766572222C22706C6179657273223A7B2" + "26D6178223A32302C226F6E6C696E65223A307D2C2276657273696F6E223A7B226E616D65223A22312E382D70726531222C22" + "70726F746F636F6C223A34347D7D" + ) + ) - def mocked_read_buffer(): + def mocked_read_bytearray() -> bytes: time.sleep(0.001) - return mock.DEFAULT - - with mock.patch.object(Connection, "read_buffer") as mocked: - mocked.side_effect = mocked_read_buffer - mocked.return_value.read_varint.return_value = 0 - mocked.return_value.read_utf.return_value = """ - { - "description": "A Minecraft Server", - "players": {"max": 20, "online": 0}, - "version": {"name": "1.8-pre1", "protocol": 44} - } - """ - java_client = JavaClient( - Connection(), # pyright: ignore[reportArgumentType] - address=Address("localhost", 25565), - version=44, - ) + return original_read_bytearray() - java_client.connection.receive( - bytearray.fromhex( - "7200707B226465736372697074696F6E223A2241204D696E65637261667420536572766572222C22706C6179657273223A" - "7B226D6178223A32302C226F6E6C696E65223A307D2C2276657273696F6E223A7B226E616D65223A22312E382D70726531" - "222C2270726F746F636F6C223A34347D7D" - ) - ) - # we slept 1ms, so this should be always ~1. - assert java_client.read_status().latency >= 1 + original_read_bytearray = self.connection.read_bytearray + + with mock.patch.object(self.connection, "read_bytearray", side_effect=mocked_read_bytearray): + # Latency should be in milliseconds, so somewhere just above 1 + # + # We give it a pretty big leeway with the max here, as the MacOS CI runs can + # sometimes take quite long (upwards of 10s). + latency = self.java_client.read_status().latency + assert 1 <= latency <= 20 + # Windows CI can occasionally measure <1ms despite a 1ms sleep; + # see https://github.com/py-mine/mcstatus/issues/442. @pytest.mark.flaky(reruns=5, condition=sys.platform.startswith("win32")) def test_test_ping_is_in_milliseconds(self): - """``time.perf_counter`` returns fractional seconds, we must convert it to milliseconds.""" + self.java_client.ping_token = 14515484 + self.connection.receive(bytearray.fromhex("09010000000000DD7D1C")) - def mocked_read_buffer(): + def mocked_read_bytearray() -> bytes: time.sleep(0.001) - return mock.DEFAULT - - with mock.patch.object(Connection, "read_buffer") as mocked: - mocked.side_effect = mocked_read_buffer - mocked.return_value.read_varint.return_value = 1 - mocked.return_value.read_long.return_value = 123456789 - java_client = JavaClient( - Connection(), # pyright: ignore[reportArgumentType] - address=Address("localhost", 25565), - version=44, - ping_token=123456789, - ) + return original_read_bytearray() - java_client.connection.receive( - bytearray.fromhex( - "7200707B226465736372697074696F6E223A2241204D696E65637261667420536572766572222C22706C6179657273223A" - "7B226D6178223A32302C226F6E6C696E65223A307D2C2276657273696F6E223A7B226E616D65223A22312E382D70726531" - "222C2270726F746F636F6C223A34347D7D" - ) - ) - # we slept 1ms, so this should be always ~1. - assert java_client.test_ping() >= 1 + original_read_bytearray = self.connection.read_bytearray + + with mock.patch.object(self.connection, "read_bytearray", side_effect=mocked_read_bytearray): + # Latency should be in milliseconds, so somewhere just above 1 + # + # We give it a pretty big leeway with the max here, as the MacOS CI runs can + # sometimes take quite long (upwards of 10s). + latency = self.java_client.test_ping() + assert 1 <= latency <= 20 diff --git a/tests/protocol/test_java_client_async.py b/tests/protocol/test_java_client_async.py index d126de58..5c7e29b4 100644 --- a/tests/protocol/test_java_client_async.py +++ b/tests/protocol/test_java_client_async.py @@ -1,42 +1,30 @@ import asyncio import sys -import time from unittest import mock import pytest from mcstatus._net.address import Address -from mcstatus._protocol.connection import Connection from mcstatus._protocol.java_client import AsyncJavaClient - - -def async_decorator(f): - def wrapper(*args, **kwargs): - return asyncio.run(f(*args, **kwargs)) - - return wrapper - - -class FakeAsyncConnection(Connection): - async def read_buffer(self): # pyright: ignore[reportIncompatibleMethodOverride] - return super().read_buffer() +from tests.protocol.helpers import AsyncBufferConnection, async_decorator class TestAsyncJavaClient: def setup_method(self): + self.connection = AsyncBufferConnection() self.java_client = AsyncJavaClient( - FakeAsyncConnection(), # pyright: ignore[reportArgumentType] + self.connection, # pyright: ignore[reportArgumentType] address=Address("localhost", 25565), version=44, ) def test_handshake(self): - self.java_client.handshake() + async_decorator(self.java_client.handshake)() - assert self.java_client.connection.flush() == bytearray.fromhex("0F002C096C6F63616C686F737463DD01") + assert self.connection.flush() == bytearray.fromhex("0F002C096C6F63616C686F737463DD01") def test_read_status(self): - self.java_client.connection.receive( + self.connection.receive( bytearray.fromhex( "7200707B226465736372697074696F6E223A2241204D696E65637261667420536572766572222C22706C6179657273223A7B2" "26D6178223A32302C226F6E6C696E65223A307D2C2276657273696F6E223A7B226E616D65223A22312E382D70726531222C22" @@ -50,15 +38,15 @@ def test_read_status(self): "players": {"max": 20, "online": 0}, "version": {"name": "1.8-pre1", "protocol": 44}, } - assert self.java_client.connection.flush() == bytearray.fromhex("0100") + assert self.connection.flush() == bytearray.fromhex("0100") def test_read_status_invalid_json(self): - self.java_client.connection.receive(bytearray.fromhex("0300017B")) + self.connection.receive(bytearray.fromhex("0300017B")) with pytest.raises(IOError, match=r"^Received invalid JSON$"): async_decorator(self.java_client.read_status)() def test_read_status_invalid_reply(self): - self.java_client.connection.receive( + self.connection.receive( bytearray.fromhex( "4F004D7B22706C6179657273223A7B226D6178223A32302C226F6E6C696E65223A307D2C2276657273696F6E223A7B226E616" "D65223A22312E382D70726531222C2270726F746F636F6C223A34347D7D" @@ -68,88 +56,79 @@ def test_read_status_invalid_reply(self): async_decorator(self.java_client.read_status)() def test_read_status_invalid_status(self): - self.java_client.connection.receive(bytearray.fromhex("0105")) + self.connection.receive(bytearray.fromhex("0105")) with pytest.raises(IOError, match=r"^Received invalid status response packet.$"): async_decorator(self.java_client.read_status)() def test_test_ping(self): - self.java_client.connection.receive(bytearray.fromhex("09010000000000DD7D1C")) + self.connection.receive(bytearray.fromhex("09010000000000DD7D1C")) self.java_client.ping_token = 14515484 assert async_decorator(self.java_client.test_ping)() >= 0 - assert self.java_client.connection.flush() == bytearray.fromhex("09010000000000DD7D1C") + assert self.connection.flush() == bytearray.fromhex("09010000000000DD7D1C") def test_test_ping_invalid(self): - self.java_client.connection.receive(bytearray.fromhex("011F")) + self.connection.receive(bytearray.fromhex("011F")) self.java_client.ping_token = 14515484 with pytest.raises(IOError, match=r"^Received invalid ping response packet.$"): async_decorator(self.java_client.test_ping)() def test_test_ping_wrong_token(self): - self.java_client.connection.receive(bytearray.fromhex("09010000000000DD7D1C")) + self.connection.receive(bytearray.fromhex("09010000000000DD7D1C")) self.java_client.ping_token = 12345 with pytest.raises(IOError, match=r"^Received mangled ping response \(expected token 12345, got 14515484\)$"): async_decorator(self.java_client.test_ping)() @pytest.mark.asyncio + # Windows CI can occasionally measure <1ms despite a 1ms sleep; + # see https://github.com/py-mine/mcstatus/issues/442. @pytest.mark.flaky(reruns=5, condition=sys.platform.startswith("win32")) async def test_latency_is_real_number(self): - """``time.perf_counter`` returns fractional seconds, we must convert it to milliseconds.""" - - def mocked_read_buffer(): - time.sleep(0.001) - return mock.DEFAULT - - with mock.patch.object(FakeAsyncConnection, "read_buffer") as mocked: - mocked.side_effect = mocked_read_buffer - # overwrite `async` here - mocked.return_value.read_varint = lambda: 0 - mocked.return_value.read_utf = lambda: ( - """ - { - "description": "A Minecraft Server", - "players": {"max": 20, "online": 0}, - "version": {"name": "1.8-pre1", "protocol": 44} - } - """ - ) - java_client = AsyncJavaClient( - FakeAsyncConnection(), # pyright: ignore[reportArgumentType] - address=Address("localhost", 25565), - version=44, + self.connection.receive( + bytearray.fromhex( + "7200707B226465736372697074696F6E223A2241204D696E65637261667420536572766572222C22706C6179657273223A7B2" + "26D6178223A32302C226F6E6C696E65223A307D2C2276657273696F6E223A7B226E616D65223A22312E382D70726531222C22" + "70726F746F636F6C223A34347D7D" ) + ) - java_client.connection.receive( - bytearray.fromhex( - "7200707B226465736372697074696F6E223A2241204D696E65637261667420536572766572222C22706C6179657273223A" - "7B226D6178223A32302C226F6E6C696E65223A307D2C2276657273696F6E223A7B226E616D65223A22312E382D70726531" - "222C2270726F746F636F6C223A34347D7D" - ) - ) - # we slept 1ms, so this should be always ~1. - assert (await java_client.read_status()).latency >= 1 + async def mocked_read_bytearray() -> bytes: + """Delay reads while still delegating to the real async connection implementation.""" + await asyncio.sleep(0.001) + return await original_read_bytearray() + + original_read_bytearray = self.connection.read_bytearray + + with mock.patch.object(self.connection, "read_bytearray", side_effect=mocked_read_bytearray): + # Latency should be in milliseconds, so somewhere just above 1 + # + # We give it a pretty big leeway with the max here, as the MacOS CI runs can + # sometimes take quite long (upwards of 10s). + latency = (await self.java_client.read_status()).latency + assert 1 <= latency <= 20 @pytest.mark.asyncio + # Windows CI can occasionally measure <1ms despite a 1ms sleep; + # see https://github.com/py-mine/mcstatus/issues/442. @pytest.mark.flaky(reruns=5, condition=sys.platform.startswith("win32")) async def test_test_ping_is_in_milliseconds(self): - """``time.perf_counter`` returns fractional seconds, we must convert it to milliseconds.""" - - def mocked_read_buffer(): - time.sleep(0.001) - return mock.DEFAULT - - with mock.patch.object(FakeAsyncConnection, "read_buffer") as mocked: - mocked.side_effect = mocked_read_buffer - mocked.return_value.read_varint = lambda: 1 # overwrite `async` here - mocked.return_value.read_long = lambda: 123456789 # overwrite `async` here - java_client = AsyncJavaClient( - FakeAsyncConnection(), # pyright: ignore[reportArgumentType] - address=Address("localhost", 25565), - version=44, - ping_token=123456789, - ) - # we slept 1ms, so this should be always ~1. - assert await java_client.test_ping() >= 1 + self.java_client.ping_token = 14515484 + self.connection.receive(bytearray.fromhex("09010000000000DD7D1C")) + + async def mocked_read_bytearray() -> bytes: + """Delay reads while still delegating to the real async connection implementation.""" + await asyncio.sleep(0.001) + return await original_read_bytearray() + + original_read_bytearray = self.connection.read_bytearray + + with mock.patch.object(self.connection, "read_bytearray", side_effect=mocked_read_bytearray): + # Latency should be in milliseconds, so somewhere just above 1 + # + # We give it a pretty big leeway with the max here, as the MacOS CI runs can + # sometimes take quite long (upwards of 10s). + latency = await self.java_client.test_ping() + assert 1 <= latency <= 20 diff --git a/tests/protocol/test_legacy_client.py b/tests/protocol/test_legacy_client.py index cf55abe2..3e6dfbf1 100644 --- a/tests/protocol/test_legacy_client.py +++ b/tests/protocol/test_legacy_client.py @@ -1,9 +1,9 @@ import pytest -from mcstatus._protocol.connection import Connection from mcstatus._protocol.legacy_client import LegacyClient from mcstatus.motd import Motd from mcstatus.responses.legacy import LegacyStatusPlayers, LegacyStatusResponse, LegacyStatusVersion +from tests.protocol.helpers import SyncBufferConnection def test_invalid_kick_reason(): @@ -40,7 +40,7 @@ def test_parse_response(response: bytes, expected: LegacyStatusResponse): def test_invalid_packet_id(): - socket = Connection() + socket = SyncBufferConnection() socket.receive(bytearray.fromhex("00")) server = LegacyClient(socket) with pytest.raises(IOError, match=r"^Received invalid packet ID$"): diff --git a/tests/protocol/test_query_client.py b/tests/protocol/test_query_client.py index b4f36abb..b46e8a87 100644 --- a/tests/protocol/test_query_client.py +++ b/tests/protocol/test_query_client.py @@ -1,24 +1,25 @@ from unittest.mock import Mock -from mcstatus._protocol.connection import Connection from mcstatus._protocol.query_client import QueryClient from mcstatus.motd import Motd +from tests.protocol.helpers import SyncDatagramConnection class TestQueryClient: def setup_method(self): - self.query_client = QueryClient(Connection()) # pyright: ignore[reportArgumentType] + self.connection = SyncDatagramConnection() + self.query_client = QueryClient(self.connection) # pyright: ignore[reportArgumentType] def test_handshake(self): - self.query_client.connection.receive(bytearray.fromhex("090000000035373033353037373800")) + self.connection.receive(bytearray.fromhex("090000000035373033353037373800")) self.query_client.handshake() - conn_bytes = self.query_client.connection.flush() + conn_bytes = self.connection.flush() assert conn_bytes[:3] == bytearray.fromhex("FEFD09") assert self.query_client.challenge == 570350778 def test_query(self): - self.query_client.connection.receive( + self.connection.receive( bytearray.fromhex( "00000000000000000000000000000000686f73746e616d650041204d696e656372616674205365727665720067616d6574797" "06500534d500067616d655f6964004d494e4543524146540076657273696f6e00312e3800706c7567696e7300006d61700077" @@ -28,7 +29,7 @@ def test_query(self): ) ) response = self.query_client.read_query() - conn_bytes = self.query_client.connection.flush() + conn_bytes = self.connection.flush() assert conn_bytes[:3] == bytearray.fromhex("FEFD00") assert conn_bytes[7:] == bytearray.fromhex("0000000000000000") assert response.raw == { @@ -46,7 +47,7 @@ def test_query(self): assert response.players.list == ["Dinnerbone", "Djinnibone", "Steve"] def test_query_handles_unorderd_map_response(self): - self.query_client.connection.receive( + self.connection.receive( bytearray( b"\x00\x00\x00\x00\x00GeyserMC\x00\x80\x00hostname\x00Geyser\x00hostip\x001.1.1.1\x00plugins\x00\x00numplayers" b"\x001\x00gametype\x00SMP\x00maxplayers\x00100\x00hostport\x0019132\x00version\x00Geyser" @@ -54,14 +55,14 @@ def test_query_handles_unorderd_map_response(self): ) ) response = self.query_client.read_query() - self.query_client.connection.flush() + self.connection.flush() assert response.raw["game_id"] == "MINECRAFT" assert response.motd == Motd.parse("Geyser") assert response.software.version == "Geyser (git-master-0fd903e) 1.18.10" def test_query_handles_unicode_motd_with_nulls(self): - self.query_client.connection.receive( + self.connection.receive( bytearray( b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00hostname\x00\x00*K\xd5\x00gametype\x00SMP" b"\x00game_id\x00MINECRAFT\x00version\x001.16.5\x00plugins\x00Paper on 1.16.5-R0.1-SNAPSHOT\x00map\x00world" @@ -70,13 +71,13 @@ def test_query_handles_unicode_motd_with_nulls(self): ) ) response = self.query_client.read_query() - self.query_client.connection.flush() + self.connection.flush() assert response.raw["game_id"] == "MINECRAFT" assert response.motd == Motd.parse("\x00*KÕ") def test_query_handles_unicode_motd_with_2a00_at_the_start(self): - self.query_client.connection.receive( + self.connection.receive( bytearray.fromhex( "00000000000000000000000000000000686f73746e616d6500006f746865720067616d657479706500534d500067616d655f6964004d" "494e4543524146540076657273696f6e00312e31382e3100706c7567696e7300006d617000776f726c64006e756d706c617965727300" @@ -85,7 +86,7 @@ def test_query_handles_unicode_motd_with_2a00_at_the_start(self): ) ) response = self.query_client.read_query() - self.query_client.connection.flush() + self.connection.flush() assert response.raw["game_id"] == "MINECRAFT" assert response.motd == Motd.parse("\x00other") # "\u2a00other" is actually what is expected, @@ -96,12 +97,12 @@ def test_session_id(self): def session_id(): return 0x01010101 - self.query_client.connection.receive(bytearray.fromhex("090000000035373033353037373800")) + self.connection.receive(bytearray.fromhex("090000000035373033353037373800")) self.query_client._generate_session_id = Mock() self.query_client._generate_session_id = session_id self.query_client.handshake() - conn_bytes = self.query_client.connection.flush() + conn_bytes = self.connection.flush() assert conn_bytes[:3] == bytearray.fromhex("FEFD09") assert conn_bytes[3:] == session_id().to_bytes(4, byteorder="big") assert self.query_client.challenge == 570350778 diff --git a/tests/protocol/test_query_client_async.py b/tests/protocol/test_query_client_async.py index dfc9f58f..3c091bb4 100644 --- a/tests/protocol/test_query_client_async.py +++ b/tests/protocol/test_query_client_async.py @@ -1,29 +1,21 @@ -from mcstatus._protocol.connection import Connection from mcstatus._protocol.query_client import AsyncQueryClient -from tests.protocol.test_java_client_async import async_decorator - - -class FakeUDPAsyncConnection(Connection): - async def read(self, length): # pyright: ignore[reportIncompatibleMethodOverride] - return super().read(length) - - async def write(self, data): # pyright: ignore[reportIncompatibleMethodOverride] - return super().write(data) +from tests.protocol.helpers import AsyncDatagramConnection, async_decorator class TestAsyncQueryClient: def setup_method(self): - self.query_client = AsyncQueryClient(FakeUDPAsyncConnection()) # pyright: ignore[reportArgumentType] + self.connection = AsyncDatagramConnection() + self.query_client = AsyncQueryClient(self.connection) # pyright: ignore[reportArgumentType] def test_handshake(self): - self.query_client.connection.receive(bytearray.fromhex("090000000035373033353037373800")) + self.connection.receive(bytearray.fromhex("090000000035373033353037373800")) async_decorator(self.query_client.handshake)() - conn_bytes = self.query_client.connection.flush() + conn_bytes = self.connection.flush() assert conn_bytes[:3] == bytearray.fromhex("FEFD09") assert self.query_client.challenge == 570350778 def test_query(self): - self.query_client.connection.receive( + self.connection.receive( bytearray.fromhex( "00000000000000000000000000000000686f73746e616d650041204d696e656372616674205365727665720067616d6574797" "06500534d500067616d655f6964004d494e4543524146540076657273696f6e00312e3800706c7567696e7300006d61700077" @@ -33,7 +25,7 @@ def test_query(self): ) ) response = async_decorator(self.query_client.read_query)() - conn_bytes = self.query_client.connection.flush() + conn_bytes = self.connection.flush() assert conn_bytes[:3] == bytearray.fromhex("FEFD00") assert conn_bytes[7:] == bytearray.fromhex("0000000000000000") assert response.raw == { diff --git a/tests/protocol/test_timeout.py b/tests/protocol/test_timeout.py index d6c0b2ec..32a8bcd9 100644 --- a/tests/protocol/test_timeout.py +++ b/tests/protocol/test_timeout.py @@ -1,12 +1,12 @@ import asyncio import typing from asyncio.exceptions import TimeoutError as AsyncioTimeoutError -from unittest.mock import patch +from unittest.mock import AsyncMock, Mock, patch import pytest from mcstatus._net.address import Address -from mcstatus._protocol.connection import TCPAsyncSocketConnection +from mcstatus._protocol.io.connection import TCPAsyncSocketConnection class FakeAsyncStream(asyncio.StreamReader): @@ -26,3 +26,57 @@ async def test_tcp_socket_read(self): async with TCPAsyncSocketConnection(Address("dummy_address", 1234), timeout=0.01) as tcp_async_socket: with pytest.raises(AsyncioTimeoutError): await tcp_async_socket.read(10) + + @pytest.mark.asyncio + async def test_tcp_socket_read_partial_data_then_eof(self): + """Raise when the stream ends after only partial payload delivery.""" + tcp_async_socket = TCPAsyncSocketConnection(Address("dummy_address", 1234), timeout=0.01) + tcp_async_socket.reader = Mock(read=AsyncMock(side_effect=[b"a", b""])) + + with pytest.raises( + OSError, + match=( + r"^Server stopped responding \(got 1 bytes, but expected 2 bytes\). " + r"Partial obtained data: bytearray\(b'a'\)$" + ), + ): + await tcp_async_socket.read(2) + + @pytest.mark.asyncio + async def test_tcp_socket_read_eof_without_any_data(self): + """Raise when the stream immediately ends without returning any data.""" + tcp_async_socket = TCPAsyncSocketConnection(Address("dummy_address", 1234), timeout=0.01) + tcp_async_socket.reader = Mock(read=AsyncMock(return_value=b"")) + + with pytest.raises(OSError, match=r"^Server did not respond with any information!$"): + await tcp_async_socket.read(2) + + @pytest.mark.asyncio + async def test_tcp_socket_write_awaits_drain(self): + """Ensure writes await ``drain`` so buffered data is flushed.""" + writer = Mock() + writer.write = Mock() + writer.drain = AsyncMock() + + tcp_async_socket = TCPAsyncSocketConnection(Address("dummy_address", 1234), timeout=0.01) + tcp_async_socket.writer = writer + + await tcp_async_socket.write(b"hello") + + writer.write.assert_called_once_with(b"hello") + writer.drain.assert_awaited_once_with() + + @pytest.mark.asyncio + async def test_tcp_socket_close_waits_for_writer(self): + """Ensure close awaits ``wait_closed`` on the stream writer.""" + writer = Mock() + writer.close = Mock() + writer.wait_closed = AsyncMock() + + tcp_async_socket = TCPAsyncSocketConnection(Address("dummy_address", 1234), timeout=0.01) + tcp_async_socket.writer = typing.cast("asyncio.StreamWriter", writer) + + await tcp_async_socket.close() + + writer.close.assert_called_once_with() + writer.wait_closed.assert_awaited_once_with() diff --git a/tests/test_server.py b/tests/test_server.py index 57c1d741..9b3d05de 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -1,58 +1,14 @@ from __future__ import annotations import asyncio -from typing import SupportsIndex, TYPE_CHECKING, TypeAlias from unittest.mock import call, patch import pytest import pytest_asyncio from mcstatus._net.address import Address -from mcstatus._protocol.connection import BaseAsyncReadSyncWriteConnection, Connection from mcstatus.server import BedrockServer, JavaServer, LegacyServer - -if TYPE_CHECKING: - from collections.abc import Iterable - -BytesConvertable: TypeAlias = "SupportsIndex | Iterable[SupportsIndex]" - - -class AsyncConnection(BaseAsyncReadSyncWriteConnection): - def __init__(self) -> None: - self.sent = bytearray() - self.received = bytearray() - - async def read(self, length: int) -> bytearray: - """Return :attr:`.received` up to length bytes, then cut received up to that point.""" - if len(self.received) < length: - raise OSError(f"Not enough data to read! {len(self.received)} < {length}") - - result = self.received[:length] - self.received = self.received[length:] - return result - - def write(self, data: Connection | str | bytearray | bytes) -> None: - """Extend :attr:`.sent` from ``data``.""" - if isinstance(data, Connection): - data = data.flush() - if isinstance(data, str): - data = bytearray(data, "utf-8") - self.sent.extend(data) - - def receive(self, data: BytesConvertable | bytearray) -> None: - """Extend :attr:`.received` with ``data``.""" - if not isinstance(data, bytearray): - data = bytearray(data) - self.received.extend(data) - - def remaining(self) -> int: - """Return length of :attr:`.received`.""" - return len(self.received) - - def flush(self) -> bytearray: - """Return :attr:`.sent`, also clears :attr:`.sent`.""" - result, self.sent = self.sent, bytearray() - return result +from tests.protocol.helpers import AsyncBufferConnection, SyncBufferConnection, SyncDatagramConnection class MockProtocolFactory(asyncio.Protocol): @@ -86,13 +42,15 @@ def resume_writing(self): @pytest_asyncio.fixture() async def create_mock_packet_server(): + """Create a temporary asyncio packet servers used by tests.""" event_loop = asyncio.get_running_loop() servers = [] async def create_server(port, data_expected_to_receive, data_to_respond_with): + """Start a server that validates one request pattern and returns a fixed payload.""" server = await event_loop.create_server( lambda: MockProtocolFactory(data_expected_to_receive, data_to_respond_with), - host="localhost", + host="127.0.0.1", port=port, ) servers.append(server) @@ -107,7 +65,7 @@ async def create_server(port, data_expected_to_receive, data_to_respond_with): class TestBedrockServer: def setup_method(self): - self.server = BedrockServer("localhost") + self.server = BedrockServer("127.0.0.1") def test_default_port(self): assert self.server.address.port == 19132 @@ -126,7 +84,7 @@ async def test_async_ping(self, unused_tcp_port, create_mock_packet_server): data_expected_to_receive=bytearray.fromhex("09010000000001C54246"), data_to_respond_with=bytearray.fromhex("0F002F096C6F63616C686F737463DD0109010000000001C54246"), ) - minecraft_server = JavaServer("localhost", port=unused_tcp_port) + minecraft_server = JavaServer("127.0.0.1", port=unused_tcp_port) latency = await minecraft_server.async_ping(ping_token=29704774, version=47) assert latency >= 0 @@ -140,7 +98,7 @@ async def test_async_lookup_constructor(self): def test_java_server_with_query_port(): with patch("mcstatus.server.JavaServer._retry_query") as patched_query_func: - server = JavaServer("localhost", query_port=12345) + server = JavaServer("127.0.0.1", query_port=12345) server.query() assert server.query_port == 12345 assert patched_query_func.call_args == call(Address("127.0.0.1", port=12345), tries=3) @@ -149,7 +107,7 @@ def test_java_server_with_query_port(): @pytest.mark.asyncio async def test_java_server_with_query_port_async(): with patch("mcstatus.server.JavaServer._retry_async_query") as patched_query_func: - server = JavaServer("localhost", query_port=12345) + server = JavaServer("127.0.0.1", query_port=12345) await server.async_query() assert server.query_port == 12345 assert patched_query_func.call_args == call(Address("127.0.0.1", port=12345), tries=3) @@ -157,8 +115,8 @@ async def test_java_server_with_query_port_async(): class TestJavaServer: def setup_method(self): - self.socket = Connection() - self.server = JavaServer("localhost") + self.socket = SyncBufferConnection() + self.server = JavaServer("127.0.0.1") def test_default_port(self): assert self.server.address.port == 25565 @@ -170,7 +128,7 @@ def test_ping(self): connection.return_value.__enter__.return_value = self.socket latency = self.server.ping(ping_token=29704774, version=47) - assert self.socket.flush() == bytearray.fromhex("0F002F096C6F63616C686F737463DD0109010000000001C54246") + assert self.socket.flush() == bytearray.fromhex("0F002F093132372E302E302E3163DD0109010000000001C54246") assert self.socket.remaining() == 0, "Data is pending to be read, but should be empty" assert latency >= 0 @@ -195,7 +153,7 @@ def test_status(self): connection.return_value.__enter__.return_value = self.socket info = self.server.status(version=47) - assert self.socket.flush() == bytearray.fromhex("0F002F096C6F63616C686F737463DD010100") + assert self.socket.flush() == bytearray.fromhex("0F002F093132372E302E302E3163DD010100") assert self.socket.remaining() == 0, "Data is pending to be read, but should be empty" assert info.raw == { "description": "A Minecraft Server", @@ -213,8 +171,9 @@ def test_status_retry(self): assert java_client.call_count == 3 def test_query(self): - self.socket.receive(bytearray.fromhex("090000000035373033353037373800")) - self.socket.receive( + socket = SyncDatagramConnection() + socket.receive(bytearray.fromhex("090000000035373033353037373800")) + socket.receive( bytearray.fromhex( "00000000000000000000000000000000686f73746e616d650041204d696e656372616674205365727665720067616d6574797" "06500534d500067616d655f6964004d494e4543524146540076657273696f6e00312e3800706c7567696e7300006d61700077" @@ -224,38 +183,30 @@ def test_query(self): ) ) - with patch("mcstatus._protocol.connection.Connection.remaining") as mock_remaining: - mock_remaining.side_effect = [15, 208] - - with ( - patch("mcstatus.server.UDPSocketConnection") as connection, - patch.object(self.server.address, "resolve_ip") as resolve_ip, - ): - connection.return_value.__enter__.return_value = self.socket - resolve_ip.return_value = "127.0.0.1" - info = self.server.query() - - conn_bytes = self.socket.flush() - assert conn_bytes[:3] == bytearray.fromhex("FEFD09") - assert info.raw == { - "hostname": "A Minecraft Server", - "gametype": "SMP", - "game_id": "MINECRAFT", - "version": "1.8", - "plugins": "", - "map": "world", - "numplayers": "3", - "maxplayers": "20", - "hostport": "25565", - "hostip": "192.168.56.1", - } + with patch("mcstatus.server.UDPSocketConnection") as connection: + connection.return_value.__enter__.return_value = socket + info = self.server.query() + + conn_bytes = socket.flush() + assert conn_bytes[:3] == bytearray.fromhex("FEFD09") + assert info.raw == { + "hostname": "A Minecraft Server", + "gametype": "SMP", + "game_id": "MINECRAFT", + "version": "1.8", + "plugins": "", + "map": "world", + "numplayers": "3", + "maxplayers": "20", + "hostport": "25565", + "hostip": "192.168.56.1", + } def test_query_retry(self): # Use a blank mock for the connection, we don't want to actually create any connections with patch("mcstatus.server.UDPSocketConnection"), patch("mcstatus.server.QueryClient") as query_client: query_client.side_effect = [RuntimeError, RuntimeError, RuntimeError] - with pytest.raises(RuntimeError, match=r"^$"), patch.object(self.server.address, "resolve_ip") as resolve_ip: # noqa: PT012 - resolve_ip.return_value = "127.0.0.1" + with pytest.raises(RuntimeError, match=r"^$"): self.server.query() assert query_client.call_count == 3 @@ -267,8 +218,8 @@ def test_lookup_constructor(self): class TestLegacyServer: def setup_method(self): - self.socket = Connection() - self.server = LegacyServer("localhost") + self.socket = SyncBufferConnection() + self.server = LegacyServer("127.0.0.1") def test_default_port(self): assert self.server.address.port == 25565 @@ -306,8 +257,8 @@ def test_status(self): class TestAsyncLegacyServer: def setup_method(self): - self.socket = AsyncConnection() - self.server = LegacyServer("localhost") + self.socket = AsyncBufferConnection() + self.server = LegacyServer("127.0.0.1") @pytest.mark.asyncio async def test_async_lookup_constructor(self): diff --git a/tests/utils/test_retry.py b/tests/utils/test_retry.py index 6be65834..e2861735 100644 --- a/tests/utils/test_retry.py +++ b/tests/utils/test_retry.py @@ -1,7 +1,7 @@ import pytest from mcstatus._utils.retry import retry -from tests.protocol.test_java_client_async import async_decorator +from tests.protocol.helpers import async_decorator def test_sync_success(): From e0a335dd6e4a0769cfe44325a8e58692974eeda3 Mon Sep 17 00:00:00 2001 From: ItsDrike Date: Thu, 26 Mar 2026 23:08:33 +0100 Subject: [PATCH 03/10] Fix some coverage issues --- .coveragerc | 1 + mcstatus/_protocol/io/base_io.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/.coveragerc b/.coveragerc index 51f4bf59..3cb20809 100644 --- a/.coveragerc +++ b/.coveragerc @@ -1,6 +1,7 @@ [report] exclude_lines = pragma: no cover + @overload ((t|typing)\.)?TYPE_CHECKING ^\s\.\.\.\s$ def __repr__ diff --git a/mcstatus/_protocol/io/base_io.py b/mcstatus/_protocol/io/base_io.py index 99746253..898a138c 100644 --- a/mcstatus/_protocol/io/base_io.py +++ b/mcstatus/_protocol/io/base_io.py @@ -461,7 +461,7 @@ async def _read_varuint(self, *, max_bits: int | None = None) -> int: value_max = (1 << (max_bits)) - 1 if max_bits is not None else float("inf") result = 0 - for i in count(): + for i in count(): # pragma: no branch byte = await self.read_value(StructFormat.UBYTE) # Read 7 least significant value bits in this byte, and shift them appropriately to be in the right place # then simply add them (OR) as additional 7 most significant bits in our result @@ -630,7 +630,7 @@ def _read_varuint(self, *, max_bits: int | None = None) -> int: value_max = (1 << (max_bits)) - 1 if max_bits is not None else float("inf") result = 0 - for i in count(): + for i in count(): # pragma: no branch byte = self.read_value(StructFormat.UBYTE) # Read 7 least significant value bits in this byte, and shift them appropriately to be in the right place # then simply add them (OR) as additional 7 most significant bits in our result From 32e9dd7952f42181b2efa17fd8e2978f903a4f84 Mon Sep 17 00:00:00 2001 From: ItsDrike Date: Thu, 26 Mar 2026 23:29:13 +0100 Subject: [PATCH 04/10] Remove str write suport with auto encode This was not used anywhere anyways and it's better to force being explicit than to "magically" encode into utf-8. --- mcstatus/_protocol/io/connection.py | 16 ++++------------ tests/protocol/helpers.py | 8 ++------ 2 files changed, 6 insertions(+), 18 deletions(-) diff --git a/mcstatus/_protocol/io/connection.py b/mcstatus/_protocol/io/connection.py index 09aedf69..ef4bdf9d 100644 --- a/mcstatus/_protocol/io/connection.py +++ b/mcstatus/_protocol/io/connection.py @@ -92,10 +92,8 @@ def read(self, length: int, /) -> bytearray: result.extend(new) return result - def write(self, data: str | bytes | bytearray, /) -> None: + def write(self, data: bytes | bytearray, /) -> None: """Send data on :attr:`.socket`.""" - if isinstance(data, str): - data = data.encode("utf-8") self.socket.sendall(data) @@ -128,10 +126,8 @@ def read(self, _length: int, /) -> bytearray: return result @override - def write(self, data: str | bytes | bytearray, /) -> None: + def write(self, data: bytes | bytearray, /) -> None: """Use :attr:`.socket` to send data to :attr:`.addr`.""" - if isinstance(data, str): - data = data.encode("utf-8") self.socket.sendto(data, self.addr) @@ -175,10 +171,8 @@ async def read(self, length: int, /) -> bytearray: return result @override - async def write(self, data: str | bytes | bytearray, /) -> None: + async def write(self, data: bytes | bytearray, /) -> None: """Write data to :attr:`.writer`.""" - if isinstance(data, str): - data = data.encode("utf-8") self.writer.write(data) await self.writer.drain() @@ -225,10 +219,8 @@ async def read(self, _length: int, /) -> bytearray: return bytearray(data) @override - async def write(self, data: str | bytes | bytearray, /) -> None: + async def write(self, data: bytes | bytearray, /) -> None: """Send data with :attr:`.stream`.""" - if isinstance(data, str): - data = data.encode("utf-8") await self.stream.send(data) def close(self) -> None: diff --git a/tests/protocol/helpers.py b/tests/protocol/helpers.py index cdb7e4ce..0853f209 100644 --- a/tests/protocol/helpers.py +++ b/tests/protocol/helpers.py @@ -104,10 +104,8 @@ def read(self, _length: int, /) -> bytes: return self.received.pop(0) @override - def write(self, data: bytes | bytearray | str, /) -> None: + def write(self, data: bytes | bytearray, /) -> None: """Append outgoing datagram payload to the send buffer.""" - if isinstance(data, str): - data = data.encode("utf-8") self.sent.write(data) def receive(self, data: bytes | bytearray) -> None: @@ -140,10 +138,8 @@ async def read(self, _length: int, /) -> bytes: return self.received.pop(0) @override - async def write(self, data: bytes | bytearray | str, /) -> None: + async def write(self, data: bytes | bytearray, /) -> None: """Append outgoing datagram payload to the send buffer.""" - if isinstance(data, str): - data = data.encode("utf-8") self.sent.write(data) def receive(self, data: bytes | bytearray) -> None: From 5bafeb1695e1e56c68d0fbe4043520c5a8358f02 Mon Sep 17 00:00:00 2001 From: ItsDrike Date: Fri, 27 Mar 2026 18:12:25 +0100 Subject: [PATCH 05/10] Remove getattr + iscoro test for async connections --- tests/protocol/test_async_support.py | 23 ----------------------- 1 file changed, 23 deletions(-) delete mode 100644 tests/protocol/test_async_support.py diff --git a/tests/protocol/test_async_support.py b/tests/protocol/test_async_support.py deleted file mode 100644 index 0647346b..00000000 --- a/tests/protocol/test_async_support.py +++ /dev/null @@ -1,23 +0,0 @@ -from inspect import iscoroutinefunction - -from mcstatus._protocol.io.connection import TCPAsyncSocketConnection, UDPAsyncSocketConnection - - -def test_is_completely_asynchronous(): - conn = TCPAsyncSocketConnection - assertions = 0 - for attribute in dir(conn): - if attribute.startswith(("read_", "write_")): - assert iscoroutinefunction(getattr(conn, attribute)) - assertions += 1 - assert assertions > 0, "No read_*/write_* attributes were found" - - -def test_query_is_completely_asynchronous(): - conn = UDPAsyncSocketConnection - assertions = 0 - for attribute in dir(conn): - if attribute.startswith(("read_", "write_")): - assert iscoroutinefunction(getattr(conn, attribute)) - assertions += 1 - assert assertions > 0, "No read_*/write_* attributes were found" From 1ea255012f1f967b311c1852369680d1e76d2165 Mon Sep 17 00:00:00 2001 From: ItsDrike Date: Fri, 27 Mar 2026 18:34:53 +0100 Subject: [PATCH 06/10] Fix StructFormat.CHAR type & explicitly test it --- mcstatus/_protocol/io/base_io.py | 8 ++++---- tests/protocol/test_base_io.py | 22 ++++++++++++++++++++++ 2 files changed, 26 insertions(+), 4 deletions(-) diff --git a/mcstatus/_protocol/io/base_io.py b/mcstatus/_protocol/io/base_io.py index 898a138c..07138ec6 100644 --- a/mcstatus/_protocol/io/base_io.py +++ b/mcstatus/_protocol/io/base_io.py @@ -142,7 +142,7 @@ async def write_value(self, fmt: FLOAT_FORMATS_TYPE, value: float, /) -> None: . async def write_value(self, fmt: Literal[StructFormat.BOOL], value: bool, /) -> None: ... # noqa: FBT001 @overload - async def write_value(self, fmt: Literal[StructFormat.CHAR], value: str, /) -> None: ... + async def write_value(self, fmt: Literal[StructFormat.CHAR], value: bytes, /) -> None: ... async def write_value(self, fmt: StructFormat, value: object, /) -> None: """Write a given ``value`` as given struct format (``fmt``) in big-endian mode.""" @@ -281,7 +281,7 @@ def write_value(self, fmt: FLOAT_FORMATS_TYPE, value: float, /) -> None: ... def write_value(self, fmt: Literal[StructFormat.BOOL], value: bool, /) -> None: ... # noqa: FBT001 @overload - def write_value(self, fmt: Literal[StructFormat.CHAR], value: str, /) -> None: ... + def write_value(self, fmt: Literal[StructFormat.CHAR], value: bytes, /) -> None: ... def write_value(self, fmt: StructFormat, value: object, /) -> None: """Write a given ``value`` as given struct format (``fmt``) in big-endian mode.""" @@ -424,7 +424,7 @@ async def read_value(self, fmt: FLOAT_FORMATS_TYPE, /) -> float: ... async def read_value(self, fmt: Literal[StructFormat.BOOL], /) -> bool: ... @overload - async def read_value(self, fmt: Literal[StructFormat.CHAR], /) -> str: ... + async def read_value(self, fmt: Literal[StructFormat.CHAR], /) -> bytes: ... async def read_value(self, fmt: StructFormat, /) -> object: """Read a value as given struct format (``fmt``) in big-endian mode. @@ -593,7 +593,7 @@ def read_value(self, fmt: FLOAT_FORMATS_TYPE, /) -> float: ... def read_value(self, fmt: Literal[StructFormat.BOOL], /) -> bool: ... @overload - def read_value(self, fmt: Literal[StructFormat.CHAR], /) -> str: ... + def read_value(self, fmt: Literal[StructFormat.CHAR], /) -> bytes: ... def read_value(self, fmt: StructFormat, /) -> object: """Read a value as given struct format (``fmt``) in big-endian mode. diff --git a/tests/protocol/test_base_io.py b/tests/protocol/test_base_io.py index 022715e2..19e36c99 100644 --- a/tests/protocol/test_base_io.py +++ b/tests/protocol/test_base_io.py @@ -67,6 +67,20 @@ async def test_write_value_matches_reference(io_type: IO_TYPE, fmt: INT_FORMATS_ assert io.flush() == expected +@pytest.mark.asyncio +async def test_write_value_char_uses_single_byte(io_type: IO_TYPE): + io = io_type() + await maybe_await(io.write_value(StructFormat.CHAR, b"a")) + assert io.flush() == b"a" + + +@pytest.mark.asyncio +async def test_write_value_char_rejects_non_single_byte(io_type: IO_TYPE): + io = io_type() + with pytest.raises(struct.error): + await maybe_await(io.write_value(StructFormat.CHAR, b"ab")) + + @pytest.mark.parametrize( ("fmt", "value"), [ @@ -102,6 +116,14 @@ async def test_read_value_matches_reference(io_type: IO_TYPE, encoded: bytes, fm assert await maybe_await(io.read_value(fmt)) == expected +@pytest.mark.asyncio +async def test_read_value_char_returns_bytes(io_type: IO_TYPE): + io = io_type(b"a") + value = await maybe_await(io.read_value(StructFormat.CHAR)) + assert value == b"a" + assert isinstance(value, bytes) + + @pytest.mark.parametrize( ("number", "expected"), [ From 757fd193b694a943a2e7bfe3451eefb18f52d79d Mon Sep 17 00:00:00 2001 From: ItsDrike Date: Fri, 27 Mar 2026 18:53:14 +0100 Subject: [PATCH 07/10] Reject overlong varuint encodings during read --- mcstatus/_protocol/io/base_io.py | 32 +++++++++++++++++++++++++++----- tests/protocol/test_base_io.py | 22 ++++++++++++++++++++++ 2 files changed, 49 insertions(+), 5 deletions(-) diff --git a/mcstatus/_protocol/io/base_io.py b/mcstatus/_protocol/io/base_io.py index 07138ec6..9e94493c 100644 --- a/mcstatus/_protocol/io/base_io.py +++ b/mcstatus/_protocol/io/base_io.py @@ -1,9 +1,9 @@ from __future__ import annotations +import math import struct from abc import ABC, abstractmethod from enum import Enum -from itertools import count from typing import Literal, TYPE_CHECKING, TypeAlias, TypeVar, overload if TYPE_CHECKING: @@ -455,13 +455,18 @@ async def _read_varuint(self, *, max_bits: int | None = None) -> int: making the encoding little-endian in 7-bit groups. :param max_bits: Maximum allowed bit width for the decoded value. - :raises OSError: If the decoded value exceeds the allowed bit width. + :raises OSError: + * If the decoded value exceeds the allowed bit width. + * If more bytes are received than can possibly encode a value constrained by ``max_bits``. :return: The decoded unsigned integer. """ value_max = (1 << (max_bits)) - 1 if max_bits is not None else float("inf") + # Varints carry 7 value bits per byte, so this is the exact maximum encoded length. + byte_limit = math.ceil(max_bits / 7) if max_bits is not None else None result = 0 - for i in count(): # pragma: no branch + i = 0 + while byte_limit is None or i < byte_limit: byte = await self.read_value(StructFormat.UBYTE) # Read 7 least significant value bits in this byte, and shift them appropriately to be in the right place # then simply add them (OR) as additional 7 most significant bits in our result @@ -476,6 +481,12 @@ async def _read_varuint(self, *, max_bits: int | None = None) -> int: if not byte & 0x80: break + i += 1 + else: + raise OSError( + f"Received varint had too many bytes for {max_bits}-bit int (continuation bit set on byte {byte_limit})." + ) + return result async def read_varint(self) -> int: @@ -624,13 +635,18 @@ def _read_varuint(self, *, max_bits: int | None = None) -> int: making the encoding little-endian in 7-bit groups. :param max_bits: Maximum allowed bit width for the decoded value. - :raises OSError: If the decoded value exceeds the allowed bit width. + :raises OSError: + * If the decoded value exceeds the allowed bit width. + * If more bytes are received than can possibly encode a value constrained by ``max_bits``. :return: The decoded unsigned integer. """ value_max = (1 << (max_bits)) - 1 if max_bits is not None else float("inf") + # Varints carry 7 value bits per byte, so this is the exact maximum encoded length. + byte_limit = math.ceil(max_bits / 7) if max_bits is not None else None result = 0 - for i in count(): # pragma: no branch + i = 0 + while byte_limit is None or i < byte_limit: byte = self.read_value(StructFormat.UBYTE) # Read 7 least significant value bits in this byte, and shift them appropriately to be in the right place # then simply add them (OR) as additional 7 most significant bits in our result @@ -645,6 +661,12 @@ def _read_varuint(self, *, max_bits: int | None = None) -> int: if not byte & 0x80: break + i += 1 + else: + raise OSError( + f"Received varint had too many bytes for {max_bits}-bit int (continuation bit set on byte {byte_limit})." + ) + return result def read_varint(self) -> int: diff --git a/tests/protocol/test_base_io.py b/tests/protocol/test_base_io.py index 19e36c99..e9274d85 100644 --- a/tests/protocol/test_base_io.py +++ b/tests/protocol/test_base_io.py @@ -189,6 +189,28 @@ async def test_read_varuint_rejects_out_of_range(io_type: IO_TYPE, encoded: byte await maybe_await(io._read_varuint(max_bits=max_bits)) +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("encoded", "max_bits", "max_bytes"), + [ + (b"\x80\x80\x80\x00", 16, 3), + (b"\x80\x80\x80\x80\x80\x00", 32, 5), + ], +) +async def test_read_varuint_rejects_too_many_bytes( + io_type: IO_TYPE, + encoded: bytes, + max_bits: int, + max_bytes: int, +): + io = io_type(encoded) + with pytest.raises( + OSError, + match=rf"^Received varint had too many bytes for {max_bits}-bit int \(continuation bit set on byte {max_bytes}\)\.$", + ): + await maybe_await(io._read_varuint(max_bits=max_bits)) + + @pytest.mark.parametrize( ("number", "expected"), [ From 3107a7a374443444dd369983d3973b15be157f28 Mon Sep 17 00:00:00 2001 From: ItsDrike Date: Fri, 27 Mar 2026 19:13:59 +0100 Subject: [PATCH 08/10] Reject negative lengths --- mcstatus/_protocol/io/base_io.py | 12 ++++++++++++ mcstatus/_protocol/io/buffer.py | 5 +++++ tests/protocol/test_base_io.py | 14 ++++++++++++++ tests/protocol/test_connection.py | 4 ++++ 4 files changed, 35 insertions(+) diff --git a/mcstatus/_protocol/io/base_io.py b/mcstatus/_protocol/io/base_io.py index 9e94493c..c7f99d17 100644 --- a/mcstatus/_protocol/io/base_io.py +++ b/mcstatus/_protocol/io/base_io.py @@ -518,9 +518,12 @@ async def read_varlong(self) -> int: async def read_bytearray(self, /) -> bytes: """Read a sequence of bytes prefixed with its length encoded as a varint. + :raises OSError: If the encoded length prefix is negative. :return: The decoded byte sequence. """ length = await self.read_varint() + if length < 0: + raise OSError(f"Length prefix for byte arrays must be non-negative, got {length}.") return await self.read(length) async def read_ascii(self) -> str: @@ -548,6 +551,7 @@ async def read_utf(self) -> str: the varint encoding overhead. :raises OSError: + * If the length prefix is negative. * If the length prefix exceeds the maximum of ``131068``, the string will not be read at all, and the error will be raised immediately after reading the prefix. * If the decoded string contains more than ``32767`` characters. In this case the data is still @@ -556,6 +560,8 @@ async def read_utf(self) -> str: :return: Decoded UTF-8 string. """ length = await self.read_varint() + if length < 0: + raise OSError(f"Length prefix for utf strings must be non-negative, got {length}.") if length > 131068: raise OSError(f"Maximum read limit for utf strings is 131068 bytes, got {length}.") @@ -698,9 +704,12 @@ def read_varlong(self) -> int: def read_bytearray(self) -> bytes: """Read a sequence of bytes prefixed with its length encoded as a varint. + :raises OSError: If the encoded length prefix is negative. :return: The decoded byte sequence. """ length = self.read_varint() + if length < 0: + raise OSError(f"Length prefix for byte arrays must be non-negative, got {length}.") return self.read(length) def read_ascii(self) -> str: @@ -728,6 +737,7 @@ def read_utf(self) -> str: the varint encoding overhead. :raises OSError: + * If the length prefix is negative. * If the length prefix exceeds the maximum of ``131068``, the string will not be read at all, and the error will be raised immediately after reading the prefix. * If the decoded string contains more than ``32767`` characters. In this case the data is still @@ -736,6 +746,8 @@ def read_utf(self) -> str: :return: Decoded UTF-8 string. """ length = self.read_varint() + if length < 0: + raise OSError(f"Length prefix for utf strings must be non-negative, got {length}.") if length > 131068: raise OSError(f"Maximum read limit for utf strings is 131068 bytes, got {length}.") diff --git a/mcstatus/_protocol/io/buffer.py b/mcstatus/_protocol/io/buffer.py index 772cac8f..e0d594a2 100644 --- a/mcstatus/_protocol/io/buffer.py +++ b/mcstatus/_protocol/io/buffer.py @@ -55,7 +55,12 @@ def read(self, length: int, /) -> bytes: data will still be depleted and the partial data that was read will be a part of the error message in the :exc:`OSError`. This behavior is here to mimic reading from a real socket connection. + + If ``length`` is negative, an :exc:`OSError` will be raised. """ + if length < 0: + raise OSError(f"Requested to read a negative amount of data: {length}.") + end = self.pos + length if end > len(self): diff --git a/tests/protocol/test_base_io.py b/tests/protocol/test_base_io.py index e9274d85..da607d25 100644 --- a/tests/protocol/test_base_io.py +++ b/tests/protocol/test_base_io.py @@ -395,6 +395,13 @@ async def test_read_bytearray_matches_reference(io_type: IO_TYPE, encoded: bytes assert await maybe_await(io.read_bytearray()) == expected +@pytest.mark.asyncio +async def test_read_bytearray_rejects_negative_length(io_type: IO_TYPE): + io = io_type(b"\xff\xff\xff\xff\x0f") + with pytest.raises(OSError, match=r"^Length prefix for byte arrays must be non-negative, got -1\.$"): + await maybe_await(io.read_bytearray()) + + @pytest.mark.parametrize( ("value", "expected"), [ @@ -472,6 +479,13 @@ async def test_read_utf_rejects_too_many_bytes(io_type: IO_TYPE): await maybe_await(io.read_utf()) +@pytest.mark.asyncio +async def test_read_utf_rejects_negative_length(io_type: IO_TYPE): + io = io_type(b"\xff\xff\xff\xff\x0f") + with pytest.raises(OSError, match=r"^Length prefix for utf strings must be non-negative, got -1\.$"): + await maybe_await(io.read_utf()) + + @pytest.mark.asyncio async def test_read_utf_rejects_too_many_characters(io_type: IO_TYPE): text = "a" * 32768 diff --git a/tests/protocol/test_connection.py b/tests/protocol/test_connection.py index a70a5ad4..7819a3cd 100644 --- a/tests/protocol/test_connection.py +++ b/tests/protocol/test_connection.py @@ -256,6 +256,10 @@ def test_read_not_enough(self): ): self.connection.read(2) + def test_read_negative_length(self): + with pytest.raises(IOError, match=r"^Requested to read a negative amount of data: -1\.$"): + self.connection.read(-1) + class TestTCPSocketConnection: @pytest.fixture(scope="class") From 1f913d6db069a114d867b538b089db7d4b322814 Mon Sep 17 00:00:00 2001 From: ItsDrike Date: Fri, 27 Mar 2026 19:29:04 +0100 Subject: [PATCH 09/10] Rename io_type to conn_cls in tests --- tests/protocol/test_base_io.py | 272 +++++++++++++++++---------------- 1 file changed, 141 insertions(+), 131 deletions(-) diff --git a/tests/protocol/test_base_io.py b/tests/protocol/test_base_io.py index da607d25..9e0465cf 100644 --- a/tests/protocol/test_base_io.py +++ b/tests/protocol/test_base_io.py @@ -14,7 +14,7 @@ if TYPE_CHECKING: from collections.abc import Awaitable -IO_TYPE = type[SyncBufferConnection] | type[AsyncBufferConnection] +ConnectionClass = type[SyncBufferConnection] | type[AsyncBufferConnection] T = TypeVar("T") @@ -24,8 +24,8 @@ pytest.param(AsyncBufferConnection, id="async"), ] ) -def io_type(request: pytest.FixtureRequest) -> IO_TYPE: - """Provide a parametrized sync and async IO connections to each test.""" +def conn_cls(request: pytest.FixtureRequest) -> ConnectionClass: + """Provide a parametrized sync/async connection class for each test.""" return request.param @@ -61,24 +61,29 @@ async def maybe_await(value: Awaitable[T] | T, /) -> T: ], ) @pytest.mark.asyncio -async def test_write_value_matches_reference(io_type: IO_TYPE, fmt: INT_FORMATS_TYPE, value: int, expected: bytes): - io = io_type() - await maybe_await(io.write_value(fmt, value)) - assert io.flush() == expected +async def test_write_value_matches_reference( + conn_cls: ConnectionClass, + fmt: INT_FORMATS_TYPE, + value: int, + expected: bytes, +): + conn = conn_cls() + await maybe_await(conn.write_value(fmt, value)) + assert conn.flush() == expected @pytest.mark.asyncio -async def test_write_value_char_uses_single_byte(io_type: IO_TYPE): - io = io_type() - await maybe_await(io.write_value(StructFormat.CHAR, b"a")) - assert io.flush() == b"a" +async def test_write_value_char_uses_single_byte(conn_cls: ConnectionClass): + conn = conn_cls() + await maybe_await(conn.write_value(StructFormat.CHAR, b"a")) + assert conn.flush() == b"a" @pytest.mark.asyncio -async def test_write_value_char_rejects_non_single_byte(io_type: IO_TYPE): - io = io_type() +async def test_write_value_char_rejects_non_single_byte(conn_cls: ConnectionClass): + conn = conn_cls() with pytest.raises(struct.error): - await maybe_await(io.write_value(StructFormat.CHAR, b"ab")) + await maybe_await(conn.write_value(StructFormat.CHAR, b"ab")) @pytest.mark.parametrize( @@ -91,10 +96,10 @@ async def test_write_value_char_rejects_non_single_byte(io_type: IO_TYPE): ], ) @pytest.mark.asyncio -async def test_write_value_rejects_out_of_range(io_type: IO_TYPE, fmt: INT_FORMATS_TYPE, value: int): - io = io_type() +async def test_write_value_rejects_out_of_range(conn_cls: ConnectionClass, fmt: INT_FORMATS_TYPE, value: int): + conn = conn_cls() with pytest.raises(struct.error): - await maybe_await(io.write_value(fmt, value)) + await maybe_await(conn.write_value(fmt, value)) @pytest.mark.parametrize( @@ -111,15 +116,20 @@ async def test_write_value_rejects_out_of_range(io_type: IO_TYPE, fmt: INT_FORMA ], ) @pytest.mark.asyncio -async def test_read_value_matches_reference(io_type: IO_TYPE, encoded: bytes, fmt: INT_FORMATS_TYPE, expected: int): - io = io_type(encoded) - assert await maybe_await(io.read_value(fmt)) == expected +async def test_read_value_matches_reference( + conn_cls: ConnectionClass, + encoded: bytes, + fmt: INT_FORMATS_TYPE, + expected: int, +): + conn = conn_cls(encoded) + assert await maybe_await(conn.read_value(fmt)) == expected @pytest.mark.asyncio -async def test_read_value_char_returns_bytes(io_type: IO_TYPE): - io = io_type(b"a") - value = await maybe_await(io.read_value(StructFormat.CHAR)) +async def test_read_value_char_returns_bytes(conn_cls: ConnectionClass): + conn = conn_cls(b"a") + value = await maybe_await(conn.read_value(StructFormat.CHAR)) assert value == b"a" assert isinstance(value, bytes) @@ -136,10 +146,10 @@ async def test_read_value_char_returns_bytes(io_type: IO_TYPE): ], ) @pytest.mark.asyncio -async def test_write_varuint_matches_reference(io_type: IO_TYPE, number: int, expected: bytes): - io = io_type() - await maybe_await(io._write_varuint(number)) - assert io.flush() == expected +async def test_write_varuint_matches_reference(conn_cls: ConnectionClass, number: int, expected: bytes): + conn = conn_cls() + await maybe_await(conn._write_varuint(number)) + assert conn.flush() == expected @pytest.mark.parametrize( @@ -154,9 +164,9 @@ async def test_write_varuint_matches_reference(io_type: IO_TYPE, number: int, ex ], ) @pytest.mark.asyncio -async def test_read_varuint_matches_reference(io_type: IO_TYPE, encoded: bytes, expected: int): - io = io_type(encoded) - assert await maybe_await(io._read_varuint()) == expected +async def test_read_varuint_matches_reference(conn_cls: ConnectionClass, encoded: bytes, expected: int): + conn = conn_cls(encoded) + assert await maybe_await(conn._read_varuint()) == expected @pytest.mark.asyncio @@ -169,10 +179,10 @@ async def test_read_varuint_matches_reference(io_type: IO_TYPE, encoded: bytes, (2**32, 32), ], ) -async def test_write_varuint_rejects_out_of_range(io_type: IO_TYPE, number: int, max_bits: int): - io = io_type() +async def test_write_varuint_rejects_out_of_range(conn_cls: ConnectionClass, number: int, max_bits: int): + conn = conn_cls() with pytest.raises(ValueError, match=r"outside of the range of"): - await maybe_await(io._write_varuint(number, max_bits=max_bits)) + await maybe_await(conn._write_varuint(number, max_bits=max_bits)) @pytest.mark.asyncio @@ -183,10 +193,10 @@ async def test_write_varuint_rejects_out_of_range(io_type: IO_TYPE, number: int, (b"\x80\x80\x80\x80\x10", 32), ], ) -async def test_read_varuint_rejects_out_of_range(io_type: IO_TYPE, encoded: bytes, max_bits: int): - io = io_type(encoded) +async def test_read_varuint_rejects_out_of_range(conn_cls: ConnectionClass, encoded: bytes, max_bits: int): + conn = conn_cls(encoded) with pytest.raises(OSError, match=r"outside the range of"): - await maybe_await(io._read_varuint(max_bits=max_bits)) + await maybe_await(conn._read_varuint(max_bits=max_bits)) @pytest.mark.asyncio @@ -198,17 +208,17 @@ async def test_read_varuint_rejects_out_of_range(io_type: IO_TYPE, encoded: byte ], ) async def test_read_varuint_rejects_too_many_bytes( - io_type: IO_TYPE, + conn_cls: ConnectionClass, encoded: bytes, max_bits: int, max_bytes: int, ): - io = io_type(encoded) + conn = conn_cls(encoded) with pytest.raises( OSError, match=rf"^Received varint had too many bytes for {max_bits}-bit int \(continuation bit set on byte {max_bytes}\)\.$", ): - await maybe_await(io._read_varuint(max_bits=max_bits)) + await maybe_await(conn._read_varuint(max_bits=max_bits)) @pytest.mark.parametrize( @@ -221,10 +231,10 @@ async def test_read_varuint_rejects_too_many_bytes( ], ) @pytest.mark.asyncio -async def test_write_varint_matches_reference(io_type: IO_TYPE, number: int, expected: bytes): - io = io_type() - await maybe_await(io.write_varint(number)) - assert io.flush() == expected +async def test_write_varint_matches_reference(conn_cls: ConnectionClass, number: int, expected: bytes): + conn = conn_cls() + await maybe_await(conn.write_varint(number)) + assert conn.flush() == expected @pytest.mark.parametrize( @@ -237,9 +247,9 @@ async def test_write_varint_matches_reference(io_type: IO_TYPE, number: int, exp ], ) @pytest.mark.asyncio -async def test_read_varint_matches_reference(io_type: IO_TYPE, encoded: bytes, expected: int): - io = io_type(encoded) - assert await maybe_await(io.read_varint()) == expected +async def test_read_varint_matches_reference(conn_cls: ConnectionClass, encoded: bytes, expected: int): + conn = conn_cls(encoded) + assert await maybe_await(conn.read_varint()) == expected @pytest.mark.parametrize( @@ -252,10 +262,10 @@ async def test_read_varint_matches_reference(io_type: IO_TYPE, encoded: bytes, e ], ) @pytest.mark.asyncio -async def test_write_varlong_matches_reference(io_type: IO_TYPE, number: int, expected: bytes): - io = io_type() - await maybe_await(io.write_varlong(number)) - assert io.flush() == expected +async def test_write_varlong_matches_reference(conn_cls: ConnectionClass, number: int, expected: bytes): + conn = conn_cls() + await maybe_await(conn.write_varlong(number)) + assert conn.flush() == expected @pytest.mark.parametrize( @@ -268,100 +278,100 @@ async def test_write_varlong_matches_reference(io_type: IO_TYPE, number: int, ex ], ) @pytest.mark.asyncio -async def test_read_varlong_matches_reference(io_type: IO_TYPE, encoded: bytes, expected: int): - io = io_type(encoded) - assert await maybe_await(io.read_varlong()) == expected +async def test_read_varlong_matches_reference(conn_cls: ConnectionClass, encoded: bytes, expected: int): + conn = conn_cls(encoded) + assert await maybe_await(conn.read_varlong()) == expected @pytest.mark.asyncio @pytest.mark.parametrize("number", [0, 1, 127, 16_384, -1, -(2**31), (2**31) - 1]) -async def test_varint_roundtrip(io_type: IO_TYPE, number: int): - io = io_type() - await maybe_await(io.write_varint(number)) - io.receive(io.flush()) - assert await maybe_await(io.read_varint()) == number +async def test_varint_roundtrip(conn_cls: ConnectionClass, number: int): + conn = conn_cls() + await maybe_await(conn.write_varint(number)) + conn.receive(conn.flush()) + assert await maybe_await(conn.read_varint()) == number @pytest.mark.asyncio @pytest.mark.parametrize("number", [127, 16_384, -128, -16_383, -(2**63), (2**63) - 1]) -async def test_varlong_roundtrip(io_type: IO_TYPE, number: int): - io = io_type() - await maybe_await(io.write_varlong(number)) - io.receive(io.flush()) - assert await maybe_await(io.read_varlong()) == number +async def test_varlong_roundtrip(conn_cls: ConnectionClass, number: int): + conn = conn_cls() + await maybe_await(conn.write_varlong(number)) + conn.receive(conn.flush()) + assert await maybe_await(conn.read_varlong()) == number @pytest.mark.asyncio @pytest.mark.parametrize("number", [-(2**63) - 1, 2**63]) -async def test_write_varlong_rejects_out_of_range(io_type: IO_TYPE, number: int): - io = io_type() +async def test_write_varlong_rejects_out_of_range(conn_cls: ConnectionClass, number: int): + conn = conn_cls() with pytest.raises(ValueError, match=r"out of range"): - await maybe_await(io.write_varlong(number)) + await maybe_await(conn.write_varlong(number)) @pytest.mark.asyncio -async def test_optional_helpers(io_type: IO_TYPE): - io = io_type() +async def test_optional_helpers(conn_cls: ConnectionClass): + conn = conn_cls() - if isinstance(io, AsyncBufferConnection): + if isinstance(conn, AsyncBufferConnection): writer = AsyncMock(return_value="written") - assert await io.write_optional(None, writer) is None + assert await conn.write_optional(None, writer) is None writer.assert_not_awaited() - assert io.flush() == b"\x00" + assert conn.flush() == b"\x00" - assert await io.write_optional("value", writer) == "written" + assert await conn.write_optional("value", writer) == "written" writer.assert_awaited_once_with("value") - assert io.flush() == b"\x01" + assert conn.flush() == b"\x01" reader = AsyncMock(return_value="parsed") - io.receive(b"\x00") - assert await io.read_optional(reader) is None + conn.receive(b"\x00") + assert await conn.read_optional(reader) is None reader.assert_not_awaited() - io.receive(b"\x01") - assert await io.read_optional(reader) == "parsed" + conn.receive(b"\x01") + assert await conn.read_optional(reader) == "parsed" reader.assert_awaited_once_with() return writer = Mock(return_value="written") - assert io.write_optional(None, writer) is None + assert conn.write_optional(None, writer) is None writer.assert_not_called() - assert io.flush() == b"\x00" + assert conn.flush() == b"\x00" - assert io.write_optional("value", writer) == "written" + assert conn.write_optional("value", writer) == "written" writer.assert_called_once_with("value") - assert io.flush() == b"\x01" + assert conn.flush() == b"\x01" reader = Mock(return_value="parsed") - io.receive(b"\x00") - assert io.read_optional(reader) is None + conn.receive(b"\x00") + assert conn.read_optional(reader) is None reader.assert_not_called() - io.receive(b"\x01") - assert io.read_optional(reader) == "parsed" + conn.receive(b"\x01") + assert conn.read_optional(reader) == "parsed" reader.assert_called_once_with() @pytest.mark.asyncio -async def test_write_and_read_ascii(io_type: IO_TYPE): - io = io_type() - await maybe_await(io.write_ascii("hello")) +async def test_write_and_read_ascii(conn_cls: ConnectionClass): + conn = conn_cls() + await maybe_await(conn.write_ascii("hello")) - io.receive(io.flush()) - assert await maybe_await(io.read_ascii()) == "hello" + conn.receive(conn.flush()) + assert await maybe_await(conn.read_ascii()) == "hello" @pytest.mark.asyncio -async def test_write_and_read_bytearray(io_type: IO_TYPE): - io = io_type() +async def test_write_and_read_bytearray(conn_cls: ConnectionClass): + conn = conn_cls() data = b"\x00\x01hello\xff" - await maybe_await(io.write_bytearray(data)) - io.receive(io.flush()) + await maybe_await(conn.write_bytearray(data)) + conn.receive(conn.flush()) - assert await maybe_await(io.read_bytearray()) == data + assert await maybe_await(conn.read_bytearray()) == data @pytest.mark.parametrize( @@ -374,10 +384,10 @@ async def test_write_and_read_bytearray(io_type: IO_TYPE): ], ) @pytest.mark.asyncio -async def test_write_bytearray_matches_reference(io_type: IO_TYPE, data: bytes, expected: bytes): - io = io_type() - await maybe_await(io.write_bytearray(data)) - assert io.flush() == expected +async def test_write_bytearray_matches_reference(conn_cls: ConnectionClass, data: bytes, expected: bytes): + conn = conn_cls() + await maybe_await(conn.write_bytearray(data)) + assert conn.flush() == expected @pytest.mark.parametrize( @@ -390,16 +400,16 @@ async def test_write_bytearray_matches_reference(io_type: IO_TYPE, data: bytes, ], ) @pytest.mark.asyncio -async def test_read_bytearray_matches_reference(io_type: IO_TYPE, encoded: bytes, expected: bytes): - io = io_type(encoded) - assert await maybe_await(io.read_bytearray()) == expected +async def test_read_bytearray_matches_reference(conn_cls: ConnectionClass, encoded: bytes, expected: bytes): + conn = conn_cls(encoded) + assert await maybe_await(conn.read_bytearray()) == expected @pytest.mark.asyncio -async def test_read_bytearray_rejects_negative_length(io_type: IO_TYPE): - io = io_type(b"\xff\xff\xff\xff\x0f") +async def test_read_bytearray_rejects_negative_length(conn_cls: ConnectionClass): + conn = conn_cls(b"\xff\xff\xff\xff\x0f") with pytest.raises(OSError, match=r"^Length prefix for byte arrays must be non-negative, got -1\.$"): - await maybe_await(io.read_bytearray()) + await maybe_await(conn.read_bytearray()) @pytest.mark.parametrize( @@ -411,10 +421,10 @@ async def test_read_bytearray_rejects_negative_length(io_type: IO_TYPE): ], ) @pytest.mark.asyncio -async def test_write_ascii_matches_reference(io_type: IO_TYPE, value: str, expected: bytes): - io = io_type() - await maybe_await(io.write_ascii(value)) - assert io.flush() == expected +async def test_write_ascii_matches_reference(conn_cls: ConnectionClass, value: str, expected: bytes): + conn = conn_cls() + await maybe_await(conn.write_ascii(value)) + assert conn.flush() == expected @pytest.mark.parametrize( @@ -426,9 +436,9 @@ async def test_write_ascii_matches_reference(io_type: IO_TYPE, value: str, expec ], ) @pytest.mark.asyncio -async def test_read_ascii_matches_reference(io_type: IO_TYPE, encoded: bytes, expected: str): - io = io_type(encoded) - assert await maybe_await(io.read_ascii()) == expected +async def test_read_ascii_matches_reference(conn_cls: ConnectionClass, encoded: bytes, expected: str): + conn = conn_cls(encoded) + assert await maybe_await(conn.read_ascii()) == expected @pytest.mark.parametrize( @@ -441,10 +451,10 @@ async def test_read_ascii_matches_reference(io_type: IO_TYPE, encoded: bytes, ex ], ) @pytest.mark.asyncio -async def test_write_utf_matches_reference(io_type: IO_TYPE, value: str, expected: bytes): - io = io_type() - await maybe_await(io.write_utf(value)) - assert io.flush() == expected +async def test_write_utf_matches_reference(conn_cls: ConnectionClass, value: str, expected: bytes): + conn = conn_cls() + await maybe_await(conn.write_utf(value)) + assert conn.flush() == expected @pytest.mark.parametrize( @@ -457,42 +467,42 @@ async def test_write_utf_matches_reference(io_type: IO_TYPE, value: str, expecte ], ) @pytest.mark.asyncio -async def test_read_utf_matches_reference(io_type: IO_TYPE, encoded: bytes, expected: str): - io = io_type(encoded) - assert await maybe_await(io.read_utf()) == expected +async def test_read_utf_matches_reference(conn_cls: ConnectionClass, encoded: bytes, expected: str): + conn = conn_cls(encoded) + assert await maybe_await(conn.read_utf()) == expected @pytest.mark.asyncio -async def test_write_utf_rejects_too_many_characters(io_type: IO_TYPE): - io = io_type() +async def test_write_utf_rejects_too_many_characters(conn_cls: ConnectionClass): + conn = conn_cls() with pytest.raises(ValueError, match=r"Maximum character limit for writing strings is 32767 characters"): - await maybe_await(io.write_utf("a" * 32768)) + await maybe_await(conn.write_utf("a" * 32768)) @pytest.mark.asyncio -async def test_read_utf_rejects_too_many_bytes(io_type: IO_TYPE): +async def test_read_utf_rejects_too_many_bytes(conn_cls: ConnectionClass): payload = Buffer() payload.write_varint(131069) - io = io_type(payload) + conn = conn_cls(payload) with pytest.raises(OSError, match=r"Maximum read limit for utf strings is 131068 bytes, got 131069"): - await maybe_await(io.read_utf()) + await maybe_await(conn.read_utf()) @pytest.mark.asyncio -async def test_read_utf_rejects_negative_length(io_type: IO_TYPE): - io = io_type(b"\xff\xff\xff\xff\x0f") +async def test_read_utf_rejects_negative_length(conn_cls: ConnectionClass): + conn = conn_cls(b"\xff\xff\xff\xff\x0f") with pytest.raises(OSError, match=r"^Length prefix for utf strings must be non-negative, got -1\.$"): - await maybe_await(io.read_utf()) + await maybe_await(conn.read_utf()) @pytest.mark.asyncio -async def test_read_utf_rejects_too_many_characters(io_type: IO_TYPE): +async def test_read_utf_rejects_too_many_characters(conn_cls: ConnectionClass): text = "a" * 32768 payload = Buffer() payload.write_varint(len(text.encode("utf-8"))) payload.write(text.encode("utf-8")) - io = io_type(payload) + conn = conn_cls(payload) with pytest.raises(OSError, match=r"Maximum read limit for utf strings is 32767 characters, got 32768"): - await maybe_await(io.read_utf()) + await maybe_await(conn.read_utf()) From 2102cc9369e608a2f18484f43deeea1b6eed96d5 Mon Sep 17 00:00:00 2001 From: ItsDrike Date: Thu, 9 Apr 2026 00:57:38 +0200 Subject: [PATCH 10/10] Disable bytes promotions & use bytes instead of bytearray --- mcstatus/_protocol/io/connection.py | 20 ++++++++++---------- mcstatus/responses/forge.py | 4 ++-- pyproject.toml | 2 +- 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/mcstatus/_protocol/io/connection.py b/mcstatus/_protocol/io/connection.py index ef4bdf9d..a84e9874 100644 --- a/mcstatus/_protocol/io/connection.py +++ b/mcstatus/_protocol/io/connection.py @@ -82,7 +82,7 @@ def __init__(self, addr: tuple[str | None, int], timeout: float = 3) -> None: self.socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) @override - def read(self, length: int, /) -> bytearray: + def read(self, length: int, /) -> bytes: """Return length bytes read from :attr:`.socket`. Raises :exc:`OSError` when server doesn't respond.""" result = bytearray() while len(result) < length: @@ -90,7 +90,7 @@ def read(self, length: int, /) -> bytearray: if len(new) == 0: raise OSError("Server did not respond with any information!") result.extend(new) - return result + return bytes(result) def write(self, data: bytes | bytearray, /) -> None: """Send data on :attr:`.socket`.""" @@ -118,11 +118,11 @@ def remaining(self) -> int: return 65535 @override - def read(self, _length: int, /) -> bytearray: + def read(self, _length: int, /) -> bytes: """Return up to :meth:`.remaining` bytes. Length does nothing here.""" - result = bytearray() + result = b"" while len(result) == 0: - result.extend(self.socket.recvfrom(self.remaining)[0]) + result = self.socket.recvfrom(self.remaining)[0] return result @override @@ -153,7 +153,7 @@ async def connect(self) -> None: sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) @override - async def read(self, length: int, /) -> bytearray: + async def read(self, length: int, /) -> bytes: """Read up to ``length`` bytes from :attr:`.reader`.""" result = bytearray() while len(result) < length: @@ -168,7 +168,7 @@ async def read(self, length: int, /) -> bytearray: f" Partial obtained data: {result!r}" ) result.extend(new) - return result + return bytes(result) @override async def write(self, data: bytes | bytearray, /) -> None: @@ -213,15 +213,15 @@ def remaining(self) -> int: return 65535 @override - async def read(self, _length: int, /) -> bytearray: + async def read(self, _length: int, /) -> bytes: """Read from :attr:`.stream`. Length does nothing here.""" data, _remote_addr = await asyncio.wait_for(self.stream.recv(), timeout=self.timeout) - return bytearray(data) + return data @override async def write(self, data: bytes | bytearray, /) -> None: """Send data with :attr:`.stream`.""" - await self.stream.send(data) + await self.stream.send(bytes(data)) def close(self) -> None: """Close :attr:`.stream`.""" diff --git a/mcstatus/responses/forge.py b/mcstatus/responses/forge.py index 70480860..03cf4b51 100644 --- a/mcstatus/responses/forge.py +++ b/mcstatus/responses/forge.py @@ -137,7 +137,7 @@ def __init__(self, stringio: StringIO) -> None: self.stringio = stringio self.received = bytearray() - def read(self, length: int, /) -> bytearray: + def read(self, length: int, /) -> bytes: """Read length bytes from ``self``, and return a byte array.""" data = bytearray() while self.received and len(data) < length: @@ -149,7 +149,7 @@ def read(self, length: int, /) -> bytearray: data.extend(result.encode("utf-16be")) while len(data) > length: self.received.append(data.pop()) - return data + return bytes(data) def remaining(self) -> int: """Return number of reads remaining.""" diff --git a/pyproject.toml b/pyproject.toml index 67418446..c923e6c0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -112,7 +112,7 @@ pythonPlatform = "All" pythonVersion = "3.10" typeCheckingMode = "standard" -disableBytesTypePromotions = false +disableBytesTypePromotions = true enableTypeIgnoreComments = false reportUnnecessaryTypeIgnoreComment = true