Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 10 additions & 11 deletions src/lean_spec/subspecs/xmss/subtree.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,8 @@ def new(
cls,
hasher: TweakHasher,
rand: Rand,
lowest_layer: int,
depth: int,
lowest_layer: Uint64,
depth: Uint64,
start_index: Uint64,
parameter: Parameter,
lowest_layer_nodes: list[HashDigestVector],
Expand Down Expand Up @@ -146,7 +146,7 @@ def new(
A `HashSubTree` containing all computed layers from `lowest_layer` to root.
"""
# Validate: nodes must fit in available positions at this layer.
max_positions = 1 << (depth - lowest_layer)
max_positions = 1 << int(depth - lowest_layer)
if int(start_index) + len(lowest_layer_nodes) > max_positions:
raise ValueError(
f"Overflow at layer {lowest_layer}: "
Expand All @@ -163,12 +163,11 @@ def new(
parent_start = current.start_index // Uint64(2)

# Hash each pair of siblings into their parent using zip for cleaner indexing.
parent_start_int = int(parent_start)
node_pairs = zip(current.nodes[::2], current.nodes[1::2], strict=True)
parents = [
hasher.apply(
parameter,
TreeTweak(level=level + 1, index=Uint64(parent_start_int + i)),
TreeTweak(level=level + 1, index=parent_start + Uint64(i)),
[left, right],
)
for i, (left, right) in enumerate(node_pairs)
Expand All @@ -179,8 +178,8 @@ def new(
layers.append(current)

return cls(
depth=Uint64(depth),
lowest_layer=Uint64(lowest_layer),
depth=depth,
lowest_layer=lowest_layer,
layers=HashTreeLayers(data=layers),
)

Expand Down Expand Up @@ -233,8 +232,8 @@ def new_top_tree(
return cls.new(
hasher=hasher,
rand=rand,
lowest_layer=depth // 2,
depth=depth,
lowest_layer=Uint64(depth // 2),
depth=Uint64(depth),
start_index=start_bottom_tree_index,
parameter=parameter,
lowest_layer_nodes=bottom_tree_roots,
Expand Down Expand Up @@ -300,8 +299,8 @@ def new_bottom_tree(
full_tree = cls.new(
hasher=hasher,
rand=rand,
lowest_layer=0,
depth=depth,
lowest_layer=Uint64(0),
depth=Uint64(depth),
start_index=bottom_tree_index * Uint64(leafs_per_tree),
parameter=parameter,
lowest_layer_nodes=leaves,
Expand Down
4 changes: 2 additions & 2 deletions tests/lean_spec/subspecs/xmss/test_merkle_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ def _run_commit_open_verify_roundtrip(
tree = HashSubTree.new(
hasher=hasher,
rand=rand,
lowest_layer=0,
depth=depth,
lowest_layer=Uint64(0),
depth=Uint64(depth),
start_index=Uint64(start_index),
parameter=parameter,
lowest_layer_nodes=leaf_hashes,
Expand Down
Loading