Implementation codes for End-to-End Crystal Structure Prediction from Powder X-Ray Diffraction (XtalNet)
Advanced Science [paper][arXiv]
XtalNet is the first equivariant deep generative model for end-to-end crystal structure prediction from Powder X-Ray Diffraction. XtalNet aims to extend the capabilities of deep learning in predicting crystal structures based on PXRD patterns, encompassing more complex structures and specific conditions.
Codes are tested on Ubuntu 20.04
Use xtalnet.yaml
to setup the environment.
conda env create -f xtalnet.yaml
Installation time depends on the internet connection.
Download data and checkpoints from here
Change root_path
in conf/data/hmof_100.yaml
and conf/data/hmof_400.yaml
to the path of the downloaded hmof_100
dir path and hmof_400
dir path.
Rename the .env.template
file into .env
and specify the following variables.
PROJECT_ROOT: the absolute path of this repo
HYDRA_JOBS: the absolute path to save hydra outputs
WABDB_DIR: the absolute path to save wanbdb outputs
export expname=cpcp_training
export model=cpcp
export data_name='hmof_100' # or 'hmof_400'
export freeze=false
export bsz=16 # 4 gpus, 8 for hmof_400
export lr=5e-4 # 2e-4 for hmof_400
export betas='[0.9,0.99]'
export eps=1e-6
export weight_decay=1e-4
bash train.sh
export expname=ccsg_training
export model=ccsg
export data_name='hmof_100' # or 'hmof_400'
export pretrained=<cpcp_ckpt_path>
export freeze=true
export bsz=16 # 4 gpus, 4 for hmof_400
export lr=1e-3 #
export betas='[0.9,0.999]'
export eps=1e-8
export weight_decay=0
bash train.sh
First generate the CPCP model's predictions
python scripts/evaluate_cpcp.py --model_path <ckpt_dir_path> --ckpt_path <ckpt_path> --save_path <save_path> --label <label>
Then compute the CPCP model's metrics
python scripts/compute_cpcp_metrics.py --root_path <results_path>
generate samples from trained model
python scripts/evaluate_ccsg.py \
--ccsg_ckpt_path <ccsg_ckpt_path> \
--cpcp_ckpt_path <clip_ckpt_path> \
--save_path <save_path> \
--label <label> --num_evals <num_evals> \
--begin_idx <begin_idx> --end_idx <end_idx>
compute metrics
python scripts/compute_ccsg_metrics.py --root_path <results_path> \
--save_path <save_path> --multi_eval --label <label> \
python scripts/gradio_demo.py --ccsg_ckpt_path <ccsg_ckpt_path> --cpcp_ckpt_path <cpcp_ckpt_path> --save_path <save_path>
You can use Cu2H8C28N6O8
and example/case.txt
as input example.
The output will be saved in <save_path>
dir and you can also download it from gradio web UI. The GT cif file can be found in example/case_gt.cif
.
For the V100 GPU, it takes nearly 20 seconds to generate a sample.