From 945ba2e0b577e6d0717a4e1bbede38987030533c Mon Sep 17 00:00:00 2001 From: Kauna <16511995+klei22@users.noreply.github.com> Date: Sat, 7 Feb 2026 20:29:51 -0800 Subject: [PATCH 1/2] Expand vector PTQ demo with noise sweep plots --- demos/fake_ptq_vector_eval_demo_minipile.sh | 281 +++++++++++++----- .../ptq/embedding_gaussian_noise_ckpt.py | 240 +++++++++++++++ 2 files changed, 439 insertions(+), 82 deletions(-) create mode 100644 quantizations/ptq/embedding_gaussian_noise_ckpt.py diff --git a/demos/fake_ptq_vector_eval_demo_minipile.sh b/demos/fake_ptq_vector_eval_demo_minipile.sh index 0d29ff7473..ff4ca79603 100755 --- a/demos/fake_ptq_vector_eval_demo_minipile.sh +++ b/demos/fake_ptq_vector_eval_demo_minipile.sh @@ -3,17 +3,18 @@ # # Runs the fake PTQ pipeline using per-vector quantization heuristics inspired by # the JL transform initialization script. For each bit-width in the sweep the -# script quantizes the checkpoint with both per-vector and per-tensor -# granularities, evaluates the model on minipile, and records validation loss -# plus angle statistics relative to the fp32 baseline checkpoint. +# script quantizes the checkpoint with per-vector granularity, evaluates the +# model on minipile, and records validation loss plus angle statistics relative +# to the fp32 baseline checkpoint. It also sweeps embedding Gaussian noise +# perturbation magnitudes to compare loss/angle distortion vs alpha. set -euo pipefail EVAL_DATASET_DIR="data/minipile" OUT_DIR="out_fake_ptq_minipile" VECTOR_SWEEP_ROOT="${OUT_DIR}_vector_sweep" -TENSOR_SWEEP_ROOT="${OUT_DIR}_tensor_sweep" SUMMARY_ROOT="${OUT_DIR}_quantization_summaries" +NOISE_SWEEP_ROOT="${OUT_DIR}_gaussian_noise_sweep" EVAL_ITERS=200 BATCH_SIZE=64 BLOCK_SIZE=256 @@ -22,6 +23,8 @@ BIT_START=8 BIT_STOP=3 BIT_STEP=-1 +ALPHAS=(0.05 0.1 0.2 0.3) + usage() { cat <<'USAGE' Usage: demos/fake_ptq_vector_eval_demo_minipile.sh [--bit-start N] [--bit-stop N] [--bit-step N] @@ -83,7 +86,7 @@ else fi popd > /dev/null -mkdir -p "$VECTOR_SWEEP_ROOT" "$TENSOR_SWEEP_ROOT" "$SUMMARY_ROOT" +mkdir -p "$VECTOR_SWEEP_ROOT" "$SUMMARY_ROOT" "$NOISE_SWEEP_ROOT" echo "=== Step 2: Train a reference model on minipile (if needed) ===" if [ ! -f "$OUT_DIR/ckpt.pt" ]; then @@ -119,63 +122,77 @@ PATTERN='transformer\.h\.[0-9]+\.(attn\.(c_attn|c_proj)|mlp\.(c_fc|c_proj))\.wei step=4 for bit in "${BITS[@]}"; do - for granularity in vector tensor; do - case "$granularity" in - vector) - SWEEP_ROOT="$VECTOR_SWEEP_ROOT" - ANGLE_LABEL="per_vector" - ;; - tensor) - SWEEP_ROOT="$TENSOR_SWEEP_ROOT" - ANGLE_LABEL="per_tensor" - ;; - esac - - QUANT_OUT_DIR="${SWEEP_ROOT}/${bit}bit" - mkdir -p "$QUANT_OUT_DIR" - - echo "=== Step ${step}: Quantize to ${bit}-bit weights (${granularity}) ===" - if [ ! -f "$QUANT_OUT_DIR/ckpt.pt" ]; then - if [ "$granularity" = "vector" ]; then - python3 quantizations/ptq/fake_quantize_ckpt.py "$OUT_DIR" \ - --out_dir "$QUANT_OUT_DIR" \ - --num_bits "$bit" \ - --granularity vector - else - python3 quantizations/ptq/fake_quantize_ckpt.py "$OUT_DIR" \ - --out_dir "$QUANT_OUT_DIR" \ - --num_bits "$bit" - fi - else - echo "Found existing ${bit}-bit checkpoint at $QUANT_OUT_DIR/ckpt.pt; skipping quantization." - fi - - step=$((step + 1)) - - echo "=== Step ${step}: Evaluate the ${bit}-bit checkpoint (${granularity}) ===" - python3 sample.py \ + QUANT_OUT_DIR="${VECTOR_SWEEP_ROOT}/${bit}bit" + mkdir -p "$QUANT_OUT_DIR" + + echo "=== Step ${step}: Quantize to ${bit}-bit weights (vector) ===" + if [ ! -f "$QUANT_OUT_DIR/ckpt.pt" ]; then + python3 quantizations/ptq/fake_quantize_ckpt.py "$OUT_DIR" \ --out_dir "$QUANT_OUT_DIR" \ - --eval_only \ - --eval_dataset minipile - - step=$((step + 1)) - - echo "=== Step ${step}: Compare ${granularity} angles against baseline ===" - ANGLE_DIR="${QUANT_OUT_DIR}/angle_reports" - mkdir -p "$ANGLE_DIR" - python3 analysis/checkpoint_analysis/checkpoint_regex_explorer.py \ - "$OUT_DIR/ckpt.pt" \ - "$PATTERN" \ - --compare-ckpt "$QUANT_OUT_DIR/ckpt.pt" \ - --comparison-csv "${ANGLE_DIR}/${ANGLE_LABEL}_angles.csv" \ - --angle-units degrees \ - --no-colorize - - step=$((step + 1)) - done + --num_bits "$bit" \ + --granularity vector + else + echo "Found existing ${bit}-bit checkpoint at $QUANT_OUT_DIR/ckpt.pt; skipping quantization." + fi + + step=$((step + 1)) + + echo "=== Step ${step}: Evaluate the ${bit}-bit checkpoint (vector) ===" + python3 sample.py \ + --out_dir "$QUANT_OUT_DIR" \ + --eval_only \ + --eval_dataset minipile + + step=$((step + 1)) + + echo "=== Step ${step}: Compare vector angles against baseline ===" + ANGLE_DIR="${QUANT_OUT_DIR}/angle_reports" + mkdir -p "$ANGLE_DIR" + python3 analysis/checkpoint_analysis/checkpoint_regex_explorer.py \ + "$OUT_DIR/ckpt.pt" \ + "$PATTERN" \ + --compare-ckpt "$QUANT_OUT_DIR/ckpt.pt" \ + --comparison-csv "${ANGLE_DIR}/per_vector_angles.csv" \ + --angle-units degrees \ + --no-colorize + + step=$((step + 1)) +done + +echo "=== Step ${step}: Sweep embedding Gaussian noise perturbations ===" +python3 quantizations/ptq/embedding_gaussian_noise_ckpt.py "$OUT_DIR" \ + --out_dir "$NOISE_SWEEP_ROOT" \ + --alphas "${ALPHAS[@]}" + +step=$((step + 1)) + +for alpha in "${ALPHAS[@]}"; do + alpha_tag="${alpha//./p}" + NOISE_OUT_DIR="${NOISE_SWEEP_ROOT}/alpha_${alpha_tag}" + + echo "=== Step ${step}: Evaluate noise checkpoint (alpha=${alpha}) ===" + python3 sample.py \ + --out_dir "$NOISE_OUT_DIR" \ + --eval_only \ + --eval_dataset minipile + + step=$((step + 1)) + + echo "=== Step ${step}: Compare noise angles against baseline (alpha=${alpha}) ===" + ANGLE_DIR="${NOISE_OUT_DIR}/angle_reports" + mkdir -p "$ANGLE_DIR" + python3 analysis/checkpoint_analysis/checkpoint_regex_explorer.py \ + "$OUT_DIR/ckpt.pt" \ + "$PATTERN" \ + --compare-ckpt "$NOISE_OUT_DIR/ckpt.pt" \ + --comparison-csv "${ANGLE_DIR}/per_vector_angles.csv" \ + --angle-units degrees \ + --no-colorize + + step=$((step + 1)) done -python3 - "$OUT_DIR" "$VECTOR_SWEEP_ROOT" "$TENSOR_SWEEP_ROOT" "$SUMMARY_ROOT" "${BITS[@]}" <<'PY' +python3 - "$OUT_DIR" "$VECTOR_SWEEP_ROOT" "$NOISE_SWEEP_ROOT" "$SUMMARY_ROOT" "${BITS[@]}" <<'PY' import csv import json import math @@ -185,7 +202,7 @@ import sys out_dir = os.path.abspath(sys.argv[1]) vector_root = os.path.abspath(sys.argv[2]) -tensor_root = os.path.abspath(sys.argv[3]) +noise_root = os.path.abspath(sys.argv[3]) summary_root = os.path.abspath(sys.argv[4]) sweep_bits = [int(arg) for arg in sys.argv[5:]] @@ -207,7 +224,7 @@ def load_sweep(root: str, granularity: str) -> list[dict[str, object]]: raise SystemExit(f"Expected sweep root at {root}") entries: list[dict[str, object]] = [] - angle_suffix = "per_vector_angles.csv" if granularity == "vector" else "per_tensor_angles.csv" + angle_suffix = "per_vector_angles.csv" for bit in sweep_bits: loss_path = os.path.join(root, f"{bit}bit", "eval_loss.txt") @@ -263,6 +280,62 @@ def load_sweep(root: str, granularity: str) -> list[dict[str, object]]: return entries +def load_noise_sweep(root: str) -> list[dict[str, object]]: + entries: list[dict[str, object]] = [] + if not os.path.isdir(root): + raise SystemExit(f"Expected noise sweep root at {root}") + + for name in sorted(os.listdir(root)): + if not name.startswith("alpha_"): + continue + alpha_str = name.split("alpha_", 1)[1].replace("p", ".") + try: + alpha_val = float(alpha_str) + except ValueError: + continue + + loss_path = os.path.join(root, name, "eval_loss.txt") + if not os.path.exists(loss_path): + raise SystemExit(f"Missing evaluation summary at {loss_path}") + with open(loss_path, encoding="utf-8") as fh: + eval_data = json.load(fh) + loss = eval_data.get("val") + if loss is None: + raise SystemExit(f"No 'val' key found in {loss_path}") + + angle_csv = os.path.join(root, name, "angle_reports", "per_vector_angles.csv") + angle_summary = None + if os.path.exists(angle_csv): + angles: list[float] = [] + with open(angle_csv, newline="", encoding="utf-8") as csv_file: + reader = csv.DictReader(csv_file) + for row in reader: + try: + angle_val = float(row.get("angle", "nan")) + except (TypeError, ValueError): + continue + if math.isfinite(angle_val): + angles.append(angle_val) + if angles: + angle_summary = { + "mean_angle": statistics.mean(angles), + "median_angle": statistics.median(angles), + } + + entries.append( + { + "alpha": alpha_val, + "label": f"alpha={alpha_val:g}", + "val_loss": float(loss), + "mean_angle": None if angle_summary is None else angle_summary["mean_angle"], + "median_angle": None if angle_summary is None else angle_summary["median_angle"], + } + ) + + entries.sort(key=lambda item: item["alpha"]) + return entries + + baseline_entry = { "bits": 32, "granularity": "fp32", @@ -275,7 +348,7 @@ baseline_entry = { all_entries = [baseline_entry] all_entries.extend(load_sweep(vector_root, "vector")) -all_entries.extend(load_sweep(tensor_root, "tensor")) +noise_entries = load_noise_sweep(noise_root) all_entries.sort(key=lambda item: (item["granularity"] != "fp32", -item["bits"])) @@ -305,6 +378,28 @@ with open(csv_path, "w", newline="", encoding="utf-8") as csv_out: } ) +noise_csv_path = os.path.join(summary_root, "gaussian_noise_eval_summary.csv") +with open(noise_csv_path, "w", newline="", encoding="utf-8") as csv_out: + fieldnames = [ + "alpha", + "label", + "val_loss", + "mean_angle_deg", + "median_angle_deg", + ] + writer = csv.DictWriter(csv_out, fieldnames=fieldnames) + writer.writeheader() + for entry in noise_entries: + writer.writerow( + { + "alpha": f"{entry['alpha']:.6f}", + "label": entry["label"], + "val_loss": f"{entry['val_loss']:.8f}", + "mean_angle_deg": "" if entry["mean_angle"] is None else f"{entry['mean_angle']:.8f}", + "median_angle_deg": "" if entry["median_angle"] is None else f"{entry['median_angle']:.8f}", + } + ) + try: import matplotlib.pyplot as plt except Exception as exc: # pragma: no cover - plotting dependency issues @@ -312,46 +407,68 @@ except Exception as exc: # pragma: no cover - plotting dependency issues plt.style.use("seaborn-v0_8") -fig, (ax_loss, ax_angle) = plt.subplots(1, 2, figsize=(12, 5)) - -granularities = ["vector", "tensor"] -markers = {"vector": "o", "tensor": "s"} -colors = {"vector": "tab:blue", "tensor": "tab:orange"} +fig, ((ax_loss, ax_angle), (ax_noise_loss, ax_noise_angle)) = plt.subplots( + 2, 2, figsize=(13, 9) +) -for granularity in granularities: - subset = [entry for entry in all_entries if entry["granularity"] == granularity] - if not subset: - continue - subset.sort(key=lambda item: item["bits"], reverse=True) - bits = [entry["bits"] for entry in subset] - losses = [entry["val_loss"] for entry in subset] - ax_loss.plot(bits, losses, marker=markers[granularity], color=colors[granularity], label=f"{granularity} quant") +subset = [entry for entry in all_entries if entry["granularity"] == "vector"] +subset.sort(key=lambda item: item["bits"], reverse=True) +bits = [entry["bits"] for entry in subset] +losses = [entry["val_loss"] for entry in subset] +ax_loss.plot(bits, losses, marker="o", color="tab:blue", label="vector quant") - angles = [entry["mean_angle"] for entry in subset] - valid_pairs = [(b, a) for b, a in zip(bits, angles) if a is not None] - if valid_pairs: - vb, va = zip(*valid_pairs) - ax_angle.plot(vb, va, marker=markers[granularity], color=colors[granularity], label=f"{granularity} quant") +angles = [entry["mean_angle"] for entry in subset] +valid_pairs = [(b, a) for b, a in zip(bits, angles) if a is not None] +if valid_pairs: + vb, va = zip(*valid_pairs) + ax_angle.plot(vb, va, marker="o", color="tab:blue", label="vector quant") ax_loss.axhline(baseline_entry["val_loss"], color="tab:green", linestyle="--", label="fp32 baseline") ax_loss.set_xlabel("Bits") ax_loss.set_ylabel("Validation loss") -ax_loss.set_title("Validation loss vs. bit-width") +ax_loss.set_title("Validation loss vs. bit-width (vector PTQ)") ax_loss.legend() ax_loss.grid(True, which="both", linestyle=":", linewidth=0.5) ax_angle.set_xlabel("Bits") ax_angle.set_ylabel("Mean angle (degrees)") -ax_angle.set_title("Mean angle vs. bit-width") +ax_angle.set_title("Mean angle vs. bit-width (vector PTQ)") ax_angle.legend() ax_angle.grid(True, which="both", linestyle=":", linewidth=0.5) +if noise_entries: + alphas = [entry["alpha"] for entry in noise_entries] + noise_losses = [entry["val_loss"] for entry in noise_entries] + noise_angles = [entry["mean_angle"] for entry in noise_entries] + + ax_noise_loss.plot(alphas, noise_losses, marker="o", color="tab:purple", label="gaussian noise") + ax_noise_loss.axhline(baseline_entry["val_loss"], color="tab:green", linestyle="--", label="fp32 baseline") + ax_noise_loss.set_xlabel("Alpha") + ax_noise_loss.set_ylabel("Validation loss") + ax_noise_loss.set_title("Validation loss vs. noise alpha") + ax_noise_loss.legend() + ax_noise_loss.grid(True, which="both", linestyle=":", linewidth=0.5) + + valid_noise_pairs = [(a, ang) for a, ang in zip(alphas, noise_angles) if ang is not None] + if valid_noise_pairs: + na, nv = zip(*valid_noise_pairs) + ax_noise_angle.plot(na, nv, marker="o", color="tab:purple", label="gaussian noise") + ax_noise_angle.set_xlabel("Alpha") + ax_noise_angle.set_ylabel("Mean angle (degrees)") + ax_noise_angle.set_title("Mean angle vs. noise alpha") + ax_noise_angle.legend() + ax_noise_angle.grid(True, which="both", linestyle=":", linewidth=0.5) +else: + ax_noise_loss.set_visible(False) + ax_noise_angle.set_visible(False) + fig.tight_layout() plot_path = os.path.join(summary_root, "quantization_eval_summary.png") fig.savefig(plot_path, dpi=200) print(f"Wrote summary CSV to {csv_path}") +print(f"Wrote noise summary CSV to {noise_csv_path}") print(f"Wrote comparison plot to {plot_path}") PY diff --git a/quantizations/ptq/embedding_gaussian_noise_ckpt.py b/quantizations/ptq/embedding_gaussian_noise_ckpt.py new file mode 100644 index 0000000000..6ecc0aaef5 --- /dev/null +++ b/quantizations/ptq/embedding_gaussian_noise_ckpt.py @@ -0,0 +1,240 @@ +import argparse +import os +import shutil +from typing import Iterable, List, Optional + +import torch + + +EPS = 1e-6 + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description=( + "Apply embedding-style Gaussian vector noise to all weights in a checkpoint " + "(vector mode only)." + ) + ) + parser.add_argument( + "ckpt_dir", + type=str, + help="Directory containing ckpt.pt and meta.pkl from a previous training run", + ) + parser.add_argument( + "--out_dir", + type=str, + default=None, + help=( + "Directory to write the noisy checkpoint(s) (defaults to " + "_gaussian_noise). If multiple alphas are provided, " + "subdirectories are created under this path." + ), + ) + parser.add_argument( + "--alphas", + type=str, + nargs="+", + default=["0.2"], + help=( + "Comma- or space-separated alpha values to scale the noise (default: 0.2). " + "Example: --alphas 0.1 0.2 or --alphas 0.1,0.2" + ), + ) + parser.add_argument( + "--seed", + type=int, + default=1337, + help="Random seed for the Gaussian noise", + ) + return parser.parse_args() + + +def parse_alpha_list(items: Iterable[str]) -> List[float]: + alphas: List[float] = [] + for item in items: + for part in str(item).split(","): + value = part.strip() + if not value: + continue + alphas.append(float(value)) + if not alphas: + raise ValueError("At least one alpha value must be provided") + return alphas + + +def iter_state_items(state_dict): + if isinstance(state_dict, torch.nn.Module): + iterable = state_dict.state_dict().items() + elif isinstance(state_dict, dict): + iterable = state_dict.items() + else: + iterable = getattr(state_dict, "state_dict", lambda: {})().items() + + for key, value in iterable: + if torch.is_tensor(value): + yield key, value + + +def infer_embedding_dimension(checkpoint, state_dict) -> Optional[int]: + for container_name in ("model_args", "config"): + container = getattr(checkpoint, "get", None) + if callable(container): + container = checkpoint.get(container_name) + else: + container = None + if isinstance(container, dict): + value = container.get("n_embd") + if isinstance(value, int): + return value + + state_get = getattr(state_dict, "get", None) + for search_key in ( + "transformer.wte.weight", + "wte.weight", + "tok_embeddings.weight", + ): + tensor = state_get(search_key) if callable(state_get) else None + if torch.is_tensor(tensor) and tensor.ndim == 2: + return int(tensor.shape[1]) + + for name, tensor in iter_state_items(state_dict): + if name.endswith("wte.weight") and torch.is_tensor(tensor) and tensor.ndim == 2: + return int(tensor.shape[1]) + + return None + + +def _vector_norm(tensor: torch.Tensor) -> torch.Tensor: + return torch.linalg.vector_norm(tensor, dim=-1, keepdim=True) + + +def apply_noise_to_vectors( + vectors: torch.Tensor, + alphas: torch.Tensor, + *, + generator: torch.Generator, +) -> torch.Tensor: + alphas = alphas.to(dtype=vectors.dtype) + noise = torch.randn_like(vectors, generator=generator) + noise = noise / (_vector_norm(noise) + EPS) + weight_norm = _vector_norm(vectors) + scaled_noise = noise.unsqueeze(0) * alphas.view(-1, *([1] * vectors.ndim)) + scaled_noise = scaled_noise * weight_norm.unsqueeze(0) + perturbed = vectors.unsqueeze(0) + scaled_noise + perturbed_norm = _vector_norm(perturbed) + perturbed = perturbed / (perturbed_norm + EPS) * weight_norm.unsqueeze(0) + return perturbed + + +def apply_noise_per_vector( + tensor: torch.Tensor, + alphas: torch.Tensor, + embedding_dim: int, + *, + generator: torch.Generator, +) -> Optional[List[torch.Tensor]]: + if tensor.ndim >= 1 and tensor.shape[-1] == embedding_dim: + perturbed = apply_noise_to_vectors(tensor, alphas, generator=generator) + return [perturbed[idx] for idx in range(perturbed.shape[0])] + + if tensor.ndim > 1 and tensor.shape[0] == embedding_dim: + moved = torch.movedim(tensor, 0, -1) + perturbed = apply_noise_to_vectors(moved, alphas, generator=generator) + return [torch.movedim(perturbed[idx], -1, 0) for idx in range(perturbed.shape[0])] + + return None + + +def build_noisy_state_dicts( + state_dict, + alphas: List[float], + embedding_dim: int, + *, + generator: torch.Generator, +) -> List[dict]: + alpha_tensor = torch.tensor(alphas, dtype=torch.float32) + noisy_state_dicts = [dict() for _ in alphas] + for key, value in state_dict.items(): + if not torch.is_tensor(value) or not torch.is_floating_point(value): + for idx in range(len(alphas)): + noisy_state_dicts[idx][key] = value + continue + outputs = apply_noise_per_vector( + value, alpha_tensor, embedding_dim, generator=generator + ) + if outputs is None: + for idx in range(len(alphas)): + noisy_state_dicts[idx][key] = value + continue + for idx, noisy in enumerate(outputs): + noisy_state_dicts[idx][key] = noisy + return noisy_state_dicts + + +def format_alpha(alpha: float) -> str: + return f"{alpha:g}".replace(".", "p") + + +def main() -> None: + args = parse_args() + alphas = parse_alpha_list(args.alphas) + + ckpt_path = os.path.join(args.ckpt_dir, "ckpt.pt") + checkpoint = torch.load(ckpt_path, map_location="cpu") + + if isinstance(checkpoint, dict) and "model" in checkpoint: + state_obj = checkpoint["model"] + else: + state_obj = checkpoint + + if isinstance(state_obj, dict): + state_dict = state_obj + else: + to_state_dict = getattr(state_obj, "state_dict", None) + if callable(to_state_dict): + state_dict = to_state_dict() + if isinstance(checkpoint, dict) and "model" in checkpoint: + checkpoint["model"] = state_dict + else: + checkpoint = state_dict + else: + raise TypeError( + "Unsupported checkpoint format: expected a mapping for the model state" + ) + + embedding_dim = infer_embedding_dimension(checkpoint, state_dict) + if embedding_dim is None: + raise ValueError("Could not determine n_embd from checkpoint") + + g = torch.Generator() + g.manual_seed(args.seed) + + noisy_state_dicts = build_noisy_state_dicts( + state_dict, alphas, embedding_dim, generator=g + ) + + base_out_dir = args.out_dir or f"{args.ckpt_dir}_gaussian_noise" + for alpha, noisy_state in zip(alphas, noisy_state_dicts): + if len(alphas) == 1: + out_dir = base_out_dir + else: + out_dir = os.path.join(base_out_dir, f"alpha_{format_alpha(alpha)}") + os.makedirs(out_dir, exist_ok=True) + + if isinstance(checkpoint, dict) and "model" in checkpoint: + checkpoint["model"] = noisy_state + out_checkpoint = checkpoint + else: + out_checkpoint = noisy_state + + torch.save(out_checkpoint, os.path.join(out_dir, "ckpt.pt")) + + meta_in = os.path.join(args.ckpt_dir, "meta.pkl") + meta_out = os.path.join(out_dir, "meta.pkl") + if os.path.exists(meta_in): + shutil.copy(meta_in, meta_out) + + +if __name__ == "__main__": + main() From 255cf29aad02ad073f688e898ee1802922a949b4 Mon Sep 17 00:00:00 2001 From: Kauna <16511995+klei22@users.noreply.github.com> Date: Sat, 7 Feb 2026 21:12:14 -0800 Subject: [PATCH 2/2] Fix gaussian noise sampling for older torch --- quantizations/ptq/embedding_gaussian_noise_ckpt.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/quantizations/ptq/embedding_gaussian_noise_ckpt.py b/quantizations/ptq/embedding_gaussian_noise_ckpt.py index 6ecc0aaef5..c14850588c 100644 --- a/quantizations/ptq/embedding_gaussian_noise_ckpt.py +++ b/quantizations/ptq/embedding_gaussian_noise_ckpt.py @@ -116,7 +116,12 @@ def apply_noise_to_vectors( generator: torch.Generator, ) -> torch.Tensor: alphas = alphas.to(dtype=vectors.dtype) - noise = torch.randn_like(vectors, generator=generator) + noise = torch.randn( + vectors.shape, + generator=generator, + device=vectors.device, + dtype=vectors.dtype, + ) noise = noise / (_vector_norm(noise) + EPS) weight_norm = _vector_norm(vectors) scaled_noise = noise.unsqueeze(0) * alphas.view(-1, *([1] * vectors.ndim))