Skip to content

RL finetuning method that ensures that the inference-time compute for queries is optimized based on query difficulty, leading to significant inference efficiency.

License

Notifications You must be signed in to change notification settings

microsoft/FrugalRAG

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

5 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

FrugalRAG

A Retrieval-Augmented Generation approach for efficient multi-hop question answering

arXiv Python 3.10+

Paper: FrugalRAG: Learning to retrieve and reason for multi-hop QA

Overview

Reinforcement learning (RL) based on the final answer’s reward has driven recent progress in small language models (SLMs) on reasoning-heavy tasks such as math and code. However, applying the same techniques to retrieval-augmented generation (RAG) benchmarks like multi-hop QA has yielded limited gains—often trailing supervised or prompting-only baselines. Instead, we argue that a viable path for RL in multi-hop QA is to use test-time scaling judiciously, for optimizing both the final answer accuracy and the efficiency in reaching that answer. We propose FrugalRAG, a two-stage finetuning framework that adaptively reduces the number of retrieval steps based on a question’s difficulty. First, we train an SLM with supervised finetuning on a full-exploration policy that generates broad sub-queries. Then, we apply RL to adaptively prune search depth based on question difficulty, directly rewarding policies that balance correctness with frugality. Unlike prior approaches requiring 100× more data, our method achieves competitive performance with only 1,000 examples. On HotPotQA and other multi-hop QA benchmarks, FrugalRAG attains state-of-the-art efficiency–accuracy tradeoffs, cutting retrieval cost nearly in half. Moreover, on the challenging BrowseCompPlus benchmark, it generalizes zero-shot and surpasses SLM-based and other baselines. These results demonstrate the use of RL—not to increase reasoning steps but to optimize them—as an effective solution for scalable, efficient RAG.

Installation

Prerequisites

  • Python 3.10+
  • CUDA-compatible GPU
  • 16GB+ GPU memory (recommended)

Environment Setup

We recommend using Conda for environment management:

git clone https://github.com/microsoft/FrugalRAG.git
cd FrugalRAG
conda env create -n frag --file environment.yaml
conda activate frag
pip install vllm==0.8.3 --no-deps

See dataset setup and training guide before evaluation.

Quick Start

1. Start Language Model Server

python -m vllm.entrypoints.openai.api_server --model Qwen/Qwen2.5-7B-Instruct --gpu-memory-utilization 0.70 --tensor-parallel-size 1 --port 7501

2. Start Retrieval Backend

Option A: ColBERT (Abstracts)

CUDA_VISIBLE_DEVICES=1 PORT=8000 python -m src.search.serve_colbert.py --index_root ../data/index/ --index wiki17.nbits.local --colbert_path ../data/colbertv2.0 --collection_path ../data/wiki.abstracts.2017/collection.tsv

Option B: ColBERT (Abstracts and Body)

CUDA_VISIBLE_DEVICES=1 PORT=8000 python -m src.search.serve_colbert.py --index_root ../data/index/ --index wiki18.nbits.local --colbert_path ../data/colbertv2.0 --collection_path ../data/wiki.2018/collection.tsv

Option C: E5

# convert faiss flat index to pytorch shards for fast execution
python src/search/shard_embeddings.py

# start server
INDEX_DIR=../data/e5-base-v2/pytorch-shards/ E5_MODEL_NAME_OR_PATH="intfloat/e5-base-v2" TOP_K=5 uvicorn src.search.start_e5_server_main:app --port 8001 

Evaluation

Running Evaluation

Ensure all required services are running before evaluation.

Run evaluation:

python -m src.evaluation.eval_mp --model_name_or_path [MODEL_PATH] --output_path [OUTPUT_PATH] --prompt_path [PROMPT_PATH] --answer_model [BASE_MODEL_NAME] --port 7501 7502 --search_port 8000 --dataset_name [DATASET_NAME] --input_file [DEV_FILE_PATH]

# extract the final answer with CoT prompt
python -m src.evaluation.eval_mp --model_name_or_path [MODEL_PATH] --output_path [OUTPUT_PATH] --prompt_path [PROMPT_PATH] --answer_model [BASE_MODEL_NAME] --port 7501 7502 --search_port 8000 --dataset_name [DATASET_NAME] --input_file [DEV_FILE_PATH] --answer_only True

Run MBE (ensure you set the path in the grade_all.py script)

python -m src.evaluation.grade_all

Available Evaluation Metrics

The evaluation framework automatically computes:

  • Exact Match (EM): Binary accuracy for correct answers
  • Match: Checks if gold answer is present in the generated answer
  • F1 Score: Token-level overlap between predicted and gold answers
  • Cost Efficiency: Retrieval operations per query
  • Recall/Support F1: Retrieval peformance
  • MBE: LLM Judge Score

Configuration

Configuration Files

Configuration files are organized in the configs/ directory:

configs/
├── create_data/          # Data generation configs
│   └── colbert/
│       ├── hotpot_qwen7b_finish.json
│       └── hotpot_qwen7b_nofinish.json
├── sft/                  # Supervised fine-tuning configs
│   └── colbert/
│       ├── hotpot_qwen7b_m5_0.90.json
│       └── hotpot_qwen7b_m5_nofinish.json
├── grpo/                 # GRPO reinforcement learning configs
│   └── colbert/
│       └── hotpot_qwen7b_m5_0.90.json
└── default_config.yaml   # Accelerate configuration

Please set the correct paths, port numbers to ensure the models run smoothly.

Key Configuration Parameters

Parameter Description Example
model_name_or_path Base model for training "Qwen/Qwen2.5-7B-Instruct"
search_port Retrieval server port 8000
port Model server ports (reasoner, answer generator) [7501, 7502]
max_iters Maximum reasoning iterations 5
ndocs Documents retrieved per iteration 3 or 5
dataset_name Target dataset (hotpot, 2wiki, musique) "hotpot"

Troubleshooting

Common Issues & Solutions

Import Errors in ColBERTv2

Issue: ImportError: cannot import name 'AdamW' from 'transformers'

Solution: Comment out the import in the relevant files. We use a newer transformers version.

Device Mismatch During SFT

Issue: RuntimeError: Expected all tensors to be on the same device

Solution: In the transformers package file loss_utils.py, line 38, use:

if reduction == "sum":
    loss = loss / num_items_in_batch.to(loss.device)
return loss

NCCL Error

Exception: Call to collective_rpc method failed: Weight update group already initialized. Call close_communicator first.                                                                                                                                              

Just rerun trl vllm-serve, sometimes it does not call close_communicator on its own.

Performance Optimization

  • Memory Usage: Adjust --gpu-memory-utilization based on your GPU memory
  • Training Speed: Use DeepSpeed Zero3 for large model training
  • Inference Speed: Use --tensor-parallel-size for multi-GPU inference

Citation

If you use this repository, please cite our paper:

@misc{java2025frugalraglearningretrievereason,
      title={FrugalRAG: Learning to retrieve and reason for multi-hop QA}, 
      author={Abhinav Java and Srivathsan Koundinyan and Nagarajan Natarajan and Amit Sharma},
      year={2025},
      eprint={2507.07634},
      archivePrefix={arXiv},
      primaryClass={cs.CL},
      url={https://arxiv.org/abs/2507.07634}, 
}

Acknowledgments

  • ColBERT for the efficient retrieval framework
  • DSPy for the programming framework for language models
  • vLLM for high-performance LLM inference

About

RL finetuning method that ensures that the inference-time compute for queries is optimized based on query difficulty, leading to significant inference efficiency.

Resources

License

Code of conduct

Security policy

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published