Related works:
 Related works:
+"An Empirical Exploration of Local Ordering Pre-training for Structured Prediction": [TODO](??)
 "A Two-Step Approach for Implicit Event Argument Detection": [details](docs/iarg.md)
 Some other parsers for interested readers: [details](docs/sop.md)
new file mode 100644
@@ -0,0 +1,68 @@
+### For the Local Bag pre-training method
Please refer to the paper for more details: [[paper]](TODO) [[bib]](TODO)
### Repo
+### Repo
### Environment
+### Environment
+As those of the main `msp` package:
### Data
+	dependencies: pytorch>=1.0.0, numpy, scipy, gensim, cython, transformers, ...
### Running
- Step 0: Setup
Assume we are at a new DIR, and please download this repo into a DIR called `src`: `git clone https://github.com/zzsfornlp/zmsp src` and specify some ENV variables (for convenience):
SRC_DIR: Root dir of this repo
	CUR_LANG: Lang id of the current language (for example en)
	WIKI_PRETRAIN_SIZE: Pretraining size
	UD_TRAIN_SIZE: Task training size
+### Running
+- Step 0: Setup
The `vv_*` files at this dir will be the vocabularies that will be utilized for the remaining steps.
+	SRC_DIR: Root dir of this repo
+	CUR_LANG: Lang id of the current language (for example en)
+	WIKI_PRETRAIN_SIZE: Pretraining size
+	UD_TRAIN_SIZE: Task training size
+- Step 1: Build dictionary with pre-trained data
+Before this, we should have the data prepared (for pre-training and task-training).
+Assume that we have UD files at `data/UD_RUN/ud24s/${CUR_LANG}_train.${UD_TRAIN_SIZE}.conllu`, and pre-training (wiki) files at `data/UD_RUN/wikis/wiki_${CUR_LANG}.${WIKI_PRETRAIN_SIZE}.txt`. 
+Assmuing now we are at DIR `data/UD_RUN/vocabs/voc_${CUR_LANG}`, we first create vocabulary for this setting with:
+	PYTHONPATH=${SRC_DIR} python3 ${SRC_DIR}/tasks/cmd.py zmlm.main.vocab_utils train:../../wikis/wiki_en.${WIKI_PRETRAIN_SIZE}.txt input_format:plain norm_digit:1 >vv.list
+The `vv_*` files at this dir will be the vocabularies that will be utilized for the remaining steps.
+- Step 2: Do pre-training
+Assuming now we are at DIR `data/..`
+Simply use the script of `${SRC_DIR}/scripts/lbag/run.py` for pre-training.
+	python3 ${SRC_DIR}/scripts/lbag/run.py -l ${CUR_LANG} --rgpu 0 --run_dir run_orp_${CUR_LANG} --enc_type trans --run_mode pre --pre_mode orp --train_size ${WIKI_PRETRAIN_SIZE} --do_test 0
+Note that by default, the data dirs are already pre-set as the ones in step 1, the paths can also be specified, please refer to the script for more details.
+There are various modes for pre-training, the most typical ones are: orp (or lbag, our local reordering strategy), mlm (masked LM), om (orp+mlm). Please use `--pre_mode` to specify.
+This may take a while (it took us three days to pretrain with 1M data on a single GPU). After this, we get the pre-trained models at `run_orp_${CUR_LANG}`.
+- Step 3: Fine-tuning on specific tasks
+Finally, training (fine-tuning) on specific tasks (here on Dep+Pos with UD data) with the pre-trained model. We can still use the script of `${SRC_DIR}/scripts/lbag/run.py`, simply change the `--run_mode` to `ppp1`, together with other information.
+	python3 ${SRC_DIR}/scripts/lbag/run.py -l ${CUR_LANG} --rgpu 0 --cur_run 1 --run_dir run_ppp1_${CUR_LANG} --run_mode ppp1 --train_size ${UD_TRAIN_SIZE} --preload_prefix ../run_orp_${CUR_LANG}/zmodel.c200
+Here, we use the `checkpoint@200` model from the pre-trained dir, other checkpoints can also be specified (only providing the unambigous model prefix will be enough).
+Again, paths are by default the ones we setup from Step 1, if using other paths, things can be also specified with various `--*_dir`.
@@ -38,7 +38,7 @@ def zwarn(s):
 # =====
-LANG_NAME_MAP = {'es': 'spanish', 'zh': 'chinese', 'en': 'english'}
+LANG_NAME_MAP = {'es': 'spanish', 'zh': 'chinese', 'en': 'english', 'ru': 'russian', 'uk': 'ukraine'}
 # map CTB pos tags to UD (partially reference "Developing Universal Dependencies for Mandarin Chinese")
new file mode 100644
index 0000000..720100b
+# (updated one) a single script handling the runnings of multiple modes
+import argparse
+import sys
+import os
+SCRIPT_ABSPATH = os.path.abspath(__file__)
+def printing(x, end='\n', file=sys.stderr, flush=True):
+    print(x, end=end, file=file, flush=flush)
+def system(cmd, pp=True, ass=True, popen=False):
+    if pp:
+        printing("Executing cmd: %s" % cmd)
+    n = os.system(cmd)
+    if ass:
+        assert n==0, f"Executing previous cmd returns error {n}."
+    return n
+def parse():
+    parser = argparse.ArgumentParser(description='Process running confs.')
+    # =====
+    # -- dirs
+    # todo(note): be sure to make base_dir is relative to run_dir and others are relative to base_dir
+    parser.add_argument("--run_dir", type=str, required=True)
+    # --
+    parser.add_argument("--base_dir", type=str, default=None)
+    parser.add_argument("--src_dir", type=str, default="src")
+    parser.add_argument("--dict_dir", type=str, default="data/UD_RUN/vocabs")
+    parser.add_argument("--embed_dir", type=str, default="data/UD_RUN/embeds")
+    parser.add_argument("--ud_dir", type=str, default="data/UD_RUN/ud24s")
+    parser.add_argument("--ner_dir", type=str, default="data/ner")
+    parser.add_argument("--wiki_dir", type=str, default="data/UD_RUN/wikis")
+    # =====
+    # -- basics
+    # running
+    parser.add_argument("--run_mode", type=str, required=True)  # pre/ppp0/ppp1/...
+    parser.add_argument("--pre_mode", type=str, default="orp")  # plm/mlm/orp
+    parser.add_argument("--do_test", type=int, default=1, help="testing when training parser")
+    parser.add_argument("--preload_prefix", type=str)  # pre-trained model for "par1"
+    parser.add_argument("--pre_upe", type=int, default=1000, help="Update for fake epoch(valid)")
+    # data
+    parser.add_argument("-l", "--lang", type=str, required=True)
+    parser.add_argument("--train_size", type=str, default="10k", help="Training size infix")
+    # wvec and bert
+    parser.add_argument("--use_wvec", type=int, default=0)
+    parser.add_argument("--use_bert", type=int, default=0)
+    parser.add_argument("--aug_wvec2", type=int, default=0)
+    # model
+    parser.add_argument("--model_dim", type=int, default=512)
+    parser.add_argument("--head_count", type=int, default=8, help="Head count in MATT for trans/vrec")
+    parser.add_argument("--enc_type", type=str, default="trans", choices=["vrec", "trans", "rnn"])
+    parser.add_argument("--lm_pred_size", type=int, default=100000)
+    parser.add_argument("--lm_tie_inputs", type=int, default=1)
+    parser.add_argument("--use_biaffine", type=int, default=1)
+    # env
+    parser.add_argument("--cur_run", type=int, default=1)
+    parser.add_argument("--rgpu", type=int, required=True)  # -1 for CPU, >=0 for GPU
+    parser.add_argument("--debug", type=int, default=0)
+    # extras
+    parser.add_argument("--extras", type=str, default="")
+    # -----
+    args = parser.parse_args()
+    return args
+# --
+def step0_dir(args):
+    printing("# =====\n=> Step 0: determine dirs")
+    run_dir = args.run_dir
+    system(f"mkdir -p {run_dir}")
+    os.chdir(run_dir)
+    # check out dirs based on src
+    base_dir = args.base_dir
+    if base_dir is None:
+        # search for src dir
+        base_dir = ""
+        _MAX_TIME = 4
+        while not os.path.isdir(os.path.join(base_dir, args.src_dir)) and _MAX_TIME > 0:
+            base_dir = os.path.join(base_dir, "..")  # upper one layer
+            _MAX_TIME -= 1
+        src_dir = os.path.join(base_dir, args.src_dir)
+    else:
+        src_dir = os.path.join(base_dir, args.src_dir)
+    assert os.path.isdir(src_dir), f"Failed to find src dir: {args.src_dir}"
+    printing(f"Finally run_dir={os.getcwd()}, base_dir={base_dir}, src_dir={src_dir}")
+    dict_dir = os.path.join(base_dir, args.dict_dir)
+    embed_dir = os.path.join(base_dir, args.embed_dir)
+    ud_dir = os.path.join(base_dir, args.ud_dir)
+    wiki_dir = os.path.join(base_dir, args.wiki_dir)
+    ner_dir = os.path.join(base_dir, args.ner_dir)
+    return run_dir, src_dir, dict_dir, embed_dir, ud_dir, wiki_dir, ner_dir
+# -----
+# todo(note): by default for par0
+def get_basic_options(args):
+    MODEL_DIM = args.model_dim
+    ENC_TYPE = args.enc_type
+    HEAD_COUNT = args.head_count
+    # =====
+    base_opt = "conf_output:_conf"
+    # => training and lrate schedule
+    base_opt += " lrate.val:0.0004 lrate_warmup:250 lrate.min_val:0.000001 lrate.m:0.75"  # default lrate
+    base_opt += " min_save_epochs:0 max_epochs:250 patience:8 anneal_times:10"
+    base_opt += " train_min_length:5 train_max_length:80 train_batch_size:80 split_batch:5 test_batch_size:8"
+    # => dropout
+    base_opt += " drop_hidden:0.1 gdrop_rnn:0.33 idrop_rnn:0.33 fix_drop:0"
+    # => model
+    # embedder
+    base_opt += " ec_word.comp_dim:300 ec_word.comp_drop:0.1 ec_word.comp_init_scale:1. ec_word.comp_init_from_pretrain:0"
+    base_opt += " ec_char.comp_dim:50 ec_char.comp_drop:0.1 ec_char.comp_init_scale:1."
+    # (NOPOS!!) base_opt+=" ec_pos.comp_dim:50 ec_pos.comp_drop:0.1 ec_pos.comp_init_scale:1."
+    base_opt += f" emb_proj_dim:{MODEL_DIM} emb_proj_act:linear"  # proj after embedder
+    # encoder
+    if ENC_TYPE == "rnn":
+        # RNN baseline
+        base_opt += " matt_conf.head_count:256 default_attn_count:256 prepr_choice:rdist"  # this is not used, simply avoid err
+        base_opt += f" enc_choice:original oenc_conf.enc_hidden:{MODEL_DIM} oenc_conf.enc_rnn_layer:3 oenc_conf.enc_rnn_sep_bidirection:1"
+    else:
+        # use VRec
+        base_opt += " enc_choice:vrec"
+        base_opt += " venc_conf.num_layer:6 venc_conf.attn_ranges:"  # vrec layers
+        base_opt += " use_pre_norm:0 use_post_norm:1"  # layer norm
+        # # no scale now!!
+        # base_opt += " scorer_conf.param_init_scale:0.5 collector_conf.param_init_scale:0.5"  # other init
+        # base_opt += " ec_word.comp_init_scale:0.5 ec_char.comp_init_scale:0.5"  # other init
+        # use MAtt
+        base_opt += f" matt_conf.head_count:{HEAD_COUNT}"  # head count
+        base_opt += f" scorer_conf.d_qk:{HEAD_DIM} scorer_conf.use_rel_dist:1 scorer_conf.no_self_loop:1"  # matt scorer
+        base_opt += f" normer_conf.use_noop:1 normer_conf.noop_fixed_val:0. normer_conf.norm_mode:cand normer_conf.attn_dropout:0.1"  # matt norm
+        base_opt += f" collector_conf.d_v:{HEAD_DIM} collector_conf.collect_mode:ch collect_reducer_mode:aff collect_red_aff_out:{MODEL_DIM}"  # matt col
+        if ENC_TYPE == "trans":
+            base_opt += f" venc_conf.share_layers:0 comb_mode:affine ff_dim:{MODEL_DIM*2}"  # no share
+        elif ENC_TYPE == "vrec":
+            base_opt += " venc_conf.share_layers:1 comb_mode:lstm ff_dim:0"  # share layers
+        else:
+            raise NotImplementedError()
+    # => prediction & loss
+    # vocab
+    base_opt += " lower_case:0 norm_digit:1 word_rthres:1000000 word_fthres:1"  # vocab
+    base_opt += " ec_word.comp_rare_unk:0.5 ec_word.comp_rare_thr:10"  # word rare unk in training
+    # masklm
+    base_opt += " mlm_conf.loss_lambda.val:0. mlm_conf.max_pred_rank:100"  # masklm
+    base_opt += " mlm_conf.hid_dim:300 mlm_conf.hid_act:elu"  # hid
+    base_opt += f" mlm_conf.tie_input_embeddings:{args.lm_tie_inputs}"  # tie embeddings
+    # plainlm
+    base_opt += " plm_conf.loss_lambda.val:0. plm_conf.max_pred_rank:100"  # plainlm
+    base_opt += " plm_conf.hid_dim:300 plm_conf.hid_act:elu"  # hid
+    base_opt += f" plm_conf.tie_input_embeddings:{args.lm_tie_inputs}"  # tie embeddings
+    # orderpr
+    base_opt += " orp_conf.loss_lambda.val:0. orp_conf.disturb_mode:abs_bin orp_conf.bin_blur_bag_keep_rate:0."
+    # base_opt += " orp_conf.disturb_ranges:8,10 orp_conf.bin_blur_bag_lbound:-2"
+    base_opt += " orp_conf.disturb_ranges:7,8"  # original one
+    # base_opt += " orp_conf.pred_range:1 orp_conf.cand_range:4 orp_conf.lambda_n1:0. orp_conf.lambda_n2:1."
+    base_opt += " orp_conf.pred_range:1 orp_conf.cand_range:1000 orp_conf.lambda_n1:0. orp_conf.lambda_n2:1."  # AllCand
+    base_opt += " orp_conf.ps_conf.use_input_pair:0"  # scorer
+    # better ones for orp
+    base_opt += " orp_conf.bin_blur_bag_keep_rate:0.5 orp_conf.pred_range:2" \
+                " orp_conf.ps_conf.use_biaffine:1 orp_conf.ps_conf.use_ff2:0"
+    # dpar
+    base_opt += " dpar_conf.loss_lambda.val:0. dpar_conf.label_neg_rate:0.1"  # overall
+    base_opt += " dpar_conf.pre_dp_space:512 dpar_conf.pre_dp_act:elu"  # pre-scoring
+    base_opt += " dpar_conf.dps_conf.use_input_pair:0"  # scorer
+    if args.use_biaffine:
+        base_opt += " dpar_conf.dps_conf.use_biaffine:1 dpar_conf.dps_conf.use_ff1:1 dpar_conf.dps_conf.use_ff2:0"
+    # pos
+    base_opt += " upos_conf.loss_lambda.val:0."
+    # --
+    return base_opt
+# ----
+def find_model(model_prefix):
+    model_dir = os.path.dirname(model_prefix)
+    model_name_prefix = os.path.basename(model_prefix)
+    model_name_full = None
+    for f in os.listdir(model_dir):
+        if f.endswith(".pr.json") and f.startswith(model_name_prefix):
+            model_name_full = f[:-8]
+    assert model_name_full is not None
+    model_path = os.path.join(model_dir, model_name_full)
+    printing(f"Find model at {model_path}")
+    return model_path
+# =====
+def step_run(args, DIRs):
+    run_dir, src_dir, dict_dir, embed_dir, ud_dir, wiki_dir, ner_dir = DIRs
+    printing(f"# =====\n=> Step 1: running!!")
+    # basic
+    cur_opt = get_basic_options(args)
+    # =====
+    # specific mode
+    CUR_LANG = args.lang
+    RUN_MODE = args.run_mode
+    PRE_MODE = args.pre_mode
+    TRAIN_SIZE = args.train_size
+    ENC_TYPE = args.enc_type
+    # wvec & bert?
+    if args.use_wvec:
+        cur_opt += " read_from_pretrain:1 ec_word.comp_init_from_pretrain:1" \
+                   f" pretrain_file:{embed_dir}/wiki.{CUR_LANG}.vec"
+    if args.use_bert:
+        cur_opt += " ec_word.comp_dim:0"  # no word if using bert (by default mbert)
+        cur_opt += " ec_bert.comp_dim:768 ec_bert.comp_drop:0.1 ec_bert.comp_init_scale:1."
+        cur_opt += " bert2_output_layers:4,5,6,7,8 bert2_training_mask_rate:0."
+    if args.aug_wvec2:  # special mode!!
+        cur_opt += f" aug_word2:1 aug_word2_pretrain:{embed_dir}/wiki.{CUR_LANG}.vec"
+    # modes
+    # -- pre-training
+    if RUN_MODE == "pre":
+        cur_opt += " min_save_epochs:0 max_epochs:200 patience:1000 anneal_times:5"
+        cur_opt += f" train:{wiki_dir}/wiki_{CUR_LANG}.{TRAIN_SIZE}.txt dev: test: input_format:plain" \
+                   f" cache_data:0 no_build_dict:1 dict_dir:{dict_dir}/voc_{CUR_LANG}/"
+        # different pre-training tasks for different archs
+        if PRE_MODE == "plm":
+            cur_opt += f" plm_conf.loss_lambda.val:1. plm_conf.max_pred_rank:{args.lm_pred_size}"
+            if ENC_TYPE != "rnn":
+                cur_opt += f" plm_conf.split_input_blm:0"  # input list for vrec
+        elif PRE_MODE == "mlm":
+            cur_opt += f" mlm_conf.loss_lambda.val:1. mlm_conf.max_pred_rank:{args.lm_pred_size}"
+            cur_opt += f" mlm_conf.mask_rate:0.15"
+        elif PRE_MODE == "orp":
+            cur_opt += " orp_conf.loss_lambda.val:1."
+        elif PRE_MODE == "om":  # orp + mlm
+            cur_opt += " orp_conf.loss_lambda.val:0.5 orp_conf.bin_blur_bag_keep_rate:0.5"  # keep for mlm
+            cur_opt += f" mlm_conf.loss_lambda.val:0.5 mlm_conf.max_pred_rank:{args.lm_pred_size}"
+            cur_opt += f" mlm_conf.mask_rate:0.15"
+        else:
+            raise NotImplementedError()
+        # -----
+        # # schedule
+        # # (OldRun) when bs=80: (UPE*80 as one epoch): 600*CHECKPOINT seems to be enough (~50 epochs for 1m)
+        # # (NewRun) when bs=240: (UPE*240 as one epoch): use 300*CHECKPOINT (~72 epochs)
+        # UPE = 1000  # number of Update per Pseudo Epoch
+        # cur_opt += " train_batch_size:240 split_batch:15"  # larger batch size for pre-training
+        # cur_opt += f" report_freq:{UPE} valid_freq:{UPE} max_updates:{304*UPE}"
+        # cur_opt += " save_freq:50 validate_epoch:0"  # save more points
+        # cur_opt += f" lrate_warmup:{5*UPE} lrate.which_idx:uidx lrate.start_bias:{50*UPE} lrate.scale:{50*UPE}"
+        # -----
+        # schedule2
+        # (NewRun2) when bs=480: (UPE*480 as one epoch): use 200*CHECKPOINT (~96 epochs)
+        UPE = args.pre_upe  # number of Update per Pseudo Epoch
+        cur_opt += " train_batch_size:480 split_batch:30"  # larger batch size for pre-training
+        cur_opt += f" report_freq:{UPE} valid_freq:{UPE} max_updates:{201*UPE} save_freq:40 validate_epoch:0"
+        cur_opt += f" lrate_warmup:{5*UPE} lrate.which_idx:uidx lrate.start_bias:{40*UPE} lrate.scale:{40*UPE} lrate.m:0.5"
+    # -- parser/tagger training
+    else:
+        assert RUN_MODE in ["par0", "par1", "pos0", "pos1", "ppp0", "ppp1", "ner0", "ner1"]
+        if RUN_MODE.startswith("ner"):
+            cur_opt += " do_ner:1 input_format:ner div_by_tok:0"  # remember to add the extra flag!!
+            cur_opt += f" train:{ner_dir}/{CUR_LANG}/train.{TRAIN_SIZE}.txt dev:{ner_dir}/{CUR_LANG}/dev.txt" \
+                       f" test:{ner_dir}/{CUR_LANG}/test.txt"
+            # extra settings for ner
+            # cur_opt += " train_batch_size:80 split_batch:2 drop_hidden:0.5 max_epochs:10000000 validate_epoch:0 valid_freq:200" \
+            #            " lrate_warmup:500 max_updates:30000 div_by_tok:0 ec_word.comp_drop:0.5 ec_char.comp_drop:0.5"
+            # cur_opt += " drop_hidden:0.5 ec_word.comp_drop:0.5 ec_char.comp_drop:0.5"
+        else:
+            cur_opt += f" train:{ud_dir}/{CUR_LANG}_train.{TRAIN_SIZE}.conllu dev:{ud_dir}/{CUR_LANG}_dev.conllu" \
+                       f" test:{ud_dir}/{CUR_LANG}_test.conllu"
+        task_comp_names = {"par": ["dpar"], "pos": ["upos"], "ppp": ["dpar", "upos"], "ner": ["ner"]}[RUN_MODE[:3]]
+        for one_tcn in task_comp_names:
+            cur_opt += f" {one_tcn}_conf.loss_lambda.val:1."
+        if RUN_MODE[-1] == "1":
+            # use pretrain vocab
+            cur_opt += f" no_build_dict:1 dict_dir:{dict_dir}/voc_{CUR_LANG}/"
+            # find preload model
+            cur_opt += f" load_pretrain_model_name:{find_model(args.preload_prefix)}"
+            # use smaller lrate
+            cur_opt += " lrate.val:0.0002"
+        if ENC_TYPE == "rnn":
+            # for rnn, better to use no warmup and larger dropout
+            cur_opt += " lrate_warmup:0 drop_hidden:0.33"
+    # gpu and seed
+    cur_opt += f" device:{-1 if args.rgpu<0 else 0}" \
+               f" niconf.random_seed:9347{args.cur_run} niconf.random_cuda_seed:9349{args.cur_run}"
+    # finally extras
+    cur_opt += " " + args.extras
+    # =====
+    # training
+    CMD = f"CUDA_VISIBLE_DEVICES={args.rgpu} PYTHONPATH={src_dir} python3 " + (" -m pdb " if args.debug else "") \
+          + f"{src_dir}/tasks/cmd.py zmlm.main.train {cur_opt} " + ("" if args.debug else f"2>&1 | tee _log")
+    system(CMD)
+    # testing
+    if not args.debug and args.do_test and (not RUN_MODE.startswith("pre")):
+        for wset in ["dev", "test"]:
+            if RUN_MODE.startswith("ner"):
+                CMD2 = f"CUDA_VISIBLE_DEVICES={args.rgpu} PYTHONPATH={src_dir} python3 {src_dir}/tasks/cmd.py zmlm.main.test" \
+                       f" _conf input_format:ner do_ner:1 test:{ner_dir}/{CUR_LANG}/{wset}.txt output_file:zout.{wset}" \
+                       f" model_load_name:zmodel.best"
+            else:
+                CMD2 = f"CUDA_VISIBLE_DEVICES={args.rgpu} PYTHONPATH={src_dir} python3 {src_dir}/tasks/cmd.py zmlm.main.test" \
+                       f" _conf input_format:conllu test:{ud_dir}/{CUR_LANG}_{wset}.conllu output_file:zout.{wset}" \
+                       f" model_load_name:zmodel.best"
+            system(CMD2)
+    # running with pre-train
+    if not args.debug and args.do_test and (RUN_MODE == "pre"):
+        CMD = f"python3 {SCRIPT_ABSPATH} -l {args.lang} --rgpu {args.rgpu} --run_dir _ppp1 --run_mode ppp1 --preload_prefix ../zmodel.c200"
+        system(CMD)
+    # =====
+# =====
+def main():
+    printing(f"Run with {sys.argv}")
+    args = parse()
+    for k in sorted(dir(args)):
+        if not k.startswith("_"):
+            printing(f"--{k} = {getattr(args, k)}")
+    # --
+    DIRs = step0_dir(args)
+    step_run(args, DIRs)
+# --
+if __name__ == '__main__':
+    main()
diff --git a/tasks/zmlm/data/insts/__init__.py b/tasks/zmlm/data/insts/__init__.py
new file mode 100644
index 0000000..a097c09
--- /dev/null
+++ b/tasks/zmlm/data/insts/__init__.py
@@ -0,0 +1,13 @@
+# -----
+# for data
+- Generally two layers: instance, field
+-- instance can be automatically recursive (but get specific classes for common ones, like "Sentence", "Paragraph")
+-- field must specify details by themselves (usually pieces of instances, like "word_seq", "pos_seq")
+from .base import BaseDataItem
+from .field import DataField, SeqField, NpArrField, DepTreeField
+from .general import GeneralInstance, GeneralSentence, GeneralParagraph, GeneralDocument
new file mode 100644
index 0000000..dca97b8
--- /dev/null
+++ b/tasks/zmlm/data/insts/base.py
@@ -0,0 +1,41 @@
+# base protocol for data io
+class BaseDataItem:
+    # =====
+    # from name to cls
+    _registered_mappings = {}  # name to class mapping
+    @staticmethod
+    def register(cls):
+        name = cls.__name__
+        mapping = BaseDataItem._registered_mappings
+        assert name not in mapping, f"Repeated entry name: {cls}({name})"
+        mapping[name] = cls
+        return cls
+    # =====
+    # IO: wrapped versions are useful for optional fields/components
+    def to_builtin(self, *args, **kwargs):
+        raise NotImplementedError()
+    @classmethod
+    def from_builtin(cls, d):
+        raise NotImplementedError()
+    def to_builtin_wrapped(self, *args, **kwargs):
+        class_name = self.__class__.__name__
+        assert class_name in BaseDataItem._registered_mappings, f"Entry not found: {self.__class__}({class_name})"
+        return {"_type": class_name, "_v": self.to_builtin(*args, **kwargs)}
+    @staticmethod
+    def from_builtin_wrapped(d):
+        class_name, val = d["_type"], d["_v"]
+        class_type = BaseDataItem._registered_mappings[class_name]
+        return class_type.from_builtin(val)
+# short-cut
+data_type_reg = BaseDataItem.register
new file mode 100644
index 0000000..e3ecfba
--- /dev/null
+++ b/tasks/zmlm/data/insts/field.py
@@ -0,0 +1,177 @@
+from typing import List
+import numpy as np
+from msp.utils import zwarn
+from msp.data import Vocab
+from .base import BaseDataItem, data_type_reg
+# =====
+# fields are those that must specify IO by themselves
+#   todo(note): Field need to deal with empties (usually denoted as None in init)
+# base class
+class DataField(BaseDataItem):
+    pass
+# general sequence field: words, poses, ...
+class SeqField(DataField):
+    def __init__(self, vals: List, idxes: List[int] = None):
+        self.vals = vals
+        self.idxes = idxes
+    def __len__(self):
+        return len(self.vals)
+    def has_vals(self):
+        return self.vals is not None
+    # reset the vals
+    def reset(self, vals: List, idxes: List[int] = None):
+        self.vals = vals
+        self.idxes = idxes
+    # set indexes (special situations)
+    def set_idxes(self, idxes: List[int]):
+        assert len(idxes)==len(self.vals), "Unmatched length of input idxes."
+        self.idxes = idxes
+    # two directions: val <=> idx
+    # init-val -> idx
+    def build_idxes(self, voc: Vocab):
+        self.idxes = [voc.get_else_unk(w) for w in self.vals]
+    # set idx & val
+    def build_vals(self, idxes: List[int], voc: Vocab):
+        self.idxes = idxes
+        self.vals = [voc.idx2word(i) for i in idxes]
+    def __repr__(self):
+        return str(self.vals)
+    # =====
+    # io: no idxes, need to rebuild if needed
+    def to_builtin(self, *args, **kwargs):
+        return self.vals
+    @classmethod
+    def from_builtin(cls, vals: List):
+        return SeqField(vals)
+# initiated from words (built from words)
+class InputCharSeqField(SeqField):
+    def __init__(self, words: List[str]):
+        super().__init__(words)
+    def build_idxes(self, c_voc: Vocab):
+        self.idxes = [[c_voc.get_else_unk(c) for c in w] for w in self.vals]
+    # set idx & val
+    def build_vals(self, idxes: List[int], voc: Vocab):
+        raise RuntimeError("No need to build vals for CharSeq.")
+    def __repr__(self):
+        return str(self.vals)
+# NumpyArrField
+class NpArrField(DataField):
+    def __init__(self, arr: np.ndarray, float_decimal=None):
+        self.arr = arr
+        self.float_decimal: int = float_decimal
+    @staticmethod
+    def _round_float(x, d):
+        r = 10 ** d
+        return int(r*x) / r
+    def __repr__(self):
+        return f"Array: type={self.arr.dtype}, shape={self.arr.shape}"
+    def to_builtin(self, *args, **kwargs):
+        flattened_arr = self.arr.flatten()
+        flattened_list = flattened_arr.tolist()
+        if self.float_decimal is not None:
+            flattened_list = [NpArrField._round_float(x, self.float_decimal) for x in flattened_list]
+        return (self.arr.shape, self.arr.dtype.name, flattened_list)
+    @classmethod
+    def from_builtin(cls, v):
+        shape, dtype, vlist = v
+        return np.asarray(vlist, dtype=dtype).reshape(shape)
+# DepTreeField: todo(note): remember everything is considering offset=1 for ARTI_ROOT
+class DepTreeField(DataField):
+    def __init__(self, heads: List[int], labels: List[str], label_idxes: List[int] = None):
+        # todo(note): ARTI_ROOT as 0 should be included here, added from the outside!
+        self.heads = heads
+        self.labels = labels
+        self.label_idxes = label_idxes
+        # == some cache values
+        self._dep_dists = None  # m-h
+        self._label_matrix = None  # Arr[m,h] of int
+        # -----
+        if self.heads is not None:
+            if self.heads[0] != 0 or self.labels[0] != "":
+                zwarn("Bad values for ARTI_ROOT!!")
+            self._build_tree()  # build and check
+    def __len__(self):
+        return len(self.heads)
+    def __repr__(self):
+        return f"DepTree of len={len(self)}"
+    def has_vals(self):
+        return self.heads is not None
+    # build the extra info for the trees
+    def _build_tree(self):
+        # TODO(!)
+        pass
+    # For labels: two directions: val <=> idx
+    # init-val -> idx
+    def build_label_idxes(self, l_voc: Vocab):
+        self.label_idxes = [l_voc.get_else_unk(w) for w in self.labels]
+    # set idx & val
+    @staticmethod
+    def build_label_vals(label_idxes: List[int], l_voc: Vocab):
+        labels = [l_voc.idx2word(i) for i in label_idxes]
+        labels[0] = ""  # ARTI_ROOT
+        return labels
+    # =====
+    def to_builtin(self, *args, **kwargs):
+        return (self.heads, self.labels, self.label_idxes)
+    @classmethod
+    def from_builtin(cls, v):
+        heads, labels, label_idxes = v
+        return DepTreeField(heads, labels, label_idxes)
+    # =====
+    @property
+    def dep_dists(self):
+        if self._dep_dists is None:
+            self._dep_dists = [m-h for m, h in enumerate(self.heads)]
+        return self._dep_dists
+    @property
+    def label_matrix(self):
+        if self._label_matrix is None:
+            # todo(+N): currently only mh_dep, may be extended to other relations or multi-hop relations
+            cur_heads, cur_lidxs = self.heads, self.label_idxes  # [Alen=1+Rlen]
+            cur_alen = len(cur_heads)
+            _mat = np.zeros([cur_alen, cur_alen], dtype=np.long)  # [Alen, Alen]
+            _mat[np.arange(cur_alen), cur_heads] = cur_lidxs  # assign mh_dep
+            _mat[0] = 0  # no head for ARTI_ROOT
+            self._label_matrix = _mat
+        return self._label_matrix
new file mode 100644
index 0000000..50f74f2
--- /dev/null
+++ b/tasks/zmlm/data/insts/general.py
@@ -0,0 +1,175 @@
+from typing import Dict, List, Tuple, Union
+from msp.utils import Helper
+from .base import BaseDataItem, data_type_reg
+from .field import DataField, SeqField, InputCharSeqField, DepTreeField
+# =====
+# instances may auto recursively judge IO in some way
+# -- two ways of building: cls.create(...) or cls.from_builtin(d)
+# todo(note): how to deal with cycles?
+#  => no cycles in the "General" structures but put cycle links inside "Field" and specifically handle them
+# recursive general instances
+class GeneralInstance(BaseDataItem):
+    def __init__(self):
+        self.items: Dict = {}  # main items (serializable)
+        self.info: Dict = {}  # other contents (serializable)
+        # -----
+        self.features = {}  # running/extra/tmp/cached features (non-serializable)
+    # common
+    def _add_sth(self, name: str, sth, v, assert_non_exist: bool):
+        if assert_non_exist:
+            assert not hasattr(self, name), f"Inside the class: {name} already exists"
+            assert name not in v, f"Inside the collection: {name} already exists!"
+        v[name] = sth
+        setattr(self, name, sth)
+    # dynamically adding items
+    def add_item(self, name: str, item: Union[BaseDataItem, List, Tuple], assert_non_exist=True):
+        self._add_sth(name, item, self.items, assert_non_exist)
+        return item
+    # adding info (should be builtin-type)
+    def add_info(self, name: str, info, assert_non_exist=True):
+        self._add_sth(name, info, self.info, assert_non_exist)
+        return info
+    # =====
+    # creating
+    # general creating
+    @classmethod
+    def create(cls, *args, **kwargs):
+        x = cls()
+        x._init(*args, **kwargs)
+        x._finish_create()
+        return x
+    # real creating (to be overridden, replacing original __init__)
+    def _init(self, *args, **kwargs):
+        pass
+    # finish up init (to be overridden, common routine used in both create and from_builtin)
+    def _finish_create(self):
+        pass
+    # =====
+    # general IO
+    @classmethod
+    def get_unwrap_name_mappings(cls):
+        return {}  # name to class to avoid wrapping (str -> FieldType)
+    @classmethod
+    def get_unwrap_middle_mappings(cls):
+        return {}  # name to extra middle builtin types, usually List (str -> type)
+    def to_builtin(self, *args, **kwargs):
+        unwrap_name_mappings = self.__class__.get_unwrap_name_mappings()
+        # only consider items and info
+        ret = {}
+        if len(self.info) > 0:
+            ret["_info"] = self.info
+        for k, v in self.items.items():
+            # same inside the whole layer
+            if k in unwrap_name_mappings:  # no need to record type info
+                _f = lambda _v, *_args, **_kwargs: _v.to_builtin(*_args, **_kwargs)
+            else:
+                _f = lambda _v, *_args, **_kwargs: _v.to_builtin_wrapped(*_args, **_kwargs)
+            # can be different types; todo(note): can be only one layer, typically usage is List
+            if isinstance(v, BaseDataItem):
+                d = _f(v, *args, **kwargs)
+            elif isinstance(v, (List, Tuple)):
+                d = [_f(v2, *args, **kwargs) for v2 in v]
+            else:
+                raise NotImplementedError(f"Unknown class for 'to_builtin': {v.__class__}")
+            ret[k] = d
+        return ret
+    @classmethod
+    def from_builtin(cls, data: Dict):
+        unwrap_name_mappings = cls.get_unwrap_name_mappings()
+        unwrap_middle_mappings = cls.get_unwrap_middle_mappings()
+        ret = cls()  # blank init
+        # recursively add items (corresponding to _init)
+        for k, v in data.items():
+            if k == "_info":
+                for k2, v2 in v.items():
+                    ret.add_info(k2, v2)
+            else:
+                # same inside the whole layer
+                if k in unwrap_name_mappings:  # no need to record type info
+                    trg_cls = unwrap_name_mappings[k]
+                    _f = trg_cls.from_builtin
+                else:
+                    _f = BaseDataItem.from_builtin_wrapped
+                middle_type = unwrap_middle_mappings.get(k)
+                # can be different types
+                # todo(+N): more middle types?
+                if middle_type is not None and issubclass(middle_type, (List, Tuple)):
+                    item = [_f(v2) for v2 in v]
+                else:
+                    item = _f(v)
+                # else:
+                #     raise NotImplementedError(f"Unknown type for 'from_builtin': {v.__class__}")
+                ret.add_item(k, item)
+        # finish up
+        ret._finish_create()
+        return ret
+class GeneralSentence(GeneralInstance):
+    @classmethod
+    def get_unwrap_name_mappings(cls):
+        return {"word_seq": SeqField, "pos_seq": SeqField, "pred_pos_seq": SeqField,
+                "dep_tree": DepTreeField, "pred_dep_tree": DepTreeField}
+    def _init(self, words: List[str]):
+        self.add_item("word_seq", SeqField(words))
+    def __len__(self):
+        return len(self.word_seq.vals)  # length of original input
+    @property
+    def char_seq(self):
+        if not hasattr(self, "_char_seq"):  # create from "word_seq" on demand
+            self._char_seq = InputCharSeqField(self.word_seq.vals)
+        return self._char_seq
+class GeneralParagraph(GeneralInstance):
+    @classmethod
+    def get_unwrap_name_mappings(cls):
+        return {"sents": GeneralSentence}
+    @classmethod
+    def get_unwrap_middle_mappings(cls):
+        return {"sents": list}
+    def _init(self, sents: List[GeneralSentence]):
+        self.add_item("sents", sents)
+class GeneralDocument(GeneralInstance):
+    @classmethod
+    def get_unwrap_name_mappings(cls):
+        return {"paras": GeneralParagraph}
+    @classmethod
+    def get_unwrap_middle_mappings(cls):
+        return {"paras": list}
+    def _init(self, paras: List[GeneralParagraph]):
+        self.add_item("paras", paras)
+    @property
+    def sents(self):
+        # on-the-fly joining all (no cache)
+        return Helper.join_list(z.sents for z in self.paras)
new file mode 100644
index 0000000..0c8fe81
--- /dev/null
+++ b/tasks/zmlm/data/io/__init__.py
@@ -0,0 +1,4 @@
+from .base import BaseDataReader, BaseDataWriter, PlainTextReader
+from .conllu import ConlluParseReader, ConllNerReader
diff --git a/tasks/zmlm/data/io/base.py b/tasks/zmlm/data/io/base.py
new file mode 100644
index 0000000..29214af
--- /dev/null
+++ b/tasks/zmlm/data/io/base.py
@@ -0,0 +1,150 @@
+# base IO with BaseDataItem using "to_builtin" and "from_builtin"
+import os
+import re
+import json
+from typing import Iterable
+from msp.utils import zopen
+from msp.data import Streamer, FileOrFdStreamer
+from ..insts import BaseDataItem, GeneralSentence
+# =====
+# an improved FileStreamer: read files for dir_mode, read lines for file_mode
+class PathOrFdStreamer(Streamer):
+    def __init__(self, path_or_fd: str, is_dir=False, re_pat=""):
+        super().__init__()
+        self.path_or_fd = path_or_fd
+        # for dir mode
+        self.is_dir = is_dir
+        if is_dir:
+            file_list0 = os.listdir(path_or_fd)
+            if re_pat != "":
+                if isinstance(re_pat, str):
+                    re_pat = re.compile(re_pat)
+                file_list0 = [f for f in file_list0 if re.fullmatch(re_pat, f)]
+            file_list0.sort()  # sort by string name
+            file_list = [os.path.join(path_or_fd, f) for f in file_list0]
+        else:
+            file_list = None
+        self.dir_file_list = file_list
+        self.dir_file_ptr = 0
+        # -----
+        self.input_is_fd = not isinstance(path_or_fd, str)
+        if self.input_is_fd:
+            self.fd = path_or_fd
+        else:
+            self.fd = None
+    def __del__(self):
+        if (self.fd is not None) and (not self.input_is_fd) and (not self.fd.closed):
+            self.fd.close()
+    def _restart(self):
+        if self.is_dir:
+            self.dir_file_ptr = 0
+        elif not self.input_is_fd:
+            if self.fd is not None:
+                self.fd.close()
+            self.fd = zopen(self.path_or_fd)
+        else:
+            assert self.restart_times_==0, "Cannot restart (read multiple times) a FdStreamer"
+    def _next(self):
+        if self.is_dir:
+            cur_ptr = self.dir_file_ptr
+            if cur_ptr >= len(self.dir_file_list):
+                return None
+            else:
+                with zopen(self.dir_file_list[cur_ptr]) as fd:
+                    ss = fd.read()
+                self.dir_file_ptr = cur_ptr + 1
+                return ss
+        else:
+            line = self.fd.readline()
+            if len(line) == 0:
+                return None
+            else:
+                return line
+# further using json to load
+class BaseDataReader(PathOrFdStreamer):
+    def __init__(self, cls, path_or_fd: str, is_dir=False, re_pat="", cut=-1):
+        super().__init__(path_or_fd, is_dir, re_pat)
+        #
+        assert issubclass(cls, BaseDataItem)
+        self.cls = cls  # the main BaseDataItem type
+        self.cut = cut
+    def _next(self):
+        if self.count_ == self.cut:
+            return None
+        ss = super()._next()
+        if ss is None:
+            return None
+        else:
+            d = json.loads(ss)
+            return self.cls.from_builtin(d)
+# writing
+class BaseDataWriter:
+    def __init__(self, path_or_fd: str, is_dir=False, fname_getter=None, suffix=""):
+        self.path_or_fd = path_or_fd
+        # dir mode
+        self.is_dir = is_dir
+        self.anon_counter = -1
+        self.fname_getter = fname_getter if fname_getter else self.anon_name_getter
+        self.suffix = suffix
+        # otherwise
+        if isinstance(path_or_fd, str):
+            self.fd = zopen(path_or_fd, "w")
+        else:
+            self.fd = path_or_fd
+    def anon_name_getter(self):
+        self.anon_counter += 1
+        return f"anon.{self.anon_counter}{self.suffix}"
+    def write_list(self, insts: Iterable[BaseDataItem]):
+        for inst in insts:
+            self.write_one(inst)
+    def write_one(self, inst: BaseDataItem):
+        if self.is_dir:
+            fname = self.fname_getter(inst)
+            with zopen(os.path.join(self.path_or_fd, fname), 'w') as fd:
+                json.dump(inst.to_builtin(), fd)
+        else:
+            self.fd.write(json.dumps(inst.to_builtin()))
+            self.fd.write("\n")
+    def finish(self):
+        if self.fd is not None and (not self.fd.closed):
+            self.fd.close()
+            self.fd = None
+    def __del__(self):
+        self.finish()
+# =====
+# plain reader
+class PlainTextReader(FileOrFdStreamer):
+    def __init__(self, file_or_fd, aug_code="", cut=-1, sep=None):
+        super().__init__(file_or_fd)
+        self.aug_code = aug_code
+        self.cut = cut
+        self.sep = sep
+    def _next(self):
+        if self.count_ == self.cut:
+            return None
+        line = self.fd.readline()
+        if len(line) == 0:
+            return None
+        tokens = line.rstrip("\n").split(self.sep)
+        one = GeneralSentence.create(tokens)
+        one.add_info("sid", self.count())
+        one.add_info("aug_code", self.aug_code)
+        return one
new file mode 100644
index 0000000..94465ca
--- /dev/null
+++ b/tasks/zmlm/data/io/conllu.py
@@ -0,0 +1,118 @@
+# read CoNLL's column format
+from msp.utils import Conf, FileHelper
+from msp.data import FileOrFdStreamer
+from msp.zext.dpar import ConlluReader
+from ..insts import GeneralSentence, SeqField, DepTreeField
+# read from column styled file and return Dict
+class ColumnReader(FileOrFdStreamer):
+    def __init__(self, file_or_fd, sep_line_f=lambda x: len(x.strip()) == 0, sep_field_t="\t"):
+        super().__init__(file_or_fd)
+        self.sep_line_f = sep_line_f
+        self.sep_field_t = sep_field_t
+    #
+    def _get(self, fd):
+        lines = FileHelper.read_multiline(fd, self.sep_line_f)
+        if lines is None:
+            return None
+        ret = {}
+        headlines = []
+        for one_line in lines:
+            if one_line.startswith("#"):
+                headlines.append(one_line)
+                continue
+            fileds = one_line.strip("\n").split(self.sep_field_t)
+            if len(ret) == 0:
+                for i in range(len(fileds)):
+                    ret[i] = []  # the start
+            else:
+                assert len(ret) == len(fileds), "Fields length um-matched!"
+            for i, x in enumerate(fileds):
+                ret[i].append(x)
+        ret["headlines"] = headlines
+        return ret
+    # not from self.fd
+    def yield_insts(self, file_or_fd):
+        if isinstance(file_or_fd, str):
+            with open(file_or_fd) as fd:
+                yield from self.yield_insts(fd)
+        else:
+            while True:
+                z = self._get(file_or_fd)
+                if z is None:
+                    break
+                yield z
+    def _next(self):
+        return self._get(self.fd)
+class ConlluParseReader(FileOrFdStreamer):
+    def __init__(self, file_or_fd, aug_code="", use_xpos=False, use_la0=True, cut=-1):
+        super().__init__(file_or_fd)
+        self.aug_code = aug_code
+        self.reader = ConlluReader()
+        self.cut = cut
+        #
+        if use_xpos:
+            self.pos_f = lambda t: t.xpos
+        else:
+            self.pos_f = lambda t: t.upos
+        if use_la0:
+            self.label_f = lambda t: t.label0
+        else:
+            self.label_f = lambda t: t.label
+    def _next(self):
+        if self.count_ == self.cut:
+            return None
+        parse = self.reader.read_one(self.fd)
+        if parse is None:
+            return None
+        else:
+            tokens = parse.get_tokens()
+            # one = ParseInstance([t.word for t in tokens], [self.pos_f(t) for t in tokens],
+            #                     [t.head for t in tokens], [self.label_f(t) for t in tokens], code=self.aug_code)
+            # one.init_idx = self.count()
+            one = GeneralSentence.create([t.word for t in tokens])
+            one.add_item("pos_seq", SeqField([self.pos_f(t) for t in tokens]))
+            one.add_item("dep_tree", DepTreeField([0]+[t.head for t in tokens], ['']+[self.label_f(t) for t in tokens]))
+            one.add_info("sid", self.count())
+            one.add_info("aug_code", self.aug_code)
+            return one
+# BIO NER tags reader
+class ConllNerReader(FileOrFdStreamer):
+    def __init__(self, file_or_fd, aug_code="", cut=-1):
+        super().__init__(file_or_fd)
+        self.aug_code = aug_code
+        self.cut = cut
+    def _next(self):
+        if self.count_ == self.cut:
+            return None
+        lines = FileHelper.read_multiline(self.fd)
+        if lines is None:
+            return None
+        else:
+            tokens, tags = [], []
+            for line in lines:
+                one_tok, one_tag = line.split()
+                if one_tag != "O":  # checking
+                    t0, t1 = one_tag.split("-")
+                    assert t0 in "BI"
+                tokens.append(one_tok)
+                tags.append(one_tag)
+            # one = ParseInstance([t.word for t in tokens], [self.pos_f(t) for t in tokens],
+            #                     [t.head for t in tokens], [self.label_f(t) for t in tokens], code=self.aug_code)
+            # one.init_idx = self.count()
+            one = GeneralSentence.create(tokens)
+            one.add_item("ner_seq", SeqField(tags))
+            one.add_info("sid", self.count())
+            one.add_info("aug_code", self.aug_code)
+            return one
new file mode 100644
index 0000000..b5ef009
--- /dev/null
+++ b/tasks/zmlm/main/test.py
@@ -0,0 +1,96 @@
+import msp
+from msp import utils
+from msp.utils import Helper, zlog, zwarn
+from msp.data import MultiHelper, WordVectors, VocabBuilder
+from ..run.confs import OverallConf, init_everything, build_model
+from ..run.run import get_data_reader, PreprocessStreamer, index_stream, batch_stream, MltTrainingRunner, MltTestingRunner
+from ..run.vocab import MLMVocabPackage
+def iter_hit_words(parse_streamer, emb):
+    words = set()
+    hit_count = 0
+    #
+    for inst in parse_streamer:
+        for one in inst.words.vals:
+            # todo(note): only those hit in the embeddings are added
+            if emb.has_key(one):
+                yield one
+                # count how many hit in embeddings
+                if one not in words:
+                    hit_count += 1
+            words.add(one)
+    zlog(f"Iter hit words: all={len(words)}, hit={hit_count}")
+# simply use an external one as here
+def aug_words_and_embs(model, aug_vocab, aug_wv):
+    orig_vocab = model.word_vocab
+    word_emb_node = model.embedder.get_node('word')
+    if word_emb_node is not None:
+        orig_arr = word_emb_node.E.detach().cpu().numpy()
+        # todo(+2): find same-spelling words in the original vocab if not-hit in the extra_embed?
+        # todo(warn): here aug_vocab should be find in aug_wv
+        aug_arr = aug_vocab.filter_embed(aug_wv, assert_all_hit=True)
+        new_vocab, new_arr = MultiHelper.aug_vocab_and_arr(orig_vocab, orig_arr, aug_vocab, aug_arr, aug_override=True)
+        # assign
+        model.word_vocab = new_vocab  # there should be more to replace? but since only testing maybe no need...
+        word_emb_node.replace_weights(new_arr)
+    else:
+        zwarn("No need to aug vocab since delexicalized model!!")
+        new_vocab = orig_vocab
+    return new_vocab
+def prepare_test(args, ConfType=None):
+    # conf
+    conf = init_everything(args, ConfType)
+    dconf, mconf = conf.dconf, conf.mconf
+    # vocab
+    vpack = MLMVocabPackage.build_by_reading(dconf.dict_dir)
+    # prepare data
+    test_streamer = PreprocessStreamer(get_data_reader(dconf.test, dconf.input_format),
+                                       lower_case=dconf.lower_case, norm_digit=dconf.norm_digit)
+    # model
+    model = build_model(conf, vpack)
+    if dconf.model_load_name != "":
+        model.load(dconf.model_load_name)
+    else:
+        zwarn("No model to load, Debugging mode??")
+    # -----
+    # augment with extra embeddings for test stream?
+    extra_embed_files = dconf.vconf.test_extra_pretrain_files
+    if len(extra_embed_files) > 0:
+        # get embeddings
+        extra_codes = dconf.vconf.test_extra_pretrain_codes
+        if len(extra_codes) == 0:
+            extra_codes = [""] * len(extra_embed_files)
+        extra_embedding = WordVectors.load(extra_embed_files[0], aug_code=extra_codes[0])
+        extra_embedding.merge_others([WordVectors.load(one_file, aug_code=one_code) for one_file, one_code in
+                                      zip(extra_embed_files[1:], extra_codes[1:])])
+        # get extra dictionary (only those words hit in extra-embed)
+        extra_vocab = VocabBuilder.build_from_stream(iter_hit_words(test_streamer, extra_embedding),
+                                                     sort_by_count=True, pre_list=(), post_list=())
+        # give them to the model
+        new_vocab = aug_words_and_embs(model, extra_vocab, extra_embedding)
+        vpack.put_voc("word", new_vocab)
+    # =====
+    # No Cache!!
+    test_inst_preparer = model.get_inst_preper(False)
+    backoff_pos_idx = dconf.backoff_pos_idx
+    test_iter = batch_stream(index_stream(test_streamer, vpack, False, False, test_inst_preparer, backoff_pos_idx),
+                             mconf.test_batch_size, mconf, False)
+    return conf, model, vpack, test_iter
+def main(args):
+    conf, model, vpack, test_iter = prepare_test(args)
+    dconf = conf.dconf
+    # go
+    rr = MltTestingRunner(model, vpack, dconf.output_file, dconf.test, dconf.output_format)
+    x = rr.run(test_iter)
+    utils.printing("The end.")
+# b tasks/zmlm/main/test.py:46
new file mode 100644
index 0000000..e6ffa8b
--- /dev/null
+++ b/tasks/zmlm/main/train.py
@@ -0,0 +1,60 @@
+from msp import utils
+from msp.data import MultiCatStreamer, InstCacher
+from ..run.confs import OverallConf, init_everything, build_model
+from ..run.run import get_data_reader, PreprocessStreamer, index_stream, batch_stream, MltTrainingRunner
+from ..run.vocab import MLMVocabPackage
+def main(args):
+    conf = init_everything(args)
+    dconf, mconf = conf.dconf, conf.mconf
+    # dev/test can be non-existing!
+    if not dconf.dev and dconf.test:
+        utils.zwarn("No dev but give test, actually use test as dev (for early stopping)!!")
+    dt_golds, dt_cuts = [], []
+    for file, one_cut in [(dconf.dev, dconf.cut_dev), (dconf.test, "")]:  # no cut for test!
+        if len(file)>0:
+            utils.zlog(f"Add file `{file}(cut={one_cut})' as dt-file #{len(dt_golds)}.")
+            dt_golds.append(file)
+            dt_cuts.append(one_cut)
+    if len(dt_golds) == 0:
+        utils.zwarn("No dev set, then please specify static lrate schedule!!")
+    # data
+    train_streamer = PreprocessStreamer(get_data_reader(dconf.train, dconf.input_format, cut=dconf.cut_train),
+                                        lower_case=dconf.lower_case, norm_digit=dconf.norm_digit)
+    dt_streamers = [PreprocessStreamer(get_data_reader(f, dconf.dev_input_format, cut=one_cut),
+                                       lower_case=dconf.lower_case, norm_digit=dconf.norm_digit)
+                    for f, one_cut in zip(dt_golds, dt_cuts)]
+    # vocab
+    if mconf.no_build_dict:
+        vpack = MLMVocabPackage.build_by_reading(dconf.dict_dir)
+    else:
+        # include dev/test only for convenience of including words hit in pre-trained embeddings
+        vpack = MLMVocabPackage.build_from_stream(dconf.vconf, train_streamer, MultiCatStreamer(dt_streamers))
+        vpack.save(dconf.dict_dir)
+    # aug2
+    if mconf.aug_word2:
+        vpack.aug_word2_vocab(train_streamer, MultiCatStreamer(dt_streamers), mconf.aug_word2_pretrain)
+        vpack.save(mconf.aug_word2_save_dir)
+    # model
+    model = build_model(conf, vpack)
+    # index the data
+    train_inst_preparer = model.get_inst_preper(True)
+    test_inst_preparer = model.get_inst_preper(False)
+    to_cache = dconf.cache_data
+    to_cache_shuffle = dconf.to_cache_shuffle
+    # todo(note): make sure to cache both train and dev to save time for cached computation
+    backoff_pos_idx = dconf.backoff_pos_idx
+    train_iter = batch_stream(index_stream(train_streamer, vpack, to_cache, to_cache_shuffle, train_inst_preparer, backoff_pos_idx), mconf.train_batch_size, mconf, True)
+    dt_iters = [batch_stream(index_stream(z, vpack, to_cache, to_cache_shuffle, test_inst_preparer, backoff_pos_idx), mconf.test_batch_size, mconf, False) for z in dt_streamers]
+    # training runner
+    tr = MltTrainingRunner(mconf.rconf, model, vpack, dev_outfs=dconf.output_file, dev_goldfs=dt_golds, dev_out_format=dconf.output_format)
+    if mconf.train_preload_model:
+        tr.load(dconf.model_load_name, mconf.train_preload_process)
+    # go
+    tr.run(train_iter, dt_iters)
+    utils.zlog("The end of Training.")
new file mode 100644
index 0000000..78c0243
--- /dev/null
+++ b/tasks/zmlm/main/train_linear.py
@@ -0,0 +1,144 @@
+# train a simple linear prober on the attentions
+from msp import utils
+from msp.data import MultiCatStreamer, InstCacher
+from ..run.confs import OverallConf, init_everything, build_model
+from ..run.run import get_data_reader, PreprocessStreamer, index_stream, batch_stream, MltTrainingRunner
+from ..run.vocab import MLMVocabPackage
+from typing import List
+from copy import deepcopy
+from msp.nn import BK
+from ..model.mtl import BaseModel, BaseModelConf, MtlMlmModel, GeneralSentence
+from ..model.mods.dpar import DparG1DecoderConf, DparG1Decoder, PairScorerConf
+class LinearProbeConf(BaseModelConf):
+    def __init__(self):
+        super().__init__()
+class LinearProbeModel(BaseModel):
+    def __init__(self, model: MtlMlmModel):
+        super().__init__(BaseModelConf())
+        # -----
+        self.model = model
+        model.refresh_batch(False)
+        # simple linear layer
+        dpar_conf = deepcopy(model.conf.dpar_conf)
+        dpar_conf.pre_dp_space = 0
+        dpar_conf.dps_conf = PairScorerConf().init_from_kwargs(use_input0=False, use_input1=False, use_input_pair=True,
+                                                               use_biaffine=False, use_ff1=False, use_ff2=True,
+                                                               ff2_hid_layer=0, use_bias=False)
+        dpar_conf.optim.optim = model.conf.main_optim.optim
+        inputp_dim = len(model.encoder.layers) * model.enc_attn_count  # concat all attns
+        self.dpar = self.add_component("dpar", DparG1Decoder(self.pc, None, inputp_dim, dpar_conf, model.inputter))
+    def fb_on_batch(self, insts: List[GeneralSentence], training=True, loss_factor=1.,
+                    rand_gen=None, assign_attns=False, **kwargs):
+        self.refresh_batch(training)
+        # get inputs with models
+        with BK.no_grad_env():
+            input_map = self.model.inputter(insts)
+            emb_t, mask_t, enc_t, cache, enc_loss = self.model._emb_and_enc(input_map, collect_loss=True)
+            input_t = BK.concat(cache.list_attn, -1)  # [bs, slen, slen, L*H]
+        losses = [self.dpar.loss(insts, BK.zeros([1,1]), input_t, mask_t)]
+        # -----
+        info = self.collect_loss_and_backward(losses, training, loss_factor)
+        info.update({"fb": 1, "sent": len(insts), "tok": sum(len(z) for z in insts)})
+        return info
+    def inference_on_batch(self, insts: List[GeneralSentence], **kwargs):
+        conf = self.conf
+        self.refresh_batch(False)
+        # print(f"{len(insts)}: {insts[0].sid}")
+        with BK.no_grad_env():
+            # decode for dpar
+            input_map = self.model.inputter(insts)
+            emb_t, mask_t, enc_t, cache, _ = self.model._emb_and_enc(input_map, collect_loss=False)
+            input_t = BK.concat(cache.list_attn, -1)  # [bs, slen, slen, L*H]
+            self.dpar.predict(insts, BK.zeros([1,1]), input_t, mask_t)
+        return {}
+def main(args):
+    conf = init_everything(args)
+    dconf, mconf = conf.dconf, conf.mconf
+    # dev/test can be non-existing!
+    if not dconf.dev and dconf.test:
+        utils.zwarn("No dev but give test, actually use test as dev (for early stopping)!!")
+    dt_golds, dt_cuts = [], []
+    for file, one_cut in [(dconf.dev, dconf.cut_dev), (dconf.test, "")]:  # no cut for test!
+        if len(file)>0:
+            utils.zlog(f"Add file `{file}(cut={one_cut})' as dt-file #{len(dt_golds)}.")
+            dt_golds.append(file)
+            dt_cuts.append(one_cut)
+    if len(dt_golds) == 0:
+        utils.zwarn("No dev set, then please specify static lrate schedule!!")
+    # data
+    train_streamer = PreprocessStreamer(get_data_reader(dconf.train, dconf.input_format, cut=dconf.cut_train),
+                                        lower_case=dconf.lower_case, norm_digit=dconf.norm_digit)
+    dt_streamers = [PreprocessStreamer(get_data_reader(f, dconf.dev_input_format, cut=one_cut),
+                                       lower_case=dconf.lower_case, norm_digit=dconf.norm_digit)
+                    for f, one_cut in zip(dt_golds, dt_cuts)]
+    # vocab
+    if mconf.no_build_dict:
+        vpack = MLMVocabPackage.build_by_reading(dconf.dict_dir)
+    else:
+        # include dev/test only for convenience of including words hit in pre-trained embeddings
+        vpack = MLMVocabPackage.build_from_stream(dconf.vconf, train_streamer, MultiCatStreamer(dt_streamers))
+        vpack.save(dconf.dict_dir)
+    # model
+    model = build_model(conf, vpack)
+    # index the data
+    train_inst_preparer = model.get_inst_preper(True)
+    test_inst_preparer = model.get_inst_preper(False)
+    to_cache = dconf.cache_data
+    to_cache_shuffle = dconf.to_cache_shuffle
+    # todo(note): make sure to cache both train and dev to save time for cached computation
+    backoff_pos_idx = dconf.backoff_pos_idx
+    train_iter = batch_stream(index_stream(train_streamer, vpack, to_cache, to_cache_shuffle, train_inst_preparer, backoff_pos_idx), mconf.train_batch_size, mconf, True)
+    dt_iters = [batch_stream(index_stream(z, vpack, to_cache, to_cache_shuffle, test_inst_preparer, backoff_pos_idx), mconf.test_batch_size, mconf, False) for z in dt_streamers]
+    # training runner
+    tr = MltTrainingRunner(mconf.rconf, model, vpack, dev_outfs=dconf.output_file, dev_goldfs=dt_golds, dev_out_format=dconf.output_format)
+    if mconf.train_preload_model:
+        tr.load(dconf.model_load_name, mconf.train_preload_process)
+    # =====
+    # switch with the linear model
+    linear_model = LinearProbeModel(model)
+    tr.model = linear_model
+    # =====
+    # go
+    tr.run(train_iter, dt_iters)
+    utils.zlog("The end of Training.")
+def lookat_weights():
+    import torch
+    m = torch.load("zmodel2.best")
+    assert len(m) == 1
+    map_w = list(m.values())[0]  # [output, input]
+    map_w = map_w[1:]
+    # -----
+    num_step, num_head = 6, 8
+    input_labels = [f"seeS{s}H{h}" for s in range(num_step) for h in range(num_head)]
+    output_labels = ["punct", "case", "nsubj", "det", "root", "nmod", "advmod", "obj", "obl", "amod", "compound", "aux", "conj", "mark", "cc", "cop", "advcl", "acl", "xcomp", "nummod", "ccomp", "appos", "flat", "parataxis", "discourse", "expl", "fixed", "list", "iobj", "csubj", "goeswith", "vocative", "reparandum", "orphan", "dep", "dislocated", "clf"]
+    assert tuple(map_w.shape) == (len(output_labels), len(input_labels))
+    #
+    labels = [input_labels, output_labels]
+    other_labels = [output_labels, input_labels]
+    weights = [map_w.T, map_w]
+    for one_label, one_other_label, one_weight in zip(labels, other_labels, weights):
+        topk_vals, topk_idxes = one_weight.topk(5, dim=-1)  # [?, 5]
+        assert len(one_label) == len(topk_vals)
+        for i, lab in enumerate(one_label):
+            print(f"{lab}: " + ", ".join([f"{one_other_label[o_idx]}({o_val})" for o_val, o_idx in zip(topk_vals[i], topk_idxes[i])]))
+# run
+PYTHONPATH=${SRC_DIR} CUDA_VISIBLE_DEVICES=2 python3 -m pdb ${SRC_DIR}/tasks/cmd.py zmlm.main.train_linear ${MODEL_DIR}/_conf2 train_preload_model:1 model_load_name:${MODEL_DIR}/zmodel2.best lrate.val:0.5 main_optim.optim:sgd lrate_warmup:0 max_epochs:20 lrate.start_bias:1 lrate.scale:5 lrate.m:0.5 dpar_conf.label_neg_rate:0.
+# b tasks/zmlm/main/train_linear:50
new file mode 100644
index 0000000..0ea2ca8
--- /dev/null
+++ b/tasks/zmlm/main/vocab_utils.py
@@ -0,0 +1,76 @@
+# to build vocab: first step of training
+from collections import Counter
+from msp import utils
+from msp.data import MultiCatStreamer, InstCacher
+from ..run.confs import OverallConf, init_everything, build_model
+from ..run.run import get_data_reader, PreprocessStreamer, index_stream, batch_stream, MltTrainingRunner
+from ..run.vocab import MLMVocabPackage
+def print_vocab_txt(vocab, dev_counter=None):
+    lines = []
+    lines.append(f"# Details of Vocab {vocab.name}")
+    if dev_counter is None:
+        dev_counter = {}
+    # iterate through vocab
+    all_count_voc = sum(z for z in vocab.final_vals if z is not None)
+    all_count_dev = sum(dev_counter.values())
+    all_hit_dev = sum(1 for z in dev_counter.values() if z>0)
+    accu_count_voc, accu_count_dev = 0, 0
+    accu_hit_dev_set = set()
+    for idx, key in enumerate(vocab.final_words):
+        # in vocab
+        count_voc = vocab.final_vals[idx]
+        if count_voc is None:
+            count_voc = 0
+        accu_count_voc += count_voc
+        # in dev
+        count_dev = dev_counter.get(key, 0)
+        accu_count_dev += count_dev
+        if count_dev > 0:
+            accu_hit_dev_set.add(key)
+        # -----
+        added_line = [
+            str(idx), key,
+            f"{count_voc}({count_voc/all_count_voc:.4f})", f"{accu_count_voc}({accu_count_voc/all_count_voc:.4f})",
+            f"{count_dev}({count_dev/all_count_dev:.4f})", f"{accu_count_dev}({accu_count_dev/all_count_dev:.4f})",
+        ]
+        lines.append("\t".join(added_line))
+    lines.append(f"Coverage: type={len(accu_hit_dev_set)}/{all_hit_dev}={len(accu_hit_dev_set)/max(1, all_hit_dev)}, "
+                 f"count={accu_count_dev}/{all_count_dev}={accu_count_dev/max(1, all_count_dev)}")
+    unhit_dev_counter = Counter({k:v for k,v in dev_counter.items() if k not in accu_hit_dev_set})
+    lines.append(f"Unhit dev top-50: {unhit_dev_counter.most_common(50)}")
+    return "\n".join(lines)
+# -----
+def main(args):
+    conf = init_everything(args)
+    dconf, mconf = conf.dconf, conf.mconf
+    # =====
+    if dconf.train:  # build
+        train_streamer = PreprocessStreamer(get_data_reader(dconf.train, dconf.input_format),
+                                            lower_case=dconf.lower_case, norm_digit=dconf.norm_digit)
+        vpack = MLMVocabPackage.build_from_stream(dconf.vconf, train_streamer, [])
+        vpack.save(dconf.dict_dir)
+    else:  # read
+        vpack = MLMVocabPackage.build_by_reading(dconf.dict_dir)
+    # check dev
+    if dconf.dev:
+        dev_streamer = PreprocessStreamer(get_data_reader(dconf.dev, dconf.dev_input_format),
+                                          lower_case=dconf.lower_case, norm_digit=dconf.norm_digit)
+        dev_counter = Counter()
+        for inst in dev_streamer:
+            for w in inst.word_seq.vals:
+                dev_counter[w] += 1
+        print(print_vocab_txt(vpack.get_voc("word"), dev_counter))
+# examples
+PYTHONPATH=${SRC_DIR} python3 ${SRC_DIR}/tasks/cmd.py zmlm.main.vocab_utils train:../wikis/wiki_en.100k.txt dev:../ud24s/en_train.-1.conllu input_format:plain dev_input_format:conllu norm_digit:1 >vv.list
new file mode 100644
index 0000000..23c47cf
--- /dev/null
+++ b/tasks/zmlm/model/base.py
@@ -0,0 +1,222 @@
+# base models with msp.nn
+from typing import List, Dict, Union
+from collections import OrderedDict
+from msp.model import Model
+from msp.nn import BK
+from msp.nn import refresh as nn_refresh
+from msp.nn.layers import RefreshOptions, BasicNode
+from msp.utils import Conf, zlog, JsonRW, Helper
+from msp.zext.process_train import RConf, SVConf, ScheduledValue, OptimConf
+# =====
+# conf
+class BaseModelConf(Conf):
+    def __init__(self):
+        self.rconf = RConf()
+        # overall conf
+        # optimizer
+        self.main_optim = OptimConf()
+        # lrate factor (<0 means not activated)
+        self.main_lrf = SVConf().init_from_kwargs(val=1., which_idx="eidx", mode="none", min_val=0., max_val=1.)
+        # dropouts
+        # todo(+N): should have distribute dropouts into each Node, but ...
+        self.drop_hidden = 0.1  # hdrop
+        self.gdrop_rnn = 0.33  # gdrop (always fixed for recurrent connections)
+        self.idrop_rnn = 0.33  # idrop for rnn
+        self.fix_drop = False  # fix drop for one run for each dropout
+# detailize common routines, with sub_module components
+class BaseModel(Model):
+    def __init__(self, conf: BaseModelConf):
+        self.conf = conf
+        # ===== Model =====
+        self.pc = BK.ParamCollection()
+        self.main_lrf = ScheduledValue(f"main:lrf", conf.main_lrf)
+        self._scheduled_values = [self.main_lrf]
+        # -----
+        self.nodes: Dict[str, BasicNode] = OrderedDict()
+        self.components: Dict[str, BaseModule] = OrderedDict()
+        # for refreshing dropouts
+        self.previous_refresh_training = True
+    # called before each mini-batch
+    def refresh_batch(self, training: bool):
+        # refresh graph
+        # todo(warn): make sure to remember clear this one
+        nn_refresh()
+        # refresh nodes
+        if not training:
+            if not self.previous_refresh_training:
+                # todo(+1): currently no need to refresh testing mode multiple times
+                return
+            self.previous_refresh_training = False
+            rop = RefreshOptions(training=False)  # default no dropout
+        else:
+            conf = self.conf
+            rop = RefreshOptions(training=True, hdrop=conf.drop_hidden, idrop=conf.idrop_rnn,
+                                 gdrop=conf.gdrop_rnn, fix_drop=conf.fix_drop)
+            self.previous_refresh_training = True
+        for node in self.nodes.values():
+            node.refresh(rop)
+        for node in self.components.values():
+            node.refresh(rop)
+    def update(self, lrate, grad_factor):
+        self.pc.optimizer_update(lrate, grad_factor)
+    def add_scheduled_value(self, sv):
+        self._scheduled_values.append(sv)
+        return sv
+    def get_scheduled_values(self):
+        return self._scheduled_values + Helper.join_list(z.get_scheduled_values() for z in self.components.values())
+    # == load and save models
+    # todo(warn): no need to load confs here
+    def load(self, path, strict=True):
+        self.pc.load(path, strict)
+        # self.conf = JsonRW.load_from_file(path+".json")
+        zlog(f"Load {self.__class__.__name__} model from {path}.", func="io")
+    def save(self, path):
+        self.pc.save(path)
+        JsonRW.to_file(self.conf, path + ".json")
+        zlog(f"Save {self.__class__.__name__} model to {path}.", func="io")
+    # =====
+    def add_component(self, name: str, node: 'BaseModule'):
+        assert name not in self.components
+        self.components[name] = node
+        # set up optim
+        self.pc.optimizer_set(node.conf.optim.optim, node.lrf, node.conf.optim, params=node.get_parameters(),
+                              check_repeat=True, check_full=True)
+        # pop off
+        self.pc.nnc_pop(node.name)
+        zlog(f"Add component module, {name}: {node}")
+        return node
+    def add_node(self, name: str, node: BasicNode):
+        assert name not in self.nodes
+        self.nodes[name] = node
+        # set up optim
+        self.pc.optimizer_set(self.conf.main_optim.optim, self.main_lrf, self.conf.main_optim,
+                              params=node.get_parameters(), check_repeat=True, check_full=True)
+        # pop off
+        self.pc.nnc_pop(node.name)
+        zlog(f"Add node, {name}: {node}")
+        return node
+    # collect all loss
+    def collect_loss_and_backward(self, loss_info_cols: List[Dict], training: bool, loss_factor: float):
+        final_loss_dict = LossHelper.combine_multiple(loss_info_cols)  # loss_name -> {}
+        if len(final_loss_dict) <= 0:
+            return {}  # no loss!
+        final_losses = []
+        ret_info_vals = OrderedDict()
+        for loss_name, loss_info in final_loss_dict.items():
+            final_losses.append(loss_info['sum']/(loss_info['count']+1e-5))
+            for k in loss_info.keys():
+                one_item = loss_info[k]
+                ret_info_vals[f"loss:{loss_name}_{k}"] = one_item.item() if hasattr(one_item, "item") else float(one_item)
+        final_loss = BK.stack(final_losses).sum()
+        if training and final_loss.requires_grad:
+            BK.backward(final_loss, loss_factor)
+        return ret_info_vals
+# =====
+class BaseModuleConf(Conf):
+    def __init__(self):
+        # optimizer
+        self.optim = OptimConf()
+        # lrate factor (<0 means not activated)
+        self.lrf = SVConf().init_from_kwargs(val=1., which_idx="eidx", mode="none", min_val=0., max_val=1.)
+        # margin (<0 means not activated)
+        self.margin = SVConf().init_from_kwargs(val=0.)
+        # loss lambda (for the whole module)
+        self.loss_lambda = SVConf().init_from_kwargs(val=1.)
+        # my hdrop (<0 means not activated)
+        self.hdrop = -1.
+class BaseModule(BasicNode):
+    def __init__(self, pc: BK.ParamCollection, conf: BaseModuleConf, name=None, init_rop=None, output_dims=None):
+        super().__init__(pc, name, init_rop)
+        self.conf = conf
+        self.output_dims = output_dims
+        # scheduled values
+        self.lrf = ScheduledValue(f"{self.name}:lrf", conf.lrf)
+        self.margin = ScheduledValue(f"{self.name}:margin", conf.margin)
+        self.loss_lambda = ScheduledValue(f"{self.name}:lambda", conf.loss_lambda)
+        self.hdrop = conf.hdrop
+        if self.hdrop >= 0.:
+            zlog(f"Forcing different hdrop at {self.name}: {self.hdrop}")
+    def refresh(self, rop=None):
+        local_changed = (self.hdrop>=0. and rop is not None)
+        if local_changed:  # todo(note): modified inplace
+            old_hdrop = rop.hdrop
+            rop.hdrop = self.hdrop
+        super().refresh(rop)
+        if local_changed:
+            rop.hdrop = old_hdrop
+    def get_scheduled_values(self):
+        return [self.lrf, self.margin, self.loss_lambda]
+    def get_output_dims(self, *input_dims):
+        return self.output_dims
+    # input list of leafs
+    def _compile_component_loss(self, comp_name, losses: List):
+        cur_loss_lambda = self.loss_lambda.value
+        if cur_loss_lambda <= 0.:
+            losses = []  # no loss since closed by component lambda
+        ret = LossHelper.compile_component_info(comp_name, losses, loss_lambda=cur_loss_lambda)
+        return ret
+    # =====
+    def loss(self, *args, **kwargs):
+        raise NotImplementedError()
+    def predict(self, *args, **kwargs):
+        raise NotImplementedError()
+# =====
+# loss helper
+# common format of loss_info is: {'name1': {'sum', 'count', ...}, 'name2': ...}
+class LossHelper:
+    @staticmethod
+    def compile_leaf_info(name: str, loss_sum: BK.Expr, loss_count: BK.Expr, loss_lambda=1., **other_values):
+        local_dict = {"sum0": loss_sum, "sum": loss_sum*loss_lambda, "count": loss_count, "run": 1}
+        local_dict.update(other_values)
+        return OrderedDict({name: local_dict})
+    @staticmethod
+    def compile_component_info(name: str, sub_losses: List[Dict], loss_lambda=1.):
+        ret_dict = OrderedDict()
+        for one_sub_loss in sub_losses:
+            for k, v in one_sub_loss.items():
+                ret_dict[f"{name}.{k}"] = v
+        for one_item in ret_dict.values():
+            one_item["sum"] = one_item["sum"] * loss_lambda
+        return ret_dict
+    @staticmethod
+    def combine_multiple(inputs: List[Dict]):
+        # each input is a flattened loss Dict
+        ret = OrderedDict()
+        for one_input in inputs:
+            if one_input is None:  # skip None
+                continue
+            for name, leaf_info in one_input.items():
+                if name in ret:
+                    # adding together
+                    target_info = ret[name]
+                    for k, v in leaf_info.items():
+                        target_info[k] = target_info.get(k, 0.) + v
+                else:
+                    # direct set
+                    ret[name] = leaf_info
+        return ret
diff --git a/tasks/zmlm/model/helper.py b/tasks/zmlm/model/helper.py
index 0000000..2f1cf99
--- /dev/null
+++ b/tasks/zmlm/model/helper.py
@@ -0,0 +1,106 @@
+# some helper nodes
+from typing import List, Dict
+import numpy as np
+from msp.utils import zwarn, Random, Helper, Conf
+from msp.data import VocabPackage
+from msp.nn import BK
+from msp.nn.layers import BasicNode, PosiEmbedding2
+from .mods.vrec import VRecEncoderConf, VRecEncoder, VRecConf, VRecNode
+from .base import BaseModuleConf, BaseModule
+# =====
+# the representation preparation node
+# (forward a vrec multiple times with fixed attns)
+class RPrepConf(Conf):
+    def __init__(self):
+        # outside
+        self.rprep_num_update = 0  # calling times
+        self.rprep_lambda_emb = 0.  # emb_t*L+ent_t*(1-L)
+        self.rprep_lambda_accu = 0.  # accu_attn*L+last_attn(1-L)
+        # the hided vrec: maybe overall make it simple and less parameterized
+        self._rprep_vr_conf = VRecConf()  # the fixed-attn vrec
+        # -- only open certain options
+        self.rprep_d_v = 128
+        self.rprep_v_act = "elu"
+        self.rprep_v_drop = 0.1
+        self.rprep_collect_mode = "ch"
+        self.rprep_collect_reducer_mode = "max"
+    def do_validate(self):
+        # make fewer options; all parameters include: combiner(lstm), collector(v-affine)
+        self._rprep_vr_conf.comb_mode = "lstm"  # just make it lstm
+        self._rprep_vr_conf.matt_conf.collector_conf.d_v = self.rprep_d_v
+        self._rprep_vr_conf.matt_conf.collector_conf.v_act = self.rprep_v_act
+        self._rprep_vr_conf.matt_conf.collector_conf.v_drop = self.rprep_v_drop
+        self._rprep_vr_conf.matt_conf.collector_conf.collect_mode = self.rprep_collect_mode
+        self._rprep_vr_conf.matt_conf.collector_conf.collect_reducer_mode = self.rprep_collect_reducer_mode
+class RPrepNode(BasicNode):
+    def __init__(self, pc, dim: int, conf: RPrepConf):
+        super().__init__(pc, None, None)
+        self.conf = conf
+        # -----
+        self.dim = dim
+        if conf.rprep_num_update > 0:
+            self.vrec_node = self.add_sub_node("v", VRecNode(pc, dim, conf._rprep_vr_conf))
+        else:
+            self.vrec_node = None
+    def get_output_dims(self, *input_dims):
+        return (self.dim, )
+    @property
+    def active(self):
+        return self.conf.rprep_num_update>0
+    # input from emb(very-first input), enc(encoder's output) and cache(vrec-enc's cache)
+    def __call__(self, emb_t, enc_t, cache):
+        if not self.active:
+            return enc_t
+        conf = self.conf
+        rprep_lambda_emb, rprep_lambda_accu= conf.rprep_lambda_emb, conf.rprep_lambda_accu
+        # input_t: [*, len_q, D]
+        if rprep_lambda_emb == 0.:
+            input_t = enc_t  # by default, only enc_t
+        else:
+            input_t = emb_t * rprep_lambda_emb + enc_t * (1. - rprep_lambda_emb)
+        cur_hidden = input_t
+        # another encoder
+        if conf.rprep_num_update > 0:
+            # attn_t: [*, len_q, len_k, head]
+            attn_t = cache.list_accu_attn[-1] * rprep_lambda_accu + cache.list_attn[-1] * rprep_lambda_accu
+            # call it
+            cur_cache = self.vrec_node.init_call(input_t)
+            for _ in range(conf.rprep_num_update):
+                cur_hidden = self.vrec_node.update_call(cur_cache, forced_attn=attn_t)
+        return cur_hidden
+# =====
+# various entropy
+class EntropyHelper:
+    @staticmethod
+    def cross_entropy(p, q, dim=-1):
+        return - (p * (q + 1e-10).log()).sum(dim)
+    @staticmethod
+    def kl_divergence(p, q, dim=-1):
+        return EntropyHelper.cross_entropy(p, q, dim) - EntropyHelper.cross_entropy(p, p, dim)
+    @staticmethod
+    def kl_divergence_bd(p, q, dim=-1):  # for both directions
+        return 0.5 * (EntropyHelper.kl_divergence(p, q, dim) + EntropyHelper.kl_divergence(q, p, dim))
+    @staticmethod
+    def js_divergence(p, q, dim=-1):
+        m = 0.5 * (p+q)
+        return 0.5 * (EntropyHelper.kl_divergence(p, m, dim) + EntropyHelper.kl_divergence(q, m, dim))
+    @staticmethod
+    def get_method(m):
+        return {"cross": EntropyHelper.cross_entropy, "kl": EntropyHelper.kl_divergence, "js": EntropyHelper.js_divergence,
+                "kl2": EntropyHelper.kl_divergence_bd, "klbd": EntropyHelper.kl_divergence_bd}[m]
new file mode 100644
index 0000000..672132b
--- /dev/null
+++ b/tasks/zmlm/model/mods/dpar.py
@@ -0,0 +1,188 @@
+# dependency parser (currently simply a first-order graph one)
+from typing import List, Dict
+import numpy as np
+from msp.data import Vocab
+from msp.utils import Conf, Random, Constants
+from msp.nn import BK
+from msp.nn.layers import PairScorerConf, PairScorer, Affine
+from ..base import BaseModuleConf, BaseModule, LossHelper, SVConf, ScheduledValue
+from .embedder import Inputter
+from ...data.insts import GeneralSentence, DepTreeField
+from tasks.zdpar.algo import nmst_unproj
+# conf
+class DparG1DecoderConf(BaseModuleConf):
+    def __init__(self):
+        super().__init__()
+        # --
+        # space transferring for individual inputs (0 means no this layer)
+        self.pre_dp_space = 0
+        self.pre_dp_act = "elu"
+        # dependency relation pairwise conf
+        self.dps_conf = PairScorerConf().init_from_kwargs(use_input_pair=True, use_biaffine=False, use_ff1=False, use_ff2=True)
+        # fix idx=0 score?
+        self.fix_s0 = True  # whether fix idx=0(NOPE) scores
+        self.fix_s0_val = 0.  # if fix, use what fix val?
+        # minus idx=0 score?
+        self.minus_s0 = True  # whether minus idx=0(NOPE) scores
+        # loss lambdas
+        self.lambda_label = 1.  # on the last dim
+        self.label_neg_rate = 0.1  # sample how much neg examples (label=0=nope)
+        self.lambda_head = 1.  # simply max on last dim and on -2 dim
+        # detach input? (for example, warmup for dec?)
+        # "dpar_conf.no_detach_input.mode:linear dpar_conf.no_detach_input.start_bias:2"
+        self.no_detach_input = SVConf().init_from_kwargs(val=1.)
+# module
+class DparG1Decoder(BaseModule):
+    def __init__(self, pc, input_dim: int, inputp_dim: int, conf: DparG1DecoderConf, inputter: Inputter):
+        super().__init__(pc, conf, name="dp")
+        self.conf = conf
+        self.inputter = inputter
+        self.input_dim = input_dim
+        self.inputp_dim = inputp_dim
+        # checkout and assign vocab
+        self._check_vocab()
+        # -----
+        # this step is performed at the embedder, thus still does not influence the inputter
+        self.add_root_token = self.inputter.embedder.add_root_token
+        assert self.add_root_token, "Currently assert this one!!"  # todo(+N)
+        # -----
+        # transform dp space
+        if conf.pre_dp_space > 0:
+            dp_space = conf.pre_dp_space
+            self.pre_aff_m = self.add_sub_node("pm", Affine(pc, input_dim, dp_space, act=conf.pre_dp_act))
+            self.pre_aff_h = self.add_sub_node("ph", Affine(pc, input_dim, dp_space, act=conf.pre_dp_act))
+        else:
+            dp_space = input_dim
+            self.pre_aff_m = self.pre_aff_h = lambda x: x
+        # dep pairwise scorer: output includes [0, r1) -> [non]+valid_words
+        self.dps_node = self.add_sub_node("dps", PairScorer(pc, dp_space, dp_space, self.dlab_r1,
+                                                            conf=conf.dps_conf, in_size_pair=inputp_dim))
+        self.dps_s0_mask = np.array([1.]+[0.]*(self.dlab_r1-1))  # [0, 1, ..., 1]
+        # whether detach input?
+        self.no_detach_input = ScheduledValue(f"dpar:no_detach", conf.no_detach_input)
+    def get_scheduled_values(self):
+        return super().get_scheduled_values() + [self.no_detach_input]
+    # check vocab at init
+    def _check_vocab(self):
+        # check range
+        _l_vocab: Vocab = self.inputter.vpack.get_voc("deplabel")
+        r0, r1 = _l_vocab.nonspecial_idx_range()  # [r0, r1) is the valid range
+        assert _l_vocab.non==0 and r0 == 1, "Should have preserve only idx=0 as NO-Relation"
+        # assign
+        self.dlab_vocab: Vocab = _l_vocab
+        self.dlab_r0, self.dlab_r1 = r0, r1  # no unk or others
+    # -----
+    # scoring
+    # [bs, slen, D], [bs, len_q, len_k, D'], [bs, slen]
+    def _score(self, repr_t, attn_t, mask_t):
+        conf = self.conf
+        # -----
+        repr_m = self.pre_aff_m(repr_t)  # [bs, slen, S]
+        repr_h = self.pre_aff_h(repr_t)  # [bs, slen, S]
+        scores0 = self.dps_node.paired_score(repr_m, repr_h, inputp=attn_t)  # [bs, len_q, len_k, 1+N]
+        # mask at outside
+        slen = BK.get_shape(mask_t, -1)
+        score_mask = BK.constants(BK.get_shape(scores0)[:-1], 1.)  # [bs, len_q, len_k]
+        score_mask *= (1.-BK.eye(slen))  # no diag
+        score_mask *= mask_t.unsqueeze(-1)  # input mask at len_k
+        score_mask *= mask_t.unsqueeze(-2)  # input mask at len_q
+        NEG = Constants.REAL_PRAC_MIN
+        scores1 = scores0 + NEG * (1.-score_mask.unsqueeze(-1))  # [bs, len_q, len_k, 1+N]
+        # add fixed idx0 scores if set
+        if conf.fix_s0:
+            fix_s0_mask_t = BK.input_real(self.dps_s0_mask)  # [1+N]
+            scores1 = (1.-fix_s0_mask_t) * scores1 + fix_s0_mask_t * conf.fix_s0_val  # [bs, len_q, len_k, 1+N]
+        # minus s0
+        if conf.minus_s0:
+            scores1 = scores1 - scores1.narrow(-1, 0, 1)  # minus idx=0 scores
+        return scores1, score_mask
+    # get ranged label scores for head
+    def _ranged_label_scores(self, label_scores):
+        return label_scores.narrow(-1, self.dlab_r0, self.dlab_r1-self.dlab_r0)
+    # loss
+    def loss(self, insts: List[GeneralSentence], repr_t, attn_t, mask_t, **kwargs):
+        conf = self.conf
+        # detach input?
+        if self.no_detach_input.value <= 0.:
+            repr_t = repr_t.detach()  # no grad back if no_detach_input<=0.
+        # scoring
+        label_scores, score_masks = self._score(repr_t, attn_t, mask_t)  # [bs, len_q, len_k, 1+N], [bs, len_q, len_k]
+        # -----
+        # get golds
+        bsize, max_len = BK.get_shape(mask_t)
+        shape_lidxes = [bsize, max_len, max_len]
+        gold_lidxes = np.zeros(shape_lidxes, dtype=np.long)  # [bs, mlen, mlen]
+        gold_heads = np.zeros(shape_lidxes[:-1], dtype=np.long)  # [bs, mlen]
+        for bidx, inst in enumerate(insts):
+            cur_dep_tree = inst.dep_tree
+            cur_len = len(cur_dep_tree)
+            gold_lidxes[bidx, :cur_len, :cur_len] = cur_dep_tree.label_matrix
+            gold_heads[bidx, :cur_len] = cur_dep_tree.heads
+        # -----
+        margin = self.margin.value
+        all_losses = []
+        # first is loss_labels
+        lambda_label = conf.lambda_label
+        if lambda_label > 0.:
+            gold_lidxes_t = BK.input_idx(gold_lidxes)  # [bs, len_q, len_k]
+            label_losses = BK.loss_nll(label_scores, gold_lidxes_t, margin=margin)  # [bs, mlen, mlen]
+            positive_mask_t = (gold_lidxes_t>0).float()  # [bs, mlen, mlen]
+            negative_mask_t = (BK.rand(shape_lidxes) < conf.label_neg_rate).float()  # [bs, mlen, mlen]
+            loss_mask_t = score_masks * (positive_mask_t + negative_mask_t)  # [bs, mlen, mlen]
+            loss_mask_t.clamp_(max=1.)
+            masked_label_losses = label_losses * loss_mask_t
+            # compile loss
+            final_label_loss = LossHelper.compile_leaf_info(f"label", masked_label_losses.sum(), loss_mask_t.sum(),
+                                                            loss_lambda=lambda_label, npos=positive_mask_t.sum())
+            all_losses.append(final_label_loss)
+        # then head loss
+        lambda_head = conf.lambda_head
+        if lambda_head > 0.:
+            # get head score simply by argmax on ranges
+            head_scores, _ = self._ranged_label_scores(label_scores).max(-1)  # [bs, mlen, mlen]
+            gold_heads_t = BK.input_idx(gold_heads)
+            head_losses = BK.loss_nll(head_scores, gold_heads_t, margin=margin)  # [bs, mlen]
+            # mask
+            head_mask_t = BK.copy(mask_t)
+            head_mask_t[:, 0] = 0  # not for ARTI_ROOT
+            masked_head_losses = head_losses * head_mask_t
+            # compile loss
+            final_head_loss = LossHelper.compile_leaf_info(f"head", masked_head_losses.sum(), head_mask_t.sum(),
+                                                           loss_lambda=lambda_label)
+            all_losses.append(final_head_loss)
+        # --
+        return self._compile_component_loss("dp", all_losses)
+    # decode
+    def predict(self, insts: List[GeneralSentence], repr_t, attn_t, mask_t, **kwargs):
+        conf = self.conf
+        # scoring
+        label_scores, score_masks = self._score(repr_t, attn_t, mask_t)  # [bs, len_q, len_k, 1+N], [bs, len_q, len_k]
+        # get ranged label scores
+        ranged_label_scores = self._ranged_label_scores(label_scores)  # [bs, m, h, RealLab]
+        # decode
+        mst_lengths = [len(z) + 1 for z in insts]  # +1 to include ROOT for mst decoding
+        mst_lengths_arr = np.asarray(mst_lengths, dtype=np.int32)
+        mst_heads_arr, mst_labels_arr, mst_scores_arr = \
+            nmst_unproj(ranged_label_scores, mask_t, mst_lengths_arr, labeled=True, ret_arr=True)
+        mst_labels_arr += self.dlab_r0  # add back offset
+        # assign; todo(+2): record scores?
+        for one_idx, one_inst in enumerate(insts):
+            cur_length = mst_lengths[one_idx]
+            cur_heads = [0] + mst_heads_arr[one_idx][1:cur_length].tolist()
+            cur_labels = DepTreeField.build_label_vals(mst_labels_arr[one_idx][:cur_length], self.dlab_vocab)
+            pred_dep_tree = DepTreeField(cur_heads, cur_labels, mst_labels_arr[one_idx][:cur_length].tolist())
+            one_inst.add_item("pred_dep_tree", pred_dep_tree, assert_non_exist=False)
+        return
+# b tasks/zmlm/model/mods/dpar.py:?
new file mode 100644
index 0000000..61add05
--- /dev/null
+++ b/tasks/zmlm/model/mods/embedder.py
@@ -0,0 +1,452 @@
+import numpy as np
+from typing import Tuple, Iterable, Dict, List, Set
+from collections import OrderedDict
+from copy import deepcopy
+from msp.utils import Conf, zcheck, zwarn, zlog
+from msp.data import VocabPackage
+from msp.zext.seq_helper import DataPadder
+from msp.nn import BK
+from msp.nn.layers import BasicNode, PosiEmbedding2, Embedding, CnnLayer, NoDropRop, Dropout, Affine, Sequential, LayerNorm
+from msp.nn.modules import Berter2, Berter2Conf, Berter2Seq
+from ...data.insts import GeneralSentence
+# confs for one component: some may be effective only for certain cases
+class EmbedderCompConf(Conf):
+    def __init__(self):
+        self.comp_dim = 0  # by default 0 meaning this comp is not active
+        self.comp_drop = 0.  # by default 0.
+        self.comp_init_scale = 0.  # init-scale for params
+        self.comp_init_from_pretrain = False  # try to init from pretrain?
+        # mainly for word
+        self.comp_rare_unk = 0.5  # replace unk words with UNK when training
+        self.comp_rare_thr = 0  # only replace those if freq<=this
+# overall conf
+class EmbedderNodeConf(Conf):
+    def __init__(self):
+        # embedder components
+        self.ec_word = EmbedderCompConf().init_from_kwargs(comp_dim=100)  # we mostly needs words
+        self.ec_pos = EmbedderCompConf()
+        self.ec_char = EmbedderCompConf()
+        self.ec_posi = EmbedderCompConf()
+        self.ec_bert = EmbedderCompConf()
+        # add speical node for root
+        self.add_root_token = True  # add the special bos/root/cls/...
+        # the padding idx 0, should it be zero? note: nope since there may be NAN problems with pre-norm
+        self.embed_fix_row0 = False
+        # final projection layer
+        self.emb_proj_dim = 0  # 0 means no, and thus simply concat
+        self.emb_proj_act = "linear"  # proj_act
+        self.emb_proj_init_scale = 1.
+        self.emb_proj_norm = False  # cannot add&norm since possible dim mismatch
+        # special case: cnn-char encoding
+        self.char_cnn_hidden = 50  # split by windows
+        self.char_cnn_windows = [3, 5]
+        # special case: bert setting
+        self.bert_conf = Berter2Conf().init_from_kwargs(bert2_retinc_cls=True, bert2_training_mask_rate=0.,
+                                                        bert2_output_mode="weighted")
+    @property
+    def ec_dict(self):
+        return OrderedDict([('word', self.ec_word), ('pos', self.ec_pos), ('char', self.ec_char),
+                            ('posi', self.ec_posi), ('bert', self.ec_bert)])
+    def do_validate(self):
+        assert self.bert_conf.bert2_output_mode != "layered", "Not applicable here!"
+        if self.add_root_token:  # setup for bert (CLS as the special node)
+            self.bert_conf.bert2_retinc_cls = True
+        else:
+            self.bert_conf.bert2_retinc_cls = False
+# =====
+# ModelNode and Helper for each type of input
+# -----
+# Input helper is used to prepare and transform inputs
+class InputHelper:
+    def __init__(self, comp_name, vpack: VocabPackage):
+        self.comp_name = comp_name
+        self.comp_seq_name = f"{comp_name}_seq"
+        self.voc = vpack.get_voc(comp_name)
+        self.padder = DataPadder(2, pad_vals=0)  # pad 0
+    # return batched arr
+    def prepare(self, insts: List[GeneralSentence]):
+        cur_input_list = [getattr(z, self.comp_seq_name).idxes for z in insts]
+        cur_input_arr, _ = self.padder.pad(cur_input_list)
+        return cur_input_arr
+    def mask(self, v, erase_mask):
+        IDX_MASK = self.voc.err  # todo(note): this one is unused, thus just take it!
+        ret_arr = v * (1-erase_mask) + IDX_MASK * erase_mask
+        return ret_arr
+    @staticmethod
+    def get_input_helper(name, *args, **kwargs):
+        helper_type = {"char": CharCnnHelper, "posi": PosiHelper, "bert": BertInputHelper}.get(name, PlainSeqHelper)
+        return helper_type(*args, **kwargs)
+# InputEmbedNode is used for embedding corresponding to the Input
+class InputEmbedNode(BasicNode):
+    def __init__(self, pc: BK.ParamCollection, comp_name: str, ec_conf: EmbedderCompConf,
+                 conf: EmbedderNodeConf, vpack: VocabPackage):
+        super().__init__(pc, None, None)
+        # -----
+        self.ec_conf = ec_conf
+        self.conf = conf
+        self.comp_name, self.comp_dim, self.comp_dropout, self.comp_init_scale = \
+            comp_name, ec_conf.comp_dim, ec_conf.comp_drop, ec_conf.comp_init_scale
+        self.voc = vpack.get_voc(comp_name, None)
+        self.output_dim = self.comp_dim  # by default the input size (may be changed later)
+        self.dropout = None  # created later (after deciding output shape)
+    def create_dropout_node(self):
+        if self.comp_dropout > 0.:
+            self.dropout = self.add_sub_node(
+                f"D{self.comp_name}", Dropout(self.pc, (self.output_dim,), fix_rate=self.comp_dropout))
+        else:
+            self.dropout = lambda x: x
+        return self.dropout
+    def get_output_dims(self, *input_dims):
+        return (self.output_dim, )
+    @staticmethod
+    def get_input_embed_node(name, *args, **kwargs):
+        node_type = {"char": CharCnnNode, "posi": PosiNode, "bert": BertInputNode}.get(name, PlainSeqNode)
+        return node_type(*args, **kwargs)
+# -----
+# plain seq
+# the basic one is the plain one
+PlainSeqHelper = InputHelper
+class PlainSeqNode(InputEmbedNode):
+    def __init__(self, pc: BK.ParamCollection, comp_name: str, ec_conf: EmbedderCompConf,
+                 conf: EmbedderNodeConf, vpack: VocabPackage):
+        super().__init__(pc, comp_name, ec_conf, conf, vpack)
+        # -----
+        # get embeddings
+        npvec = None
+        if self.ec_conf.comp_init_from_pretrain:
+            npvec = vpack.get_emb(comp_name)
+            zlog(f"Try to init InputEmbedNode {comp_name} with npvec.shape={npvec.shape if (npvec is not None) else None}")
+            if npvec is None:
+                zwarn("Warn: cannot get pre-trained embeddings to init!!")
+        # get rare unk range
+        # - get freq vals, make sure special ones will not be pruned; todo(note): directly use that field
+        voc_rare_mask = [float(z is not None and z<=ec_conf.comp_rare_thr) for z in self.voc.final_vals]
+        self.rare_mask = BK.input_real(voc_rare_mask)
+        self.use_rare_unk = (ec_conf.comp_rare_unk>0. and ec_conf.comp_rare_thr>0)
+        # --
+        # dropout outside explicitly
+        self.E = self.add_sub_node(f"E{self.comp_name}", Embedding(
+            pc, len(self.voc), self.comp_dim, fix_row0=conf.embed_fix_row0, npvec=npvec, name=comp_name,
+            init_rop=NoDropRop(), init_scale=self.comp_init_scale))
+        self.create_dropout_node()
+    # [*, slen] -> [*, 1+slen, D]
+    def __call__(self, input, add_root_token: bool):
+        voc = self.voc
+        # todo(note): append a [cls/root] idx, currently use "bos"
+        input_t = BK.input_idx(input)  # [*, 1+slen]
+        # rare unk in training
+        if self.rop.training and self.use_rare_unk:
+            rare_unk_rate = self.ec_conf.comp_rare_unk
+            cur_unk_imask = (self.rare_mask[input_t] * (BK.rand(BK.get_shape(input_t))<rare_unk_rate)).detach().long()
+            input_t = input_t * (1-cur_unk_imask) + self.voc.unk * cur_unk_imask
+        # root
+        if add_root_token:
+            input_t_p0 = BK.constants(BK.get_shape(input_t)[:-1]+[1], voc.bos, dtype=input_t.dtype)  # [*, 1+slen]
+            input_t_p1 = BK.concat([input_t_p0, input_t], -1)
+        else:
+            input_t_p1 = input_t
+        expr = self.E(input_t_p1)  # [*, 1?+slen]
+        return self.dropout(expr)
+# -----
+# char-cnn seq
+class CharCnnHelper(InputHelper):
+    def __init__(self, comp_name, vpack: VocabPackage):
+        super().__init__(comp_name, vpack)
+        self.padder = DataPadder(3, pad_vals=0)  # replace the padder
+    def mask(self, v, erase_mask):
+        return super().mask(v, np.expand_dims(erase_mask, axis=-1))  # todo(note): for simplicity, simply allow all-unk
+class CharCnnNode(PlainSeqNode):
+    def __init__(self, pc: BK.ParamCollection, comp_name: str, ec_conf: EmbedderCompConf,
+                 conf: EmbedderNodeConf, vpack: VocabPackage):
+        super().__init__(pc, comp_name, ec_conf, conf, vpack)
+        # -----
+        per_cnn_size = conf.char_cnn_hidden // len(conf.char_cnn_windows)
+        self.char_cnns = [self.add_sub_node("char_cnn", CnnLayer(
+            self.pc, self.comp_dim, per_cnn_size, z, pooling="max", act="tanh", init_rop=NoDropRop()))
+                          for z in conf.char_cnn_windows]
+        self.output_dim = conf.char_cnn_hidden
+        self.create_dropout_node()
+    # [*, slen, wlen] -> [*, 1+slen, wlen, D]
+    def __call__(self, char_input, add_root_token: bool):
+        char_input_t = BK.input_idx(char_input)  # [*, slen, wlen]
+        if add_root_token:
+            slice_shape = BK.get_shape(char_input_t)
+            slice_shape[-2] = 1
+            char_input_t0 = BK.constants(slice_shape, 0, dtype=char_input_t.dtype)  # todo(note): simply put 0 here!
+            char_input_t1 = BK.concat([char_input_t0, char_input_t], -2)  # [*, 1?+slen, wlen]
+        else:
+            char_input_t1 = char_input_t
+        char_embeds = self.E(char_input_t1)  # [*, 1?+slen, wlen, D]
+        char_cat_expr = BK.concat([z(char_embeds) for z in self.char_cnns])
+        return self.dropout(char_cat_expr)  # todo(note): only final dropout
+# -----
+# positional seq (absolute position)
+class PosiHelper(InputHelper):
+    def __init__(self, comp_name, vpack: VocabPackage):
+        super().__init__(comp_name, vpack)
+    # return batched arr
+    def prepare(self, insts: List[GeneralSentence]):
+        shape = (len(insts), max(len(z) for z in insts))
+        return shape
+    def mask(self, v, erase_mask):
+        return v  # todo(note): generally no need to change posi info for word mask
+class PosiNode(InputEmbedNode):
+    def __init__(self, pc: BK.ParamCollection, comp_name: str, ec_conf: EmbedderCompConf,
+                 conf: EmbedderNodeConf, vpack: VocabPackage):
+        super().__init__(pc, comp_name, ec_conf, conf, vpack)
+        # -----
+        self.node = self.add_sub_node("posi", PosiEmbedding2(pc, self.comp_dim, init_scale=self.comp_init_scale))
+        self.create_dropout_node()
+    def __call__(self, input_v, add_root_token: bool):
+        if isinstance(input_v, np.ndarray):
+            # direct use this [batch_size, slen] as input
+            posi_idxes = BK.input_idx(input_v)
+            expr = self.node(posi_idxes)  # [batch_size, slen, D]
+        else:
+            # input is a shape as prepared by "PosiHelper"
+            batch_size, max_len = input_v
+            if add_root_token:
+                max_len += 1
+            posi_idxes = BK.arange_idx(max_len)  # [1?+slen] add root=0 here
+            expr = self.node(posi_idxes).unsqueeze(0).expand(batch_size, -1, -1)
+        return self.dropout(expr)
+# -----
+# for bert
+class BertInputHelper(InputHelper):
+    def __init__(self, comp_name, vpack: VocabPackage):
+        super().__init__(comp_name, vpack)
+        # ----
+        self.berter: Berter2 = None
+        self.mask_id = None
+    def set_berter(self, berter):
+        self.berter = berter  # todo(note): need to set from outside!
+        self.mask_id = self.berter.tokenizer.mask_token_id
+    # return List[Bert]
+    def prepare(self, insts: List[GeneralSentence]):
+        _key = "_bert2seq"
+        rets = []
+        for z in insts:
+            _v = z.features.get(_key)
+            if _v is None:  # no cache, rebuild it
+                _v = self.berter.subword_tokenize2(z.word_seq.vals, True)  # here use orig strs
+                z.features[_key] = _v
+            rets.append(_v)
+        return rets
+    def mask(self, v, erase_mask):
+        ret = [one_b2s.apply_mask_new(one_mask, self.mask_id) for one_b2s, one_mask in zip(v, erase_mask)]
+        return ret
+class BertInputNode(InputEmbedNode):
+    def __init__(self, pc: BK.ParamCollection, comp_name: str, ec_conf: EmbedderCompConf,
+                 conf: EmbedderNodeConf, vpack: VocabPackage):
+        super().__init__(pc, comp_name, ec_conf, conf, vpack)
+        # -----
+        if conf.add_root_token:
+            assert conf.bert_conf.bert2_retinc_cls, "Require add root token by include CLS for Berter2"
+        self.berter = self.add_sub_node("bert", Berter2(pc, conf.bert_conf))
+        self.output_dim = self.berter.get_output_dims()[0]  # fix size
+        self.create_dropout_node()
+    def __call__(self, input_list: List[Berter2Seq], add_root_token: bool):
+        # no need to put add_root_token since already setup for berter
+        expr = self.berter(input_list)
+        return self.dropout(expr)  # [bs, 1?+slen, D]
+# =====
+# the overall manager
+class EmbedderNode(BasicNode):
+    def __init__(self, pc: BK.ParamCollection, conf: EmbedderNodeConf, vpack: VocabPackage):
+        super().__init__(pc, None, None)
+        self.conf = conf
+        self.vpack = vpack
+        self.add_root_token = conf.add_root_token
+        # -----
+        self.nodes = []  # params
+        self.comp_names = []
+        self.comp_dims = []  # real dims
+        self.berter: Berter2 = None
+        for comp_name, comp_conf in conf.ec_dict.items():
+            if comp_conf.comp_dim > 0:
+                # directly get the nodes
+                one_node = InputEmbedNode.get_input_embed_node(comp_name, pc, comp_name, comp_conf, conf, vpack)
+                comp_dim = one_node.get_output_dims()[0]  # fix dim
+                # especially for berter
+                if comp_name == "bert":
+                    assert self.berter is None
+                    self.berter = one_node.berter
+                # general steps
+                self.comp_names.append(comp_name)
+                self.nodes.append(self.add_sub_node(f"EC{comp_name}", one_node))
+                self.comp_dims.append(comp_dim)
+        # final projection?
+        self.has_proj = (conf.emb_proj_dim > 0)
+        if self.has_proj:
+            proj_layer = Affine(self.pc, sum(self.comp_dims), conf.emb_proj_dim,
+                                act=conf.emb_proj_act, init_scale=conf.emb_proj_init_scale)
+            if conf.emb_proj_norm:
+                norm_layer = LayerNorm(self.pc, conf.emb_proj_dim)
+                self.final_layer = self.add_sub_node("fl", Sequential(self.pc, [proj_layer, norm_layer]))
+            else:
+                self.final_layer = self.add_sub_node("fl", proj_layer)
+            self.output_dim = conf.emb_proj_dim
+        else:
+            self.final_layer = None
+            self.output_dim = sum(self.comp_dims)
+    def get_node(self, name, df=None):
+        try:
+            _idx = self.comp_names.index(name)
+            return self.nodes[_idx]
+        except:
+            return df
+    def get_output_dims(self, *input_dims):
+        return (self.output_dim, )
+    def __repr__(self):
+        comp_ss = [f"{n}({d})" for n, d in zip(self.comp_names, self.comp_dims)]
+        return f"# EmbedderNode: {comp_ss} -> {self.output_dim}"
+    # mainly need inputs of [*, slen] or other stuffs specific to the nodes
+    # return reprs, masks
+    def __call__(self, input_map: Dict):
+        exprs = []
+        # get masks: this mask is for validing of inst batching
+        final_masks = BK.input_real(input_map["mask"])  # [*, slen]
+        if self.add_root_token:  # append 1
+            slice_t = BK.constants(BK.get_shape(final_masks)[:-1]+[1], 1.)
+            final_masks = BK.concat([slice_t, final_masks], -1)  # [*, 1+slen]
+        # -----
+        # for each component
+        for idx, name in enumerate(self.comp_names):
+            cur_node = self.nodes[idx]
+            cur_input = input_map[name]
+            cur_expr = cur_node(cur_input, self.add_root_token)
+            exprs.append(cur_expr)
+        # -----
+        concated_exprs = BK.concat(exprs, dim=-1)
+        # optional proj
+        if self.has_proj:
+            final_expr = self.final_layer(concated_exprs)
+        else:
+            final_expr = concated_exprs
+        return final_expr, final_masks
+# =====
+# the overall input helper
+class Inputter:
+    def __init__(self, embedder: EmbedderNode, vpack: VocabPackage):
+        self.vpack = vpack
+        self.embedder = embedder
+        # -----
+        # prepare the inputter
+        self.comp_names = embedder.comp_names
+        self.comp_helpers = []
+        for comp_name in self.comp_names:
+            one_helper = InputHelper.get_input_helper(comp_name, comp_name, vpack)
+            self.comp_helpers.append(one_helper)
+            if comp_name == "bert":
+                assert embedder.berter is not None
+                one_helper.set_berter(berter=embedder.berter)
+        # ====
+        self.mask_padder = DataPadder(2, pad_vals=0., mask_range=2)
+    def __call__(self, insts: List[GeneralSentence]):
+        # first pad words to get masks
+        _, masks_arr = self.mask_padder.pad([z.word_seq.idxes for z in insts])
+        # then get each one
+        ret_map = {"mask": masks_arr}
+        for comp_name, comp_helper in zip(self.comp_names, self.comp_helpers):
+            ret_map[comp_name] = comp_helper.prepare(insts)
+        return ret_map
+    # todo(note): return new masks (input is read only!!)
+    def mask_input(self, input_map: Dict, input_erase_mask, nomask_names_set: Set):
+        ret_map = {"mask": input_map["mask"]}
+        for comp_name, comp_helper in zip(self.comp_names, self.comp_helpers):
+            if comp_name in nomask_names_set:  # direct borrow that one
+                ret_map[comp_name] = input_map[comp_name]
+            else:
+                ret_map[comp_name] = comp_helper.mask(input_map[comp_name], input_erase_mask)
+        return ret_map
+# =====
+# aug word2: both inputter and embedder
+class AugWord2Node(BasicNode):
+    def __init__(self, pc: BK.ParamCollection, ref_conf: EmbedderNodeConf, ref_vpack: VocabPackage,
+                 comp_name: str, comp_dim: int, output_dim: int):
+        super().__init__(pc, None, None)
+        # =====
+        self.ref_conf = ref_conf
+        self.ref_vpack = ref_vpack
+        self.add_root_token = ref_conf.add_root_token
+        # modify conf
+        _tmp_ec_conf = deepcopy(ref_conf.ec_word)
+        _tmp_ec_conf.comp_dim = comp_dim
+        _tmp_ec_conf.comp_init_from_pretrain = True
+        # -----
+        self.w_node = self.add_sub_node("nw", InputEmbedNode.get_input_embed_node(
+            comp_name, pc, comp_name, _tmp_ec_conf, ref_conf, ref_vpack))
+        self.c_node = self.add_sub_node("nc", InputEmbedNode.get_input_embed_node(
+            "char", pc, "char", deepcopy(ref_conf.ec_char), ref_conf, ref_vpack))
+        # final proj
+        _dims = [self.w_node.get_output_dims()[0], self.c_node.get_output_dims()[0]]
+        self.final_layer = self.add_sub_node("fl", Affine(
+            self.pc, _dims, output_dim, act=ref_conf.emb_proj_act, init_scale=ref_conf.emb_proj_init_scale))
+        self.output_dim = output_dim
+        # =====
+        # inputter
+        self.w_helper = InputHelper.get_input_helper(comp_name, comp_name, ref_vpack)
+        self.c_helper = InputHelper.get_input_helper("char", "char", ref_vpack)
+    def get_output_dims(self, *input_dims):
+        return (self.output_dim, )
+    def __call__(self, insts):
+        # inputter
+        input_w_arr = self.w_helper.prepare(insts)
+        input_c_arr = self.c_helper.prepare(insts)
+        # embedder
+        output_w_t = self.w_node(input_w_arr, self.add_root_token)
+        output_c_t = self.c_node(input_c_arr, self.add_root_token)
+        final_t = self.final_layer([output_w_t, output_c_t])
+        return final_t
+# b tasks/zmlm/model/mods/embedder:342
new file mode 100644
index 0000000..250a2dc
--- /dev/null
+++ b/tasks/zmlm/model/mods/masklm.py
@@ -0,0 +1,188 @@
+# the module for masked lm
+from typing import List, Dict, Tuple
+import numpy as np
+from msp.utils import Conf, Random, zlog, JsonRW, zfatal, zwarn
+from msp.data import VocabPackage
+from msp.zext.seq_helper import DataPadder
+from msp.nn import BK
+from msp.nn.layers import Affine, NoDropRop
+from ..base import BaseModuleConf, BaseModule, LossHelper
+from .embedder import Inputter
+# -----
+class MaskLMNodeConf(BaseModuleConf):
+    def __init__(self):
+        super().__init__()
+        self.loss_lambda.val = 0.  # by default no such loss
+        # ----
+        # hidden layer
+        self.hid_dim = 0
+        self.hid_act = "elu"
+        # mask/pred options
+        self.mask_rate = 0.
+        self.min_mask_rank = 1  # min word idx to mask
+        self.min_pred_rank = 1  # min word idx to pred for the masked ones
+        self.max_pred_rank = 1  # max word idx to pred for the masked ones
+        self.tie_input_embeddings = False  # tie all preds with input embeddings
+        self.init_pred_from_pretrain = False
+        # lambdas for word/pos
+        self.lambda_word = 1.
+        self.lambda_pos = 0.
+        # sometimes maybe we don't want to mask some fields, for example, predicting words with full pos
+        self.nomask_names = ["pos"]
+        # =====
+        # special loss considering multiple layers
+        self.loss_layers = [-1]  # by default use only the last layer
+        self.loss_weights = [1.]  # weights for each layer?
+        self.loss_comb_method = "min"  # sum/avg/min/max
+        self.score_comb_method = "sum"  # how to combine scores (logprobs)
+class MaskLMNode(BaseModule):
+    def __init__(self, pc: BK.ParamCollection, input_dim: int, conf: MaskLMNodeConf, inputter: Inputter):
+        super().__init__(pc, conf, name="MLM")
+        self.conf = conf
+        self.inputter = inputter
+        self.input_dim = input_dim
+        # this step is performed at the embedder, thus still does not influence the inputter
+        self.add_root_token = self.inputter.embedder.add_root_token
+        # vocab and padder
+        vpack = inputter.vpack
+        vocab_word, vocab_pos = vpack.get_voc("word"), vpack.get_voc("pos")
+        # no mask fields
+        self.nomask_names_set = set(conf.nomask_names)
+        # models
+        if conf.hid_dim <= 0:  # no hidden layer
+            self.hid_layer = None
+            self.pred_input_dim = input_dim
+        else:
+            self.hid_layer = self.add_sub_node("hid", Affine(pc, input_dim, conf.hid_dim, act=conf.hid_act))
+            self.pred_input_dim = conf.hid_dim
+        # todo(note): unk is the first one above real words
+        self.pred_word_size = min(conf.max_pred_rank+1, vocab_word.unk)
+        self.pred_pos_size = vocab_pos.unk
+        if conf.tie_input_embeddings:
+            zwarn("Tie all preds in mlm with input embeddings!!")
+            self.pred_word_layer = self.pred_pos_layer = None
+            self.inputter_word_node = self.inputter.embedder.get_node("word")
+            self.inputter_pos_node = self.inputter.embedder.get_node("pos")
+        else:
+            self.inputter_word_node, self.inputter_pos_node = None, None
+            self.pred_word_layer = self.add_sub_node("pw", Affine(pc, self.pred_input_dim, self.pred_word_size, init_rop=NoDropRop()))
+            self.pred_pos_layer = self.add_sub_node("pp", Affine(pc, self.pred_input_dim, self.pred_pos_size, init_rop=NoDropRop()))
+            if conf.init_pred_from_pretrain:
+                npvec = vpack.get_emb("word")
+                if npvec is None:
+                    zwarn("Pretrained vector not provided, skip init pred embeddings!!")
+                else:
+                    with BK.no_grad_env():
+                        self.pred_word_layer.ws[0].copy_(BK.input_real(npvec[:self.pred_word_size].T))
+                    zlog(f"Init pred embeddings from pretrained vectors (size={self.pred_word_size}).")
+        # =====
+        COMBINE_METHOD_FS = {
+            "sum": lambda xs: BK.stack(xs, -1).sum(-1), "avg": lambda xs: BK.stack(xs, -1).mean(-1),
+            "min": lambda xs: BK.stack(xs, -1).min(-1)[0], "max": lambda xs: BK.stack(xs, -1).max(-1)[0],
+        }
+        self.loss_comb_f = COMBINE_METHOD_FS[conf.loss_comb_method]
+        self.score_comb_f = COMBINE_METHOD_FS[conf.score_comb_method]
+    # ----
+    # mask randomly certain part for each inst
+    def mask_input(self, input_map: Dict, rand_gen=None, extra_mask_arr=None):
+        if rand_gen is None:
+            rand_gen = Random
+        # -----
+        conf = self.conf
+        # original valid mask
+        input_mask = input_map["mask"]  # [*, len]
+        # create mask for input words
+        input_word = input_map.get("word")  # [*, len]
+        # prepare 'input_erase_mask' and 'output_pred_mask'
+        input_erase_mask = (rand_gen.random_sample(input_mask.shape) < conf.mask_rate) & (input_mask>0.)  # by mask-rate
+        if input_word is not None:  # min word to mask
+            input_erase_mask &= (input_word >= conf.min_mask_rank)
+        if extra_mask_arr is not None:  # extra mask for valid points
+            input_erase_mask &= (extra_mask_arr>0.)
+        # get the new masked inputs (mask for all input components)
+        new_map = self.inputter.mask_input(input_map, input_erase_mask.astype(np.int), self.nomask_names_set)
+        return new_map, input_erase_mask.astype(np.float32)
+    # helper method: p self.arr2words(seq_idx_t); p self.arr2words(argmax_idxes)
+    def arr2words(self, arr):
+        vocab_word = self.inputter.vpack.get_voc("word")
+        flattened_arr = arr.reshape(-1)
+        words_list = [vocab_word.idx2word(i) for i in flattened_arr]
+        words_arr = np.asarray(words_list).reshape(arr.shape)
+        return words_arr
+    # [bsize, slen, *]
+    # todo(note): be careful that repr_t can be root-appended!!
+    def loss(self, repr_ts, input_erase_mask_arr, orig_map: Dict, active_hid=True, **kwargs):
+        conf = self.conf
+        _tie_input_embeddings = conf.tie_input_embeddings
+        # prepare idxes for the masked ones
+        if self.add_root_token:  # offset for the special root added in embedder
+            mask_idxes, mask_valids = BK.mask2idx(BK.input_real(input_erase_mask_arr), padding_idx=-1)  # [bsize, ?]
+            repr_mask_idxes = mask_idxes + 1
+            mask_idxes.clamp_(min=0)
+        else:
+            mask_idxes, mask_valids = BK.mask2idx(BK.input_real(input_erase_mask_arr))  # [bsize, ?]
+            repr_mask_idxes = mask_idxes
+        # get the losses
+        if BK.get_shape(mask_idxes, -1) == 0:  # no loss
+            return self._compile_component_loss("mlm", [])
+        else:
+            if not isinstance(repr_ts, (List, Tuple)):
+                repr_ts = [repr_ts]
+            target_word_scores, target_pos_scores = [], []
+            target_pos_scores = None  # todo(+N): for simplicity, currently ignore this one!!
+            for layer_idx in conf.loss_layers:
+                # calculate scores
+                target_reprs = BK.gather_first_dims(repr_ts[layer_idx], repr_mask_idxes, 1)  # [bsize, ?, *]
+                if self.hid_layer and active_hid:  # todo(+N): sometimes, we only want last softmax, need to ensure dim at outside!
+                    target_hids = self.hid_layer(target_reprs)
+                else:
+                    target_hids = target_reprs
+                if _tie_input_embeddings:
+                    pred_W = self.inputter_word_node.E.E[:self.pred_word_size]  # [PSize, Dim]
+                    target_word_scores.append(BK.matmul(target_hids, pred_W.T))  # List[bsize, ?, Vw]
+                else:
+                    target_word_scores.append(self.pred_word_layer(target_hids))  # List[bsize, ?, Vw]
+            # gather the losses
+            all_losses = []
+            for pred_name, target_scores, loss_lambda, range_min, range_max in \
+                    zip(["word", "pos"], [target_word_scores, target_pos_scores], [conf.lambda_word, conf.lambda_pos],
+                        [conf.min_pred_rank, 0], [min(conf.max_pred_rank, self.pred_word_size-1), self.pred_pos_size-1]):
+                if loss_lambda > 0.:
+                    seq_idx_t = BK.input_idx(orig_map[pred_name])  # [bsize, slen]
+                    target_idx_t = seq_idx_t.gather(-1, mask_idxes)  # [bsize, ?]
+                    ranged_mask_valids = mask_valids * (target_idx_t>=range_min).float() * (target_idx_t<=range_max).float()
+                    target_idx_t[(ranged_mask_valids < 1.)] = 0  # make sure invalid ones in range
+                    # calculate for each layer
+                    all_layer_losses, all_layer_scores = [], []
+                    for one_layer_idx, one_target_scores in enumerate(target_scores):
+                        # get loss: [bsize, ?]
+                        one_pred_losses = BK.loss_nll(one_target_scores, target_idx_t) * conf.loss_weights[one_layer_idx]
+                        all_layer_losses.append(one_pred_losses)
+                        # get scores
+                        one_pred_scores = BK.log_softmax(one_target_scores, -1) * conf.loss_weights[one_layer_idx]
+                        all_layer_scores.append(one_pred_scores)
+                    # combine all layers
+                    pred_losses = self.loss_comb_f(all_layer_losses)
+                    pred_loss_sum = (pred_losses * ranged_mask_valids).sum()
+                    pred_loss_count = ranged_mask_valids.sum()
+                    # argmax
+                    _, argmax_idxes = self.score_comb_f(all_layer_scores).max(-1)
+                    pred_corrs = (argmax_idxes == target_idx_t).float() * ranged_mask_valids
+                    pred_corr_count = pred_corrs.sum()
+                    # compile leaf loss
+                    r_loss = LossHelper.compile_leaf_info(pred_name, pred_loss_sum, pred_loss_count,
+                                                          loss_lambda=loss_lambda, corr=pred_corr_count)
+                    all_losses.append(r_loss)
+            return self._compile_component_loss("mlm", all_losses)
# b tasks/zmlm/model/mods/masklm:170
+# the module for order predictor
+from typing import Dict
+import numpy as np
+from msp.utils import Conf, Random
+from msp.nn import BK
+from msp.nn.layers import PairScorerConf, PairScorer, Affine, NoDropRop
+from ..base import BaseModuleConf, BaseModule, LossHelper
+from .embedder import Inputter
+# -----
+# conf
+class OrderPredNodeConf(BaseModuleConf):
+    def __init__(self):
+        super().__init__()
+        self.loss_lambda.val = 0.  # by default no such loss
+        # --
+        self.disturb_mode = "abs_bin"  # abs_noise/abs_bin/rel_noise/rel_bin
+        self.disturb_ranges = [0, 1]  # [a, b) range for either abs or rel
+        self.abs_noise_sort = True  # for "abs_noise", sort the idxes, making it like local shuffling
+        self.abs_keep_cls = True  # if using "abs_*", keep cls/arti_root/... as it is
+        self.bin_rand_offset = True  # add random offsets [0, R) to idxes (only in blur1)
+        self.bin_blur_method = 2  # which blurring method for bin
+        self.bin_blur_bag_lbound = 100  # <=0 means [R+this, R], >0 means (auto) [R//2+1, R]
+        self.bin_blur_bag_keep_rate = 0.  # keep the original order in what percentage of the bags
+        # pairwise scorer
+        self.ps_conf = PairScorerConf().init_from_kwargs(use_bias=False, use_input_pair=False,
+                                                         use_biaffine=False, use_ff1=False, use_ff2=True)
+        self.pred_abs = False  # predicting only absolute distance?
+        self.pred_range = 0  # [-pr, 0) + (0, pr], not predicting self
+        self.cand_range = -1  # at least as pred_range
+        # simple clasification losses
+        self.lambda_n1 = 1.  # reduce on dim=-1(2*R): classify dist for each cand
+        self.lambda_n2 = 1.  # reduce on dim=-2(cand): classify cand for each dist
+    def do_validate(self):
+        self.disturb_ranges = [int(z) for z in self.disturb_ranges]
+        assert len(self.disturb_ranges) == 2 and self.disturb_ranges[1]>self.disturb_ranges[0]
+        self.cand_range = max(self.cand_range, self.pred_range)  # at least as pred_range
+# module
+class OrderPredNode(BaseModule):
+    def __init__(self, pc: BK.ParamCollection, input_dim: int, inputp_dim: int, conf: OrderPredNodeConf, inputter: Inputter):
+        super().__init__(pc, conf, name="ORP")
+        self.conf = conf
+        self.inputter = inputter
+        self.input_dim = input_dim
+        self.inputp_dim = inputp_dim
+        # -----
+        # this step is performed at the embedder, thus still does not influence the inputter
+        self.add_root_token = self.inputter.embedder.add_root_token
+        # --
+        self._disturb_f = getattr(self, "_disturb_"+conf.disturb_mode)  # shortcut
+        self.disturb_range_low, self.disturb_range_high = conf.disturb_ranges
+        self._blur_idxes_f = [None, self._blur_idxes1, self._blur_idxes2][conf.bin_blur_method]
+        self._bin_blur_bag_lbound = conf.bin_blur_bag_lbound
+        self._bin_blur_bag_keep_rate = conf.bin_blur_bag_keep_rate
+        # -----
+        # simple dist predicting: inputs -> [*, len_q, len_k, 2*PredRange]
+        self.ps_node = self.add_sub_node("ps", PairScorer(pc, input_dim, input_dim, 2*conf.pred_range,
+                                                          conf=conf.ps_conf, in_size_pair=inputp_dim))
+        # --
+        self.speical_hid_layer = None
+    # =====
+    # different disturbing methods
+    # add_noise then bin: either stay at this bag or jump to the next
+    def _disturb_local_shuffle(self, input_mask_arr, disturb_range, rand_gen):
+        conf = self.conf
+        assert self._bin_blur_bag_keep_rate == 0., "Currently not supporting this one for this mode!"
+        batch_size, max_len = input_mask_arr.shape
+        # --
+        R = disturb_range
+        # first assign ordinary range idxes
+        orig_idxes = np.arange(max_len)[np.newaxis, :]  # [1, max_len]
+        # add noise: either stay at this bag or jump to the next one
+        noised_idxes = orig_idxes + R * rand_gen.randint(2, size=(batch_size, max_len))  # [bs, len]
+        blur_idxes = noised_idxes // R * R  # [bs, len]
+        torank_values = blur_idxes + rand_gen.random_sample(size=(batch_size, max_len))  # sample small from [0,1)
+        if conf.abs_keep_cls and self.add_root_token:
+            torank_values[:, 0] = -R  # make it rank 0
+        torank_values += (R*100) * (1.-input_mask_arr)  # make invalid ones stay there
+        ret_abs_idxes = torank_values.argsort(axis=-1)  # [bs, len]
+        # [bs, 1, len] - [bs, len, 1] = [bs, len, len]
+        ret_rel_dists = ret_abs_idxes[:, np.newaxis, :] - ret_abs_idxes[:, :, np.newaxis]
+        return ret_abs_idxes, ret_rel_dists
+    # first split into local bags and then shuffle for local bags
+    def _disturb_local_shuffle2(self, input_mask_arr, disturb_range, rand_gen):
+        conf = self.conf
+        batch_size, max_len = input_mask_arr.shape
+        # --
+        R = disturb_range
+        arange_arr = np.arange(batch_size)[:, np.newaxis]  # [bs, 1]
+        arange_arr2 = np.arange(max_len)[np.newaxis, :]  # [1, bs]
+        # sample bag sizes
+        orig_size = (batch_size, max_len)
+        # sample bag sizes
+        if self._bin_blur_bag_lbound <= 0:
+            _L = max(1, R + self._bin_blur_bag_lbound)  # at least bag_size >= 1
+        else:
+            _L = R // 2 + 1  # automatically
+        bag_sizes = rand_gen.randint(_L, 1+R, size=orig_size)  # [bs, len]
+        mark_idxes = np.cumsum(bag_sizes, -1)  # [bs*len]
+        mark_idxes = mark_idxes * (mark_idxes<max_len)  # [bs*len], make invalid ones 0
+        # marked seq
+        marked_seq = np.zeros(orig_size, dtype=np.int)  # [bs, len]
+        marked_seq[arange_arr, mark_idxes] = 1  # [bs, len]
+        marked_seq[:,0] = 1  # make them all start at 1, since some invalid ones can have idx=0
+        # again cumsum to get values
+        cumsum_values = np.cumsum(marked_seq.reshape(orig_size), -1) - 1  # [bs, len], minus 1 to make it start with 0
+        torank_values = cumsum_values.copy()
+        if conf.abs_keep_cls and self.add_root_token:
+            torank_values[:, 0] = -R  # make it rank 0
+        # keep some bags?
+        _bin_blur_bag_keep_rate = self._bin_blur_bag_keep_rate
+        if _bin_blur_bag_keep_rate>0.:
+            bag_keeps = (rand_gen.random_sample(size=orig_size)<_bin_blur_bag_keep_rate).astype(np.float)  # [bs, len], whether keep bag
+            keep_valid = bag_keeps[arange_arr, cumsum_values]  # [bs, len], whether keep token
+            keep_valid_plus = np.clip(keep_valid + (1. - input_mask_arr), 0, 1)
+            keep_adding = arange_arr2 * keep_valid_plus  # adding this to sorting values
+        else:
+            keep_valid = None
+            keep_adding = (R*max(10, max_len)) * (1.-input_mask_arr)  # make invalid ones stay there
+        # shuffle as noise and rank
+        torank_values = torank_values*max(10, max_len) + keep_adding
+        noised_to_rank_values = torank_values + rand_gen.random_sample(size=orig_size)  # [bs, len]
+        sorted_idxes = noised_to_rank_values.argsort(axis=-1)  # [bs, len]
+        ret_abs_idxes = sorted_idxes.argsort(axis=-1)  # [bs, len], get real ranking
+        # [bs, 1, len] - [bs, len, 1] = [bs, len, len]
+        ret_rel_dists = ret_abs_idxes[:, np.newaxis, :] - ret_abs_idxes[:, :, np.newaxis]
+        return ret_abs_idxes, ret_rel_dists, keep_valid
+    # orig_idx + noise
+    def _disturb_abs_noise(self, input_mask_arr, disturb_range, rand_gen):
+        conf = self.conf
+        assert self._bin_blur_bag_keep_rate == 0., "Currently not supporting this one for this mode!"
+        batch_size, max_len = input_mask_arr.shape
+        # --
+        R = disturb_range * 2  # +/-range
+        # first assign ordinary range idxes
+        orig_idxes = np.arange(max_len)[np.newaxis, :]  # [1, max_len]
+        # add noises: [batch_size, max_len]
+        noised_idxes = orig_idxes + R * (rand_gen.random_sample((batch_size, max_len))-0.5)  # +/- disturb_range
+        if conf.abs_noise_sort:  # simple argsort
+            if conf.abs_keep_cls and self.add_root_token:
+                noised_idxes[:, 0] = -R  # make it rank 0
+            noised_idxes += (R*100) * (1.-input_mask_arr)  # make invalid ones stay there
+            ret_abs_idxes = noised_idxes.argsort(axis=-1)
+        else:  # clip and round
+            # todo(note): this can lead to repeated idxes
+            ret_abs_idxes = noised_idxes.clip(0, max_len-1).round().astype(np.int64)
+            if conf.abs_keep_cls and self.add_root_token:
+                ret_abs_idxes[:, 0] = 0  # directly put it 0
+        # [bs, 1, len] - [bs, len, 1] = [bs, len, len]
+        ret_rel_dists = ret_abs_idxes[:, np.newaxis, :] - ret_abs_idxes[:, :, np.newaxis]
+        return ret_abs_idxes, ret_rel_dists
+    # orig_dist + noise
+    def _disturb_rel_noise(self, input_mask_arr, disturb_range, rand_gen):
+        conf = self.conf
+        assert self._bin_blur_bag_keep_rate == 0., "Currently not supporting this one for this mode!"
+        batch_size, max_len = input_mask_arr.shape
+        # --
+        R = disturb_range * 2  # +/-range
+        # first assign ordinary range idxes
+        orig_idxes0 = np.arange(max_len)  # [max_len]
+        orig_idxes = orig_idxes0[np.newaxis, :]  # [1, max_len]
+        orig_rel_idxes = orig_idxes[:, np.newaxis, :] - orig_idxes[:, :, np.newaxis]  # [1, len, len]
+        # add noises +/- disturb_range: [batch_size, max_len, max_len]
+        noised_rel_idxes_f = orig_rel_idxes + R * (rand_gen.random_sample([batch_size, max_len, max_len])-0.5)
+        # round
+        noised_rel_idxes = noised_rel_idxes_f.round().astype(np.int64)  # [bs, len, len]
+        # keep diag as 0 and others not 0 (>0:+1 <0:-1)
+        noised_rel_idxes += (2*(noised_rel_idxes_f>=0) - 1) * (noised_rel_idxes==0)  # += (1 if f>=0 else -1) * (i==0)
+        noised_rel_idxes[:, orig_idxes0, orig_idxes0] = 0
+        # no abs_idxes available
+        return None, noised_rel_idxes
+    # =====
+    # helper for *bin*, orig_idxes should be >=0; keep the original ones if <keep_thresh
+    def blur_idxes(self, orig_idxes, R: int, keep_thresh: int, rand_offset_size, rand_gen):
+        offset_orig_idxes = (orig_idxes-keep_thresh).clip(min=0)  # offset by keep_thresh
+        # bin the idxes
+        blur_idxes, blur_keeps = self._blur_idxes_f(offset_orig_idxes, R, rand_offset_size, rand_gen)
+        # keep original ones if <thresh & starting with this for the blur ones
+        ret_idxes = np.where(orig_idxes < keep_thresh, orig_idxes, blur_idxes + keep_thresh)
+        return ret_idxes, blur_keeps
+    # simple blur
+    def _blur_idxes1(self, input_idxes, R: int, rand_offset_size, rand_gen):
+        bin_rand_offset = self.conf.bin_rand_offset
+        # add offsets (1): make it group slightly differently at the start
+        if bin_rand_offset:  # but keep at least a relatively large group >R/2
+            toblur_idxes = input_idxes + rand_gen.randint(1 + R // 2, size=rand_offset_size)
+        else:
+            toblur_idxes = input_idxes
+        blur_idxes = toblur_idxes // R * R
+        assert self._bin_blur_bag_keep_rate==0., "Currently not supporting this one for this mode!"
+        # add random offsets (2)
+        if bin_rand_offset:
+            blur_idxes += rand_gen.randint(R, size=rand_offset_size)
+        return blur_idxes, None
+    # random split bags with unequal sizes and random pick one as representative
+    # todo(note): input idxes cannot be >= shape[1]+R, otherwise will oor-error
+    def _blur_idxes2(self, orig_idxes, R: int, rand_offset_size, rand_gen):
+        _bin_blur_bag_keep_rate = self._bin_blur_bag_keep_rate
+        # conf = self.conf
+        batch_size, max_len = orig_idxes.shape[:2]  # todo(note): assume the first two dims means this!!
+        # --
+        prep_size = (batch_size, max_len+R)  # make slightly larger prep sizes since the input make be larger
+        arange_arr = np.arange(batch_size)[:, np.newaxis]  # [bs, 1]
+        # sample bag sizes
+        if self._bin_blur_bag_lbound <= 0:
+            _L = max(1, R+self._bin_blur_bag_lbound)  # at least bag_size >= 1
+        else:
+            _L = R//2+1  # automatically
+        bag_sizes = rand_gen.randint(_L, 1+R, size=prep_size)  # [bs, L+R], (3,4,3,5,...)
+        if _bin_blur_bag_keep_rate>0.:
+            bag_keeps = (rand_gen.random_sample(size=prep_size)<_bin_blur_bag_keep_rate)  # [bs, L+R], (1,0,1,1,...)
+        else:
+            bag_keeps = None  # no keeps
+        # csum for the idxes
+        mark_idxes = np.cumsum(bag_sizes, -1)  # [bs, L+R], (3,7,10,15,...)
+        mark_idxes_i = mark_idxes * (mark_idxes<max_len+R)  # [bs, L+R], make invalid ones 0
+        # mark idxes for later selecting
+        marked_seq = np.zeros(prep_size, dtype=np.int)  # [bs, L+R], (0,0,0,1,0,0,0,0,1,...)
+        marked_seq[arange_arr, mark_idxes_i] = 1  # [bs, L+R]
+        marked_seq[:, 0] = 0  # initial ones can only be marked by invalid ones
+        # get idxes for range(max_len)
+        range_idxes = np.cumsum(marked_seq, -1)  # [bs, L+R], (0,0,0,1,1,1,1,1,2,...)
+        # get representatives for each idx (each bag?): (3-?,7-?,10-?,15-?,...)
+        repr_idxes = np.floor(mark_idxes - (bag_sizes * rand_gen.random_sample(prep_size))).astype(np.int)  # [bs, L+R]
+        # --
+        # final select (twice)
+        sel_range_idxes = repr_idxes[arange_arr, range_idxes]  # [bs, L+R], (3-?,3-?,3-?,7-?,7-?,7-?,7-?,7-?,10-?,...)
+        if _bin_blur_bag_keep_rate>0.:
+            sel_range_keeps = bag_keeps[arange_arr, range_idxes]  # [bs, L+R], whether keep for each idx
+            sel_range_idxes = np.where(sel_range_keeps, np.arange(max_len+R)[np.newaxis,:], sel_range_idxes)  # [bs,L+R]
+            sel_ret_keeps = sel_range_keeps[arange_arr, orig_idxes.reshape([batch_size, -1])].reshape(orig_idxes.shape).astype(np.float)  # [bs, L]
+        else:
+            sel_ret_keeps = None
+        sel_ret_idxes = sel_range_idxes[arange_arr, orig_idxes.reshape([batch_size, -1])].reshape(orig_idxes.shape)  # [bs, L]
+        return sel_ret_idxes, sel_ret_keeps
+    # bin(orig_idx)
+    def _disturb_abs_bin(self, input_mask_arr, disturb_range, rand_gen):
+        conf = self.conf
+        batch_size, max_len = input_mask_arr.shape
+        # --
+        R = disturb_range  # bin size
+        rand_offset_size = (batch_size, 1)
+        # first assign ordinary range idxes
+        orig_idxes = np.arange(max_len) + np.zeros(rand_offset_size, dtype=np.int)  # [bs, len]
+        blur_keep_thresh = int(conf.abs_keep_cls and self.add_root_token)  # whether keep ARTI_ROOT as 0
+        ret_abs_idxes, ret_abs_keeps = self.blur_idxes(orig_idxes, R, blur_keep_thresh, (batch_size, 1), rand_gen)  # [bs, max_len]
+        # [bs, 1, len] - [bs, len, 1] = [bs, len, len]
+        ret_rel_dists = ret_abs_idxes[:, np.newaxis, :] - ret_abs_idxes[:, :, np.newaxis]
+        return ret_abs_idxes, ret_rel_dists, ret_abs_keeps
+    # bin(orig_dist)
+    def _disturb_rel_bin(self, input_mask_arr, disturb_range, rand_gen):
+        # conf = self.conf
+        batch_size, max_len = input_mask_arr.shape
+        # --
+        R = disturb_range  # bin size
+        rand_offset_size = (batch_size, 1, 1)
+        # first assign ordinary range idxes
+        orig_idxes = np.arange(max_len)  # [max_len]
+        orig_rel_idxes = orig_idxes[np.newaxis, :] - orig_idxes[:, np.newaxis]  # [len, len]
+        abs_rel_idxes = np.abs(orig_rel_idxes) + np.zeros(rand_offset_size, dtype=np.int)  # [bs, len, len]
+        # always keep 0 as special
+        blur_abs_rel_idxes, _ = self.blur_idxes(abs_rel_idxes, R, 1, rand_offset_size, rand_gen)  # [bs, len, len]
+        # recover sign
+        blur_rel_idxes = blur_abs_rel_idxes * (2*(orig_rel_idxes>=0) - 1)  # recover signs
+        # no abs_idxes available
+        return None, blur_rel_idxes
+    # -----
+    def disturb_input(self, input_map: Dict, rand_gen=None):
+        if rand_gen is None:
+            rand_gen = Random
+        # -----
+        # get shape
+        input_mask_arr = input_map["mask"]  # [bs, len]
+        disturb_range = rand_gen.randint(self.disturb_range_low, self.disturb_range_high)  # random range
+        if self.add_root_token:  # add +1 here!!
+            input_mask_arr = np.pad(input_mask_arr, ((0,0),(1,0)), constant_values=1.)  # [bs, 1+len]
+        # get disturbed idxes/dists
+        disturb_ret = self._disturb_f(input_mask_arr, disturb_range, rand_gen)
+        abs_idxes, rel_dists = disturb_ret[:2]
+        # shallow copy
+        ret_map = input_map.copy()
+        ret_map["posi"] = abs_idxes
+        ret_map["rel_dist"] = rel_dists
+        if len(disturb_ret)>2:
+            ret_map["disturb_keep"] = disturb_ret[2]
+        # # debug
+        # self.see_disturbed_inputs(input_map, abs_idxes)
+        # # -----
+        return ret_map
+    # helper function to see
+    def see_disturbed_inputs(self, input_map, abs_idxes):
+        input_word_idxes = input_map["word"]  # [bs, len?]
+        if self.add_root_token:
+            input_word_idxes = np.pad(input_word_idxes, ((0,0),(1,0)), constant_values=0)  # [bs, 1+len]
+        # argsort again to get inversed idxes
+        inversed_idxes = np.argsort(abs_idxes, -1)  # [bs, len?]
+        disturbed_word_idxes = input_word_idxes[np.arange(abs_idxes.shape[0])[:,np.newaxis], inversed_idxes]  # [bs, len?]
+        tshape = disturbed_word_idxes.shape
+        # -----
+        word_voc = self.inputter.vpack.get_voc("word")
+        orig_word_strs = np.array([word_voc.idx2word(z) for z in input_word_idxes.reshape(-1)], dtype=object).reshape(tshape)
+        disturbed_word_strs = np.array([word_voc.idx2word(z) for z in disturbed_word_idxes.reshape(-1)], dtype=object).reshape(tshape)
+        return orig_word_strs, disturbed_word_strs
+    # [bs, slen, *], [bs, len_q, len_k, head], [bs, slen], [bs, slen]
+    def loss(self, repr_t, attn_t, mask_t, disturb_keep_arr, **kwargs):
+        conf = self.conf
+        CR, PR = conf.cand_range, conf.pred_range
+        # -----
+        mask_single = BK.copy(mask_t)
+        # no predictions for ARTI_ROOT
+        if self.add_root_token:
+            mask_single[:, 0] = 0.  # [bs, slen]
+        # casting predicting range
+        cur_slen = BK.get_shape(mask_single, -1)
+        arange_t = BK.arange_idx(cur_slen)  # [slen]
+        # [1, len] - [len, 1] = [len, len]
+        reldist_t = (arange_t.unsqueeze(-2) - arange_t.unsqueeze(-1))  # [slen, slen]
+        mask_pair = ((reldist_t.abs() <= CR) & (reldist_t != 0)).float()  # within CR-range; [slen, slen]
+        mask_pair = mask_pair * mask_single.unsqueeze(-1) * mask_single.unsqueeze(-2)  # [bs, slen, slen]
+        if disturb_keep_arr is not None:
+            mask_pair *= BK.input_real(1.-disturb_keep_arr).unsqueeze(-1)  # no predictions for the kept ones!
+        # get all pair scores
+        score_t = self.ps_node.paired_score(repr_t, repr_t, attn_t, maskp=mask_pair)  # [bs, len_q, len_k, 2*R]
+        # -----
+        # loss: normalize on which dim?
+        # get the answers first
+        if conf.pred_abs:
+            answer_t = reldist_t.abs()  # [1,2,3,...,PR]
+            answer_t.clamp_(min=0, max=PR-1)  # [slen, slen], clip in range, distinguish using masks
+        else:
+            answer_t = BK.where((reldist_t>=0), reldist_t-1, reldist_t+2*PR)  # [1,2,3,...PR,-PR,-PR+1,...,-1]
+            answer_t.clamp_(min=0, max=2*PR-1)  # [slen, slen], clip in range, distinguish using masks
+        # expand answer into idxes
+        answer_hit_t = BK.zeros(BK.get_shape(answer_t) + [2*PR])  # [len_q, len_k, 2*R]
+        answer_hit_t.scatter_(-1, answer_t.unsqueeze(-1), 1.)
+        answer_valid_t = ((reldist_t.abs() <= PR) & (reldist_t != 0)).float().unsqueeze(-1)  # [bs, len_q, len_k, 1]
+        answer_hit_t = answer_hit_t * mask_pair.unsqueeze(-1) * answer_valid_t  # clear invalid ones; [bs, len_q, len_k, 2*R]
+        # get losses sum(log(answer*prob))
+        # -- dim=-1 is standard 2*PR classification, dim=-2 usually have 2*PR candidates, but can be less at edges
+        all_losses = []
+        for one_dim, one_lambda in zip([-1, -2], [conf.lambda_n1, conf.lambda_n2]):
+            if one_lambda > 0.:
+                # since currently there can be only one or zero correct answer
+                logprob_t = BK.log_softmax(score_t, one_dim)  # [bs, len_q, len_k, 2*R]
+                sumlogprob_t = (logprob_t * answer_hit_t).sum(one_dim)  # [bs, len_q, len_k||2*R]
+                cur_dim_mask_t = (answer_hit_t.sum(one_dim)>0.).float()  # [bs, len_q, len_k||2*R]
+                # loss
+                cur_dim_loss = - (sumlogprob_t * cur_dim_mask_t).sum()
+                cur_dim_count = cur_dim_mask_t.sum()
+                # argmax and corr (any correct counts)
+                _, cur_argmax_idxes = score_t.max(one_dim)
+                cur_corrs = answer_hit_t.gather(one_dim, cur_argmax_idxes.unsqueeze(one_dim))  # [bs, len_q, len_k|1, 2*R|1]
+                cur_dim_corr_count = cur_corrs.sum()
+                # compile loss
+                one_loss = LossHelper.compile_leaf_info(f"d{one_dim}", cur_dim_loss, cur_dim_count,
+                                                        loss_lambda=one_lambda, corr=cur_dim_corr_count)
+                all_losses.append(one_loss)
+        return self._compile_component_loss("orp", all_losses)
+    # =====
+    # add a linear layer for speical loss
+    def add_node_special(self, masklm_node):
+        self.speical_hid_layer = self.add_sub_node("shid", Affine(
+            self.pc, self.input_dim, masklm_node.conf.hid_dim, act=masklm_node.conf.hid_act))
+    # borrow masklm and predict for local_shuffle2
+    def loss_special(self, repr_t, mask_t, disturb_keep_arr, input_map, masklm_node):
+        pred_mask_t = BK.copy(mask_t)
+        pred_mask_t *= BK.input_real(1.-disturb_keep_arr)  # not for the non-shuffled ones
+        abs_posi = input_map["posi"]  # shuffled positions
+        # no predictions for ARTI_ROOT
+        if self.add_root_token:
+            pred_mask_t[:, 0] = 0.  # [bs, slen]
+            abs_posi = (abs_posi[:,1:]-1)  # offset by 1
+            pred_mask_t = pred_mask_t[:,1:]  # remove root
+        corr_targets_arr = input_map["word"][np.arange(len(abs_posi))[:,np.newaxis], abs_posi]
+        repr_t_hid = self.speical_hid_layer(repr_t)  # go through hid here!!
+        loss_item = masklm_node.loss([repr_t_hid], pred_mask_t, {"word": corr_targets_arr}, active_hid=False)
+        if len(loss_item) == 0:
+            return loss_item
+        # todo(note): simply change its name
+        vs = [v for v in loss_item.values()]
+        assert len(vs)==1
+        return {"orp.d-2": vs[0]}
# b tasks/zmlm/model/mods/orderpr:305
+# the module for plain lm
+# todo(note): require directional encoders, like RNN with enc_rnn_sep_bidirection=True!
+from typing import List, Dict
+from msp.utils import Conf, zwarn, zlog
+from msp.nn import BK
+from msp.nn.layers import BasicNode, Affine, NoDropRop
+from ..base import BaseModuleConf, BaseModule, LossHelper
+from .embedder import Inputter
+# -----
+class PlainLMNodeConf(BaseModuleConf):
+    def __init__(self):
+        super().__init__()
+        self.loss_lambda.val = 0.  # by default no such loss
+        # ----
+        # hidden layer
+        self.hid_dim = 0
+        self.hid_act = "elu"
+        # pred options
+        self.split_input_blm = True  # split the inputs into half for (l2r+r2l) for blm, otherwise input list or all l2r
+        self.min_pred_rank = 1  # >= min word idx to pred for the masked ones
+        self.max_pred_rank = 1  # <= max word idx to pred for the masked ones
+        # tie weights?
+        self.tie_input_embeddings = False  # tie all preds with input embeddings
+        self.tie_bidirect_pred = False  # tie params for bidirection preds
+        self.init_pred_from_pretrain = False
+class PlainLMNode(BaseModule):
+    def __init__(self, pc: BK.ParamCollection, input_dim: int, conf: PlainLMNodeConf, inputter: Inputter):
+        super().__init__(pc, conf, name="PLM")
+        self.conf = conf
+        self.inputter = inputter
+        self.input_dim = input_dim
+        self.split_input_blm = conf.split_input_blm
+        # this step is performed at the embedder, thus still does not influence the inputter
+        self.add_root_token = self.inputter.embedder.add_root_token
+        # vocab and padder
+        vpack = inputter.vpack
+        vocab_word = vpack.get_voc("word")
+        # models
+        real_input_dim = input_dim//2 if self.split_input_blm else input_dim
+        if conf.hid_dim <= 0:  # no hidden layer
+            self.l2r_hid_layer = self.r2l_hid_layer = None
+            self.pred_input_dim = real_input_dim
+        else:
+            self.l2r_hid_layer = self.add_sub_node("l2r_h", Affine(pc, real_input_dim, conf.hid_dim, act=conf.hid_act))
+            self.r2l_hid_layer = self.add_sub_node("r2l_h", Affine(pc, real_input_dim, conf.hid_dim, act=conf.hid_act))
+            self.pred_input_dim = conf.hid_dim
+        # todo(note): unk is the first one above real words
+        self.pred_size = min(conf.max_pred_rank+1, vocab_word.unk)
+        if conf.tie_input_embeddings:
+            zwarn("Tie all preds in plm with input embeddings!!")
+            self.l2r_pred = self.r2l_pred = None
+            self.inputter_embed_node = self.inputter.embedder.get_node("word")
+        else:
+            self.l2r_pred = self.add_sub_node("l2r_p", Affine(pc, self.pred_input_dim, self.pred_size, init_rop=NoDropRop()))
+            if conf.tie_bidirect_pred:
+                self.r2l_pred = self.l2r_pred
+            else:
+                self.r2l_pred = self.add_sub_node("r2l_p", Affine(pc, self.pred_input_dim, self.pred_size, init_rop=NoDropRop()))
+            self.inputter_embed_node = None
+            if conf.init_pred_from_pretrain:
+                npvec = vpack.get_emb("word")
+                if npvec is None:
+                    zwarn("Pretrained vector not provided, skip init pred embeddings!!")
+                else:
+                    with BK.no_grad_env():
+                        self.l2r_pred.ws[0].copy_(BK.input_real(npvec[:self.pred_size].T))
+                        self.r2l_pred.ws[0].copy_(BK.input_real(npvec[:self.pred_size].T))
+                    zlog(f"Init pred embeddings from pretrained vectors (size={self.pred_size}).")
+    # [bsize, slen, *]
+    # todo(note): be careful that repr_t can be root-appended!!
+    def loss(self, repr_t, orig_map: Dict, **kwargs):
+        conf = self.conf
+        _tie_input_embeddings = conf.tie_input_embeddings
+        # --
+        # specify input
+        add_root_token = self.add_root_token
+        # get from inputs
+        if isinstance(repr_t, (list, tuple)):
+            l2r_repr_t, r2l_repr_t = repr_t
+        elif self.split_input_blm:
+            l2r_repr_t, r2l_repr_t = BK.chunk(repr_t, 2, -1)
+        else:
+            l2r_repr_t, r2l_repr_t = repr_t, None
+        # l2r and r2l
+        word_t = BK.input_idx(orig_map["word"])  # [bs, rlen]
+        slice_zero_t = BK.zeros([BK.get_shape(word_t, 0), 1]).long()  # [bs, 1]
+        if add_root_token:
+            l2r_trg_t = BK.concat([word_t, slice_zero_t], -1)  # pad one extra 0, [bs, rlen+1]
+            r2l_trg_t = BK.concat([slice_zero_t, slice_zero_t, word_t[:,:-1]], -1)  # pad two extra 0 at front, [bs, 2+rlen-1]
+        else:
+            l2r_trg_t = BK.concat([word_t[:,1:], slice_zero_t], -1)  # pad one extra 0, but remove the first one, [bs, -1+rlen+1]
+            r2l_trg_t = BK.concat([slice_zero_t, word_t[:,:-1]], -1)  # pad one extra 0 at front, [bs, 1+rlen-1]
+        # gather the losses
+        all_losses = []
+        pred_range_min, pred_range_max = max(1, conf.min_pred_rank), self.pred_size-1
+        if _tie_input_embeddings:
+            pred_W = self.inputter_embed_node.E.E[:self.pred_size]  # [PSize, Dim]
+        else:
+            pred_W = None
+        # get input embeddings for output
+        for pred_name, hid_node, pred_node, input_t, trg_t in \
+                    zip(["l2r", "r2l"], [self.l2r_hid_layer, self.r2l_hid_layer], [self.l2r_pred, self.r2l_pred],
+                        [l2r_repr_t, r2l_repr_t], [l2r_trg_t, r2l_trg_t]):
+            if input_t is None:
+                continue
+            # hidden
+            hid_t = hid_node(input_t) if hid_node else input_t  # [bs, slen, hid]
+            # pred: [bs, slen, Vsize]
+            if _tie_input_embeddings:
+                scores_t = BK.matmul(hid_t, pred_W.T)
+            else:
+                scores_t = pred_node(hid_t)
+            # loss
+            mask_t = ((trg_t >= pred_range_min) & (trg_t <= pred_range_max)).float()  # [bs, slen]
+            trg_t.clamp_(max=pred_range_max)  # make it in range
+            losses_t = BK.loss_nll(scores_t, trg_t) * mask_t  # [bs, slen]
+            _, argmax_idxes = scores_t.max(-1)  # [bs, slen]
+            corrs_t = (argmax_idxes == trg_t).float() * mask_t  # [bs, slen]
+            # compile leaf loss
+            one_loss = LossHelper.compile_leaf_info(pred_name, losses_t.sum(), mask_t.sum(), loss_lambda=1., corr=corrs_t.sum())
+            all_losses.append(one_loss)
+        return self._compile_component_loss("plm", all_losses)
# b zmlm/model/mods/plainlm:45
+# sequence labeling module with linear CRF
+# adopted from NCRF++: https://github.com/jiesutd/NCRFpp/blob/master/model/crf.py (17 Dec 2018)
+from typing import List, Dict, Tuple
+import numpy as np
+from msp.utils import Conf, Random, zlog, JsonRW, zfatal, zwarn
+from msp.nn import BK
+from msp.nn.layers import Affine, NoDropRop
+from ..base import BaseModuleConf, BaseModule, LossHelper
+from .embedder import Inputter
+from ...data.insts import GeneralSentence, SeqField
+import torch
+# extra ones
+STOP_TAG = -1
+# Compute log sum exp in a numerically stable way for the forward algorithm
+def log_sum_exp(vec, m_size):
+    """
+    calculate log of exp sum
+    args:
+        vec (batch_size, vanishing_dim, hidden_dim) : input tensor
+        m_size : hidden_dim
+    return:
+        batch_size, hidden_dim
+    """
+    _, idx = torch.max(vec, 1)  # B * 1 * M
+    max_score = torch.gather(vec, 1, idx.view(-1, 1, m_size)).view(-1, 1, m_size)  # B * M
+    return max_score.view(-1, m_size) + torch.log(torch.sum(torch.exp(vec - max_score.expand_as(vec)), 1)).view(-1, m_size)  # B * M
+# --
+class SeqCrfNodeConf(BaseModuleConf):
+    def __init__(self):
+        super().__init__()
+        self.loss_lambda.val = 0.  # by default no such loss
+        # --
+        # hidden layer
+        self.hid_dim = 0
+        self.hid_act = "elu"
+        # for loss function
+        self.div_by_tok = True
+class SeqCrfNode(BaseModule):
+    def __init__(self, pc: BK.ParamCollection, pname: str, input_dim: int, conf: SeqCrfNodeConf, inputter: Inputter):
+        super().__init__(pc, conf, name="CRF")
+        self.conf = conf
+        self.inputter = inputter
+        self.input_dim = input_dim
+        # this step is performed at the embedder, thus still does not influence the inputter
+        self.add_root_token = self.inputter.embedder.add_root_token
+        # --
+        self.pname = pname
+        self.attr_name = pname + "_seq"  # attribute name in Instance
+        self.vocab = inputter.vpack.get_voc(pname)
+        # todo(note): we must make sure that 0 means NAN
+        assert self.vocab.non == 0
+        # models
+        if conf.hid_dim <= 0:  # no hidden layer
+            self.hid_layer = None
+            self.pred_input_dim = input_dim
+        else:
+            self.hid_layer = self.add_sub_node("hid", Affine(pc, input_dim, conf.hid_dim, act=conf.hid_act))
+            self.pred_input_dim = conf.hid_dim
+        self.tagset_size = self.vocab.unk  # todo(note): UNK is the prediction boundary
+        self.pred_layer = self.add_sub_node("pr", Affine(pc, self.pred_input_dim, self.tagset_size+2, init_rop=NoDropRop()))
+        # transition matrix
+        init_transitions = np.zeros([self.tagset_size+2, self.tagset_size+2])
+        init_transitions[:, START_TAG] = -10000.0
+        init_transitions[STOP_TAG, :] = -10000.0
+        init_transitions[:, 0] = -10000.0
+        init_transitions[0, :] = -10000.0
+        self.transitions = self.add_param("T", (self.tagset_size+2, self.tagset_size+2), init=init_transitions)
+    # score
+    def _score(self, repr_t):
+        if self.hid_layer is None:
+            hid_t = repr_t.contiguous()
+        else:
+            hid_t = self.hid_layer(repr_t.contiguous())
+        out_t = self.pred_layer(hid_t)
+        return out_t
+    # loss
+    def loss(self, insts: List[GeneralSentence], repr_t, mask_t, **kwargs):
+        conf = self.conf
+        if self.add_root_token:
+            repr_t = repr_t[:,1:]
+            mask_t = mask_t[:,1:]
+        # score
+        scores_t = self._score(repr_t)  # [bs, rlen, D]
+        # get gold
+        gold_pidxes = np.zeros(BK.get_shape(mask_t), dtype=np.long)  # [bs, ?+rlen]
+        for bidx, inst in enumerate(insts):
+            cur_seq_idxes = getattr(inst, self.attr_name).idxes
+            gold_pidxes[bidx, :len(cur_seq_idxes)] = cur_seq_idxes
+        # get loss
+        gold_pidxes_t = BK.input_idx(gold_pidxes)
+        nll_loss_sum = self.neg_log_likelihood_loss(scores_t, mask_t.bool(), gold_pidxes_t)
+        if not conf.div_by_tok:  # otherwise div by sent
+            nll_loss_sum *= (mask_t.sum() / len(insts))
+        # compile loss
+        crf_loss = LossHelper.compile_leaf_info("crf", nll_loss_sum.sum(), mask_t.sum())
+        return self._compile_component_loss(self.pname, [crf_loss])
+    # predict
+    def predict(self, insts: List[GeneralSentence], repr_t, mask_t, **kwargs):
+        conf = self.conf
+        if self.add_root_token:
+            repr_t = repr_t[:,1:]
+            mask_t = mask_t[:,1:]
+        # score
+        scores_t = self._score(repr_t)  # [bs, rlen, D]
+        # decode
+        _, decode_idx = self._viterbi_decode(scores_t, mask_t.bool())
+        decode_idx_arr = BK.get_value(decode_idx)  # [bs, rlen]
+        for one_bidx, one_inst in enumerate(insts):
+            one_pidxes = decode_idx_arr[one_bidx].tolist()[:len(one_inst)]
+            one_pseq = SeqField(None)
+            one_pseq.build_vals(one_pidxes, self.vocab)
+            one_inst.add_item("pred_"+self.attr_name, one_pseq, assert_non_exist=False)
+        return
+    # =====
+    def _calculate_PZ(self, feats, mask):
+        """
+            input:
+                feats: (batch, seq_len, self.tag_size+2)
+                masks: (batch, seq_len)
+        """
+        batch_size = feats.size(0)
+        seq_len = feats.size(1)
+        tag_size = feats.size(2)
+        # print feats.view(seq_len, tag_size)
+        assert (tag_size == self.tagset_size+2)
+        mask = mask.transpose(1,0).contiguous()
+        ins_num = seq_len * batch_size
+        ## be careful the view shape, it is .view(ins_num, 1, tag_size) but not .view(ins_num, tag_size, 1)
+        feats = feats.transpose(1,0).contiguous().view(ins_num,1, tag_size).expand(ins_num, tag_size, tag_size)
+        ## need to consider start
+        scores = feats + self.transitions.view(1,tag_size,tag_size).expand(ins_num, tag_size, tag_size)
+        scores = scores.view(seq_len, batch_size, tag_size, tag_size)
+        # build iter
+        seq_iter = enumerate(scores)
+        _, inivalues = next(seq_iter)  # bat_size * from_target_size * to_target_size
+        # only need start from start_tag
+        partition = inivalues[:, START_TAG, :].clone().view(batch_size, tag_size, 1)  # bat_size * to_target_size
+        ## add start score (from start to all tag, duplicate to batch_size)
+        # partition = partition + self.transitions[START_TAG,:].view(1, tag_size, 1).expand(batch_size, tag_size, 1)
+        # iter over last scores
+        for idx, cur_values in seq_iter:
+            # previous to_target is current from_target
+            # partition: previous results log(exp(from_target)), #(batch_size * from_target)
+            # cur_values: bat_size * from_target * to_target
+            cur_values = cur_values + partition.contiguous().view(batch_size, tag_size, 1).expand(batch_size, tag_size, tag_size)
+            cur_partition = log_sum_exp(cur_values, tag_size)
+            # print cur_partition.data
+                # (bat_size * from_target * to_target) -> (bat_size * to_target)
+            # partition = utils.switch(partition, cur_partition, mask[idx].view(bat_size, 1).expand(bat_size, self.tagset_size)).view(bat_size, -1)
+            mask_idx = mask[idx, :].view(batch_size, 1).expand(batch_size, tag_size)
+            ## effective updated partition part, only keep the partition value of mask value = 1
+            masked_cur_partition = cur_partition.masked_select(mask_idx)
+            ## let mask_idx broadcastable, to disable warning
+            mask_idx = mask_idx.contiguous().view(batch_size, tag_size, 1)
+            ## replace the partition where the maskvalue=1, other partition value keeps the same
+            partition.masked_scatter_(mask_idx, masked_cur_partition)
+        # until the last state, add transition score for all partition (and do log_sum_exp) then select the value in STOP_TAG
+        cur_values = self.transitions.view(1,tag_size, tag_size).expand(batch_size, tag_size, tag_size) + partition.contiguous().view(batch_size, tag_size, 1).expand(batch_size, tag_size, tag_size)
+        cur_partition = log_sum_exp(cur_values, tag_size)
+        final_partition = cur_partition[:, STOP_TAG]
+        return final_partition.sum(), scores
+    def _viterbi_decode(self, feats, mask):
+        """
+            input:
+                feats: (batch, seq_len, self.tag_size+2)
+                mask: (batch, seq_len)
+            output:
+                decode_idx: (batch, seq_len) decoded sequence
+                path_score: (batch, 1) corresponding score for each sequence (to be implementated)
+        """
+        batch_size = feats.size(0)
+        seq_len = feats.size(1)
+        tag_size = feats.size(2)
+        assert(tag_size == self.tagset_size+2)
+        ## calculate sentence length for each sentence
+        length_mask = torch.sum(mask.long(), dim = 1).view(batch_size,1).long()
+        ## mask to (seq_len, batch_size)
+        mask = mask.transpose(1,0).contiguous()
+        ins_num = seq_len * batch_size
+        ## be careful the view shape, it is .view(ins_num, 1, tag_size) but not .view(ins_num, tag_size, 1)
+        feats = feats.transpose(1,0).contiguous().view(ins_num, 1, tag_size).expand(ins_num, tag_size, tag_size)
+        ## need to consider start
+        scores = feats + self.transitions.view(1,tag_size,tag_size).expand(ins_num, tag_size, tag_size)
+        scores = scores.view(seq_len, batch_size, tag_size, tag_size)
+        # build iter
+        seq_iter = enumerate(scores)
+        ## record the position of best score
+        back_points = list()
+        partition_history = list()
+        ##  reverse mask (bug for mask = 1- mask, use this as alternative choice)
+        # mask = 1 + (-1)*mask
+        # mask =  (1 - mask.long()).byte()
+        mask =  (1 - mask.long()).bool()
+        _, inivalues = next(seq_iter)  # bat_size * from_target_size * to_target_size
+        # only need start from start_tag
+        partition = inivalues[:, START_TAG, :].clone().view(batch_size, tag_size)  # bat_size * to_target_size
+        # print "init part:",partition.size()
+        partition_history.append(partition)
+        # iter over last scores
+        for idx, cur_values in seq_iter:
+            # previous to_target is current from_target
+            # partition: previous results log(exp(from_target)), #(batch_size * from_target)
+            # cur_values: batch_size * from_target * to_target
+            cur_values = cur_values + partition.contiguous().view(batch_size, tag_size, 1).expand(batch_size, tag_size, tag_size)
+            ## forscores, cur_bp = torch.max(cur_values[:,:-2,:], 1) # do not consider START_TAG/STOP_TAG
+            # print "cur value:", cur_values.size()
+            partition, cur_bp = torch.max(cur_values, 1)
+            # print "partsize:",partition.size()
+            # exit(0)
+            # print partition
+            # print cur_bp
+            # print "one best, ",idx
+            partition_history.append(partition)
+            ## cur_bp: (batch_size, tag_size) max source score position in current tag
+            ## set padded label as 0, which will be filtered in post processing
+            cur_bp.masked_fill_(mask[idx].view(batch_size, 1).expand(batch_size, tag_size), 0)
+            back_points.append(cur_bp)
+        # exit(0)
+        ### add score to final STOP_TAG
+        partition_history = torch.cat(partition_history, 0).view(seq_len, batch_size, -1).transpose(1,0).contiguous() ## (batch_size, seq_len. tag_size)
+        ### get the last position for each setences, and select the last partitions using gather()
+        last_position = length_mask.view(batch_size,1,1).expand(batch_size, 1, tag_size) -1
+        last_partition = torch.gather(partition_history, 1, last_position).view(batch_size,tag_size,1)
+        ### calculate the score from last partition to end state (and then select the STOP_TAG from it)
+        last_values = last_partition.expand(batch_size, tag_size, tag_size) + self.transitions.view(1,tag_size, tag_size).expand(batch_size, tag_size, tag_size)
+        _, last_bp = torch.max(last_values, 1)
+        # pad_zero = autograd.Variable(torch.zeros(batch_size, tag_size)).long()
+        # if self.gpu:
+        #     pad_zero = pad_zero.cuda()
+        pad_zero = BK.zeros((batch_size, tag_size)).long()
+        back_points.append(pad_zero)
+        back_points  =  torch.cat(back_points).view(seq_len, batch_size, tag_size)
+        ## select end ids in STOP_TAG
+        pointer = last_bp[:, STOP_TAG]
+        insert_last = pointer.contiguous().view(batch_size,1,1).expand(batch_size,1, tag_size)
+        back_points = back_points.transpose(1,0).contiguous()
+        ## move the end ids(expand to tag_size) to the corresponding position of back_points to replace the 0 values
+        # print "lp:",last_position
+        # print "il:",insert_last
+        back_points.scatter_(1, last_position, insert_last)
+        # print "bp:",back_points
+        # exit(0)
+        back_points = back_points.transpose(1,0).contiguous()
+        ## decode from the end, padded position ids are 0, which will be filtered if following evaluation
+        # decode_idx = autograd.Variable(torch.LongTensor(seq_len, batch_size))
+        # if self.gpu:
+        #     decode_idx = decode_idx.cuda()
+        decode_idx = BK.zeros((seq_len, batch_size)).long()
+        decode_idx[-1] = pointer.detach()
+        for idx in range(len(back_points)-2, -1, -1):
+            pointer = torch.gather(back_points[idx], 1, pointer.contiguous().view(batch_size, 1))
+            decode_idx[idx] = pointer.detach().view(batch_size)
+        path_score = None
+        decode_idx = decode_idx.transpose(1,0)
+        return path_score, decode_idx
+    def _score_sentence(self, scores, mask, tags):
+        """
+            input:
+                scores: variable (seq_len, batch, tag_size, tag_size)
+                mask: (batch, seq_len)
+                tags: tensor  (batch, seq_len)
+            output:
+                score: sum of score for gold sequences within whole batch
+        """
+        # Gives the score of a provided tag sequence
+        batch_size = scores.size(1)
+        seq_len = scores.size(0)
+        tag_size = scores.size(2)
+        ## convert tag value into a new format, recorded label bigram information to index
+        # new_tags = autograd.Variable(torch.LongTensor(batch_size, seq_len))
+        # if self.gpu:
+        #     new_tags = new_tags.cuda()
+        new_tags = BK.zeros((batch_size, seq_len)).long()
+        for idx in range(seq_len):
+            if idx == 0:
+                ## start -> first score
+                new_tags[:,0] =  (tag_size - 2)*tag_size + tags[:,0]
+            else:
+                new_tags[:,idx] =  tags[:,idx-1]*tag_size + tags[:,idx]
+        ## transition for label to STOP_TAG
+        end_transition = self.transitions[:,STOP_TAG].contiguous().view(1, tag_size).expand(batch_size, tag_size)
+        ## length for batch,  last word position = length - 1
+        length_mask = torch.sum(mask.long(), dim = 1).view(batch_size,1).long()
+        ## index the label id of last word
+        end_ids = torch.gather(tags, 1, length_mask - 1)
+        ## index the transition score for end_id to STOP_TAG
+        end_energy = torch.gather(end_transition, 1, end_ids)
+        ## convert tag as (seq_len, batch_size, 1)
+        new_tags = new_tags.transpose(1,0).contiguous().view(seq_len, batch_size, 1)
+        ### need convert tags id to search from 400 positions of scores
+        tg_energy = torch.gather(scores.view(seq_len, batch_size, -1), 2, new_tags).view(seq_len, batch_size)  # seq_len * bat_size
+        ## mask transpose to (seq_len, batch_size)
+        tg_energy = tg_energy.masked_select(mask.transpose(1,0))
+        # ## calculate the score from START_TAG to first label
+        # start_transition = self.transitions[START_TAG,:].view(1, tag_size).expand(batch_size, tag_size)
+        # start_energy = torch.gather(start_transition, 1, tags[0,:])
+        ## add all score together
+        # gold_score = start_energy.sum() + tg_energy.sum() + end_energy.sum()
+        gold_score = tg_energy.sum() + end_energy.sum()
+        return gold_score
+    def neg_log_likelihood_loss(self, feats, mask, tags):
+        # nonegative log likelihood
+        batch_size = feats.size(0)
+        forward_score, scores = self._calculate_PZ(feats, mask)
+        gold_score = self._score_sentence(scores, mask, tags)
+        # print "batch, f:", forward_score.data[0], " g:", gold_score.data[0], " dis:", forward_score.data[0] - gold_score.data[0]
+        # exit(0)
+        return forward_score - gold_score
diff --git a/tasks/zmlm/model/mods/seqlab.py b/tasks/zmlm/model/mods/seqlab.py
new file mode 100644
index 0000000..55e4ad9
--- /dev/null
+++ b/tasks/zmlm/model/mods/seqlab.py
@@ -0,0 +1,100 @@
+# sequence labeling module
+# -- simple individual classifiers
+from typing import List, Dict, Tuple
+import numpy as np
+from msp.utils import Conf, Random, zlog, JsonRW, zfatal, zwarn
+from msp.nn import BK
+from msp.nn.layers import Affine, NoDropRop
+from ..base import BaseModuleConf, BaseModule, LossHelper
+from .embedder import Inputter
+from ...data.insts import GeneralSentence, SeqField
+# --
+class SeqLabNodeConf(BaseModuleConf):
+    def __init__(self):
+        super().__init__()
+        self.loss_lambda.val = 0.  # by default no such loss
+        # --
+        # hidden layer
+        self.hid_dim = 0
+        self.hid_act = "elu"
+class SeqLabNode(BaseModule):
+    def __init__(self, pc: BK.ParamCollection, pname: str, input_dim: int, conf: SeqLabNodeConf, inputter: Inputter):
+        super().__init__(pc, conf, name="SLB")
+        self.conf = conf
+        self.inputter = inputter
+        self.input_dim = input_dim
+        # this step is performed at the embedder, thus still does not influence the inputter
+        self.add_root_token = self.inputter.embedder.add_root_token
+        # --
+        self.pname = pname
+        self.attr_name = pname + "_seq"  # attribute name in Instance
+        self.vocab = inputter.vpack.get_voc(pname)
+        # todo(note): we must make sure that 0 means NAN
+        assert self.vocab.non == 0
+        # models
+        if conf.hid_dim <= 0:  # no hidden layer
+            self.hid_layer = None
+            self.pred_input_dim = input_dim
+        else:
+            self.hid_layer = self.add_sub_node("hid", Affine(pc, input_dim, conf.hid_dim, act=conf.hid_act))
+            self.pred_input_dim = conf.hid_dim
+        self.pred_out_dim = self.vocab.unk  # todo(note): UNK is the prediction boundary
+        self.pred_layer = self.add_sub_node("pr", Affine(pc, self.pred_input_dim, self.pred_out_dim, init_rop=NoDropRop()))
+    # score
+    def _score(self, repr_t):
+        if self.hid_layer is None:
+            hid_t = repr_t
+        else:
+            hid_t = self.hid_layer(repr_t)
+        out_t = self.pred_layer(hid_t)
+        return out_t
+    # loss
+    def loss(self, insts: List[GeneralSentence], repr_t, mask_t, **kwargs):
+        conf = self.conf
+        # score
+        scores_t = self._score(repr_t)  # [bs, ?+rlen, D]
+        # get gold
+        gold_pidxes = np.zeros(BK.get_shape(mask_t), dtype=np.long)  # [bs, ?+rlen]
+        for bidx, inst in enumerate(insts):
+            cur_seq_idxes = getattr(inst, self.attr_name).idxes
+            if self.add_root_token:
+                gold_pidxes[bidx, 1:1+len(cur_seq_idxes)] = cur_seq_idxes
+            else:
+                gold_pidxes[bidx, :len(cur_seq_idxes)] = cur_seq_idxes
+        # get loss
+        margin = self.margin.value
+        gold_pidxes_t = BK.input_idx(gold_pidxes)
+        gold_pidxes_t *= (gold_pidxes_t<self.pred_out_dim).long()  # 0 means invalid ones!!
+        loss_mask_t = (gold_pidxes_t>0).float() * mask_t  # [bs, ?+rlen]
+        lab_losses_t = BK.loss_nll(scores_t, gold_pidxes_t, margin=margin)  # [bs, ?+rlen]
+        # argmax
+        _, argmax_idxes = scores_t.max(-1)
+        pred_corrs = (argmax_idxes == gold_pidxes_t).float() * loss_mask_t
+        # compile loss
+        lab_loss = LossHelper.compile_leaf_info("slab", lab_losses_t.sum(), loss_mask_t.sum(), corr=pred_corrs.sum())
+        return self._compile_component_loss(self.pname, [lab_loss])
+    # predict
+    def predict(self, insts: List[GeneralSentence], repr_t, mask_t, **kwargs):
+        conf = self.conf
+        # score
+        scores_t = self._score(repr_t)  # [bs, ?+rlen, D]
+        _, argmax_idxes = scores_t.max(-1)  # [bs, ?+rlen]
+        argmax_idxes_arr = BK.get_value(argmax_idxes)  # [bs, ?+rlen]
+        # assign; todo(+2): record scores?
+        one_offset = int(self.add_root_token)
+        for one_bidx, one_inst in enumerate(insts):
+            one_pidxes = argmax_idxes_arr[one_bidx, one_offset:one_offset+len(one_inst)].tolist()
+            one_pseq = SeqField(None)
+            one_pseq.build_vals(one_pidxes, self.vocab)
+            one_inst.add_item("pred_"+self.attr_name, one_pseq, assert_non_exist=False)
+        return
diff --git a/tasks/zmlm/model/mods/vrec/__init__.py b/tasks/zmlm/model/mods/vrec/__init__.py
new file mode 100644
index 0000000..72a5b0d
--- /dev/null
+++ b/tasks/zmlm/model/mods/vrec/__init__.py
@@ -0,0 +1,10 @@
+# the MuRec encoder
+# -----
+# for the murec-att-node, it aims to collect attentinal evidences:
+# c1_scorer / c2_normer / c3_collector
+# for the murec-layer, it is composed of murec-att & rec-wrapper
+from .vrec import VRecEncoderConf, VRecEncoder, VRecCache, VRecConf, VRecNode
diff --git a/tasks/zmlm/model/mods/vrec/base.py b/tasks/zmlm/model/mods/vrec/base.py
new file mode 100644
index 0000000..0f6da94
--- /dev/null
+++ b/tasks/zmlm/model/mods/vrec/base.py
@@ -0,0 +1,137 @@
+# some basic components
+from typing import List
+from msp.utils import Conf
+from msp.nn import BK
+from msp.nn.layers import BasicNode, ActivationHelper, Affine, NoDropRop, Dropout
+# -----
+# Affine/MLP layer for convenient use
+class AffineHelperNode(BasicNode):
+    def __init__(self, pc, input_dim, hid_dim: int, hid_act='linear', hid_drop=0., hid_piece4init=1,
+                 out_dim=0, out_fbias=0., out_fact="linear", out_piece4init=1, init_scale=1.):
+        super().__init__(pc, None, None)
+        # -----
+        # hidden layer
+        self.hid_aff = self.add_sub_node("hidden", Affine(
+            pc, input_dim, hid_dim, n_piece4init=hid_piece4init, init_scale=init_scale, init_rop=NoDropRop()))
+        self.hid_act_f = ActivationHelper.get_act(hid_act)
+        self.hid_drop = self.add_sub_node("drop", Dropout(pc, (hid_dim, ), fix_rate=hid_drop))
+        # -----
+        # output layer (optional)
+        self.final_output_dim = hid_dim
+        # todo(+N): how about split hidden layers for each specific output
+        self.out_fbias = out_fbias  # fixed extra bias
+        self.out_act_f = ActivationHelper.get_act(out_fact)
+        # no output dropouts
+        if out_dim > 0:
+            assert hid_act != "linear", "Please use non-linear activation for mlp!"
+            assert hid_piece4init == 1, "Strange hid_piece4init for hidden layer with out_dim>0"
+            self.final_aff = self.add_sub_node("final", Affine(
+                pc, hid_dim, out_dim, n_piece4init=out_piece4init, init_scale=init_scale, init_rop=NoDropRop()))
+            self.final_output_dim = out_dim
+        else:
+            self.final_aff = None
+    def get_output_dims(self, *input_dims):
+        return (self.final_output_dim, )
+    def __call__(self, input_t):
+        # hidden
+        hid_t = self.hid_aff(input_t)
+        # act and drop
+        act_hid_t = self.hid_act_f(hid_t)
+        drop_hid_t = self.hid_drop(act_hid_t)
+        # optional extra layer
+        if self.final_aff:
+            out0 = self.final_aff(drop_hid_t)
+            out1 = self.out_act_f(out0 + self.out_fbias)  # fixed bias and final activation
+            return out1
+        else:
+            return drop_hid_t
+# -----
+# gumbel-softmax / concrete-node
+class ConcreteNodeConf(Conf):
+    def __init__(self):
+        self.use_gumbel = False
+        self.use_argmax = False
+        self.prune_val = 0.  # hard prune for probs
+# discrete / gumbel softmax
+class ConcreteNode(BasicNode):
+    def __init__(self, pc, conf: ConcreteNodeConf):
+        super().__init__(pc, None, None)
+        self.use_gumbel = conf.use_gumbel
+        self.use_argmax = conf.use_argmax
+        self.prune_val = conf.prune_val
+        self.gumbel_eps = 1e-10
+    # [*, S] -> normalized [*, S]
+    def __call__(self, scores, temperature=1., dim=-1):
+        is_training = self.rop.training
+        # only use stochastic at training
+        if is_training:
+            if self.use_gumbel:
+                gumbel_eps = self.gumbel_eps
+                G = (BK.rand(BK.get_shape(scores)) + gumbel_eps).clamp(max=1.)  # [0,1)
+                scores = scores - (gumbel_eps - G.log()).log()
+        # normalize
+        probs = BK.softmax(scores / temperature, dim=dim)  # [*, S]
+        # prune and re-normalize?
+        if self.prune_val > 0.:
+            probs = probs * (probs > self.prune_val).float()
+            # todo(note): currently no re-normalize
+            # probs = probs / probs.sum(dim=dim, keepdim=True)  # [*, S]
+        # argmax and ste
+        if self.use_argmax:  # use the hard argmax
+            max_probs, _ = probs.max(dim, keepdim=True)  # [*, 1]
+            # todo(+N): currently we do not re-normalize here, should it be done here?
+            st_probs = (probs>=max_probs).float() * probs  # [*, S]
+            if is_training:  # (hard-soft).detach() + soft
+                st_probs = (st_probs - probs).detach() + probs  # [*, S]
+            return st_probs
+        else:
+            return probs
+# -----
+# affine combiner
+class AffineCombiner(BasicNode):
+    def __init__(self, pc, input_dims: List[int], use_affs: List[bool], out_dim: int,
+                 out_act='linear', out_drop=0., param_init_scale=1.):
+        super().__init__(pc, None, None)
+        # -----
+        self.input_dims = input_dims
+        self.use_affs = use_affs
+        self.out_dim = out_dim
+        # =====
+        assert len(input_dims) == len(use_affs)
+        self.aff_nodes = []
+        for d, use in zip(input_dims, use_affs):
+            if use:
+                one_aff = self.add_sub_node("aff", Affine(pc, d, out_dim, init_scale=param_init_scale, init_rop=NoDropRop()))
+            else:
+                assert d == out_dim, f"Dimension mismatch for skipping affine: {d} vs {out_dim}!"
+                one_aff = None
+            self.aff_nodes.append(one_aff)
+        self.out_act_f = ActivationHelper.get_act(out_act)
+        self.out_drop = self.add_sub_node("drop", Dropout(pc, (out_dim, ), fix_rate=out_drop))
+    def get_output_dims(self, *input_dims):
+        return (self.out_dim, )
+    # input is List[*, len, dim]
+    def __call__(self, input_ts: List):
+        comb_t = 0.
+        assert len(input_ts) == len(self.aff_nodes), "Input number mismatch"
+        for one_t, one_aff in zip(input_ts, self.aff_nodes):
+            if one_aff:
+                add_t = one_aff(one_t)
+            else:
+                add_t = one_t
+            comb_t += add_t
+        comb_t_act = self.out_act_f(comb_t)
+        comb_t_drop = self.out_drop(comb_t_act)
+        return comb_t_drop
diff --git a/tasks/zmlm/model/mods/vrec/c1_score.py b/tasks/zmlm/model/mods/vrec/c1_score.py
new file mode 100644
index 0000000..92ad120
--- /dev/null
+++ b/tasks/zmlm/model/mods/vrec/c1_score.py
@@ -0,0 +1,209 @@
+# Component 1 for MAtt: scorer
+# (Q, K, accu_attns, rel_dist) => scores [*, len_q, len_k, head]
+from typing import List, Union, Dict, Iterable
+from msp.utils import Constants, Conf, zwarn, Helper
+from msp.nn import BK
+from msp.nn.layers import BasicNode, ActivationHelper, Affine, NoDropRop, NoFixRop, Dropout, PosiEmbedding2, LayerNorm
+from .base import AffineHelperNode
+# -----
+# conf
+class MAttScorerConf(Conf):
+    def __init__(self):
+        # special
+        self.param_init_scale = 1.  # todo(+N): ugly extra scale
+        self.use_piece4init = False
+        # dot-product for Q*K
+        self.d_qk = 64  # query*key -> score
+        self.qk_act = 'linear'
+        self.qk_drop = 0.
+        self.score_scale_power = 0.5  # scale power for the dot product scoring
+        self.use_unhead_score = False  # general score (untyped) for q*k
+        # relative distances?
+        self.use_rel_dist = False
+        self.rel_emb_dim = 512  # this should be enough
+        self.rel_emb_zero0 = True  # zero the 0 item
+        self.rel_dist_clip = 1000  # this should be enough
+        self.rel_dist_abs = False  # discard sign on rel dist, if True: [0, clip] else [-clip, clip]
+        self.rel_init_sincos = True  # init and freeze rel_dist embeddings with sincos version
+        # q(/h)-specific lambdas to minus accu_attns
+        self.use_lambq = False
+        self.lambq_hdim = 128
+        self.lambq_hdrop = 0.1
+        self.lambq_hact = "elu"
+        self.lambq_fbias = 0.  # fixed extra bias
+        self.lambq_fact = "relu"  # output activation after adding bias
+        # q(/h)-specific score offset
+        self.use_soff = False
+        self.soff_hdim = 128
+        self.soff_hdrop = 0.1
+        self.soff_hact = "elu"
+        self.soff_fbias = 0.  # fixed extra bias
+        self.soff_fact = "relu"  # output activation after adding bias
+        # no (masking out) self-loop?
+        self.no_self_loop = False
+    def get_att_scale_qk(self):
+        return float(self.d_qk ** self.score_scale_power)
+# the scorer
+class MAttScorerNode(BasicNode):
+    def __init__(self, pc: BK.ParamCollection, dim_q, dim_k, head_count, conf: MAttScorerConf):
+        super().__init__(pc, None, None)
+        self.conf = conf
+        self.dim_q, self.dim_k = dim_q, dim_k
+        self.head_count = head_count
+        # -----
+        self.use_unhead_score = conf.use_unhead_score
+        _calc_head_count = (1+head_count) if self.use_unhead_score else head_count  # extra dim for overall unhead score
+        _d_qk = conf.d_qk
+        _hid_size_qk = _calc_head_count * _d_qk
+        self._att_scale_qk = conf.get_att_scale_qk()  # scale for dot product
+        self.split_dims = [_calc_head_count, -1]
+        # affine-qk
+        self.affine_q = self.add_sub_node("aq", AffineHelperNode(
+            pc, dim_q, _hid_size_qk, hid_act=conf.qk_act, hid_drop=conf.qk_drop,
+            hid_piece4init=(_calc_head_count if conf.use_piece4init else 1), init_scale=conf.param_init_scale))
+        self.affine_k = self.add_sub_node("ak", AffineHelperNode(
+            pc, dim_k, _hid_size_qk, hid_act=conf.qk_act, hid_drop=conf.qk_drop,
+            hid_piece4init=(_calc_head_count if conf.use_piece4init else 1), init_scale=conf.param_init_scale))
+        # rel positional
+        if conf.use_rel_dist:
+            self.dist_helper = self.add_sub_node("dh", RelDistHelperNode(pc, _calc_head_count, conf))
+        else:
+            self.dist_helper = None
+        # lambq
+        if conf.use_lambq:
+            self.lambq_aff = self.add_sub_node("al", AffineHelperNode(
+                pc, dim_q, hid_dim=conf.lambq_hdim, hid_act=conf.lambq_hact, hid_drop=conf.lambq_hdrop,
+                out_dim=head_count, out_fbias=conf.lambq_fbias, out_fact=conf.lambq_fact,
+                out_piece4init=(head_count if conf.use_piece4init else 1), init_scale=conf.param_init_scale))
+        else:
+            self.lambq_aff = None
+        # soff
+        if conf.use_soff:
+            self.soff_aff = self.add_sub_node("as", AffineHelperNode(
+                pc, dim_q, hid_dim=conf.soff_hdim, hid_act=conf.soff_hact, hid_drop=conf.soff_hdrop,
+                out_dim=(1+head_count), out_fbias=conf.soff_fbias, out_fact=conf.soff_fact,
+                out_piece4init=((1+head_count) if conf.use_piece4init else 1), init_scale=conf.param_init_scale))
+        else:
+            self.soff_aff = None
+        # no (masking out) self loop?
+        self.no_self_loop = conf.no_self_loop
+    # -----
+    # [*, len, head*dim] -> [*, head, len, dim] if do_transpose else [*, len, head, dim]
+    def _shape_project(self, x, do_transpose: bool):
+        x_size = BK.get_shape(x)[:-1] + self.split_dims
+        x_reshaped = x.view(x_size)
+        return x_reshaped.transpose(-2, -3) if do_transpose else x_reshaped
+    # input is *[*, len?, Din], [*, len_q, len_k, head], [*, len_k], [*, len_q, len_k]
+    def __call__(self, query, key, accu_attn, mask_k, mask_qk, rel_dist):
+        conf = self.conf
+        # == calculate the dot-product scores
+        # calculate the three: # [bs, len_?, head*D]; and also add sta ones if needed
+        query_up, key_up = self.affine_q(query), self.affine_k(key)  # [*, len?, head?*Dqk]
+        query_up, key_up = self._shape_project(query_up, True), self._shape_project(key_up, True)  # [*, head?, len_?, D]
+        # original scores
+        scores = BK.matmul(query_up, BK.transpose(key_up, -1, -2)) / self._att_scale_qk  # [*, head?, len_q, len_k]
+        # == adding rel_dist ones
+        if conf.use_rel_dist:
+            scores = self.dist_helper(query_up, key_up, rel_dist=rel_dist, input_scores=scores)
+        # tranpose
+        scores = scores.transpose(-2, -3).transpose(-1, -2)  # [*, len_q, len_k, head?]
+        # == unhead score
+        if conf.use_unhead_score:
+            scores_t0, score_t1 = BK.split(scores, [1, self.head_count], -1)  # [*, len_q, len_k, 1|head]
+            scores = scores_t0 + score_t1  # [*, len_q, len_k, head]
+        # == combining with history accumulated attns
+        if conf.use_lambq and accu_attn is not None:
+            # todo(note): here we only consider "query" and "head", would it be necessary for "key"?
+            lambq_vals = self.lambq_aff(query)  # [*, len_q, head], if for eg., using relu as fact, this>=0
+            scores -= lambq_vals.unsqueeze(-2) * accu_attn
+        # == score offset
+        if conf.use_soff:
+            # todo(note): here we only consider "query" and "head", key may be handled by "unhead_score"
+            score_offset_t = self.soff_aff(query)  # [*, len_q, 1+head]
+            score_offset_t0, score_offset_t1 = BK.split(score_offset_t, [1, self.head_count], -1)  # [*, len_q, 1|head]
+            scores -= score_offset_t0.unsqueeze(-2)
+            scores -= score_offset_t1.unsqueeze(-2)  # still [*, len_q, len_k, head]
+        # == apply mask & no-self-loop
+        # NEG_INF = Constants.REAL_PRAC_MIN
+        NEG_INF = -1000.  # this should be enough
+        NEG_INF2 = -2000.  # this should be enough
+        if mask_k is not None:  # [*, 1, len_k, 1]
+            scores += (1.-mask_k).unsqueeze(-2).unsqueeze(-1) * NEG_INF2
+        if mask_qk is not None:  # [*, len_q, len_k, 1]
+            scores += (1.-mask_qk).unsqueeze(-1) * NEG_INF2
+        if self.no_self_loop:
+            query_len = BK.get_shape(query, -2)
+            assert query_len == BK.get_shape(key, -2), "Shape not matched for no_self_loop"
+            scores += BK.eye(query_len).unsqueeze(-1) * NEG_INF  # [len_q, len_k, 1]
+        return scores.contiguous()  # [*, len_q, len_k, head]
+# -----
+# RelDistHelper: here we apply (relative) positional info only to att/score calculation
+# -- following the ones in xl-transformer/xlnet
+class RelDistHelperNode(BasicNode):
+    def __init__(self, pc, head_count: int, conf: MAttScorerConf):
+        super().__init__(pc, None, None)
+        self.rel_dist_abs = conf.rel_dist_abs
+        _clip = conf.rel_dist_clip
+        _head_count = head_count
+        _d_qk = conf.d_qk
+        _d_emb = conf.rel_emb_dim
+        # scale
+        self._att_scale_qk = conf.get_att_scale_qk()  # scale for dot product
+        # positional embedding
+        self.E = self.add_sub_node("e", PosiEmbedding2(
+            pc, _d_emb, max_val=_clip, init_sincos=conf.rel_init_sincos, freeze=conf.rel_init_sincos, zero0=conf.rel_emb_zero0))
+        # different parameters for each head
+        self.split_dims = [_head_count, _d_qk]
+        self.affine_rel = self.add_sub_node("ar", AffineHelperNode(
+            pc, _d_emb, _head_count*_d_qk, hid_piece4init=(head_count if conf.use_piece4init else 1), init_scale=conf.param_init_scale))
+        self.vec_u = self.add_param("vu", (_head_count, _d_qk), scale=conf.param_init_scale)
+        self.vec_v = self.add_param("vv", (_head_count, _d_qk), scale=conf.param_init_scale)
+    # [query, key]
+    def get_rel_dist(self, len_q: int, len_k: int):
+        dist_x = BK.arange_idx(0, len_k).unsqueeze(0)  # [1, len_k]
+        dist_y = BK.arange_idx(0, len_q).unsqueeze(1)  # [len_q, 1]
+        distance = dist_x - dist_y  # [len_q, len_k]
+        return distance
+    # [*, head, len_q/len_k, D]; could be inplaced adding "input_scores"
+    def __call__(self, query_up, key_up, rel_dist=None, input_scores=None):
+        _att_scale_qk = self._att_scale_qk
+        # -----
+        # get dim info
+        len_q, len_k = BK.get_shape(query_up, -2), BK.get_shape(key_up, -2)
+        # get distance embeddings
+        if rel_dist is None:
+            rel_dist = self.get_rel_dist(len_q, len_k)
+        if self.rel_dist_abs:  # use abs?
+            rel_dist = BK.abs(rel_dist)
+        dist_embs = self.E(rel_dist)  # [len_q, len_k, Demb]
+        # -----
+        # dist_up
+        dist_up0 = self.affine_rel(dist_embs)  # [len_q, len_k, head*D]
+        # -> [head, len_q, len_k, D]
+        dist_up1 = dist_up0.view(BK.get_shape(dist_up0)[:-1] + self.split_dims).transpose(-2, -3).transpose(-3, -4)
+        # -----
+        # all items are [*, head, len_q, len_k]
+        posi_scores = (input_scores if (input_scores is not None) else 0.)
+        # item (b): <query, dist>: [head, len_q, len_k, D] * [*, head, len_q, D, 1] -> [*, head, len_q, len_k]
+        item_b = (BK.matmul(dist_up1, query_up.unsqueeze(-1)) / _att_scale_qk).squeeze(-1)
+        posi_scores += item_b
+        # todo(note): remove this item_c since it is not related with rel_dist
+        # # item (c): <key, u>: [*, head, len_k, D] * [head, D, 1] -> [*, head, 1, len_k]
+        # item_c = (BK.matmul(key_up, self.vec_u.unsqueeze(-1)) / _att_scale_qk).squeeze(-1).unsqueeze(-2)
+        # posi_scores += item_c
+        # item (d): <dist, v>: [head, len_q, len_k, D] * [head, 1, D, 1] -> [head, len_q, len_k]
+        item_d = (BK.matmul(dist_up1, self.vec_v.unsqueeze(-2).unsqueeze(-1)) / _att_scale_qk).squeeze(-1)
+        posi_scores += item_d
+        return posi_scores
diff --git a/tasks/zmlm/model/mods/vrec/c2_norm.py b/tasks/zmlm/model/mods/vrec/c2_norm.py
new file mode 100644
index 0000000..fbfcf88
--- /dev/null
+++ b/tasks/zmlm/model/mods/vrec/c2_norm.py
@@ -0,0 +1,135 @@
+# Component 2 for MAtt: normer
+# (scores [*, len_q, len_k, head], accu_attns) => normalized-scores [*, len_q, len_k, head]
+from typing import List, Union, Dict, Iterable
+from msp.utils import Constants, Conf, zwarn, Helper
+from msp.nn import BK
+from msp.nn.layers import BasicNode, ActivationHelper, Affine, NoDropRop, NoFixRop, Dropout, PosiEmbedding2, LayerNorm
+from .base import ConcreteNodeConf, ConcreteNode
+# -----
+# conf
+class MAttNormerConf(Conf):
+    def __init__(self):
+        # special
+        # self.param_init_scale = 1.  # actually no params in this node
+        # noop compare score (fixed)
+        self.use_noop = False  # attend to nothing
+        self.noop_fixed_val = 0.  # fixed noop score
+        # norm
+        # todo(+N): here no longer norm multiple times since that seems complicated
+        self.norm_mode = 'cand'  # flatten/head/cand/head_cand/binary
+        self.norm_prune = 0.  # prune final normalized scores to 0. if <=this
+        # concrete
+        self.cconf = ConcreteNodeConf()
+        # dropout directly on the normalized scores
+        self.attn_dropout = 0.
+# for getting normalized scores (like probs)
+class MAttNormerNode(BasicNode):
+    def __init__(self, pc: BK.ParamCollection, head_count, conf: MAttNormerConf):
+        super().__init__(pc, None, None)
+        self.conf = conf
+        self.head_count = head_count
+        # -----
+        self._norm_f = getattr(self, "_norm_"+conf.norm_mode)  # shortcut
+        self._norm_dims = {'flatten': [-1], 'head': [-1], 'cand': [-2], 'head_cand': [-1, -2], 'binary': [-1]}[conf.norm_mode]
+        self.norm_prune = conf.norm_prune
+        # cnode: special attention
+        self.cnode = self.add_sub_node("cn", ConcreteNode(pc, conf.cconf))
+        # attention dropout: # no-fix attention dropout, not elegant here
+        rr = NoFixRop()
+        self.adrop = self.add_sub_node("adrop", Dropout(pc, (), init_rop=rr, fix_rate=conf.attn_dropout))
+    # =====
+    # different strategies to calculate prob matrix
+    @property
+    def norm_dims(self):
+        return self._norm_dims
+    # helper function for noop
+    def _normalize(self, cnode: ConcreteNode, orig_scores, use_noop: bool, noop_fixed_val: float, temperature: float, dim: int):
+        cur_shape = BK.get_shape(orig_scores)  # original
+        orig_that_dim = cur_shape[dim]
+        cur_shape[dim] = 1
+        if use_noop:
+            noop_scores = BK.constants(cur_shape, value=noop_fixed_val)  # [*, 1, *]
+            to_norm_scores = BK.concat([orig_scores, noop_scores], dim=dim)  # [*, D+1, *]
+        else:
+            to_norm_scores = orig_scores  # [*, D, *]
+        # normalize
+        prob_full = cnode(to_norm_scores, temperature=temperature, dim=dim)  # [*, ?, *]
+        if use_noop:
+            prob_valid, prob_noop = BK.split(prob_full, [orig_that_dim, 1], dim)  # [*, D|1, *]
+        else:
+            prob_valid, prob_noop = prob_full, None
+        return prob_valid, prob_noop, prob_full
+    # 1. direct normalize over flattened [len_k, head]
+    def _norm_flatten(self, scores, temperature):
+        conf = self.conf
+        cnode = self.cnode
+        use_noop, noop_fixed_val = conf.use_noop, conf.noop_fixed_val
+        # -----
+        orig_shape = BK.get_shape(scores)
+        prob_valid, prob_noop, prob_full = self._normalize(
+            cnode, scores.view(orig_shape[:-2]+[-1]), use_noop, noop_fixed_val, temperature, -1)
+        attn = prob_valid.view(orig_shape)  # back to 2d shape
+        # [*, len_q, len_k, head], [*, len_q, len_k*head], [*, len_q, ?], [*, len_q, len_k*head+?]
+        return attn, [prob_valid], [prob_noop], [prob_full], [-1]
+    # 2. norm on head-dim for each cand
+    def _norm_head(self, scores, temperature):
+        conf = self.conf
+        cnode = self.cnode
+        use_noop, noop_fixed_val = conf.use_noop, conf.noop_fixed_val
+        # -----
+        prob_valid, prob_noop, prob_full = self._normalize(cnode, scores, use_noop, noop_fixed_val, temperature, -1)
+        # [*, len_q, len_k, head], [*, len_q, len_k, head], [*, len_q, len_k, ?], [*, len_q, len_k, head+?]
+        return prob_valid, [prob_valid], [prob_noop], [prob_full], [-1]
+    # 3. norm on cand-dim for each head
+    def _norm_cand(self, scores, temperature):
+        conf = self.conf
+        cnode = self.cnode
+        use_noop, noop_fixed_val = conf.use_noop, conf.noop_fixed_val
+        # -----
+        prob_valid, prob_noop, prob_full = self._normalize(cnode, scores, use_noop, noop_fixed_val, temperature, -2)
+        # [*, len_q, len_k, head], [*, len_q, len_k, head], [*, len_q, ?, head], [*, len_q, len_k+?, head]
+        return prob_valid, [prob_valid], [prob_noop], [prob_full], [-2]
+    # 4. multiplying (2) and (3)
+    def _norm_head_cand(self, scores, temperature):
+        conf = self.conf
+        cnode = self.cnode
+        use_noop, noop_fixed_val = conf.use_noop, conf.noop_fixed_val
+        # -----
+        prob_valid1, prob_noop1, prob_full1 = self._normalize(cnode, scores, use_noop, noop_fixed_val, temperature, -1)
+        prob_valid2, prob_noop2, prob_full2 = self._normalize(cnode, scores, use_noop, noop_fixed_val, temperature, -2)
+        # [*, len_q, len_k, head], [*, len_q, len_k, head], ...
+        return prob_valid1*prob_valid2, [prob_valid1, prob_valid2], [prob_noop1, prob_noop2], [prob_full1, prob_full2], [-1, -2]
+    # 5. binary for each individual one
+    def _norm_binary(self, scores, temperature):
+        conf = self.conf
+        cnode = self.cnode
+        use_noop, noop_fixed_val = conf.use_noop, conf.noop_fixed_val
+        # -----
+        prob_valid, prob_noop, prob_full = self._normalize(cnode, scores.unsqueeze(-1), use_noop, noop_fixed_val, temperature, -1)
+        # [*, len_q, len_k, head], [*, len_q, len_k, head], [*, len_q, len_k, head, ?], [*, len_q, len_k, head, 1+?]
+        attn = prob_valid.squeeze(-1)
+        return attn, [prob_valid], [prob_noop], [prob_full], [-1]
+    # input score is [*, len_q, len_k, head], output normalized ones and some extras
+    def __call__(self, scores, temperature):
+        attn, list_prob_valid, list_prob_noop, list_prob_full, list_dims = self._norm_f(scores, temperature)
+        # prune
+        if self.norm_prune > 0.:
+            attn = attn * (attn > self.norm_prune).float()
+        # dropout
+        attn = self.adrop(attn)
+        return attn, list_prob_valid, list_prob_noop, list_prob_full, list_dims
diff --git a/tasks/zmlm/model/mods/vrec/c3_collect.py b/tasks/zmlm/model/mods/vrec/c3_collect.py
new file mode 100644
index 0000000..00e0cbe
--- /dev/null
+++ b/tasks/zmlm/model/mods/vrec/c3_collect.py
@@ -0,0 +1,166 @@
+# Component 3 for MAtt: collector
+# (V, normalized_scores [*, len_q, len_k, head], accu_attns) => collected-V [*, len_q, Dv]
+from typing import List, Union, Dict, Iterable
+from msp.utils import Constants, Conf, zwarn, Helper
+from msp.nn import BK
+from msp.nn.layers import BasicNode, ActivationHelper, Affine, NoDropRop, NoFixRop, Dropout, PosiEmbedding2, LayerNorm
+from .base import AffineHelperNode
+# -----
+# conf
+class MAttCollectorConf(Conf):
+    def __init__(self):
+        # special
+        self.param_init_scale = 1.  # todo(+N): ugly extra scale
+        self.use_piece4init = False
+        # content of values
+        self.use_affv = True  # if not using, then directly use input (dim_v)
+        self.d_v = 64  # content reprs
+        self.v_act = 'linear'
+        self.v_drop = 0.
+        # attn & accu_attn
+        self.accu_lambda = 0.  # for eg., -1 means delete previous ones, +1 means consider previous ones
+        self.accu_ceil = 1.  # clamp final applying scores by certain range
+        self.accu_floor = 0.
+        # model based clamp
+        self.use_mclp = False
+        self.mclp_hdim = 128
+        self.mclp_hdrop = 0.1
+        self.mclp_hact = "elu"
+        self.mclp_fbias = 1.  # fixed extra bias (making mclp by default positive value)
+        self.mclp_fact = "relu"  # output activation after adding bias
+        # collecting: 1. weighted sum over certain dims, 2. further reduce if needed
+        self.collect_mode = "ch"  # direct_2d(2d)/cand_head(ch)/head_cand(hc)
+        self.collect_renorm = False  # renorm to avoid weights too large?
+        self.collect_renorm_mindiv = 1.  # avoid renorm on small values
+        self.collect_reducer_mode = "aff"  # max/sum/avg/aff; (aff is only valid for multihead-reduce)
+        self.collect_red_aff_out = 512  # output for "aff" mode, <0 means the same as d_v
+        self.collect_red_aff_act = 'linear'
+# the collector: normalized_scores * V
+class MAttCollectorNode(BasicNode):
+    def __init__(self, pc: BK.ParamCollection, dim_q, dim_v, head_count, conf: MAttCollectorConf):
+        super().__init__(pc, None, None)
+        self.conf = conf
+        self.dim_q, self.dim_v, self.head_count = dim_q, dim_v, head_count
+        # -----
+        _d_v = conf.d_v
+        _hid_size_v = head_count * _d_v
+        self.split_dims = [head_count, -1]
+        # affine-v
+        if conf.use_affv:
+            self.out_dim = _d_v
+            self.affine_v = self.add_sub_node("av", AffineHelperNode(
+                pc, dim_v, _hid_size_v, hid_act=conf.v_act, hid_drop=conf.v_drop,
+                hid_piece4init=(head_count if conf.use_piece4init else 1), init_scale=conf.param_init_scale))
+        else:
+            self.out_dim = dim_v
+            self.affine_v = None
+        # -----
+        # model based clamp
+        if conf.use_mclp:
+            self.mclp_aff = self.add_sub_node("am", AffineHelperNode(
+                pc, dim_q, hid_dim=conf.mclp_hdim, hid_act=conf.mclp_hact, hid_drop=conf.mclp_hdrop,
+                out_dim=head_count, out_fbias=conf.mclp_fbias, out_fact=conf.mclp_fact,
+                out_piece4init=(head_count if conf.use_piece4init else 1), init_scale=conf.param_init_scale))
+        else:
+            self.mclp_aff = None
+        # -----
+        # how to reduce the two dims of [len_k, head]
+        self._collect_f = getattr(self, "_collect_"+conf.collect_mode)  # shortcut
+        if conf.collect_reducer_mode == "aff":
+            aff_out_dim = _d_v if (conf.collect_red_aff_out<=0) else conf.collect_red_aff_out
+            aff_out_act = conf.collect_red_aff_act
+            self.out_dim = aff_out_dim
+            # no param_init_scale here, since output is not multihead
+            self.multihead_aff_node = self.add_sub_node("ma", Affine(pc, _hid_size_v, aff_out_dim,
+                                                                     act=aff_out_act, init_rop=NoDropRop()))
+        else:
+            self.multihead_aff_node = None
+        # todo(note): always reduce on dim=-2!!
+        self._collect_reduce_f = {
+            "sum": lambda x: x.sum(-2), "avg": lambda x: x.mean(-2), "max": lambda x: x.max(-2)[0],
+            "aff": lambda x: self.multihead_aff_node(x.view(BK.get_shape(x)[:-2] + [-1])),  # flatten last two and aff
+            "cat": lambda x: x.view(BK.get_shape(x)[:-2] + [-1]),  # directly flatten
+        }[conf.collect_reducer_mode]
+        # special mode "cat"
+        if conf.collect_reducer_mode == "cat":
+            assert conf.collect_mode == "ch", "Must get rid of the var. dim of cand first"
+            self.out_dim = _hid_size_v
+        # =====
+    def get_output_dims(self, *input_dims):
+        return (self.out_dim, )
+    # input is [*, len_k, D], *[*, len_q, len_k, head]
+    def __call__(self, query, value, attn, accu_attn):
+        conf = self.conf
+        # get values
+        if self.affine_v:
+            value_up = self.affine_v(value).view(BK.get_shape(value)[:-1]+self.split_dims)  # [*, (len_k, head), Dv]
+        else:
+            value_up = value.unsqueeze(-2)  # [*, len_k, 1, D]
+            value_up_expand_shape = BK.get_shape(value_up)
+            value_up_expand_shape[-2] = self.head_count
+            value_up = value_up.expand(value_up_expand_shape).contiguous()  # [*, len_k, head, D]
+        # get final applying attn scores: [*, len_q, (len_k, head)]
+        if conf.accu_lambda != 0.:
+            combined_attn = attn + accu_attn * conf.accu_lambda
+            clamped_attn = combined_attn.clamp(min=conf.accu_floor, max=conf.accu_ceil)
+            # apply_attn = (clamped_attn - combined_attn).detach() + combined_attn  # todo(+N): make gradient flow?
+            apply_accu_attn = clamped_attn
+        else:
+            apply_accu_attn = attn
+        # model based clamp
+        if self.mclp_aff:
+            mclp_upper = self.mclp_aff(query).unsqueeze(-2)  # [*, len_q, 1, head]
+            apply_attn = BK.min_elem(apply_accu_attn, mclp_upper)  # clamp on the ceil side
+        else:
+            apply_attn = apply_accu_attn
+        # apply attn and get result
+        result_value = self._collect_f(value_up, apply_attn)  # [*, len_q, Dv]
+        return result_value
+    # =====
+    # specific strategies
+    # 1. simply weighted sum over 2d [len_k, head]
+    def _collect_2d(self, value_up, apply_attn):
+        conf = self.conf
+        fused_attn = apply_attn.view(BK.get_shape(apply_attn)[:-2]+[-1])  # [*, len_q, len_k*head]
+        if conf.collect_renorm:
+            fused_attn = fused_attn / fused_attn.sum(-1, keepdims=True).clamp(min=conf.collect_renorm_mindiv)
+        value_up_shape = BK.get_shape(value_up)
+        fused_value_up = value_up.view(value_up_shape[:-3] + [-1, value_up_shape[-1]])  # [*, len_k*head, Dv]
+        result_value = BK.matmul(fused_attn, fused_value_up)  # [*, len_q, Dv]
+        return result_value
+    # 2. weighted sum over cand, then reduce on head (pool/affine)
+    def _collect_ch(self, value_up, apply_attn):
+        conf = self.conf
+        # first weighted sum over cand
+        cand_final_attn = apply_attn.transpose(-1, -2).transpose(-2, -3)  # [*, head, len_q, len_k]
+        if conf.collect_renorm:
+            cand_final_attn = cand_final_attn / cand_final_attn.sum(-1, keepdims=True).clamp(min=conf.collect_renorm_mindiv)
+        transposed_value_up = value_up.transpose(-2, -3)  # [*, head, len_k, Dv]
+        multihead_context = BK.matmul(cand_final_attn, transposed_value_up).transpose(-2, -3)  # [*, len_q, head, Dv]
+        # then reduce head
+        result_value = self._collect_reduce_f(multihead_context.contiguous())  # [*, len_q, Dv]
+        return result_value
+    # 3. weighted sum over head, then reduce on cand (pool) (cannot use affine here)
+    def _collect_hc(self, value_up, apply_attn):
+        conf = self.conf
+        # first weighted sum over head
+        head_final_attn = apply_attn.transpose(-2, -3)  # [*, len_k, len_q, head]
+        if conf.collect_renorm:
+            head_final_attn = head_final_attn / head_final_attn.sum(-1, keepdims=True).clamp(min=conf.collect_renorm_mindiv)
+        apply_value_up = value_up  # [*, len_k, head, Dv]
+        multicand_context = BK.matmul(head_final_attn, apply_value_up).transpose(-2, -3)  # [*, len_q, len_k, Dv]
+        # then reduce cand
+        result_value = self._collect_reduce_f(multicand_context)  # [*, len_q, Dv]
+        return result_value
diff --git a/tasks/zmlm/model/mods/vrec/fcomb.py b/tasks/zmlm/model/mods/vrec/fcomb.py
new file mode 100644
index 0000000..1eb57f3
--- /dev/null
+++ b/tasks/zmlm/model/mods/vrec/fcomb.py
@@ -0,0 +1,24 @@
+from typing import List, Union, Dict, Iterable
+from msp.utils import Constants, Conf, zwarn, Helper
+from msp.nn import BK
+from msp.nn.layers import BasicNode, Affine, NoDropRop, PosiEmbedding2, ActivationHelper, Dropout
+from .base import AffineHelperNode
+# =====
+# actually very similar to MattNode.c1_scorer
+# conf
+class FCombConf(Conf):
+    def __init__(self):
+        pass
+# node
+class FCombNode(BasicNode):
+    def __init__(self, pc, dim_q, dim_k, dim_v, conf: FCombConf):
+        super().__init__(pc, None, None)
+        pass
+    def __call__(self, *args, **kwargs):
+        raise NotImplementedError()
diff --git a/tasks/zmlm/model/mods/vrec/matt.py b/tasks/zmlm/model/mods/vrec/matt.py
new file mode 100644
index 0000000..d7374f8
--- /dev/null
+++ b/tasks/zmlm/model/mods/vrec/matt.py
@@ -0,0 +1,49 @@
+# multihead attention node
+# -- include the components of: c{1,2,3}_*
+# -- (Q, K, V, accu_attns, rel_dist) => (results_value, attn, attn_info)
+from typing import List, Union, Dict, Iterable
+from msp.utils import Constants, Conf, zwarn, Helper
+from msp.nn import BK
+from msp.nn.layers import BasicNode
+from .c1_score import MAttScorerConf, MAttScorerNode
+from .c2_norm import MAttNormerConf, MAttNormerNode
+from .c3_collect import MAttCollectorConf, MAttCollectorNode
+# -----
+# conf
+class MAttConf(Conf):
+    def __init__(self):
+        self.head_count = 8
+        self.scorer_conf = MAttScorerConf()
+        self.normer_conf = MAttNormerConf()
+        self.collector_conf = MAttCollectorConf()
+# node
+class MAttNode(BasicNode):
+    def __init__(self, pc, dim_q, dim_k, dim_v, conf: MAttConf):
+        super().__init__(pc, None, None)
+        self.conf = conf
+        # -----
+        self.scorer = self.add_sub_node("c1", MAttScorerNode(pc, dim_q, dim_k, conf.head_count, conf.scorer_conf))
+        self.normer = self.add_sub_node("c2", MAttNormerNode(pc, conf.head_count, conf.normer_conf))
+        self.collector = self.add_sub_node("c3", MAttCollectorNode(pc, dim_q, dim_v, conf.head_count, conf.collector_conf))
+    def get_output_dims(self, *input_dims):
+        return self.collector.get_output_dims(*input_dims)
+    # input *[*, len?, Din]
+    def __call__(self, query, key, value, accu_attn, mask_k=None, mask_qk=None, rel_dist=None, temperature=1., forced_attn=None):
+        if forced_attn is None:  # need to calculate the attns
+            scores = self.scorer(query, key, accu_attn, mask_k=mask_k, mask_qk=mask_qk, rel_dist=rel_dist)
+            normer_ret = self.normer(scores, temperature)
+        else:  # skip the structured part
+            scores = None
+            normer_ret = (forced_attn, [], [], [], [])
+        result_value = self.collector(query, value, normer_ret[0], accu_attn)
+        # return the output of the three compoents:
+        # [*, len_q, len_k, head], ([*, len_q, len_k, head], ...), [*, len_q, Dv]
+        return scores, normer_ret, result_value
diff --git a/tasks/zmlm/model/mods/vrec/vrec.py b/tasks/zmlm/model/mods/vrec/vrec.py
new file mode 100644
index 0000000..1c8e331
--- /dev/null
+++ b/tasks/zmlm/model/mods/vrec/vrec.py
@@ -0,0 +1,360 @@
+# the vertical recursive/recurrent module
+from typing import List, Union, Dict, Iterable
+import numpy as np
+import math
+from msp.utils import Constants, Conf, zwarn, Helper
+from msp.nn import BK
+from msp.nn.layers import BasicNode, LstmNode2, Affine, Dropout, NoDropRop, LayerNorm
+from msp.zext.process_train import SVConf, ScheduledValue
+from .matt import MAttConf, MAttNode
+from .fcomb import FCombConf, FCombNode
+from .base import AffineCombiner
+from ...base import BaseModuleConf, BaseModule, LossHelper
+# from ...helper import EntropyHelper
+# =====
+# single node
+# conf
+class VRecConf(Conf):
+    def __init__(self):
+        # layer norm
+        self.use_pre_norm = False  # combine(ss(LN(x)), x)
+        self.use_post_norm = False  # LN(combine(ss(x), x)
+        # -----
+        # selfatt conf
+        self.feat_mod = "matt"  # matt/fcomb
+        self.matt_conf = MAttConf()
+        self.fc_conf = FCombConf()
+        # what to feed
+        self.feat_qk_lambda_orig = 0.  # for Q/K, this*orig_t + (1-this)*rec_t
+        self.feat_v_lambda_orig = 0.  # for V(k), ...
+        # -----
+        # combiner
+        self.comb_mode = "affine"  # affine/lstm
+        # affine mode
+        self.comb_affine_q = False
+        self.comb_affine_v = False  # if d_v != dim_q, should affine this
+        self.comb_affine_act = 'linear'  # maybe linear is fine since at outside there may be LayerNorm
+        self.comb_affine_drop = 0.1
+        # what to feed
+        self.comb_q_lambda_orig = 0.  # for q, ...
+        # -----
+        # further ff?
+        self.ff_dim = 0
+        self.ff_drop = 0.1
+        self.ff_act = "relu"
+    @property
+    def attn_count(self):
+        if self.feat_mod == "matt":
+            return self.matt_conf.head_count
+        elif self.feat_mod == "fcomb":
+            return self.fc_conf.fc_count
+        else:
+            raise NotImplementedError()
+# cache
+class VRecCache:
+    def __init__(self):
+        # values for cur step: "orig" is the very first input, "rec" is the recusively built one
+        self.orig_t = None  # [*, len_q, D]
+        self.rec_t = None  # [*, len_q, D]
+        self.accu_attn = None  # [*, len_q, len_v, head]
+        self.rec_lstm_c_t = None  # optional for lstm mode
+        # list on steps
+        self.list_hidden = []  # List[*, len_q, D]
+        self.list_score = []  # List[*, len_q, len_v, head]
+        self.list_attn = []  # List[*, len_q, len_v, head]
+        self.list_accu_attn = []  # List[*, len_q, len_v, head]
+        self.list_attn_info = []  # Tuple(attn, [prob_valid], [prob_noop], [prob_full])
+# one node (one step) in the recursion
+class VRecNode(BasicNode):
+    def __init__(self, pc, dim: int, conf: VRecConf):
+        super().__init__(pc, None, None)
+        self.conf = conf
+        self.dim = dim
+        # =====
+        # Feat
+        if conf.feat_mod == "matt":
+            self.feat_node = self.add_sub_node("feat", MAttNode(pc, dim, dim, dim, conf.matt_conf))
+            self.attn_count = conf.matt_conf.head_count
+        elif conf.feat_mod == "fcomb":
+            self.feat_node = self.add_sub_node("feat", FCombNode(pc, dim, dim, dim, conf.fc_conf))
+            self.attn_count = conf.fc_conf.fc_count
+        else:
+            raise NotImplementedError()
+        feat_out_dim = self.feat_node.get_output_dims()[0]
+        # =====
+        # Combiner
+        if conf.comb_mode == "affine":
+            self.comb_aff = self.add_sub_node("aff", AffineCombiner(
+                pc, [dim, feat_out_dim], [conf.comb_affine_q, conf.comb_affine_v], dim,
+                out_act=conf.comb_affine_act, out_drop=conf.comb_affine_drop))
+            self.comb_f = lambda q, v, c: (self.comb_aff([q,v]), None)
+        elif conf.comb_mode == "lstm":
+            self.comb_lstm = self.add_sub_node("lstm", LstmNode2(pc, feat_out_dim, dim))
+            self.comb_f = self._call_lstm
+        else:
+            raise NotImplementedError()
+        # =====
+        # ff
+        if conf.ff_dim > 0:
+            self.has_ff = True
+            self.linear1 = self.add_sub_node("l1", Affine(pc, dim, conf.ff_dim, act=conf.ff_act, init_rop=NoDropRop()))
+            self.dropout1 = self.add_sub_node("d1", Dropout(pc, (conf.ff_dim, ), fix_rate=conf.ff_drop))
+            self.linear2 = self.add_sub_node("l2", Affine(pc, conf.ff_dim, dim, act="linear", init_rop=NoDropRop()))
+            self.dropout2 = self.add_sub_node("d2", Dropout(pc, (dim, ), fix_rate=conf.ff_drop))
+        else:
+            self.has_ff = False
+        # layer norms
+        if conf.use_pre_norm:
+            self.att_pre_norm = self.add_sub_node("aln1", LayerNorm(pc, dim))
+            self.ff_pre_norm = self.add_sub_node("fln1", LayerNorm(pc, dim))
+        else:
+            self.att_pre_norm = self.ff_pre_norm = None
+        if conf.use_post_norm:
+            self.att_post_norm = self.add_sub_node("aln2", LayerNorm(pc, dim))
+            self.ff_post_norm = self.add_sub_node("fln2", LayerNorm(pc, dim))
+        else:
+            self.att_post_norm = self.ff_post_norm = None
+    def get_output_dims(self, *input_dims):
+        return (self.dim, )
+    # -----
+    def _call_lstm(self, q, v, c):
+        prev_shapes = BK.get_shape(q)[:-1]  # [*, len_q]
+        in_reshape = [np.prod(prev_shapes), -1]
+        in_q, in_v, in_c = [z.view(in_reshape) for z in [q,v,c]]
+        ret_h, ret_c = self.comb_lstm(in_v, (in_q, in_c), None)  # v acts as input, q acts as hidden
+        out_reshape = prev_shapes + [-1]  # [*, len_q, D]
+        return ret_h.view(out_reshape), ret_c.view(out_reshape)
+    # init cache
+    def init_call(self, src):
+        # init accumulated attn: all 0.
+        src_shape = BK.get_shape(src)
+        attn_shape = src_shape[:-1] + [src_shape[-2], self.attn_count]  # [*, len_q, len_k, head]
+        cache = VRecCache()
+        cache.orig_t = src
+        cache.rec_t = src  # initially, the same as "orig_t"
+        cache.rec_lstm_c_t = BK.zeros(BK.get_shape(src))  # initially zero
+        cache.accu_attn = BK.zeros(attn_shape)
+        return cache
+    # one update step: *[bs, slen, dim]
+    def update_call(self, cache: VRecCache, src_mask=None, qk_mask=None,
+                    attn_range=None, rel_dist=None, temperature=1., forced_attn=None):
+        conf = self.conf
+        # -----
+        # first call matt to get v
+        matt_input_qk = cache.orig_t * conf.feat_qk_lambda_orig + cache.rec_t * (1. - conf.feat_qk_lambda_orig)
+        if self.att_pre_norm:
+            matt_input_qk = self.att_pre_norm(matt_input_qk)
+        matt_input_v = cache.orig_t * conf.feat_v_lambda_orig + cache.rec_t * (1. - conf.feat_v_lambda_orig)
+        # todo(note): currently no pre-norm for matt_input_v
+        # put attn_range as mask_qk
+        if attn_range is not None and attn_range >= 0:  # <0 means not effective
+            cur_slen = BK.get_shape(matt_input_qk, -2)
+            tmp_arange_t = BK.arange_idx(cur_slen)  # [slen]
+            # less or equal!!
+            mask_qk = ((tmp_arange_t.unsqueeze(-1) - tmp_arange_t.unsqueeze(0)).abs() <= attn_range).float()
+            if qk_mask is not None:  # further with input masks
+                mask_qk *= qk_mask
+        else:
+            mask_qk = qk_mask
+        scores, attn_info, result_value = self.feat_node(
+            matt_input_qk, matt_input_qk, matt_input_v, cache.accu_attn, mask_k=src_mask, mask_qk=mask_qk,
+            rel_dist=rel_dist, temperature=temperature, forced_attn=forced_attn)  # ..., [*, len_q, dv]
+        # -----
+        # then combine q(hidden) and v(input)
+        comb_input_q = cache.orig_t * conf.comb_q_lambda_orig + cache.rec_t * (1. - conf.comb_q_lambda_orig)
+        comb_result, comb_c = self.comb_f(comb_input_q, result_value, cache.rec_lstm_c_t)  # [*, len_q, dim]
+        if self.att_post_norm:
+            comb_result = self.att_post_norm(comb_result)
+        # -----
+        # ff
+        if self.has_ff:
+            if self.ff_pre_norm:
+                ff_input = self.ff_pre_norm(comb_result)
+            else:
+                ff_input = comb_result
+            ff_output = comb_result + self.dropout2(self.linear2(self.dropout1(self.linear1(ff_input))))
+            if self.ff_post_norm:
+                ff_output = self.ff_post_norm(ff_output)
+        else:  # otherwise no ff
+            ff_output = comb_result
+        # -----
+        # update cache and return output
+        # cache.orig_t = cache.orig_t  # this does not change
+        cache.rec_t = ff_output
+        cache.accu_attn = cache.accu_attn + attn_info[0]  # accumulating attn
+        cache.rec_lstm_c_t = comb_c  # optional C for lstm
+        cache.list_hidden.append(ff_output)  # all hidden layers
+        cache.list_score.append(scores)  # all un-normed scores
+        cache.list_attn.append(attn_info[0])  # all normed scores
+        cache.list_accu_attn.append(cache.accu_attn)  # all accumulated attns
+        cache.list_attn_info.append(attn_info)  # all attn infos
+        return ff_output
+# =====
+# encoder
+class VRecEncoderConf(BaseModuleConf):
+    def __init__(self):
+        super().__init__()
+        # -----
+        # repeat of common layers
+        self.vr_conf = VRecConf()  # vrec conf
+        self.num_layer = 4
+        self.share_layers = True  # recurrent/recursive if sharing
+        self.attn_ranges = []
+        # scheduled values
+        self.temperature = SVConf().init_from_kwargs(val=1., which_idx="eidx", mode="none", min_val=0.1)
+        self.lambda_noop_prob = SVConf().init_from_kwargs(val=0., which_idx="eidx", mode="none", min_val=0.)
+        self.lambda_entropy = SVConf().init_from_kwargs(val=0., which_idx="eidx", mode="none", min_val=0.)
+        self.lambda_attn_l1 = SVConf().init_from_kwargs(val=0., which_idx="eidx", mode="none", min_val=0.)
+        # self.lambda_coverage = SVConf().init_from_kwargs(val=0., which_idx="eidx", mode="none", min_val=0.)
+        # special noop loss
+        self.noop_epsilon = 0.01  # if <0, means directly using prob themselves without log
+    def do_validate(self):
+        self.attn_ranges = [int(z) for z in self.attn_ranges]
+class VRecEncoder(BaseModule):
+    def __init__(self, pc, dim: int, conf: VRecEncoderConf):
+        super().__init__(pc, conf, output_dims=(dim,))
+        # -----
+        if conf.share_layers:
+            node = self.add_sub_node("m", VRecNode(pc, dim, conf.vr_conf))
+            self.layers = [node for _ in range(conf.num_layer)]  # use the same node!!
+        else:
+            self.layers = [self.add_sub_node("m", VRecNode(pc, dim, conf.vr_conf)) for _ in range(conf.num_layer)]
+        self.attn_ranges = conf.attn_ranges.copy()
+        self.attn_ranges.extend([None] * (conf.num_layer - len(self.attn_ranges)))  # None means no restrictions
+        # scheduled values
+        self.temperature = ScheduledValue(f"{self.name}:temperature", conf.temperature)
+        self.lambda_noop_prob = ScheduledValue(f"{self.name}:lambda_noop_prob", conf.lambda_noop_prob)
+        self.lambda_entropy = ScheduledValue(f"{self.name}:lambda_entropy", conf.lambda_entropy)
+        self.lambda_attn_l1 = ScheduledValue(f"{self.name}:lambda_attn_l1", conf.lambda_attn_l1)
+        # self.lambda_coverage = ScheduledValue(f"{self.name}:lambda_coverage", conf.lambda_coverage)
+    @property
+    def attn_count(self):
+        if len(self.layers) > 0:
+            all_counts = [z.attn_count for z in self.layers]
+            assert all(z==all_counts[0] for z in all_counts), "attn_count disagree!!"
+            return all_counts[0]
+        else:
+            return self.conf.vr_conf.attn_count
+    def get_scheduled_values(self):
+        return super().get_scheduled_values() + [self.temperature, self.lambda_noop_prob, self.lambda_entropy, self.lambda_attn_l1]
+    def __call__(self, src, src_mask=None, qk_mask=None, rel_dist=None, forced_attns=None, collect_loss=False):
+        if src_mask is not None:
+            src_mask = BK.input_real(src_mask)
+        if forced_attns is None:
+            forced_attns = [None] * len(self.layers)
+        # -----
+        # forward
+        temperature = self.temperature.value
+        cur_hidden = src
+        if len(self.layers) > 0:
+            cache = self.layers[0].init_call(src)
+            for one_lidx, one_layer in enumerate(self.layers):
+                cur_hidden = one_layer.update_call(cache, src_mask=src_mask, qk_mask=qk_mask, attn_range=self.attn_ranges[one_lidx],
+                                                   rel_dist=rel_dist, temperature=temperature, forced_attn=forced_attns[one_lidx])
+        else:
+            cache = VRecCache()  # empty one
+        # -----
+        # collect loss
+        if collect_loss:
+            loss_item = self._collect_losses(cache)
+        else:
+            loss_item = None
+        return cur_hidden, cache, loss_item
+    # =====
+    # helpers
+    def _loss_noop_prob(self, t):
+        noop_epsilon = self.conf.noop_epsilon
+        if noop_epsilon >= 0.:
+            ret_loss = ((1. + noop_epsilon - t).log() - math.log(noop_epsilon))
+        else:
+            ret_loss = (1. - t)
+        return ret_loss
+    def _loss_entropy(self, t_and_dim):
+        # the entropy
+        t, dim = t_and_dim
+        all_ents = - (t * (t + 1e-10).log()).sum(dim)  # reduce on specific dim
+        # all_ents = EntropyHelper.cross_entropy(t, t, dim)
+        return all_ents
+    def _loss_attn_l1(self, t):
+        # simply L1 reg for all attn (but only for >=0 part)
+        t_act = t * (t>=0.).float()
+        return t_act.abs()  # same shape as input
+    # collect loss from one layer
+    def _collect_losses(self, cache: VRecCache):
+        special_losses = []
+        list_attn_info = cache.list_attn_info  # List[(attn, list_prob_valid, list_prob_noop, list_prob_full, list_dims)]
+        # -----
+        cur_lambda_noop_prob = self.lambda_noop_prob.value
+        if cur_lambda_noop_prob > 0.:
+            noop_losses = self.get_losses_from_attn_list(
+                list_attn_info, lambda x: x[2], self._loss_noop_prob, "PNopS", cur_lambda_noop_prob)
+            special_losses.extend(noop_losses)
+        # -----
+        cur_lambda_entropy = self.lambda_entropy.value
+        if cur_lambda_entropy > 0.:
+            ent_losses = self.get_losses_from_attn_list(
+                list_attn_info, lambda x: list(zip(x[3],x[4])), self._loss_entropy, "EntRegS", cur_lambda_entropy)
+            special_losses.extend(ent_losses)
+        # -----
+        cur_lambda_attn_l1 = self.lambda_attn_l1.value
+        if cur_lambda_attn_l1 > 0.:
+            attn_l1_losses = self.get_losses_from_attn_list(
+                list_attn_info, lambda x: [x[0]], self._loss_attn_l1, "AttnL1S", cur_lambda_attn_l1)
+            special_losses.extend(attn_l1_losses)
+        # -----
+        # here combine them into one
+        ret_loss_item = self._compile_component_loss("vrec", special_losses)
+        return ret_loss_item
+    # common procedures for obtaining losses from "list_attn_info"
+    @staticmethod
+    def get_losses_from_attn_list(list_attn_info: List, ts_f, loss_f, loss_prefix, loss_lambda):
+        loss_num = None
+        loss_counts: List[int] = []
+        loss_sums: List[List] = []
+        rets = []
+        # -----
+        for one_attn_info in list_attn_info:  # each update step
+            one_ts: List = ts_f(one_attn_info)  # get tensor list from attn_info
+            # get number of losses
+            if loss_num is None:
+                loss_num = len(one_ts)
+                loss_counts = [0] * loss_num
+                loss_sums = [[] for _ in range(loss_num)]
+            else:
+                assert len(one_ts) == loss_num, "mismatched ts length"
+            # iter them
+            for one_t_idx, one_t in enumerate(one_ts):  # iter on the tensor list
+                one_loss = loss_f(one_t)
+                # need it to be in the corresponding shape
+                loss_counts[one_t_idx] += np.prod(BK.get_shape(one_loss)).item()
+                loss_sums[one_t_idx].append(one_loss.sum())
+        # for different steps
+        for i, one_loss_count, one_loss_sums in zip(range(len(loss_counts)), loss_counts, loss_sums):
+            loss_leaf = LossHelper.compile_leaf_info(f"{loss_prefix}{i}", BK.stack(one_loss_sums, 0).sum(),
+                                                     BK.input_real(one_loss_count), loss_lambda=loss_lambda)
+            rets.append(loss_leaf)
+        return rets
diff --git a/tasks/zmlm/model/mtl.py b/tasks/zmlm/model/mtl.py
new file mode 100644
index 0000000..2f7052a
--- /dev/null
+++ b/tasks/zmlm/model/mtl.py
@@ -0,0 +1,460 @@
+# the multitask model
+from typing import List, Dict
+import numpy as np
+from msp.utils import zlog, zwarn, Random, Helper, Conf
+from msp.data import VocabPackage
+from msp.nn import BK
+from msp.nn.layers import BasicNode, PosiEmbedding2, RefreshOptions, Dropout
+from msp.nn.modules import EncConf, MyEncoder
+from msp.nn.modules.berter2 import BertFeaturesWeightLayer
+from msp.zext.process_train import SVConf, ScheduledValue
+from .base import BaseModelConf, BaseModel, BaseModuleConf, BaseModule, LossHelper
+from .helper import RPrepConf, RPrepNode, EntropyHelper
+from .mods.embedder import EmbedderNodeConf, EmbedderNode, Inputter, AugWord2Node
+from .mods.vrec import VRecEncoderConf, VRecEncoder, VRecConf, VRecNode
+from .mods.masklm import MaskLMNodeConf, MaskLMNode
+from .mods.plainlm import PlainLMNodeConf, PlainLMNode
+from .mods.orderpr import OrderPredNodeConf, OrderPredNode
+from .mods.dpar import DparG1DecoderConf, DparG1Decoder
+from .mods.seqlab import SeqLabNodeConf, SeqLabNode
+from .mods.seqcrf import SeqCrfNodeConf, SeqCrfNode
+from ..data.insts import GeneralSentence, NpArrField
+# =====
+# the overall Model
+class MtlMlmModelConf(BaseModelConf):
+    def __init__(self):
+        super().__init__()
+        # components
+        self.emb_conf = EmbedderNodeConf()
+        self.mlm_conf = MaskLMNodeConf()
+        self.plm_conf = PlainLMNodeConf()
+        self.orp_conf = OrderPredNodeConf()
+        self.orp_loss_special = False  # special loss for orp
+        # non-pretraining parts
+        self.dpar_conf = DparG1DecoderConf()
+        self.upos_conf = SeqLabNodeConf()
+        self.ner_conf = SeqCrfNodeConf()
+        self.do_ner = False  # an extra flag
+        self.ner_use_crf = True  # whether use crf for ner?
+        # where pairwise repr comes from
+        self.default_attn_count = 128  # by default, setting this to what?
+        self.prepr_choice = "attn_max"  # rdist, rdist_abs, attn_sum, attn_avg, attn_max, attn_last
+        # which encoder to use
+        self.enc_choice = "vrec"  # by default, use the new one
+        self.venc_conf = VRecEncoderConf()
+        self.oenc_conf = EncConf().init_from_kwargs(enc_hidden=512, enc_rnn_layer=0)
+        # another small encoder
+        self.rprep_conf = RPrepConf()
+        # agreement module
+        self.lambda_agree = SVConf().init_from_kwargs(val=0., which_idx="eidx", mode="none")
+        self.agree_loss_f = "js"
+        # separate generator seed especially for testing mask
+        self.testing_rand_gen_seed = 0
+        self.testing_get_attns = False
+        # todo(+N): should make this into another conf
+        # for training and testing
+        self.no_build_dict = False
+        # -- length thresh
+        self.train_max_length = 80
+        self.train_min_length = 5
+        self.test_single_length = 80
+        # -- batch size
+        self.train_shuffle = True
+        self.train_batch_size = 80
+        self.train_inst_copy = 1  # how many times to copy the training insts (for different masks)
+        self.test_batch_size = 32
+        # -- maxibatch for BatchArranger
+        self.train_maxibatch_size = 20  # cache for this number of batches in BatchArranger
+        self.test_maxibatch_size = -1  # -1 means cache all
+        # -- batch_size_f
+        self.train_batch_on_len = False  # whether use sent length as budgets rather then sent count
+        self.test_batch_on_len = False
+        # -----
+        self.train_preload_model = False
+        self.train_preload_process = False
+        # interactive testing
+        self.test_interactive = False  # special mode
+        # load pre-training model?
+        self.load_pretrain_model_name = ""
+        # -----
+        # special mode with extra word2
+        # todo(+N): ugly patch!
+        self.aug_word2 = False
+        self.aug_word2_dim = 300
+        self.aug_word2_pretrain = ""
+        self.aug_word2_save_dir = ""
+        self.aug_word2_aug_encoder = True  # special model's special mode, feature-based original encoder and stack another one!!
+        self.aug_detach_ratio = 0.5
+        self.aug_detach_dropout = 0.33
+        self.aug_detach_numlayer = 6  # use mixture of how many last layers?
+    def do_validate(self):
+        if self.orp_loss_special:
+            assert self.orp_conf.disturb_mode == "local_shuffle2"
+class MtlMlmModel(BaseModel):
+    def __init__(self, conf: MtlMlmModelConf, vpack: VocabPackage):
+        super().__init__(conf)
+        # for easier checking
+        self.word_vocab = vpack.get_voc("word")
+        # components
+        self.embedder = self.add_node("emb", EmbedderNode(self.pc, conf.emb_conf, vpack))
+        self.inputter = Inputter(self.embedder, vpack)  # not a node
+        self.emb_out_dim = self.embedder.get_output_dims()[0]
+        self.enc_attn_count = conf.default_attn_count
+        if conf.enc_choice == "vrec":
+            self.encoder = self.add_component("enc", VRecEncoder(self.pc, self.emb_out_dim, conf.venc_conf))
+            self.enc_attn_count = self.encoder.attn_count
+        elif conf.enc_choice == "original":
+            conf.oenc_conf._input_dim = self.emb_out_dim
+            self.encoder = self.add_node("enc", MyEncoder(self.pc, conf.oenc_conf))
+        else:
+            raise NotImplementedError()
+        zlog(f"Finished building model's encoder {self.encoder}, all size is {self.encoder.count_allsize_parameters()}")
+        self.enc_out_dim = self.encoder.get_output_dims()[0]
+        # --
+        conf.rprep_conf._rprep_vr_conf.matt_conf.head_count = self.enc_attn_count  # make head-count agree
+        self.rpreper = self.add_node("rprep", RPrepNode(self.pc, self.enc_out_dim, conf.rprep_conf))
+        # --
+        self.lambda_agree = self.add_scheduled_value(ScheduledValue(f"agr:lambda", conf.lambda_agree))
+        self.agree_loss_f = EntropyHelper.get_method(conf.agree_loss_f)
+        # --
+        self.masklm = self.add_component("mlm", MaskLMNode(self.pc, self.enc_out_dim, conf.mlm_conf, self.inputter))
+        self.plainlm = self.add_component("plm", PlainLMNode(self.pc, self.enc_out_dim, conf.plm_conf, self.inputter))
+        # todo(note): here we use attn as dim_pair, do not use pair if not using vrec!!
+        self.orderpr = self.add_component("orp", OrderPredNode(
+            self.pc, self.enc_out_dim, self.enc_attn_count, conf.orp_conf, self.inputter))
+        # =====
+        # pre-training pre-load point!!
+        if conf.load_pretrain_model_name:
+            zlog(f"At preload_pretrain point: Loading from {conf.load_pretrain_model_name}")
+            self.pc.load(conf.load_pretrain_model_name, strict=False)
+        # =====
+        self.dpar = self.add_component("dpar", DparG1Decoder(
+            self.pc, self.enc_out_dim, self.enc_attn_count, conf.dpar_conf, self.inputter))
+        self.upos = self.add_component("upos", SeqLabNode(
+            self.pc, "pos", self.enc_out_dim, self.conf.upos_conf, self.inputter))
+        if conf.do_ner:
+            if conf.ner_use_crf:
+                self.ner = self.add_component("ner", SeqCrfNode(
+                    self.pc, "ner", self.enc_out_dim, self.conf.ner_conf, self.inputter))
+            else:
+                self.ner = self.add_component("ner", SeqLabNode(
+                    self.pc, "ner", self.enc_out_dim, self.conf.ner_conf, self.inputter))
+        else:
+            self.ner = None
+        # for pairwise reprs (no trainable params here!)
+        self.rel_dist_embed = self.add_node("oremb", PosiEmbedding2(self.pc, n_dim=self.enc_attn_count, max_val=100))
+        self._prepr_f_attn_sum = lambda cache, rdist: BK.stack(cache.list_attn, 0).sum(0) if (len(cache.list_attn))>0 else None
+        self._prepr_f_attn_avg = lambda cache, rdist: BK.stack(cache.list_attn, 0).mean(0) if (len(cache.list_attn))>0 else None
+        self._prepr_f_attn_max = lambda cache, rdist: BK.stack(cache.list_attn, 0).max(0)[0] if (len(cache.list_attn))>0 else None
+        self._prepr_f_attn_last = lambda cache, rdist: cache.list_attn[-1] if (len(cache.list_attn))>0 else None
+        self._prepr_f_rdist = lambda cache, rdist: self._get_rel_dist_embed(rdist, False)
+        self._prepr_f_rdist_abs = lambda cache, rdist: self._get_rel_dist_embed(rdist, True)
+        self.prepr_f = getattr(self, "_prepr_f_"+conf.prepr_choice)  # shortcut
+        # --
+        self.testing_rand_gen = Random.create_sep_generator(conf.testing_rand_gen_seed)  # especial gen for testing
+        # =====
+        if conf.orp_loss_special:
+            self.orderpr.add_node_special(self.masklm)
+        # =====
+        # extra one!!
+        self.aug_word2 = self.aug_encoder = self.aug_mixturer = None
+        if conf.aug_word2:
+            self.aug_word2 = self.add_node("aug2", AugWord2Node(self.pc, conf.emb_conf, vpack,
+                                                                "word2", conf.aug_word2_dim, self.emb_out_dim))
+            if conf.aug_word2_aug_encoder:
+                assert conf.enc_choice == "vrec"
+                self.aug_detach_drop = self.add_node("dd", Dropout(self.pc, (self.enc_out_dim,), fix_rate=conf.aug_detach_dropout))
+                self.aug_encoder = self.add_component("Aenc", VRecEncoder(self.pc, self.emb_out_dim, conf.venc_conf))
+                self.aug_mixturer = self.add_node("Amix", BertFeaturesWeightLayer(self.pc, conf.aug_detach_numlayer))
+    # helper: embed and encode
+    def _emb_and_enc(self, cur_input_map: Dict, collect_loss: bool, insts=None):
+        conf = self.conf
+        # -----
+        # special mode
+        if conf.aug_word2 and conf.aug_word2_aug_encoder:
+            _rop = RefreshOptions(training=False)  # special feature-mode!!
+            self.embedder.refresh(_rop)
+            self.encoder.refresh(_rop)
+        # -----
+        emb_t, mask_t = self.embedder(cur_input_map)
+        rel_dist = cur_input_map.get("rel_dist", None)
+        if rel_dist is not None:
+            rel_dist = BK.input_idx(rel_dist)
+        if conf.enc_choice == "vrec":
+            enc_t, cache, enc_loss = self.encoder(emb_t, src_mask=mask_t, rel_dist=rel_dist, collect_loss=collect_loss)
+        elif conf.enc_choice == "original":  # todo(note): change back to arr for back compatibility
+            assert rel_dist is None, "Original encoder does not support rel_dist"
+            enc_t = self.encoder(emb_t, BK.get_value(mask_t))
+            cache, enc_loss = None, None
+        else:
+            raise NotImplementedError()
+        # another encoder based on attn
+        final_enc_t = self.rpreper(emb_t, enc_t, cache)  # [*, slen, D] => final encoder output
+        if conf.aug_word2:
+            emb2_t = self.aug_word2(insts)
+            if conf.aug_word2_aug_encoder:
+                # simply add them all together, detach orig-enc as features
+                stack_hidden_t = BK.stack(cache.list_hidden[-conf.aug_detach_numlayer:], -2).detach()
+                features = self.aug_mixturer(stack_hidden_t)
+                aug_input = (emb2_t + conf.aug_detach_ratio*self.aug_detach_drop(features))
+                final_enc_t, cache, enc_loss = self.aug_encoder(aug_input, src_mask=mask_t,
+                                                                rel_dist=rel_dist, collect_loss=collect_loss)
+            else:
+                final_enc_t = (final_enc_t + emb2_t)  # otherwise, simply adding
+        return emb_t, mask_t, final_enc_t, cache, enc_loss
+    # ---
+    # helper on agreement loss
+    def _get_agr_loss(self, loss_prefix, cache1, cache2=None, copy_num=1):
+        # =====
+        def _ts_pair_f(_attn_info_pair):
+            _ainfo1, _ainfo2 = _attn_info_pair
+            _rets = []
+            for _t1, _d1, _t2, _d2 in zip(_ainfo1[3], _ainfo1[4], _ainfo2[3], _ainfo2[4]):
+                assert _d1 == _d2, "disagree on which dim to collect agr_loss!"
+                _rets.append((_t1, _t2, _d1))
+            return _rets
+        # --
+        def _ts_self_f(_list_attn_info):
+            _rets = []
+            for _t, _d in zip(_list_attn_info[3], _list_attn_info[4]):
+                # extra dim at idx 0; todo(note): must repeat insts at outmost idx: repeated = insts * copy_num
+                _t1 = _t.view([copy_num, -1] + BK.get_shape(_t)[1:])  # [copy, bs, ...]
+                # roll it by 1
+                _t2 = BK.concat([_t1[-1].unsqueeze(0), _t1[:-1]], dim=0)
+                _rets.append((_t1, _t2, _d))
+            return _rets
+        # --
+        _arg_loss_f = lambda x: self.agree_loss_f(*x)
+        # =====
+        cur_lambda_agr = self.lambda_agree.value
+        if cur_lambda_agr > 0.:
+            if cache2 is None:  # roll self
+                assert copy_num > 1
+                rets = VRecEncoder.get_losses_from_attn_list(
+                    cache1.list_attn_info, _ts_self_f, _arg_loss_f, loss_prefix, cur_lambda_agr)
+            else:
+                rets = VRecEncoder.get_losses_from_attn_list(
+                    list(zip(cache1.list_attn_info, cache2.list_attn_info)), _ts_pair_f, _arg_loss_f, loss_prefix, cur_lambda_agr)
+        else:
+            rets = []
+        return rets
+    def _get_rel_dist_embed(self, rel_dist, use_abs: bool):
+        if use_abs:
+            rel_dist = BK.input_idx(rel_dist).abs()
+        ret = self.rel_dist_embed(rel_dist)  # [bs, len, len, H]
+        return ret
+    def _get_rel_dist(self, len_q: int, len_k: int = None):
+        if len_k is None:
+            len_k = len_q
+        dist_x = BK.arange_idx(0, len_k).unsqueeze(0)  # [1, len_k]
+        dist_y = BK.arange_idx(0, len_q).unsqueeze(1)  # [len_q, 1]
+        distance = dist_x - dist_y  # [len_q, len_k]
+        return distance
+    # ---
+    def fb_on_batch(self, insts: List[GeneralSentence], training=True, loss_factor=1.,
+                    rand_gen=None, assign_attns=False, **kwargs):
+        # =====
+        # import torch
+        # torch.autograd.set_detect_anomaly(True)
+        # with torch.autograd.detect_anomaly():
+        # =====
+        conf = self.conf
+        self.refresh_batch(training)
+        if len(insts) == 0:
+            return {"fb": 0, "sent": 0, "tok": 0}
+        # -----
+        # copying instances for training: expand at dim0
+        cur_copy = conf.train_inst_copy if training else 1
+        copied_insts = insts * cur_copy
+        all_losses = []
+        # -----
+        # original input
+        input_map = self.inputter(copied_insts)
+        # for the pretraining modules
+        has_loss_mlm, has_loss_orp = (self.masklm.loss_lambda.value > 0.), (self.orderpr.loss_lambda.value > 0.)
+        if (not has_loss_orp) and has_loss_mlm:  # only for mlm
+            masked_input_map, input_erase_mask_arr = self.masklm.mask_input(input_map, rand_gen=rand_gen)
+            emb_t, mask_t, enc_t, cache, enc_loss = self._emb_and_enc(masked_input_map, collect_loss=True)
+            all_losses.append(enc_loss)
+            # mlm loss; todo(note): currently only using one layer
+            mlm_loss = self.masklm.loss(enc_t, input_erase_mask_arr, input_map)
+            all_losses.append(mlm_loss)
+            # assign values
+            if assign_attns:  # may repeat and only keep that last one, but does not matter!
+                self._assign_attns_item(copied_insts, "mask", input_erase_mask_arr=input_erase_mask_arr, cache=cache)
+            # agreement loss
+            if cur_copy > 1:
+                all_losses.extend(self._get_agr_loss("agr_mlm", cache, copy_num=cur_copy))
+        if has_loss_orp:
+            disturbed_input_map = self.orderpr.disturb_input(input_map, rand_gen=rand_gen)
+            if has_loss_mlm:  # further mask some
+                disturb_keep_arr = disturbed_input_map.get("disturb_keep", None)
+                assert disturb_keep_arr is not None, "No keep region for mlm!"
+                # todo(note): in this mode we assume add_root, so here exclude arti-root by [:,1:]
+                masked_input_map, input_erase_mask_arr = \
+                    self.masklm.mask_input(input_map, rand_gen=rand_gen, extra_mask_arr=disturb_keep_arr[:,1:])
+                disturbed_input_map.update(masked_input_map)  # update
+            emb_t, mask_t, enc_t, cache, enc_loss = self._emb_and_enc(disturbed_input_map, collect_loss=True)
+            all_losses.append(enc_loss)
+            # orp loss
+            if conf.orp_loss_special:
+                orp_loss = self.orderpr.loss_special(enc_t, mask_t, disturbed_input_map.get("disturb_keep", None),
+                                                     disturbed_input_map, self.masklm)
+            else:
+                orp_input_attn = self.prepr_f(cache, disturbed_input_map.get("rel_dist"))
+                orp_loss = self.orderpr.loss(enc_t, orp_input_attn, mask_t, disturbed_input_map.get("disturb_keep", None))
+            all_losses.append(orp_loss)
+            # mlm loss
+            if has_loss_mlm:
+                mlm_loss = self.masklm.loss(enc_t, input_erase_mask_arr, input_map)
+                all_losses.append(mlm_loss)
+            # assign values
+            if assign_attns:  # may repeat and only keep that last one, but does not matter!
+                self._assign_attns_item(copied_insts, "dist", abs_posi_arr=disturbed_input_map.get("posi"), cache=cache)
+            # agreement loss
+            if cur_copy > 1:
+                all_losses.extend(self._get_agr_loss("agr_orp", cache, copy_num=cur_copy))
+        if self.plainlm.loss_lambda.value > 0.:
+            if conf.enc_choice == "vrec":  # special case for blm
+                emb_t, mask_t = self.embedder(input_map)
+                rel_dist = input_map.get("rel_dist", None)
+                if rel_dist is not None:
+                    rel_dist = BK.input_idx(rel_dist)
+                # two directions
+                true_rel_dist = self._get_rel_dist(BK.get_shape(mask_t, -1))  # q-k: [len_q, len_k]
+                enc_t1, cache1, enc_loss1 = self.encoder(emb_t, src_mask=mask_t, qk_mask=(true_rel_dist<=0).float(),
+                                                         rel_dist=rel_dist, collect_loss=True)
+                enc_t2, cache2, enc_loss2 = self.encoder(emb_t, src_mask=mask_t, qk_mask=(true_rel_dist>=0).float(),
+                                                         rel_dist=rel_dist, collect_loss=True)
+                assert not self.rpreper.active, "TODO: Not supported for this mode"
+                all_losses.extend([enc_loss1, enc_loss2])
+                # plm loss with explict two inputs
+                plm_loss = self.plainlm.loss([enc_t1, enc_t2], input_map)
+                all_losses.append(plm_loss)
+            else:
+                # here use original input
+                emb_t, mask_t, enc_t, cache, enc_loss = self._emb_and_enc(input_map, collect_loss=True)
+                all_losses.append(enc_loss)
+                # plm loss
+                plm_loss = self.plainlm.loss(enc_t, input_map)
+                all_losses.append(plm_loss)
+            # agreement loss
+            assert self.lambda_agree.value==0., "Not implemented for this mode"
+        # =====
+        # task loss
+        dpar_loss_lambda, upos_loss_lambda, ner_loss_lambda = \
+            [0. if z is None else z.loss_lambda.value for z in [self.dpar, self.upos, self.ner]]
+        if any(z>0. for z in [dpar_loss_lambda, upos_loss_lambda, ner_loss_lambda]):
+            # here use original input
+            emb_t, mask_t, enc_t, cache, enc_loss = self._emb_and_enc(input_map, collect_loss=True, insts=insts)
+            all_losses.append(enc_loss)
+            # parsing loss
+            if dpar_loss_lambda > 0.:
+                dpar_input_attn = self.prepr_f(cache, self._get_rel_dist(BK.get_shape(mask_t, -1)))
+                dpar_loss = self.dpar.loss(copied_insts, enc_t, dpar_input_attn, mask_t)
+                all_losses.append(dpar_loss)
+            # pos loss
+            if upos_loss_lambda > 0.:
+                upos_loss = self.upos.loss(copied_insts, enc_t, mask_t)
+                all_losses.append(upos_loss)
+            # ner loss
+            if ner_loss_lambda > 0.:
+                ner_loss = self.ner.loss(copied_insts, enc_t, mask_t)
+                all_losses.append(ner_loss)
+        # -----
+        info = self.collect_loss_and_backward(all_losses, training, loss_factor)
+        info.update({"fb": 1, "sent": len(insts), "tok": sum(len(z) for z in insts)})
+        return info
+    def inference_on_batch(self, insts: List[GeneralSentence], **kwargs):
+        conf = self.conf
+        self.refresh_batch(False)
+        with BK.no_grad_env():
+            # special mode
+            # use: CUDA_VISIBLE_DEVICES=3 PYTHONPATH=../../src/ python3 -m pdb ../../src/tasks/cmd.py zmlm.main.test ${RUN_DIR}/_conf device:0 dict_dir:${RUN_DIR}/ model_load_name:${RUN_DIR}/zmodel.best test:./_en.debug test_interactive:1
+            if conf.test_interactive:
+                iinput_sent = input(">> (Interactive testing) Input sent sep by blanks: ")
+                iinput_tokens = iinput_sent.split()
+                if len(iinput_sent) > 0:
+                    iinput_inst = GeneralSentence.create(iinput_tokens)
+                    iinput_inst.word_seq.set_idxes([self.word_vocab.get_else_unk(w) for w in iinput_inst.word_seq.vals])
+                    iinput_inst.char_seq.build_idxes(self.inputter.vpack.get_voc("char"))
+                    iinput_map = self.inputter([iinput_inst])
+                    iinput_erase_mask = np.asarray([[z=="Z" for z in iinput_tokens]]).astype(dtype=np.float32)
+                    iinput_masked_map = self.inputter.mask_input(iinput_map, iinput_erase_mask, set("pos"))
+                    emb_t, mask_t, enc_t, cache, enc_loss = self._emb_and_enc(iinput_masked_map, collect_loss=False, insts=[iinput_inst])
+                    mlm_loss = self.masklm.loss(enc_t, iinput_erase_mask, iinput_map)
+                    dpar_input_attn = self.prepr_f(cache, self._get_rel_dist(BK.get_shape(mask_t, -1)))
+                    self.dpar.predict([iinput_inst], enc_t, dpar_input_attn, mask_t)
+                    self.upos.predict([iinput_inst], enc_t, mask_t)
+                    # print them
+                    import pandas as pd
+                    cur_fields = {
+                        "idxes": list(range(1, len(iinput_inst)+1)),
+                        "word": iinput_inst.word_seq.vals, "pos": iinput_inst.pred_pos_seq.vals,
+                        "head": iinput_inst.pred_dep_tree.heads[1:], "dlab": iinput_inst.pred_dep_tree.labels[1:]}
+                    zlog(f"Result:\n{pd.DataFrame(cur_fields).to_string()}")
+                return {}  # simply return here for interactive mode
+            # -----
+            # test for MLM simply as in training (use special separate rand_gen to keep the masks the same for testing)
+            # todo(+2): do we need to keep testing/validing during training the same? Currently not!
+            info = self.fb_on_batch(insts, training=False, rand_gen=self.testing_rand_gen, assign_attns=conf.testing_get_attns)
+            # -----
+            if len(insts) == 0:
+                return info
+            # decode for dpar
+            input_map = self.inputter(insts)
+            emb_t, mask_t, enc_t, cache, _ = self._emb_and_enc(input_map, collect_loss=False, insts=insts)
+            dpar_input_attn = self.prepr_f(cache, self._get_rel_dist(BK.get_shape(mask_t, -1)))
+            self.dpar.predict(insts, enc_t, dpar_input_attn, mask_t)
+            self.upos.predict(insts, enc_t, mask_t)
+            if self.ner is not None:
+                self.ner.predict(insts, enc_t, mask_t)
+            # -----
+            if conf.testing_get_attns:
+                if conf.enc_choice == "vrec":
+                    self._assign_attns_item(insts, "orig", cache=cache)
+                elif conf.enc_choice in ["original"]:
+                    pass
+                else:
+                    raise NotImplementedError()
+            return info
+    def _assign_attns_item(self, insts, prefix, input_erase_mask_arr=None, abs_posi_arr=None, cache=None):
+        if cache is not None:
+            attn_names, attn_list = [], []
+            for one_sidx, one_attn in enumerate(cache.list_attn):
+                attn_names.append(f"{prefix}_att{one_sidx}")
+                attn_list.append(one_attn)
+            if cache.accu_attn is not None:
+                attn_names.append(f"{prefix}_att_accu")
+                attn_list.append(cache.accu_attn)
+            for one_name, one_attn in zip(attn_names, attn_list):
+                # (step_idx, ) -> [bs, len_q, len_k, head]
+                one_attn_arr = BK.get_value(one_attn)
+                for bidx, inst in enumerate(insts):
+                    save_arr = one_attn_arr[bidx]
+                    inst.add_item(one_name, NpArrField(save_arr, float_decimal=4), assert_non_exist=False)
+        if abs_posi_arr is not None:
+            for bidx, inst in enumerate(insts):
+                inst.add_item(f"{prefix}_abs_posi",
+                              NpArrField(abs_posi_arr[bidx], float_decimal=0), assert_non_exist=False)
+        if input_erase_mask_arr is not None:
+            for bidx, inst in enumerate(insts):
+                inst.add_item(f"{prefix}_erase_mask",
+                              NpArrField(input_erase_mask_arr[bidx], float_decimal=4), assert_non_exist=False)
+        # -----
+# b tasks/zmlm/model/mtl:56
diff --git a/tasks/zmlm/run/__init__.py b/tasks/zmlm/run/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tasks/zmlm/run/confs.py b/tasks/zmlm/run/confs.py
new file mode 100644
index 0000000..f420db7
--- /dev/null
+++ b/tasks/zmlm/run/confs.py
@@ -0,0 +1,91 @@
+# common configurations
+import json
+from msp import utils, nn
+from msp.nn import NIConf
+from msp.utils import Conf, Logger, zopen, zfatal, zwarn
+from .vocab import MLMVocabPackageConf
+# todo(note): when testing, unless in the default setting, need to specify: "dict_dir" and "model_load_name"
+class DConf(Conf):
+    def __init__(self):
+        # vocab
+        self.vconf = MLMVocabPackageConf()
+        # data paths
+        self.train = ""
+        self.dev = ""
+        self.test = ""
+        self.cache_data = True          # turn off if large data
+        self.to_cache_shuffle = False
+        self.dict_dir = "./"
+        # cuttings for training (simply for convenience without especially preparing data...)
+        self.cut_train = -1  # <0 means no cut!
+        self.cut_dev = -1
+        # save name for trainer not here!!
+        self.model_load_name = "zmodel.best"         # load name
+        self.output_file = "zout"
+        # format (conllu, plain, json)
+        self.input_format = "conllu"
+        self.dev_input_format = ""  # special mixing case
+        self.output_format = "json"
+        # special processing
+        self.lower_case = False
+        self.norm_digit = False  # norm digits to 0
+        # todo(note): deprecated; do not mix things in this way. Use pos field and word-thresh to control things!
+        self.backoff_pos_idx = -1  # <0 means nope, when idx>=this, then backoff to pos ones
+    def do_validate(self):
+        if len(self.dev_input_format)==0:
+            self.dev_input_format = self.input_format
+# the overall conf
+class OverallConf(Conf):
+    def __init__(self, args):
+        # top-levels
+        self.conf_output = ""
+        self.log_file = Logger.MAGIC_CODE
+        self.msp_seed = 9341
+        #
+        self.niconf = NIConf()      # nn-init-conf
+        self.dconf = DConf()        # data-conf
+        # model
+        from ..model.mtl import MtlMlmModelConf
+        self.mconf = MtlMlmModelConf()
+        # =====
+        #
+        self.update_from_args(args)
+        self.validate()
+# get model
+def build_model(conf, vpack):
+    mconf = conf.mconf
+    from ..model.mtl import MtlMlmModel
+    model = MtlMlmModel(mconf, vpack)
+    return model
+# the start of everything
+def init_everything(args, ConfType=None, quite=False):
+    # search for basic confs
+    all_argv, basic_argv = Conf.search_args(args, ["conf_output", "log_file", "msp_seed"],
+                                            [str, str, int], [None, Logger.MAGIC_CODE, None])
+    # for basic argvs
+    conf_output = basic_argv["conf_output"]
+    log_file = basic_argv["log_file"]
+    msp_seed = basic_argv["msp_seed"]
+    if conf_output:
+        with zopen(conf_output, "w") as fd:
+            for k,v in all_argv.items():
+                # todo(note): do not save this one
+                if k != "conf_output":
+                    fd.write(f"{k}:{v}\n")
+    utils.init(log_file, msp_seed, quite=quite)
+    # real init of the conf
+    if ConfType is None:
+        conf = OverallConf(args)
+    else:
+        conf = ConfType(args)
+    nn.init(conf.niconf)
+    return conf
diff --git a/tasks/zmlm/run/run.py b/tasks/zmlm/run/run.py
new file mode 100644
index 0000000..9604e52
--- /dev/null
+++ b/tasks/zmlm/run/run.py
@@ -0,0 +1,291 @@
+# some processing tools on the data stream
+from typing import List, Iterable, Dict
+from collections import OrderedDict
+from copy import deepcopy
+from msp.utils import zlog, zwarn, zopen, GLOBAL_RECORDER, Helper, zcheck
+from msp.data import FAdapterStreamer, BatchArranger, InstCacher
+from msp.zext.process_train import TrainingRunner, RecordResult
+from msp.zext.process_test import TestingRunner, ResultManager
+from msp.zext.dpar import ParserEvaler
+from msp.data import WordNormer
+from ..data.insts import GeneralSentence
+from ..data.io import BaseDataWriter, BaseDataReader, ConlluParseReader, PlainTextReader, ConllNerReader
+from .vocab import MLMVocabPackage, UD2_POS_UNK_MAP
+from .seqeval import get_prf_scores
+# =====
+# preparing data
+# pre-processing (lower_case, norm_digit)
+class PreprocessStreamer(FAdapterStreamer):
+    def __init__(self, in_stream, lower_case: bool, norm_digit: bool):
+        super().__init__(in_stream, self._go_prep, True)
+        self.normer = WordNormer(lower_case, norm_digit)
+    def _go_prep(self, inst: GeneralSentence):
+        inst.word_seq.reset(self.normer.norm_stream(inst.word_seq.vals))
+# for indexing instances
+class IndexerStreamer(FAdapterStreamer):
+    def __init__(self, in_stream, vpack: MLMVocabPackage, inst_preparer, backoff_pos_idx: int):
+        super().__init__(in_stream, self._go_index, False)
+        # -----
+        self.w_vocab = vpack.get_voc("word")
+        self.w2_vocab = vpack.get_voc("word2")  # extra set
+        self.c_vocab = vpack.get_voc("char")
+        self.p_vocab = vpack.get_voc("pos")
+        self.l_vocab = vpack.get_voc("deplabel")
+        self.n_vocab = vpack.get_voc("ner")
+        self.inst_preparer = inst_preparer
+        self.backoff_pos_idx = backoff_pos_idx
+    def _go_index(self, inst: GeneralSentence):
+        # word
+        # todo(warn): remember to norm word; replace singleton at model's input, not here
+        if inst.word_seq.has_vals():
+            w_voc = self.w_vocab
+            word_idxes = [w_voc.get_else_unk(w) for w in inst.word_seq.vals]  # todo(note): currently unk-idx is large
+            if self.backoff_pos_idx >= 0:  # when using this mode, there must be backoff strings in word vocab
+                zwarn("The option of 'backoff_pos_idx' is deprecated, do not mix things in this way!")
+                word_backoff_idxes = [w_voc[UD2_POS_UNK_MAP[z]] for z in inst.pos_seq.vals]
+                word_idxes = [(zb if z>=self.backoff_pos_idx else z) for z, zb in zip(word_idxes, word_backoff_idxes)]
+            inst.word_seq.set_idxes(word_idxes)
+        # =====
+        # aug extra word set!!
+        if self.w2_vocab is not None:
+            w2_voc = self.w2_vocab
+            inst.word2_seq = deepcopy(inst.word_seq)
+            word_idxes = [w2_voc.get_else_unk(w) for w in inst.word2_seq.vals]  # todo(note): currently unk-idx is large
+            inst.word2_seq.set_idxes(word_idxes)
+        # =====
+        # others
+        char_seq, pos_seq, dep_tree, ner_seq = [getattr(inst, z, None) for z in ["char_seq", "pos_seq", "dep_tree", "ner_seq"]]
+        if char_seq is not None and char_seq.has_vals():
+            char_seq.build_idxes(self.c_vocab)
+        if pos_seq is not None and pos_seq.has_vals():
+            pos_seq.build_idxes(self.p_vocab)
+        if dep_tree is not None and dep_tree.has_vals():
+            dep_tree.build_label_idxes(self.l_vocab)
+        if ner_seq is not None and ner_seq.has_vals():
+            ner_seq.build_idxes(self.n_vocab)
+        if self.inst_preparer is not None:
+            inst = self.inst_preparer(inst)
+        # in fact, inplaced if not wrapping model specific preparer
+        return inst
+def index_stream(in_stream, vpack, cached, cache_shuffle, inst_preparer, backoff_pos_idx):
+    i_stream = IndexerStreamer(in_stream, vpack, inst_preparer, backoff_pos_idx)
+    if cached:
+        return InstCacher(i_stream, shuffle=cache_shuffle)
+    else:
+        return i_stream
+# for arrange batches
+def batch_stream(in_stream, batch_size, ticonf, training):
+    if training:
+        batch_size_f = (lambda x: max(len(x), MIN_SENT_LEN_FOR_BSIZE)) if ticonf.train_batch_on_len else None
+        b_stream = BatchArranger(in_stream, batch_size=batch_size, maxibatch_size=ticonf.train_maxibatch_size, batch_size_f=batch_size_f,
+                                 dump_detectors=lambda one: len(one)>=ticonf.train_max_length or len(one)<ticonf.train_min_length,
+                                 single_detectors=None, sorting_keyer=len, shuffling=ticonf.train_shuffle)
+    else:
+        batch_size_f = (lambda x: max(len(x), MIN_SENT_LEN_FOR_BSIZE)) if ticonf.test_batch_on_len else None
+        b_stream = BatchArranger(in_stream, batch_size=batch_size, maxibatch_size=ticonf.test_maxibatch_size, batch_size_f=batch_size_f,
+                                 dump_detectors=None, single_detectors=lambda one: len(one)>=ticonf.test_single_length,
+                                 sorting_keyer=len, shuffling=False)
+    return b_stream
+# =====
+# data io
+def get_data_reader(file_or_fd, input_format, **kwargs):
+    if input_format == "conllu":
+        r = ConlluParseReader(file_or_fd, **kwargs)
+    elif input_format == "ner":
+        r = ConllNerReader(file_or_fd, **kwargs)
+    elif input_format == "json":
+        r = BaseDataReader(GeneralSentence, file_or_fd)
+    elif input_format == "plain":
+        r = PlainTextReader(file_or_fd, **kwargs)
+    else:
+        raise NotImplementedError(f"Unknown input_format {input_format}")
+    return r
+def get_data_writer(file_or_fd, output_format, **kwargs):
+    if output_format == "json":
+        return BaseDataWriter(file_or_fd, **kwargs)
+    else:
+        raise NotImplementedError(f"Unknown output_format {output_format}")
+# =====
+# runner
+class MltResultMannager(ResultManager):
+    def __init__(self, vpack, outf, goldf, out_format):
+        self.insts = []
+        #
+        self.vpack = vpack
+        self.outf = outf
+        self.goldf = goldf  # todo(note): goldf is only for reporting which file to compare?
+        self.out_format = out_format
+    def add(self, ones: List[GeneralSentence]):
+        self.insts.extend(ones)
+    # write, eval & summary
+    def end(self):
+        # sorting by idx of reading
+        self.insts.sort(key=lambda x: x.sid)
+        # todo(+1): write other output file
+        if self.outf is not None:
+            with zopen(self.outf, "w") as fd:
+                data_writer = get_data_writer(fd, self.out_format)
+                data_writer.write_list(self.insts)
+        # eval for parsing & pos & ner
+        evaler = ParserEvaler()
+        ner_golds, ner_preds = [], []
+        has_syntax, has_ner = False, False
+        for one_inst in self.insts:
+            gold_pos_seq = getattr(one_inst, "pos_seq", None)
+            pred_pos_seq = getattr(one_inst, "pred_pos_seq", None)
+            gold_ner_seq = getattr(one_inst, "ner_seq", None)
+            pred_ner_seq = getattr(one_inst, "pred_ner_seq", None)
+            gold_tree = getattr(one_inst, "dep_tree", None)
+            pred_tree = getattr(one_inst, "pred_dep_tree", None)
+            if gold_pos_seq is not None and gold_tree is not None:
+                # todo(+N): only ARTI_ROOT in trees
+                gold_pos, gold_heads, gold_labels = gold_pos_seq.vals, gold_tree.heads[1:], gold_tree.labels[1:]
+                pred_pos = pred_pos_seq.vals if (pred_pos_seq is not None) else [""]*len(gold_pos)
+                pred_heads, pred_labels = (pred_tree.heads[1:], pred_tree.labels[1:]) if (pred_tree is not None) \
+                    else ([-1]*len(gold_heads), [""]*len(gold_labels))
+                evaler.eval_one(gold_pos, gold_heads, gold_labels, pred_pos, pred_heads, pred_labels)
+                has_syntax = True
+            if gold_ner_seq is not None and pred_ner_seq is not None:
+                ner_golds.append(gold_ner_seq.vals)
+                ner_preds.append(pred_ner_seq.vals)
+                has_ner = True
+        # -----
+        zlog("Results of %s vs. %s" % (self.outf, self.goldf), func="result")
+        if has_syntax:
+            report_str, res = evaler.summary()
+            zlog(report_str, func="result")
+        else:
+            res = {}
+        if has_ner:
+            for _n, _v in zip("prf", get_prf_scores(ner_golds, ner_preds)):
+                res[f"ner_{_n}"] = _v
+        res["gold"] = self.goldf  # record which file
+        zlog("zzzzztest: testing result is " + str(res))
+        # note that res['res'] will not be used
+        if 'res' in res:
+            del res['res']
+        # -----
+        return res
+# testing runner
+class MltTestingRunner(TestingRunner):
+    def __init__(self, model, vpack, outf, goldf, out_format):
+        super().__init__(model)
+        self.res_manager = MltResultMannager(vpack, outf, goldf, out_format)
+    # run and record for one batch
+    def _run_batch(self, insts):
+        res = self.model.inference_on_batch(insts)
+        self.res_manager.add(insts)
+        return res
+    # eval & report
+    def _run_end(self):
+        x = self.test_recorder.summary()
+        res = self.res_manager.end()
+        x.update(res)
+        MltDevResult.calc_acc(x)
+        Helper.printd(x, sep=" || ")
+        return x
+class MltDevResult(RecordResult):
+    def __init__(self, result_ds: List[Dict]):
+        self.all_results = result_ds
+        rr = {}
+        for idx, vv in enumerate(result_ds):
+            rr["zres"+str(idx)] = vv                        # record all the results for printing
+        # res = 0. if (len(result_ds)==0) else result_ds[0]["res"]        # but only use the first one as DEV-res
+        # todo(note): use the first dev result, and use UAS if res<=0.
+        if len(result_ds) == 0:
+            res = 0.
+        else:
+            res = result_ds[0]["res"]
+        super().__init__(rr, res)
+    @staticmethod
+    def calc_acc(x):
+        # put results for MLM/ORP
+        addings = OrderedDict()
+        for comp_name, comp_loss_names in zip(["mlm", "plm", "orp", "dp", "pos", "ner"],
+                                              [["word", "pos"], ["l2r", "r2l"], ["d-1", "d-2"], ["head", "label"],
+                                               ["slab"], ["slab", "crf"]]):
+            for c in comp_loss_names:
+                addings[f"res_{c}.acc"] = x.get(f"loss:{comp_name}.{c}_corr", 0.) / (x.get(f"loss:{comp_name}.{c}_count", 0.) + 1e-5)
+                addings[f"res_{c}.nll"] = x.get(f"loss:{comp_name}.{c}_sum", 0.) / (x.get(f"loss:{comp_name}.{c}_count", 0.) + 1e-5)
+        for k,v in addings.items():
+            if v>0.:  # only care active ones
+                x[k] = v
+        # -----
+        # set "res"
+        res_cands = [-addings["res_word.nll"], -addings["res_l2r.nll"], -addings["res_r2l.nll"],
+                     -addings["res_d-1.nll"], -addings["res_d-2.nll"], x.get("tok_las", 0.),
+                     addings["res_slab.acc"], x.get("ner_f", 0.)]
+        x["res"] = sum(z for z in res_cands if z!=0.)
+# training runner
+class MltTrainingRunner(TrainingRunner):
+    def __init__(self, rconf, model, vpack, dev_outfs, dev_goldfs, dev_out_format):
+        super().__init__(rconf, model)
+        self.vpack = vpack
+        self.dev_out_format = dev_out_format
+        #
+        self.dev_goldfs = dev_goldfs
+        if isinstance(dev_outfs, (list, tuple)):
+            zcheck(len(dev_outfs) == len(dev_goldfs), "Mismatched number of output and gold!")
+            self.dev_outfs = dev_outfs
+        else:
+            self.dev_outfs = [dev_outfs] * len(dev_goldfs)
+    # to be implemented
+    def _run_fb(self, annotated_insts, loss_factor: float):
+        res = self.model.fb_on_batch(annotated_insts, loss_factor=loss_factor)
+        return res
+    def _run_train_report(self):
+        x = self.train_recorder.summary()
+        y = GLOBAL_RECORDER.summary()
+        if len(y) > 0:
+            x.update(y)
+        MltDevResult.calc_acc(x)
+        Helper.printd(x, " || ")
+        return RecordResult(x, score=x.get("res", 0.))
+    def _run_update(self, lrate: float, grad_factor: float):
+        self.model.update(lrate, grad_factor)
+    def _run_validate(self, dev_streams):
+        if not isinstance(dev_streams, Iterable):
+            dev_streams = [dev_streams]
+        dev_results = []
+        zcheck(len(dev_streams) == len(self.dev_goldfs), "Mismatch number of streams!")
+        dev_idx = 0
+        for one_stream, one_dev_outf, one_dev_goldf in zip(dev_streams, self.dev_outfs, self.dev_goldfs):
+            # todo(+2): overwrite previous ones?
+            rr = MltTestingRunner(self.model, self.vpack, one_dev_outf+".dev"+str(dev_idx), one_dev_goldf, self.dev_out_format)
+            x = rr.run(one_stream)
+            dev_results.append(x)
+            dev_idx += 1
+        return MltDevResult(dev_results)
diff --git a/tasks/zmlm/run/seqeval.py b/tasks/zmlm/run/seqeval.py
new file mode 100644
index 0000000..211f871
--- /dev/null
+++ b/tasks/zmlm/run/seqeval.py
@@ -0,0 +1,132 @@
+# directly from https://github.com/chakki-works/seqeval/ (on 29 Apr 2019)
+"""Metrics to assess performance on sequence labeling task given prediction
+Functions named as ``*_score`` return a scalar value to maximize: the higher
+the better
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from collections import defaultdict
+import numpy as np
+def get_entities(seq, suffix=False):
+    """Gets entities from sequence.
+    Args:
+        seq (list): sequence of labels.
+    Returns:
+        list: list of (chunk_type, chunk_start, chunk_end).
+    Example:
+        # >>> from seqeval.metrics.sequence_labeling import get_entities
+        # >>> seq = ['B-PER', 'I-PER', 'O', 'B-LOC']
+        # >>> get_entities(seq)
+        [('PER', 0, 1), ('LOC', 3, 3)]
+    """
+    # for nested list
+    if any(isinstance(s, list) for s in seq):
+        seq = [item for sublist in seq for item in sublist + ['O']]
+    prev_tag = 'O'
+    prev_type = ''
+    begin_offset = 0
+    chunks = []
+    for i, chunk in enumerate(seq + ['O']):
+        if suffix:
+            tag = chunk[-1]
+            type_ = chunk.split('-')[0]
+        else:
+            tag = chunk[0]
+            type_ = chunk.split('-')[-1]
+        if end_of_chunk(prev_tag, tag, prev_type, type_):
+            chunks.append((prev_type, begin_offset, i-1))
+        if start_of_chunk(prev_tag, tag, prev_type, type_):
+            begin_offset = i
+        prev_tag = tag
+        prev_type = type_
+    return chunks
+def end_of_chunk(prev_tag, tag, prev_type, type_):
+    """Checks if a chunk ended between the previous and current word.
+    Args:
+        prev_tag: previous chunk tag.
+        tag: current chunk tag.
+        prev_type: previous type.
+        type_: current type.
+    Returns:
+        chunk_end: boolean.
+    """
+    chunk_end = False
+    if prev_tag == 'E': chunk_end = True
+    if prev_tag == 'S': chunk_end = True
+    if prev_tag == 'B' and tag == 'B': chunk_end = True
+    if prev_tag == 'B' and tag == 'S': chunk_end = True
+    if prev_tag == 'B' and tag == 'O': chunk_end = True
+    if prev_tag == 'I' and tag == 'B': chunk_end = True
+    if prev_tag == 'I' and tag == 'S': chunk_end = True
+    if prev_tag == 'I' and tag == 'O': chunk_end = True
+    if prev_tag != 'O' and prev_tag != '.' and prev_type != type_:
+        chunk_end = True
+    return chunk_end
+def start_of_chunk(prev_tag, tag, prev_type, type_):
+    """Checks if a chunk started between the previous and current word.
+    Args:
+        prev_tag: previous chunk tag.
+        tag: current chunk tag.
+        prev_type: previous type.
+        type_: current type.
+    Returns:
+        chunk_start: boolean.
+    """
+    chunk_start = False
+    if tag == 'B': chunk_start = True
+    if tag == 'S': chunk_start = True
+    if prev_tag == 'E' and tag == 'E': chunk_start = True
+    if prev_tag == 'E' and tag == 'I': chunk_start = True
+    if prev_tag == 'S' and tag == 'E': chunk_start = True
+    if prev_tag == 'S' and tag == 'I': chunk_start = True
+    if prev_tag == 'O' and tag == 'E': chunk_start = True
+    if prev_tag == 'O' and tag == 'I': chunk_start = True
+    if tag != 'O' and tag != '.' and prev_type != type_:
+        chunk_start = True
+    return chunk_start
+def get_prf_scores(y_true, y_pred, suffix=False):
+    true_entities = set(get_entities(y_true, suffix))
+    pred_entities = set(get_entities(y_pred, suffix))
+    nb_correct = len(true_entities & pred_entities)
+    nb_pred = len(pred_entities)
+    nb_true = len(true_entities)
+    p = nb_correct / nb_pred if nb_pred > 0 else 0
+    r = nb_correct / nb_true if nb_true > 0 else 0
+    f1 = 2 * p * r / (p + r) if p + r > 0 else 0
+    return p, r, f1
diff --git a/tasks/zmlm/run/vocab.py b/tasks/zmlm/run/vocab.py
new file mode 100644
index 0000000..8dbe35a
--- /dev/null
+++ b/tasks/zmlm/run/vocab.py
@@ -0,0 +1,149 @@
+# the vocab package
+from typing import Dict
+from msp.utils import Conf, zlog, zwarn
+from msp.data import VocabPackage, VocabBuilder, WordVectors, VocabHelper
+# -----
+# UD2 pre-defined types
+UD2_POS_LIST = ["NOUN", "PUNCT", "VERB", "PRON", "ADP", "DET", "PROPN", "ADJ", "AUX", "ADV", "CCONJ", "PART", "NUM", "SCONJ", "X", "INTJ", "SYM"]
+UD2_LABEL_LIST = ["punct", "case", "nsubj", "det", "root", "nmod", "advmod", "obj", "obl", "amod", "compound", "aux", "conj", "mark", "cc", "cop", "advcl", "acl", "xcomp", "nummod", "ccomp", "appos", "flat", "parataxis", "discourse", "expl", "fixed", "list", "iobj", "csubj", "goeswith", "vocative", "reparandum", "orphan", "dep", "dislocated", "clf"]
+# for POS UNK backoff
+UD2_POS_UNK_MAP = {p: VocabHelper.convert_special_pattern("UNK_"+p) for p in UD2_POS_LIST}
+# -----
+class MLMVocabPackageConf(Conf):
+    def __init__(self):
+        self.add_ud2_prevalues = True  # add for UD2 pos and labels
+        self.add_ud2_pos_backoffs = False
+        # pretrain
+        self.pretrain_file = []
+        self.read_from_pretrain = False
+        self.pretrain_scale = 1.0
+        self.pretrain_init_nohit = 1.0
+        self.pretrain_codes = []
+        self.test_extra_pretrain_files = []  # extra embeddings for testing
+        self.test_extra_pretrain_codes = []
+        # thresholds for word
+        self.word_rthres = 1000000  # rank <= this
+        self.word_fthres = 1  # freq >= this
+        self.ignore_thresh_with_pretrain = True  # if hit in pretrain, also include and ignore thresh
+class MLMVocabPackage(VocabPackage):
+    def __init__(self, vocabs: Dict, embeds: Dict):
+        super().__init__(vocabs, embeds)
+        # -----
+    @staticmethod
+    def build_by_reading(prefix):
+        zlog("Load vocabs from files.")
+        possible_vocabs = ["word", "char", "pos", "deplabel", "ner", "word2"]
+        one = MLMVocabPackage({n:None for n in possible_vocabs}, {n:None for n in possible_vocabs})
+        one.load(prefix=prefix)
+        return one
+    @staticmethod
+    def build_from_stream(build_conf: MLMVocabPackageConf, stream, extra_stream):
+        zlog("Build vocabs from streams.")
+        ret = MLMVocabPackage({}, {})
+        # -----
+        if build_conf.add_ud2_pos_backoffs:
+            ud2_pos_pre_list = list(VocabBuilder.DEFAULT_PRE_LIST) + [UD2_POS_UNK_MAP[p] for p in UD2_POS_LIST]
+            word_builder = VocabBuilder("word", pre_list=ud2_pos_pre_list)
+        else:
+            word_builder = VocabBuilder("word")
+        char_builder = VocabBuilder("char")
+        pos_builder = VocabBuilder("pos")
+        deplabel_builder = VocabBuilder("deplabel")
+        ner_builder = VocabBuilder("ner")
+        if build_conf.add_ud2_prevalues:
+            zlog(f"Add pre-defined UD2 values for upos({len(UD2_POS_LIST)}) and ulabel({len(UD2_LABEL_LIST)}).")
+            pos_builder.feed_stream(UD2_POS_LIST)
+            deplabel_builder.feed_stream(UD2_LABEL_LIST)
+        for inst in stream:
+            word_builder.feed_stream(inst.word_seq.vals)
+            for w in inst.word_seq.vals:
+                char_builder.feed_stream(w)
+            # todo(+N): currently we are assuming that we are using UD pos/deps, and directly go with the default ones
+            # pos and label can be optional??
+            # if inst.poses.has_vals():
+            #     pos_builder.feed_stream(inst.poses.vals)
+            # if inst.deplabels.has_vals():
+            #     deplabel_builder.feed_stream(inst.deplabels.vals)
+            if hasattr(inst, "ner_seq") and inst.ner_seq.has_vals():
+                ner_builder.feed_stream(inst.ner_seq.vals)
+        # ===== embeddings
+        w2vec = None
+        if build_conf.read_from_pretrain:
+            # todo(warn): for convenience, extra vocab (usually dev&test) is only used for collecting pre-train vecs
+            extra_word_set = set(w for inst in extra_stream for w in inst.word_seq.vals)
+            # ----- load (possibly multiple) pretrain embeddings
+            # must provide build_conf.pretrain_file (there can be multiple pretrain files!)
+            list_pretrain_file, list_code_pretrain = build_conf.pretrain_file, build_conf.pretrain_codes
+            list_code_pretrain.extend([""] * len(list_pretrain_file))  # pad default ones
+            w2vec = WordVectors.load(list_pretrain_file[0], aug_code=list_code_pretrain[0])
+            if len(list_pretrain_file) > 1:
+                w2vec.merge_others([WordVectors.load(list_pretrain_file[i], aug_code=list_code_pretrain[i])
+                                    for i in range(1, len(list_pretrain_file))])
+            # -----
+            # first filter according to thresholds
+            word_builder.filter(
+                lambda ww, rank, val: (val >= build_conf.word_fthres and rank <= build_conf.word_rthres) or
+                                      (build_conf.ignore_thresh_with_pretrain and w2vec.has_key(ww)))
+            # then add extra ones
+            if build_conf.ignore_thresh_with_pretrain:
+                for w in extra_word_set:
+                    if w2vec.has_key(w) and (not word_builder.has_key_currently(w)):
+                        word_builder.feed_one(w)
+            word_vocab = word_builder.finish()
+            word_embed1 = word_vocab.filter_embed(w2vec, init_nohit=build_conf.pretrain_init_nohit,
+                                                  scale=build_conf.pretrain_scale)
+        else:
+            word_vocab = word_builder.finish_thresh(rthres=build_conf.word_rthres, fthres=build_conf.word_fthres)
+            word_embed1 = None
+        #
+        char_vocab = char_builder.finish()
+        pos_vocab = pos_builder.finish(sort_by_count=False)
+        deplabel_vocab = deplabel_builder.finish(sort_by_count=False)
+        ner_vocab = ner_builder.finish()
+        # assign
+        ret.put_voc("word", word_vocab)
+        ret.put_voc("char", char_vocab)
+        ret.put_voc("pos", pos_vocab)
+        ret.put_voc("deplabel", deplabel_vocab)
+        ret.put_voc("ner", ner_vocab)
+        ret.put_emb("word", word_embed1)
+        #
+        return ret
+    # ======
+    # special mode extend another word vocab
+    # todo(+N): ugly patch!!
+    def aug_word2_vocab(self, stream, extra_stream, extra_embed_file: str):
+        zlog(f"Aug another word vocab from streams and extra_embed_file={extra_embed_file}")
+        word_builder = VocabBuilder("word2")
+        for inst in stream:
+            word_builder.feed_stream(inst.word_seq.vals)
+        # embeddings
+        if len(extra_embed_file) > 0:
+            extra_word_set = set(w for inst in extra_stream for w in inst.word_seq.vals)
+            w2vec = WordVectors.load(extra_embed_file)
+            for w in extra_word_set:
+                if w2vec.has_key(w) and (not word_builder.has_key_currently(w)):
+                    word_builder.feed_one(w)
+            word_vocab = word_builder.finish()  # no filtering!!
+            word_embed1 = word_vocab.filter_embed(w2vec, init_nohit=1.0, scale=1.0)
+        else:
+            zwarn("WARNING: No pretrain file for aug node!!")
+            word_vocab = word_builder.finish()  # no filtering!!
+            word_embed1 = None
+        self.put_voc("word2", word_vocab)
+        self.put_emb("word2", word_embed1)
+    # =====
+# b tasks/zmlm/run/vocab:146