From ce965610646a3b5572443e6cf8add0b5138a42c4 Mon Sep 17 00:00:00 2001 From: namgyu-youn Date: Fri, 16 May 2025 04:45:46 +0900 Subject: [PATCH 1/3] Built single-GPU enchmark for preconditioners (#157) The benchmark compares the performance of various preconditioners (SGD, AdaGrad, Root Inverse Shampoo, Eigendecomposed Shampoo, and Eigenvalue-Corrected Shampoo) using rich console and PyTorch profiler. In rich console, you can check the following: - Total time taken for each preconditioner - Average time taken per epoch - Memory usage in MB - GPU utilization percentage (if applicable) In PyTorch profiler, you can check the following: - Most time-consuming operations (5-th) - Bottleneck analysis for each preconditioner Requested by @tsunghsienlee in #157 for developers experience. Co-authored-by: Tsung-Hsien Lee --- benchmarks/README.md | 67 ++++++ benchmarks/preconditioners.py | 398 ++++++++++++++++++++++++++++++++++ 2 files changed, 465 insertions(+) create mode 100644 benchmarks/README.md create mode 100644 benchmarks/preconditioners.py diff --git a/benchmarks/README.md b/benchmarks/README.md new file mode 100644 index 00000000..de0a35dc --- /dev/null +++ b/benchmarks/README.md @@ -0,0 +1,67 @@ +# Preconditioner Benchmark + +This benchmark compares different preconditioners in [shampoo_preconditioner_list.py](https://github.com/facebookresearch/optimizers/blob/main/distributed_shampoo/utils/shampoo_preconditioner_list.py). It illustrates total time, average time, GPU-usage, and bottleneck using **rich Console and PyTorch profiler**. + + +## Benchmark List + +- `SGDPreconditionerList` : SGD (no preconditioning) +- `AdagradPreconditionerList` : AdaGrad +- `RootInvShampooPreconditionerList` : Root Inverse Shampoo +- `EigendecomposedShampooPreconditionerList` : Eigendecomposed Shampoo +- `EigenvalueCorrectedShampooPreconditionerList` : Eigenvalue-Corrected Shampoo + + +## Example + +```zsh +➜ optimizers git: ✗ uv run benchmarks/preconditioner.py + +Preconditioner Benchmark +Device: cuda | Param Shape: (2048, 2048) | Epochs: 200 + + Performance Summary +┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━━━┓ +┃ Preconditioner ┃ Total (s) ┃ Avg/Epoch (ms) ┃ Memory (MB) ┃ Relative ┃ GPU Util % ┃ +┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━━━┩ +│ SGD │ 0.021 │ 0.11 │ 0.0 │ 1.0x │ 1.6 │ +│ AdaGrad │ 0.064 │ 0.32 │ 16.0 │ 3.0x │ 164.7 │ +│ EigendecomposedShampoo │ 1.260 │ 6.30 │ 64.0 │ 58.8x │ 428.9 │ +│ EigenvalueCorrectedShampoo │ 1.602 │ 8.01 │ 64.0 │ 74.7x │ 595.3 │ +└────────────────────────────┴───────────┴────────────────┴─────────────┴──────────┴────────────┘ + +Bottleneck Analysis + +SGD top operations: + 1. aten::to: 0.08ms + 2. aten::_to_copy: 0.08ms + 3. aten::copy_: 0.08ms + 4. Memcpy HtoD (Pageable -> Device): 0.08ms + 5. aten::empty: 0.00ms + +AdaGrad top operations: + 1. ## AdagradPreconditionerList:update_preconditioners ##: 26.43ms + 2. aten::_foreach_addcmul_: 26.43ms + 3. void at::native::(anonymous namespace)::multi_tensor_apply_kernel, at::native::(anonymous namespace)::PointwiseOpScalarFunctor, std::multiplies, float>(at::native::(anonymous namespace)::TensorListMetadata<3>, at::native::(anonymous namespace)::PointwiseOpScalarFunctor, std::multiplies, +float): 26.43ms + 4. ## AdagradPreconditionerList:update_preconditioners ##: 26.43ms + 5. aten::to: 0.08ms + +EigendecomposedShampoo top operations: + 1. ## EigendecomposedShampooPreconditionerList:update_preconditioners ##: 899.27ms + 2. aten::mm: 675.26ms + 3. ## EigendecomposedShampooPreconditionerList:_update_factor_matrices ##: 668.67ms + 4. ## EigendecomposedShampooPreconditionerList:_update_factor_matrices ##: 654.35ms + 5. aten::tensordot: 601.84ms + +EigenvalueCorrectedShampoo top operations: + 1. ## EigenvalueCorrectedShampooPreconditionerList:update_preconditioners ##: 2367.36ms + 2. aten::mm: 1210.46ms + 3. aten::tensordot: 1136.00ms + 4. void cutlass::Kernel2(cutlass_80_simt_sgemm_256x128_8x4_nt_align1::Params): 824.67ms + 5. ## EigenvalueCorrectedShampooPreconditionerList:_update_factor_matrices ##: 652.20ms + +Quick Comparison (SGD = 1.0x) +SGD: 1.0x | AdaGrad: 3.0x | EigendecomposedShampoo: 58.8x | EigenvalueCorrectedShampoo: 74.7x +``` \ No newline at end of file diff --git a/benchmarks/preconditioners.py b/benchmarks/preconditioners.py new file mode 100644 index 00000000..947bf116 --- /dev/null +++ b/benchmarks/preconditioners.py @@ -0,0 +1,398 @@ +import time +from dataclasses import dataclass +from typing import Any, Optional + +import torch + +from distributed_shampoo.shampoo_types import ( + EigenvalueCorrectedShampooPreconditionerConfig, + ShampooPreconditionerConfig, +) +from distributed_shampoo.utils.shampoo_block_info import BlockInfo +from distributed_shampoo.utils.shampoo_preconditioner_list import ( + AdagradPreconditionerList, + EigendecomposedShampooPreconditionerList, + EigenvalueCorrectedShampooPreconditionerList, + PreconditionerList, + RootInvShampooPreconditionerList, + SGDPreconditionerList, +) +from matrix_functions_types import QREigendecompositionConfig + +# Note: This is a workaround for mypy not recognizing the rich library +# Since only benchmarks requires it, we can ignore the errors here +# If you want to fundamentally change this, you can use a stub file (i.e., .pyi, .ini) +from rich.console import Console # type: ignore +from rich.progress import BarColumn, Progress, SpinnerColumn, TextColumn # type: ignore +from rich.table import Table # type: ignore +from torch.profiler import profile, ProfilerActivity + + +@dataclass +class BenchmarkResult: + """Container for benchmark results. + + Attributes: + preconditioner: Type of the preconditioner. (e.g., Root Inverse Shampoo, Eigendecomposed Shampoo) + total_time: Total time taken to run all epochs (in seconds). + avg_time_per_epoch: Average time per epoch (in seconds). + memory_usage: Memory usage in bytes. + gpu_utilization: GPU utilization percentage (only for CUDA devices). + top_operations: List of tuples with (operation_name, time_ms) for most time-consuming operations. + """ + + preconditioner: str + total_time: float + avg_time_per_epoch: float + memory_usage: float + gpu_utilization: Optional[float] = None + top_operations: Optional[list[tuple[str, float]]] = None + + +class PreconditionerBenchmark: + """Benchmark different preconditioners. + + This class provides utilities to benchmark various preconditioners + for optimization algorithms. It measures performance metrics like + execution time, memory usage, and GPU utilization. + """ + + # Mapping of preconditioner names to their respective implementation classes + PRECONDITIONER_TYPES = { + "SGD": SGDPreconditionerList, + "AdaGrad": AdagradPreconditionerList, + "Shampoo": RootInvShampooPreconditionerList, + "EigendecomposedShampoo": EigendecomposedShampooPreconditionerList, + "EigenvalueCorrectedShampoo": EigenvalueCorrectedShampooPreconditionerList, + } + + def __init__(self, param_shapes: list[tuple[int, ...]], device: str): + """Initialize benchmark with parameter shapes and device. + + Args: + param_shapes: List of tensor shapes to benchmark. + device: Device to run benchmark on. If None, automatically selects + "cuda" if available, otherwise "cpu". + """ + self.param_shapes = param_shapes + self.device = device + self.console = Console() + self.config = self._get_default_config() + self._setup_parameters() + + def _get_default_config(self): + """Create default configuration for preconditioners. + + Returns: + A config object with default settings for the benchmark. + """ + return type( + "Config", + (), + { + "beta2": 1.0, + "epsilon": 1e-12, + "use_bias_correction": True, + "factor_matrix_dtype": torch.float, + }, + )() + + def _setup_parameters(self): + """Initialize parameters, blocks, and block_infos. + + Sets up the following instance attributes: + - self.state: Dictionary mapping parameters to their states + - self.blocks: List of parameter tensors + - self.block_infos: List of BlockInfo objects + """ + self.state: dict[torch.Tensor, Any] = {} + self.blocks: list[torch.Tensor] = [] + self.block_infos: list[BlockInfo] = [] + + for i, shape in enumerate(self.param_shapes): + param = torch.nn.Parameter(torch.randn(shape, device=self.device)) + self.state[param] = {} + self.blocks.append(param.data) + self.block_infos.append( + BlockInfo(param=param, composable_block_ids=(i, "block_0")) + ) + + def _get_preconditioner_config( + self, preconditioner_type: str + ) -> ShampooPreconditionerConfig | EigenvalueCorrectedShampooPreconditionerConfig: + """Get configuration for specific preconditioner types. + + Args: + preconditioner_type: The type of preconditioner to configure. + + Returns: + A configuration object appropriate for the specified preconditioner type. + """ + eigen_config = QREigendecompositionConfig() + + if preconditioner_type == "EigendecomposedShampoo": + return ShampooPreconditionerConfig( + amortized_computation_config=eigen_config + ) + elif preconditioner_type == "EigenvalueCorrectedShampoo": + return EigenvalueCorrectedShampooPreconditionerConfig( + amortized_computation_config=eigen_config + ) + return ShampooPreconditionerConfig() + + def _create_preconditioner(self, preconditioner_type: str) -> PreconditionerList: + """Create a preconditioner of the specified type. + + Args: + preconditioner_type: Name of the preconditioner type to create. + + Returns: + An initialized preconditioner instance. + """ + preconditioner_class = self.PRECONDITIONER_TYPES[preconditioner_type] + + # SGD only needs block_list + if preconditioner_type == "SGD": + return preconditioner_class(block_list=tuple(self.blocks)) + + # Common kwargs for other preconditioners + kwargs = { + "block_list": tuple(self.blocks), + "state": self.state, + "block_info_list": tuple(self.block_infos), + "beta2": self.config.beta2, + "epsilon": self.config.epsilon, + } + + # AdaGrad doesn't use bias correction + if preconditioner_type == "AdaGrad": + kwargs["use_bias_correction"] = False + else: + kwargs["use_bias_correction"] = self.config.use_bias_correction + + # Shampoo variants need additional config + if preconditioner_type in [ + "Shampoo", + "EigendecomposedShampoo", + "EigenvalueCorrectedShampoo", + ]: + kwargs.update( + { + "preconditioner_config": self._get_preconditioner_config( + preconditioner_type + ), + "factor_matrix_dtype": self.config.factor_matrix_dtype, + } + ) + + return preconditioner_class(**kwargs) + + def _benchmark_single( + self, + preconditioner_type: str, + num_epochs: int = 200, + precondition_frequency: int = 40, + ) -> BenchmarkResult: + """Benchmark a single preconditioner with profiler.""" + preconditioner = self._create_preconditioner(preconditioner_type) + gradients = self._generate_gradients(num_epochs) + memory_usage = getattr(preconditioner, "num_bytes", lambda: 0)() + + # Configure profiler + activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA] + + # Build PyTorch profiler + with profile( + activities=activities, + record_shapes=True, + profile_memory=True, + with_stack=True, + ) as prof: + start_time = time.time() + + # Show progress + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), + transient=True, + ) as progress: + task = progress.add_task( + f"[green]Benchmarking {preconditioner_type}...", total=num_epochs + ) + + for epoch in range(num_epochs): + step = torch.tensor( + epoch + 1, dtype=torch.int64, device=self.device + ) + grad_list = gradients[epoch] + + # Perform amortized computation per the frequency epochs (40) + perform_amortized = (epoch + 1) % precondition_frequency == 0 + + preconditioner.update_preconditioners( + masked_grad_list=tuple(grad_list), + step=step, + perform_amortized_computation=perform_amortized, + ) + progress.update(task, advance=1) + + total_time = time.time() - start_time + + # Extract profiling results + key_averages = prof.key_averages() + + # Select time attribute based on device + time_attr = "device_time_total" if self.device == "cuda" else "cpu_time_total" + + # Sort key_averages by time + sorted_events = sorted( + key_averages, key=lambda x: getattr(x, time_attr, 0), reverse=True + ) + + top_ops: list[tuple[str, float]] = [] + for event in sorted_events: + if len(top_ops) >= 5: + break + if "randn" not in event.key and "normal" not in event.key: + time_ms = getattr(event, time_attr, 0) / 1000 + top_ops.append((event.key, time_ms)) + + # Calculate GPU-usage if device is CUDA + gpu_utilization = None + if self.device == "cuda": + device_time = sum(getattr(event, time_attr, 0) for event in key_averages) + total_time_us = total_time * 1e6 + gpu_utilization = ( + (device_time / total_time_us) * 100 if total_time_us > 0 else 0 + ) + + return BenchmarkResult( + preconditioner=preconditioner_type, + total_time=total_time, + avg_time_per_epoch=total_time / num_epochs, + memory_usage=memory_usage, + gpu_utilization=gpu_utilization, + top_operations=top_ops, + ) + + def _generate_gradients(self, num_epochs: int) -> list[list[torch.Tensor]]: + """Pre-generate gradients for benchmarking. + + This method pre-generates random gradients for all epochs to ensure + fair comparison between different preconditioners. + + Args: + num_epochs: Number of epochs to generate gradients for. + + Returns: + List of gradient lists, one per epoch. + """ + gradients = [] + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), + transient=True, + ) as progress: + task = progress.add_task("[cyan]Generating gradients...", total=num_epochs) + for _ in range(num_epochs): + grad_list = [torch.randn_like(block) * 0.01 for block in self.blocks] + gradients.append(grad_list) + progress.update(task, advance=1) + return gradients + + def _display_results(self, results: dict[str, BenchmarkResult]): + """Display benchmark results in formatted tables. + + This method creates and displays: + 1. A performance summary table + 2. Bottleneck analysis for each preconditioner + 3. A quick comparison of relative performance + + Args: + results: Dictionary mapping preconditioner names to BenchmarkResult objects. + """ + # Main performance table + table = Table(title="Performance Summary") + table.add_column("Preconditioner", style="cyan", no_wrap=True) + table.add_column("Total (s)", justify="right", style="magenta") + table.add_column("Avg/Epoch (ms)", justify="right", style="yellow") + table.add_column("Memory (MB)", justify="right", style="green") + table.add_column("Relative", justify="right", style="red") + + if self.device == "cuda": + table.add_column("GPU Util %", justify="right", style="blue") + + sgd_time = results["SGD"].avg_time_per_epoch + + for preconditioner_type, result in results.items(): + relative_time = result.avg_time_per_epoch / sgd_time + row = [ + preconditioner_type, + f"{result.total_time:.3f}", + f"{result.avg_time_per_epoch * 1000:.2f}", + f"{result.memory_usage / 1024 / 1024:.1f}", + f"{relative_time:.1f}x", + ] + + if self.device == "cuda": + gpu_util = result.gpu_utilization or 0 + row.append(f"{gpu_util:.1f}") + + table.add_row(*row) + + self.console.print(table) + + # Show bottleneck analysis + self.console.print("\n[bold]Bottleneck Analysis[/bold]") + for preconditioner_type, result in results.items(): + if result.top_operations: + self.console.print( + f"\n[cyan]{preconditioner_type}[/cyan] top operations:" + ) + for i, (op_name, time_ms) in enumerate(result.top_operations[:10]): + self.console.print(f" {i+1}. {op_name}: {time_ms:.2f}ms") + + # Quick comparison + self.console.print("\n[bold]Quick Comparison (SGD = 1.0x)[/bold]") + comparisons = [] + for preconditioner_type, result in results.items(): + relative_time = result.avg_time_per_epoch / sgd_time + comparisons.append( + f"{preconditioner_type}: [bold red]{relative_time:.1f}x[/bold red]" + ) + self.console.print(" | ".join(comparisons)) + + def run_all_benchmarks(self, num_epochs: int = 100) -> dict[str, BenchmarkResult]: + """Run benchmarks for all preconditioners. + + Args: + num_epochs: Number of epochs to run for each benchmark. + + Returns: + Dictionary mapping preconditioner names to their benchmark results. + """ + self.console.print("\n[bold cyan]Preconditioner Benchmark[/bold cyan]") + self.console.print( + f"[dim]Device: {self.device} | Param Shape: {self.param_shapes[0]} | " + f"Epochs: {num_epochs}[/dim]\n" + ) + + results = {} + for preconditioner_type in self.PRECONDITIONER_TYPES: + result = self._benchmark_single(preconditioner_type, num_epochs) + results[preconditioner_type] = result + + self._display_results(results) + return results + + +if __name__ == "__main__": + # Run benchmark with a single 2048x2048 parameter matrix + benchmark = PreconditionerBenchmark( + [(2048, 2048)], device="cuda" if torch.cuda.is_available() else "cpu" + ) + benchmark.run_all_benchmarks(num_epochs=200) From 590bb9f2b72b43d907ddf46a41be1c11348d98e7 Mon Sep 17 00:00:00 2001 From: namgyu-youn Date: Wed, 4 Jun 2025 20:48:46 +0900 Subject: [PATCH 2/3] Update preconditioner benchmarks 1. Hardcode the device to "cuda" and basic configurations for benchmarks. 2. Enhance sorting logics for profiling results. 3. Fix typo in rich Console output. --- benchmarks/README.md | 61 +---------- benchmarks/preconditioners.py | 200 ++++++++++++++-------------------- 2 files changed, 87 insertions(+), 174 deletions(-) diff --git a/benchmarks/README.md b/benchmarks/README.md index de0a35dc..555804ec 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -1,9 +1,10 @@ -# Preconditioner Benchmark +# Preconditioner Benchmarks -This benchmark compares different preconditioners in [shampoo_preconditioner_list.py](https://github.com/facebookresearch/optimizers/blob/main/distributed_shampoo/utils/shampoo_preconditioner_list.py). It illustrates total time, average time, GPU-usage, and bottleneck using **rich Console and PyTorch profiler**. +This benchmark compares different preconditioners in [shampoo_preconditioner_list.py](https://github.com/facebookresearch/optimizers/blob/main/distributed_shampoo/utils/shampoo_preconditioner_list.py). It illustrates **time consumption, CPU-usage, and GPU-usage using PyTorch Profiler** and rich Console. +Note: You should be available to CUDA for the benchmarks -## Benchmark List +### Benchmark List - `SGDPreconditionerList` : SGD (no preconditioning) - `AdagradPreconditionerList` : AdaGrad @@ -12,56 +13,6 @@ This benchmark compares different preconditioners in [shampoo_preconditioner_lis - `EigenvalueCorrectedShampooPreconditionerList` : Eigenvalue-Corrected Shampoo -## Example +### Example -```zsh -➜ optimizers git: ✗ uv run benchmarks/preconditioner.py - -Preconditioner Benchmark -Device: cuda | Param Shape: (2048, 2048) | Epochs: 200 - - Performance Summary -┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━━━┓ -┃ Preconditioner ┃ Total (s) ┃ Avg/Epoch (ms) ┃ Memory (MB) ┃ Relative ┃ GPU Util % ┃ -┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━━━┩ -│ SGD │ 0.021 │ 0.11 │ 0.0 │ 1.0x │ 1.6 │ -│ AdaGrad │ 0.064 │ 0.32 │ 16.0 │ 3.0x │ 164.7 │ -│ EigendecomposedShampoo │ 1.260 │ 6.30 │ 64.0 │ 58.8x │ 428.9 │ -│ EigenvalueCorrectedShampoo │ 1.602 │ 8.01 │ 64.0 │ 74.7x │ 595.3 │ -└────────────────────────────┴───────────┴────────────────┴─────────────┴──────────┴────────────┘ - -Bottleneck Analysis - -SGD top operations: - 1. aten::to: 0.08ms - 2. aten::_to_copy: 0.08ms - 3. aten::copy_: 0.08ms - 4. Memcpy HtoD (Pageable -> Device): 0.08ms - 5. aten::empty: 0.00ms - -AdaGrad top operations: - 1. ## AdagradPreconditionerList:update_preconditioners ##: 26.43ms - 2. aten::_foreach_addcmul_: 26.43ms - 3. void at::native::(anonymous namespace)::multi_tensor_apply_kernel, at::native::(anonymous namespace)::PointwiseOpScalarFunctor, std::multiplies, float>(at::native::(anonymous namespace)::TensorListMetadata<3>, at::native::(anonymous namespace)::PointwiseOpScalarFunctor, std::multiplies, -float): 26.43ms - 4. ## AdagradPreconditionerList:update_preconditioners ##: 26.43ms - 5. aten::to: 0.08ms - -EigendecomposedShampoo top operations: - 1. ## EigendecomposedShampooPreconditionerList:update_preconditioners ##: 899.27ms - 2. aten::mm: 675.26ms - 3. ## EigendecomposedShampooPreconditionerList:_update_factor_matrices ##: 668.67ms - 4. ## EigendecomposedShampooPreconditionerList:_update_factor_matrices ##: 654.35ms - 5. aten::tensordot: 601.84ms - -EigenvalueCorrectedShampoo top operations: - 1. ## EigenvalueCorrectedShampooPreconditionerList:update_preconditioners ##: 2367.36ms - 2. aten::mm: 1210.46ms - 3. aten::tensordot: 1136.00ms - 4. void cutlass::Kernel2(cutlass_80_simt_sgemm_256x128_8x4_nt_align1::Params): 824.67ms - 5. ## EigenvalueCorrectedShampooPreconditionerList:_update_factor_matrices ##: 652.20ms - -Quick Comparison (SGD = 1.0x) -SGD: 1.0x | AdaGrad: 3.0x | EigendecomposedShampoo: 58.8x | EigenvalueCorrectedShampoo: 74.7x -``` \ No newline at end of file +![image](https://github.com/user-attachments/assets/c92edaf0-234f-464d-b661-8dc28da84118) \ No newline at end of file diff --git a/benchmarks/preconditioners.py b/benchmarks/preconditioners.py index 947bf116..e74b9a39 100644 --- a/benchmarks/preconditioners.py +++ b/benchmarks/preconditioners.py @@ -1,6 +1,15 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the BSD-style license found in the +LICENSE file in the root directory of this source tree. + +""" + import time from dataclasses import dataclass -from typing import Any, Optional +from typing import Optional import torch @@ -38,7 +47,6 @@ class BenchmarkResult: avg_time_per_epoch: Average time per epoch (in seconds). memory_usage: Memory usage in bytes. gpu_utilization: GPU utilization percentage (only for CUDA devices). - top_operations: List of tuples with (operation_name, time_ms) for most time-consuming operations. """ preconditioner: str @@ -46,7 +54,7 @@ class BenchmarkResult: avg_time_per_epoch: float memory_usage: float gpu_utilization: Optional[float] = None - top_operations: Optional[list[tuple[str, float]]] = None + profiling_table: Optional[str] = None class PreconditionerBenchmark: @@ -57,7 +65,7 @@ class PreconditionerBenchmark: execution time, memory usage, and GPU utilization. """ - # Mapping of preconditioner names to their respective implementation classes + # Mapping preconditioner names PRECONDITIONER_TYPES = { "SGD": SGDPreconditionerList, "AdaGrad": AdagradPreconditionerList, @@ -67,36 +75,11 @@ class PreconditionerBenchmark: } def __init__(self, param_shapes: list[tuple[int, ...]], device: str): - """Initialize benchmark with parameter shapes and device. - - Args: - param_shapes: List of tensor shapes to benchmark. - device: Device to run benchmark on. If None, automatically selects - "cuda" if available, otherwise "cpu". - """ + """Initialize benchmark with parameter shapes and device.""" self.param_shapes = param_shapes - self.device = device self.console = Console() - self.config = self._get_default_config() self._setup_parameters() - def _get_default_config(self): - """Create default configuration for preconditioners. - - Returns: - A config object with default settings for the benchmark. - """ - return type( - "Config", - (), - { - "beta2": 1.0, - "epsilon": 1e-12, - "use_bias_correction": True, - "factor_matrix_dtype": torch.float, - }, - )() - def _setup_parameters(self): """Initialize parameters, blocks, and block_infos. @@ -105,12 +88,12 @@ def _setup_parameters(self): - self.blocks: List of parameter tensors - self.block_infos: List of BlockInfo objects """ - self.state: dict[torch.Tensor, Any] = {} + self.state: dict[torch.Tensor] = {} self.blocks: list[torch.Tensor] = [] self.block_infos: list[BlockInfo] = [] for i, shape in enumerate(self.param_shapes): - param = torch.nn.Parameter(torch.randn(shape, device=self.device)) + param = torch.nn.Parameter(torch.randn(shape, device="cuda")) self.state[param] = {} self.blocks.append(param.data) self.block_infos.append( @@ -151,7 +134,7 @@ def _create_preconditioner(self, preconditioner_type: str) -> PreconditionerList """ preconditioner_class = self.PRECONDITIONER_TYPES[preconditioner_type] - # SGD only needs block_list + # SGD only needs block_list (no preconditioner config) if preconditioner_type == "SGD": return preconditioner_class(block_list=tuple(self.blocks)) @@ -160,17 +143,10 @@ def _create_preconditioner(self, preconditioner_type: str) -> PreconditionerList "block_list": tuple(self.blocks), "state": self.state, "block_info_list": tuple(self.block_infos), - "beta2": self.config.beta2, - "epsilon": self.config.epsilon, + "beta2": 0.999, + "epsilon": 1e-8, + "use_bias_correction": preconditioner_type != "AdaGrad", } - - # AdaGrad doesn't use bias correction - if preconditioner_type == "AdaGrad": - kwargs["use_bias_correction"] = False - else: - kwargs["use_bias_correction"] = self.config.use_bias_correction - - # Shampoo variants need additional config if preconditioner_type in [ "Shampoo", "EigendecomposedShampoo", @@ -180,30 +156,23 @@ def _create_preconditioner(self, preconditioner_type: str) -> PreconditionerList { "preconditioner_config": self._get_preconditioner_config( preconditioner_type - ), - "factor_matrix_dtype": self.config.factor_matrix_dtype, + ) } ) return preconditioner_class(**kwargs) def _benchmark_single( - self, - preconditioner_type: str, - num_epochs: int = 200, - precondition_frequency: int = 40, + self, preconditioner_type: str, num_epochs=200 ) -> BenchmarkResult: """Benchmark a single preconditioner with profiler.""" preconditioner = self._create_preconditioner(preconditioner_type) gradients = self._generate_gradients(num_epochs) memory_usage = getattr(preconditioner, "num_bytes", lambda: 0)() - # Configure profiler - activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA] - - # Build PyTorch profiler + torch.cuda.empty_cache() with profile( - activities=activities, + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True, profile_memory=True, with_stack=True, @@ -223,16 +192,12 @@ def _benchmark_single( ) for epoch in range(num_epochs): - step = torch.tensor( - epoch + 1, dtype=torch.int64, device=self.device - ) - grad_list = gradients[epoch] - + step = torch.tensor(epoch + 1, dtype=torch.float, device="cuda") # Perform amortized computation per the frequency epochs (40) - perform_amortized = (epoch + 1) % precondition_frequency == 0 + perform_amortized = (epoch + 1) % 40 == 0 preconditioner.update_preconditioners( - masked_grad_list=tuple(grad_list), + masked_grad_list=tuple(gradients[epoch]), step=step, perform_amortized_computation=perform_amortized, ) @@ -242,39 +207,48 @@ def _benchmark_single( # Extract profiling results key_averages = prof.key_averages() + top_ops = sorted( + key_averages, key=lambda x: getattr(x, "cuda_time_total", 0), reverse=True + )[:8] + + # Calculate total CPU and GPU times + total_cpu_time = sum(x.cpu_time_total for x in key_averages) + total_device_time = sum(x.device_time_total for x in key_averages) + + lines = [ + "Name Self CPU % CPU Time Self GPU % GPU Time Calls", + "-" * 85, + ] + + # Format the top operations + for op in top_ops: + name = op.key[:35].ljust(35) + self_cpu_pct = ( + f"{op.self_cpu_time_total / total_cpu_time * 100:.1f}%".rjust(9) + ) + cpu_time = f"{op.cpu_time_total / 1000:.1f}ms".rjust(9) + self_gpu_pct = ( + f"{op.self_device_time_total / total_device_time * 100:.1f}%".rjust(9) + ) + gpu_time = f"{op.device_time_total / 1000:.1f}ms".rjust(8) + calls = f"{op.count}".rjust(9) - # Select time attribute based on device - time_attr = "device_time_total" if self.device == "cuda" else "cpu_time_total" + lines.append( + f"{name} {self_cpu_pct} {cpu_time} {self_gpu_pct} {gpu_time} {calls}" + ) - # Sort key_averages by time - sorted_events = sorted( - key_averages, key=lambda x: getattr(x, time_attr, 0), reverse=True + top_ops = "\n".join(lines) + gpu_utilization = ( + (total_device_time / (total_time * 1e6)) * 100 if total_time > 0 else 0 ) - top_ops: list[tuple[str, float]] = [] - for event in sorted_events: - if len(top_ops) >= 5: - break - if "randn" not in event.key and "normal" not in event.key: - time_ms = getattr(event, time_attr, 0) / 1000 - top_ops.append((event.key, time_ms)) - - # Calculate GPU-usage if device is CUDA - gpu_utilization = None - if self.device == "cuda": - device_time = sum(getattr(event, time_attr, 0) for event in key_averages) - total_time_us = total_time * 1e6 - gpu_utilization = ( - (device_time / total_time_us) * 100 if total_time_us > 0 else 0 - ) - return BenchmarkResult( preconditioner=preconditioner_type, total_time=total_time, avg_time_per_epoch=total_time / num_epochs, memory_usage=memory_usage, gpu_utilization=gpu_utilization, - top_operations=top_ops, + profiling_table=top_ops, ) def _generate_gradients(self, num_epochs: int) -> list[list[torch.Tensor]]: @@ -299,8 +273,9 @@ def _generate_gradients(self, num_epochs: int) -> list[list[torch.Tensor]]: ) as progress: task = progress.add_task("[cyan]Generating gradients...", total=num_epochs) for _ in range(num_epochs): - grad_list = [torch.randn_like(block) * 0.01 for block in self.blocks] - gradients.append(grad_list) + gradients.append( + [torch.randn_like(block) * 0.01 for block in self.blocks] + ) progress.update(task, advance=1) return gradients @@ -317,56 +292,46 @@ def _display_results(self, results: dict[str, BenchmarkResult]): """ # Main performance table table = Table(title="Performance Summary") - table.add_column("Preconditioner", style="cyan", no_wrap=True) - table.add_column("Total (s)", justify="right", style="magenta") + table.add_column("Preconditioner", style="cyan", justify="left") + table.add_column("Total time (s)", justify="right", style="magenta") table.add_column("Avg/Epoch (ms)", justify="right", style="yellow") table.add_column("Memory (MB)", justify="right", style="green") table.add_column("Relative", justify="right", style="red") - - if self.device == "cuda": - table.add_column("GPU Util %", justify="right", style="blue") + table.add_column("GPU Util %", justify="right", style="blue") sgd_time = results["SGD"].avg_time_per_epoch for preconditioner_type, result in results.items(): - relative_time = result.avg_time_per_epoch / sgd_time row = [ preconditioner_type, - f"{result.total_time:.3f}", - f"{result.avg_time_per_epoch * 1000:.2f}", + f"{result.total_time:.1f}", + f"{result.avg_time_per_epoch * 1000:.1f}", f"{result.memory_usage / 1024 / 1024:.1f}", - f"{relative_time:.1f}x", + f"{result.avg_time_per_epoch / sgd_time:.1f}x", + f"{result.gpu_utilization or 0:.0f}", ] - if self.device == "cuda": - gpu_util = result.gpu_utilization or 0 - row.append(f"{gpu_util:.1f}") - table.add_row(*row) self.console.print(table) - # Show bottleneck analysis + # Show bottleneck analysis in rich Console self.console.print("\n[bold]Bottleneck Analysis[/bold]") for preconditioner_type, result in results.items(): - if result.top_operations: - self.console.print( - f"\n[cyan]{preconditioner_type}[/cyan] top operations:" - ) - for i, (op_name, time_ms) in enumerate(result.top_operations[:10]): - self.console.print(f" {i+1}. {op_name}: {time_ms:.2f}ms") + if result.profiling_table: + self.console.print(f"\n[cyan]{preconditioner_type}[/cyan]:") + lines = result.profiling_table.split("\n") + for line in lines: + self.console.print(f" {line}") - # Quick comparison - self.console.print("\n[bold]Quick Comparison (SGD = 1.0x)[/bold]") comparisons = [] for preconditioner_type, result in results.items(): - relative_time = result.avg_time_per_epoch / sgd_time - comparisons.append( - f"{preconditioner_type}: [bold red]{relative_time:.1f}x[/bold red]" - ) - self.console.print(" | ".join(comparisons)) + ratio = result.avg_time_per_epoch / sgd_time + comparisons.append(f"{preconditioner_type}: {ratio:.1f}x") + + self.console.print(f"Total time relative to SGD: {' | '.join(comparisons)}") - def run_all_benchmarks(self, num_epochs: int = 100) -> dict[str, BenchmarkResult]: + def run_all_benchmarks(self, num_epochs: int = 200) -> dict[str, BenchmarkResult]: """Run benchmarks for all preconditioners. Args: @@ -377,8 +342,7 @@ def run_all_benchmarks(self, num_epochs: int = 100) -> dict[str, BenchmarkResult """ self.console.print("\n[bold cyan]Preconditioner Benchmark[/bold cyan]") self.console.print( - f"[dim]Device: {self.device} | Param Shape: {self.param_shapes[0]} | " - f"Epochs: {num_epochs}[/dim]\n" + f"[dim]Device: cuda | Shape: {self.param_shapes[0]} | Epochs: {num_epochs}[/dim]\n" ) results = {} @@ -392,7 +356,5 @@ def run_all_benchmarks(self, num_epochs: int = 100) -> dict[str, BenchmarkResult if __name__ == "__main__": # Run benchmark with a single 2048x2048 parameter matrix - benchmark = PreconditionerBenchmark( - [(2048, 2048)], device="cuda" if torch.cuda.is_available() else "cpu" - ) + benchmark = PreconditionerBenchmark([(2048, 2048)], device="cuda") benchmark.run_all_benchmarks(num_epochs=200) From 418fda69c746a2500cc40286042e5395373d1998 Mon Sep 17 00:00:00 2001 From: namgyu-youn Date: Wed, 4 Jun 2025 20:58:25 +0900 Subject: [PATCH 3/3] Fix mypy type[annotation-unchecked] complain - top_ops is not a valid name for a variable, it should be profiling_table --- benchmarks/preconditioners.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/benchmarks/preconditioners.py b/benchmarks/preconditioners.py index e74b9a39..01286cfa 100644 --- a/benchmarks/preconditioners.py +++ b/benchmarks/preconditioners.py @@ -237,7 +237,7 @@ def _benchmark_single( f"{name} {self_cpu_pct} {cpu_time} {self_gpu_pct} {gpu_time} {calls}" ) - top_ops = "\n".join(lines) + profiling_table = "\n".join(lines) gpu_utilization = ( (total_device_time / (total_time * 1e6)) * 100 if total_time > 0 else 0 ) @@ -248,7 +248,7 @@ def _benchmark_single( avg_time_per_epoch=total_time / num_epochs, memory_usage=memory_usage, gpu_utilization=gpu_utilization, - profiling_table=top_ops, + profiling_table=profiling_table, ) def _generate_gradients(self, num_epochs: int) -> list[list[torch.Tensor]]: