Skip to content

Latest commit

 

History

History
81 lines (66 loc) · 2.77 KB

README.md

File metadata and controls

81 lines (66 loc) · 2.77 KB

xRAG

Official repo for xRAG: Extreme Context Compression for Retrieval-augmented Generation with One Token

xRAG

Get Started

Refer to Dockerfile for required packages

Configure wandb and accelerate

wandb login
accelerate config

Pretrained Checkpoints

HuggingFace

Model Backbone Download
xRAG-7b mistralai/Mistral-7B-Instruct-v0.2 🤗 Hugging Face
xRAG-MoE mistralai/Mixtral-8x7B-Instruct-v0.1 🤗 Hugging Face

Tutorial

We provide a tutorial for xRAG in tutorial.ipynb. Check it out!

Data

  • download enwiki-dec2021 as pretraining data and corpus for retrieval
  • prepare instruction tuning data in prepare_data.ipynb
  • download TriviaQA
  • using ColBERT-v2 to conduct retrieval

Training

Training scripts in scripts/, for example, to train a Mistral-7b with SFR:

accelerate launch \
    --mixed_precision bf16 \
    --num_machines 1 \
    --num_processes 8 \
    --main_process_port 29666 \
    -m \
    src.language_modeling.train \
    --config config/language_modeling/pretrain.yaml \

Evaluation

The evaluation code is in src/eval. For example, to evaluate on TriviaQA:

without retrieval augmentation:

CUDA_VISIBLE_DEVICES=0 python -m src.eval.run_eval \
        --data triviaqa \
        --model_name_or_path Hannibal046/xrag-7b

with retrieval augmentation:

CUDA_VISIBLE_DEVICES=0 python -m src.eval.run_eval \
        --data triviaqa \
        --model_name_or_path Hannibal046/xrag-7b \
        --use_rag

with xRAG:

CUDA_VISIBLE_DEVICES=0 python -m src.eval.run_eval \
        --data triviaqa \
        --model_name_or_path Hannibal046/xrag-7b \
        --retriever_name_or_path Salesforce/SFR-Embedding-Mistral \
        --use_rag

Benchmark

To benchmark xRAG, we provide the code in src/language_modeling/profiler.py.

python -m src.language_modeling.profiler --instruction_length 54 --generation_length 30 --dataset triviaqa --use_xrag
python -m src.language_modeling.profiler --instruction_length 54 --generation_length 30 --dataset triviaqa