Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions examples/usb/upgrade.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,10 @@ async def main() -> None:
print("Connecting to SMP DUT...", end="", flush=True)
async with SMPClient(
SMPSerialTransport(
max_smp_encoded_frame_size=max_smp_encoded_frame_size,
line_length=line_length,
line_buffers=line_buffers,
fragmentation_strategy=SMPSerialTransport.BufferParams(
line_length=line_length,
line_buffers=line_buffers,
)
),
port_a.device,
) as client:
Expand Down Expand Up @@ -187,9 +188,10 @@ async def ensure_request(request: SMPRequest[TRep, TEr1, TEr2]) -> TRep:
print("Connecting to B SMP DUT...", end="", flush=True)
async with SMPClient(
SMPSerialTransport(
max_smp_encoded_frame_size=max_smp_encoded_frame_size,
line_length=line_length,
line_buffers=line_buffers,
fragmentation_strategy=SMPSerialTransport.BufferParams(
line_length=line_length,
line_buffers=line_buffers,
)
),
port_b.device,
) as client:
Expand Down
107 changes: 84 additions & 23 deletions smpclient/transport/serial.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
import math
import time
from enum import IntEnum, unique
from functools import cached_property
from typing import Final
from typing import Final, NamedTuple

from serial import Serial, SerialException
from smp import packet as smppacket
Expand Down Expand Up @@ -69,11 +68,29 @@ def __init__(self) -> None:
self.state = SMPSerialTransport._ReadBuffer.State.SER
"""The state of the read buffer."""

class Auto(NamedTuple):
"""Automatically determine buffer parameters from the SMP server.

On connect, queries the server's MCUMGR_PARAM for `buf_size`
(CONFIG_MCUMGR_TRANSPORT_NETBUF_SIZE) and calculates:
- line_length: 127 (standard MTU for uart/usb/shell)
- line_buffers: buf_size / line_length

Falls back to BufferParams() if server doesn't support MCUMGR_PARAM.
"""

class BufferParams(NamedTuple):
"""Buffer parameters for the serial transport."""

line_length: int = 127
"""The maximum SMP packet size."""

line_buffers: int = 1
"""The number of line buffers in the serial buffer."""

def __init__( # noqa: DOC301
self,
max_smp_encoded_frame_size: int = 256,
line_length: int = 128,
line_buffers: int = 2,
fragmentation_strategy: Auto | BufferParams = Auto(),
baudrate: int = 115200,
bytesize: int = 8,
parity: str = "N",
Expand All @@ -89,11 +106,10 @@ def __init__( # noqa: DOC301
"""Initialize the serial transport.

Args:
max_smp_encoded_frame_size: The maximum size of an encoded SMP
frame. The SMP server needs to have a buffer large enough to
receive the encoded frame packets and to store the decoded frame.
line_length: The maximum SMP packet size.
line_buffers: The number of line buffers in the serial buffer.
fragmentation_strategy: The fragmentation strategy to use. Either
`SMPSerialTransport.Auto()` to automatically determine buffer
parameters from the SMP server, or `SMPSerialTransport.BufferParams`
to manually specify buffer parameters.
baudrate: The baudrate of the serial connection. OK to ignore for
USB CDC ACM.
bytesize: The number of data bits.
Expand All @@ -108,18 +124,7 @@ def __init__( # noqa: DOC301
exclusive: The exclusive access timeout.

"""
if max_smp_encoded_frame_size < line_length * line_buffers:
logger.error(
f"{max_smp_encoded_frame_size=} is less than {line_length=} * {line_buffers=}!"
)
elif max_smp_encoded_frame_size != line_length * line_buffers:
logger.warning(
f"{max_smp_encoded_frame_size=} is not equal to {line_length=} * {line_buffers=}!"
)

self._max_smp_encoded_frame_size: Final = max_smp_encoded_frame_size
self._line_length: Final = line_length
self._line_buffers: Final = line_buffers
self._fragmentation_strategy: Final = fragmentation_strategy
self._conn: Final = Serial(
baudrate=baudrate,
bytesize=bytesize,
Expand All @@ -136,6 +141,62 @@ def __init__( # noqa: DOC301
self._buffer = SMPSerialTransport._ReadBuffer()
logger.debug(f"Initialized {self.__class__.__name__}")

@property
def _line_length(self) -> int:
"""The maximum SMP packet size."""
if isinstance(self._fragmentation_strategy, SMPSerialTransport.Auto):
return self.BufferParams().line_length
else:
return self._fragmentation_strategy.line_length
Comment on lines +147 to +150
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that we dropped 3.9 support, so this should be replaced with exhaustive match case.


@property
def _line_buffers(self) -> int:
"""The number of line buffers."""
if isinstance(self._fragmentation_strategy, SMPSerialTransport.Auto):
if self._smp_server_transport_buffer_size is not None:
return self._smp_server_transport_buffer_size // self.BufferParams().line_length
return self.BufferParams().line_buffers
else:
return self._fragmentation_strategy.line_buffers

@property
def _max_smp_encoded_frame_size(self) -> int:
"""The maximum encoded frame size (line_length * line_buffers)."""
if isinstance(self._fragmentation_strategy, SMPSerialTransport.Auto):
if self._smp_server_transport_buffer_size is not None:
return self._smp_server_transport_buffer_size
return self._line_length * self._line_buffers
else:
return (
self._fragmentation_strategy.line_length * self._fragmentation_strategy.line_buffers
)

@override
def initialize(self, smp_server_transport_buffer_size: int) -> None:
"""Initialize with the server's buffer size from MCUMGR_PARAM.

Args:
smp_server_transport_buffer_size: The server's CONFIG_MCUMGR_TRANSPORT_NETBUF_SIZE
"""
super().initialize(smp_server_transport_buffer_size)

if isinstance(self._fragmentation_strategy, SMPSerialTransport.Auto):
logger.info(
f"Auto-configured from server: {self._line_length=}, "
f"{self._line_buffers=}, mtu={self._max_smp_encoded_frame_size}"
)
else:
# Validate user's BufferParams against server capabilities
calculated_size = (
self._fragmentation_strategy.line_length * self._fragmentation_strategy.line_buffers
)
if calculated_size > smp_server_transport_buffer_size:
logger.warning(
f"BufferParams (line_length={self._fragmentation_strategy.line_length} * "
f"line_buffers={self._fragmentation_strategy.line_buffers} = {calculated_size}) " # noqa: E501
f"exceeds server buffer size ({smp_server_transport_buffer_size})"
)

@override
async def connect(self, address: str, timeout_s: float) -> None:
self._conn.port = address
Expand Down Expand Up @@ -309,7 +370,7 @@ def mtu(self) -> int:
return self._max_smp_encoded_frame_size

@override
@cached_property
@property
def max_unencoded_size(self) -> int:
"""The serial transport encodes each packet instead of sending SMP messages as raw bytes."""

Expand Down
29 changes: 19 additions & 10 deletions tests/test_smp_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,14 @@ def aiter(iterable: Any) -> Any:
class SMPMockTransport:
"""Satisfies the `SMPTransport` `Protocol`."""

mtu = PropertyMock()
max_unencoded_size = PropertyMock()

def __init__(self) -> None:
self.connect = AsyncMock()
self.disconnect = AsyncMock()
self.send = AsyncMock()
self.receive = AsyncMock()
self.mtu = PropertyMock()
self.max_unencoded_size = PropertyMock()
self._smp_server_transport_buffer_size: int | None = None
self.initialize = AsyncMock()

Expand Down Expand Up @@ -334,12 +335,16 @@ async def test_upload_hello_world_bin_encoded(
pytest.skip("The line buffer size is too small")

m = SMPSerialTransport(
max_smp_encoded_frame_size=max_smp_encoded_frame_size,
line_length=line_length,
line_buffers=line_buffers,
fragmentation_strategy=SMPSerialTransport.BufferParams(
line_length=line_length,
line_buffers=line_buffers,
)
)
s = SMPClient(m, "address")
assert s._transport.mtu == max_smp_encoded_frame_size
# MTU is line_length * line_buffers, which may be <= max_smp_encoded_frame_size
# due to integer division
assert s._transport.mtu == line_length * line_buffers
assert s._transport.mtu <= max_smp_encoded_frame_size

packets: List[bytes] = []

Expand Down Expand Up @@ -565,12 +570,16 @@ async def test_file_upload_test_encoded(max_smp_encoded_frame_size: int, line_bu
pytest.skip("The line buffer size is too small")

m = SMPSerialTransport(
max_smp_encoded_frame_size=max_smp_encoded_frame_size,
line_length=line_length,
line_buffers=line_buffers,
fragmentation_strategy=SMPSerialTransport.BufferParams(
line_length=line_length,
line_buffers=line_buffers,
)
)
s = SMPClient(m, "address")
assert s._transport.mtu == max_smp_encoded_frame_size
# MTU is line_length * line_buffers, which may be <= max_smp_encoded_frame_size
# due to integer division
assert s._transport.mtu == line_length * line_buffers
assert s._transport.mtu <= max_smp_encoded_frame_size

packets: List[bytes] = []

Expand Down
67 changes: 64 additions & 3 deletions tests/test_smp_serial_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,23 @@


def test_constructor() -> None:
# Test with Auto() (default)
t = SMPSerialTransport()
assert isinstance(t._conn, Serial)

t = SMPSerialTransport(max_smp_encoded_frame_size=512, line_length=128, line_buffers=4)
assert t.mtu == 127 # Default for Auto without initialize
assert t._line_length == 127
assert t._line_buffers == 1
assert t._max_smp_encoded_frame_size == 127

# Test with BufferParams
t = SMPSerialTransport(
fragmentation_strategy=SMPSerialTransport.BufferParams(line_length=128, line_buffers=4)
)
assert isinstance(t._conn, Serial)
assert t.mtu == 512
assert t.mtu == 512 # 128 * 4
assert t._line_length == 128
assert t._line_buffers == 4
assert t._max_smp_encoded_frame_size == 512
assert t.max_unencoded_size < 512


Expand Down Expand Up @@ -174,3 +185,53 @@ async def test_send_and_receive() -> None:

t.send.assert_awaited_once_with(b"some data")
t.receive.assert_awaited_once_with()


def test_initialize_with_auto() -> None:
"""Test that Auto mode updates parameters based on server's buffer size."""
t = SMPSerialTransport() # Uses Auto() by default

# Before initialize, uses conservative defaults
assert t._line_length == 127
assert t._line_buffers == 1
assert t._max_smp_encoded_frame_size == 127

# After initialize with server buffer size
t.initialize(512)
assert t._line_length == 127
assert t._line_buffers == 512 // 127 # 4
assert t._max_smp_encoded_frame_size == 512
assert t.mtu == 512


def test_initialize_with_buffer_params() -> None:
"""Test that BufferParams mode doesn't change user-specified parameters."""
t = SMPSerialTransport(
fragmentation_strategy=SMPSerialTransport.BufferParams(line_length=128, line_buffers=2)
)

# Before initialize
assert t._line_length == 128
assert t._line_buffers == 2
assert t._max_smp_encoded_frame_size == 256 # 128 * 2

# After initialize - parameters should NOT change
t.initialize(512)
assert t._line_length == 128
assert t._line_buffers == 2
assert t._max_smp_encoded_frame_size == 256
assert t.mtu == 256


def test_initialize_with_buffer_params_warning(caplog: pytest.LogCaptureFixture) -> None:
"""Test that a warning is logged when user's params exceed server buffer size."""
t = SMPSerialTransport(
fragmentation_strategy=SMPSerialTransport.BufferParams(
line_length=128, line_buffers=4 # 128 * 4 = 512
)
)

with caplog.at_level(logging.WARNING):
t.initialize(256) # Server buffer (256) is smaller than calculated size (512)

assert any("exceeds server buffer size" in record.message for record in caplog.records)
Loading