diff --git a/examples/usb/upgrade.py b/examples/usb/upgrade.py index f5b9d15..518d652 100644 --- a/examples/usb/upgrade.py +++ b/examples/usb/upgrade.py @@ -108,9 +108,10 @@ async def main() -> None: print("Connecting to SMP DUT...", end="", flush=True) async with SMPClient( SMPSerialTransport( - max_smp_encoded_frame_size=max_smp_encoded_frame_size, - line_length=line_length, - line_buffers=line_buffers, + fragmentation_strategy=SMPSerialTransport.BufferParams( + line_length=line_length, + line_buffers=line_buffers, + ) ), port_a.device, ) as client: @@ -187,9 +188,10 @@ async def ensure_request(request: SMPRequest[TRep, TEr1, TEr2]) -> TRep: print("Connecting to B SMP DUT...", end="", flush=True) async with SMPClient( SMPSerialTransport( - max_smp_encoded_frame_size=max_smp_encoded_frame_size, - line_length=line_length, - line_buffers=line_buffers, + fragmentation_strategy=SMPSerialTransport.BufferParams( + line_length=line_length, + line_buffers=line_buffers, + ) ), port_b.device, ) as client: diff --git a/smpclient/transport/serial.py b/smpclient/transport/serial.py index 3f6f976..baa03d4 100644 --- a/smpclient/transport/serial.py +++ b/smpclient/transport/serial.py @@ -10,8 +10,7 @@ import math import time from enum import IntEnum, unique -from functools import cached_property -from typing import Final +from typing import Final, NamedTuple from serial import Serial, SerialException from smp import packet as smppacket @@ -69,11 +68,29 @@ def __init__(self) -> None: self.state = SMPSerialTransport._ReadBuffer.State.SER """The state of the read buffer.""" + class Auto(NamedTuple): + """Automatically determine buffer parameters from the SMP server. + + On connect, queries the server's MCUMGR_PARAM for `buf_size` + (CONFIG_MCUMGR_TRANSPORT_NETBUF_SIZE) and calculates: + - line_length: 127 (standard MTU for uart/usb/shell) + - line_buffers: buf_size / line_length + + Falls back to BufferParams() if server doesn't support MCUMGR_PARAM. + """ + + class BufferParams(NamedTuple): + """Buffer parameters for the serial transport.""" + + line_length: int = 127 + """The maximum SMP packet size.""" + + line_buffers: int = 1 + """The number of line buffers in the serial buffer.""" + def __init__( # noqa: DOC301 self, - max_smp_encoded_frame_size: int = 256, - line_length: int = 128, - line_buffers: int = 2, + fragmentation_strategy: Auto | BufferParams = Auto(), baudrate: int = 115200, bytesize: int = 8, parity: str = "N", @@ -89,11 +106,10 @@ def __init__( # noqa: DOC301 """Initialize the serial transport. Args: - max_smp_encoded_frame_size: The maximum size of an encoded SMP - frame. The SMP server needs to have a buffer large enough to - receive the encoded frame packets and to store the decoded frame. - line_length: The maximum SMP packet size. - line_buffers: The number of line buffers in the serial buffer. + fragmentation_strategy: The fragmentation strategy to use. Either + `SMPSerialTransport.Auto()` to automatically determine buffer + parameters from the SMP server, or `SMPSerialTransport.BufferParams` + to manually specify buffer parameters. baudrate: The baudrate of the serial connection. OK to ignore for USB CDC ACM. bytesize: The number of data bits. @@ -108,18 +124,7 @@ def __init__( # noqa: DOC301 exclusive: The exclusive access timeout. """ - if max_smp_encoded_frame_size < line_length * line_buffers: - logger.error( - f"{max_smp_encoded_frame_size=} is less than {line_length=} * {line_buffers=}!" - ) - elif max_smp_encoded_frame_size != line_length * line_buffers: - logger.warning( - f"{max_smp_encoded_frame_size=} is not equal to {line_length=} * {line_buffers=}!" - ) - - self._max_smp_encoded_frame_size: Final = max_smp_encoded_frame_size - self._line_length: Final = line_length - self._line_buffers: Final = line_buffers + self._fragmentation_strategy: Final = fragmentation_strategy self._conn: Final = Serial( baudrate=baudrate, bytesize=bytesize, @@ -136,6 +141,62 @@ def __init__( # noqa: DOC301 self._buffer = SMPSerialTransport._ReadBuffer() logger.debug(f"Initialized {self.__class__.__name__}") + @property + def _line_length(self) -> int: + """The maximum SMP packet size.""" + if isinstance(self._fragmentation_strategy, SMPSerialTransport.Auto): + return self.BufferParams().line_length + else: + return self._fragmentation_strategy.line_length + + @property + def _line_buffers(self) -> int: + """The number of line buffers.""" + if isinstance(self._fragmentation_strategy, SMPSerialTransport.Auto): + if self._smp_server_transport_buffer_size is not None: + return self._smp_server_transport_buffer_size // self.BufferParams().line_length + return self.BufferParams().line_buffers + else: + return self._fragmentation_strategy.line_buffers + + @property + def _max_smp_encoded_frame_size(self) -> int: + """The maximum encoded frame size (line_length * line_buffers).""" + if isinstance(self._fragmentation_strategy, SMPSerialTransport.Auto): + if self._smp_server_transport_buffer_size is not None: + return self._smp_server_transport_buffer_size + return self._line_length * self._line_buffers + else: + return ( + self._fragmentation_strategy.line_length * self._fragmentation_strategy.line_buffers + ) + + @override + def initialize(self, smp_server_transport_buffer_size: int) -> None: + """Initialize with the server's buffer size from MCUMGR_PARAM. + + Args: + smp_server_transport_buffer_size: The server's CONFIG_MCUMGR_TRANSPORT_NETBUF_SIZE + """ + super().initialize(smp_server_transport_buffer_size) + + if isinstance(self._fragmentation_strategy, SMPSerialTransport.Auto): + logger.info( + f"Auto-configured from server: {self._line_length=}, " + f"{self._line_buffers=}, mtu={self._max_smp_encoded_frame_size}" + ) + else: + # Validate user's BufferParams against server capabilities + calculated_size = ( + self._fragmentation_strategy.line_length * self._fragmentation_strategy.line_buffers + ) + if calculated_size > smp_server_transport_buffer_size: + logger.warning( + f"BufferParams (line_length={self._fragmentation_strategy.line_length} * " + f"line_buffers={self._fragmentation_strategy.line_buffers} = {calculated_size}) " # noqa: E501 + f"exceeds server buffer size ({smp_server_transport_buffer_size})" + ) + @override async def connect(self, address: str, timeout_s: float) -> None: self._conn.port = address @@ -309,7 +370,7 @@ def mtu(self) -> int: return self._max_smp_encoded_frame_size @override - @cached_property + @property def max_unencoded_size(self) -> int: """The serial transport encodes each packet instead of sending SMP messages as raw bytes.""" diff --git a/tests/test_smp_client.py b/tests/test_smp_client.py index fd0dcf3..ae332cb 100644 --- a/tests/test_smp_client.py +++ b/tests/test_smp_client.py @@ -63,13 +63,14 @@ def aiter(iterable: Any) -> Any: class SMPMockTransport: """Satisfies the `SMPTransport` `Protocol`.""" + mtu = PropertyMock() + max_unencoded_size = PropertyMock() + def __init__(self) -> None: self.connect = AsyncMock() self.disconnect = AsyncMock() self.send = AsyncMock() self.receive = AsyncMock() - self.mtu = PropertyMock() - self.max_unencoded_size = PropertyMock() self._smp_server_transport_buffer_size: int | None = None self.initialize = AsyncMock() @@ -334,12 +335,16 @@ async def test_upload_hello_world_bin_encoded( pytest.skip("The line buffer size is too small") m = SMPSerialTransport( - max_smp_encoded_frame_size=max_smp_encoded_frame_size, - line_length=line_length, - line_buffers=line_buffers, + fragmentation_strategy=SMPSerialTransport.BufferParams( + line_length=line_length, + line_buffers=line_buffers, + ) ) s = SMPClient(m, "address") - assert s._transport.mtu == max_smp_encoded_frame_size + # MTU is line_length * line_buffers, which may be <= max_smp_encoded_frame_size + # due to integer division + assert s._transport.mtu == line_length * line_buffers + assert s._transport.mtu <= max_smp_encoded_frame_size packets: List[bytes] = [] @@ -565,12 +570,16 @@ async def test_file_upload_test_encoded(max_smp_encoded_frame_size: int, line_bu pytest.skip("The line buffer size is too small") m = SMPSerialTransport( - max_smp_encoded_frame_size=max_smp_encoded_frame_size, - line_length=line_length, - line_buffers=line_buffers, + fragmentation_strategy=SMPSerialTransport.BufferParams( + line_length=line_length, + line_buffers=line_buffers, + ) ) s = SMPClient(m, "address") - assert s._transport.mtu == max_smp_encoded_frame_size + # MTU is line_length * line_buffers, which may be <= max_smp_encoded_frame_size + # due to integer division + assert s._transport.mtu == line_length * line_buffers + assert s._transport.mtu <= max_smp_encoded_frame_size packets: List[bytes] = [] diff --git a/tests/test_smp_serial_transport.py b/tests/test_smp_serial_transport.py index def3ed5..fcd4e37 100644 --- a/tests/test_smp_serial_transport.py +++ b/tests/test_smp_serial_transport.py @@ -14,12 +14,23 @@ def test_constructor() -> None: + # Test with Auto() (default) t = SMPSerialTransport() assert isinstance(t._conn, Serial) - - t = SMPSerialTransport(max_smp_encoded_frame_size=512, line_length=128, line_buffers=4) + assert t.mtu == 127 # Default for Auto without initialize + assert t._line_length == 127 + assert t._line_buffers == 1 + assert t._max_smp_encoded_frame_size == 127 + + # Test with BufferParams + t = SMPSerialTransport( + fragmentation_strategy=SMPSerialTransport.BufferParams(line_length=128, line_buffers=4) + ) assert isinstance(t._conn, Serial) - assert t.mtu == 512 + assert t.mtu == 512 # 128 * 4 + assert t._line_length == 128 + assert t._line_buffers == 4 + assert t._max_smp_encoded_frame_size == 512 assert t.max_unencoded_size < 512 @@ -174,3 +185,53 @@ async def test_send_and_receive() -> None: t.send.assert_awaited_once_with(b"some data") t.receive.assert_awaited_once_with() + + +def test_initialize_with_auto() -> None: + """Test that Auto mode updates parameters based on server's buffer size.""" + t = SMPSerialTransport() # Uses Auto() by default + + # Before initialize, uses conservative defaults + assert t._line_length == 127 + assert t._line_buffers == 1 + assert t._max_smp_encoded_frame_size == 127 + + # After initialize with server buffer size + t.initialize(512) + assert t._line_length == 127 + assert t._line_buffers == 512 // 127 # 4 + assert t._max_smp_encoded_frame_size == 512 + assert t.mtu == 512 + + +def test_initialize_with_buffer_params() -> None: + """Test that BufferParams mode doesn't change user-specified parameters.""" + t = SMPSerialTransport( + fragmentation_strategy=SMPSerialTransport.BufferParams(line_length=128, line_buffers=2) + ) + + # Before initialize + assert t._line_length == 128 + assert t._line_buffers == 2 + assert t._max_smp_encoded_frame_size == 256 # 128 * 2 + + # After initialize - parameters should NOT change + t.initialize(512) + assert t._line_length == 128 + assert t._line_buffers == 2 + assert t._max_smp_encoded_frame_size == 256 + assert t.mtu == 256 + + +def test_initialize_with_buffer_params_warning(caplog: pytest.LogCaptureFixture) -> None: + """Test that a warning is logged when user's params exceed server buffer size.""" + t = SMPSerialTransport( + fragmentation_strategy=SMPSerialTransport.BufferParams( + line_length=128, line_buffers=4 # 128 * 4 = 512 + ) + ) + + with caplog.at_level(logging.WARNING): + t.initialize(256) # Server buffer (256) is smaller than calculated size (512) + + assert any("exceeds server buffer size" in record.message for record in caplog.records)