从原始数据清洗到对齐调优的端到端 LLM 教学与研究仓库。
本仓库汇集了 Stanford CS336「大型语言模型」课程的完整工程栈,覆盖数据工程、分词器构建、Transformer 模型实现、训练调度、系统级优化、缩放法则验证以及对齐与安全研究。每一部分都提供可以直接运行的实现、基准脚本和教学配套资料,既适合课堂作业,也能作为构建工业级 LLM pipeline 的参考。
你可以在这里体验从 Common Crawl 文本过滤、MinHash 去重和 fastText 质量评分,到 FlashAttention Triton kernel、分布式梯度同步重叠,再到 SFT/GRPO 对齐和安全强化的全流程实践。
- 端到端数据工程:
data_/pipeline.py,data_/minihash_lsh.py,data_/quality.py等脚本复现语言识别、毒性检测、Gopher 质量规则以及 MinHash LSH 去重。 - 自研 BPE 分词器:
basics/train_bpe.py与basics/tokenizer.py支持自定义特殊 token、并行 BPE 学习与二进制语料导出。 - 现代化 Transformer 栈:
basics/model.py集成 Rotary Embedding、RMSNorm,并可切换 FlashAttention v2(system/flashattentionv2_full_version_forward_backward.py)与 Triton kernel(system/flash_attention_triton.py)。 - 可编排的训练循环:
basics/train.py搭配basics/optimizer.py、basics/nn_utils.py提供torch.compile加速、余弦学习率调度、梯度裁剪以及 Weights & Biases 日志集成。 - 系统级性能优化:
system/ddp.py,system/ddp_overlap_*,system/profile_time.py展示 DDP 通信重叠、端到端 profiler 以及算子级优化;system/cuda/提供手写 CUDA kernel 与 Makefile。 - 缩放法则与吞吐实验:
scaling_Laws/chinchilla_curve.py,system/benchmark_lm.py辅助研究 compute-optimal 策略与算力利用率。 - 对齐与安全研究:
alignment/子仓库落地 SFT、GRPO、奖励模型和安全强化流程,配套uv环境和单元测试。 - 课程资料完备:
handout/,slides/,spring2025-lectures/覆盖讲义、幻灯片、演示代码与依赖配置。
basics/:语言模型的核心实现,包含 Transformer 模块、优化器、数据加载、BPE 工具与训练脚本。data_/:大规模语料的数据清洗与质量控制流水线,涵盖 WARC/WET 解析、语言识别、毒性过滤、MinHash 去重等组件。system/:性能优化与系统实验,包括 FlashAttention 变体、分布式训练、profiling 工具、自定义 CUDA kernel 和性能报告脚本。alignment/:对齐与安全作业,提供 SFT、GRPO、奖励模型训练、submission 工具与uv依赖锁定。scaling_Laws/:Chinchilla 缩放法则和等算力曲线分析脚本,支持实验中对 compute/data 模式的敏感性研究。- 教学资料:
handout/,slides/,spring2025-lectures/集成讲义 PDF、课堂幻灯与示范 notebook。
运行要求 Python 3.10+ 与 PyTorch (MPS/CUDA/CPU 均可)。推荐步骤如下:
python3 -m venv .venv
source .venv/bin/activate # Windows 使用 .venv\Scripts\activate
pip install --upgrade pip
# 安装课程主体依赖
pip install -r spring2025-lectures/requirements.txt
# 数据流水线与作业所需的额外依赖
pip install fastwarc resiliparse uv
# 可选:安装 FlashAttention(需对应 CUDA 环境)
pip install flash-attn --no-build-isolation提示:如果使用 CUDA,请根据硬件选择官方提供的 PyTorch/flash-attn 轮子;对于 Apple Silicon,可直接依赖 MPS 后端。
- 准备数据:将预处理后的二进制语料放置在
data_/<dataset_name>/train.bin与data_/<dataset_name>/valid.bin。basics/data.py默认按uint16读取。 - 运行训练脚本:在
basics/目录下执行train.py或使用scripts/run_train.sh的示例配置。
cd basics
python train.py \
--dataset_name tinystory \
--context_length 256 \
--batch_size 16 \
--vocab_size 10000 \
--d_model 512 \
--num_layers 4 \
--num_heads 16 \
--d_ff 1344 \
--total_iters 20000 \
--max_learning_rate 5e-4 \
--cosine_cycle_iters 20000 \
--weight_decay 1e-3 \
--wandb_logging True \
--wandb_project cs336-assignment1 \
--wandb_run_name tinystories-baseline运行前请确认 WANDB_API_KEY 已设置;如无需日志,可将 --wandb_logging 设为 False。训练过程中会默认使用 torch.compile,并在 data/out/checkpoints/ 下保存权重。
- BPE 词表训练:调整
basics/train_bpe.py底部的配置后执行脚本,可在多进程下统计 pretokens 并生成自定义vocab.json与merge.txt。 - 分词与二进制导出:使用
basics/tokenizer.py的Tokenizer.encode_iterable将文本语料转换为 token id,并基于numpy.memmap写入.bin文件,以供Dataset类高效加载。 - 网页数据清洗:
data_/pipeline.py演示如何读取 Common Crawl WET 文件,串联extract_text_from_html、filter_by_language、filter_by_gopher_rules、filter_by_harmful_content等步骤。将LANG_ID_MODEL_PATH、NSFW_MODEL_PATH、TOXIC_MODEL_PATH指向本地 fastText 模型即可获得真实打分。 - 去重与质量控制:
data_/minihash_lsh.py、data_/quality.py、data_/toxicity.py等模块提供 MinHash LSH、启发式质量指标与毒性评估的参考实现。
scripts/run_train.sh:TinyStories baseline,演示完整指令与日志配置。basics/optimizer.py:实现 AdamW 与余弦退火调度,提供set_lr便于自定义策略。basics/nn_utils.py:包含数值稳定的 softmax、交叉熵、梯度裁剪等基础算子。system/benchmark_lm.py:对不同模型配置进行吞吐量基准测试,便于对比缩放策略。system/profile_time.py:封装 PyTorch Profiler,支持 forward-only 与 forward+backward 分析并输出 flame stack。system/profile_memory.py,system/profile_normlayer.py:追踪显存占用、诊断瓶颈。
- FlashAttention 变体:
system/flash_attention_triton.py提供 Triton 实现;system/flashattentionv2_full_version_forward_backward.py对标官方 FlashAttention v2 的前向与反向核。 - 分布式训练:
system/ddp.py、system/ddp_overlap_bucket.py、system/ddp_overlap_individual.py展示梯度分桶、通信重叠与异步 all-reduce。 - CUDA 示例:
system/cuda/包含matrix_add.cu,vector_add.cu等 kernel 及编译用Makefile,可作为自定义算子模板。 - Profiling:
system/profile_atten_time.py,system/profile_sharding.py与system/profile_memory.py支持算子级时间、显存与 sharding 策略比对。
alignment/ 子仓库提供 Assignment 5 的完整实现:
cd alignment
uv sync --no-install-package flash-attn
uv sync
uv run pytest完成 tests/adapters.py 后即可驱动 SFT、GRPO(alignment/cs336_alignment/grpo)与安全 RLHF 实验,scripts/test_and_make_submission.sh 帮助打包提交。
CS336/
|-- basics/ # Transformer 核心实现(MQA,GQA,MOE)、BPE 工具与训练脚本
|-- data_/ # 数据清洗与质量评估流水线
|-- system/ # 加速、分布式与 profiling 工具
|-- alignment/ # SFT/GRPO 对齐实验与安全作业
|-- scaling_Laws/ # 缩放法则分析脚本
|-- handout/ # 作业讲义 PDF
|-- spring2025-lectures/ # 课堂 demo 与依赖配置
|-- dir/ # 数据集附录示例(如 MATH dataset card)
|-- README.md
- 作业手册:查看
handout/目录内的 PDF。 - Lecture demos:
spring2025-lectures/给出配套 notebook、requirements 与样例代码。 - Alignment 文档:
alignment/README.md与随附 PDF 详细说明 RLHF、安全强化流程。