diff --git a/.gitignore b/.gitignore index 629283090..60b4c7e26 100644 --- a/.gitignore +++ b/.gitignore @@ -16,6 +16,8 @@ src/scope/core/pipelines/**/*.mp4 notes/ +benchmark_*.json + # Cursor IDE files .cursorrules .cursorignore diff --git a/README.md b/README.md index 58251d985..a5d0872bd 100644 --- a/README.md +++ b/README.md @@ -131,6 +131,52 @@ After your first generation you can: - Use [LoRAs](./docs/lora.md) to customize the concepts and styles used in your generations. - Use [Spout](./docs/spout.md) (Windows only) to share real-time video between Scope and other local applications. +## Benchmarking + +Scope includes a comprehensive benchmarking suite to test pipeline performance across different configurations and hardware setups. This is useful for: + +- Understanding performance characteristics of different GPUs (H100, A6000, 4090, etc.) +- Determining optimal configurations (resolution, denoising steps) for your hardware +- Identifying optimization opportunities + +### Quick Start + +Install benchmark dependencies: + +```bash +uv sync --group benchmark +``` + +Run a comprehensive benchmark (all pipelines, all configurations): + +```bash +uv run benchmark.py +``` + +### Example Usage + +```bash +# Benchmark specific pipelines +uv run benchmark.py --pipelines streamdiffusionv2 longlive + +# Custom resolutions +uv run benchmark.py --resolutions 480x832 768x1344 + +# Custom iterations +uv run benchmark.py --warmup 10 --iterations 50 + +# Save results to specific file +uv run benchmark.py --output h100_results.json +``` + +### Output + +The benchmark generates a JSON file with: +- Hardware specifications (GPU, CPU, memory) +- Average performance metrics per configuration +- Peak resource utilization (VRAM, GPU utilization, CPU usage) + + ## Firewalls If you run Scope in a cloud environment with restrictive firewall settings (eg. Runpod), Scope supports using [TURN servers](https://webrtc.org/getting-started/turn-server) to establish a connection between your browser and the streaming server. diff --git a/benchmark.py b/benchmark.py new file mode 100644 index 000000000..07e5500cb --- /dev/null +++ b/benchmark.py @@ -0,0 +1,607 @@ +import argparse +import gc +import json +import platform +import statistics +import threading +import time +from datetime import datetime +from pathlib import Path +from typing import Any + +import cpuinfo +import psutil +import pynvml +import torch +from omegaconf import OmegaConf + +from scope.core.config import get_model_file_path, get_models_dir +from scope.core.pipelines.registry import PipelineRegistry +from scope.core.pipelines.utils import Quantization +from scope.server.download_models import download_models +from scope.server.models_config import models_are_downloaded + + +class HardwareInfo: + """Collects and stores hardware information.""" + + def __init__(self): + self._info = self._collect_info() + + def _collect_info(self) -> dict[str, Any]: + return { + "gpu": self._get_gpu_info(), + "cpu": self._get_cpu_info(), + "memory": self._get_memory_info(), + "platform": self._get_platform_info(), + } + + def _get_gpu_info(self) -> dict[str, Any]: + gpu_info = {"available": torch.cuda.is_available(), "count": 0, "devices": []} + if not torch.cuda.is_available(): + return gpu_info + + gpu_info["count"] = torch.cuda.device_count() + gpu_info["cuda_version"] = torch.version.cuda + + pynvml.nvmlInit() + for i in range(gpu_info["count"]): + handle = pynvml.nvmlDeviceGetHandleByIndex(i) + name = pynvml.nvmlDeviceGetName(handle) + if isinstance(name, bytes): + name = name.decode("utf-8") + + mem = pynvml.nvmlDeviceGetMemoryInfo(handle) + driver = pynvml.nvmlSystemGetDriverVersion() + if isinstance(driver, bytes): + driver = driver.decode("utf-8") + + gpu_info["devices"].append( + { + "index": i, + "name": name, + "memory_total_gb": mem.total / (1024**3), + "driver_version": driver, + } + ) + pynvml.nvmlShutdown() + + if not gpu_info["devices"]: + for i in range(gpu_info["count"]): + props = torch.cuda.get_device_properties(i) + gpu_info["devices"].append( + { + "index": i, + "name": props.name, + "memory_total_gb": props.total_memory / (1024**3), + } + ) + + return gpu_info + + def _get_cpu_info(self) -> dict[str, Any]: + """Get comprehensive CPU information using py-cpuinfo""" + cpu_data = cpuinfo.get_cpu_info() + flags = cpu_data.get("flags", []) + + cpu_info = { + "model": cpu_data.get("brand_raw", platform.processor() or "Unknown"), + "architecture": cpu_data.get("arch", platform.machine()), + "physical_cores": psutil.cpu_count(logical=False), + "logical_cores": psutil.cpu_count(logical=True), + } + + # SIMD extensions (relevant for ML performance) + simd_flags = { + "avx": "AVX", "avx2": "AVX2", "avx512f": "AVX512F", + "avx512_bf16": "AVX512_BF16", "fma": "FMA", "fma3": "FMA", "neon": "NEON" + } + simd = [name for flag, name in simd_flags.items() if flag in flags] + if simd: + cpu_info["simd_support"] = list(dict.fromkeys(simd)) # dedupe + + # Cache info + cache_keys = [ + ("l1_data_cache_size", "L1d"), ("l1_instruction_cache_size", "L1i"), + ("l2_cache_size", "L2"), ("l3_cache_size", "L3") + ] + cache = {name: self._fmt_bytes(cpu_data[key]) + for key, name in cache_keys if cpu_data.get(key)} + if cache: + cpu_info["cache"] = cache + + cpu_quota = self._read_cgroup("cpu.max", "cpu/cpu.cfs_quota_us") + if cpu_quota: + cpu_info["container_cpu_limit"] = cpu_quota + + return cpu_info + + def _fmt_bytes(self, size: int | str) -> str: + """Format bytes in human-readable format""" + if isinstance(size, str): + if any(unit in size for unit in ["KiB", "MiB", "GiB", "KB", "MB", "GB"]): + return size + try: + size = int(size) + except ValueError: + return size + + for unit, threshold in [("GiB", 1024**3), ("MiB", 1024**2), ("KiB", 1024)]: + if size >= threshold: + return f"{size / threshold:.0f} {unit}" + return f"{size} B" + + def _read_cgroup(self, v2_path: str, v1_path: str) -> float | None: + """Read cgroup value (tries v2 then v1)""" + try: + with open(f"/sys/fs/cgroup/{v2_path}", "r") as f: + val = f.read().strip() + if val not in ("max", "unlimited"): + if "cpu" in v2_path: + quota, period = val.split() + return float(quota) / float(period) + return int(val) + except (FileNotFoundError, ValueError): + pass + + try: + with open(f"/sys/fs/cgroup/{v1_path}", "r") as f: + val = int(f.read().strip()) + if "cpu" in v1_path and val > 0: + with open("/sys/fs/cgroup/cpu/cpu.cfs_period_us", "r") as pf: + return val / int(pf.read().strip()) + if "memory" in v1_path and val < 1024**4: + return val + except (FileNotFoundError, ValueError): + pass + + return None + + def _get_memory_info(self) -> dict[str, Any]: + mem = psutil.virtual_memory() + host_total = mem.total / (1024**3) + host_avail = mem.available / (1024**3) + + container_limit = self._read_cgroup("memory.max", "memory/memory.limit_in_bytes") + + if container_limit: + container_gb = container_limit / (1024**3) + return { + "total_gb": container_gb, + "available_gb": min(host_avail, container_gb), + "host_total_gb": host_total, + "is_containerized": True, + } + + return { + "total_gb": host_total, + "available_gb": host_avail, + "is_containerized": False, + } + + def _get_platform_info(self) -> dict[str, Any]: + return { + "system": platform.system(), + "release": platform.release(), + "python_version": platform.python_version(), + "pytorch_version": torch.__version__, + } + + def to_dict(self) -> dict[str, Any]: + return self._info + + def get_primary_gpu_vram_gb(self) -> float: + if not self._info["gpu"]["available"] or not self._info["gpu"]["devices"]: + return 0.0 + return self._info["gpu"]["devices"][0]["memory_total_gb"] + + +class ResourceMonitor: + def __init__(self, interval_ms: int = 100, device_index: int = 0): + self.interval_ms = interval_ms + self.device_index = device_index + self._monitoring = False + self._thread = None + self._samples = [] + self._lock = threading.Lock() + self._process = psutil.Process() + self._pynvml_initialized = False + self._gpu_handle = None + + pynvml.nvmlInit() + self._gpu_handle = pynvml.nvmlDeviceGetHandleByIndex(device_index) + self._pynvml_initialized = True + + def start(self): + if self._monitoring: + return + self._monitoring = True + self._samples = [] + self._thread = threading.Thread(target=self._monitor_loop, daemon=True) + self._thread.start() + + def stop(self): + if not self._monitoring: + return + self._monitoring = False + if self._thread: + self._thread.join(timeout=2.0) + self._thread = None + + def _monitor_loop(self): + while self._monitoring: + sample = self._collect_sample() + with self._lock: + self._samples.append(sample) + time.sleep(self.interval_ms / 1000.0) + + def _collect_sample(self) -> dict[str, Any]: + sample = {} + if torch.cuda.is_available(): + sample["gpu_memory_allocated_gb"] = torch.cuda.memory_allocated( + self.device_index + ) / (1024**3) + if self._pynvml_initialized and self._gpu_handle: + util = pynvml.nvmlDeviceGetUtilizationRates(self._gpu_handle) + sample["gpu_utilization_percent"] = util.gpu + + sample["system_cpu_percent"] = psutil.cpu_percent() + return sample + + def get_statistics(self) -> dict[str, float]: + with self._lock: + samples = self._samples.copy() + if not samples: + return {} + + stats = {} + keys = [ + "gpu_memory_allocated_gb", + "gpu_utilization_percent", + "system_cpu_percent", + ] + for key in keys: + values = [s[key] for s in samples if key in s] + if values: + stats[f"{key}_avg"] = sum(values) / len(values) + stats[f"{key}_max"] = max(values) + stats[f"{key}_min"] = min(values) + stats[f"{key}_std"] = statistics.stdev(values) + return stats + + def cleanup(self): + self.stop() + if self._pynvml_initialized: + pynvml.nvmlShutdown() + + +class ConfigurationMatrix: + STANDARD_RESOLUTIONS = [ + (320, 576), + (480, 832), + (512, 512), + (576, 1024), + (768, 1344), + ] + + DEFAULT_PROMPT = "A realistic video of a serene landscape with rolling hills, a clear blue sky, and a gentle stream." + + def __init__(self, pipelines=None, resolutions=None): + self.selected_pipelines = pipelines + self.custom_resolutions = resolutions + + def build(self) -> list[dict]: + all_pipelines = PipelineRegistry.list_pipelines() + + if self.selected_pipelines: + pipelines = [p for p in all_pipelines if p in self.selected_pipelines] + else: + pipelines = [p for p in all_pipelines if p != "passthrough"] + + configurations = [] + for pid in pipelines: + resolutions = self._get_resolutions(pid) + + for h, w in resolutions: + config = { + "pipeline_id": pid, + "height": h, + "width": w, + "prompt": self.DEFAULT_PROMPT, + } + configurations.append(config) + + return configurations + + def _get_resolutions(self, pid: str) -> list[tuple[int, int]]: + if self.custom_resolutions: + return self.custom_resolutions + + pipeline_class = PipelineRegistry.get(pid) + if not pipeline_class: + return [] + default_cfg = pipeline_class.get_config_class()() + + res_set = {(default_cfg.height, default_cfg.width)} + + for h, w in self.STANDARD_RESOLUTIONS: + res_set.add((h, w)) + + return sorted(res_set) + + +class BenchmarkRunner: + def __init__(self, warmup_iterations=5, iterations=30, compile_model=False): + self.warmup_iterations = warmup_iterations + self.iterations = iterations + self.compile_model = compile_model + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + def run_config(self, config: dict) -> dict: + pipeline_id = config["pipeline_id"] + print( + f"\n--- Benchmarking {pipeline_id} [{config['height']}x{config['width']}] ---" + ) + + if not models_are_downloaded(pipeline_id): + print(f"Downloading models for {pipeline_id}...") + try: + download_models(pipeline_id) + print(f"Models downloaded successfully for {pipeline_id}") + except Exception as e: + print(f"ERROR: Failed to download models: {e}") + return {"error": f"Model download failed: {str(e)}"} + + pipeline = None + try: + pipeline = self._init_pipeline(config) + inputs = {"prompts": [{"text": config["prompt"], "weight": 100}]} + + if pipeline_id == "streamdiffusionv2": + inputs["video"] = torch.randn( + 1, + 3, + 4, + config["height"], + config["width"], + device=self.device, + dtype=torch.bfloat16, + ) + + print(f"Warmup ({self.warmup_iterations} iterations)...") + try: + for _ in range(self.warmup_iterations): + pipeline(**inputs) + except Exception as e: + raise Exception(f"Warmup failed: {e}") from e + + print(f"Measuring ({self.iterations} iterations)...") + monitor = ResourceMonitor() + latencies = [] + fps_measures = [] + + try: + monitor.start() + output = None + for _ in range(self.iterations): + t0 = time.time() + output = pipeline(**inputs) + latency = time.time() - t0 + latencies.append(latency) + fps_measures.append(output.shape[0] / latency) + + if output is not None: + output = output.cpu() + del output + torch.cuda.synchronize() + torch.cuda.empty_cache() + finally: + try: + monitor.stop() + resource_stats = monitor.get_statistics() + monitor.cleanup() + except Exception: + resource_stats = {} + + if not latencies: + return {"error": "No successful iterations"} + + avg_latency = statistics.mean(latencies) + min_latency = min(latencies) + max_latency = max(latencies) + jitter = statistics.stdev(latencies) if len(latencies) > 1 else 0.0 + + fps_avg = statistics.mean(fps_measures) + fps_min = min(fps_measures) + fps_max = max(fps_measures) + + results = { + "fps_avg": round(fps_avg, 2), + "fps_min": round(fps_min, 2), + "fps_max": round(fps_max, 2), + "latency_avg_sec": round(avg_latency, 4), + "latency_min_sec": round(min_latency, 4), + "latency_max_sec": round(max_latency, 4), + "jitter_sec": round(jitter, 6), + **resource_stats, + } + + print( + f"-> FPS: {results['fps_avg']} | Latency: {results['latency_avg_sec']}s | Jitter: {results['jitter_sec']}s" + ) + return results + + except Exception as e: + print(f"ERROR: {e}") + return {"error": str(e)} + finally: + del pipeline + self._clear_memory() + time.sleep(3.0) + + def _init_pipeline(self, config: dict): + pid = config["pipeline_id"] + pipeline_class = PipelineRegistry.get(pid) + + model_config = OmegaConf.load( + Path(__file__).parent + / "src/scope/core/pipelines" + / pid.replace("-", "_") + / "model.yaml" + ) + pipeline_config = { + "model_dir": str(get_models_dir()), + "model_config": model_config, + "height": config["height"], + "width": config["width"], + } + + def model_path(p): + return str(get_model_file_path(p)) + + wan_enc = model_path("WanVideo_comfy/umt5-xxl-enc-fp8_e4m3fn.safetensors") + wan_tok = model_path("Wan2.1-T2V-1.3B/google/umt5-xxl") + + paths = {} + if pid == "streamdiffusionv2": + paths = { + "generator_path": model_path( + "StreamDiffusionV2/wan_causal_dmd_v2v/model.pt" + ) + } + elif pid == "longlive": + paths = { + "generator_path": model_path("LongLive-1.3B/models/longlive_base.pt"), + "lora_path": model_path("LongLive-1.3B/models/lora.pt"), + } + elif pid == "krea-realtime-video": + paths = { + "generator_path": model_path( + "krea-realtime-video/krea-realtime-video-14b.safetensors" + ), + "vae_path": model_path("Wan2.1-T2V-1.3B/Wan2.1_VAE.pth"), + } + elif pid == "reward-forcing": + paths = { + "generator_path": model_path("Reward-Forcing-T2V-1.3B/rewardforcing.pt") + } + + pipeline_config.update(paths) + if "text_encoder_path" not in pipeline_config: + pipeline_config["text_encoder_path"] = wan_enc + if "tokenizer_path" not in pipeline_config: + pipeline_config["tokenizer_path"] = wan_tok + + quantization = Quantization.FP8_E4M3FN if pid == "krea-realtime-video" else None + args = { + "config": OmegaConf.create(pipeline_config), + "device": self.device, + "dtype": torch.bfloat16, + } + if quantization: + args.update({"quantization": quantization}) + + if pid == "krea-realtime-video": + args["compile"] = self.compile_model + return pipeline_class(**args) + + def _clear_memory(self): + """Aggressively clear GPU and system memory.""" + for _ in range(3): + gc.collect() + torch.cuda.synchronize() + torch.cuda.empty_cache() + + +def main(): + parser = argparse.ArgumentParser(description="Scope Benchmark") + parser.add_argument("--pipelines", nargs="+", help="Specific pipelines to test") + parser.add_argument("--resolutions", nargs="+", help="Resolutions (e.g. 512x512)") + parser.add_argument( + "--iterations", type=int, default=30, help="Measurement iterations per config" + ) + parser.add_argument( + "--warmup", type=int, default=5, help="Warmup iterations per config" + ) + parser.add_argument( + "--output", default=f"benchmark_{datetime.now().strftime('%Y%m%d_%H%M')}.json" + ) + parser.add_argument( + "--no-tf32", action="store_true", help="Disable TF32 (enabled by default)" + ) + parser.add_argument("--compile", action="store_true", help="Enable torch.compile") + args = parser.parse_args() + + if not args.no_tf32 and torch.cuda.is_available(): + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + print("TF32 Enabled") + + custom_res = [] + if args.resolutions: + for r in args.resolutions: + try: + h, w = map(int, r.split("x")) + custom_res.append((h, w)) + except ValueError: + pass + + hw = HardwareInfo() + hw_dict = hw.to_dict() + cpu, mem, gpu = hw_dict['cpu'], hw_dict['memory'], hw_dict['gpu'] + + print("\n=== Hardware ===") + print(f"CPU: {cpu['model']} ({cpu.get('architecture', 'Unknown')})") + print(f"Cores: {cpu['physical_cores']}P/{cpu['logical_cores']}L" + + (f" | Container limit: {cpu['container_cpu_limit']:.1f}" if 'container_cpu_limit' in cpu else "")) + if 'simd_support' in cpu: + print(f"SIMD: {', '.join(cpu['simd_support'])}") + if 'cache' in cpu: + print(f"Cache: {', '.join(f'{k}={v}' for k, v in cpu['cache'].items())}") + + ram_str = f"{mem['total_gb']:.1f} GB" + if mem.get('is_containerized'): + ram_str += f" (host: {mem['host_total_gb']:.1f} GB)" + print(f"RAM: {ram_str}") + + print(f"GPU: {gpu.get('devices', [{}])[0].get('name', 'None')}") + print(f"VRAM: {hw.get_primary_gpu_vram_gb():.1f} GB") + + matrix = ConfigurationMatrix( + pipelines=args.pipelines, + resolutions=custom_res, + ).build() + + print(f"\nPlanned Configurations: {len(matrix)}") + if not matrix: + return + + runner = BenchmarkRunner(args.warmup, args.iterations, compile_model=args.compile) + results = [] + + try: + for i, config in enumerate(matrix, 1): + print(f"\n[{i}/{len(matrix)}]", end=" ") + metrics = runner.run_config(config) + results.append( + { + "pipeline": config["pipeline_id"], + "resolution": f"{config['height']}x{config['width']}", + "metrics": metrics, + } + ) + except KeyboardInterrupt: + print("\nStopped.") + + data = { + "metadata": {"timestamp": datetime.now().isoformat(), "args": vars(args)}, + "hardware": hw.to_dict(), + "results": results, + } + with open(args.output, "w") as f: + json.dump(data, f, indent=2) + print(f"\nSaved to {args.output}") + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index 70534ae36..a94c395a4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -108,6 +108,11 @@ dev = [ "pytest>=8.4.2", "freezegun>=1.5.5", ] +benchmark = [ + "psutil>=6.1.0", + "nvidia-ml-py>=12.560.30", + "py-cpuinfo>=9.0.0", +] [tool.ruff] line-length = 88 diff --git a/src/scope/core/pipelines/krea_realtime_video/docs/usage.md b/src/scope/core/pipelines/krea_realtime_video/docs/usage.md index c5ac8953a..fcc097302 100644 --- a/src/scope/core/pipelines/krea_realtime_video/docs/usage.md +++ b/src/scope/core/pipelines/krea_realtime_video/docs/usage.md @@ -89,7 +89,7 @@ Then: ``` # Run from scope directory -uv run -m score.core.pipelines.krea_realtime_video.test +uv run -m scope.core.pipelines.krea_realtime_video.test ``` This will create an `output.mp4` file in the `krea_realtime_video` directory. diff --git a/src/scope/core/pipelines/longlive/docs/usage.md b/src/scope/core/pipelines/longlive/docs/usage.md index 9026970dd..9335eedbf 100644 --- a/src/scope/core/pipelines/longlive/docs/usage.md +++ b/src/scope/core/pipelines/longlive/docs/usage.md @@ -73,7 +73,7 @@ Then: ``` # Run from scope directory -uv run -m score.core.pipelines.longlive.test +uv run -m scope.core.pipelines.longlive.test ``` This will create an `output.mp4` file in the `longlive` directory. diff --git a/src/scope/core/pipelines/streamdiffusionv2/docs/usage.md b/src/scope/core/pipelines/streamdiffusionv2/docs/usage.md index 036dd02be..d4c16058b 100644 --- a/src/scope/core/pipelines/streamdiffusionv2/docs/usage.md +++ b/src/scope/core/pipelines/streamdiffusionv2/docs/usage.md @@ -55,7 +55,7 @@ Then: ``` # Run from scope directory -uv run -m score.core.pipelines.streamdiffusionv2.test +uv run -m scope.core.pipelines.streamdiffusionv2.test ``` This will create an `output.mp4` file in the `streamdiffusionv2` directory. diff --git a/uv.lock b/uv.lock index ad0ef91b8..60147016c 100644 --- a/uv.lock +++ b/uv.lock @@ -655,6 +655,11 @@ dependencies = [ ] [package.dev-dependencies] +benchmark = [ + { name = "nvidia-ml-py" }, + { name = "psutil" }, + { name = "py-cpuinfo" }, +] dev = [ { name = "freezegun" }, { name = "imageio" }, @@ -701,6 +706,11 @@ requires-dist = [ ] [package.metadata.requires-dev] +benchmark = [ + { name = "nvidia-ml-py", specifier = ">=12.560.30" }, + { name = "psutil", specifier = ">=6.1.0" }, + { name = "py-cpuinfo", specifier = ">=9.0.0" }, +] dev = [ { name = "freezegun", specifier = ">=1.5.5" }, { name = "imageio", specifier = ">=2.37.0" }, @@ -1953,6 +1963,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/56/79/12978b96bd44274fe38b5dde5cfb660b1d114f70a65ef962bcbbed99b549/nvidia_cusparselt_cu12-0.7.1-py3-none-manylinux2014_x86_64.whl", hash = "sha256:f1bb701d6b930d5a7cea44c19ceb973311500847f81b634d802b7b539dc55623", size = 287193691, upload-time = "2025-02-26T00:15:44.104Z" }, ] +[[package]] +name = "nvidia-ml-py" +version = "13.590.44" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/1b/23/3871537f204aee823c574ba25cbeb08cae779979d4d43c01adddda00bab9/nvidia_ml_py-13.590.44.tar.gz", hash = "sha256:b358c7614b0fdeea4b95f046f1c90123bfe25d148ab93bb1c00248b834703373", size = 49737, upload-time = "2025-12-08T14:41:10.872Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e4/47/4c822bd37a008e72fd5a0eae33524ae3ac97b13f7030f63bae1728b8957e/nvidia_ml_py-13.590.44-py3-none-any.whl", hash = "sha256:18feb54eca7d0e3cdc8d1a040a771eda72d9ec3148e5443087970dbfd7377ecc", size = 50683, upload-time = "2025-12-08T14:41:09.597Z" }, +] + [[package]] name = "nvidia-nccl-cu12" version = "2.27.3" @@ -2294,6 +2313,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c9/ad/33b2ccec09bf96c2b2ef3f9a6f66baac8253d7565d8839e024a6b905d45d/psutil-7.1.3-cp37-abi3-win_arm64.whl", hash = "sha256:bd0d69cee829226a761e92f28140bec9a5ee9d5b4fb4b0cc589068dbfff559b1", size = 244608, upload-time = "2025-11-02T12:26:36.136Z" }, ] +[[package]] +name = "py-cpuinfo" +version = "9.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/37/a8/d832f7293ebb21690860d2e01d8115e5ff6f2ae8bbdc953f0eb0fa4bd2c7/py-cpuinfo-9.0.0.tar.gz", hash = "sha256:3cdbbf3fac90dc6f118bfd64384f309edeadd902d7c8fb17f02ffa1fc3f49690", size = 104716, upload-time = "2022-10-25T20:38:06.303Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e0/a9/023730ba63db1e494a271cb018dcd361bd2c917ba7004c3e49d5daf795a2/py_cpuinfo-9.0.0-py3-none-any.whl", hash = "sha256:859625bc251f64e21f077d099d4162689c762b5d6a4c3c97553d56241c9674d5", size = 22335, upload-time = "2022-10-25T20:38:27.636Z" }, +] + [[package]] name = "pycparser" version = "2.23"