Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
312 changes: 166 additions & 146 deletions src/lean_spec/subspecs/containers/state/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,56 +610,6 @@ def state_transition(self, block: Block, valid_signatures: bool = True) -> "Stat

return new_state

def _aggregate_signatures_from_gossip(
self,
validator_ids: list[Uint64],
data_root: Bytes32,
epoch: Slot,
gossip_signatures: dict[SignatureKey, "Signature"] | None = None,
) -> tuple[AggregatedSignatureProof, set[Uint64]] | None:
"""
Aggregate per-validator XMSS signatures into a single proof.

Returns:
A tuple of (proof, missing_validator_ids) or None if no signatures found.
The proof contains the participants bitfield.

Raises:
AggregationError: If the aggregation fails.
"""
if not gossip_signatures or not validator_ids:
return None

signatures: list[Signature] = []
public_keys: list[PublicKey] = []

included_validator_ids: set[Uint64] = set()
missing_validator_ids: set[Uint64] = set()

for validator_index in validator_ids:
# Attempt to retrieve the signature; fail fast if any are missing.
key = SignatureKey(validator_index, data_root)
if (sig := gossip_signatures.get(key)) is None:
missing_validator_ids.add(validator_index)
continue

included_validator_ids.add(validator_index)
signatures.append(sig)
public_keys.append(self.validators[validator_index].get_pubkey())

if not included_validator_ids:
return None

participants = AggregationBits.from_validator_indices(list(included_validator_ids))
proof = AggregatedSignatureProof.aggregate(
participants=participants,
public_keys=public_keys,
signatures=signatures,
message=data_root,
epoch=epoch,
)
return proof, missing_validator_ids

def build_block(
self,
slot: Slot,
Expand Down Expand Up @@ -797,121 +747,191 @@ def compute_aggregated_signatures(
"""
Compute aggregated signatures for a set of attestations.

Tries to aggregate all attestations together. If that fails, splits them greedily to
generate the minimal number of aggregated attestations.
This method implements a two-phase signature collection strategy:

Args:
attestations: The attestations to compute aggregated signatures for.
gossip_signatures: Optional per-validator XMSS signatures learned from gossip.
aggregated_payloads: Optional aggregated signature payloads learned from blocks.
1. **Gossip Phase**: For each attestation group, first attempt to collect
individual XMSS signatures from the gossip network. These are fresh
signatures that validators broadcast when they attest.

2. **Fallback Phase**: For any validators not covered by gossip, fall back
to previously-seen aggregated proofs from blocks. This uses a greedy
set-cover approach to minimize the number of proofs needed.

The result is a list of (attestation, proof) pairs ready for block inclusion.

Parameters
----------
attestations : list[Attestation]
Individual attestations to aggregate and sign.
gossip_signatures : dict[SignatureKey, Signature] | None
Per-validator XMSS signatures learned from the gossip network.
aggregated_payloads : dict[SignatureKey, list[AggregatedSignatureProof]] | None
Aggregated proofs learned from previously-seen blocks.

Returns:
A tuple of `(aggregated_attestations, aggregated_signatures)`.
-------
tuple[list[AggregatedAttestation], list[AggregatedSignatureProof]]
Paired attestations and their corresponding proofs.
"""
final_aggregated_attestations: list[AggregatedAttestation] = []
final_aggregated_proofs: list[AggregatedSignatureProof] = []
# Accumulator for (attestation, proof) pairs.
results: list[tuple[AggregatedAttestation, AggregatedSignatureProof]] = []

# Aggregate all the attestations into a single aggregated attestation.
completely_aggregated_attestations = AggregatedAttestation.aggregate_by_data(attestations)

# Try to compute the aggregated signatures for the single aggregated attestation.
#
# We will try to compute the aggregated signatures for the completely aggregated
# attestations.
# - either we can find per validator XMSS signatures from gossip, or
# - we can find at least one aggregated payload learned from a block that references
# this validator+data.
# Group individual attestations by data
#
# If the aggregated signatures cannot be computed, we will split the completely aggregated
# attestations in a greedy way.
for completely_aggregated_attestation in completely_aggregated_attestations:
validator_ids = (
completely_aggregated_attestation.aggregation_bits.to_validator_indices()
)
data_root = completely_aggregated_attestation.data.data_root_bytes()
slot = completely_aggregated_attestation.data.slot
# Multiple validators may attest to the same data (slot, head, target, source).
# We aggregate them into groups so each group can share a single proof.
for aggregated in AggregatedAttestation.aggregate_by_data(attestations):
# Extract the common attestation data and its hash.
#
# All validators in this group signed the same message (the data root).
data = aggregated.data
data_root = data.data_root_bytes()

proofs: list[AggregatedSignatureProof] = []
# Get the list of validators who attested to this data.
validator_ids = aggregated.aggregation_bits.to_validator_indices()

# Try to find per validator XMSS signatures from gossip.
gossip_result = self._aggregate_signatures_from_gossip(
validator_ids,
data_root,
slot,
gossip_signatures,
)
# Phase 1: Gossip Collection
#
# When a validator creates an attestation, it broadcasts the
# individual XMSS signature over the gossip network. If we have
# received these signatures, we can aggregate them ourselves.
#
# This is the preferred path: fresh signatures from the network.

if gossip_result is not None:
gossip_proof, remaining_validator_ids = gossip_result
proofs.append(gossip_proof)
# Parallel lists for signatures, public keys, and validator IDs.
gossip_sigs: list[Signature] = []
gossip_keys: list[PublicKey] = []
gossip_ids: list[Uint64] = []

# Track validators we couldn't find signatures for.
#
# These will need to be covered by Phase 2 (existing proofs).
remaining: set[Uint64] = set()

# Attempt to collect each validator's signature from gossip.
#
# Signatures are keyed by (validator ID, data root).
# - If a signature exists, we add it to our collection.
# - Otherwise, we mark that validator as "remaining" for the fallback phase.
if gossip_signatures:
for vid in validator_ids:
key = SignatureKey(vid, data_root)
if (sig := gossip_signatures.get(key)) is not None:
# Found a signature: collect it along with the public key.
gossip_sigs.append(sig)
gossip_keys.append(self.validators[vid].get_pubkey())
gossip_ids.append(vid)
else:
# No signature available: mark for fallback coverage.
remaining.add(vid)
else:
remaining_validator_ids = set(validator_ids)

# Pick existing aggregated proofs to cover remaining validators.
while remaining_validator_ids:
proof, remaining_validator_ids = self._pick_from_aggregated_proofs(
remaining_validator_ids,
data_root,
aggregated_payloads,
# No gossip data at all: all validators need fallback coverage.
remaining = set(validator_ids)

# If we collected any gossip signatures, aggregate them into a proof.
#
# The aggregation combines multiple XMSS signatures into a single
# compact proof that can verify all participants signed the message.
if gossip_ids:
participants = AggregationBits.from_validator_indices(gossip_ids)
proof = AggregatedSignatureProof.aggregate(
participants=participants,
public_keys=gossip_keys,
signatures=gossip_sigs,
message=data_root,
epoch=data.slot,
)
proofs.append(proof)

# TODO: Recursively aggregate the proofs. Since we currently don't support
# recursive aggregation, we just append each proof separately. This is fine
# for now, eventually we will recursively aggregate into one.
for proof in proofs:
final_aggregated_attestations.append(
AggregatedAttestation(
aggregation_bits=proof.participants,
data=completely_aggregated_attestation.data,
results.append(
(
AggregatedAttestation(aggregation_bits=participants, data=data),
proof,
)
)
final_aggregated_proofs.append(proof)

return final_aggregated_attestations, final_aggregated_proofs

def _pick_from_aggregated_proofs(
self,
remaining_validator_ids: set[Uint64],
data_root: Bytes32,
aggregated_payloads: dict[SignatureKey, list[AggregatedSignatureProof]] | None = None,
) -> tuple[AggregatedSignatureProof, set[Uint64]]:
"""
Pick an aggregated proof that covers the most remaining validators.

Args:
remaining_validator_ids: The validator ids still needing coverage.
data_root: The attestation data root.
aggregated_payloads: Previously learned proofs keyed by (validator_id, data_root).
# Phase 2: Fallback to existing proofs
#
# Some validators may not have broadcast their signatures over gossip,
# but we might have seen proofs for them in previously-received blocks.
#
# Example scenario:
#
# - We need signatures from validators {0, 1, 2, 3, 4}.
# - Gossip gave us signatures for {0, 1}.
# - Remaining: {2, 3, 4}.
# - From old blocks, we have:
# • Proof A covering {2, 3}
# • Proof B covering {3, 4}
# • Proof C covering {4}
#
# We want to cover {2, 3, 4} with as few proofs as possible.
# A greedy approach: always pick the proof with the largest overlap.
#
# - Iteration 1: Proof A covers {2, 3} (2 validators). Pick it.
# Remaining: {4}.
# - Iteration 2: Proof B covers {4} (1 validator). Pick it.
# Remaining: {} → done.
#
# Result: 2 proofs instead of 3.

Returns:
A tuple of (proof, remaining_validator_ids after this proof is applied).
while remaining and aggregated_payloads:
# Step 1: Find candidate proofs for a remaining validator.
#
# Proofs are indexed by (validator ID, data root). We pick any
# validator still in the remaining set and look up proofs that
# include them.
target_id = next(iter(remaining))
candidates = aggregated_payloads.get(SignatureKey(target_id, data_root), [])

Raises:
ValueError: If no suitable proof is found.
"""
if not remaining_validator_ids:
raise ValueError("remaining validator ids cannot be empty")
# No proofs found for this validator: stop the loop.
if not candidates:
break

if aggregated_payloads is None:
raise ValueError("aggregated payloads required when gossip coverage incomplete")
# Step 2: Pick the proof covering the most remaining validators.
#
# At each step, we select the single proof that eliminates the highest
# number of *currently missing* validators from our list.
#
# The 'score' of a candidate proof is defined as the size of the
# intersection between:
# A. The validators inside the proof (`p.participants`)
# B. The validators we still need (`remaining`)
#
# Example:
# Remaining needed : {Alice, Bob, Charlie}
# Proof 1 covers : {Alice, Dave} -> Score: 1 (Only Alice counts)
# Proof 2 covers : {Bob, Charlie, Eve} -> Score: 2 (Bob & Charlie count)
# -> Result: We pick Proof 2 because it has the highest score.
best, covered = max(
((p, set(p.participants.to_validator_indices())) for p in candidates),
# Calculate the intersection size (A ∩ B) for every candidate.
key=lambda pair: len(pair[1] & remaining),
)

best_proof: AggregatedSignatureProof | None = None
best_overlap: set[Uint64] = set()
best_remaining: set[Uint64] = set()
# Guard: If the best proof has zero overlap with remaining, stop.
if covered.isdisjoint(remaining):
break

representative_validator_id = next(iter(remaining_validator_ids))
key = SignatureKey(representative_validator_id, data_root)
# Step 3: Record the proof and remove covered validators.
#
# TODO: We don't support recursive aggregation yet.
# In the future, we should be able to aggregate the proofs into a single proof.
results.append(
(
AggregatedAttestation(aggregation_bits=best.participants, data=data),
best,
)
)
remaining -= covered

for proof in aggregated_payloads.get(key, []):
participants = set(proof.participants.to_validator_indices())
overlap = participants.intersection(remaining_validator_ids)
if len(overlap) > len(best_overlap):
best_proof = proof
best_overlap = overlap
best_remaining = remaining_validator_ids - overlap
# Final Assembly
#
# - We built a list of (attestation, proof) tuples.
# - Now we unzip them into two parallel lists for the return value.

if best_proof is None:
raise ValueError("Failed to locate aggregated proof for remaining validators")
# Handle the empty case explicitly.
if not results:
return [], []

return best_proof, best_remaining
# Unzip the results into parallel lists.
aggregated_attestations, aggregated_proofs = zip(*results, strict=True)
return list(aggregated_attestations), list(aggregated_proofs)
Loading
Loading