From 7da412bbf02b6018b597d7a5fc857b7911444cb7 Mon Sep 17 00:00:00 2001 From: Royce Rajan Date: Mon, 12 Jan 2026 16:13:51 -0800 Subject: [PATCH] Send many small ERASEs instead of a single, large INITIATE Helps get past transport timeouts. --- examples/htool.c | 4 +- examples/htool_payload_update.c | 10 +++- protocol/BUILD | 2 + protocol/payload_update.c | 53 ++++++++++++++++--- protocol/payload_update.h | 4 +- protocol/payload_update_test.cc | 92 ++++++++++++++++++++++++++------- 6 files changed, 138 insertions(+), 27 deletions(-) diff --git a/examples/htool.c b/examples/htool.c index d2df71b..d2fc86c 100644 --- a/examples/htool.c +++ b/examples/htool.c @@ -933,7 +933,9 @@ static const struct htool_cmd CMDS[] = { .desc = "Perform payload update protocol for Titan images.", .params = (const struct htool_param[]){ - {HTOOL_POSITIONAL, .name = "source-file"}, {}}, + {HTOOL_POSITIONAL, .name = "source-file"}, + {HTOOL_FLAG_BOOL, 's', "skip_erase", "false", .desc = "Skip erasing the staging side."}, + {}}, .func = htool_payload_update, }, { diff --git a/examples/htool_payload_update.c b/examples/htool_payload_update.c index 4336bbb..c26e77b 100644 --- a/examples/htool_payload_update.c +++ b/examples/htool_payload_update.c @@ -39,6 +39,11 @@ int htool_payload_update(const struct htool_invocation* inv) { return -1; } + bool skip_erase; + if (htool_get_param_bool(inv, "skip_erase", &skip_erase)) { + return -1; + } + int fd = open(image_file, O_RDONLY, 0); if (fd == -1) { fprintf(stderr, "Error opening file %s: %s\n", image_file, strerror(errno)); @@ -64,7 +69,7 @@ int htool_payload_update(const struct htool_invocation* inv) { } enum payload_update_err payload_update_status = - libhoth_payload_update(dev, image, statbuf.st_size); + libhoth_payload_update(dev, image, statbuf.st_size, skip_erase); switch (payload_update_status) { case PAYLOAD_UPDATE_OK: fprintf(stderr, "Payload update finished\n"); @@ -82,6 +87,9 @@ int htool_payload_update(const struct htool_invocation* inv) { case PAYLOAD_UPDATE_FINALIZE_FAIL: fprintf(stderr, "Error when finalizing.\n"); break; + case PAYLOAD_UPDATE_IMAGE_NOT_SECTOR_ALIGNED: + fprintf(stderr, "Payload image is not sector-aligned.\n"); + break; default: break; } diff --git a/protocol/BUILD b/protocol/BUILD index b4f7b60..ce055e0 100644 --- a/protocol/BUILD +++ b/protocol/BUILD @@ -102,6 +102,7 @@ cc_library( ":host_cmd", ":payload_info", ":util", + ":progress", "//transports:libhoth_device", ], ) @@ -114,6 +115,7 @@ cc_test( ], deps = [ ":command_version", + ":payload_info", ":payload_update", "//protocol/test:libhoth_device_mock", "@googletest//:gtest", diff --git a/protocol/payload_update.c b/protocol/payload_update.c index 3e4c7aa..c68dc69 100644 --- a/protocol/payload_update.c +++ b/protocol/payload_update.c @@ -19,6 +19,7 @@ #include #include +#include "progress.h" #include "command_version.h" #include "host_cmd.h" #include "payload_info.h" @@ -82,24 +83,64 @@ static int libhoth_payload_update_finalize( return 0; } +static int payload_update_erase(struct libhoth_device* const dev, const size_t offset, const size_t len) { + struct payload_update_packet request; + request.type = PAYLOAD_UPDATE_ERASE; + request.offset = offset; + request.len = len; + return libhoth_hostcmd_exec(dev, HOTH_CMD_BOARD_SPECIFIC_BASE + HOTH_PRV_CMD_HOTH_PAYLOAD_UPDATE, 0, &request, sizeof(request), NULL, 0, NULL); +} + enum payload_update_err libhoth_payload_update(struct libhoth_device* dev, - uint8_t* image, size_t size) { + uint8_t* image, size_t size, bool skip_erase) { if (libhoth_find_image_descriptor(image, size) == NULL) { return PAYLOAD_UPDATE_BAD_IMG; } - fprintf(stderr, "Initiating payload update protocol with libhoth.\n"); - if (send_payload_update_request_with_command(dev, PAYLOAD_UPDATE_INITIATE) != - 0) { - return PAYLOAD_UPDATE_INITIATE_FAIL; + if (!skip_erase) { + struct libhoth_progress_stderr erase_progress; + libhoth_progress_stderr_init(&erase_progress, "Erase staging side"); + + const size_t block_erase = 64 * 1024; + const size_t sector_erase = 4 * 1024; + + const bool is_image_size_sector_aligned = ((size % sector_erase) == 0); + if (!is_image_size_sector_aligned) { + fprintf(stderr, "error: image size (0x%zx) is not sector-aligned.\n", size); + return PAYLOAD_UPDATE_IMAGE_NOT_SECTOR_ALIGNED; + } + + // Erase by blocks as much as possible. + size_t offset = 0; + for (; offset + block_erase <= size; offset += block_erase) { + const int ret = payload_update_erase(dev, offset, block_erase); + if (ret != 0) { + fprintf(stderr, "block erase err: %d\n", ret); + return ret; + } + erase_progress.progress.func(erase_progress.progress.param, offset, size); + } + + // Erase remaining by sectors. + for (; offset + sector_erase <= size; offset += sector_erase) { + const int ret = payload_update_erase(dev, offset, sector_erase); + if (ret != 0) { + fprintf(stderr, "sector erase err: %d\n", ret); + return ret; + } + erase_progress.progress.func(erase_progress.progress.param, offset, size); + } } const size_t max_chunk_size = LIBHOTH_MAILBOX_SIZE - sizeof(struct hoth_host_request) - sizeof(struct payload_update_packet); - fprintf(stderr, "Flashing the image to hoth.\n"); + struct libhoth_progress_stderr program_progress; + libhoth_progress_stderr_init(&program_progress, "Sending payload"); for (size_t offset = 0; offset < size; ++offset) { + program_progress.progress.func(program_progress.progress.param, offset, size); + if (image[offset] == 0xFF) { continue; } diff --git a/protocol/payload_update.h b/protocol/payload_update.h index 78edc32..d6b37c0 100644 --- a/protocol/payload_update.h +++ b/protocol/payload_update.h @@ -20,6 +20,7 @@ extern "C" { #endif #include +#include #include "transports/libhoth_device.h" @@ -55,6 +56,7 @@ enum payload_update_err { PAYLOAD_UPDATE_FLASH_FAIL, PAYLOAD_UPDATE_FINALIZE_FAIL, PAYLOAD_UPDATE_READ_FAIL, + PAYLOAD_UPDATE_IMAGE_NOT_SECTOR_ALIGNED, }; struct payload_update_packet { @@ -72,7 +74,7 @@ struct payload_update_finalize_response_v1 { } __attribute__((packed)); enum payload_update_err libhoth_payload_update(struct libhoth_device* dev, - uint8_t* image, size_t len); + uint8_t* image, size_t len, bool skip_erase); int libhoth_payload_update_getstatus( struct libhoth_device* dev, struct payload_update_status* update_status); enum payload_update_err libhoth_payload_update_read_chunk(struct libhoth_device* dev, int fd, size_t len, size_t offset); diff --git a/protocol/payload_update_test.cc b/protocol/payload_update_test.cc index 4e6abcd..93bd479 100644 --- a/protocol/payload_update_test.cc +++ b/protocol/payload_update_test.cc @@ -17,6 +17,7 @@ #include #include "command_version.h" +#include "payload_info.h" #include "test/libhoth_device_mock.h" #include #include @@ -33,13 +34,23 @@ constexpr int64_t kMagic = 0x5F435344474D495F; constexpr int64_t kAlign = 1 << 16; constexpr int64_t kDummy = 0; +MATCHER_P2(IsEraseRequest, offset, len, "") { + const uint8_t* data = static_cast(arg); + const struct payload_update_packet* p = + reinterpret_cast( + data + sizeof(struct hoth_host_request)); + return p->type == PAYLOAD_UPDATE_ERASE && + p->offset == static_cast(offset) && + p->len == static_cast(len); +} + TEST_F(LibHothTest, payload_update_bad_image_test) { EXPECT_CALL(mock_, send(_, UsesCommand(kCmd), _)) .WillRepeatedly(Return(LIBHOTH_OK)); uint8_t bad_buffer[100] = {0}; - EXPECT_EQ(libhoth_payload_update(&hoth_dev_, bad_buffer, sizeof(bad_buffer)), + EXPECT_EQ(libhoth_payload_update(&hoth_dev_, bad_buffer, sizeof(bad_buffer), false), PAYLOAD_UPDATE_BAD_IMG); } @@ -77,7 +88,7 @@ TEST_F(LibHothTest, payload_update_test) { std::unique_ptr buffer = std::make_unique(2 * kAlign); std::memcpy(buffer.get() + kAlign, &kMagic, sizeof(kMagic)); - EXPECT_EQ(libhoth_payload_update(&hoth_dev_, buffer.get(), 2 * kAlign), + EXPECT_EQ(libhoth_payload_update(&hoth_dev_, buffer.get(), 2 * kAlign, false), PAYLOAD_UPDATE_OK); } @@ -86,7 +97,6 @@ TEST_F(LibHothTest, payload_update_command_version_unsupported) { InSequence s; EXPECT_CALL(mock_, send(_, UsesCommand(kCmd), _)) - .WillOnce(Return(LIBHOTH_OK)) .WillOnce(Return(LIBHOTH_OK)); EXPECT_CALL(mock_, send(_, UsesCommand(HOTH_CMD_GET_CMD_VERSIONS), _)) .WillOnce(Return(LIBHOTH_OK)); @@ -96,7 +106,6 @@ TEST_F(LibHothTest, payload_update_command_version_unsupported) { static constexpr uint32_t kVersionMask = 0x1; EXPECT_CALL(mock_, receive) - .WillOnce(DoAll(CopyResp(&kDummy, 0), Return(LIBHOTH_OK))) .WillOnce(DoAll(CopyResp(&kDummy, 0), Return(LIBHOTH_OK))) .WillOnce(DoAll(CopyResp(&kVersionMask, sizeof(kVersionMask)), Return(LIBHOTH_OK))) @@ -105,33 +114,32 @@ TEST_F(LibHothTest, payload_update_command_version_unsupported) { uint8_t buffer[100] = {0}; std::memcpy(buffer, &kMagic, sizeof(kMagic)); - EXPECT_EQ(libhoth_payload_update(&hoth_dev_, buffer, sizeof(buffer)), + EXPECT_EQ(libhoth_payload_update(&hoth_dev_, buffer, sizeof(buffer), true), PAYLOAD_UPDATE_OK); } -TEST_F(LibHothTest, payload_update_initiate_fail) { +TEST_F(LibHothTest, payload_update_erase_fail) { EXPECT_CALL(mock_, send(_, UsesCommand(kCmd), _)) .WillRepeatedly(Return(LIBHOTH_OK)); EXPECT_CALL(mock_, receive).WillOnce(Return(-1)); - uint8_t buffer[100] = {0}; + uint8_t buffer[4096] = {0}; std::memcpy(buffer, &kMagic, sizeof(kMagic)); - EXPECT_EQ(libhoth_payload_update(&hoth_dev_, buffer, sizeof(buffer)), - PAYLOAD_UPDATE_INITIATE_FAIL); + EXPECT_EQ(libhoth_payload_update(&hoth_dev_, buffer, sizeof(buffer), false), + -1); } TEST_F(LibHothTest, payload_update_flash_fail) { EXPECT_CALL(mock_, send(_, UsesCommand(kCmd), _)) .WillRepeatedly(Return(LIBHOTH_OK)); EXPECT_CALL(mock_, receive) - .WillOnce(DoAll(CopyResp(&kDummy, 0), Return(LIBHOTH_OK))) - .WillOnce(DoAll(CopyResp(&kDummy, 0), Return(-1))); + .WillOnce(Return(-1)); uint8_t buffer[100] = {0}; std::memcpy(buffer, &kMagic, sizeof(kMagic)); - EXPECT_EQ(libhoth_payload_update(&hoth_dev_, buffer, sizeof(buffer)), + EXPECT_EQ(libhoth_payload_update(&hoth_dev_, buffer, sizeof(buffer), true), PAYLOAD_UPDATE_FLASH_FAIL); } @@ -140,21 +148,19 @@ TEST_F(LibHothTest, payload_update_command_version_fail) { InSequence s; EXPECT_CALL(mock_, send(_, UsesCommand(kCmd), _)) - .WillOnce(Return(LIBHOTH_OK)) .WillOnce(Return(LIBHOTH_OK)); EXPECT_CALL(mock_, send(_, UsesCommand(HOTH_CMD_GET_CMD_VERSIONS), _)) .WillOnce(Return(LIBHOTH_OK)); } EXPECT_CALL(mock_, receive) - .WillOnce(DoAll(CopyResp(&kDummy, 0), Return(LIBHOTH_OK))) .WillOnce(DoAll(CopyResp(&kDummy, 0), Return(LIBHOTH_OK))) .WillOnce(DoAll(CopyResp(&kDummy, 0), Return(-1))); uint8_t buffer[100] = {0}; std::memcpy(buffer, &kMagic, sizeof(kMagic)); - EXPECT_EQ(libhoth_payload_update(&hoth_dev_, buffer, sizeof(buffer)), + EXPECT_EQ(libhoth_payload_update(&hoth_dev_, buffer, sizeof(buffer), true), PAYLOAD_UPDATE_FINALIZE_FAIL); } @@ -163,7 +169,6 @@ TEST_F(LibHothTest, payload_update_finalize_fail) { InSequence s; EXPECT_CALL(mock_, send(_, UsesCommand(kCmd), _)) - .WillOnce(Return(LIBHOTH_OK)) .WillOnce(Return(LIBHOTH_OK)); EXPECT_CALL(mock_, send(_, UsesCommand(HOTH_CMD_GET_CMD_VERSIONS), _)) .WillOnce(Return(LIBHOTH_OK)); @@ -173,7 +178,6 @@ TEST_F(LibHothTest, payload_update_finalize_fail) { static constexpr uint32_t kVersionMask = 0x1; EXPECT_CALL(mock_, receive) - .WillOnce(DoAll(CopyResp(&kDummy, 0), Return(LIBHOTH_OK))) .WillOnce(DoAll(CopyResp(&kDummy, 0), Return(LIBHOTH_OK))) .WillOnce(DoAll(CopyResp(&kVersionMask, sizeof(kVersionMask)), Return(LIBHOTH_OK))) @@ -182,7 +186,7 @@ TEST_F(LibHothTest, payload_update_finalize_fail) { uint8_t buffer[100] = {0}; std::memcpy(buffer, &kMagic, sizeof(kMagic)); - EXPECT_EQ(libhoth_payload_update(&hoth_dev_, buffer, sizeof(buffer)), + EXPECT_EQ(libhoth_payload_update(&hoth_dev_, buffer, sizeof(buffer), true), PAYLOAD_UPDATE_FINALIZE_FAIL); } @@ -203,3 +207,55 @@ TEST_F(LibHothTest, payload_update_status) { EXPECT_EQ(exp_us.a_valid, us.a_valid); EXPECT_EQ(exp_us.active_half, us.active_half); } + +TEST_F(LibHothTest, payload_update_erase_test) { + constexpr size_t kBlockErase = 64 * 1024; + constexpr size_t kSectorErase = 4 * 1024; + constexpr size_t kSize = kBlockErase + kSectorErase; + uint8_t buffer[kSize]; + std::memset(buffer, 0xFF, kSize); + + struct image_descriptor desc = {}; + desc.descriptor_magic = TITAN_IMAGE_DESCRIPTOR_MAGIC; + desc.descriptor_area_size = sizeof(desc); + std::memcpy(buffer, &desc, sizeof(desc)); + + { + InSequence s; + + // Block Erase + EXPECT_CALL(mock_, send(_, IsEraseRequest(0, kBlockErase), _)) + .WillOnce(Return(LIBHOTH_OK)); + EXPECT_CALL(mock_, receive) + .WillOnce(DoAll(CopyResp(&kDummy, 0), Return(LIBHOTH_OK))); + + // Sector Erase + EXPECT_CALL(mock_, send(_, IsEraseRequest(kBlockErase, kSectorErase), _)) + .WillOnce(Return(LIBHOTH_OK)); + EXPECT_CALL(mock_, receive) + .WillOnce(DoAll(CopyResp(&kDummy, 0), Return(LIBHOTH_OK))); + + // Flash + EXPECT_CALL(mock_, send(_, UsesCommand(kCmd), _)) + .WillOnce(Return(LIBHOTH_OK)); + EXPECT_CALL(mock_, receive) + .WillOnce(DoAll(CopyResp(&kDummy, 0), Return(LIBHOTH_OK))); + + // Finalize version check + static constexpr uint32_t kVersionMask = 0; + EXPECT_CALL(mock_, send(_, UsesCommand(HOTH_CMD_GET_CMD_VERSIONS), _)) + .WillOnce(Return(LIBHOTH_OK)); + EXPECT_CALL(mock_, receive) + .WillOnce(DoAll(CopyResp(&kVersionMask, sizeof(kVersionMask)), + Return(LIBHOTH_OK))); + + // Finalize + EXPECT_CALL(mock_, send(_, UsesCommand(kCmd), _)) + .WillOnce(Return(LIBHOTH_OK)); + EXPECT_CALL(mock_, receive) + .WillOnce(DoAll(CopyResp(&kDummy, 0), Return(LIBHOTH_OK))); + } + + EXPECT_EQ(libhoth_payload_update(&hoth_dev_, buffer, kSize, false), + PAYLOAD_UPDATE_OK); +}