-
Notifications
You must be signed in to change notification settings - Fork 48
/
Copy pathtrain_zhen.sh
36 lines (28 loc) · 937 Bytes
/
train_zhen.sh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
export PYTHONPATH=./:${PYTHONPATH}
export CUDA_VISIBLE_DEVICES=3
binFile=./tensor2tensor/bin
PROBLEM=translate_zhen_wmt17
MODEL=transformer
HPARAMS=zhen_wmt17_transformer_rl_delta_setting
# HPARAMS=zhen_wmt17_transformer_rl_delta_setting_random
# HPARAMS=zhen_wmt17_transformer_rl_total_setting
# HPARAMS=zhen_wmt17_transformer_rl_total_setting_random
# HPARAMS=zhen_wmt17_transformer_rl_delta_setting_random_baseline
# HPARAMS=zhen_wmt17_transformer_rl_delta_setting_random_mle
DATA_DIR=../transformer_data/zhen
TRAIN_DIR=./model/${HPARAMS}
mkdir -p $TRAIN_DIR
${binFile}/t2t-trainer \
--t2t_usr_dir=./zhen_wmt17 \
--data_dir=$DATA_DIR \
--problems=$PROBLEM \
--model=$MODEL \
--hparams_set=$HPARAMS \
--output_dir=$TRAIN_DIR \
--train_steps=300000 \
--save_checkpoints_steps=500 \
--keep_checkpoint_max=50 \
--local_eval_frequency=1000000 \
--hparams='batch_size=1024,learning_rate=0.0001' \
--eval_steps=3 \
--worker_gpu=1 \