Skip to content

Charon-yzc/CS336

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

10 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

CS336: 大型语言模型工程栈 (Spring 2025)

从原始数据清洗到对齐调优的端到端 LLM 教学与研究仓库。

目录

项目概览

本仓库汇集了 Stanford CS336「大型语言模型」课程的完整工程栈,覆盖数据工程、分词器构建、Transformer 模型实现、训练调度、系统级优化、缩放法则验证以及对齐与安全研究。每一部分都提供可以直接运行的实现、基准脚本和教学配套资料,既适合课堂作业,也能作为构建工业级 LLM pipeline 的参考。

你可以在这里体验从 Common Crawl 文本过滤、MinHash 去重和 fastText 质量评分,到 FlashAttention Triton kernel、分布式梯度同步重叠,再到 SFT/GRPO 对齐和安全强化的全流程实践。

核心亮点

  • 端到端数据工程data_/pipeline.py, data_/minihash_lsh.py, data_/quality.py 等脚本复现语言识别、毒性检测、Gopher 质量规则以及 MinHash LSH 去重。
  • 自研 BPE 分词器basics/train_bpe.pybasics/tokenizer.py 支持自定义特殊 token、并行 BPE 学习与二进制语料导出。
  • 现代化 Transformer 栈basics/model.py 集成 Rotary Embedding、RMSNorm,并可切换 FlashAttention v2(system/flashattentionv2_full_version_forward_backward.py)与 Triton kernel(system/flash_attention_triton.py)。
  • 可编排的训练循环basics/train.py 搭配 basics/optimizer.pybasics/nn_utils.py 提供 torch.compile 加速、余弦学习率调度、梯度裁剪以及 Weights & Biases 日志集成。
  • 系统级性能优化system/ddp.py, system/ddp_overlap_*, system/profile_time.py 展示 DDP 通信重叠、端到端 profiler 以及算子级优化;system/cuda/ 提供手写 CUDA kernel 与 Makefile。
  • 缩放法则与吞吐实验scaling_Laws/chinchilla_curve.py, system/benchmark_lm.py 辅助研究 compute-optimal 策略与算力利用率。
  • 对齐与安全研究alignment/ 子仓库落地 SFT、GRPO、奖励模型和安全强化流程,配套 uv 环境和单元测试。
  • 课程资料完备handout/, slides/, spring2025-lectures/ 覆盖讲义、幻灯片、演示代码与依赖配置。

功能模块

  • basics/:语言模型的核心实现,包含 Transformer 模块、优化器、数据加载、BPE 工具与训练脚本。
  • data_/:大规模语料的数据清洗与质量控制流水线,涵盖 WARC/WET 解析、语言识别、毒性过滤、MinHash 去重等组件。
  • system/:性能优化与系统实验,包括 FlashAttention 变体、分布式训练、profiling 工具、自定义 CUDA kernel 和性能报告脚本。
  • alignment/:对齐与安全作业,提供 SFT、GRPO、奖励模型训练、submission 工具与 uv 依赖锁定。
  • scaling_Laws/:Chinchilla 缩放法则和等算力曲线分析脚本,支持实验中对 compute/data 模式的敏感性研究。
  • 教学资料handout/, slides/, spring2025-lectures/ 集成讲义 PDF、课堂幻灯与示范 notebook。

环境准备

运行要求 Python 3.10+ 与 PyTorch (MPS/CUDA/CPU 均可)。推荐步骤如下:

python3 -m venv .venv
source .venv/bin/activate    # Windows 使用 .venv\Scripts\activate
pip install --upgrade pip

# 安装课程主体依赖
pip install -r spring2025-lectures/requirements.txt

# 数据流水线与作业所需的额外依赖
pip install fastwarc resiliparse uv

# 可选:安装 FlashAttention(需对应 CUDA 环境)
pip install flash-attn --no-build-isolation

提示:如果使用 CUDA,请根据硬件选择官方提供的 PyTorch/flash-attn 轮子;对于 Apple Silicon,可直接依赖 MPS 后端。

快速开始

  1. 准备数据:将预处理后的二进制语料放置在 data_/<dataset_name>/train.bindata_/<dataset_name>/valid.binbasics/data.py 默认按 uint16 读取。
  2. 运行训练脚本:在 basics/ 目录下执行 train.py 或使用 scripts/run_train.sh 的示例配置。
cd basics
python train.py \
  --dataset_name tinystory \
  --context_length 256 \
  --batch_size 16 \
  --vocab_size 10000 \
  --d_model 512 \
  --num_layers 4 \
  --num_heads 16 \
  --d_ff 1344 \
  --total_iters 20000 \
  --max_learning_rate 5e-4 \
  --cosine_cycle_iters 20000 \
  --weight_decay 1e-3 \
  --wandb_logging True \
  --wandb_project cs336-assignment1 \
  --wandb_run_name tinystories-baseline

运行前请确认 WANDB_API_KEY 已设置;如无需日志,可将 --wandb_logging 设为 False。训练过程中会默认使用 torch.compile,并在 data/out/checkpoints/ 下保存权重。

数据与分词流水线

  • BPE 词表训练:调整 basics/train_bpe.py 底部的配置后执行脚本,可在多进程下统计 pretokens 并生成自定义 vocab.jsonmerge.txt
  • 分词与二进制导出:使用 basics/tokenizer.pyTokenizer.encode_iterable 将文本语料转换为 token id,并基于 numpy.memmap 写入 .bin 文件,以供 Dataset 类高效加载。
  • 网页数据清洗data_/pipeline.py 演示如何读取 Common Crawl WET 文件,串联 extract_text_from_htmlfilter_by_languagefilter_by_gopher_rulesfilter_by_harmful_content 等步骤。将 LANG_ID_MODEL_PATHNSFW_MODEL_PATHTOXIC_MODEL_PATH 指向本地 fastText 模型即可获得真实打分。
  • 去重与质量控制data_/minihash_lsh.pydata_/quality.pydata_/toxicity.py 等模块提供 MinHash LSH、启发式质量指标与毒性评估的参考实现。

训练与实验工具

  • scripts/run_train.sh:TinyStories baseline,演示完整指令与日志配置。
  • basics/optimizer.py:实现 AdamW 与余弦退火调度,提供 set_lr 便于自定义策略。
  • basics/nn_utils.py:包含数值稳定的 softmax、交叉熵、梯度裁剪等基础算子。
  • system/benchmark_lm.py:对不同模型配置进行吞吐量基准测试,便于对比缩放策略。
  • system/profile_time.py:封装 PyTorch Profiler,支持 forward-only 与 forward+backward 分析并输出 flame stack。
  • system/profile_memory.py, system/profile_normlayer.py:追踪显存占用、诊断瓶颈。

系统级优化工具

  • FlashAttention 变体system/flash_attention_triton.py 提供 Triton 实现;system/flashattentionv2_full_version_forward_backward.py 对标官方 FlashAttention v2 的前向与反向核。
  • 分布式训练system/ddp.pysystem/ddp_overlap_bucket.pysystem/ddp_overlap_individual.py 展示梯度分桶、通信重叠与异步 all-reduce。
  • CUDA 示例system/cuda/ 包含 matrix_add.cu, vector_add.cu 等 kernel 及编译用 Makefile,可作为自定义算子模板。
  • Profilingsystem/profile_atten_time.py, system/profile_sharding.pysystem/profile_memory.py 支持算子级时间、显存与 sharding 策略比对。

对齐与安全实验

alignment/ 子仓库提供 Assignment 5 的完整实现:

cd alignment
uv sync --no-install-package flash-attn
uv sync
uv run pytest

完成 tests/adapters.py 后即可驱动 SFT、GRPO(alignment/cs336_alignment/grpo)与安全 RLHF 实验,scripts/test_and_make_submission.sh 帮助打包提交。

仓库结构

CS336/
|-- basics/                 # Transformer 核心实现(MQA,GQA,MOE)、BPE 工具与训练脚本
|-- data_/                  # 数据清洗与质量评估流水线
|-- system/                 # 加速、分布式与 profiling 工具
|-- alignment/              # SFT/GRPO 对齐实验与安全作业
|-- scaling_Laws/           # 缩放法则分析脚本
|-- handout/                # 作业讲义 PDF
|-- spring2025-lectures/    # 课堂 demo 与依赖配置
|-- dir/                    # 数据集附录示例(如 MATH dataset card)
|-- README.md

资料与参考

  • 作业手册:查看 handout/ 目录内的 PDF。
  • Lecture demos:spring2025-lectures/ 给出配套 notebook、requirements 与样例代码。
  • Alignment 文档:alignment/README.md 与随附 PDF 详细说明 RLHF、安全强化流程。

About

assignments

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published