Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion examples/htool.c
Original file line number Diff line number Diff line change
Expand Up @@ -933,7 +933,9 @@ static const struct htool_cmd CMDS[] = {
.desc = "Perform payload update protocol for Titan images.",
.params =
(const struct htool_param[]){
{HTOOL_POSITIONAL, .name = "source-file"}, {}},
{HTOOL_POSITIONAL, .name = "source-file"},
{HTOOL_FLAG_BOOL, 's', "skip_erase", "false", .desc = "Skip erasing the staging side."},
{}},
.func = htool_payload_update,
},
{
Expand Down
10 changes: 9 additions & 1 deletion examples/htool_payload_update.c
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ int htool_payload_update(const struct htool_invocation* inv) {
return -1;
}

bool skip_erase;
if (htool_get_param_bool(inv, "skip_erase", &skip_erase)) {
return -1;
}

int fd = open(image_file, O_RDONLY, 0);
if (fd == -1) {
fprintf(stderr, "Error opening file %s: %s\n", image_file, strerror(errno));
Expand All @@ -64,7 +69,7 @@ int htool_payload_update(const struct htool_invocation* inv) {
}

enum payload_update_err payload_update_status =
libhoth_payload_update(dev, image, statbuf.st_size);
libhoth_payload_update(dev, image, statbuf.st_size, skip_erase);
switch (payload_update_status) {
case PAYLOAD_UPDATE_OK:
fprintf(stderr, "Payload update finished\n");
Expand All @@ -82,6 +87,9 @@ int htool_payload_update(const struct htool_invocation* inv) {
case PAYLOAD_UPDATE_FINALIZE_FAIL:
fprintf(stderr, "Error when finalizing.\n");
break;
case PAYLOAD_UPDATE_IMAGE_NOT_SECTOR_ALIGNED:
fprintf(stderr, "Payload image is not sector-aligned.\n");
break;
default:
break;
}
Expand Down
2 changes: 2 additions & 0 deletions protocol/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ cc_library(
":host_cmd",
":payload_info",
":util",
":progress",
"//transports:libhoth_device",
],
)
Expand All @@ -114,6 +115,7 @@ cc_test(
],
deps = [
":command_version",
":payload_info",
":payload_update",
"//protocol/test:libhoth_device_mock",
"@googletest//:gtest",
Expand Down
53 changes: 47 additions & 6 deletions protocol/payload_update.c
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <string.h>
#include <unistd.h>

#include "progress.h"
#include "command_version.h"
#include "host_cmd.h"
#include "payload_info.h"
Expand Down Expand Up @@ -82,24 +83,64 @@ static int libhoth_payload_update_finalize(
return 0;
}

static int payload_update_erase(struct libhoth_device* const dev, const size_t offset, const size_t len) {
struct payload_update_packet request;
request.type = PAYLOAD_UPDATE_ERASE;
request.offset = offset;
request.len = len;
return libhoth_hostcmd_exec(dev, HOTH_CMD_BOARD_SPECIFIC_BASE + HOTH_PRV_CMD_HOTH_PAYLOAD_UPDATE, 0, &request, sizeof(request), NULL, 0, NULL);
}

enum payload_update_err libhoth_payload_update(struct libhoth_device* dev,
uint8_t* image, size_t size) {
uint8_t* image, size_t size, bool skip_erase) {
if (libhoth_find_image_descriptor(image, size) == NULL) {
return PAYLOAD_UPDATE_BAD_IMG;
}

fprintf(stderr, "Initiating payload update protocol with libhoth.\n");
if (send_payload_update_request_with_command(dev, PAYLOAD_UPDATE_INITIATE) !=
0) {
return PAYLOAD_UPDATE_INITIATE_FAIL;
if (!skip_erase) {
struct libhoth_progress_stderr erase_progress;
libhoth_progress_stderr_init(&erase_progress, "Erase staging side");

const size_t block_erase = 64 * 1024;
const size_t sector_erase = 4 * 1024;

const bool is_image_size_sector_aligned = ((size % sector_erase) == 0);
if (!is_image_size_sector_aligned) {
fprintf(stderr, "error: image size (0x%zx) is not sector-aligned.\n", size);
return PAYLOAD_UPDATE_IMAGE_NOT_SECTOR_ALIGNED;
}

// Erase by blocks as much as possible.
size_t offset = 0;
for (; offset + block_erase <= size; offset += block_erase) {
const int ret = payload_update_erase(dev, offset, block_erase);
if (ret != 0) {
fprintf(stderr, "block erase err: %d\n", ret);
return ret;
}
erase_progress.progress.func(erase_progress.progress.param, offset, size);
}

// Erase remaining by sectors.
for (; offset + sector_erase <= size; offset += sector_erase) {
const int ret = payload_update_erase(dev, offset, sector_erase);
if (ret != 0) {
fprintf(stderr, "sector erase err: %d\n", ret);
return ret;
}
erase_progress.progress.func(erase_progress.progress.param, offset, size);
}
}

const size_t max_chunk_size = LIBHOTH_MAILBOX_SIZE -
sizeof(struct hoth_host_request) -
sizeof(struct payload_update_packet);

fprintf(stderr, "Flashing the image to hoth.\n");
struct libhoth_progress_stderr program_progress;
libhoth_progress_stderr_init(&program_progress, "Sending payload");
for (size_t offset = 0; offset < size; ++offset) {
program_progress.progress.func(program_progress.progress.param, offset, size);

if (image[offset] == 0xFF) {
continue;
}
Expand Down
4 changes: 3 additions & 1 deletion protocol/payload_update.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ extern "C" {
#endif

#include <stdint.h>
#include <stdbool.h>

#include "transports/libhoth_device.h"

Expand Down Expand Up @@ -55,6 +56,7 @@ enum payload_update_err {
PAYLOAD_UPDATE_FLASH_FAIL,
PAYLOAD_UPDATE_FINALIZE_FAIL,
PAYLOAD_UPDATE_READ_FAIL,
PAYLOAD_UPDATE_IMAGE_NOT_SECTOR_ALIGNED,
};

struct payload_update_packet {
Expand All @@ -72,7 +74,7 @@ struct payload_update_finalize_response_v1 {
} __attribute__((packed));

enum payload_update_err libhoth_payload_update(struct libhoth_device* dev,
uint8_t* image, size_t len);
uint8_t* image, size_t len, bool skip_erase);
int libhoth_payload_update_getstatus(
struct libhoth_device* dev, struct payload_update_status* update_status);
enum payload_update_err libhoth_payload_update_read_chunk(struct libhoth_device* dev, int fd, size_t len, size_t offset);
Expand Down
92 changes: 74 additions & 18 deletions protocol/payload_update_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <cstdint>

#include "command_version.h"
#include "payload_info.h"
#include "test/libhoth_device_mock.h"
#include <gmock/gmock.h>
#include <gtest/gtest.h>
Expand All @@ -33,13 +34,23 @@ constexpr int64_t kMagic = 0x5F435344474D495F;
constexpr int64_t kAlign = 1 << 16;
constexpr int64_t kDummy = 0;

MATCHER_P2(IsEraseRequest, offset, len, "") {
const uint8_t* data = static_cast<const uint8_t*>(arg);
const struct payload_update_packet* p =
reinterpret_cast<const struct payload_update_packet*>(
data + sizeof(struct hoth_host_request));
return p->type == PAYLOAD_UPDATE_ERASE &&
p->offset == static_cast<uint32_t>(offset) &&
p->len == static_cast<uint32_t>(len);
}

TEST_F(LibHothTest, payload_update_bad_image_test) {
EXPECT_CALL(mock_, send(_, UsesCommand(kCmd), _))
.WillRepeatedly(Return(LIBHOTH_OK));

uint8_t bad_buffer[100] = {0};

EXPECT_EQ(libhoth_payload_update(&hoth_dev_, bad_buffer, sizeof(bad_buffer)),
EXPECT_EQ(libhoth_payload_update(&hoth_dev_, bad_buffer, sizeof(bad_buffer), false),
PAYLOAD_UPDATE_BAD_IMG);
}

Expand Down Expand Up @@ -77,7 +88,7 @@ TEST_F(LibHothTest, payload_update_test) {
std::unique_ptr<uint8_t[]> buffer = std::make_unique<uint8_t[]>(2 * kAlign);
std::memcpy(buffer.get() + kAlign, &kMagic, sizeof(kMagic));

EXPECT_EQ(libhoth_payload_update(&hoth_dev_, buffer.get(), 2 * kAlign),
EXPECT_EQ(libhoth_payload_update(&hoth_dev_, buffer.get(), 2 * kAlign, false),
PAYLOAD_UPDATE_OK);
}

Expand All @@ -86,7 +97,6 @@ TEST_F(LibHothTest, payload_update_command_version_unsupported) {
InSequence s;

EXPECT_CALL(mock_, send(_, UsesCommand(kCmd), _))
.WillOnce(Return(LIBHOTH_OK))
.WillOnce(Return(LIBHOTH_OK));
EXPECT_CALL(mock_, send(_, UsesCommand(HOTH_CMD_GET_CMD_VERSIONS), _))
.WillOnce(Return(LIBHOTH_OK));
Expand All @@ -96,7 +106,6 @@ TEST_F(LibHothTest, payload_update_command_version_unsupported) {

static constexpr uint32_t kVersionMask = 0x1;
EXPECT_CALL(mock_, receive)
.WillOnce(DoAll(CopyResp(&kDummy, 0), Return(LIBHOTH_OK)))
.WillOnce(DoAll(CopyResp(&kDummy, 0), Return(LIBHOTH_OK)))
.WillOnce(DoAll(CopyResp(&kVersionMask, sizeof(kVersionMask)),
Return(LIBHOTH_OK)))
Expand All @@ -105,33 +114,32 @@ TEST_F(LibHothTest, payload_update_command_version_unsupported) {
uint8_t buffer[100] = {0};
std::memcpy(buffer, &kMagic, sizeof(kMagic));

EXPECT_EQ(libhoth_payload_update(&hoth_dev_, buffer, sizeof(buffer)),
EXPECT_EQ(libhoth_payload_update(&hoth_dev_, buffer, sizeof(buffer), true),
PAYLOAD_UPDATE_OK);
}

TEST_F(LibHothTest, payload_update_initiate_fail) {
TEST_F(LibHothTest, payload_update_erase_fail) {
EXPECT_CALL(mock_, send(_, UsesCommand(kCmd), _))
.WillRepeatedly(Return(LIBHOTH_OK));
EXPECT_CALL(mock_, receive).WillOnce(Return(-1));

uint8_t buffer[100] = {0};
uint8_t buffer[4096] = {0};
std::memcpy(buffer, &kMagic, sizeof(kMagic));

EXPECT_EQ(libhoth_payload_update(&hoth_dev_, buffer, sizeof(buffer)),
PAYLOAD_UPDATE_INITIATE_FAIL);
EXPECT_EQ(libhoth_payload_update(&hoth_dev_, buffer, sizeof(buffer), false),
-1);
}

TEST_F(LibHothTest, payload_update_flash_fail) {
EXPECT_CALL(mock_, send(_, UsesCommand(kCmd), _))
.WillRepeatedly(Return(LIBHOTH_OK));
EXPECT_CALL(mock_, receive)
.WillOnce(DoAll(CopyResp(&kDummy, 0), Return(LIBHOTH_OK)))
.WillOnce(DoAll(CopyResp(&kDummy, 0), Return(-1)));
.WillOnce(Return(-1));

uint8_t buffer[100] = {0};
std::memcpy(buffer, &kMagic, sizeof(kMagic));

EXPECT_EQ(libhoth_payload_update(&hoth_dev_, buffer, sizeof(buffer)),
EXPECT_EQ(libhoth_payload_update(&hoth_dev_, buffer, sizeof(buffer), true),
PAYLOAD_UPDATE_FLASH_FAIL);
}

Expand All @@ -140,21 +148,19 @@ TEST_F(LibHothTest, payload_update_command_version_fail) {
InSequence s;

EXPECT_CALL(mock_, send(_, UsesCommand(kCmd), _))
.WillOnce(Return(LIBHOTH_OK))
.WillOnce(Return(LIBHOTH_OK));
EXPECT_CALL(mock_, send(_, UsesCommand(HOTH_CMD_GET_CMD_VERSIONS), _))
.WillOnce(Return(LIBHOTH_OK));
}

EXPECT_CALL(mock_, receive)
.WillOnce(DoAll(CopyResp(&kDummy, 0), Return(LIBHOTH_OK)))
.WillOnce(DoAll(CopyResp(&kDummy, 0), Return(LIBHOTH_OK)))
.WillOnce(DoAll(CopyResp(&kDummy, 0), Return(-1)));

uint8_t buffer[100] = {0};
std::memcpy(buffer, &kMagic, sizeof(kMagic));

EXPECT_EQ(libhoth_payload_update(&hoth_dev_, buffer, sizeof(buffer)),
EXPECT_EQ(libhoth_payload_update(&hoth_dev_, buffer, sizeof(buffer), true),
PAYLOAD_UPDATE_FINALIZE_FAIL);
}

Expand All @@ -163,7 +169,6 @@ TEST_F(LibHothTest, payload_update_finalize_fail) {
InSequence s;

EXPECT_CALL(mock_, send(_, UsesCommand(kCmd), _))
.WillOnce(Return(LIBHOTH_OK))
.WillOnce(Return(LIBHOTH_OK));
EXPECT_CALL(mock_, send(_, UsesCommand(HOTH_CMD_GET_CMD_VERSIONS), _))
.WillOnce(Return(LIBHOTH_OK));
Expand All @@ -173,7 +178,6 @@ TEST_F(LibHothTest, payload_update_finalize_fail) {

static constexpr uint32_t kVersionMask = 0x1;
EXPECT_CALL(mock_, receive)
.WillOnce(DoAll(CopyResp(&kDummy, 0), Return(LIBHOTH_OK)))
.WillOnce(DoAll(CopyResp(&kDummy, 0), Return(LIBHOTH_OK)))
.WillOnce(DoAll(CopyResp(&kVersionMask, sizeof(kVersionMask)),
Return(LIBHOTH_OK)))
Expand All @@ -182,7 +186,7 @@ TEST_F(LibHothTest, payload_update_finalize_fail) {
uint8_t buffer[100] = {0};
std::memcpy(buffer, &kMagic, sizeof(kMagic));

EXPECT_EQ(libhoth_payload_update(&hoth_dev_, buffer, sizeof(buffer)),
EXPECT_EQ(libhoth_payload_update(&hoth_dev_, buffer, sizeof(buffer), true),
PAYLOAD_UPDATE_FINALIZE_FAIL);
}

Expand All @@ -203,3 +207,55 @@ TEST_F(LibHothTest, payload_update_status) {
EXPECT_EQ(exp_us.a_valid, us.a_valid);
EXPECT_EQ(exp_us.active_half, us.active_half);
}

TEST_F(LibHothTest, payload_update_erase_test) {
constexpr size_t kBlockErase = 64 * 1024;
constexpr size_t kSectorErase = 4 * 1024;
constexpr size_t kSize = kBlockErase + kSectorErase;
uint8_t buffer[kSize];
std::memset(buffer, 0xFF, kSize);

struct image_descriptor desc = {};
desc.descriptor_magic = TITAN_IMAGE_DESCRIPTOR_MAGIC;
desc.descriptor_area_size = sizeof(desc);
std::memcpy(buffer, &desc, sizeof(desc));

{
InSequence s;

// Block Erase
EXPECT_CALL(mock_, send(_, IsEraseRequest(0, kBlockErase), _))
.WillOnce(Return(LIBHOTH_OK));
EXPECT_CALL(mock_, receive)
.WillOnce(DoAll(CopyResp(&kDummy, 0), Return(LIBHOTH_OK)));

// Sector Erase
EXPECT_CALL(mock_, send(_, IsEraseRequest(kBlockErase, kSectorErase), _))
.WillOnce(Return(LIBHOTH_OK));
EXPECT_CALL(mock_, receive)
.WillOnce(DoAll(CopyResp(&kDummy, 0), Return(LIBHOTH_OK)));

// Flash
EXPECT_CALL(mock_, send(_, UsesCommand(kCmd), _))
.WillOnce(Return(LIBHOTH_OK));
EXPECT_CALL(mock_, receive)
.WillOnce(DoAll(CopyResp(&kDummy, 0), Return(LIBHOTH_OK)));

// Finalize version check
static constexpr uint32_t kVersionMask = 0;
EXPECT_CALL(mock_, send(_, UsesCommand(HOTH_CMD_GET_CMD_VERSIONS), _))
.WillOnce(Return(LIBHOTH_OK));
EXPECT_CALL(mock_, receive)
.WillOnce(DoAll(CopyResp(&kVersionMask, sizeof(kVersionMask)),
Return(LIBHOTH_OK)));

// Finalize
EXPECT_CALL(mock_, send(_, UsesCommand(kCmd), _))
.WillOnce(Return(LIBHOTH_OK));
EXPECT_CALL(mock_, receive)
.WillOnce(DoAll(CopyResp(&kDummy, 0), Return(LIBHOTH_OK)));
}

EXPECT_EQ(libhoth_payload_update(&hoth_dev_, buffer, kSize, false),
PAYLOAD_UPDATE_OK);
}
Loading