From 4a151cef3aea56ed27cda2261c2d36c7aa65982f Mon Sep 17 00:00:00 2001 From: Thomas Coratger Date: Wed, 31 Dec 2025 22:52:50 +0100 Subject: [PATCH 1/2] signature aggregation: unify into a single function --- .../subspecs/containers/state/state.py | 309 +++++++++--------- .../containers/test_state_aggregation.py | 183 ----------- 2 files changed, 163 insertions(+), 329 deletions(-) diff --git a/src/lean_spec/subspecs/containers/state/state.py b/src/lean_spec/subspecs/containers/state/state.py index cd700638..dc9859bd 100644 --- a/src/lean_spec/subspecs/containers/state/state.py +++ b/src/lean_spec/subspecs/containers/state/state.py @@ -610,56 +610,6 @@ def state_transition(self, block: Block, valid_signatures: bool = True) -> "Stat return new_state - def _aggregate_signatures_from_gossip( - self, - validator_ids: list[Uint64], - data_root: Bytes32, - epoch: Slot, - gossip_signatures: dict[SignatureKey, "Signature"] | None = None, - ) -> tuple[AggregatedSignatureProof, set[Uint64]] | None: - """ - Aggregate per-validator XMSS signatures into a single proof. - - Returns: - A tuple of (proof, missing_validator_ids) or None if no signatures found. - The proof contains the participants bitfield. - - Raises: - AggregationError: If the aggregation fails. - """ - if not gossip_signatures or not validator_ids: - return None - - signatures: list[Signature] = [] - public_keys: list[PublicKey] = [] - - included_validator_ids: set[Uint64] = set() - missing_validator_ids: set[Uint64] = set() - - for validator_index in validator_ids: - # Attempt to retrieve the signature; fail fast if any are missing. - key = SignatureKey(validator_index, data_root) - if (sig := gossip_signatures.get(key)) is None: - missing_validator_ids.add(validator_index) - continue - - included_validator_ids.add(validator_index) - signatures.append(sig) - public_keys.append(self.validators[validator_index].get_pubkey()) - - if not included_validator_ids: - return None - - participants = AggregationBits.from_validator_indices(list(included_validator_ids)) - proof = AggregatedSignatureProof.aggregate( - participants=participants, - public_keys=public_keys, - signatures=signatures, - message=data_root, - epoch=epoch, - ) - return proof, missing_validator_ids - def build_block( self, slot: Slot, @@ -797,121 +747,188 @@ def compute_aggregated_signatures( """ Compute aggregated signatures for a set of attestations. - Tries to aggregate all attestations together. If that fails, splits them greedily to - generate the minimal number of aggregated attestations. + This method implements a two-phase signature collection strategy: - Args: - attestations: The attestations to compute aggregated signatures for. - gossip_signatures: Optional per-validator XMSS signatures learned from gossip. - aggregated_payloads: Optional aggregated signature payloads learned from blocks. + 1. **Gossip Phase**: For each attestation group, first attempt to collect + individual XMSS signatures from the gossip network. These are fresh + signatures that validators broadcast when they attest. + + 2. **Fallback Phase**: For any validators not covered by gossip, fall back + to previously-seen aggregated proofs from blocks. This uses a greedy + set-cover approach to minimize the number of proofs needed. + + The result is a list of (attestation, proof) pairs ready for block inclusion. + + Parameters + ---------- + attestations : list[Attestation] + Individual attestations to aggregate and sign. + gossip_signatures : dict[SignatureKey, Signature] | None + Per-validator XMSS signatures learned from the gossip network. + aggregated_payloads : dict[SignatureKey, list[AggregatedSignatureProof]] | None + Aggregated proofs learned from previously-seen blocks. Returns: - A tuple of `(aggregated_attestations, aggregated_signatures)`. + ------- + tuple[list[AggregatedAttestation], list[AggregatedSignatureProof]] + Paired attestations and their corresponding proofs. """ - final_aggregated_attestations: list[AggregatedAttestation] = [] - final_aggregated_proofs: list[AggregatedSignatureProof] = [] - - # Aggregate all the attestations into a single aggregated attestation. - completely_aggregated_attestations = AggregatedAttestation.aggregate_by_data(attestations) + # Accumulator for (attestation, proof) pairs. + results: list[tuple[AggregatedAttestation, AggregatedSignatureProof]] = [] - # Try to compute the aggregated signatures for the single aggregated attestation. + # Group individual attestations by data # - # We will try to compute the aggregated signatures for the completely aggregated - # attestations. - # - either we can find per validator XMSS signatures from gossip, or - # - we can find at least one aggregated payload learned from a block that references - # this validator+data. - # - # If the aggregated signatures cannot be computed, we will split the completely aggregated - # attestations in a greedy way. - for completely_aggregated_attestation in completely_aggregated_attestations: - validator_ids = ( - completely_aggregated_attestation.aggregation_bits.to_validator_indices() - ) - data_root = completely_aggregated_attestation.data.data_root_bytes() - slot = completely_aggregated_attestation.data.slot + # Multiple validators may attest to the same data (slot, head, target, source). + # We aggregate them into groups so each group can share a single proof. + for aggregated in AggregatedAttestation.aggregate_by_data(attestations): + # Extract the common attestation data and its hash. + # + # All validators in this group signed the same message (the data root). + data = aggregated.data + data_root = data.data_root_bytes() - proofs: list[AggregatedSignatureProof] = [] + # Get the list of validators who attested to this data. + validator_ids = aggregated.aggregation_bits.to_validator_indices() - # Try to find per validator XMSS signatures from gossip. - gossip_result = self._aggregate_signatures_from_gossip( - validator_ids, - data_root, - slot, - gossip_signatures, - ) + # Phase 1: Gossip Collection + # + # When a validator creates an attestation, it broadcasts the + # individual XMSS signature over the gossip network. If we have + # received these signatures, we can aggregate them ourselves. + # + # This is the preferred path: fresh signatures from the network. + + # Parallel lists for signatures, public keys, and validator IDs. + gossip_sigs: list[Signature] = [] + gossip_keys: list[PublicKey] = [] + gossip_ids: list[Uint64] = [] + + # Track validators we couldn't find signatures for. + # + # These will need to be covered by Phase 2 (existing proofs). + remaining: set[Uint64] = set() - if gossip_result is not None: - gossip_proof, remaining_validator_ids = gossip_result - proofs.append(gossip_proof) + # Attempt to collect each validator's signature from gossip. + # + # Signatures are keyed by (validator ID, data root). + # - If a signature exists, we add it to our collection. + # - Otherwise, we mark that validator as "remaining" for the fallback phase. + if gossip_signatures: + for vid in validator_ids: + key = SignatureKey(vid, data_root) + if (sig := gossip_signatures.get(key)) is not None: + # Found a signature: collect it along with the public key. + gossip_sigs.append(sig) + gossip_keys.append(self.validators[vid].get_pubkey()) + gossip_ids.append(vid) + else: + # No signature available: mark for fallback coverage. + remaining.add(vid) else: - remaining_validator_ids = set(validator_ids) - - # Pick existing aggregated proofs to cover remaining validators. - while remaining_validator_ids: - proof, remaining_validator_ids = self._pick_from_aggregated_proofs( - remaining_validator_ids, - data_root, - aggregated_payloads, + # No gossip data at all: all validators need fallback coverage. + remaining = set(validator_ids) + + # If we collected any gossip signatures, aggregate them into a proof. + # + # The aggregation combines multiple XMSS signatures into a single + # compact proof that can verify all participants signed the message. + if gossip_ids: + participants = AggregationBits.from_validator_indices(gossip_ids) + proof = AggregatedSignatureProof.aggregate( + participants=participants, + public_keys=gossip_keys, + signatures=gossip_sigs, + message=data_root, + epoch=data.slot, ) - proofs.append(proof) - - # TODO: Recursively aggregate the proofs. Since we currently don't support - # recursive aggregation, we just append each proof separately. This is fine - # for now, eventually we will recursively aggregate into one. - for proof in proofs: - final_aggregated_attestations.append( - AggregatedAttestation( - aggregation_bits=proof.participants, - data=completely_aggregated_attestation.data, + results.append( + ( + AggregatedAttestation(aggregation_bits=participants, data=data), + proof, ) ) - final_aggregated_proofs.append(proof) - - return final_aggregated_attestations, final_aggregated_proofs - def _pick_from_aggregated_proofs( - self, - remaining_validator_ids: set[Uint64], - data_root: Bytes32, - aggregated_payloads: dict[SignatureKey, list[AggregatedSignatureProof]] | None = None, - ) -> tuple[AggregatedSignatureProof, set[Uint64]]: - """ - Pick an aggregated proof that covers the most remaining validators. - - Args: - remaining_validator_ids: The validator ids still needing coverage. - data_root: The attestation data root. - aggregated_payloads: Previously learned proofs keyed by (validator_id, data_root). + # Phase 2: Fallback to existing proofs + # + # Some validators may not have broadcast their signatures over gossip, + # but we might have seen proofs for them in previously-received blocks. + # + # Example scenario: + # + # - We need signatures from validators {0, 1, 2, 3, 4}. + # - Gossip gave us signatures for {0, 1}. + # - Remaining: {2, 3, 4}. + # - From old blocks, we have: + # • Proof A covering {2, 3} + # • Proof B covering {3, 4} + # • Proof C covering {4} + # + # We want to cover {2, 3, 4} with as few proofs as possible. + # A greedy approach: always pick the proof with the largest overlap. + # + # - Iteration 1: Proof A covers {2, 3} (2 validators). Pick it. + # Remaining: {4}. + # - Iteration 2: Proof B covers {4} (1 validator). Pick it. + # Remaining: {} → done. + # + # Result: 2 proofs instead of 3. - Returns: - A tuple of (proof, remaining_validator_ids after this proof is applied). + while remaining and aggregated_payloads: + # Step 1: Find candidate proofs for a remaining validator. + # + # Proofs are indexed by (validator ID, data root). We pick any + # validator still in the remaining set and look up proofs that + # include them. + target_id = next(iter(remaining)) + candidates = aggregated_payloads.get(SignatureKey(target_id, data_root), []) - Raises: - ValueError: If no suitable proof is found. - """ - if not remaining_validator_ids: - raise ValueError("remaining validator ids cannot be empty") + # No proofs found for this validator: stop the loop. + if not candidates: + break - if aggregated_payloads is None: - raise ValueError("aggregated payloads required when gossip coverage incomplete") + # Step 2: Pick the proof covering the most remaining validators. + # + # At each step, we select the single proof that eliminates the highest + # number of *currently missing* validators from our list. + # + # The 'score' of a candidate proof is defined as the size of the + # intersection between: + # A. The validators inside the proof (`p.participants`) + # B. The validators we still need (`remaining`) + # + # Example: + # Remaining needed : {Alice, Bob, Charlie} + # Proof 1 covers : {Alice, Dave} -> Score: 1 (Only Alice counts) + # Proof 2 covers : {Bob, Charlie, Eve} -> Score: 2 (Bob & Charlie count) + # -> Result: We pick Proof 2 because it has the highest score. + best, covered = max( + ((p, set(p.participants.to_validator_indices())) for p in candidates), + # Calculate the intersection size (A ∩ B) for every candidate. + key=lambda pair: len(pair[1] & remaining), + ) - best_proof: AggregatedSignatureProof | None = None - best_overlap: set[Uint64] = set() - best_remaining: set[Uint64] = set() + # Guard: If the best proof has zero overlap with remaining, stop. + if covered.isdisjoint(remaining): + break - representative_validator_id = next(iter(remaining_validator_ids)) - key = SignatureKey(representative_validator_id, data_root) + # Step 3: Record the proof and remove covered validators. + results.append( + ( + AggregatedAttestation(aggregation_bits=best.participants, data=data), + best, + ) + ) + remaining -= covered - for proof in aggregated_payloads.get(key, []): - participants = set(proof.participants.to_validator_indices()) - overlap = participants.intersection(remaining_validator_ids) - if len(overlap) > len(best_overlap): - best_proof = proof - best_overlap = overlap - best_remaining = remaining_validator_ids - overlap + # Final Assembly + # + # - We built a list of (attestation, proof) tuples. + # - Now we unzip them into two parallel lists for the return value. - if best_proof is None: - raise ValueError("Failed to locate aggregated proof for remaining validators") + # Handle the empty case explicitly. + if not results: + return [], [] - return best_proof, best_remaining + # Unzip the results into parallel lists. + aggregated_attestations, aggregated_proofs = zip(*results, strict=True) + return list(aggregated_attestations), list(aggregated_proofs) diff --git a/tests/lean_spec/subspecs/containers/test_state_aggregation.py b/tests/lean_spec/subspecs/containers/test_state_aggregation.py index bae36e83..3ca9e455 100644 --- a/tests/lean_spec/subspecs/containers/test_state_aggregation.py +++ b/tests/lean_spec/subspecs/containers/test_state_aggregation.py @@ -106,143 +106,6 @@ def make_attestation_data( ) -def test_gossip_aggregation_succeeds_with_all_signatures() -> None: - state = make_state(2) - data_root = b"\x11" * 32 - validator_ids = [Uint64(0), Uint64(1)] - gossip_signatures = { - SignatureKey(Uint64(0), Bytes32(data_root)): make_signature(0), - SignatureKey(Uint64(1), Bytes32(data_root)): make_signature(1), - } - - result = state._aggregate_signatures_from_gossip( - validator_ids, - Bytes32(data_root), - Slot(3), - gossip_signatures, - ) - - assert result is not None - proof, remaining = result - assert set(proof.participants.to_validator_indices()) == set(validator_ids) - assert remaining == set() - - -def test_gossip_aggregation_returns_partial_result_when_some_missing() -> None: - state = make_state(2) - data_root = b"\x22" * 32 - gossip_signatures = {SignatureKey(Uint64(0), Bytes32(data_root)): make_signature(0)} - - result = state._aggregate_signatures_from_gossip( - [Uint64(0), Uint64(1)], - Bytes32(data_root), - Slot(2), - gossip_signatures, - ) - - assert result is not None - proof, remaining = result - assert proof.participants.to_validator_indices() == [Uint64(0)] - assert remaining == {Uint64(1)} - - -def test_gossip_aggregation_returns_none_if_no_signature_matches() -> None: - state = make_state(2) - data_root = b"\x33" * 32 - # Gossip data exists but for a different validator key, so no signatures match - gossip_signatures = {SignatureKey(Uint64(9), Bytes32(data_root)): make_signature(0)} - - result = state._aggregate_signatures_from_gossip( - [Uint64(0), Uint64(1)], - Bytes32(data_root), - Slot(2), - gossip_signatures, - ) - - assert result is None - - -def test_pick_from_aggregated_proofs_prefers_widest_overlap() -> None: - state = make_state(3) - data_root = b"\x44" * 32 - remaining_validator_ids = {Uint64(0), Uint64(1)} - - narrow_proof = make_test_proof([Uint64(0)], b"narrow") - best_proof = make_test_proof([Uint64(0), Uint64(1)], b"best") - - aggregated_payloads = { - SignatureKey(Uint64(0), Bytes32(data_root)): [narrow_proof, best_proof], - SignatureKey(Uint64(1), Bytes32(data_root)): [best_proof, narrow_proof], - } - - proof, remaining = state._pick_from_aggregated_proofs( - remaining_validator_ids=remaining_validator_ids, - data_root=Bytes32(data_root), - aggregated_payloads=aggregated_payloads, - ) - - assert set(proof.participants.to_validator_indices()) == {Uint64(0), Uint64(1)} - assert remaining == set() - - -def test_pick_from_aggregated_proofs_returns_remaining_for_partial_payload() -> None: - state = make_state(2) - data_root = b"\x45" * 32 - remaining_validator_ids = {Uint64(0), Uint64(1)} - - partial_proof_0 = make_test_proof([Uint64(0)], b"partial-0") - partial_proof_1 = make_test_proof([Uint64(1)], b"partial-1") - - aggregated_payloads = { - SignatureKey(Uint64(0), Bytes32(data_root)): [partial_proof_0], - SignatureKey(Uint64(1), Bytes32(data_root)): [partial_proof_1], - } - - proof, remaining = state._pick_from_aggregated_proofs( - remaining_validator_ids=remaining_validator_ids, - data_root=Bytes32(data_root), - aggregated_payloads=aggregated_payloads, - ) - - covered_validators = set(proof.participants.to_validator_indices()) - assert covered_validators <= {Uint64(0), Uint64(1)} - assert remaining == remaining_validator_ids - covered_validators - - -def test_pick_from_aggregated_proofs_requires_payloads() -> None: - state = make_state(1) - - with pytest.raises(ValueError, match="aggregated payloads required"): - state._pick_from_aggregated_proofs( - remaining_validator_ids={Uint64(0)}, - data_root=Bytes32(b"\x55" * 32), - aggregated_payloads=None, - ) - - -def test_pick_from_aggregated_proofs_errors_on_empty_remaining() -> None: - state = make_state(1) - - with pytest.raises(ValueError, match="remaining validator ids cannot be empty"): - state._pick_from_aggregated_proofs( - remaining_validator_ids=set(), - data_root=Bytes32(b"\x66" * 32), - aggregated_payloads={}, - ) - - -def test_pick_from_aggregated_proofs_errors_when_no_candidates() -> None: - state = make_state(1) - data_root = b"\x77" * 32 - - with pytest.raises(ValueError, match="Failed to locate aggregated proof"): - state._pick_from_aggregated_proofs( - remaining_validator_ids={Uint64(0)}, - data_root=Bytes32(data_root), - aggregated_payloads={SignatureKey(Uint64(0), Bytes32(data_root)): []}, - ) - - def test_compute_aggregated_signatures_prefers_full_gossip_payload() -> None: state = make_state(2) source = Checkpoint(root=make_bytes32(1), slot=Slot(0)) @@ -373,52 +236,6 @@ def test_build_block_skips_attestations_without_signatures() -> None: assert list(block.body.attestations.data) == [] -def test_gossip_aggregation_with_empty_validator_list() -> None: - """Empty validator list should return None.""" - state = make_state(2) - data_root = b"\x99" * 32 - gossip_signatures = {SignatureKey(Uint64(0), Bytes32(data_root)): make_signature(0)} - - result = state._aggregate_signatures_from_gossip( - [], # empty validator list - Bytes32(data_root), - Slot(1), - gossip_signatures, - ) - - assert result is None - - -def test_gossip_aggregation_with_none_gossip_signatures() -> None: - """None gossip_signatures should return None.""" - state = make_state(2) - data_root = b"\x88" * 32 - - result = state._aggregate_signatures_from_gossip( - [Uint64(0), Uint64(1)], - Bytes32(data_root), - Slot(1), - None, # None gossip_signatures - ) - - assert result is None - - -def test_gossip_aggregation_with_empty_gossip_signatures() -> None: - """Empty gossip_signatures dict should return None.""" - state = make_state(2) - data_root = b"\x77" * 32 - - result = state._aggregate_signatures_from_gossip( - [Uint64(0), Uint64(1)], - Bytes32(data_root), - Slot(1), - {}, # empty dict - ) - - assert result is None - - def test_compute_aggregated_signatures_with_empty_attestations() -> None: """Empty attestations list should return empty results.""" state = make_state(2) From b742640ceaa3b03797f22c578774e2532e37c8cb Mon Sep 17 00:00:00 2001 From: Thomas Coratger Date: Wed, 31 Dec 2025 22:58:18 +0100 Subject: [PATCH 2/2] cleanup --- src/lean_spec/subspecs/containers/state/state.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/lean_spec/subspecs/containers/state/state.py b/src/lean_spec/subspecs/containers/state/state.py index dc9859bd..d24a4d95 100644 --- a/src/lean_spec/subspecs/containers/state/state.py +++ b/src/lean_spec/subspecs/containers/state/state.py @@ -912,6 +912,9 @@ def compute_aggregated_signatures( break # Step 3: Record the proof and remove covered validators. + # + # TODO: We don't support recursive aggregation yet. + # In the future, we should be able to aggregate the proofs into a single proof. results.append( ( AggregatedAttestation(aggregation_bits=best.participants, data=data),