Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 62 additions & 64 deletions analysis/ceca.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,16 @@
# Copyright lowRISC contributors.
# Licensed under the Apache License, Version 2.0, see LICENSE for details.
# SPDX-License-Identifier: Apache-2.0
"""A distributed implementation of the correlation-enhanced power analysis
collision attack.

See "Correlation-Enhanced Power Analysis Collision Attack" by A. Moradi, O.
Mischke, and T. Eisenbarth (https://eprint.iacr.org/2010/297.pdf) for more
information.

Typical usage:
>>> ./ceca.py -f PROJECT_FILE -n 400000 -w 5 -a 117 127 -d output -s 3
"""

import argparse
import enum
Expand All @@ -25,22 +35,12 @@
from capture.project_library.project import ProjectConfig # noqa : E402
from capture.project_library.project import SCAProject # noqa : E402

"""A distributed implementation of the correlation-enhanced power analysis
collision attack.

See "Correlation-Enhanced Power Analysis Collision Attack" by A. Moradi, O.
Mischke, and T. Eisenbarth (https://eprint.iacr.org/2010/297.pdf) for more
information.

Typical usage:
>>> ./ceca.py -f PROJECT_FILE -n 400000 -w 5 -a 117 127 -d output -s 3
"""


def timer():
"""A customization of the ``codetiming.Timer`` decorator."""

def decorator(func):

@codetiming.Timer(
name=func.__name__,
text=f"{func.__name__} took {{seconds:.1f}}s",
Expand Down Expand Up @@ -79,7 +79,8 @@ class TraceWorker:
>>> results = ray.get(tasks)
"""

def __init__(self, project_file, trace_slice, attack_window, attack_direction):
def __init__(self, project_file, trace_slice, attack_window,
attack_direction):
"""Inits a TraceWorker.

Args:
Expand All @@ -94,26 +95,27 @@ def __init__(self, project_file, trace_slice, attack_window, attack_direction):
project_type = "ot_trace_library"

# Open the project.
project_cfg = ProjectConfig(
type=project_type, path=project_file, wave_dtype=np.uint16, overwrite=False
)
project_cfg = ProjectConfig(type=project_type,
path=project_file,
wave_dtype=np.uint16,
overwrite=False)
self.project = SCAProject(project_cfg)
self.project.open_project()

# TODO: Consider more efficient formats.
self.num_samples = attack_window.stop - attack_window.start
if attack_direction == AttackDirection.INPUT:
self.texts = np.vstack(
self.project.get_plaintexts(trace_slice.start, trace_slice.stop)
)
self.project.get_plaintexts(trace_slice.start,
trace_slice.stop))
else:
self.texts = np.vstack(
self.project.get_ciphertexts(trace_slice.start, trace_slice.stop)
)
self.project.get_ciphertexts(trace_slice.start,
trace_slice.stop))

self.traces = np.asarray(
self.project.get_waves(trace_slice.start, trace_slice.stop)
)[:, attack_window]
self.project.get_waves(trace_slice.start,
trace_slice.stop))[:, attack_window]

self.project.close(save=False)

Expand All @@ -132,7 +134,7 @@ def compute_stats(self):
cnt = self.traces.shape[0]
sum_ = self.traces.sum(axis=0)
mean = sum_ / cnt
sum_dev_prods = ((self.traces - mean) ** 2).sum(axis=0)
sum_dev_prods = ((self.traces - mean)**2).sum(axis=0)
return (cnt, sum_, sum_dev_prods)

def filter_noisy_traces(self, min_trace, max_trace):
Expand All @@ -146,8 +148,7 @@ def filter_noisy_traces(self, min_trace, max_trace):
Number of remaining traces.
"""
traces_to_use = np.all(
(self.traces >= min_trace) & (self.traces <= max_trace), axis=1
)
(self.traces >= min_trace) & (self.traces <= max_trace), axis=1)
self.traces = self.traces[traces_to_use]
self.texts = self.texts[traces_to_use]
return self.traces.shape[0]
Expand Down Expand Up @@ -214,12 +215,12 @@ def compute_mean_and_std(workers):
running_cnt += cnt
else:
running_sum_dev_prods += sum_dev_prods + (
(cnt * running_sum - running_cnt * sum_) ** 2 /
(cnt * running_cnt * (cnt + running_cnt))
)
(cnt * running_sum - running_cnt * sum_)**2 /
(cnt * running_cnt * (cnt + running_cnt)))
running_sum += sum_
running_cnt += cnt
return running_sum / running_cnt, np.sqrt(running_sum_dev_prods / running_cnt)
return running_sum / running_cnt, np.sqrt(running_sum_dev_prods /
running_cnt)


def filter_noisy_traces(workers, mean_trace, std_trace, max_std):
Expand All @@ -237,7 +238,8 @@ def filter_noisy_traces(workers, mean_trace, std_trace, max_std):
min_trace = mean_trace - max_std * std_trace
max_trace = mean_trace + max_std * std_trace
tasks = [
worker.filter_noisy_traces.remote(min_trace, max_trace) for worker in workers
worker.filter_noisy_traces.remote(min_trace, max_trace)
for worker in workers
]

running_cnt = 0
Expand Down Expand Up @@ -392,7 +394,8 @@ def find_best_diffs(pairwise_diffs_scores):
# the most likely differences between key bytes.
G.add_edge(a, b, weight=DiffScore(pairwise_diffs_scores[a, b, 1]))
# Find paths from key byte 0 to all other bytes.
paths = nx.algorithms.shortest_paths.weighted.single_source_dijkstra_path(G, 0)
paths = nx.algorithms.shortest_paths.weighted.single_source_dijkstra_path(
G, 0)
# Recover the paths and corresponding differences from key byte 0 to all
# other bytes.
diffs = np.zeros(16, dtype=np.uint8)
Expand Down Expand Up @@ -422,9 +425,11 @@ def recover_key(diffs, attack_direction, plaintext, ciphertext):
# Create a matrix of all possible keys.
keys = np.zeros((256, 16), np.uint8)
for first_byte_val in range(256):
key = np.asarray([diffs[i] ^ first_byte_val for i in range(16)], np.uint8)
key = np.asarray([diffs[i] ^ first_byte_val for i in range(16)],
np.uint8)
if attack_direction == AttackDirection.OUTPUT:
key = np.asarray(cwa.aes_funcs.key_schedule_rounds(key, 10, 0), np.uint8)
key = np.asarray(cwa.aes_funcs.key_schedule_rounds(key, 10, 0),
np.uint8)
keys[first_byte_val] = key
# Encrypt the plaintext using all candidates in parallel.
ciphertexts = scared.aes.base.encrypt(plaintext, keys)
Expand Down Expand Up @@ -464,9 +469,8 @@ def compare_diffs(pairwise_diffs_scores, attack_direction, correct_key):


@timer()
def perform_attack(
project_file, num_traces, attack_window, attack_direction, max_std, num_workers
):
def perform_attack(project_file, num_traces, attack_window, attack_direction,
max_std, num_workers):
"""Performs a correlation-enhanced power analysis collision attack.

This function:
Expand Down Expand Up @@ -506,9 +510,10 @@ def perform_attack(
project_type = "ot_trace_library"

# Open the project.
project_cfg = ProjectConfig(
type=project_type, path=project_file, wave_dtype=np.uint16, overwrite=False
)
project_cfg = ProjectConfig(type=project_type,
path=project_file,
wave_dtype=np.uint16,
overwrite=False)
project = SCAProject(project_cfg)
project.open_project()

Expand All @@ -524,11 +529,11 @@ def perform_attack(
f"Invalid attack window: {attack_window} (must be in [0, {last_sample}])"
)
if max_std <= 0:
raise ValueError(f"Invalid max_std: {max_std} (must be greater than zero)")
raise ValueError(
f"Invalid max_std: {max_std} (must be greater than zero)")
if num_workers <= 0:
raise ValueError(
f"Invalid num_workers: {num_workers} (must be greater than zero)"
)
f"Invalid num_workers: {num_workers} (must be greater than zero)")

# Instantiate workers
def worker_trace_slices():
Expand All @@ -539,15 +544,15 @@ def worker_trace_slices():
traces_per_worker = int(num_traces / num_workers)
first_worker_num_traces = traces_per_worker + num_traces % num_workers
yield slice(0, first_worker_num_traces)
for trace_begin in range(
first_worker_num_traces, num_traces, traces_per_worker
):
for trace_begin in range(first_worker_num_traces, num_traces,
traces_per_worker):
yield slice(trace_begin, trace_begin + traces_per_worker)

# Attack window is inclusive.
attack_window = slice(attack_window[0], attack_window[1] + 1)
workers = [
TraceWorker.remote(project_file, trace_slice, attack_window, attack_direction)
TraceWorker.remote(project_file, trace_slice, attack_window,
attack_direction)
for trace_slice in worker_trace_slices()
]
assert len(workers) == num_workers
Expand All @@ -556,32 +561,27 @@ def worker_trace_slices():
# Filter noisy traces.
orig_num_traces = num_traces
num_traces = filter_noisy_traces(workers, mean, std_dev, max_std)
logging.info(
f"Will use {num_traces} traces "
f"({100 * num_traces / orig_num_traces:.1f}% of all traces)"
)
logging.info(f"Will use {num_traces} traces "
f"({100 * num_traces / orig_num_traces:.1f}% of all traces)")
# Mean traces for all values of all text bytes.
mean_text_traces = compute_mean_text_traces(workers)
# Guess the differences between key bytes.
pairwise_diffs_scores = compute_pairwise_diffs_and_scores(mean_text_traces)
diffs = find_best_diffs(pairwise_diffs_scores)
logging.info(f"Difference values (delta_0_i): {diffs}")
# Recover the key.
key = recover_key(
diffs, attack_direction, project.get_plaintexts(0), project.get_ciphertexts(0)
)
key = recover_key(diffs, attack_direction, project.get_plaintexts(0),
project.get_ciphertexts(0))
if key is not None:
logging.info(f"Recovered AES key: {bytes(key).hex()}")
else:
logging.error("Failed to recover the AES key")
# Compare differences - both matrices are symmetric and have an all-zero main diagonal.
correct_diffs = compare_diffs(
pairwise_diffs_scores, attack_direction, project.get_keys(0)
)
correct_diffs = compare_diffs(pairwise_diffs_scores, attack_direction,
project.get_keys(0))
logging.info(
f"Recovered {((np.sum(correct_diffs) - 16) / 2).astype(int)}/120 "
"differences between key bytes"
)
"differences between key bytes")
project.close(save=False)
return key

Expand All @@ -591,8 +591,7 @@ def parse_args():
parser = argparse.ArgumentParser(
description="""A distributed implementation of the attack described in
"Correlation-Enhanced Power Analysis Collision Attack" by A. Moradi, O.
Mischke, and T. Eisenbarth (https://eprint.iacr.org/2010/297.pdf)."""
)
Mischke, and T. Eisenbarth (https://eprint.iacr.org/2010/297.pdf).""")
parser.add_argument(
"-f",
"--project-file",
Expand Down Expand Up @@ -649,8 +648,7 @@ def config_logger():
sh = logging.StreamHandler()
sh.setLevel(logging.INFO)
formatter = logging.Formatter(
"%(asctime)s %(levelname)s %(filename)s:%(lineno)d -- %(message)s"
)
"%(asctime)s %(levelname)s %(filename)s:%(lineno)d -- %(message)s")
sh.setFormatter(formatter)
logger.addHandler(sh)
return logger
Expand All @@ -665,9 +663,9 @@ def main():
ray.init(
runtime_env={
"working_dir": "../",
"excludes": ["*.db", "*.cwp", "*.npy", "*.bit", "*/lfs/*", "*.pack"],
}
)
"excludes":
["*.db", "*.cwp", "*.npy", "*.bit", "*/lfs/*", "*.pack"],
})

key = perform_attack(**vars(args))
sys.exit(0 if key is not None else 1)
Expand Down
Loading