diff --git a/README.md b/README.md index 3a8fcf5..266327e 100644 --- a/README.md +++ b/README.md @@ -12,6 +12,8 @@ How to configurate, generally: [here](docs/conf.md) Related works: +"An Empirical Exploration of Local Ordering Pre-training for Structured Prediction": [TODO](??) + "A Two-Step Approach for Implicit Event Argument Detection": [details](docs/iarg.md) Some other parsers for interested readers: [details](docs/sop.md) diff --git a/docs/lbag.md b/docs/lbag.md new file mode 100644 index 0000000..c0b50e3 --- /dev/null +++ b/docs/lbag.md @@ -0,0 +1,68 @@ +### For the Local Bag pre-training method + +Hi, this describes our implementation for our work: "An Empirical Exploration of Local Ordering Pre-training for Structured Prediction". + +Please refer to the paper for more details: [[paper]](TODO) [[bib]](TODO) + +### Repo + +When we were carrying out our experiments for this work, we used the repo at this commit [`here`](TODO). In later versions of this repo, there may be slight changes (for example, default hyper-parameter change or hyper-parameter name change). + +### Environment + +As those of the main `msp` package: + + python>=3.6 + dependencies: pytorch>=1.0.0, numpy, scipy, gensim, cython, transformers, ... + +### Data + +- Pre-training data: any large corpus can be utilized, we use a random subset of wikipedia. (The format is simply one sentence per line, but **need to be tokenized (separated by spaces)!!**) +- Task data: The dependency parsing data are in CoNLL-U format, which are available from the official UD website. The NER data should otherwise be in the format like those in CoNLL03. + +### Running + +- Step 0: Setup + +Assume we are at a new DIR, and please download this repo into a DIR called `src`: `git clone https://github.com/zzsfornlp/zmsp src` and specify some ENV variables (for convenience): + + SRC_DIR: Root dir of this repo + CUR_LANG: Lang id of the current language (for example en) + WIKI_PRETRAIN_SIZE: Pretraining size + UD_TRAIN_SIZE: Task training size + +- Step 1: Build dictionary with pre-trained data + +Before this, we should have the data prepared (for pre-training and task-training). + +Assume that we have UD files at `data/UD_RUN/ud24s/${CUR_LANG}_train.${UD_TRAIN_SIZE}.conllu`, and pre-training (wiki) files at `data/UD_RUN/wikis/wiki_${CUR_LANG}.${WIKI_PRETRAIN_SIZE}.txt`. + +Assmuing now we are at DIR `data/UD_RUN/vocabs/voc_${CUR_LANG}`, we first create vocabulary for this setting with: + + PYTHONPATH=${SRC_DIR} python3 ${SRC_DIR}/tasks/cmd.py zmlm.main.vocab_utils train:../../wikis/wiki_en.${WIKI_PRETRAIN_SIZE}.txt input_format:plain norm_digit:1 >vv.list + +The `vv_*` files at this dir will be the vocabularies that will be utilized for the remaining steps. + +- Step 2: Do pre-training + +Assuming now we are at DIR `data/..` + +Simply use the script of `${SRC_DIR}/scripts/lbag/run.py` for pre-training. + + python3 ${SRC_DIR}/scripts/lbag/run.py -l ${CUR_LANG} --rgpu 0 --run_dir run_orp_${CUR_LANG} --enc_type trans --run_mode pre --pre_mode orp --train_size ${WIKI_PRETRAIN_SIZE} --do_test 0 + +Note that by default, the data dirs are already pre-set as the ones in step 1, the paths can also be specified, please refer to the script for more details. + +There are various modes for pre-training, the most typical ones are: orp (or lbag, our local reordering strategy), mlm (masked LM), om (orp+mlm). Please use `--pre_mode` to specify. + +This may take a while (it took us three days to pretrain with 1M data on a single GPU). After this, we get the pre-trained models at `run_orp_${CUR_LANG}`. + +- Step 3: Fine-tuning on specific tasks + +Finally, training (fine-tuning) on specific tasks (here on Dep+Pos with UD data) with the pre-trained model. We can still use the script of `${SRC_DIR}/scripts/lbag/run.py`, simply change the `--run_mode` to `ppp1`, together with other information. + + python3 ${SRC_DIR}/scripts/lbag/run.py -l ${CUR_LANG} --rgpu 0 --cur_run 1 --run_dir run_ppp1_${CUR_LANG} --run_mode ppp1 --train_size ${UD_TRAIN_SIZE} --preload_prefix ../run_orp_${CUR_LANG}/zmodel.c200 + +Here, we use the `checkpoint@200` model from the pre-trained dir, other checkpoints can also be specified (only providing the unambigous model prefix will be enough). + +Again, paths are by default the ones we setup from Step 1, if using other paths, things can be also specified with various `--*_dir`. diff --git a/msp/zext/annotators/stanford.py b/msp/zext/annotators/stanford.py index 28205be..4e7dedf 100644 --- a/msp/zext/annotators/stanford.py +++ b/msp/zext/annotators/stanford.py @@ -38,7 +38,7 @@ def zwarn(s): zlog("!!"+str(s)) # ===== -LANG_NAME_MAP = {'es': 'spanish', 'zh': 'chinese', 'en': 'english'} +LANG_NAME_MAP = {'es': 'spanish', 'zh': 'chinese', 'en': 'english', 'ru': 'russian', 'uk': 'ukraine'} # map CTB pos tags to UD (partially reference "Developing Universal Dependencies for Mandarin Chinese") CORENLP_POS_TAGS = [ diff --git a/scripts/lbag/run.py b/scripts/lbag/run.py new file mode 100644 index 0000000..720100b --- /dev/null +++ b/scripts/lbag/run.py @@ -0,0 +1,315 @@ +# + +# (updated one) a single script handling the runnings of multiple modes + +import argparse +import sys +import os + +SCRIPT_ABSPATH = os.path.abspath(__file__) + +def printing(x, end='\n', file=sys.stderr, flush=True): + print(x, end=end, file=file, flush=flush) + +def system(cmd, pp=True, ass=True, popen=False): + if pp: + printing("Executing cmd: %s" % cmd) + n = os.system(cmd) + if ass: + assert n==0, f"Executing previous cmd returns error {n}." + return n + +def parse(): + parser = argparse.ArgumentParser(description='Process running confs.') + # ===== + # -- dirs + # todo(note): be sure to make base_dir is relative to run_dir and others are relative to base_dir + parser.add_argument("--run_dir", type=str, required=True) + # -- + parser.add_argument("--base_dir", type=str, default=None) + parser.add_argument("--src_dir", type=str, default="src") + parser.add_argument("--dict_dir", type=str, default="data/UD_RUN/vocabs") + parser.add_argument("--embed_dir", type=str, default="data/UD_RUN/embeds") + parser.add_argument("--ud_dir", type=str, default="data/UD_RUN/ud24s") + parser.add_argument("--ner_dir", type=str, default="data/ner") + parser.add_argument("--wiki_dir", type=str, default="data/UD_RUN/wikis") + # ===== + # -- basics + # running + parser.add_argument("--run_mode", type=str, required=True) # pre/ppp0/ppp1/... + parser.add_argument("--pre_mode", type=str, default="orp") # plm/mlm/orp + parser.add_argument("--do_test", type=int, default=1, help="testing when training parser") + parser.add_argument("--preload_prefix", type=str) # pre-trained model for "par1" + parser.add_argument("--pre_upe", type=int, default=1000, help="Update for fake epoch(valid)") + # data + parser.add_argument("-l", "--lang", type=str, required=True) + parser.add_argument("--train_size", type=str, default="10k", help="Training size infix") + # wvec and bert + parser.add_argument("--use_wvec", type=int, default=0) + parser.add_argument("--use_bert", type=int, default=0) + parser.add_argument("--aug_wvec2", type=int, default=0) + # model + parser.add_argument("--model_dim", type=int, default=512) + parser.add_argument("--head_count", type=int, default=8, help="Head count in MATT for trans/vrec") + parser.add_argument("--enc_type", type=str, default="trans", choices=["vrec", "trans", "rnn"]) + parser.add_argument("--lm_pred_size", type=int, default=100000) + parser.add_argument("--lm_tie_inputs", type=int, default=1) + parser.add_argument("--use_biaffine", type=int, default=1) + # env + parser.add_argument("--cur_run", type=int, default=1) + parser.add_argument("--rgpu", type=int, required=True) # -1 for CPU, >=0 for GPU + parser.add_argument("--debug", type=int, default=0) + # extras + parser.add_argument("--extras", type=str, default="") + # ----- + args = parser.parse_args() + return args + +# -- +def step0_dir(args): + printing("# =====\n=> Step 0: determine dirs") + run_dir = args.run_dir + system(f"mkdir -p {run_dir}") + os.chdir(run_dir) + # check out dirs based on src + base_dir = args.base_dir + if base_dir is None: + # search for src dir + base_dir = "" + _MAX_TIME = 4 + while not os.path.isdir(os.path.join(base_dir, args.src_dir)) and _MAX_TIME > 0: + base_dir = os.path.join(base_dir, "..") # upper one layer + _MAX_TIME -= 1 + src_dir = os.path.join(base_dir, args.src_dir) + else: + src_dir = os.path.join(base_dir, args.src_dir) + assert os.path.isdir(src_dir), f"Failed to find src dir: {args.src_dir}" + printing(f"Finally run_dir={os.getcwd()}, base_dir={base_dir}, src_dir={src_dir}") + dict_dir = os.path.join(base_dir, args.dict_dir) + embed_dir = os.path.join(base_dir, args.embed_dir) + ud_dir = os.path.join(base_dir, args.ud_dir) + wiki_dir = os.path.join(base_dir, args.wiki_dir) + ner_dir = os.path.join(base_dir, args.ner_dir) + return run_dir, src_dir, dict_dir, embed_dir, ud_dir, wiki_dir, ner_dir + +# ----- +# todo(note): by default for par0 +def get_basic_options(args): + MODEL_DIM = args.model_dim + ENC_TYPE = args.enc_type + HEAD_COUNT = args.head_count + HEAD_DIM = MODEL_DIM // HEAD_COUNT + # ===== + base_opt = "conf_output:_conf" + # => training and lrate schedule + base_opt += " lrate.val:0.0004 lrate_warmup:250 lrate.min_val:0.000001 lrate.m:0.75" # default lrate + base_opt += " min_save_epochs:0 max_epochs:250 patience:8 anneal_times:10" + base_opt += " train_min_length:5 train_max_length:80 train_batch_size:80 split_batch:5 test_batch_size:8" + # => dropout + base_opt += " drop_hidden:0.1 gdrop_rnn:0.33 idrop_rnn:0.33 fix_drop:0" + # => model + # embedder + base_opt += " ec_word.comp_dim:300 ec_word.comp_drop:0.1 ec_word.comp_init_scale:1. ec_word.comp_init_from_pretrain:0" + base_opt += " ec_char.comp_dim:50 ec_char.comp_drop:0.1 ec_char.comp_init_scale:1." + # (NOPOS!!) base_opt+=" ec_pos.comp_dim:50 ec_pos.comp_drop:0.1 ec_pos.comp_init_scale:1." + base_opt += f" emb_proj_dim:{MODEL_DIM} emb_proj_act:linear" # proj after embedder + # encoder + if ENC_TYPE == "rnn": + # RNN baseline + base_opt += " matt_conf.head_count:256 default_attn_count:256 prepr_choice:rdist" # this is not used, simply avoid err + base_opt += f" enc_choice:original oenc_conf.enc_hidden:{MODEL_DIM} oenc_conf.enc_rnn_layer:3 oenc_conf.enc_rnn_sep_bidirection:1" + else: + # use VRec + base_opt += " enc_choice:vrec" + base_opt += " venc_conf.num_layer:6 venc_conf.attn_ranges:" # vrec layers + base_opt += " use_pre_norm:0 use_post_norm:1" # layer norm + # # no scale now!! + # base_opt += " scorer_conf.param_init_scale:0.5 collector_conf.param_init_scale:0.5" # other init + # base_opt += " ec_word.comp_init_scale:0.5 ec_char.comp_init_scale:0.5" # other init + # use MAtt + base_opt += f" matt_conf.head_count:{HEAD_COUNT}" # head count + base_opt += f" scorer_conf.d_qk:{HEAD_DIM} scorer_conf.use_rel_dist:1 scorer_conf.no_self_loop:1" # matt scorer + base_opt += f" normer_conf.use_noop:1 normer_conf.noop_fixed_val:0. normer_conf.norm_mode:cand normer_conf.attn_dropout:0.1" # matt norm + base_opt += f" collector_conf.d_v:{HEAD_DIM} collector_conf.collect_mode:ch collect_reducer_mode:aff collect_red_aff_out:{MODEL_DIM}" # matt col + if ENC_TYPE == "trans": + base_opt += f" venc_conf.share_layers:0 comb_mode:affine ff_dim:{MODEL_DIM*2}" # no share + elif ENC_TYPE == "vrec": + base_opt += " venc_conf.share_layers:1 comb_mode:lstm ff_dim:0" # share layers + else: + raise NotImplementedError() + # => prediction & loss + # vocab + base_opt += " lower_case:0 norm_digit:1 word_rthres:1000000 word_fthres:1" # vocab + base_opt += " ec_word.comp_rare_unk:0.5 ec_word.comp_rare_thr:10" # word rare unk in training + # masklm + base_opt += " mlm_conf.loss_lambda.val:0. mlm_conf.max_pred_rank:100" # masklm + base_opt += " mlm_conf.hid_dim:300 mlm_conf.hid_act:elu" # hid + base_opt += f" mlm_conf.tie_input_embeddings:{args.lm_tie_inputs}" # tie embeddings + # plainlm + base_opt += " plm_conf.loss_lambda.val:0. plm_conf.max_pred_rank:100" # plainlm + base_opt += " plm_conf.hid_dim:300 plm_conf.hid_act:elu" # hid + base_opt += f" plm_conf.tie_input_embeddings:{args.lm_tie_inputs}" # tie embeddings + # orderpr + base_opt += " orp_conf.loss_lambda.val:0. orp_conf.disturb_mode:abs_bin orp_conf.bin_blur_bag_keep_rate:0." + # base_opt += " orp_conf.disturb_ranges:8,10 orp_conf.bin_blur_bag_lbound:-2" + base_opt += " orp_conf.disturb_ranges:7,8" # original one + # base_opt += " orp_conf.pred_range:1 orp_conf.cand_range:4 orp_conf.lambda_n1:0. orp_conf.lambda_n2:1." + base_opt += " orp_conf.pred_range:1 orp_conf.cand_range:1000 orp_conf.lambda_n1:0. orp_conf.lambda_n2:1." # AllCand + base_opt += " orp_conf.ps_conf.use_input_pair:0" # scorer + # better ones for orp + base_opt += " orp_conf.bin_blur_bag_keep_rate:0.5 orp_conf.pred_range:2" \ + " orp_conf.ps_conf.use_biaffine:1 orp_conf.ps_conf.use_ff2:0" + # dpar + base_opt += " dpar_conf.loss_lambda.val:0. dpar_conf.label_neg_rate:0.1" # overall + base_opt += " dpar_conf.pre_dp_space:512 dpar_conf.pre_dp_act:elu" # pre-scoring + base_opt += " dpar_conf.dps_conf.use_input_pair:0" # scorer + if args.use_biaffine: + base_opt += " dpar_conf.dps_conf.use_biaffine:1 dpar_conf.dps_conf.use_ff1:1 dpar_conf.dps_conf.use_ff2:0" + # pos + base_opt += " upos_conf.loss_lambda.val:0." + # -- + return base_opt + +# ---- +def find_model(model_prefix): + model_dir = os.path.dirname(model_prefix) + model_name_prefix = os.path.basename(model_prefix) + model_name_full = None + for f in os.listdir(model_dir): + if f.endswith(".pr.json") and f.startswith(model_name_prefix): + model_name_full = f[:-8] + assert model_name_full is not None + model_path = os.path.join(model_dir, model_name_full) + printing(f"Find model at {model_path}") + return model_path + +# ===== +def step_run(args, DIRs): + run_dir, src_dir, dict_dir, embed_dir, ud_dir, wiki_dir, ner_dir = DIRs + printing(f"# =====\n=> Step 1: running!!") + # basic + cur_opt = get_basic_options(args) + # ===== + # specific mode + CUR_LANG = args.lang + RUN_MODE = args.run_mode + PRE_MODE = args.pre_mode + TRAIN_SIZE = args.train_size + ENC_TYPE = args.enc_type + # wvec & bert? + if args.use_wvec: + cur_opt += " read_from_pretrain:1 ec_word.comp_init_from_pretrain:1" \ + f" pretrain_file:{embed_dir}/wiki.{CUR_LANG}.vec" + if args.use_bert: + cur_opt += " ec_word.comp_dim:0" # no word if using bert (by default mbert) + cur_opt += " ec_bert.comp_dim:768 ec_bert.comp_drop:0.1 ec_bert.comp_init_scale:1." + cur_opt += " bert2_output_layers:4,5,6,7,8 bert2_training_mask_rate:0." + if args.aug_wvec2: # special mode!! + cur_opt += f" aug_word2:1 aug_word2_pretrain:{embed_dir}/wiki.{CUR_LANG}.vec" + # modes + # -- pre-training + if RUN_MODE == "pre": + cur_opt += " min_save_epochs:0 max_epochs:200 patience:1000 anneal_times:5" + cur_opt += f" train:{wiki_dir}/wiki_{CUR_LANG}.{TRAIN_SIZE}.txt dev: test: input_format:plain" \ + f" cache_data:0 no_build_dict:1 dict_dir:{dict_dir}/voc_{CUR_LANG}/" + # different pre-training tasks for different archs + if PRE_MODE == "plm": + cur_opt += f" plm_conf.loss_lambda.val:1. plm_conf.max_pred_rank:{args.lm_pred_size}" + if ENC_TYPE != "rnn": + cur_opt += f" plm_conf.split_input_blm:0" # input list for vrec + elif PRE_MODE == "mlm": + cur_opt += f" mlm_conf.loss_lambda.val:1. mlm_conf.max_pred_rank:{args.lm_pred_size}" + cur_opt += f" mlm_conf.mask_rate:0.15" + elif PRE_MODE == "orp": + cur_opt += " orp_conf.loss_lambda.val:1." + elif PRE_MODE == "om": # orp + mlm + cur_opt += " orp_conf.loss_lambda.val:0.5 orp_conf.bin_blur_bag_keep_rate:0.5" # keep for mlm + cur_opt += f" mlm_conf.loss_lambda.val:0.5 mlm_conf.max_pred_rank:{args.lm_pred_size}" + cur_opt += f" mlm_conf.mask_rate:0.15" + else: + raise NotImplementedError() + # ----- + # # schedule + # # (OldRun) when bs=80: (UPE*80 as one epoch): 600*CHECKPOINT seems to be enough (~50 epochs for 1m) + # # (NewRun) when bs=240: (UPE*240 as one epoch): use 300*CHECKPOINT (~72 epochs) + # UPE = 1000 # number of Update per Pseudo Epoch + # cur_opt += " train_batch_size:240 split_batch:15" # larger batch size for pre-training + # cur_opt += f" report_freq:{UPE} valid_freq:{UPE} max_updates:{304*UPE}" + # cur_opt += " save_freq:50 validate_epoch:0" # save more points + # cur_opt += f" lrate_warmup:{5*UPE} lrate.which_idx:uidx lrate.start_bias:{50*UPE} lrate.scale:{50*UPE}" + # ----- + # schedule2 + # (NewRun2) when bs=480: (UPE*480 as one epoch): use 200*CHECKPOINT (~96 epochs) + UPE = args.pre_upe # number of Update per Pseudo Epoch + cur_opt += " train_batch_size:480 split_batch:30" # larger batch size for pre-training + cur_opt += f" report_freq:{UPE} valid_freq:{UPE} max_updates:{201*UPE} save_freq:40 validate_epoch:0" + cur_opt += f" lrate_warmup:{5*UPE} lrate.which_idx:uidx lrate.start_bias:{40*UPE} lrate.scale:{40*UPE} lrate.m:0.5" + # -- parser/tagger training + else: + assert RUN_MODE in ["par0", "par1", "pos0", "pos1", "ppp0", "ppp1", "ner0", "ner1"] + if RUN_MODE.startswith("ner"): + cur_opt += " do_ner:1 input_format:ner div_by_tok:0" # remember to add the extra flag!! + cur_opt += f" train:{ner_dir}/{CUR_LANG}/train.{TRAIN_SIZE}.txt dev:{ner_dir}/{CUR_LANG}/dev.txt" \ + f" test:{ner_dir}/{CUR_LANG}/test.txt" + # extra settings for ner + # cur_opt += " train_batch_size:80 split_batch:2 drop_hidden:0.5 max_epochs:10000000 validate_epoch:0 valid_freq:200" \ + # " lrate_warmup:500 max_updates:30000 div_by_tok:0 ec_word.comp_drop:0.5 ec_char.comp_drop:0.5" + # cur_opt += " drop_hidden:0.5 ec_word.comp_drop:0.5 ec_char.comp_drop:0.5" + else: + cur_opt += f" train:{ud_dir}/{CUR_LANG}_train.{TRAIN_SIZE}.conllu dev:{ud_dir}/{CUR_LANG}_dev.conllu" \ + f" test:{ud_dir}/{CUR_LANG}_test.conllu" + task_comp_names = {"par": ["dpar"], "pos": ["upos"], "ppp": ["dpar", "upos"], "ner": ["ner"]}[RUN_MODE[:3]] + for one_tcn in task_comp_names: + cur_opt += f" {one_tcn}_conf.loss_lambda.val:1." + if RUN_MODE[-1] == "1": + # use pretrain vocab + cur_opt += f" no_build_dict:1 dict_dir:{dict_dir}/voc_{CUR_LANG}/" + # find preload model + cur_opt += f" load_pretrain_model_name:{find_model(args.preload_prefix)}" + # use smaller lrate + cur_opt += " lrate.val:0.0002" + if ENC_TYPE == "rnn": + # for rnn, better to use no warmup and larger dropout + cur_opt += " lrate_warmup:0 drop_hidden:0.33" + # gpu and seed + cur_opt += f" device:{-1 if args.rgpu<0 else 0}" \ + f" niconf.random_seed:9347{args.cur_run} niconf.random_cuda_seed:9349{args.cur_run}" + # finally extras + cur_opt += " " + args.extras + # ===== + # training + CMD = f"CUDA_VISIBLE_DEVICES={args.rgpu} PYTHONPATH={src_dir} python3 " + (" -m pdb " if args.debug else "") \ + + f"{src_dir}/tasks/cmd.py zmlm.main.train {cur_opt} " + ("" if args.debug else f"2>&1 | tee _log") + system(CMD) + # testing + if not args.debug and args.do_test and (not RUN_MODE.startswith("pre")): + for wset in ["dev", "test"]: + if RUN_MODE.startswith("ner"): + CMD2 = f"CUDA_VISIBLE_DEVICES={args.rgpu} PYTHONPATH={src_dir} python3 {src_dir}/tasks/cmd.py zmlm.main.test" \ + f" _conf input_format:ner do_ner:1 test:{ner_dir}/{CUR_LANG}/{wset}.txt output_file:zout.{wset}" \ + f" model_load_name:zmodel.best" + else: + CMD2 = f"CUDA_VISIBLE_DEVICES={args.rgpu} PYTHONPATH={src_dir} python3 {src_dir}/tasks/cmd.py zmlm.main.test" \ + f" _conf input_format:conllu test:{ud_dir}/{CUR_LANG}_{wset}.conllu output_file:zout.{wset}" \ + f" model_load_name:zmodel.best" + system(CMD2) + # running with pre-train + if not args.debug and args.do_test and (RUN_MODE == "pre"): + CMD = f"python3 {SCRIPT_ABSPATH} -l {args.lang} --rgpu {args.rgpu} --run_dir _ppp1 --run_mode ppp1 --preload_prefix ../zmodel.c200" + system(CMD) + # ===== + +# ===== +def main(): + printing(f"Run with {sys.argv}") + args = parse() + for k in sorted(dir(args)): + if not k.startswith("_"): + printing(f"--{k} = {getattr(args, k)}") + # -- + DIRs = step0_dir(args) + step_run(args, DIRs) + +# -- +if __name__ == '__main__': + main() diff --git a/tasks/zmlm/__init__.py b/tasks/zmlm/__init__.py new file mode 100644 index 0000000..5ce3fd0 --- /dev/null +++ b/tasks/zmlm/__init__.py @@ -0,0 +1,3 @@ +# + +# package especially for masklm & other related tasks diff --git a/tasks/zmlm/data/__init__.py b/tasks/zmlm/data/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tasks/zmlm/data/insts/__init__.py b/tasks/zmlm/data/insts/__init__.py new file mode 100644 index 0000000..a097c09 --- /dev/null +++ b/tasks/zmlm/data/insts/__init__.py @@ -0,0 +1,13 @@ +# + +# ----- +# for data +""" +- Generally two layers: instance, field +-- instance can be automatically recursive (but get specific classes for common ones, like "Sentence", "Paragraph") +-- field must specify details by themselves (usually pieces of instances, like "word_seq", "pos_seq") +""" + +from .base import BaseDataItem +from .field import DataField, SeqField, NpArrField, DepTreeField +from .general import GeneralInstance, GeneralSentence, GeneralParagraph, GeneralDocument diff --git a/tasks/zmlm/data/insts/base.py b/tasks/zmlm/data/insts/base.py new file mode 100644 index 0000000..dca97b8 --- /dev/null +++ b/tasks/zmlm/data/insts/base.py @@ -0,0 +1,41 @@ +# + +# base protocol for data io + +class BaseDataItem: + # ===== + # from name to cls + + _registered_mappings = {} # name to class mapping + + @staticmethod + def register(cls): + name = cls.__name__ + mapping = BaseDataItem._registered_mappings + assert name not in mapping, f"Repeated entry name: {cls}({name})" + mapping[name] = cls + return cls + + # ===== + # IO: wrapped versions are useful for optional fields/components + + def to_builtin(self, *args, **kwargs): + raise NotImplementedError() + + @classmethod + def from_builtin(cls, d): + raise NotImplementedError() + + def to_builtin_wrapped(self, *args, **kwargs): + class_name = self.__class__.__name__ + assert class_name in BaseDataItem._registered_mappings, f"Entry not found: {self.__class__}({class_name})" + return {"_type": class_name, "_v": self.to_builtin(*args, **kwargs)} + + @staticmethod + def from_builtin_wrapped(d): + class_name, val = d["_type"], d["_v"] + class_type = BaseDataItem._registered_mappings[class_name] + return class_type.from_builtin(val) + +# short-cut +data_type_reg = BaseDataItem.register diff --git a/tasks/zmlm/data/insts/field.py b/tasks/zmlm/data/insts/field.py new file mode 100644 index 0000000..e3ecfba --- /dev/null +++ b/tasks/zmlm/data/insts/field.py @@ -0,0 +1,177 @@ +# + +from typing import List +import numpy as np +from msp.utils import zwarn +from msp.data import Vocab +from .base import BaseDataItem, data_type_reg + +# ===== +# fields are those that must specify IO by themselves +# todo(note): Field need to deal with empties (usually denoted as None in init) + +# base class +@data_type_reg +class DataField(BaseDataItem): + pass + +# general sequence field: words, poses, ... +@data_type_reg +class SeqField(DataField): + def __init__(self, vals: List, idxes: List[int] = None): + self.vals = vals + self.idxes = idxes + + def __len__(self): + return len(self.vals) + + def has_vals(self): + return self.vals is not None + + # reset the vals + def reset(self, vals: List, idxes: List[int] = None): + self.vals = vals + self.idxes = idxes + + # set indexes (special situations) + def set_idxes(self, idxes: List[int]): + assert len(idxes)==len(self.vals), "Unmatched length of input idxes." + self.idxes = idxes + + # two directions: val <=> idx + # init-val -> idx + def build_idxes(self, voc: Vocab): + self.idxes = [voc.get_else_unk(w) for w in self.vals] + + # set idx & val + def build_vals(self, idxes: List[int], voc: Vocab): + self.idxes = idxes + self.vals = [voc.idx2word(i) for i in idxes] + + def __repr__(self): + return str(self.vals) + + # ===== + # io: no idxes, need to rebuild if needed + + def to_builtin(self, *args, **kwargs): + return self.vals + + @classmethod + def from_builtin(cls, vals: List): + return SeqField(vals) + +# initiated from words (built from words) +@data_type_reg +class InputCharSeqField(SeqField): + def __init__(self, words: List[str]): + super().__init__(words) + + def build_idxes(self, c_voc: Vocab): + self.idxes = [[c_voc.get_else_unk(c) for c in w] for w in self.vals] + + # set idx & val + def build_vals(self, idxes: List[int], voc: Vocab): + raise RuntimeError("No need to build vals for CharSeq.") + + def __repr__(self): + return str(self.vals) + +# NumpyArrField +@data_type_reg +class NpArrField(DataField): + def __init__(self, arr: np.ndarray, float_decimal=None): + self.arr = arr + self.float_decimal: int = float_decimal + + @staticmethod + def _round_float(x, d): + r = 10 ** d + return int(r*x) / r + + def __repr__(self): + return f"Array: type={self.arr.dtype}, shape={self.arr.shape}" + + def to_builtin(self, *args, **kwargs): + flattened_arr = self.arr.flatten() + flattened_list = flattened_arr.tolist() + if self.float_decimal is not None: + flattened_list = [NpArrField._round_float(x, self.float_decimal) for x in flattened_list] + return (self.arr.shape, self.arr.dtype.name, flattened_list) + + @classmethod + def from_builtin(cls, v): + shape, dtype, vlist = v + return np.asarray(vlist, dtype=dtype).reshape(shape) + +# DepTreeField: todo(note): remember everything is considering offset=1 for ARTI_ROOT +@data_type_reg +class DepTreeField(DataField): + def __init__(self, heads: List[int], labels: List[str], label_idxes: List[int] = None): + # todo(note): ARTI_ROOT as 0 should be included here, added from the outside! + self.heads = heads + self.labels = labels + self.label_idxes = label_idxes + # == some cache values + self._dep_dists = None # m-h + self._label_matrix = None # Arr[m,h] of int + # ----- + if self.heads is not None: + if self.heads[0] != 0 or self.labels[0] != "": + zwarn("Bad values for ARTI_ROOT!!") + self._build_tree() # build and check + + def __len__(self): + return len(self.heads) + + def __repr__(self): + return f"DepTree of len={len(self)}" + + def has_vals(self): + return self.heads is not None + + # build the extra info for the trees + def _build_tree(self): + # TODO(!) + pass + + # For labels: two directions: val <=> idx + # init-val -> idx + def build_label_idxes(self, l_voc: Vocab): + self.label_idxes = [l_voc.get_else_unk(w) for w in self.labels] + + # set idx & val + @staticmethod + def build_label_vals(label_idxes: List[int], l_voc: Vocab): + labels = [l_voc.idx2word(i) for i in label_idxes] + labels[0] = "" # ARTI_ROOT + return labels + + # ===== + + def to_builtin(self, *args, **kwargs): + return (self.heads, self.labels, self.label_idxes) + + @classmethod + def from_builtin(cls, v): + heads, labels, label_idxes = v + return DepTreeField(heads, labels, label_idxes) + + # ===== + @property + def dep_dists(self): + if self._dep_dists is None: + self._dep_dists = [m-h for m, h in enumerate(self.heads)] + return self._dep_dists + + @property + def label_matrix(self): + if self._label_matrix is None: + # todo(+N): currently only mh_dep, may be extended to other relations or multi-hop relations + cur_heads, cur_lidxs = self.heads, self.label_idxes # [Alen=1+Rlen] + cur_alen = len(cur_heads) + _mat = np.zeros([cur_alen, cur_alen], dtype=np.long) # [Alen, Alen] + _mat[np.arange(cur_alen), cur_heads] = cur_lidxs # assign mh_dep + _mat[0] = 0 # no head for ARTI_ROOT + self._label_matrix = _mat + return self._label_matrix diff --git a/tasks/zmlm/data/insts/general.py b/tasks/zmlm/data/insts/general.py new file mode 100644 index 0000000..50f74f2 --- /dev/null +++ b/tasks/zmlm/data/insts/general.py @@ -0,0 +1,175 @@ +# + +from typing import Dict, List, Tuple, Union +from msp.utils import Helper +from .base import BaseDataItem, data_type_reg +from .field import DataField, SeqField, InputCharSeqField, DepTreeField + +# ===== +# instances may auto recursively judge IO in some way +# -- two ways of building: cls.create(...) or cls.from_builtin(d) +# todo(note): how to deal with cycles? +# => no cycles in the "General" structures but put cycle links inside "Field" and specifically handle them + +# recursive general instances +@data_type_reg +class GeneralInstance(BaseDataItem): + def __init__(self): + self.items: Dict = {} # main items (serializable) + self.info: Dict = {} # other contents (serializable) + # ----- + self.features = {} # running/extra/tmp/cached features (non-serializable) + + # common + def _add_sth(self, name: str, sth, v, assert_non_exist: bool): + if assert_non_exist: + assert not hasattr(self, name), f"Inside the class: {name} already exists" + assert name not in v, f"Inside the collection: {name} already exists!" + v[name] = sth + setattr(self, name, sth) + + # dynamically adding items + def add_item(self, name: str, item: Union[BaseDataItem, List, Tuple], assert_non_exist=True): + self._add_sth(name, item, self.items, assert_non_exist) + return item + + # adding info (should be builtin-type) + def add_info(self, name: str, info, assert_non_exist=True): + self._add_sth(name, info, self.info, assert_non_exist) + return info + + # ===== + # creating + + # general creating + @classmethod + def create(cls, *args, **kwargs): + x = cls() + x._init(*args, **kwargs) + x._finish_create() + return x + + # real creating (to be overridden, replacing original __init__) + def _init(self, *args, **kwargs): + pass + + # finish up init (to be overridden, common routine used in both create and from_builtin) + def _finish_create(self): + pass + + # ===== + # general IO + + @classmethod + def get_unwrap_name_mappings(cls): + return {} # name to class to avoid wrapping (str -> FieldType) + + @classmethod + def get_unwrap_middle_mappings(cls): + return {} # name to extra middle builtin types, usually List (str -> type) + + def to_builtin(self, *args, **kwargs): + unwrap_name_mappings = self.__class__.get_unwrap_name_mappings() + # only consider items and info + ret = {} + if len(self.info) > 0: + ret["_info"] = self.info + for k, v in self.items.items(): + # same inside the whole layer + if k in unwrap_name_mappings: # no need to record type info + _f = lambda _v, *_args, **_kwargs: _v.to_builtin(*_args, **_kwargs) + else: + _f = lambda _v, *_args, **_kwargs: _v.to_builtin_wrapped(*_args, **_kwargs) + # can be different types; todo(note): can be only one layer, typically usage is List + if isinstance(v, BaseDataItem): + d = _f(v, *args, **kwargs) + elif isinstance(v, (List, Tuple)): + d = [_f(v2, *args, **kwargs) for v2 in v] + else: + raise NotImplementedError(f"Unknown class for 'to_builtin': {v.__class__}") + ret[k] = d + return ret + + @classmethod + def from_builtin(cls, data: Dict): + unwrap_name_mappings = cls.get_unwrap_name_mappings() + unwrap_middle_mappings = cls.get_unwrap_middle_mappings() + ret = cls() # blank init + # recursively add items (corresponding to _init) + for k, v in data.items(): + if k == "_info": + for k2, v2 in v.items(): + ret.add_info(k2, v2) + else: + # same inside the whole layer + if k in unwrap_name_mappings: # no need to record type info + trg_cls = unwrap_name_mappings[k] + _f = trg_cls.from_builtin + else: + _f = BaseDataItem.from_builtin_wrapped + middle_type = unwrap_middle_mappings.get(k) + # can be different types + # todo(+N): more middle types? + if middle_type is not None and issubclass(middle_type, (List, Tuple)): + item = [_f(v2) for v2 in v] + else: + item = _f(v) + # else: + # raise NotImplementedError(f"Unknown type for 'from_builtin': {v.__class__}") + ret.add_item(k, item) + # finish up + ret._finish_create() + return ret + +# +@data_type_reg +class GeneralSentence(GeneralInstance): + @classmethod + def get_unwrap_name_mappings(cls): + return {"word_seq": SeqField, "pos_seq": SeqField, "pred_pos_seq": SeqField, + "dep_tree": DepTreeField, "pred_dep_tree": DepTreeField} + + def _init(self, words: List[str]): + self.add_item("word_seq", SeqField(words)) + + def __len__(self): + return len(self.word_seq.vals) # length of original input + + @property + def char_seq(self): + if not hasattr(self, "_char_seq"): # create from "word_seq" on demand + self._char_seq = InputCharSeqField(self.word_seq.vals) + return self._char_seq + +# +@data_type_reg +class GeneralParagraph(GeneralInstance): + @classmethod + def get_unwrap_name_mappings(cls): + return {"sents": GeneralSentence} + + @classmethod + def get_unwrap_middle_mappings(cls): + return {"sents": list} + + def _init(self, sents: List[GeneralSentence]): + self.add_item("sents", sents) + +# +@data_type_reg +class GeneralDocument(GeneralInstance): + @classmethod + def get_unwrap_name_mappings(cls): + return {"paras": GeneralParagraph} + + @classmethod + def get_unwrap_middle_mappings(cls): + return {"paras": list} + + def _init(self, paras: List[GeneralParagraph]): + self.add_item("paras", paras) + + @property + def sents(self): + # on-the-fly joining all (no cache) + return Helper.join_list(z.sents for z in self.paras) diff --git a/tasks/zmlm/data/io/__init__.py b/tasks/zmlm/data/io/__init__.py new file mode 100644 index 0000000..0c8fe81 --- /dev/null +++ b/tasks/zmlm/data/io/__init__.py @@ -0,0 +1,4 @@ +# + +from .base import BaseDataReader, BaseDataWriter, PlainTextReader +from .conllu import ConlluParseReader, ConllNerReader diff --git a/tasks/zmlm/data/io/base.py b/tasks/zmlm/data/io/base.py new file mode 100644 index 0000000..29214af --- /dev/null +++ b/tasks/zmlm/data/io/base.py @@ -0,0 +1,150 @@ +# + +# base IO with BaseDataItem using "to_builtin" and "from_builtin" + +import os +import re +import json +from typing import Iterable +from msp.utils import zopen +from msp.data import Streamer, FileOrFdStreamer +from ..insts import BaseDataItem, GeneralSentence + +# ===== +# an improved FileStreamer: read files for dir_mode, read lines for file_mode +class PathOrFdStreamer(Streamer): + def __init__(self, path_or_fd: str, is_dir=False, re_pat=""): + super().__init__() + self.path_or_fd = path_or_fd + # for dir mode + self.is_dir = is_dir + if is_dir: + file_list0 = os.listdir(path_or_fd) + if re_pat != "": + if isinstance(re_pat, str): + re_pat = re.compile(re_pat) + file_list0 = [f for f in file_list0 if re.fullmatch(re_pat, f)] + file_list0.sort() # sort by string name + file_list = [os.path.join(path_or_fd, f) for f in file_list0] + else: + file_list = None + self.dir_file_list = file_list + self.dir_file_ptr = 0 + # ----- + self.input_is_fd = not isinstance(path_or_fd, str) + if self.input_is_fd: + self.fd = path_or_fd + else: + self.fd = None + + def __del__(self): + if (self.fd is not None) and (not self.input_is_fd) and (not self.fd.closed): + self.fd.close() + + def _restart(self): + if self.is_dir: + self.dir_file_ptr = 0 + elif not self.input_is_fd: + if self.fd is not None: + self.fd.close() + self.fd = zopen(self.path_or_fd) + else: + assert self.restart_times_==0, "Cannot restart (read multiple times) a FdStreamer" + + def _next(self): + if self.is_dir: + cur_ptr = self.dir_file_ptr + if cur_ptr >= len(self.dir_file_list): + return None + else: + with zopen(self.dir_file_list[cur_ptr]) as fd: + ss = fd.read() + self.dir_file_ptr = cur_ptr + 1 + return ss + else: + line = self.fd.readline() + if len(line) == 0: + return None + else: + return line + +# further using json to load +class BaseDataReader(PathOrFdStreamer): + def __init__(self, cls, path_or_fd: str, is_dir=False, re_pat="", cut=-1): + super().__init__(path_or_fd, is_dir, re_pat) + # + assert issubclass(cls, BaseDataItem) + self.cls = cls # the main BaseDataItem type + self.cut = cut + + def _next(self): + if self.count_ == self.cut: + return None + ss = super()._next() + if ss is None: + return None + else: + d = json.loads(ss) + return self.cls.from_builtin(d) + +# writing +class BaseDataWriter: + def __init__(self, path_or_fd: str, is_dir=False, fname_getter=None, suffix=""): + self.path_or_fd = path_or_fd + # dir mode + self.is_dir = is_dir + self.anon_counter = -1 + self.fname_getter = fname_getter if fname_getter else self.anon_name_getter + self.suffix = suffix + # otherwise + if isinstance(path_or_fd, str): + self.fd = zopen(path_or_fd, "w") + else: + self.fd = path_or_fd + + def anon_name_getter(self): + self.anon_counter += 1 + return f"anon.{self.anon_counter}{self.suffix}" + + def write_list(self, insts: Iterable[BaseDataItem]): + for inst in insts: + self.write_one(inst) + + def write_one(self, inst: BaseDataItem): + if self.is_dir: + fname = self.fname_getter(inst) + with zopen(os.path.join(self.path_or_fd, fname), 'w') as fd: + json.dump(inst.to_builtin(), fd) + else: + self.fd.write(json.dumps(inst.to_builtin())) + self.fd.write("\n") + + def finish(self): + if self.fd is not None and (not self.fd.closed): + self.fd.close() + self.fd = None + + def __del__(self): + self.finish() + +# ===== +# plain reader + +class PlainTextReader(FileOrFdStreamer): + def __init__(self, file_or_fd, aug_code="", cut=-1, sep=None): + super().__init__(file_or_fd) + self.aug_code = aug_code + self.cut = cut + self.sep = sep + + def _next(self): + if self.count_ == self.cut: + return None + line = self.fd.readline() + if len(line) == 0: + return None + tokens = line.rstrip("\n").split(self.sep) + one = GeneralSentence.create(tokens) + one.add_info("sid", self.count()) + one.add_info("aug_code", self.aug_code) + return one diff --git a/tasks/zmlm/data/io/conllu.py b/tasks/zmlm/data/io/conllu.py new file mode 100644 index 0000000..94465ca --- /dev/null +++ b/tasks/zmlm/data/io/conllu.py @@ -0,0 +1,118 @@ +# + +# read CoNLL's column format + +from msp.utils import Conf, FileHelper +from msp.data import FileOrFdStreamer +from msp.zext.dpar import ConlluReader +from ..insts import GeneralSentence, SeqField, DepTreeField + +# read from column styled file and return Dict +class ColumnReader(FileOrFdStreamer): + def __init__(self, file_or_fd, sep_line_f=lambda x: len(x.strip()) == 0, sep_field_t="\t"): + super().__init__(file_or_fd) + self.sep_line_f = sep_line_f + self.sep_field_t = sep_field_t + + # + def _get(self, fd): + lines = FileHelper.read_multiline(fd, self.sep_line_f) + if lines is None: + return None + ret = {} + headlines = [] + for one_line in lines: + if one_line.startswith("#"): + headlines.append(one_line) + continue + fileds = one_line.strip("\n").split(self.sep_field_t) + if len(ret) == 0: + for i in range(len(fileds)): + ret[i] = [] # the start + else: + assert len(ret) == len(fileds), "Fields length um-matched!" + for i, x in enumerate(fileds): + ret[i].append(x) + ret["headlines"] = headlines + return ret + + # not from self.fd + def yield_insts(self, file_or_fd): + if isinstance(file_or_fd, str): + with open(file_or_fd) as fd: + yield from self.yield_insts(fd) + else: + while True: + z = self._get(file_or_fd) + if z is None: + break + yield z + + def _next(self): + return self._get(self.fd) + +# +class ConlluParseReader(FileOrFdStreamer): + def __init__(self, file_or_fd, aug_code="", use_xpos=False, use_la0=True, cut=-1): + super().__init__(file_or_fd) + self.aug_code = aug_code + self.reader = ConlluReader() + self.cut = cut + # + if use_xpos: + self.pos_f = lambda t: t.xpos + else: + self.pos_f = lambda t: t.upos + if use_la0: + self.label_f = lambda t: t.label0 + else: + self.label_f = lambda t: t.label + + def _next(self): + if self.count_ == self.cut: + return None + parse = self.reader.read_one(self.fd) + if parse is None: + return None + else: + tokens = parse.get_tokens() + # one = ParseInstance([t.word for t in tokens], [self.pos_f(t) for t in tokens], + # [t.head for t in tokens], [self.label_f(t) for t in tokens], code=self.aug_code) + # one.init_idx = self.count() + one = GeneralSentence.create([t.word for t in tokens]) + one.add_item("pos_seq", SeqField([self.pos_f(t) for t in tokens])) + one.add_item("dep_tree", DepTreeField([0]+[t.head for t in tokens], ['']+[self.label_f(t) for t in tokens])) + one.add_info("sid", self.count()) + one.add_info("aug_code", self.aug_code) + return one + +# BIO NER tags reader +class ConllNerReader(FileOrFdStreamer): + def __init__(self, file_or_fd, aug_code="", cut=-1): + super().__init__(file_or_fd) + self.aug_code = aug_code + self.cut = cut + + def _next(self): + if self.count_ == self.cut: + return None + lines = FileHelper.read_multiline(self.fd) + if lines is None: + return None + else: + tokens, tags = [], [] + for line in lines: + one_tok, one_tag = line.split() + if one_tag != "O": # checking + t0, t1 = one_tag.split("-") + assert t0 in "BI" + tokens.append(one_tok) + tags.append(one_tag) + # one = ParseInstance([t.word for t in tokens], [self.pos_f(t) for t in tokens], + # [t.head for t in tokens], [self.label_f(t) for t in tokens], code=self.aug_code) + # one.init_idx = self.count() + one = GeneralSentence.create(tokens) + one.add_item("ner_seq", SeqField(tags)) + one.add_info("sid", self.count()) + one.add_info("aug_code", self.aug_code) + return one diff --git a/tasks/zmlm/main/test.py b/tasks/zmlm/main/test.py new file mode 100644 index 0000000..b5ef009 --- /dev/null +++ b/tasks/zmlm/main/test.py @@ -0,0 +1,96 @@ +# + +import msp +from msp import utils +from msp.utils import Helper, zlog, zwarn +from msp.data import MultiHelper, WordVectors, VocabBuilder + +from ..run.confs import OverallConf, init_everything, build_model +from ..run.run import get_data_reader, PreprocessStreamer, index_stream, batch_stream, MltTrainingRunner, MltTestingRunner +from ..run.vocab import MLMVocabPackage + +# +def iter_hit_words(parse_streamer, emb): + words = set() + hit_count = 0 + # + for inst in parse_streamer: + for one in inst.words.vals: + # todo(note): only those hit in the embeddings are added + if emb.has_key(one): + yield one + # count how many hit in embeddings + if one not in words: + hit_count += 1 + words.add(one) + zlog(f"Iter hit words: all={len(words)}, hit={hit_count}") + +# simply use an external one as here +def aug_words_and_embs(model, aug_vocab, aug_wv): + orig_vocab = model.word_vocab + word_emb_node = model.embedder.get_node('word') + if word_emb_node is not None: + orig_arr = word_emb_node.E.detach().cpu().numpy() + # todo(+2): find same-spelling words in the original vocab if not-hit in the extra_embed? + # todo(warn): here aug_vocab should be find in aug_wv + aug_arr = aug_vocab.filter_embed(aug_wv, assert_all_hit=True) + new_vocab, new_arr = MultiHelper.aug_vocab_and_arr(orig_vocab, orig_arr, aug_vocab, aug_arr, aug_override=True) + # assign + model.word_vocab = new_vocab # there should be more to replace? but since only testing maybe no need... + word_emb_node.replace_weights(new_arr) + else: + zwarn("No need to aug vocab since delexicalized model!!") + new_vocab = orig_vocab + return new_vocab + +# +def prepare_test(args, ConfType=None): + # conf + conf = init_everything(args, ConfType) + dconf, mconf = conf.dconf, conf.mconf + # vocab + vpack = MLMVocabPackage.build_by_reading(dconf.dict_dir) + # prepare data + test_streamer = PreprocessStreamer(get_data_reader(dconf.test, dconf.input_format), + lower_case=dconf.lower_case, norm_digit=dconf.norm_digit) + # model + model = build_model(conf, vpack) + if dconf.model_load_name != "": + model.load(dconf.model_load_name) + else: + zwarn("No model to load, Debugging mode??") + # ----- + # augment with extra embeddings for test stream? + extra_embed_files = dconf.vconf.test_extra_pretrain_files + if len(extra_embed_files) > 0: + # get embeddings + extra_codes = dconf.vconf.test_extra_pretrain_codes + if len(extra_codes) == 0: + extra_codes = [""] * len(extra_embed_files) + extra_embedding = WordVectors.load(extra_embed_files[0], aug_code=extra_codes[0]) + extra_embedding.merge_others([WordVectors.load(one_file, aug_code=one_code) for one_file, one_code in + zip(extra_embed_files[1:], extra_codes[1:])]) + # get extra dictionary (only those words hit in extra-embed) + extra_vocab = VocabBuilder.build_from_stream(iter_hit_words(test_streamer, extra_embedding), + sort_by_count=True, pre_list=(), post_list=()) + # give them to the model + new_vocab = aug_words_and_embs(model, extra_vocab, extra_embedding) + vpack.put_voc("word", new_vocab) + # ===== + # No Cache!! + test_inst_preparer = model.get_inst_preper(False) + backoff_pos_idx = dconf.backoff_pos_idx + test_iter = batch_stream(index_stream(test_streamer, vpack, False, False, test_inst_preparer, backoff_pos_idx), + mconf.test_batch_size, mconf, False) + return conf, model, vpack, test_iter + +# +def main(args): + conf, model, vpack, test_iter = prepare_test(args) + dconf = conf.dconf + # go + rr = MltTestingRunner(model, vpack, dconf.output_file, dconf.test, dconf.output_format) + x = rr.run(test_iter) + utils.printing("The end.") + +# b tasks/zmlm/main/test.py:46 diff --git a/tasks/zmlm/main/train.py b/tasks/zmlm/main/train.py new file mode 100644 index 0000000..e6ffa8b --- /dev/null +++ b/tasks/zmlm/main/train.py @@ -0,0 +1,60 @@ +# + +# +from msp import utils +from msp.data import MultiCatStreamer, InstCacher + +from ..run.confs import OverallConf, init_everything, build_model +from ..run.run import get_data_reader, PreprocessStreamer, index_stream, batch_stream, MltTrainingRunner +from ..run.vocab import MLMVocabPackage + +# +def main(args): + conf = init_everything(args) + dconf, mconf = conf.dconf, conf.mconf + # dev/test can be non-existing! + if not dconf.dev and dconf.test: + utils.zwarn("No dev but give test, actually use test as dev (for early stopping)!!") + dt_golds, dt_cuts = [], [] + for file, one_cut in [(dconf.dev, dconf.cut_dev), (dconf.test, "")]: # no cut for test! + if len(file)>0: + utils.zlog(f"Add file `{file}(cut={one_cut})' as dt-file #{len(dt_golds)}.") + dt_golds.append(file) + dt_cuts.append(one_cut) + if len(dt_golds) == 0: + utils.zwarn("No dev set, then please specify static lrate schedule!!") + # data + train_streamer = PreprocessStreamer(get_data_reader(dconf.train, dconf.input_format, cut=dconf.cut_train), + lower_case=dconf.lower_case, norm_digit=dconf.norm_digit) + dt_streamers = [PreprocessStreamer(get_data_reader(f, dconf.dev_input_format, cut=one_cut), + lower_case=dconf.lower_case, norm_digit=dconf.norm_digit) + for f, one_cut in zip(dt_golds, dt_cuts)] + # vocab + if mconf.no_build_dict: + vpack = MLMVocabPackage.build_by_reading(dconf.dict_dir) + else: + # include dev/test only for convenience of including words hit in pre-trained embeddings + vpack = MLMVocabPackage.build_from_stream(dconf.vconf, train_streamer, MultiCatStreamer(dt_streamers)) + vpack.save(dconf.dict_dir) + # aug2 + if mconf.aug_word2: + vpack.aug_word2_vocab(train_streamer, MultiCatStreamer(dt_streamers), mconf.aug_word2_pretrain) + vpack.save(mconf.aug_word2_save_dir) + # model + model = build_model(conf, vpack) + # index the data + train_inst_preparer = model.get_inst_preper(True) + test_inst_preparer = model.get_inst_preper(False) + to_cache = dconf.cache_data + to_cache_shuffle = dconf.to_cache_shuffle + # todo(note): make sure to cache both train and dev to save time for cached computation + backoff_pos_idx = dconf.backoff_pos_idx + train_iter = batch_stream(index_stream(train_streamer, vpack, to_cache, to_cache_shuffle, train_inst_preparer, backoff_pos_idx), mconf.train_batch_size, mconf, True) + dt_iters = [batch_stream(index_stream(z, vpack, to_cache, to_cache_shuffle, test_inst_preparer, backoff_pos_idx), mconf.test_batch_size, mconf, False) for z in dt_streamers] + # training runner + tr = MltTrainingRunner(mconf.rconf, model, vpack, dev_outfs=dconf.output_file, dev_goldfs=dt_golds, dev_out_format=dconf.output_format) + if mconf.train_preload_model: + tr.load(dconf.model_load_name, mconf.train_preload_process) + # go + tr.run(train_iter, dt_iters) + utils.zlog("The end of Training.") diff --git a/tasks/zmlm/main/train_linear.py b/tasks/zmlm/main/train_linear.py new file mode 100644 index 0000000..78c0243 --- /dev/null +++ b/tasks/zmlm/main/train_linear.py @@ -0,0 +1,144 @@ +# + +# train a simple linear prober on the attentions + +from msp import utils +from msp.data import MultiCatStreamer, InstCacher + +from ..run.confs import OverallConf, init_everything, build_model +from ..run.run import get_data_reader, PreprocessStreamer, index_stream, batch_stream, MltTrainingRunner +from ..run.vocab import MLMVocabPackage + +from typing import List +from copy import deepcopy +from msp.nn import BK +from ..model.mtl import BaseModel, BaseModelConf, MtlMlmModel, GeneralSentence +from ..model.mods.dpar import DparG1DecoderConf, DparG1Decoder, PairScorerConf + +# +class LinearProbeConf(BaseModelConf): + def __init__(self): + super().__init__() + +class LinearProbeModel(BaseModel): + def __init__(self, model: MtlMlmModel): + super().__init__(BaseModelConf()) + # ----- + self.model = model + model.refresh_batch(False) + # simple linear layer + dpar_conf = deepcopy(model.conf.dpar_conf) + dpar_conf.pre_dp_space = 0 + dpar_conf.dps_conf = PairScorerConf().init_from_kwargs(use_input0=False, use_input1=False, use_input_pair=True, + use_biaffine=False, use_ff1=False, use_ff2=True, + ff2_hid_layer=0, use_bias=False) + dpar_conf.optim.optim = model.conf.main_optim.optim + inputp_dim = len(model.encoder.layers) * model.enc_attn_count # concat all attns + self.dpar = self.add_component("dpar", DparG1Decoder(self.pc, None, inputp_dim, dpar_conf, model.inputter)) + + def fb_on_batch(self, insts: List[GeneralSentence], training=True, loss_factor=1., + rand_gen=None, assign_attns=False, **kwargs): + self.refresh_batch(training) + # get inputs with models + with BK.no_grad_env(): + input_map = self.model.inputter(insts) + emb_t, mask_t, enc_t, cache, enc_loss = self.model._emb_and_enc(input_map, collect_loss=True) + input_t = BK.concat(cache.list_attn, -1) # [bs, slen, slen, L*H] + losses = [self.dpar.loss(insts, BK.zeros([1,1]), input_t, mask_t)] + # ----- + info = self.collect_loss_and_backward(losses, training, loss_factor) + info.update({"fb": 1, "sent": len(insts), "tok": sum(len(z) for z in insts)}) + return info + + def inference_on_batch(self, insts: List[GeneralSentence], **kwargs): + conf = self.conf + self.refresh_batch(False) + # print(f"{len(insts)}: {insts[0].sid}") + with BK.no_grad_env(): + # decode for dpar + input_map = self.model.inputter(insts) + emb_t, mask_t, enc_t, cache, _ = self.model._emb_and_enc(input_map, collect_loss=False) + input_t = BK.concat(cache.list_attn, -1) # [bs, slen, slen, L*H] + self.dpar.predict(insts, BK.zeros([1,1]), input_t, mask_t) + return {} + +# +def main(args): + conf = init_everything(args) + dconf, mconf = conf.dconf, conf.mconf + # dev/test can be non-existing! + if not dconf.dev and dconf.test: + utils.zwarn("No dev but give test, actually use test as dev (for early stopping)!!") + dt_golds, dt_cuts = [], [] + for file, one_cut in [(dconf.dev, dconf.cut_dev), (dconf.test, "")]: # no cut for test! + if len(file)>0: + utils.zlog(f"Add file `{file}(cut={one_cut})' as dt-file #{len(dt_golds)}.") + dt_golds.append(file) + dt_cuts.append(one_cut) + if len(dt_golds) == 0: + utils.zwarn("No dev set, then please specify static lrate schedule!!") + # data + train_streamer = PreprocessStreamer(get_data_reader(dconf.train, dconf.input_format, cut=dconf.cut_train), + lower_case=dconf.lower_case, norm_digit=dconf.norm_digit) + dt_streamers = [PreprocessStreamer(get_data_reader(f, dconf.dev_input_format, cut=one_cut), + lower_case=dconf.lower_case, norm_digit=dconf.norm_digit) + for f, one_cut in zip(dt_golds, dt_cuts)] + # vocab + if mconf.no_build_dict: + vpack = MLMVocabPackage.build_by_reading(dconf.dict_dir) + else: + # include dev/test only for convenience of including words hit in pre-trained embeddings + vpack = MLMVocabPackage.build_from_stream(dconf.vconf, train_streamer, MultiCatStreamer(dt_streamers)) + vpack.save(dconf.dict_dir) + # model + model = build_model(conf, vpack) + # index the data + train_inst_preparer = model.get_inst_preper(True) + test_inst_preparer = model.get_inst_preper(False) + to_cache = dconf.cache_data + to_cache_shuffle = dconf.to_cache_shuffle + # todo(note): make sure to cache both train and dev to save time for cached computation + backoff_pos_idx = dconf.backoff_pos_idx + train_iter = batch_stream(index_stream(train_streamer, vpack, to_cache, to_cache_shuffle, train_inst_preparer, backoff_pos_idx), mconf.train_batch_size, mconf, True) + dt_iters = [batch_stream(index_stream(z, vpack, to_cache, to_cache_shuffle, test_inst_preparer, backoff_pos_idx), mconf.test_batch_size, mconf, False) for z in dt_streamers] + # training runner + tr = MltTrainingRunner(mconf.rconf, model, vpack, dev_outfs=dconf.output_file, dev_goldfs=dt_golds, dev_out_format=dconf.output_format) + if mconf.train_preload_model: + tr.load(dconf.model_load_name, mconf.train_preload_process) + # ===== + # switch with the linear model + linear_model = LinearProbeModel(model) + tr.model = linear_model + # ===== + # go + tr.run(train_iter, dt_iters) + utils.zlog("The end of Training.") + +def lookat_weights(): + import torch + m = torch.load("zmodel2.best") + assert len(m) == 1 + map_w = list(m.values())[0] # [output, input] + map_w = map_w[1:] + # ----- + num_step, num_head = 6, 8 + input_labels = [f"seeS{s}H{h}" for s in range(num_step) for h in range(num_head)] + output_labels = ["punct", "case", "nsubj", "det", "root", "nmod", "advmod", "obj", "obl", "amod", "compound", "aux", "conj", "mark", "cc", "cop", "advcl", "acl", "xcomp", "nummod", "ccomp", "appos", "flat", "parataxis", "discourse", "expl", "fixed", "list", "iobj", "csubj", "goeswith", "vocative", "reparandum", "orphan", "dep", "dislocated", "clf"] + assert tuple(map_w.shape) == (len(output_labels), len(input_labels)) + # + labels = [input_labels, output_labels] + other_labels = [output_labels, input_labels] + weights = [map_w.T, map_w] + for one_label, one_other_label, one_weight in zip(labels, other_labels, weights): + topk_vals, topk_idxes = one_weight.topk(5, dim=-1) # [?, 5] + assert len(one_label) == len(topk_vals) + for i, lab in enumerate(one_label): + print(f"{lab}: " + ", ".join([f"{one_other_label[o_idx]}({o_val})" for o_val, o_idx in zip(topk_vals[i], topk_idxes[i])])) + +# run +""" +SRC_DIR=../../src/ +MODEL_DIR=../run_trans_enu10k/ +PYTHONPATH=${SRC_DIR} CUDA_VISIBLE_DEVICES=2 python3 -m pdb ${SRC_DIR}/tasks/cmd.py zmlm.main.train_linear ${MODEL_DIR}/_conf2 train_preload_model:1 model_load_name:${MODEL_DIR}/zmodel2.best lrate.val:0.5 main_optim.optim:sgd lrate_warmup:0 max_epochs:20 lrate.start_bias:1 lrate.scale:5 lrate.m:0.5 dpar_conf.label_neg_rate:0. +# b tasks/zmlm/main/train_linear:50 +""" diff --git a/tasks/zmlm/main/vocab_utils.py b/tasks/zmlm/main/vocab_utils.py new file mode 100644 index 0000000..0ea2ca8 --- /dev/null +++ b/tasks/zmlm/main/vocab_utils.py @@ -0,0 +1,76 @@ +# + +# to build vocab: first step of training + +# +from collections import Counter +from msp import utils +from msp.data import MultiCatStreamer, InstCacher + +from ..run.confs import OverallConf, init_everything, build_model +from ..run.run import get_data_reader, PreprocessStreamer, index_stream, batch_stream, MltTrainingRunner +from ..run.vocab import MLMVocabPackage + +# +def print_vocab_txt(vocab, dev_counter=None): + lines = [] + lines.append(f"# Details of Vocab {vocab.name}") + if dev_counter is None: + dev_counter = {} + # iterate through vocab + all_count_voc = sum(z for z in vocab.final_vals if z is not None) + all_count_dev = sum(dev_counter.values()) + all_hit_dev = sum(1 for z in dev_counter.values() if z>0) + accu_count_voc, accu_count_dev = 0, 0 + accu_hit_dev_set = set() + for idx, key in enumerate(vocab.final_words): + # in vocab + count_voc = vocab.final_vals[idx] + if count_voc is None: + count_voc = 0 + accu_count_voc += count_voc + # in dev + count_dev = dev_counter.get(key, 0) + accu_count_dev += count_dev + if count_dev > 0: + accu_hit_dev_set.add(key) + # ----- + added_line = [ + str(idx), key, + f"{count_voc}({count_voc/all_count_voc:.4f})", f"{accu_count_voc}({accu_count_voc/all_count_voc:.4f})", + f"{count_dev}({count_dev/all_count_dev:.4f})", f"{accu_count_dev}({accu_count_dev/all_count_dev:.4f})", + ] + lines.append("\t".join(added_line)) + lines.append(f"Coverage: type={len(accu_hit_dev_set)}/{all_hit_dev}={len(accu_hit_dev_set)/max(1, all_hit_dev)}, " + f"count={accu_count_dev}/{all_count_dev}={accu_count_dev/max(1, all_count_dev)}") + unhit_dev_counter = Counter({k:v for k,v in dev_counter.items() if k not in accu_hit_dev_set}) + lines.append(f"Unhit dev top-50: {unhit_dev_counter.most_common(50)}") + return "\n".join(lines) + +# ----- +def main(args): + conf = init_everything(args) + dconf, mconf = conf.dconf, conf.mconf + # ===== + if dconf.train: # build + train_streamer = PreprocessStreamer(get_data_reader(dconf.train, dconf.input_format), + lower_case=dconf.lower_case, norm_digit=dconf.norm_digit) + vpack = MLMVocabPackage.build_from_stream(dconf.vconf, train_streamer, []) + vpack.save(dconf.dict_dir) + else: # read + vpack = MLMVocabPackage.build_by_reading(dconf.dict_dir) + # check dev + if dconf.dev: + dev_streamer = PreprocessStreamer(get_data_reader(dconf.dev, dconf.dev_input_format), + lower_case=dconf.lower_case, norm_digit=dconf.norm_digit) + dev_counter = Counter() + for inst in dev_streamer: + for w in inst.word_seq.vals: + dev_counter[w] += 1 + print(print_vocab_txt(vpack.get_voc("word"), dev_counter)) + +# examples +""" +SRC_DIR=../../../src/ +PYTHONPATH=${SRC_DIR} python3 ${SRC_DIR}/tasks/cmd.py zmlm.main.vocab_utils train:../wikis/wiki_en.100k.txt dev:../ud24s/en_train.-1.conllu input_format:plain dev_input_format:conllu norm_digit:1 >vv.list +""" diff --git a/tasks/zmlm/model/__init__.py b/tasks/zmlm/model/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tasks/zmlm/model/base.py b/tasks/zmlm/model/base.py new file mode 100644 index 0000000..23c47cf --- /dev/null +++ b/tasks/zmlm/model/base.py @@ -0,0 +1,222 @@ +# + +# base models with msp.nn +from typing import List, Dict, Union +from collections import OrderedDict +from msp.model import Model +from msp.nn import BK +from msp.nn import refresh as nn_refresh +from msp.nn.layers import RefreshOptions, BasicNode +from msp.utils import Conf, zlog, JsonRW, Helper +from msp.zext.process_train import RConf, SVConf, ScheduledValue, OptimConf + +# ===== +# conf +class BaseModelConf(Conf): + def __init__(self): + self.rconf = RConf() + # overall conf + # optimizer + self.main_optim = OptimConf() + # lrate factor (<0 means not activated) + self.main_lrf = SVConf().init_from_kwargs(val=1., which_idx="eidx", mode="none", min_val=0., max_val=1.) + # dropouts + # todo(+N): should have distribute dropouts into each Node, but ... + self.drop_hidden = 0.1 # hdrop + self.gdrop_rnn = 0.33 # gdrop (always fixed for recurrent connections) + self.idrop_rnn = 0.33 # idrop for rnn + self.fix_drop = False # fix drop for one run for each dropout + +# detailize common routines, with sub_module components +class BaseModel(Model): + def __init__(self, conf: BaseModelConf): + self.conf = conf + # ===== Model ===== + self.pc = BK.ParamCollection() + self.main_lrf = ScheduledValue(f"main:lrf", conf.main_lrf) + self._scheduled_values = [self.main_lrf] + # ----- + self.nodes: Dict[str, BasicNode] = OrderedDict() + self.components: Dict[str, BaseModule] = OrderedDict() + # for refreshing dropouts + self.previous_refresh_training = True + + # called before each mini-batch + def refresh_batch(self, training: bool): + # refresh graph + # todo(warn): make sure to remember clear this one + nn_refresh() + # refresh nodes + if not training: + if not self.previous_refresh_training: + # todo(+1): currently no need to refresh testing mode multiple times + return + self.previous_refresh_training = False + rop = RefreshOptions(training=False) # default no dropout + else: + conf = self.conf + rop = RefreshOptions(training=True, hdrop=conf.drop_hidden, idrop=conf.idrop_rnn, + gdrop=conf.gdrop_rnn, fix_drop=conf.fix_drop) + self.previous_refresh_training = True + for node in self.nodes.values(): + node.refresh(rop) + for node in self.components.values(): + node.refresh(rop) + + def update(self, lrate, grad_factor): + self.pc.optimizer_update(lrate, grad_factor) + + def add_scheduled_value(self, sv): + self._scheduled_values.append(sv) + return sv + + def get_scheduled_values(self): + return self._scheduled_values + Helper.join_list(z.get_scheduled_values() for z in self.components.values()) + + # == load and save models + # todo(warn): no need to load confs here + def load(self, path, strict=True): + self.pc.load(path, strict) + # self.conf = JsonRW.load_from_file(path+".json") + zlog(f"Load {self.__class__.__name__} model from {path}.", func="io") + + def save(self, path): + self.pc.save(path) + JsonRW.to_file(self.conf, path + ".json") + zlog(f"Save {self.__class__.__name__} model to {path}.", func="io") + + # ===== + def add_component(self, name: str, node: 'BaseModule'): + assert name not in self.components + self.components[name] = node + # set up optim + self.pc.optimizer_set(node.conf.optim.optim, node.lrf, node.conf.optim, params=node.get_parameters(), + check_repeat=True, check_full=True) + # pop off + self.pc.nnc_pop(node.name) + zlog(f"Add component module, {name}: {node}") + return node + + def add_node(self, name: str, node: BasicNode): + assert name not in self.nodes + self.nodes[name] = node + # set up optim + self.pc.optimizer_set(self.conf.main_optim.optim, self.main_lrf, self.conf.main_optim, + params=node.get_parameters(), check_repeat=True, check_full=True) + # pop off + self.pc.nnc_pop(node.name) + zlog(f"Add node, {name}: {node}") + return node + + # collect all loss + def collect_loss_and_backward(self, loss_info_cols: List[Dict], training: bool, loss_factor: float): + final_loss_dict = LossHelper.combine_multiple(loss_info_cols) # loss_name -> {} + if len(final_loss_dict) <= 0: + return {} # no loss! + final_losses = [] + ret_info_vals = OrderedDict() + for loss_name, loss_info in final_loss_dict.items(): + final_losses.append(loss_info['sum']/(loss_info['count']+1e-5)) + for k in loss_info.keys(): + one_item = loss_info[k] + ret_info_vals[f"loss:{loss_name}_{k}"] = one_item.item() if hasattr(one_item, "item") else float(one_item) + final_loss = BK.stack(final_losses).sum() + if training and final_loss.requires_grad: + BK.backward(final_loss, loss_factor) + return ret_info_vals + +# ===== +class BaseModuleConf(Conf): + def __init__(self): + # optimizer + self.optim = OptimConf() + # lrate factor (<0 means not activated) + self.lrf = SVConf().init_from_kwargs(val=1., which_idx="eidx", mode="none", min_val=0., max_val=1.) + # margin (<0 means not activated) + self.margin = SVConf().init_from_kwargs(val=0.) + # loss lambda (for the whole module) + self.loss_lambda = SVConf().init_from_kwargs(val=1.) + # my hdrop (<0 means not activated) + self.hdrop = -1. + +class BaseModule(BasicNode): + def __init__(self, pc: BK.ParamCollection, conf: BaseModuleConf, name=None, init_rop=None, output_dims=None): + super().__init__(pc, name, init_rop) + self.conf = conf + self.output_dims = output_dims + # scheduled values + self.lrf = ScheduledValue(f"{self.name}:lrf", conf.lrf) + self.margin = ScheduledValue(f"{self.name}:margin", conf.margin) + self.loss_lambda = ScheduledValue(f"{self.name}:lambda", conf.loss_lambda) + self.hdrop = conf.hdrop + if self.hdrop >= 0.: + zlog(f"Forcing different hdrop at {self.name}: {self.hdrop}") + + def refresh(self, rop=None): + local_changed = (self.hdrop>=0. and rop is not None) + if local_changed: # todo(note): modified inplace + old_hdrop = rop.hdrop + rop.hdrop = self.hdrop + super().refresh(rop) + if local_changed: + rop.hdrop = old_hdrop + + def get_scheduled_values(self): + return [self.lrf, self.margin, self.loss_lambda] + + def get_output_dims(self, *input_dims): + return self.output_dims + + # input list of leafs + def _compile_component_loss(self, comp_name, losses: List): + cur_loss_lambda = self.loss_lambda.value + if cur_loss_lambda <= 0.: + losses = [] # no loss since closed by component lambda + ret = LossHelper.compile_component_info(comp_name, losses, loss_lambda=cur_loss_lambda) + return ret + + # ===== + def loss(self, *args, **kwargs): + raise NotImplementedError() + + def predict(self, *args, **kwargs): + raise NotImplementedError() + +# ===== +# loss helper + +# common format of loss_info is: {'name1': {'sum', 'count', ...}, 'name2': ...} +class LossHelper: + @staticmethod + def compile_leaf_info(name: str, loss_sum: BK.Expr, loss_count: BK.Expr, loss_lambda=1., **other_values): + local_dict = {"sum0": loss_sum, "sum": loss_sum*loss_lambda, "count": loss_count, "run": 1} + local_dict.update(other_values) + return OrderedDict({name: local_dict}) + + @staticmethod + def compile_component_info(name: str, sub_losses: List[Dict], loss_lambda=1.): + ret_dict = OrderedDict() + for one_sub_loss in sub_losses: + for k, v in one_sub_loss.items(): + ret_dict[f"{name}.{k}"] = v + for one_item in ret_dict.values(): + one_item["sum"] = one_item["sum"] * loss_lambda + return ret_dict + + @staticmethod + def combine_multiple(inputs: List[Dict]): + # each input is a flattened loss Dict + ret = OrderedDict() + for one_input in inputs: + if one_input is None: # skip None + continue + for name, leaf_info in one_input.items(): + if name in ret: + # adding together + target_info = ret[name] + for k, v in leaf_info.items(): + target_info[k] = target_info.get(k, 0.) + v + else: + # direct set + ret[name] = leaf_info + return ret diff --git a/tasks/zmlm/model/helper.py b/tasks/zmlm/model/helper.py new file mode 100644 index 0000000..2f1cf99 --- /dev/null +++ b/tasks/zmlm/model/helper.py @@ -0,0 +1,106 @@ +# + +# some helper nodes + +from typing import List, Dict +import numpy as np +from msp.utils import zwarn, Random, Helper, Conf +from msp.data import VocabPackage +from msp.nn import BK +from msp.nn.layers import BasicNode, PosiEmbedding2 +from .mods.vrec import VRecEncoderConf, VRecEncoder, VRecConf, VRecNode +from .base import BaseModuleConf, BaseModule + +# ===== +# the representation preparation node +# (forward a vrec multiple times with fixed attns) + +class RPrepConf(Conf): + def __init__(self): + # outside + self.rprep_num_update = 0 # calling times + self.rprep_lambda_emb = 0. # emb_t*L+ent_t*(1-L) + self.rprep_lambda_accu = 0. # accu_attn*L+last_attn(1-L) + # the hided vrec: maybe overall make it simple and less parameterized + self._rprep_vr_conf = VRecConf() # the fixed-attn vrec + # -- only open certain options + self.rprep_d_v = 128 + self.rprep_v_act = "elu" + self.rprep_v_drop = 0.1 + self.rprep_collect_mode = "ch" + self.rprep_collect_reducer_mode = "max" + + def do_validate(self): + # make fewer options; all parameters include: combiner(lstm), collector(v-affine) + self._rprep_vr_conf.comb_mode = "lstm" # just make it lstm + self._rprep_vr_conf.matt_conf.collector_conf.d_v = self.rprep_d_v + self._rprep_vr_conf.matt_conf.collector_conf.v_act = self.rprep_v_act + self._rprep_vr_conf.matt_conf.collector_conf.v_drop = self.rprep_v_drop + self._rprep_vr_conf.matt_conf.collector_conf.collect_mode = self.rprep_collect_mode + self._rprep_vr_conf.matt_conf.collector_conf.collect_reducer_mode = self.rprep_collect_reducer_mode + +class RPrepNode(BasicNode): + def __init__(self, pc, dim: int, conf: RPrepConf): + super().__init__(pc, None, None) + self.conf = conf + # ----- + self.dim = dim + if conf.rprep_num_update > 0: + self.vrec_node = self.add_sub_node("v", VRecNode(pc, dim, conf._rprep_vr_conf)) + else: + self.vrec_node = None + + def get_output_dims(self, *input_dims): + return (self.dim, ) + + @property + def active(self): + return self.conf.rprep_num_update>0 + + # input from emb(very-first input), enc(encoder's output) and cache(vrec-enc's cache) + def __call__(self, emb_t, enc_t, cache): + if not self.active: + return enc_t + conf = self.conf + rprep_lambda_emb, rprep_lambda_accu= conf.rprep_lambda_emb, conf.rprep_lambda_accu + # input_t: [*, len_q, D] + if rprep_lambda_emb == 0.: + input_t = enc_t # by default, only enc_t + else: + input_t = emb_t * rprep_lambda_emb + enc_t * (1. - rprep_lambda_emb) + cur_hidden = input_t + # another encoder + if conf.rprep_num_update > 0: + # attn_t: [*, len_q, len_k, head] + attn_t = cache.list_accu_attn[-1] * rprep_lambda_accu + cache.list_attn[-1] * rprep_lambda_accu + # call it + cur_cache = self.vrec_node.init_call(input_t) + for _ in range(conf.rprep_num_update): + cur_hidden = self.vrec_node.update_call(cur_cache, forced_attn=attn_t) + return cur_hidden + +# ===== +# various entropy + +class EntropyHelper: + @staticmethod + def cross_entropy(p, q, dim=-1): + return - (p * (q + 1e-10).log()).sum(dim) + + @staticmethod + def kl_divergence(p, q, dim=-1): + return EntropyHelper.cross_entropy(p, q, dim) - EntropyHelper.cross_entropy(p, p, dim) + + @staticmethod + def kl_divergence_bd(p, q, dim=-1): # for both directions + return 0.5 * (EntropyHelper.kl_divergence(p, q, dim) + EntropyHelper.kl_divergence(q, p, dim)) + + @staticmethod + def js_divergence(p, q, dim=-1): + m = 0.5 * (p+q) + return 0.5 * (EntropyHelper.kl_divergence(p, m, dim) + EntropyHelper.kl_divergence(q, m, dim)) + + @staticmethod + def get_method(m): + return {"cross": EntropyHelper.cross_entropy, "kl": EntropyHelper.kl_divergence, "js": EntropyHelper.js_divergence, + "kl2": EntropyHelper.kl_divergence_bd, "klbd": EntropyHelper.kl_divergence_bd}[m] diff --git a/tasks/zmlm/model/mods/__init__.py b/tasks/zmlm/model/mods/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tasks/zmlm/model/mods/dpar.py b/tasks/zmlm/model/mods/dpar.py new file mode 100644 index 0000000..672132b --- /dev/null +++ b/tasks/zmlm/model/mods/dpar.py @@ -0,0 +1,188 @@ +# + +# dependency parser (currently simply a first-order graph one) + +from typing import List, Dict +import numpy as np +from msp.data import Vocab +from msp.utils import Conf, Random, Constants +from msp.nn import BK +from msp.nn.layers import PairScorerConf, PairScorer, Affine +from ..base import BaseModuleConf, BaseModule, LossHelper, SVConf, ScheduledValue +from .embedder import Inputter +from ...data.insts import GeneralSentence, DepTreeField +from tasks.zdpar.algo import nmst_unproj + +# conf +class DparG1DecoderConf(BaseModuleConf): + def __init__(self): + super().__init__() + # -- + # space transferring for individual inputs (0 means no this layer) + self.pre_dp_space = 0 + self.pre_dp_act = "elu" + # dependency relation pairwise conf + self.dps_conf = PairScorerConf().init_from_kwargs(use_input_pair=True, use_biaffine=False, use_ff1=False, use_ff2=True) + # fix idx=0 score? + self.fix_s0 = True # whether fix idx=0(NOPE) scores + self.fix_s0_val = 0. # if fix, use what fix val? + # minus idx=0 score? + self.minus_s0 = True # whether minus idx=0(NOPE) scores + # loss lambdas + self.lambda_label = 1. # on the last dim + self.label_neg_rate = 0.1 # sample how much neg examples (label=0=nope) + self.lambda_head = 1. # simply max on last dim and on -2 dim + # detach input? (for example, warmup for dec?) + # "dpar_conf.no_detach_input.mode:linear dpar_conf.no_detach_input.start_bias:2" + self.no_detach_input = SVConf().init_from_kwargs(val=1.) + +# module +class DparG1Decoder(BaseModule): + def __init__(self, pc, input_dim: int, inputp_dim: int, conf: DparG1DecoderConf, inputter: Inputter): + super().__init__(pc, conf, name="dp") + self.conf = conf + self.inputter = inputter + self.input_dim = input_dim + self.inputp_dim = inputp_dim + # checkout and assign vocab + self._check_vocab() + # ----- + # this step is performed at the embedder, thus still does not influence the inputter + self.add_root_token = self.inputter.embedder.add_root_token + assert self.add_root_token, "Currently assert this one!!" # todo(+N) + # ----- + # transform dp space + if conf.pre_dp_space > 0: + dp_space = conf.pre_dp_space + self.pre_aff_m = self.add_sub_node("pm", Affine(pc, input_dim, dp_space, act=conf.pre_dp_act)) + self.pre_aff_h = self.add_sub_node("ph", Affine(pc, input_dim, dp_space, act=conf.pre_dp_act)) + else: + dp_space = input_dim + self.pre_aff_m = self.pre_aff_h = lambda x: x + # dep pairwise scorer: output includes [0, r1) -> [non]+valid_words + self.dps_node = self.add_sub_node("dps", PairScorer(pc, dp_space, dp_space, self.dlab_r1, + conf=conf.dps_conf, in_size_pair=inputp_dim)) + self.dps_s0_mask = np.array([1.]+[0.]*(self.dlab_r1-1)) # [0, 1, ..., 1] + # whether detach input? + self.no_detach_input = ScheduledValue(f"dpar:no_detach", conf.no_detach_input) + + def get_scheduled_values(self): + return super().get_scheduled_values() + [self.no_detach_input] + + # check vocab at init + def _check_vocab(self): + # check range + _l_vocab: Vocab = self.inputter.vpack.get_voc("deplabel") + r0, r1 = _l_vocab.nonspecial_idx_range() # [r0, r1) is the valid range + assert _l_vocab.non==0 and r0 == 1, "Should have preserve only idx=0 as NO-Relation" + # assign + self.dlab_vocab: Vocab = _l_vocab + self.dlab_r0, self.dlab_r1 = r0, r1 # no unk or others + + # ----- + # scoring + # [bs, slen, D], [bs, len_q, len_k, D'], [bs, slen] + def _score(self, repr_t, attn_t, mask_t): + conf = self.conf + # ----- + repr_m = self.pre_aff_m(repr_t) # [bs, slen, S] + repr_h = self.pre_aff_h(repr_t) # [bs, slen, S] + scores0 = self.dps_node.paired_score(repr_m, repr_h, inputp=attn_t) # [bs, len_q, len_k, 1+N] + # mask at outside + slen = BK.get_shape(mask_t, -1) + score_mask = BK.constants(BK.get_shape(scores0)[:-1], 1.) # [bs, len_q, len_k] + score_mask *= (1.-BK.eye(slen)) # no diag + score_mask *= mask_t.unsqueeze(-1) # input mask at len_k + score_mask *= mask_t.unsqueeze(-2) # input mask at len_q + NEG = Constants.REAL_PRAC_MIN + scores1 = scores0 + NEG * (1.-score_mask.unsqueeze(-1)) # [bs, len_q, len_k, 1+N] + # add fixed idx0 scores if set + if conf.fix_s0: + fix_s0_mask_t = BK.input_real(self.dps_s0_mask) # [1+N] + scores1 = (1.-fix_s0_mask_t) * scores1 + fix_s0_mask_t * conf.fix_s0_val # [bs, len_q, len_k, 1+N] + # minus s0 + if conf.minus_s0: + scores1 = scores1 - scores1.narrow(-1, 0, 1) # minus idx=0 scores + return scores1, score_mask + + # get ranged label scores for head + def _ranged_label_scores(self, label_scores): + return label_scores.narrow(-1, self.dlab_r0, self.dlab_r1-self.dlab_r0) + + # loss + def loss(self, insts: List[GeneralSentence], repr_t, attn_t, mask_t, **kwargs): + conf = self.conf + # detach input? + if self.no_detach_input.value <= 0.: + repr_t = repr_t.detach() # no grad back if no_detach_input<=0. + # scoring + label_scores, score_masks = self._score(repr_t, attn_t, mask_t) # [bs, len_q, len_k, 1+N], [bs, len_q, len_k] + # ----- + # get golds + bsize, max_len = BK.get_shape(mask_t) + shape_lidxes = [bsize, max_len, max_len] + gold_lidxes = np.zeros(shape_lidxes, dtype=np.long) # [bs, mlen, mlen] + gold_heads = np.zeros(shape_lidxes[:-1], dtype=np.long) # [bs, mlen] + for bidx, inst in enumerate(insts): + cur_dep_tree = inst.dep_tree + cur_len = len(cur_dep_tree) + gold_lidxes[bidx, :cur_len, :cur_len] = cur_dep_tree.label_matrix + gold_heads[bidx, :cur_len] = cur_dep_tree.heads + # ----- + margin = self.margin.value + all_losses = [] + # first is loss_labels + lambda_label = conf.lambda_label + if lambda_label > 0.: + gold_lidxes_t = BK.input_idx(gold_lidxes) # [bs, len_q, len_k] + label_losses = BK.loss_nll(label_scores, gold_lidxes_t, margin=margin) # [bs, mlen, mlen] + positive_mask_t = (gold_lidxes_t>0).float() # [bs, mlen, mlen] + negative_mask_t = (BK.rand(shape_lidxes) < conf.label_neg_rate).float() # [bs, mlen, mlen] + loss_mask_t = score_masks * (positive_mask_t + negative_mask_t) # [bs, mlen, mlen] + loss_mask_t.clamp_(max=1.) + masked_label_losses = label_losses * loss_mask_t + # compile loss + final_label_loss = LossHelper.compile_leaf_info(f"label", masked_label_losses.sum(), loss_mask_t.sum(), + loss_lambda=lambda_label, npos=positive_mask_t.sum()) + all_losses.append(final_label_loss) + # then head loss + lambda_head = conf.lambda_head + if lambda_head > 0.: + # get head score simply by argmax on ranges + head_scores, _ = self._ranged_label_scores(label_scores).max(-1) # [bs, mlen, mlen] + gold_heads_t = BK.input_idx(gold_heads) + head_losses = BK.loss_nll(head_scores, gold_heads_t, margin=margin) # [bs, mlen] + # mask + head_mask_t = BK.copy(mask_t) + head_mask_t[:, 0] = 0 # not for ARTI_ROOT + masked_head_losses = head_losses * head_mask_t + # compile loss + final_head_loss = LossHelper.compile_leaf_info(f"head", masked_head_losses.sum(), head_mask_t.sum(), + loss_lambda=lambda_label) + all_losses.append(final_head_loss) + # -- + return self._compile_component_loss("dp", all_losses) + + # decode + def predict(self, insts: List[GeneralSentence], repr_t, attn_t, mask_t, **kwargs): + conf = self.conf + # scoring + label_scores, score_masks = self._score(repr_t, attn_t, mask_t) # [bs, len_q, len_k, 1+N], [bs, len_q, len_k] + # get ranged label scores + ranged_label_scores = self._ranged_label_scores(label_scores) # [bs, m, h, RealLab] + # decode + mst_lengths = [len(z) + 1 for z in insts] # +1 to include ROOT for mst decoding + mst_lengths_arr = np.asarray(mst_lengths, dtype=np.int32) + mst_heads_arr, mst_labels_arr, mst_scores_arr = \ + nmst_unproj(ranged_label_scores, mask_t, mst_lengths_arr, labeled=True, ret_arr=True) + mst_labels_arr += self.dlab_r0 # add back offset + # assign; todo(+2): record scores? + for one_idx, one_inst in enumerate(insts): + cur_length = mst_lengths[one_idx] + cur_heads = [0] + mst_heads_arr[one_idx][1:cur_length].tolist() + cur_labels = DepTreeField.build_label_vals(mst_labels_arr[one_idx][:cur_length], self.dlab_vocab) + pred_dep_tree = DepTreeField(cur_heads, cur_labels, mst_labels_arr[one_idx][:cur_length].tolist()) + one_inst.add_item("pred_dep_tree", pred_dep_tree, assert_non_exist=False) + return + +# b tasks/zmlm/model/mods/dpar.py:? diff --git a/tasks/zmlm/model/mods/embedder.py b/tasks/zmlm/model/mods/embedder.py new file mode 100644 index 0000000..61add05 --- /dev/null +++ b/tasks/zmlm/model/mods/embedder.py @@ -0,0 +1,452 @@ +# + +import numpy as np +from typing import Tuple, Iterable, Dict, List, Set +from collections import OrderedDict +from copy import deepcopy + +from msp.utils import Conf, zcheck, zwarn, zlog +from msp.data import VocabPackage +from msp.zext.seq_helper import DataPadder +from msp.nn import BK +from msp.nn.layers import BasicNode, PosiEmbedding2, Embedding, CnnLayer, NoDropRop, Dropout, Affine, Sequential, LayerNorm +from msp.nn.modules import Berter2, Berter2Conf, Berter2Seq +from ...data.insts import GeneralSentence + +# confs for one component: some may be effective only for certain cases +class EmbedderCompConf(Conf): + def __init__(self): + self.comp_dim = 0 # by default 0 meaning this comp is not active + self.comp_drop = 0. # by default 0. + self.comp_init_scale = 0. # init-scale for params + self.comp_init_from_pretrain = False # try to init from pretrain? + # mainly for word + self.comp_rare_unk = 0.5 # replace unk words with UNK when training + self.comp_rare_thr = 0 # only replace those if freq<=this + +# overall conf +class EmbedderNodeConf(Conf): + def __init__(self): + # embedder components + self.ec_word = EmbedderCompConf().init_from_kwargs(comp_dim=100) # we mostly needs words + self.ec_pos = EmbedderCompConf() + self.ec_char = EmbedderCompConf() + self.ec_posi = EmbedderCompConf() + self.ec_bert = EmbedderCompConf() + # add speical node for root + self.add_root_token = True # add the special bos/root/cls/... + # the padding idx 0, should it be zero? note: nope since there may be NAN problems with pre-norm + self.embed_fix_row0 = False + # final projection layer + self.emb_proj_dim = 0 # 0 means no, and thus simply concat + self.emb_proj_act = "linear" # proj_act + self.emb_proj_init_scale = 1. + self.emb_proj_norm = False # cannot add&norm since possible dim mismatch + # special case: cnn-char encoding + self.char_cnn_hidden = 50 # split by windows + self.char_cnn_windows = [3, 5] + # special case: bert setting + self.bert_conf = Berter2Conf().init_from_kwargs(bert2_retinc_cls=True, bert2_training_mask_rate=0., + bert2_output_mode="weighted") + + @property + def ec_dict(self): + return OrderedDict([('word', self.ec_word), ('pos', self.ec_pos), ('char', self.ec_char), + ('posi', self.ec_posi), ('bert', self.ec_bert)]) + + def do_validate(self): + assert self.bert_conf.bert2_output_mode != "layered", "Not applicable here!" + if self.add_root_token: # setup for bert (CLS as the special node) + self.bert_conf.bert2_retinc_cls = True + else: + self.bert_conf.bert2_retinc_cls = False + +# ===== +# ModelNode and Helper for each type of input + +# ----- +# Input helper is used to prepare and transform inputs +class InputHelper: + def __init__(self, comp_name, vpack: VocabPackage): + self.comp_name = comp_name + self.comp_seq_name = f"{comp_name}_seq" + self.voc = vpack.get_voc(comp_name) + self.padder = DataPadder(2, pad_vals=0) # pad 0 + + # return batched arr + def prepare(self, insts: List[GeneralSentence]): + cur_input_list = [getattr(z, self.comp_seq_name).idxes for z in insts] + cur_input_arr, _ = self.padder.pad(cur_input_list) + return cur_input_arr + + def mask(self, v, erase_mask): + IDX_MASK = self.voc.err # todo(note): this one is unused, thus just take it! + ret_arr = v * (1-erase_mask) + IDX_MASK * erase_mask + return ret_arr + + @staticmethod + def get_input_helper(name, *args, **kwargs): + helper_type = {"char": CharCnnHelper, "posi": PosiHelper, "bert": BertInputHelper}.get(name, PlainSeqHelper) + return helper_type(*args, **kwargs) + +# InputEmbedNode is used for embedding corresponding to the Input +class InputEmbedNode(BasicNode): + def __init__(self, pc: BK.ParamCollection, comp_name: str, ec_conf: EmbedderCompConf, + conf: EmbedderNodeConf, vpack: VocabPackage): + super().__init__(pc, None, None) + # ----- + self.ec_conf = ec_conf + self.conf = conf + self.comp_name, self.comp_dim, self.comp_dropout, self.comp_init_scale = \ + comp_name, ec_conf.comp_dim, ec_conf.comp_drop, ec_conf.comp_init_scale + self.voc = vpack.get_voc(comp_name, None) + self.output_dim = self.comp_dim # by default the input size (may be changed later) + self.dropout = None # created later (after deciding output shape) + + def create_dropout_node(self): + if self.comp_dropout > 0.: + self.dropout = self.add_sub_node( + f"D{self.comp_name}", Dropout(self.pc, (self.output_dim,), fix_rate=self.comp_dropout)) + else: + self.dropout = lambda x: x + return self.dropout + + def get_output_dims(self, *input_dims): + return (self.output_dim, ) + + @staticmethod + def get_input_embed_node(name, *args, **kwargs): + node_type = {"char": CharCnnNode, "posi": PosiNode, "bert": BertInputNode}.get(name, PlainSeqNode) + return node_type(*args, **kwargs) + +# ----- +# plain seq + +# the basic one is the plain one +PlainSeqHelper = InputHelper + +# +class PlainSeqNode(InputEmbedNode): + def __init__(self, pc: BK.ParamCollection, comp_name: str, ec_conf: EmbedderCompConf, + conf: EmbedderNodeConf, vpack: VocabPackage): + super().__init__(pc, comp_name, ec_conf, conf, vpack) + # ----- + # get embeddings + npvec = None + if self.ec_conf.comp_init_from_pretrain: + npvec = vpack.get_emb(comp_name) + zlog(f"Try to init InputEmbedNode {comp_name} with npvec.shape={npvec.shape if (npvec is not None) else None}") + if npvec is None: + zwarn("Warn: cannot get pre-trained embeddings to init!!") + # get rare unk range + # - get freq vals, make sure special ones will not be pruned; todo(note): directly use that field + voc_rare_mask = [float(z is not None and z<=ec_conf.comp_rare_thr) for z in self.voc.final_vals] + self.rare_mask = BK.input_real(voc_rare_mask) + self.use_rare_unk = (ec_conf.comp_rare_unk>0. and ec_conf.comp_rare_thr>0) + # -- + # dropout outside explicitly + self.E = self.add_sub_node(f"E{self.comp_name}", Embedding( + pc, len(self.voc), self.comp_dim, fix_row0=conf.embed_fix_row0, npvec=npvec, name=comp_name, + init_rop=NoDropRop(), init_scale=self.comp_init_scale)) + self.create_dropout_node() + + # [*, slen] -> [*, 1+slen, D] + def __call__(self, input, add_root_token: bool): + voc = self.voc + # todo(note): append a [cls/root] idx, currently use "bos" + input_t = BK.input_idx(input) # [*, 1+slen] + # rare unk in training + if self.rop.training and self.use_rare_unk: + rare_unk_rate = self.ec_conf.comp_rare_unk + cur_unk_imask = (self.rare_mask[input_t] * (BK.rand(BK.get_shape(input_t))<rare_unk_rate)).detach().long() + input_t = input_t * (1-cur_unk_imask) + self.voc.unk * cur_unk_imask + # root + if add_root_token: + input_t_p0 = BK.constants(BK.get_shape(input_t)[:-1]+[1], voc.bos, dtype=input_t.dtype) # [*, 1+slen] + input_t_p1 = BK.concat([input_t_p0, input_t], -1) + else: + input_t_p1 = input_t + expr = self.E(input_t_p1) # [*, 1?+slen] + return self.dropout(expr) + +# ----- +# char-cnn seq + +class CharCnnHelper(InputHelper): + def __init__(self, comp_name, vpack: VocabPackage): + super().__init__(comp_name, vpack) + self.padder = DataPadder(3, pad_vals=0) # replace the padder + + def mask(self, v, erase_mask): + return super().mask(v, np.expand_dims(erase_mask, axis=-1)) # todo(note): for simplicity, simply allow all-unk + +class CharCnnNode(PlainSeqNode): + def __init__(self, pc: BK.ParamCollection, comp_name: str, ec_conf: EmbedderCompConf, + conf: EmbedderNodeConf, vpack: VocabPackage): + super().__init__(pc, comp_name, ec_conf, conf, vpack) + # ----- + per_cnn_size = conf.char_cnn_hidden // len(conf.char_cnn_windows) + self.char_cnns = [self.add_sub_node("char_cnn", CnnLayer( + self.pc, self.comp_dim, per_cnn_size, z, pooling="max", act="tanh", init_rop=NoDropRop())) + for z in conf.char_cnn_windows] + self.output_dim = conf.char_cnn_hidden + self.create_dropout_node() + + # [*, slen, wlen] -> [*, 1+slen, wlen, D] + def __call__(self, char_input, add_root_token: bool): + char_input_t = BK.input_idx(char_input) # [*, slen, wlen] + if add_root_token: + slice_shape = BK.get_shape(char_input_t) + slice_shape[-2] = 1 + char_input_t0 = BK.constants(slice_shape, 0, dtype=char_input_t.dtype) # todo(note): simply put 0 here! + char_input_t1 = BK.concat([char_input_t0, char_input_t], -2) # [*, 1?+slen, wlen] + else: + char_input_t1 = char_input_t + char_embeds = self.E(char_input_t1) # [*, 1?+slen, wlen, D] + char_cat_expr = BK.concat([z(char_embeds) for z in self.char_cnns]) + return self.dropout(char_cat_expr) # todo(note): only final dropout + +# ----- +# positional seq (absolute position) + +class PosiHelper(InputHelper): + def __init__(self, comp_name, vpack: VocabPackage): + super().__init__(comp_name, vpack) + + # return batched arr + def prepare(self, insts: List[GeneralSentence]): + shape = (len(insts), max(len(z) for z in insts)) + return shape + + def mask(self, v, erase_mask): + return v # todo(note): generally no need to change posi info for word mask + +class PosiNode(InputEmbedNode): + def __init__(self, pc: BK.ParamCollection, comp_name: str, ec_conf: EmbedderCompConf, + conf: EmbedderNodeConf, vpack: VocabPackage): + super().__init__(pc, comp_name, ec_conf, conf, vpack) + # ----- + self.node = self.add_sub_node("posi", PosiEmbedding2(pc, self.comp_dim, init_scale=self.comp_init_scale)) + self.create_dropout_node() + + def __call__(self, input_v, add_root_token: bool): + if isinstance(input_v, np.ndarray): + # direct use this [batch_size, slen] as input + posi_idxes = BK.input_idx(input_v) + expr = self.node(posi_idxes) # [batch_size, slen, D] + else: + # input is a shape as prepared by "PosiHelper" + batch_size, max_len = input_v + if add_root_token: + max_len += 1 + posi_idxes = BK.arange_idx(max_len) # [1?+slen] add root=0 here + expr = self.node(posi_idxes).unsqueeze(0).expand(batch_size, -1, -1) + return self.dropout(expr) + +# ----- +# for bert + +class BertInputHelper(InputHelper): + def __init__(self, comp_name, vpack: VocabPackage): + super().__init__(comp_name, vpack) + # ---- + self.berter: Berter2 = None + self.mask_id = None + + def set_berter(self, berter): + self.berter = berter # todo(note): need to set from outside! + self.mask_id = self.berter.tokenizer.mask_token_id + + # return List[Bert] + def prepare(self, insts: List[GeneralSentence]): + _key = "_bert2seq" + rets = [] + for z in insts: + _v = z.features.get(_key) + if _v is None: # no cache, rebuild it + _v = self.berter.subword_tokenize2(z.word_seq.vals, True) # here use orig strs + z.features[_key] = _v + rets.append(_v) + return rets + + def mask(self, v, erase_mask): + ret = [one_b2s.apply_mask_new(one_mask, self.mask_id) for one_b2s, one_mask in zip(v, erase_mask)] + return ret + +class BertInputNode(InputEmbedNode): + def __init__(self, pc: BK.ParamCollection, comp_name: str, ec_conf: EmbedderCompConf, + conf: EmbedderNodeConf, vpack: VocabPackage): + super().__init__(pc, comp_name, ec_conf, conf, vpack) + # ----- + if conf.add_root_token: + assert conf.bert_conf.bert2_retinc_cls, "Require add root token by include CLS for Berter2" + self.berter = self.add_sub_node("bert", Berter2(pc, conf.bert_conf)) + self.output_dim = self.berter.get_output_dims()[0] # fix size + self.create_dropout_node() + + def __call__(self, input_list: List[Berter2Seq], add_root_token: bool): + # no need to put add_root_token since already setup for berter + expr = self.berter(input_list) + return self.dropout(expr) # [bs, 1?+slen, D] + +# ===== +# the overall manager +class EmbedderNode(BasicNode): + def __init__(self, pc: BK.ParamCollection, conf: EmbedderNodeConf, vpack: VocabPackage): + super().__init__(pc, None, None) + self.conf = conf + self.vpack = vpack + self.add_root_token = conf.add_root_token + # ----- + self.nodes = [] # params + self.comp_names = [] + self.comp_dims = [] # real dims + self.berter: Berter2 = None + for comp_name, comp_conf in conf.ec_dict.items(): + if comp_conf.comp_dim > 0: + # directly get the nodes + one_node = InputEmbedNode.get_input_embed_node(comp_name, pc, comp_name, comp_conf, conf, vpack) + comp_dim = one_node.get_output_dims()[0] # fix dim + # especially for berter + if comp_name == "bert": + assert self.berter is None + self.berter = one_node.berter + # general steps + self.comp_names.append(comp_name) + self.nodes.append(self.add_sub_node(f"EC{comp_name}", one_node)) + self.comp_dims.append(comp_dim) + # final projection? + self.has_proj = (conf.emb_proj_dim > 0) + if self.has_proj: + proj_layer = Affine(self.pc, sum(self.comp_dims), conf.emb_proj_dim, + act=conf.emb_proj_act, init_scale=conf.emb_proj_init_scale) + if conf.emb_proj_norm: + norm_layer = LayerNorm(self.pc, conf.emb_proj_dim) + self.final_layer = self.add_sub_node("fl", Sequential(self.pc, [proj_layer, norm_layer])) + else: + self.final_layer = self.add_sub_node("fl", proj_layer) + self.output_dim = conf.emb_proj_dim + else: + self.final_layer = None + self.output_dim = sum(self.comp_dims) + + def get_node(self, name, df=None): + try: + _idx = self.comp_names.index(name) + return self.nodes[_idx] + except: + return df + + def get_output_dims(self, *input_dims): + return (self.output_dim, ) + + def __repr__(self): + comp_ss = [f"{n}({d})" for n, d in zip(self.comp_names, self.comp_dims)] + return f"# EmbedderNode: {comp_ss} -> {self.output_dim}" + + # mainly need inputs of [*, slen] or other stuffs specific to the nodes + # return reprs, masks + def __call__(self, input_map: Dict): + exprs = [] + # get masks: this mask is for validing of inst batching + final_masks = BK.input_real(input_map["mask"]) # [*, slen] + if self.add_root_token: # append 1 + slice_t = BK.constants(BK.get_shape(final_masks)[:-1]+[1], 1.) + final_masks = BK.concat([slice_t, final_masks], -1) # [*, 1+slen] + # ----- + # for each component + for idx, name in enumerate(self.comp_names): + cur_node = self.nodes[idx] + cur_input = input_map[name] + cur_expr = cur_node(cur_input, self.add_root_token) + exprs.append(cur_expr) + # ----- + concated_exprs = BK.concat(exprs, dim=-1) + # optional proj + if self.has_proj: + final_expr = self.final_layer(concated_exprs) + else: + final_expr = concated_exprs + return final_expr, final_masks + +# ===== +# the overall input helper +class Inputter: + def __init__(self, embedder: EmbedderNode, vpack: VocabPackage): + self.vpack = vpack + self.embedder = embedder + # ----- + # prepare the inputter + self.comp_names = embedder.comp_names + self.comp_helpers = [] + for comp_name in self.comp_names: + one_helper = InputHelper.get_input_helper(comp_name, comp_name, vpack) + self.comp_helpers.append(one_helper) + if comp_name == "bert": + assert embedder.berter is not None + one_helper.set_berter(berter=embedder.berter) + # ==== + self.mask_padder = DataPadder(2, pad_vals=0., mask_range=2) + + def __call__(self, insts: List[GeneralSentence]): + # first pad words to get masks + _, masks_arr = self.mask_padder.pad([z.word_seq.idxes for z in insts]) + # then get each one + ret_map = {"mask": masks_arr} + for comp_name, comp_helper in zip(self.comp_names, self.comp_helpers): + ret_map[comp_name] = comp_helper.prepare(insts) + return ret_map + + # todo(note): return new masks (input is read only!!) + def mask_input(self, input_map: Dict, input_erase_mask, nomask_names_set: Set): + ret_map = {"mask": input_map["mask"]} + for comp_name, comp_helper in zip(self.comp_names, self.comp_helpers): + if comp_name in nomask_names_set: # direct borrow that one + ret_map[comp_name] = input_map[comp_name] + else: + ret_map[comp_name] = comp_helper.mask(input_map[comp_name], input_erase_mask) + return ret_map + +# ===== +# aug word2: both inputter and embedder +class AugWord2Node(BasicNode): + def __init__(self, pc: BK.ParamCollection, ref_conf: EmbedderNodeConf, ref_vpack: VocabPackage, + comp_name: str, comp_dim: int, output_dim: int): + super().__init__(pc, None, None) + # ===== + self.ref_conf = ref_conf + self.ref_vpack = ref_vpack + self.add_root_token = ref_conf.add_root_token + # modify conf + _tmp_ec_conf = deepcopy(ref_conf.ec_word) + _tmp_ec_conf.comp_dim = comp_dim + _tmp_ec_conf.comp_init_from_pretrain = True + # ----- + self.w_node = self.add_sub_node("nw", InputEmbedNode.get_input_embed_node( + comp_name, pc, comp_name, _tmp_ec_conf, ref_conf, ref_vpack)) + self.c_node = self.add_sub_node("nc", InputEmbedNode.get_input_embed_node( + "char", pc, "char", deepcopy(ref_conf.ec_char), ref_conf, ref_vpack)) + # final proj + _dims = [self.w_node.get_output_dims()[0], self.c_node.get_output_dims()[0]] + self.final_layer = self.add_sub_node("fl", Affine( + self.pc, _dims, output_dim, act=ref_conf.emb_proj_act, init_scale=ref_conf.emb_proj_init_scale)) + self.output_dim = output_dim + # ===== + # inputter + self.w_helper = InputHelper.get_input_helper(comp_name, comp_name, ref_vpack) + self.c_helper = InputHelper.get_input_helper("char", "char", ref_vpack) + + def get_output_dims(self, *input_dims): + return (self.output_dim, ) + + def __call__(self, insts): + # inputter + input_w_arr = self.w_helper.prepare(insts) + input_c_arr = self.c_helper.prepare(insts) + # embedder + output_w_t = self.w_node(input_w_arr, self.add_root_token) + output_c_t = self.c_node(input_c_arr, self.add_root_token) + final_t = self.final_layer([output_w_t, output_c_t]) + return final_t + +# b tasks/zmlm/model/mods/embedder:342 diff --git a/tasks/zmlm/model/mods/masklm.py b/tasks/zmlm/model/mods/masklm.py new file mode 100644 index 0000000..250a2dc --- /dev/null +++ b/tasks/zmlm/model/mods/masklm.py @@ -0,0 +1,188 @@ +# + +# the module for masked lm + +from typing import List, Dict, Tuple +import numpy as np + +from msp.utils import Conf, Random, zlog, JsonRW, zfatal, zwarn +from msp.data import VocabPackage +from msp.zext.seq_helper import DataPadder +from msp.nn import BK +from msp.nn.layers import Affine, NoDropRop + +from ..base import BaseModuleConf, BaseModule, LossHelper +from .embedder import Inputter + +# ----- +class MaskLMNodeConf(BaseModuleConf): + def __init__(self): + super().__init__() + self.loss_lambda.val = 0. # by default no such loss + # ---- + # hidden layer + self.hid_dim = 0 + self.hid_act = "elu" + # mask/pred options + self.mask_rate = 0. + self.min_mask_rank = 1 # min word idx to mask + self.min_pred_rank = 1 # min word idx to pred for the masked ones + self.max_pred_rank = 1 # max word idx to pred for the masked ones + self.tie_input_embeddings = False # tie all preds with input embeddings + self.init_pred_from_pretrain = False + # lambdas for word/pos + self.lambda_word = 1. + self.lambda_pos = 0. + # sometimes maybe we don't want to mask some fields, for example, predicting words with full pos + self.nomask_names = ["pos"] + # ===== + # special loss considering multiple layers + self.loss_layers = [-1] # by default use only the last layer + self.loss_weights = [1.] # weights for each layer? + self.loss_comb_method = "min" # sum/avg/min/max + self.score_comb_method = "sum" # how to combine scores (logprobs) + +class MaskLMNode(BaseModule): + def __init__(self, pc: BK.ParamCollection, input_dim: int, conf: MaskLMNodeConf, inputter: Inputter): + super().__init__(pc, conf, name="MLM") + self.conf = conf + self.inputter = inputter + self.input_dim = input_dim + # this step is performed at the embedder, thus still does not influence the inputter + self.add_root_token = self.inputter.embedder.add_root_token + # vocab and padder + vpack = inputter.vpack + vocab_word, vocab_pos = vpack.get_voc("word"), vpack.get_voc("pos") + # no mask fields + self.nomask_names_set = set(conf.nomask_names) + # models + if conf.hid_dim <= 0: # no hidden layer + self.hid_layer = None + self.pred_input_dim = input_dim + else: + self.hid_layer = self.add_sub_node("hid", Affine(pc, input_dim, conf.hid_dim, act=conf.hid_act)) + self.pred_input_dim = conf.hid_dim + # todo(note): unk is the first one above real words + self.pred_word_size = min(conf.max_pred_rank+1, vocab_word.unk) + self.pred_pos_size = vocab_pos.unk + if conf.tie_input_embeddings: + zwarn("Tie all preds in mlm with input embeddings!!") + self.pred_word_layer = self.pred_pos_layer = None + self.inputter_word_node = self.inputter.embedder.get_node("word") + self.inputter_pos_node = self.inputter.embedder.get_node("pos") + else: + self.inputter_word_node, self.inputter_pos_node = None, None + self.pred_word_layer = self.add_sub_node("pw", Affine(pc, self.pred_input_dim, self.pred_word_size, init_rop=NoDropRop())) + self.pred_pos_layer = self.add_sub_node("pp", Affine(pc, self.pred_input_dim, self.pred_pos_size, init_rop=NoDropRop())) + if conf.init_pred_from_pretrain: + npvec = vpack.get_emb("word") + if npvec is None: + zwarn("Pretrained vector not provided, skip init pred embeddings!!") + else: + with BK.no_grad_env(): + self.pred_word_layer.ws[0].copy_(BK.input_real(npvec[:self.pred_word_size].T)) + zlog(f"Init pred embeddings from pretrained vectors (size={self.pred_word_size}).") + # ===== + COMBINE_METHOD_FS = { + "sum": lambda xs: BK.stack(xs, -1).sum(-1), "avg": lambda xs: BK.stack(xs, -1).mean(-1), + "min": lambda xs: BK.stack(xs, -1).min(-1)[0], "max": lambda xs: BK.stack(xs, -1).max(-1)[0], + } + self.loss_comb_f = COMBINE_METHOD_FS[conf.loss_comb_method] + self.score_comb_f = COMBINE_METHOD_FS[conf.score_comb_method] + + # ---- + # mask randomly certain part for each inst + def mask_input(self, input_map: Dict, rand_gen=None, extra_mask_arr=None): + if rand_gen is None: + rand_gen = Random + # ----- + conf = self.conf + # original valid mask + input_mask = input_map["mask"] # [*, len] + # create mask for input words + input_word = input_map.get("word") # [*, len] + # prepare 'input_erase_mask' and 'output_pred_mask' + input_erase_mask = (rand_gen.random_sample(input_mask.shape) < conf.mask_rate) & (input_mask>0.) # by mask-rate + if input_word is not None: # min word to mask + input_erase_mask &= (input_word >= conf.min_mask_rank) + if extra_mask_arr is not None: # extra mask for valid points + input_erase_mask &= (extra_mask_arr>0.) + # get the new masked inputs (mask for all input components) + new_map = self.inputter.mask_input(input_map, input_erase_mask.astype(np.int), self.nomask_names_set) + return new_map, input_erase_mask.astype(np.float32) + + # helper method: p self.arr2words(seq_idx_t); p self.arr2words(argmax_idxes) + def arr2words(self, arr): + vocab_word = self.inputter.vpack.get_voc("word") + flattened_arr = arr.reshape(-1) + words_list = [vocab_word.idx2word(i) for i in flattened_arr] + words_arr = np.asarray(words_list).reshape(arr.shape) + return words_arr + + # [bsize, slen, *] + # todo(note): be careful that repr_t can be root-appended!! + def loss(self, repr_ts, input_erase_mask_arr, orig_map: Dict, active_hid=True, **kwargs): + conf = self.conf + _tie_input_embeddings = conf.tie_input_embeddings + # prepare idxes for the masked ones + if self.add_root_token: # offset for the special root added in embedder + mask_idxes, mask_valids = BK.mask2idx(BK.input_real(input_erase_mask_arr), padding_idx=-1) # [bsize, ?] + repr_mask_idxes = mask_idxes + 1 + mask_idxes.clamp_(min=0) + else: + mask_idxes, mask_valids = BK.mask2idx(BK.input_real(input_erase_mask_arr)) # [bsize, ?] + repr_mask_idxes = mask_idxes + # get the losses + if BK.get_shape(mask_idxes, -1) == 0: # no loss + return self._compile_component_loss("mlm", []) + else: + if not isinstance(repr_ts, (List, Tuple)): + repr_ts = [repr_ts] + target_word_scores, target_pos_scores = [], [] + target_pos_scores = None # todo(+N): for simplicity, currently ignore this one!! + for layer_idx in conf.loss_layers: + # calculate scores + target_reprs = BK.gather_first_dims(repr_ts[layer_idx], repr_mask_idxes, 1) # [bsize, ?, *] + if self.hid_layer and active_hid: # todo(+N): sometimes, we only want last softmax, need to ensure dim at outside! + target_hids = self.hid_layer(target_reprs) + else: + target_hids = target_reprs + if _tie_input_embeddings: + pred_W = self.inputter_word_node.E.E[:self.pred_word_size] # [PSize, Dim] + target_word_scores.append(BK.matmul(target_hids, pred_W.T)) # List[bsize, ?, Vw] + else: + target_word_scores.append(self.pred_word_layer(target_hids)) # List[bsize, ?, Vw] + # gather the losses + all_losses = [] + for pred_name, target_scores, loss_lambda, range_min, range_max in \ + zip(["word", "pos"], [target_word_scores, target_pos_scores], [conf.lambda_word, conf.lambda_pos], + [conf.min_pred_rank, 0], [min(conf.max_pred_rank, self.pred_word_size-1), self.pred_pos_size-1]): + if loss_lambda > 0.: + seq_idx_t = BK.input_idx(orig_map[pred_name]) # [bsize, slen] + target_idx_t = seq_idx_t.gather(-1, mask_idxes) # [bsize, ?] + ranged_mask_valids = mask_valids * (target_idx_t>=range_min).float() * (target_idx_t<=range_max).float() + target_idx_t[(ranged_mask_valids < 1.)] = 0 # make sure invalid ones in range + # calculate for each layer + all_layer_losses, all_layer_scores = [], [] + for one_layer_idx, one_target_scores in enumerate(target_scores): + # get loss: [bsize, ?] + one_pred_losses = BK.loss_nll(one_target_scores, target_idx_t) * conf.loss_weights[one_layer_idx] + all_layer_losses.append(one_pred_losses) + # get scores + one_pred_scores = BK.log_softmax(one_target_scores, -1) * conf.loss_weights[one_layer_idx] + all_layer_scores.append(one_pred_scores) + # combine all layers + pred_losses = self.loss_comb_f(all_layer_losses) + pred_loss_sum = (pred_losses * ranged_mask_valids).sum() + pred_loss_count = ranged_mask_valids.sum() + # argmax + _, argmax_idxes = self.score_comb_f(all_layer_scores).max(-1) + pred_corrs = (argmax_idxes == target_idx_t).float() * ranged_mask_valids + pred_corr_count = pred_corrs.sum() + # compile leaf loss + r_loss = LossHelper.compile_leaf_info(pred_name, pred_loss_sum, pred_loss_count, + loss_lambda=loss_lambda, corr=pred_corr_count) + all_losses.append(r_loss) + return self._compile_component_loss("mlm", all_losses) + +# b tasks/zmlm/model/mods/masklm:170 diff --git a/tasks/zmlm/model/mods/orderpr.py b/tasks/zmlm/model/mods/orderpr.py new file mode 100644 index 0000000..a6d17c2 --- /dev/null +++ b/tasks/zmlm/model/mods/orderpr.py @@ -0,0 +1,406 @@ +# + +# the module for order predictor + +from typing import Dict +import numpy as np +from msp.utils import Conf, Random +from msp.nn import BK +from msp.nn.layers import PairScorerConf, PairScorer, Affine, NoDropRop +from ..base import BaseModuleConf, BaseModule, LossHelper +from .embedder import Inputter + +# ----- + +# conf +class OrderPredNodeConf(BaseModuleConf): + def __init__(self): + super().__init__() + self.loss_lambda.val = 0. # by default no such loss + # -- + self.disturb_mode = "abs_bin" # abs_noise/abs_bin/rel_noise/rel_bin + self.disturb_ranges = [0, 1] # [a, b) range for either abs or rel + self.abs_noise_sort = True # for "abs_noise", sort the idxes, making it like local shuffling + self.abs_keep_cls = True # if using "abs_*", keep cls/arti_root/... as it is + self.bin_rand_offset = True # add random offsets [0, R) to idxes (only in blur1) + self.bin_blur_method = 2 # which blurring method for bin + self.bin_blur_bag_lbound = 100 # <=0 means [R+this, R], >0 means (auto) [R//2+1, R] + self.bin_blur_bag_keep_rate = 0. # keep the original order in what percentage of the bags + # pairwise scorer + self.ps_conf = PairScorerConf().init_from_kwargs(use_bias=False, use_input_pair=False, + use_biaffine=False, use_ff1=False, use_ff2=True) + self.pred_abs = False # predicting only absolute distance? + self.pred_range = 0 # [-pr, 0) + (0, pr], not predicting self + self.cand_range = -1 # at least as pred_range + # simple clasification losses + self.lambda_n1 = 1. # reduce on dim=-1(2*R): classify dist for each cand + self.lambda_n2 = 1. # reduce on dim=-2(cand): classify cand for each dist + + def do_validate(self): + self.disturb_ranges = [int(z) for z in self.disturb_ranges] + assert len(self.disturb_ranges) == 2 and self.disturb_ranges[1]>self.disturb_ranges[0] + self.cand_range = max(self.cand_range, self.pred_range) # at least as pred_range + +# module +class OrderPredNode(BaseModule): + def __init__(self, pc: BK.ParamCollection, input_dim: int, inputp_dim: int, conf: OrderPredNodeConf, inputter: Inputter): + super().__init__(pc, conf, name="ORP") + self.conf = conf + self.inputter = inputter + self.input_dim = input_dim + self.inputp_dim = inputp_dim + # ----- + # this step is performed at the embedder, thus still does not influence the inputter + self.add_root_token = self.inputter.embedder.add_root_token + # -- + self._disturb_f = getattr(self, "_disturb_"+conf.disturb_mode) # shortcut + self.disturb_range_low, self.disturb_range_high = conf.disturb_ranges + self._blur_idxes_f = [None, self._blur_idxes1, self._blur_idxes2][conf.bin_blur_method] + self._bin_blur_bag_lbound = conf.bin_blur_bag_lbound + self._bin_blur_bag_keep_rate = conf.bin_blur_bag_keep_rate + # ----- + # simple dist predicting: inputs -> [*, len_q, len_k, 2*PredRange] + self.ps_node = self.add_sub_node("ps", PairScorer(pc, input_dim, input_dim, 2*conf.pred_range, + conf=conf.ps_conf, in_size_pair=inputp_dim)) + # -- + self.speical_hid_layer = None + + # ===== + # different disturbing methods + + # add_noise then bin: either stay at this bag or jump to the next + def _disturb_local_shuffle(self, input_mask_arr, disturb_range, rand_gen): + conf = self.conf + assert self._bin_blur_bag_keep_rate == 0., "Currently not supporting this one for this mode!" + batch_size, max_len = input_mask_arr.shape + # -- + R = disturb_range + # first assign ordinary range idxes + orig_idxes = np.arange(max_len)[np.newaxis, :] # [1, max_len] + # add noise: either stay at this bag or jump to the next one + noised_idxes = orig_idxes + R * rand_gen.randint(2, size=(batch_size, max_len)) # [bs, len] + blur_idxes = noised_idxes // R * R # [bs, len] + torank_values = blur_idxes + rand_gen.random_sample(size=(batch_size, max_len)) # sample small from [0,1) + if conf.abs_keep_cls and self.add_root_token: + torank_values[:, 0] = -R # make it rank 0 + torank_values += (R*100) * (1.-input_mask_arr) # make invalid ones stay there + ret_abs_idxes = torank_values.argsort(axis=-1) # [bs, len] + # [bs, 1, len] - [bs, len, 1] = [bs, len, len] + ret_rel_dists = ret_abs_idxes[:, np.newaxis, :] - ret_abs_idxes[:, :, np.newaxis] + return ret_abs_idxes, ret_rel_dists + + # first split into local bags and then shuffle for local bags + def _disturb_local_shuffle2(self, input_mask_arr, disturb_range, rand_gen): + conf = self.conf + batch_size, max_len = input_mask_arr.shape + # -- + R = disturb_range + arange_arr = np.arange(batch_size)[:, np.newaxis] # [bs, 1] + arange_arr2 = np.arange(max_len)[np.newaxis, :] # [1, bs] + # sample bag sizes + orig_size = (batch_size, max_len) + # sample bag sizes + if self._bin_blur_bag_lbound <= 0: + _L = max(1, R + self._bin_blur_bag_lbound) # at least bag_size >= 1 + else: + _L = R // 2 + 1 # automatically + bag_sizes = rand_gen.randint(_L, 1+R, size=orig_size) # [bs, len] + mark_idxes = np.cumsum(bag_sizes, -1) # [bs*len] + mark_idxes = mark_idxes * (mark_idxes<max_len) # [bs*len], make invalid ones 0 + # marked seq + marked_seq = np.zeros(orig_size, dtype=np.int) # [bs, len] + marked_seq[arange_arr, mark_idxes] = 1 # [bs, len] + marked_seq[:,0] = 1 # make them all start at 1, since some invalid ones can have idx=0 + # again cumsum to get values + cumsum_values = np.cumsum(marked_seq.reshape(orig_size), -1) - 1 # [bs, len], minus 1 to make it start with 0 + torank_values = cumsum_values.copy() + if conf.abs_keep_cls and self.add_root_token: + torank_values[:, 0] = -R # make it rank 0 + # keep some bags? + _bin_blur_bag_keep_rate = self._bin_blur_bag_keep_rate + if _bin_blur_bag_keep_rate>0.: + bag_keeps = (rand_gen.random_sample(size=orig_size)<_bin_blur_bag_keep_rate).astype(np.float) # [bs, len], whether keep bag + keep_valid = bag_keeps[arange_arr, cumsum_values] # [bs, len], whether keep token + keep_valid_plus = np.clip(keep_valid + (1. - input_mask_arr), 0, 1) + keep_adding = arange_arr2 * keep_valid_plus # adding this to sorting values + else: + keep_valid = None + keep_adding = (R*max(10, max_len)) * (1.-input_mask_arr) # make invalid ones stay there + # shuffle as noise and rank + torank_values = torank_values*max(10, max_len) + keep_adding + noised_to_rank_values = torank_values + rand_gen.random_sample(size=orig_size) # [bs, len] + sorted_idxes = noised_to_rank_values.argsort(axis=-1) # [bs, len] + ret_abs_idxes = sorted_idxes.argsort(axis=-1) # [bs, len], get real ranking + # [bs, 1, len] - [bs, len, 1] = [bs, len, len] + ret_rel_dists = ret_abs_idxes[:, np.newaxis, :] - ret_abs_idxes[:, :, np.newaxis] + return ret_abs_idxes, ret_rel_dists, keep_valid + + # orig_idx + noise + def _disturb_abs_noise(self, input_mask_arr, disturb_range, rand_gen): + conf = self.conf + assert self._bin_blur_bag_keep_rate == 0., "Currently not supporting this one for this mode!" + batch_size, max_len = input_mask_arr.shape + # -- + R = disturb_range * 2 # +/-range + # first assign ordinary range idxes + orig_idxes = np.arange(max_len)[np.newaxis, :] # [1, max_len] + # add noises: [batch_size, max_len] + noised_idxes = orig_idxes + R * (rand_gen.random_sample((batch_size, max_len))-0.5) # +/- disturb_range + if conf.abs_noise_sort: # simple argsort + if conf.abs_keep_cls and self.add_root_token: + noised_idxes[:, 0] = -R # make it rank 0 + noised_idxes += (R*100) * (1.-input_mask_arr) # make invalid ones stay there + ret_abs_idxes = noised_idxes.argsort(axis=-1) + else: # clip and round + # todo(note): this can lead to repeated idxes + ret_abs_idxes = noised_idxes.clip(0, max_len-1).round().astype(np.int64) + if conf.abs_keep_cls and self.add_root_token: + ret_abs_idxes[:, 0] = 0 # directly put it 0 + # [bs, 1, len] - [bs, len, 1] = [bs, len, len] + ret_rel_dists = ret_abs_idxes[:, np.newaxis, :] - ret_abs_idxes[:, :, np.newaxis] + return ret_abs_idxes, ret_rel_dists + + # orig_dist + noise + def _disturb_rel_noise(self, input_mask_arr, disturb_range, rand_gen): + conf = self.conf + assert self._bin_blur_bag_keep_rate == 0., "Currently not supporting this one for this mode!" + batch_size, max_len = input_mask_arr.shape + # -- + R = disturb_range * 2 # +/-range + # first assign ordinary range idxes + orig_idxes0 = np.arange(max_len) # [max_len] + orig_idxes = orig_idxes0[np.newaxis, :] # [1, max_len] + orig_rel_idxes = orig_idxes[:, np.newaxis, :] - orig_idxes[:, :, np.newaxis] # [1, len, len] + # add noises +/- disturb_range: [batch_size, max_len, max_len] + noised_rel_idxes_f = orig_rel_idxes + R * (rand_gen.random_sample([batch_size, max_len, max_len])-0.5) + # round + noised_rel_idxes = noised_rel_idxes_f.round().astype(np.int64) # [bs, len, len] + # keep diag as 0 and others not 0 (>0:+1 <0:-1) + noised_rel_idxes += (2*(noised_rel_idxes_f>=0) - 1) * (noised_rel_idxes==0) # += (1 if f>=0 else -1) * (i==0) + noised_rel_idxes[:, orig_idxes0, orig_idxes0] = 0 + # no abs_idxes available + return None, noised_rel_idxes + + # ===== + + # helper for *bin*, orig_idxes should be >=0; keep the original ones if <keep_thresh + def blur_idxes(self, orig_idxes, R: int, keep_thresh: int, rand_offset_size, rand_gen): + offset_orig_idxes = (orig_idxes-keep_thresh).clip(min=0) # offset by keep_thresh + # bin the idxes + blur_idxes, blur_keeps = self._blur_idxes_f(offset_orig_idxes, R, rand_offset_size, rand_gen) + # keep original ones if <thresh & starting with this for the blur ones + ret_idxes = np.where(orig_idxes < keep_thresh, orig_idxes, blur_idxes + keep_thresh) + return ret_idxes, blur_keeps + + # simple blur + def _blur_idxes1(self, input_idxes, R: int, rand_offset_size, rand_gen): + bin_rand_offset = self.conf.bin_rand_offset + # add offsets (1): make it group slightly differently at the start + if bin_rand_offset: # but keep at least a relatively large group >R/2 + toblur_idxes = input_idxes + rand_gen.randint(1 + R // 2, size=rand_offset_size) + else: + toblur_idxes = input_idxes + blur_idxes = toblur_idxes // R * R + assert self._bin_blur_bag_keep_rate==0., "Currently not supporting this one for this mode!" + # add random offsets (2) + if bin_rand_offset: + blur_idxes += rand_gen.randint(R, size=rand_offset_size) + return blur_idxes, None + + # random split bags with unequal sizes and random pick one as representative + # todo(note): input idxes cannot be >= shape[1]+R, otherwise will oor-error + def _blur_idxes2(self, orig_idxes, R: int, rand_offset_size, rand_gen): + _bin_blur_bag_keep_rate = self._bin_blur_bag_keep_rate + # conf = self.conf + batch_size, max_len = orig_idxes.shape[:2] # todo(note): assume the first two dims means this!! + # -- + prep_size = (batch_size, max_len+R) # make slightly larger prep sizes since the input make be larger + arange_arr = np.arange(batch_size)[:, np.newaxis] # [bs, 1] + # sample bag sizes + if self._bin_blur_bag_lbound <= 0: + _L = max(1, R+self._bin_blur_bag_lbound) # at least bag_size >= 1 + else: + _L = R//2+1 # automatically + bag_sizes = rand_gen.randint(_L, 1+R, size=prep_size) # [bs, L+R], (3,4,3,5,...) + if _bin_blur_bag_keep_rate>0.: + bag_keeps = (rand_gen.random_sample(size=prep_size)<_bin_blur_bag_keep_rate) # [bs, L+R], (1,0,1,1,...) + else: + bag_keeps = None # no keeps + # csum for the idxes + mark_idxes = np.cumsum(bag_sizes, -1) # [bs, L+R], (3,7,10,15,...) + mark_idxes_i = mark_idxes * (mark_idxes<max_len+R) # [bs, L+R], make invalid ones 0 + # mark idxes for later selecting + marked_seq = np.zeros(prep_size, dtype=np.int) # [bs, L+R], (0,0,0,1,0,0,0,0,1,...) + marked_seq[arange_arr, mark_idxes_i] = 1 # [bs, L+R] + marked_seq[:, 0] = 0 # initial ones can only be marked by invalid ones + # get idxes for range(max_len) + range_idxes = np.cumsum(marked_seq, -1) # [bs, L+R], (0,0,0,1,1,1,1,1,2,...) + # get representatives for each idx (each bag?): (3-?,7-?,10-?,15-?,...) + repr_idxes = np.floor(mark_idxes - (bag_sizes * rand_gen.random_sample(prep_size))).astype(np.int) # [bs, L+R] + # -- + # final select (twice) + sel_range_idxes = repr_idxes[arange_arr, range_idxes] # [bs, L+R], (3-?,3-?,3-?,7-?,7-?,7-?,7-?,7-?,10-?,...) + if _bin_blur_bag_keep_rate>0.: + sel_range_keeps = bag_keeps[arange_arr, range_idxes] # [bs, L+R], whether keep for each idx + sel_range_idxes = np.where(sel_range_keeps, np.arange(max_len+R)[np.newaxis,:], sel_range_idxes) # [bs,L+R] + sel_ret_keeps = sel_range_keeps[arange_arr, orig_idxes.reshape([batch_size, -1])].reshape(orig_idxes.shape).astype(np.float) # [bs, L] + else: + sel_ret_keeps = None + sel_ret_idxes = sel_range_idxes[arange_arr, orig_idxes.reshape([batch_size, -1])].reshape(orig_idxes.shape) # [bs, L] + return sel_ret_idxes, sel_ret_keeps + + # bin(orig_idx) + def _disturb_abs_bin(self, input_mask_arr, disturb_range, rand_gen): + conf = self.conf + batch_size, max_len = input_mask_arr.shape + # -- + R = disturb_range # bin size + rand_offset_size = (batch_size, 1) + # first assign ordinary range idxes + orig_idxes = np.arange(max_len) + np.zeros(rand_offset_size, dtype=np.int) # [bs, len] + blur_keep_thresh = int(conf.abs_keep_cls and self.add_root_token) # whether keep ARTI_ROOT as 0 + ret_abs_idxes, ret_abs_keeps = self.blur_idxes(orig_idxes, R, blur_keep_thresh, (batch_size, 1), rand_gen) # [bs, max_len] + # [bs, 1, len] - [bs, len, 1] = [bs, len, len] + ret_rel_dists = ret_abs_idxes[:, np.newaxis, :] - ret_abs_idxes[:, :, np.newaxis] + return ret_abs_idxes, ret_rel_dists, ret_abs_keeps + + # bin(orig_dist) + def _disturb_rel_bin(self, input_mask_arr, disturb_range, rand_gen): + # conf = self.conf + batch_size, max_len = input_mask_arr.shape + # -- + R = disturb_range # bin size + rand_offset_size = (batch_size, 1, 1) + # first assign ordinary range idxes + orig_idxes = np.arange(max_len) # [max_len] + orig_rel_idxes = orig_idxes[np.newaxis, :] - orig_idxes[:, np.newaxis] # [len, len] + abs_rel_idxes = np.abs(orig_rel_idxes) + np.zeros(rand_offset_size, dtype=np.int) # [bs, len, len] + # always keep 0 as special + blur_abs_rel_idxes, _ = self.blur_idxes(abs_rel_idxes, R, 1, rand_offset_size, rand_gen) # [bs, len, len] + # recover sign + blur_rel_idxes = blur_abs_rel_idxes * (2*(orig_rel_idxes>=0) - 1) # recover signs + # no abs_idxes available + return None, blur_rel_idxes + + # ----- + def disturb_input(self, input_map: Dict, rand_gen=None): + if rand_gen is None: + rand_gen = Random + # ----- + # get shape + input_mask_arr = input_map["mask"] # [bs, len] + disturb_range = rand_gen.randint(self.disturb_range_low, self.disturb_range_high) # random range + if self.add_root_token: # add +1 here!! + input_mask_arr = np.pad(input_mask_arr, ((0,0),(1,0)), constant_values=1.) # [bs, 1+len] + # get disturbed idxes/dists + disturb_ret = self._disturb_f(input_mask_arr, disturb_range, rand_gen) + abs_idxes, rel_dists = disturb_ret[:2] + # shallow copy + ret_map = input_map.copy() + ret_map["posi"] = abs_idxes + ret_map["rel_dist"] = rel_dists + if len(disturb_ret)>2: + ret_map["disturb_keep"] = disturb_ret[2] + # # debug + # self.see_disturbed_inputs(input_map, abs_idxes) + # # ----- + return ret_map + + # helper function to see + def see_disturbed_inputs(self, input_map, abs_idxes): + input_word_idxes = input_map["word"] # [bs, len?] + if self.add_root_token: + input_word_idxes = np.pad(input_word_idxes, ((0,0),(1,0)), constant_values=0) # [bs, 1+len] + # argsort again to get inversed idxes + inversed_idxes = np.argsort(abs_idxes, -1) # [bs, len?] + disturbed_word_idxes = input_word_idxes[np.arange(abs_idxes.shape[0])[:,np.newaxis], inversed_idxes] # [bs, len?] + tshape = disturbed_word_idxes.shape + # ----- + word_voc = self.inputter.vpack.get_voc("word") + orig_word_strs = np.array([word_voc.idx2word(z) for z in input_word_idxes.reshape(-1)], dtype=object).reshape(tshape) + disturbed_word_strs = np.array([word_voc.idx2word(z) for z in disturbed_word_idxes.reshape(-1)], dtype=object).reshape(tshape) + return orig_word_strs, disturbed_word_strs + + # [bs, slen, *], [bs, len_q, len_k, head], [bs, slen], [bs, slen] + def loss(self, repr_t, attn_t, mask_t, disturb_keep_arr, **kwargs): + conf = self.conf + CR, PR = conf.cand_range, conf.pred_range + # ----- + mask_single = BK.copy(mask_t) + # no predictions for ARTI_ROOT + if self.add_root_token: + mask_single[:, 0] = 0. # [bs, slen] + # casting predicting range + cur_slen = BK.get_shape(mask_single, -1) + arange_t = BK.arange_idx(cur_slen) # [slen] + # [1, len] - [len, 1] = [len, len] + reldist_t = (arange_t.unsqueeze(-2) - arange_t.unsqueeze(-1)) # [slen, slen] + mask_pair = ((reldist_t.abs() <= CR) & (reldist_t != 0)).float() # within CR-range; [slen, slen] + mask_pair = mask_pair * mask_single.unsqueeze(-1) * mask_single.unsqueeze(-2) # [bs, slen, slen] + if disturb_keep_arr is not None: + mask_pair *= BK.input_real(1.-disturb_keep_arr).unsqueeze(-1) # no predictions for the kept ones! + # get all pair scores + score_t = self.ps_node.paired_score(repr_t, repr_t, attn_t, maskp=mask_pair) # [bs, len_q, len_k, 2*R] + # ----- + # loss: normalize on which dim? + # get the answers first + if conf.pred_abs: + answer_t = reldist_t.abs() # [1,2,3,...,PR] + answer_t.clamp_(min=0, max=PR-1) # [slen, slen], clip in range, distinguish using masks + else: + answer_t = BK.where((reldist_t>=0), reldist_t-1, reldist_t+2*PR) # [1,2,3,...PR,-PR,-PR+1,...,-1] + answer_t.clamp_(min=0, max=2*PR-1) # [slen, slen], clip in range, distinguish using masks + # expand answer into idxes + answer_hit_t = BK.zeros(BK.get_shape(answer_t) + [2*PR]) # [len_q, len_k, 2*R] + answer_hit_t.scatter_(-1, answer_t.unsqueeze(-1), 1.) + answer_valid_t = ((reldist_t.abs() <= PR) & (reldist_t != 0)).float().unsqueeze(-1) # [bs, len_q, len_k, 1] + answer_hit_t = answer_hit_t * mask_pair.unsqueeze(-1) * answer_valid_t # clear invalid ones; [bs, len_q, len_k, 2*R] + # get losses sum(log(answer*prob)) + # -- dim=-1 is standard 2*PR classification, dim=-2 usually have 2*PR candidates, but can be less at edges + all_losses = [] + for one_dim, one_lambda in zip([-1, -2], [conf.lambda_n1, conf.lambda_n2]): + if one_lambda > 0.: + # since currently there can be only one or zero correct answer + logprob_t = BK.log_softmax(score_t, one_dim) # [bs, len_q, len_k, 2*R] + sumlogprob_t = (logprob_t * answer_hit_t).sum(one_dim) # [bs, len_q, len_k||2*R] + cur_dim_mask_t = (answer_hit_t.sum(one_dim)>0.).float() # [bs, len_q, len_k||2*R] + # loss + cur_dim_loss = - (sumlogprob_t * cur_dim_mask_t).sum() + cur_dim_count = cur_dim_mask_t.sum() + # argmax and corr (any correct counts) + _, cur_argmax_idxes = score_t.max(one_dim) + cur_corrs = answer_hit_t.gather(one_dim, cur_argmax_idxes.unsqueeze(one_dim)) # [bs, len_q, len_k|1, 2*R|1] + cur_dim_corr_count = cur_corrs.sum() + # compile loss + one_loss = LossHelper.compile_leaf_info(f"d{one_dim}", cur_dim_loss, cur_dim_count, + loss_lambda=one_lambda, corr=cur_dim_corr_count) + all_losses.append(one_loss) + return self._compile_component_loss("orp", all_losses) + + # ===== + # add a linear layer for speical loss + def add_node_special(self, masklm_node): + self.speical_hid_layer = self.add_sub_node("shid", Affine( + self.pc, self.input_dim, masklm_node.conf.hid_dim, act=masklm_node.conf.hid_act)) + + # borrow masklm and predict for local_shuffle2 + def loss_special(self, repr_t, mask_t, disturb_keep_arr, input_map, masklm_node): + pred_mask_t = BK.copy(mask_t) + pred_mask_t *= BK.input_real(1.-disturb_keep_arr) # not for the non-shuffled ones + abs_posi = input_map["posi"] # shuffled positions + # no predictions for ARTI_ROOT + if self.add_root_token: + pred_mask_t[:, 0] = 0. # [bs, slen] + abs_posi = (abs_posi[:,1:]-1) # offset by 1 + pred_mask_t = pred_mask_t[:,1:] # remove root + corr_targets_arr = input_map["word"][np.arange(len(abs_posi))[:,np.newaxis], abs_posi] + repr_t_hid = self.speical_hid_layer(repr_t) # go through hid here!! + loss_item = masklm_node.loss([repr_t_hid], pred_mask_t, {"word": corr_targets_arr}, active_hid=False) + if len(loss_item) == 0: + return loss_item + # todo(note): simply change its name + vs = [v for v in loss_item.values()] + assert len(vs)==1 + return {"orp.d-2": vs[0]} + +# b tasks/zmlm/model/mods/orderpr:305 diff --git a/tasks/zmlm/model/mods/plainlm.py b/tasks/zmlm/model/mods/plainlm.py new file mode 100644 index 0000000..f2dc3f8 --- /dev/null +++ b/tasks/zmlm/model/mods/plainlm.py @@ -0,0 +1,130 @@ +# + +# the module for plain lm +# todo(note): require directional encoders, like RNN with enc_rnn_sep_bidirection=True! + +from typing import List, Dict +from msp.utils import Conf, zwarn, zlog +from msp.nn import BK +from msp.nn.layers import BasicNode, Affine, NoDropRop +from ..base import BaseModuleConf, BaseModule, LossHelper +from .embedder import Inputter + +# ----- +class PlainLMNodeConf(BaseModuleConf): + def __init__(self): + super().__init__() + self.loss_lambda.val = 0. # by default no such loss + # ---- + # hidden layer + self.hid_dim = 0 + self.hid_act = "elu" + # pred options + self.split_input_blm = True # split the inputs into half for (l2r+r2l) for blm, otherwise input list or all l2r + self.min_pred_rank = 1 # >= min word idx to pred for the masked ones + self.max_pred_rank = 1 # <= max word idx to pred for the masked ones + # tie weights? + self.tie_input_embeddings = False # tie all preds with input embeddings + self.tie_bidirect_pred = False # tie params for bidirection preds + self.init_pred_from_pretrain = False + +class PlainLMNode(BaseModule): + def __init__(self, pc: BK.ParamCollection, input_dim: int, conf: PlainLMNodeConf, inputter: Inputter): + super().__init__(pc, conf, name="PLM") + self.conf = conf + self.inputter = inputter + self.input_dim = input_dim + self.split_input_blm = conf.split_input_blm + # this step is performed at the embedder, thus still does not influence the inputter + self.add_root_token = self.inputter.embedder.add_root_token + # vocab and padder + vpack = inputter.vpack + vocab_word = vpack.get_voc("word") + # models + real_input_dim = input_dim//2 if self.split_input_blm else input_dim + if conf.hid_dim <= 0: # no hidden layer + self.l2r_hid_layer = self.r2l_hid_layer = None + self.pred_input_dim = real_input_dim + else: + self.l2r_hid_layer = self.add_sub_node("l2r_h", Affine(pc, real_input_dim, conf.hid_dim, act=conf.hid_act)) + self.r2l_hid_layer = self.add_sub_node("r2l_h", Affine(pc, real_input_dim, conf.hid_dim, act=conf.hid_act)) + self.pred_input_dim = conf.hid_dim + # todo(note): unk is the first one above real words + self.pred_size = min(conf.max_pred_rank+1, vocab_word.unk) + if conf.tie_input_embeddings: + zwarn("Tie all preds in plm with input embeddings!!") + self.l2r_pred = self.r2l_pred = None + self.inputter_embed_node = self.inputter.embedder.get_node("word") + else: + self.l2r_pred = self.add_sub_node("l2r_p", Affine(pc, self.pred_input_dim, self.pred_size, init_rop=NoDropRop())) + if conf.tie_bidirect_pred: + self.r2l_pred = self.l2r_pred + else: + self.r2l_pred = self.add_sub_node("r2l_p", Affine(pc, self.pred_input_dim, self.pred_size, init_rop=NoDropRop())) + self.inputter_embed_node = None + if conf.init_pred_from_pretrain: + npvec = vpack.get_emb("word") + if npvec is None: + zwarn("Pretrained vector not provided, skip init pred embeddings!!") + else: + with BK.no_grad_env(): + self.l2r_pred.ws[0].copy_(BK.input_real(npvec[:self.pred_size].T)) + self.r2l_pred.ws[0].copy_(BK.input_real(npvec[:self.pred_size].T)) + zlog(f"Init pred embeddings from pretrained vectors (size={self.pred_size}).") + + # [bsize, slen, *] + # todo(note): be careful that repr_t can be root-appended!! + def loss(self, repr_t, orig_map: Dict, **kwargs): + conf = self.conf + _tie_input_embeddings = conf.tie_input_embeddings + # -- + # specify input + add_root_token = self.add_root_token + # get from inputs + if isinstance(repr_t, (list, tuple)): + l2r_repr_t, r2l_repr_t = repr_t + elif self.split_input_blm: + l2r_repr_t, r2l_repr_t = BK.chunk(repr_t, 2, -1) + else: + l2r_repr_t, r2l_repr_t = repr_t, None + # l2r and r2l + word_t = BK.input_idx(orig_map["word"]) # [bs, rlen] + slice_zero_t = BK.zeros([BK.get_shape(word_t, 0), 1]).long() # [bs, 1] + if add_root_token: + l2r_trg_t = BK.concat([word_t, slice_zero_t], -1) # pad one extra 0, [bs, rlen+1] + r2l_trg_t = BK.concat([slice_zero_t, slice_zero_t, word_t[:,:-1]], -1) # pad two extra 0 at front, [bs, 2+rlen-1] + else: + l2r_trg_t = BK.concat([word_t[:,1:], slice_zero_t], -1) # pad one extra 0, but remove the first one, [bs, -1+rlen+1] + r2l_trg_t = BK.concat([slice_zero_t, word_t[:,:-1]], -1) # pad one extra 0 at front, [bs, 1+rlen-1] + # gather the losses + all_losses = [] + pred_range_min, pred_range_max = max(1, conf.min_pred_rank), self.pred_size-1 + if _tie_input_embeddings: + pred_W = self.inputter_embed_node.E.E[:self.pred_size] # [PSize, Dim] + else: + pred_W = None + # get input embeddings for output + for pred_name, hid_node, pred_node, input_t, trg_t in \ + zip(["l2r", "r2l"], [self.l2r_hid_layer, self.r2l_hid_layer], [self.l2r_pred, self.r2l_pred], + [l2r_repr_t, r2l_repr_t], [l2r_trg_t, r2l_trg_t]): + if input_t is None: + continue + # hidden + hid_t = hid_node(input_t) if hid_node else input_t # [bs, slen, hid] + # pred: [bs, slen, Vsize] + if _tie_input_embeddings: + scores_t = BK.matmul(hid_t, pred_W.T) + else: + scores_t = pred_node(hid_t) + # loss + mask_t = ((trg_t >= pred_range_min) & (trg_t <= pred_range_max)).float() # [bs, slen] + trg_t.clamp_(max=pred_range_max) # make it in range + losses_t = BK.loss_nll(scores_t, trg_t) * mask_t # [bs, slen] + _, argmax_idxes = scores_t.max(-1) # [bs, slen] + corrs_t = (argmax_idxes == trg_t).float() * mask_t # [bs, slen] + # compile leaf loss + one_loss = LossHelper.compile_leaf_info(pred_name, losses_t.sum(), mask_t.sum(), loss_lambda=1., corr=corrs_t.sum()) + all_losses.append(one_loss) + return self._compile_component_loss("plm", all_losses) + +# b zmlm/model/mods/plainlm:45 diff --git a/tasks/zmlm/model/mods/seqcrf.py b/tasks/zmlm/model/mods/seqcrf.py new file mode 100644 index 0000000..289ecfb --- /dev/null +++ b/tasks/zmlm/model/mods/seqcrf.py @@ -0,0 +1,340 @@ +# + +# sequence labeling module with linear CRF +# adopted from NCRF++: https://github.com/jiesutd/NCRFpp/blob/master/model/crf.py (17 Dec 2018) + +from typing import List, Dict, Tuple +import numpy as np + +from msp.utils import Conf, Random, zlog, JsonRW, zfatal, zwarn +from msp.nn import BK +from msp.nn.layers import Affine, NoDropRop + +from ..base import BaseModuleConf, BaseModule, LossHelper +from .embedder import Inputter +from ...data.insts import GeneralSentence, SeqField + +import torch + +# extra ones +START_TAG = -2 +STOP_TAG = -1 + +# Compute log sum exp in a numerically stable way for the forward algorithm +def log_sum_exp(vec, m_size): + """ + calculate log of exp sum + args: + vec (batch_size, vanishing_dim, hidden_dim) : input tensor + m_size : hidden_dim + return: + batch_size, hidden_dim + """ + _, idx = torch.max(vec, 1) # B * 1 * M + max_score = torch.gather(vec, 1, idx.view(-1, 1, m_size)).view(-1, 1, m_size) # B * M + return max_score.view(-1, m_size) + torch.log(torch.sum(torch.exp(vec - max_score.expand_as(vec)), 1)).view(-1, m_size) # B * M + +# -- +class SeqCrfNodeConf(BaseModuleConf): + def __init__(self): + super().__init__() + self.loss_lambda.val = 0. # by default no such loss + # -- + # hidden layer + self.hid_dim = 0 + self.hid_act = "elu" + # for loss function + self.div_by_tok = True + +class SeqCrfNode(BaseModule): + def __init__(self, pc: BK.ParamCollection, pname: str, input_dim: int, conf: SeqCrfNodeConf, inputter: Inputter): + super().__init__(pc, conf, name="CRF") + self.conf = conf + self.inputter = inputter + self.input_dim = input_dim + # this step is performed at the embedder, thus still does not influence the inputter + self.add_root_token = self.inputter.embedder.add_root_token + # -- + self.pname = pname + self.attr_name = pname + "_seq" # attribute name in Instance + self.vocab = inputter.vpack.get_voc(pname) + # todo(note): we must make sure that 0 means NAN + assert self.vocab.non == 0 + # models + if conf.hid_dim <= 0: # no hidden layer + self.hid_layer = None + self.pred_input_dim = input_dim + else: + self.hid_layer = self.add_sub_node("hid", Affine(pc, input_dim, conf.hid_dim, act=conf.hid_act)) + self.pred_input_dim = conf.hid_dim + self.tagset_size = self.vocab.unk # todo(note): UNK is the prediction boundary + self.pred_layer = self.add_sub_node("pr", Affine(pc, self.pred_input_dim, self.tagset_size+2, init_rop=NoDropRop())) + # transition matrix + init_transitions = np.zeros([self.tagset_size+2, self.tagset_size+2]) + init_transitions[:, START_TAG] = -10000.0 + init_transitions[STOP_TAG, :] = -10000.0 + init_transitions[:, 0] = -10000.0 + init_transitions[0, :] = -10000.0 + self.transitions = self.add_param("T", (self.tagset_size+2, self.tagset_size+2), init=init_transitions) + + # score + def _score(self, repr_t): + if self.hid_layer is None: + hid_t = repr_t.contiguous() + else: + hid_t = self.hid_layer(repr_t.contiguous()) + out_t = self.pred_layer(hid_t) + return out_t + + # loss + def loss(self, insts: List[GeneralSentence], repr_t, mask_t, **kwargs): + conf = self.conf + if self.add_root_token: + repr_t = repr_t[:,1:] + mask_t = mask_t[:,1:] + # score + scores_t = self._score(repr_t) # [bs, rlen, D] + # get gold + gold_pidxes = np.zeros(BK.get_shape(mask_t), dtype=np.long) # [bs, ?+rlen] + for bidx, inst in enumerate(insts): + cur_seq_idxes = getattr(inst, self.attr_name).idxes + gold_pidxes[bidx, :len(cur_seq_idxes)] = cur_seq_idxes + # get loss + gold_pidxes_t = BK.input_idx(gold_pidxes) + nll_loss_sum = self.neg_log_likelihood_loss(scores_t, mask_t.bool(), gold_pidxes_t) + if not conf.div_by_tok: # otherwise div by sent + nll_loss_sum *= (mask_t.sum() / len(insts)) + # compile loss + crf_loss = LossHelper.compile_leaf_info("crf", nll_loss_sum.sum(), mask_t.sum()) + return self._compile_component_loss(self.pname, [crf_loss]) + + # predict + def predict(self, insts: List[GeneralSentence], repr_t, mask_t, **kwargs): + conf = self.conf + if self.add_root_token: + repr_t = repr_t[:,1:] + mask_t = mask_t[:,1:] + # score + scores_t = self._score(repr_t) # [bs, rlen, D] + # decode + _, decode_idx = self._viterbi_decode(scores_t, mask_t.bool()) + decode_idx_arr = BK.get_value(decode_idx) # [bs, rlen] + for one_bidx, one_inst in enumerate(insts): + one_pidxes = decode_idx_arr[one_bidx].tolist()[:len(one_inst)] + one_pseq = SeqField(None) + one_pseq.build_vals(one_pidxes, self.vocab) + one_inst.add_item("pred_"+self.attr_name, one_pseq, assert_non_exist=False) + return + + # ===== + + def _calculate_PZ(self, feats, mask): + """ + input: + feats: (batch, seq_len, self.tag_size+2) + masks: (batch, seq_len) + """ + batch_size = feats.size(0) + seq_len = feats.size(1) + tag_size = feats.size(2) + # print feats.view(seq_len, tag_size) + assert (tag_size == self.tagset_size+2) + mask = mask.transpose(1,0).contiguous() + ins_num = seq_len * batch_size + ## be careful the view shape, it is .view(ins_num, 1, tag_size) but not .view(ins_num, tag_size, 1) + feats = feats.transpose(1,0).contiguous().view(ins_num,1, tag_size).expand(ins_num, tag_size, tag_size) + ## need to consider start + scores = feats + self.transitions.view(1,tag_size,tag_size).expand(ins_num, tag_size, tag_size) + scores = scores.view(seq_len, batch_size, tag_size, tag_size) + # build iter + seq_iter = enumerate(scores) + _, inivalues = next(seq_iter) # bat_size * from_target_size * to_target_size + # only need start from start_tag + partition = inivalues[:, START_TAG, :].clone().view(batch_size, tag_size, 1) # bat_size * to_target_size + + ## add start score (from start to all tag, duplicate to batch_size) + # partition = partition + self.transitions[START_TAG,:].view(1, tag_size, 1).expand(batch_size, tag_size, 1) + # iter over last scores + for idx, cur_values in seq_iter: + # previous to_target is current from_target + # partition: previous results log(exp(from_target)), #(batch_size * from_target) + # cur_values: bat_size * from_target * to_target + + cur_values = cur_values + partition.contiguous().view(batch_size, tag_size, 1).expand(batch_size, tag_size, tag_size) + cur_partition = log_sum_exp(cur_values, tag_size) + # print cur_partition.data + + # (bat_size * from_target * to_target) -> (bat_size * to_target) + # partition = utils.switch(partition, cur_partition, mask[idx].view(bat_size, 1).expand(bat_size, self.tagset_size)).view(bat_size, -1) + mask_idx = mask[idx, :].view(batch_size, 1).expand(batch_size, tag_size) + + ## effective updated partition part, only keep the partition value of mask value = 1 + masked_cur_partition = cur_partition.masked_select(mask_idx) + ## let mask_idx broadcastable, to disable warning + mask_idx = mask_idx.contiguous().view(batch_size, tag_size, 1) + + ## replace the partition where the maskvalue=1, other partition value keeps the same + partition.masked_scatter_(mask_idx, masked_cur_partition) + # until the last state, add transition score for all partition (and do log_sum_exp) then select the value in STOP_TAG + cur_values = self.transitions.view(1,tag_size, tag_size).expand(batch_size, tag_size, tag_size) + partition.contiguous().view(batch_size, tag_size, 1).expand(batch_size, tag_size, tag_size) + cur_partition = log_sum_exp(cur_values, tag_size) + final_partition = cur_partition[:, STOP_TAG] + return final_partition.sum(), scores + + def _viterbi_decode(self, feats, mask): + """ + input: + feats: (batch, seq_len, self.tag_size+2) + mask: (batch, seq_len) + output: + decode_idx: (batch, seq_len) decoded sequence + path_score: (batch, 1) corresponding score for each sequence (to be implementated) + """ + batch_size = feats.size(0) + seq_len = feats.size(1) + tag_size = feats.size(2) + assert(tag_size == self.tagset_size+2) + ## calculate sentence length for each sentence + length_mask = torch.sum(mask.long(), dim = 1).view(batch_size,1).long() + ## mask to (seq_len, batch_size) + mask = mask.transpose(1,0).contiguous() + ins_num = seq_len * batch_size + ## be careful the view shape, it is .view(ins_num, 1, tag_size) but not .view(ins_num, tag_size, 1) + feats = feats.transpose(1,0).contiguous().view(ins_num, 1, tag_size).expand(ins_num, tag_size, tag_size) + ## need to consider start + scores = feats + self.transitions.view(1,tag_size,tag_size).expand(ins_num, tag_size, tag_size) + scores = scores.view(seq_len, batch_size, tag_size, tag_size) + + # build iter + seq_iter = enumerate(scores) + ## record the position of best score + back_points = list() + partition_history = list() + ## reverse mask (bug for mask = 1- mask, use this as alternative choice) + # mask = 1 + (-1)*mask + # mask = (1 - mask.long()).byte() + mask = (1 - mask.long()).bool() + _, inivalues = next(seq_iter) # bat_size * from_target_size * to_target_size + # only need start from start_tag + partition = inivalues[:, START_TAG, :].clone().view(batch_size, tag_size) # bat_size * to_target_size + # print "init part:",partition.size() + partition_history.append(partition) + # iter over last scores + for idx, cur_values in seq_iter: + # previous to_target is current from_target + # partition: previous results log(exp(from_target)), #(batch_size * from_target) + # cur_values: batch_size * from_target * to_target + cur_values = cur_values + partition.contiguous().view(batch_size, tag_size, 1).expand(batch_size, tag_size, tag_size) + ## forscores, cur_bp = torch.max(cur_values[:,:-2,:], 1) # do not consider START_TAG/STOP_TAG + # print "cur value:", cur_values.size() + partition, cur_bp = torch.max(cur_values, 1) + # print "partsize:",partition.size() + # exit(0) + # print partition + # print cur_bp + # print "one best, ",idx + partition_history.append(partition) + ## cur_bp: (batch_size, tag_size) max source score position in current tag + ## set padded label as 0, which will be filtered in post processing + cur_bp.masked_fill_(mask[idx].view(batch_size, 1).expand(batch_size, tag_size), 0) + back_points.append(cur_bp) + # exit(0) + ### add score to final STOP_TAG + partition_history = torch.cat(partition_history, 0).view(seq_len, batch_size, -1).transpose(1,0).contiguous() ## (batch_size, seq_len. tag_size) + ### get the last position for each setences, and select the last partitions using gather() + last_position = length_mask.view(batch_size,1,1).expand(batch_size, 1, tag_size) -1 + last_partition = torch.gather(partition_history, 1, last_position).view(batch_size,tag_size,1) + ### calculate the score from last partition to end state (and then select the STOP_TAG from it) + last_values = last_partition.expand(batch_size, tag_size, tag_size) + self.transitions.view(1,tag_size, tag_size).expand(batch_size, tag_size, tag_size) + _, last_bp = torch.max(last_values, 1) + # pad_zero = autograd.Variable(torch.zeros(batch_size, tag_size)).long() + # if self.gpu: + # pad_zero = pad_zero.cuda() + pad_zero = BK.zeros((batch_size, tag_size)).long() + back_points.append(pad_zero) + back_points = torch.cat(back_points).view(seq_len, batch_size, tag_size) + + ## select end ids in STOP_TAG + pointer = last_bp[:, STOP_TAG] + insert_last = pointer.contiguous().view(batch_size,1,1).expand(batch_size,1, tag_size) + back_points = back_points.transpose(1,0).contiguous() + ## move the end ids(expand to tag_size) to the corresponding position of back_points to replace the 0 values + # print "lp:",last_position + # print "il:",insert_last + back_points.scatter_(1, last_position, insert_last) + # print "bp:",back_points + # exit(0) + back_points = back_points.transpose(1,0).contiguous() + ## decode from the end, padded position ids are 0, which will be filtered if following evaluation + # decode_idx = autograd.Variable(torch.LongTensor(seq_len, batch_size)) + # if self.gpu: + # decode_idx = decode_idx.cuda() + decode_idx = BK.zeros((seq_len, batch_size)).long() + decode_idx[-1] = pointer.detach() + for idx in range(len(back_points)-2, -1, -1): + pointer = torch.gather(back_points[idx], 1, pointer.contiguous().view(batch_size, 1)) + decode_idx[idx] = pointer.detach().view(batch_size) + path_score = None + decode_idx = decode_idx.transpose(1,0) + return path_score, decode_idx + + def _score_sentence(self, scores, mask, tags): + """ + input: + scores: variable (seq_len, batch, tag_size, tag_size) + mask: (batch, seq_len) + tags: tensor (batch, seq_len) + output: + score: sum of score for gold sequences within whole batch + """ + # Gives the score of a provided tag sequence + batch_size = scores.size(1) + seq_len = scores.size(0) + tag_size = scores.size(2) + ## convert tag value into a new format, recorded label bigram information to index + # new_tags = autograd.Variable(torch.LongTensor(batch_size, seq_len)) + # if self.gpu: + # new_tags = new_tags.cuda() + new_tags = BK.zeros((batch_size, seq_len)).long() + for idx in range(seq_len): + if idx == 0: + ## start -> first score + new_tags[:,0] = (tag_size - 2)*tag_size + tags[:,0] + + else: + new_tags[:,idx] = tags[:,idx-1]*tag_size + tags[:,idx] + + ## transition for label to STOP_TAG + end_transition = self.transitions[:,STOP_TAG].contiguous().view(1, tag_size).expand(batch_size, tag_size) + ## length for batch, last word position = length - 1 + length_mask = torch.sum(mask.long(), dim = 1).view(batch_size,1).long() + ## index the label id of last word + end_ids = torch.gather(tags, 1, length_mask - 1) + + ## index the transition score for end_id to STOP_TAG + end_energy = torch.gather(end_transition, 1, end_ids) + + ## convert tag as (seq_len, batch_size, 1) + new_tags = new_tags.transpose(1,0).contiguous().view(seq_len, batch_size, 1) + ### need convert tags id to search from 400 positions of scores + tg_energy = torch.gather(scores.view(seq_len, batch_size, -1), 2, new_tags).view(seq_len, batch_size) # seq_len * bat_size + ## mask transpose to (seq_len, batch_size) + tg_energy = tg_energy.masked_select(mask.transpose(1,0)) + + # ## calculate the score from START_TAG to first label + # start_transition = self.transitions[START_TAG,:].view(1, tag_size).expand(batch_size, tag_size) + # start_energy = torch.gather(start_transition, 1, tags[0,:]) + + ## add all score together + # gold_score = start_energy.sum() + tg_energy.sum() + end_energy.sum() + gold_score = tg_energy.sum() + end_energy.sum() + return gold_score + + def neg_log_likelihood_loss(self, feats, mask, tags): + # nonegative log likelihood + batch_size = feats.size(0) + forward_score, scores = self._calculate_PZ(feats, mask) + gold_score = self._score_sentence(scores, mask, tags) + # print "batch, f:", forward_score.data[0], " g:", gold_score.data[0], " dis:", forward_score.data[0] - gold_score.data[0] + # exit(0) + return forward_score - gold_score diff --git a/tasks/zmlm/model/mods/seqlab.py b/tasks/zmlm/model/mods/seqlab.py new file mode 100644 index 0000000..55e4ad9 --- /dev/null +++ b/tasks/zmlm/model/mods/seqlab.py @@ -0,0 +1,100 @@ +# + +# sequence labeling module +# -- simple individual classifiers + +from typing import List, Dict, Tuple +import numpy as np + +from msp.utils import Conf, Random, zlog, JsonRW, zfatal, zwarn +from msp.nn import BK +from msp.nn.layers import Affine, NoDropRop + +from ..base import BaseModuleConf, BaseModule, LossHelper +from .embedder import Inputter +from ...data.insts import GeneralSentence, SeqField + +# -- +class SeqLabNodeConf(BaseModuleConf): + def __init__(self): + super().__init__() + self.loss_lambda.val = 0. # by default no such loss + # -- + # hidden layer + self.hid_dim = 0 + self.hid_act = "elu" + +class SeqLabNode(BaseModule): + def __init__(self, pc: BK.ParamCollection, pname: str, input_dim: int, conf: SeqLabNodeConf, inputter: Inputter): + super().__init__(pc, conf, name="SLB") + self.conf = conf + self.inputter = inputter + self.input_dim = input_dim + # this step is performed at the embedder, thus still does not influence the inputter + self.add_root_token = self.inputter.embedder.add_root_token + # -- + self.pname = pname + self.attr_name = pname + "_seq" # attribute name in Instance + self.vocab = inputter.vpack.get_voc(pname) + # todo(note): we must make sure that 0 means NAN + assert self.vocab.non == 0 + # models + if conf.hid_dim <= 0: # no hidden layer + self.hid_layer = None + self.pred_input_dim = input_dim + else: + self.hid_layer = self.add_sub_node("hid", Affine(pc, input_dim, conf.hid_dim, act=conf.hid_act)) + self.pred_input_dim = conf.hid_dim + self.pred_out_dim = self.vocab.unk # todo(note): UNK is the prediction boundary + self.pred_layer = self.add_sub_node("pr", Affine(pc, self.pred_input_dim, self.pred_out_dim, init_rop=NoDropRop())) + + # score + def _score(self, repr_t): + if self.hid_layer is None: + hid_t = repr_t + else: + hid_t = self.hid_layer(repr_t) + out_t = self.pred_layer(hid_t) + return out_t + + # loss + def loss(self, insts: List[GeneralSentence], repr_t, mask_t, **kwargs): + conf = self.conf + # score + scores_t = self._score(repr_t) # [bs, ?+rlen, D] + # get gold + gold_pidxes = np.zeros(BK.get_shape(mask_t), dtype=np.long) # [bs, ?+rlen] + for bidx, inst in enumerate(insts): + cur_seq_idxes = getattr(inst, self.attr_name).idxes + if self.add_root_token: + gold_pidxes[bidx, 1:1+len(cur_seq_idxes)] = cur_seq_idxes + else: + gold_pidxes[bidx, :len(cur_seq_idxes)] = cur_seq_idxes + # get loss + margin = self.margin.value + gold_pidxes_t = BK.input_idx(gold_pidxes) + gold_pidxes_t *= (gold_pidxes_t<self.pred_out_dim).long() # 0 means invalid ones!! + loss_mask_t = (gold_pidxes_t>0).float() * mask_t # [bs, ?+rlen] + lab_losses_t = BK.loss_nll(scores_t, gold_pidxes_t, margin=margin) # [bs, ?+rlen] + # argmax + _, argmax_idxes = scores_t.max(-1) + pred_corrs = (argmax_idxes == gold_pidxes_t).float() * loss_mask_t + # compile loss + lab_loss = LossHelper.compile_leaf_info("slab", lab_losses_t.sum(), loss_mask_t.sum(), corr=pred_corrs.sum()) + return self._compile_component_loss(self.pname, [lab_loss]) + + # predict + def predict(self, insts: List[GeneralSentence], repr_t, mask_t, **kwargs): + conf = self.conf + # score + scores_t = self._score(repr_t) # [bs, ?+rlen, D] + _, argmax_idxes = scores_t.max(-1) # [bs, ?+rlen] + argmax_idxes_arr = BK.get_value(argmax_idxes) # [bs, ?+rlen] + # assign; todo(+2): record scores? + one_offset = int(self.add_root_token) + for one_bidx, one_inst in enumerate(insts): + one_pidxes = argmax_idxes_arr[one_bidx, one_offset:one_offset+len(one_inst)].tolist() + one_pseq = SeqField(None) + one_pseq.build_vals(one_pidxes, self.vocab) + one_inst.add_item("pred_"+self.attr_name, one_pseq, assert_non_exist=False) + return diff --git a/tasks/zmlm/model/mods/vrec/__init__.py b/tasks/zmlm/model/mods/vrec/__init__.py new file mode 100644 index 0000000..72a5b0d --- /dev/null +++ b/tasks/zmlm/model/mods/vrec/__init__.py @@ -0,0 +1,10 @@ +# + +# the MuRec encoder + +# ----- +# for the murec-att-node, it aims to collect attentinal evidences: +# c1_scorer / c2_normer / c3_collector +# for the murec-layer, it is composed of murec-att & rec-wrapper + +from .vrec import VRecEncoderConf, VRecEncoder, VRecCache, VRecConf, VRecNode diff --git a/tasks/zmlm/model/mods/vrec/base.py b/tasks/zmlm/model/mods/vrec/base.py new file mode 100644 index 0000000..0f6da94 --- /dev/null +++ b/tasks/zmlm/model/mods/vrec/base.py @@ -0,0 +1,137 @@ +# + +# some basic components + +from typing import List +from msp.utils import Conf +from msp.nn import BK +from msp.nn.layers import BasicNode, ActivationHelper, Affine, NoDropRop, Dropout + +# ----- +# Affine/MLP layer for convenient use +class AffineHelperNode(BasicNode): + def __init__(self, pc, input_dim, hid_dim: int, hid_act='linear', hid_drop=0., hid_piece4init=1, + out_dim=0, out_fbias=0., out_fact="linear", out_piece4init=1, init_scale=1.): + super().__init__(pc, None, None) + # ----- + # hidden layer + self.hid_aff = self.add_sub_node("hidden", Affine( + pc, input_dim, hid_dim, n_piece4init=hid_piece4init, init_scale=init_scale, init_rop=NoDropRop())) + self.hid_act_f = ActivationHelper.get_act(hid_act) + self.hid_drop = self.add_sub_node("drop", Dropout(pc, (hid_dim, ), fix_rate=hid_drop)) + # ----- + # output layer (optional) + self.final_output_dim = hid_dim + # todo(+N): how about split hidden layers for each specific output + self.out_fbias = out_fbias # fixed extra bias + self.out_act_f = ActivationHelper.get_act(out_fact) + # no output dropouts + if out_dim > 0: + assert hid_act != "linear", "Please use non-linear activation for mlp!" + assert hid_piece4init == 1, "Strange hid_piece4init for hidden layer with out_dim>0" + self.final_aff = self.add_sub_node("final", Affine( + pc, hid_dim, out_dim, n_piece4init=out_piece4init, init_scale=init_scale, init_rop=NoDropRop())) + self.final_output_dim = out_dim + else: + self.final_aff = None + + def get_output_dims(self, *input_dims): + return (self.final_output_dim, ) + + def __call__(self, input_t): + # hidden + hid_t = self.hid_aff(input_t) + # act and drop + act_hid_t = self.hid_act_f(hid_t) + drop_hid_t = self.hid_drop(act_hid_t) + # optional extra layer + if self.final_aff: + out0 = self.final_aff(drop_hid_t) + out1 = self.out_act_f(out0 + self.out_fbias) # fixed bias and final activation + return out1 + else: + return drop_hid_t + +# ----- +# gumbel-softmax / concrete-node +class ConcreteNodeConf(Conf): + def __init__(self): + self.use_gumbel = False + self.use_argmax = False + self.prune_val = 0. # hard prune for probs + +# discrete / gumbel softmax +class ConcreteNode(BasicNode): + def __init__(self, pc, conf: ConcreteNodeConf): + super().__init__(pc, None, None) + self.use_gumbel = conf.use_gumbel + self.use_argmax = conf.use_argmax + self.prune_val = conf.prune_val + self.gumbel_eps = 1e-10 + + # [*, S] -> normalized [*, S] + def __call__(self, scores, temperature=1., dim=-1): + is_training = self.rop.training + # only use stochastic at training + if is_training: + if self.use_gumbel: + gumbel_eps = self.gumbel_eps + G = (BK.rand(BK.get_shape(scores)) + gumbel_eps).clamp(max=1.) # [0,1) + scores = scores - (gumbel_eps - G.log()).log() + # normalize + probs = BK.softmax(scores / temperature, dim=dim) # [*, S] + # prune and re-normalize? + if self.prune_val > 0.: + probs = probs * (probs > self.prune_val).float() + # todo(note): currently no re-normalize + # probs = probs / probs.sum(dim=dim, keepdim=True) # [*, S] + # argmax and ste + if self.use_argmax: # use the hard argmax + max_probs, _ = probs.max(dim, keepdim=True) # [*, 1] + # todo(+N): currently we do not re-normalize here, should it be done here? + st_probs = (probs>=max_probs).float() * probs # [*, S] + if is_training: # (hard-soft).detach() + soft + st_probs = (st_probs - probs).detach() + probs # [*, S] + return st_probs + else: + return probs + +# ----- +# affine combiner +class AffineCombiner(BasicNode): + def __init__(self, pc, input_dims: List[int], use_affs: List[bool], out_dim: int, + out_act='linear', out_drop=0., param_init_scale=1.): + super().__init__(pc, None, None) + # ----- + self.input_dims = input_dims + self.use_affs = use_affs + self.out_dim = out_dim + # ===== + assert len(input_dims) == len(use_affs) + self.aff_nodes = [] + for d, use in zip(input_dims, use_affs): + if use: + one_aff = self.add_sub_node("aff", Affine(pc, d, out_dim, init_scale=param_init_scale, init_rop=NoDropRop())) + else: + assert d == out_dim, f"Dimension mismatch for skipping affine: {d} vs {out_dim}!" + one_aff = None + self.aff_nodes.append(one_aff) + self.out_act_f = ActivationHelper.get_act(out_act) + self.out_drop = self.add_sub_node("drop", Dropout(pc, (out_dim, ), fix_rate=out_drop)) + + def get_output_dims(self, *input_dims): + return (self.out_dim, ) + + # input is List[*, len, dim] + def __call__(self, input_ts: List): + comb_t = 0. + assert len(input_ts) == len(self.aff_nodes), "Input number mismatch" + for one_t, one_aff in zip(input_ts, self.aff_nodes): + if one_aff: + add_t = one_aff(one_t) + else: + add_t = one_t + comb_t += add_t + comb_t_act = self.out_act_f(comb_t) + comb_t_drop = self.out_drop(comb_t_act) + return comb_t_drop diff --git a/tasks/zmlm/model/mods/vrec/c1_score.py b/tasks/zmlm/model/mods/vrec/c1_score.py new file mode 100644 index 0000000..92ad120 --- /dev/null +++ b/tasks/zmlm/model/mods/vrec/c1_score.py @@ -0,0 +1,209 @@ +# + +# Component 1 for MAtt: scorer +# (Q, K, accu_attns, rel_dist) => scores [*, len_q, len_k, head] + +from typing import List, Union, Dict, Iterable +from msp.utils import Constants, Conf, zwarn, Helper +from msp.nn import BK +from msp.nn.layers import BasicNode, ActivationHelper, Affine, NoDropRop, NoFixRop, Dropout, PosiEmbedding2, LayerNorm +from .base import AffineHelperNode + +# ----- + +# conf +class MAttScorerConf(Conf): + def __init__(self): + # special + self.param_init_scale = 1. # todo(+N): ugly extra scale + self.use_piece4init = False + # dot-product for Q*K + self.d_qk = 64 # query*key -> score + self.qk_act = 'linear' + self.qk_drop = 0. + self.score_scale_power = 0.5 # scale power for the dot product scoring + self.use_unhead_score = False # general score (untyped) for q*k + # relative distances? + self.use_rel_dist = False + self.rel_emb_dim = 512 # this should be enough + self.rel_emb_zero0 = True # zero the 0 item + self.rel_dist_clip = 1000 # this should be enough + self.rel_dist_abs = False # discard sign on rel dist, if True: [0, clip] else [-clip, clip] + self.rel_init_sincos = True # init and freeze rel_dist embeddings with sincos version + # q(/h)-specific lambdas to minus accu_attns + self.use_lambq = False + self.lambq_hdim = 128 + self.lambq_hdrop = 0.1 + self.lambq_hact = "elu" + self.lambq_fbias = 0. # fixed extra bias + self.lambq_fact = "relu" # output activation after adding bias + # q(/h)-specific score offset + self.use_soff = False + self.soff_hdim = 128 + self.soff_hdrop = 0.1 + self.soff_hact = "elu" + self.soff_fbias = 0. # fixed extra bias + self.soff_fact = "relu" # output activation after adding bias + # no (masking out) self-loop? + self.no_self_loop = False + + def get_att_scale_qk(self): + return float(self.d_qk ** self.score_scale_power) + +# the scorer +class MAttScorerNode(BasicNode): + def __init__(self, pc: BK.ParamCollection, dim_q, dim_k, head_count, conf: MAttScorerConf): + super().__init__(pc, None, None) + self.conf = conf + self.dim_q, self.dim_k = dim_q, dim_k + self.head_count = head_count + # ----- + self.use_unhead_score = conf.use_unhead_score + _calc_head_count = (1+head_count) if self.use_unhead_score else head_count # extra dim for overall unhead score + _d_qk = conf.d_qk + _hid_size_qk = _calc_head_count * _d_qk + self._att_scale_qk = conf.get_att_scale_qk() # scale for dot product + self.split_dims = [_calc_head_count, -1] + # affine-qk + self.affine_q = self.add_sub_node("aq", AffineHelperNode( + pc, dim_q, _hid_size_qk, hid_act=conf.qk_act, hid_drop=conf.qk_drop, + hid_piece4init=(_calc_head_count if conf.use_piece4init else 1), init_scale=conf.param_init_scale)) + self.affine_k = self.add_sub_node("ak", AffineHelperNode( + pc, dim_k, _hid_size_qk, hid_act=conf.qk_act, hid_drop=conf.qk_drop, + hid_piece4init=(_calc_head_count if conf.use_piece4init else 1), init_scale=conf.param_init_scale)) + # rel positional + if conf.use_rel_dist: + self.dist_helper = self.add_sub_node("dh", RelDistHelperNode(pc, _calc_head_count, conf)) + else: + self.dist_helper = None + # lambq + if conf.use_lambq: + self.lambq_aff = self.add_sub_node("al", AffineHelperNode( + pc, dim_q, hid_dim=conf.lambq_hdim, hid_act=conf.lambq_hact, hid_drop=conf.lambq_hdrop, + out_dim=head_count, out_fbias=conf.lambq_fbias, out_fact=conf.lambq_fact, + out_piece4init=(head_count if conf.use_piece4init else 1), init_scale=conf.param_init_scale)) + else: + self.lambq_aff = None + # soff + if conf.use_soff: + self.soff_aff = self.add_sub_node("as", AffineHelperNode( + pc, dim_q, hid_dim=conf.soff_hdim, hid_act=conf.soff_hact, hid_drop=conf.soff_hdrop, + out_dim=(1+head_count), out_fbias=conf.soff_fbias, out_fact=conf.soff_fact, + out_piece4init=((1+head_count) if conf.use_piece4init else 1), init_scale=conf.param_init_scale)) + else: + self.soff_aff = None + # no (masking out) self loop? + self.no_self_loop = conf.no_self_loop + + # ----- + # [*, len, head*dim] -> [*, head, len, dim] if do_transpose else [*, len, head, dim] + def _shape_project(self, x, do_transpose: bool): + x_size = BK.get_shape(x)[:-1] + self.split_dims + x_reshaped = x.view(x_size) + return x_reshaped.transpose(-2, -3) if do_transpose else x_reshaped + + # input is *[*, len?, Din], [*, len_q, len_k, head], [*, len_k], [*, len_q, len_k] + def __call__(self, query, key, accu_attn, mask_k, mask_qk, rel_dist): + conf = self.conf + # == calculate the dot-product scores + # calculate the three: # [bs, len_?, head*D]; and also add sta ones if needed + query_up, key_up = self.affine_q(query), self.affine_k(key) # [*, len?, head?*Dqk] + query_up, key_up = self._shape_project(query_up, True), self._shape_project(key_up, True) # [*, head?, len_?, D] + # original scores + scores = BK.matmul(query_up, BK.transpose(key_up, -1, -2)) / self._att_scale_qk # [*, head?, len_q, len_k] + # == adding rel_dist ones + if conf.use_rel_dist: + scores = self.dist_helper(query_up, key_up, rel_dist=rel_dist, input_scores=scores) + # tranpose + scores = scores.transpose(-2, -3).transpose(-1, -2) # [*, len_q, len_k, head?] + # == unhead score + if conf.use_unhead_score: + scores_t0, score_t1 = BK.split(scores, [1, self.head_count], -1) # [*, len_q, len_k, 1|head] + scores = scores_t0 + score_t1 # [*, len_q, len_k, head] + # == combining with history accumulated attns + if conf.use_lambq and accu_attn is not None: + # todo(note): here we only consider "query" and "head", would it be necessary for "key"? + lambq_vals = self.lambq_aff(query) # [*, len_q, head], if for eg., using relu as fact, this>=0 + scores -= lambq_vals.unsqueeze(-2) * accu_attn + # == score offset + if conf.use_soff: + # todo(note): here we only consider "query" and "head", key may be handled by "unhead_score" + score_offset_t = self.soff_aff(query) # [*, len_q, 1+head] + score_offset_t0, score_offset_t1 = BK.split(score_offset_t, [1, self.head_count], -1) # [*, len_q, 1|head] + scores -= score_offset_t0.unsqueeze(-2) + scores -= score_offset_t1.unsqueeze(-2) # still [*, len_q, len_k, head] + # == apply mask & no-self-loop + # NEG_INF = Constants.REAL_PRAC_MIN + NEG_INF = -1000. # this should be enough + NEG_INF2 = -2000. # this should be enough + if mask_k is not None: # [*, 1, len_k, 1] + scores += (1.-mask_k).unsqueeze(-2).unsqueeze(-1) * NEG_INF2 + if mask_qk is not None: # [*, len_q, len_k, 1] + scores += (1.-mask_qk).unsqueeze(-1) * NEG_INF2 + if self.no_self_loop: + query_len = BK.get_shape(query, -2) + assert query_len == BK.get_shape(key, -2), "Shape not matched for no_self_loop" + scores += BK.eye(query_len).unsqueeze(-1) * NEG_INF # [len_q, len_k, 1] + return scores.contiguous() # [*, len_q, len_k, head] + +# ----- +# RelDistHelper: here we apply (relative) positional info only to att/score calculation +# -- following the ones in xl-transformer/xlnet +class RelDistHelperNode(BasicNode): + def __init__(self, pc, head_count: int, conf: MAttScorerConf): + super().__init__(pc, None, None) + self.rel_dist_abs = conf.rel_dist_abs + _clip = conf.rel_dist_clip + _head_count = head_count + _d_qk = conf.d_qk + _d_emb = conf.rel_emb_dim + # scale + self._att_scale_qk = conf.get_att_scale_qk() # scale for dot product + # positional embedding + self.E = self.add_sub_node("e", PosiEmbedding2( + pc, _d_emb, max_val=_clip, init_sincos=conf.rel_init_sincos, freeze=conf.rel_init_sincos, zero0=conf.rel_emb_zero0)) + # different parameters for each head + self.split_dims = [_head_count, _d_qk] + self.affine_rel = self.add_sub_node("ar", AffineHelperNode( + pc, _d_emb, _head_count*_d_qk, hid_piece4init=(head_count if conf.use_piece4init else 1), init_scale=conf.param_init_scale)) + self.vec_u = self.add_param("vu", (_head_count, _d_qk), scale=conf.param_init_scale) + self.vec_v = self.add_param("vv", (_head_count, _d_qk), scale=conf.param_init_scale) + + # [query, key] + def get_rel_dist(self, len_q: int, len_k: int): + dist_x = BK.arange_idx(0, len_k).unsqueeze(0) # [1, len_k] + dist_y = BK.arange_idx(0, len_q).unsqueeze(1) # [len_q, 1] + distance = dist_x - dist_y # [len_q, len_k] + return distance + + # [*, head, len_q/len_k, D]; could be inplaced adding "input_scores" + def __call__(self, query_up, key_up, rel_dist=None, input_scores=None): + _att_scale_qk = self._att_scale_qk + # ----- + # get dim info + len_q, len_k = BK.get_shape(query_up, -2), BK.get_shape(key_up, -2) + # get distance embeddings + if rel_dist is None: + rel_dist = self.get_rel_dist(len_q, len_k) + if self.rel_dist_abs: # use abs? + rel_dist = BK.abs(rel_dist) + dist_embs = self.E(rel_dist) # [len_q, len_k, Demb] + # ----- + # dist_up + dist_up0 = self.affine_rel(dist_embs) # [len_q, len_k, head*D] + # -> [head, len_q, len_k, D] + dist_up1 = dist_up0.view(BK.get_shape(dist_up0)[:-1] + self.split_dims).transpose(-2, -3).transpose(-3, -4) + # ----- + # all items are [*, head, len_q, len_k] + posi_scores = (input_scores if (input_scores is not None) else 0.) + # item (b): <query, dist>: [head, len_q, len_k, D] * [*, head, len_q, D, 1] -> [*, head, len_q, len_k] + item_b = (BK.matmul(dist_up1, query_up.unsqueeze(-1)) / _att_scale_qk).squeeze(-1) + posi_scores += item_b + # todo(note): remove this item_c since it is not related with rel_dist + # # item (c): <key, u>: [*, head, len_k, D] * [head, D, 1] -> [*, head, 1, len_k] + # item_c = (BK.matmul(key_up, self.vec_u.unsqueeze(-1)) / _att_scale_qk).squeeze(-1).unsqueeze(-2) + # posi_scores += item_c + # item (d): <dist, v>: [head, len_q, len_k, D] * [head, 1, D, 1] -> [head, len_q, len_k] + item_d = (BK.matmul(dist_up1, self.vec_v.unsqueeze(-2).unsqueeze(-1)) / _att_scale_qk).squeeze(-1) + posi_scores += item_d + return posi_scores diff --git a/tasks/zmlm/model/mods/vrec/c2_norm.py b/tasks/zmlm/model/mods/vrec/c2_norm.py new file mode 100644 index 0000000..fbfcf88 --- /dev/null +++ b/tasks/zmlm/model/mods/vrec/c2_norm.py @@ -0,0 +1,135 @@ +# + +# Component 2 for MAtt: normer +# (scores [*, len_q, len_k, head], accu_attns) => normalized-scores [*, len_q, len_k, head] + +from typing import List, Union, Dict, Iterable +from msp.utils import Constants, Conf, zwarn, Helper +from msp.nn import BK +from msp.nn.layers import BasicNode, ActivationHelper, Affine, NoDropRop, NoFixRop, Dropout, PosiEmbedding2, LayerNorm +from .base import ConcreteNodeConf, ConcreteNode + +# ----- + +# conf +class MAttNormerConf(Conf): + def __init__(self): + # special + # self.param_init_scale = 1. # actually no params in this node + # noop compare score (fixed) + self.use_noop = False # attend to nothing + self.noop_fixed_val = 0. # fixed noop score + # norm + # todo(+N): here no longer norm multiple times since that seems complicated + self.norm_mode = 'cand' # flatten/head/cand/head_cand/binary + self.norm_prune = 0. # prune final normalized scores to 0. if <=this + # concrete + self.cconf = ConcreteNodeConf() + # dropout directly on the normalized scores + self.attn_dropout = 0. + +# for getting normalized scores (like probs) +class MAttNormerNode(BasicNode): + def __init__(self, pc: BK.ParamCollection, head_count, conf: MAttNormerConf): + super().__init__(pc, None, None) + self.conf = conf + self.head_count = head_count + # ----- + self._norm_f = getattr(self, "_norm_"+conf.norm_mode) # shortcut + self._norm_dims = {'flatten': [-1], 'head': [-1], 'cand': [-2], 'head_cand': [-1, -2], 'binary': [-1]}[conf.norm_mode] + self.norm_prune = conf.norm_prune + # cnode: special attention + self.cnode = self.add_sub_node("cn", ConcreteNode(pc, conf.cconf)) + # attention dropout: # no-fix attention dropout, not elegant here + rr = NoFixRop() + self.adrop = self.add_sub_node("adrop", Dropout(pc, (), init_rop=rr, fix_rate=conf.attn_dropout)) + + # ===== + # different strategies to calculate prob matrix + + @property + def norm_dims(self): + return self._norm_dims + + # helper function for noop + def _normalize(self, cnode: ConcreteNode, orig_scores, use_noop: bool, noop_fixed_val: float, temperature: float, dim: int): + cur_shape = BK.get_shape(orig_scores) # original + orig_that_dim = cur_shape[dim] + cur_shape[dim] = 1 + if use_noop: + noop_scores = BK.constants(cur_shape, value=noop_fixed_val) # [*, 1, *] + to_norm_scores = BK.concat([orig_scores, noop_scores], dim=dim) # [*, D+1, *] + else: + to_norm_scores = orig_scores # [*, D, *] + # normalize + prob_full = cnode(to_norm_scores, temperature=temperature, dim=dim) # [*, ?, *] + if use_noop: + prob_valid, prob_noop = BK.split(prob_full, [orig_that_dim, 1], dim) # [*, D|1, *] + else: + prob_valid, prob_noop = prob_full, None + return prob_valid, prob_noop, prob_full + + # 1. direct normalize over flattened [len_k, head] + def _norm_flatten(self, scores, temperature): + conf = self.conf + cnode = self.cnode + use_noop, noop_fixed_val = conf.use_noop, conf.noop_fixed_val + # ----- + orig_shape = BK.get_shape(scores) + prob_valid, prob_noop, prob_full = self._normalize( + cnode, scores.view(orig_shape[:-2]+[-1]), use_noop, noop_fixed_val, temperature, -1) + attn = prob_valid.view(orig_shape) # back to 2d shape + # [*, len_q, len_k, head], [*, len_q, len_k*head], [*, len_q, ?], [*, len_q, len_k*head+?] + return attn, [prob_valid], [prob_noop], [prob_full], [-1] + + # 2. norm on head-dim for each cand + def _norm_head(self, scores, temperature): + conf = self.conf + cnode = self.cnode + use_noop, noop_fixed_val = conf.use_noop, conf.noop_fixed_val + # ----- + prob_valid, prob_noop, prob_full = self._normalize(cnode, scores, use_noop, noop_fixed_val, temperature, -1) + # [*, len_q, len_k, head], [*, len_q, len_k, head], [*, len_q, len_k, ?], [*, len_q, len_k, head+?] + return prob_valid, [prob_valid], [prob_noop], [prob_full], [-1] + + # 3. norm on cand-dim for each head + def _norm_cand(self, scores, temperature): + conf = self.conf + cnode = self.cnode + use_noop, noop_fixed_val = conf.use_noop, conf.noop_fixed_val + # ----- + prob_valid, prob_noop, prob_full = self._normalize(cnode, scores, use_noop, noop_fixed_val, temperature, -2) + # [*, len_q, len_k, head], [*, len_q, len_k, head], [*, len_q, ?, head], [*, len_q, len_k+?, head] + return prob_valid, [prob_valid], [prob_noop], [prob_full], [-2] + + # 4. multiplying (2) and (3) + def _norm_head_cand(self, scores, temperature): + conf = self.conf + cnode = self.cnode + use_noop, noop_fixed_val = conf.use_noop, conf.noop_fixed_val + # ----- + prob_valid1, prob_noop1, prob_full1 = self._normalize(cnode, scores, use_noop, noop_fixed_val, temperature, -1) + prob_valid2, prob_noop2, prob_full2 = self._normalize(cnode, scores, use_noop, noop_fixed_val, temperature, -2) + # [*, len_q, len_k, head], [*, len_q, len_k, head], ... + return prob_valid1*prob_valid2, [prob_valid1, prob_valid2], [prob_noop1, prob_noop2], [prob_full1, prob_full2], [-1, -2] + + # 5. binary for each individual one + def _norm_binary(self, scores, temperature): + conf = self.conf + cnode = self.cnode + use_noop, noop_fixed_val = conf.use_noop, conf.noop_fixed_val + # ----- + prob_valid, prob_noop, prob_full = self._normalize(cnode, scores.unsqueeze(-1), use_noop, noop_fixed_val, temperature, -1) + # [*, len_q, len_k, head], [*, len_q, len_k, head], [*, len_q, len_k, head, ?], [*, len_q, len_k, head, 1+?] + attn = prob_valid.squeeze(-1) + return attn, [prob_valid], [prob_noop], [prob_full], [-1] + + # input score is [*, len_q, len_k, head], output normalized ones and some extras + def __call__(self, scores, temperature): + attn, list_prob_valid, list_prob_noop, list_prob_full, list_dims = self._norm_f(scores, temperature) + # prune + if self.norm_prune > 0.: + attn = attn * (attn > self.norm_prune).float() + # dropout + attn = self.adrop(attn) + return attn, list_prob_valid, list_prob_noop, list_prob_full, list_dims diff --git a/tasks/zmlm/model/mods/vrec/c3_collect.py b/tasks/zmlm/model/mods/vrec/c3_collect.py new file mode 100644 index 0000000..00e0cbe --- /dev/null +++ b/tasks/zmlm/model/mods/vrec/c3_collect.py @@ -0,0 +1,166 @@ +# + +# Component 3 for MAtt: collector +# (V, normalized_scores [*, len_q, len_k, head], accu_attns) => collected-V [*, len_q, Dv] + +from typing import List, Union, Dict, Iterable +from msp.utils import Constants, Conf, zwarn, Helper +from msp.nn import BK +from msp.nn.layers import BasicNode, ActivationHelper, Affine, NoDropRop, NoFixRop, Dropout, PosiEmbedding2, LayerNorm +from .base import AffineHelperNode + +# ----- + +# conf +class MAttCollectorConf(Conf): + def __init__(self): + # special + self.param_init_scale = 1. # todo(+N): ugly extra scale + self.use_piece4init = False + # content of values + self.use_affv = True # if not using, then directly use input (dim_v) + self.d_v = 64 # content reprs + self.v_act = 'linear' + self.v_drop = 0. + # attn & accu_attn + self.accu_lambda = 0. # for eg., -1 means delete previous ones, +1 means consider previous ones + self.accu_ceil = 1. # clamp final applying scores by certain range + self.accu_floor = 0. + # model based clamp + self.use_mclp = False + self.mclp_hdim = 128 + self.mclp_hdrop = 0.1 + self.mclp_hact = "elu" + self.mclp_fbias = 1. # fixed extra bias (making mclp by default positive value) + self.mclp_fact = "relu" # output activation after adding bias + # collecting: 1. weighted sum over certain dims, 2. further reduce if needed + self.collect_mode = "ch" # direct_2d(2d)/cand_head(ch)/head_cand(hc) + self.collect_renorm = False # renorm to avoid weights too large? + self.collect_renorm_mindiv = 1. # avoid renorm on small values + self.collect_reducer_mode = "aff" # max/sum/avg/aff; (aff is only valid for multihead-reduce) + self.collect_red_aff_out = 512 # output for "aff" mode, <0 means the same as d_v + self.collect_red_aff_act = 'linear' + +# the collector: normalized_scores * V +class MAttCollectorNode(BasicNode): + def __init__(self, pc: BK.ParamCollection, dim_q, dim_v, head_count, conf: MAttCollectorConf): + super().__init__(pc, None, None) + self.conf = conf + self.dim_q, self.dim_v, self.head_count = dim_q, dim_v, head_count + # ----- + _d_v = conf.d_v + _hid_size_v = head_count * _d_v + self.split_dims = [head_count, -1] + # affine-v + if conf.use_affv: + self.out_dim = _d_v + self.affine_v = self.add_sub_node("av", AffineHelperNode( + pc, dim_v, _hid_size_v, hid_act=conf.v_act, hid_drop=conf.v_drop, + hid_piece4init=(head_count if conf.use_piece4init else 1), init_scale=conf.param_init_scale)) + else: + self.out_dim = dim_v + self.affine_v = None + # ----- + # model based clamp + if conf.use_mclp: + self.mclp_aff = self.add_sub_node("am", AffineHelperNode( + pc, dim_q, hid_dim=conf.mclp_hdim, hid_act=conf.mclp_hact, hid_drop=conf.mclp_hdrop, + out_dim=head_count, out_fbias=conf.mclp_fbias, out_fact=conf.mclp_fact, + out_piece4init=(head_count if conf.use_piece4init else 1), init_scale=conf.param_init_scale)) + else: + self.mclp_aff = None + # ----- + # how to reduce the two dims of [len_k, head] + self._collect_f = getattr(self, "_collect_"+conf.collect_mode) # shortcut + if conf.collect_reducer_mode == "aff": + aff_out_dim = _d_v if (conf.collect_red_aff_out<=0) else conf.collect_red_aff_out + aff_out_act = conf.collect_red_aff_act + self.out_dim = aff_out_dim + # no param_init_scale here, since output is not multihead + self.multihead_aff_node = self.add_sub_node("ma", Affine(pc, _hid_size_v, aff_out_dim, + act=aff_out_act, init_rop=NoDropRop())) + else: + self.multihead_aff_node = None + # todo(note): always reduce on dim=-2!! + self._collect_reduce_f = { + "sum": lambda x: x.sum(-2), "avg": lambda x: x.mean(-2), "max": lambda x: x.max(-2)[0], + "aff": lambda x: self.multihead_aff_node(x.view(BK.get_shape(x)[:-2] + [-1])), # flatten last two and aff + "cat": lambda x: x.view(BK.get_shape(x)[:-2] + [-1]), # directly flatten + }[conf.collect_reducer_mode] + # special mode "cat" + if conf.collect_reducer_mode == "cat": + assert conf.collect_mode == "ch", "Must get rid of the var. dim of cand first" + self.out_dim = _hid_size_v + # ===== + + def get_output_dims(self, *input_dims): + return (self.out_dim, ) + + # input is [*, len_k, D], *[*, len_q, len_k, head] + def __call__(self, query, value, attn, accu_attn): + conf = self.conf + # get values + if self.affine_v: + value_up = self.affine_v(value).view(BK.get_shape(value)[:-1]+self.split_dims) # [*, (len_k, head), Dv] + else: + value_up = value.unsqueeze(-2) # [*, len_k, 1, D] + value_up_expand_shape = BK.get_shape(value_up) + value_up_expand_shape[-2] = self.head_count + value_up = value_up.expand(value_up_expand_shape).contiguous() # [*, len_k, head, D] + # get final applying attn scores: [*, len_q, (len_k, head)] + if conf.accu_lambda != 0.: + combined_attn = attn + accu_attn * conf.accu_lambda + clamped_attn = combined_attn.clamp(min=conf.accu_floor, max=conf.accu_ceil) + # apply_attn = (clamped_attn - combined_attn).detach() + combined_attn # todo(+N): make gradient flow? + apply_accu_attn = clamped_attn + else: + apply_accu_attn = attn + # model based clamp + if self.mclp_aff: + mclp_upper = self.mclp_aff(query).unsqueeze(-2) # [*, len_q, 1, head] + apply_attn = BK.min_elem(apply_accu_attn, mclp_upper) # clamp on the ceil side + else: + apply_attn = apply_accu_attn + # apply attn and get result + result_value = self._collect_f(value_up, apply_attn) # [*, len_q, Dv] + return result_value + + # ===== + # specific strategies + + # 1. simply weighted sum over 2d [len_k, head] + def _collect_2d(self, value_up, apply_attn): + conf = self.conf + fused_attn = apply_attn.view(BK.get_shape(apply_attn)[:-2]+[-1]) # [*, len_q, len_k*head] + if conf.collect_renorm: + fused_attn = fused_attn / fused_attn.sum(-1, keepdims=True).clamp(min=conf.collect_renorm_mindiv) + value_up_shape = BK.get_shape(value_up) + fused_value_up = value_up.view(value_up_shape[:-3] + [-1, value_up_shape[-1]]) # [*, len_k*head, Dv] + result_value = BK.matmul(fused_attn, fused_value_up) # [*, len_q, Dv] + return result_value + + # 2. weighted sum over cand, then reduce on head (pool/affine) + def _collect_ch(self, value_up, apply_attn): + conf = self.conf + # first weighted sum over cand + cand_final_attn = apply_attn.transpose(-1, -2).transpose(-2, -3) # [*, head, len_q, len_k] + if conf.collect_renorm: + cand_final_attn = cand_final_attn / cand_final_attn.sum(-1, keepdims=True).clamp(min=conf.collect_renorm_mindiv) + transposed_value_up = value_up.transpose(-2, -3) # [*, head, len_k, Dv] + multihead_context = BK.matmul(cand_final_attn, transposed_value_up).transpose(-2, -3) # [*, len_q, head, Dv] + # then reduce head + result_value = self._collect_reduce_f(multihead_context.contiguous()) # [*, len_q, Dv] + return result_value + + # 3. weighted sum over head, then reduce on cand (pool) (cannot use affine here) + def _collect_hc(self, value_up, apply_attn): + conf = self.conf + # first weighted sum over head + head_final_attn = apply_attn.transpose(-2, -3) # [*, len_k, len_q, head] + if conf.collect_renorm: + head_final_attn = head_final_attn / head_final_attn.sum(-1, keepdims=True).clamp(min=conf.collect_renorm_mindiv) + apply_value_up = value_up # [*, len_k, head, Dv] + multicand_context = BK.matmul(head_final_attn, apply_value_up).transpose(-2, -3) # [*, len_q, len_k, Dv] + # then reduce cand + result_value = self._collect_reduce_f(multicand_context) # [*, len_q, Dv] + return result_value diff --git a/tasks/zmlm/model/mods/vrec/fcomb.py b/tasks/zmlm/model/mods/vrec/fcomb.py new file mode 100644 index 0000000..1eb57f3 --- /dev/null +++ b/tasks/zmlm/model/mods/vrec/fcomb.py @@ -0,0 +1,24 @@ +# + +from typing import List, Union, Dict, Iterable +from msp.utils import Constants, Conf, zwarn, Helper +from msp.nn import BK +from msp.nn.layers import BasicNode, Affine, NoDropRop, PosiEmbedding2, ActivationHelper, Dropout +from .base import AffineHelperNode + +# ===== +# actually very similar to MattNode.c1_scorer + +# conf +class FCombConf(Conf): + def __init__(self): + pass + +# node +class FCombNode(BasicNode): + def __init__(self, pc, dim_q, dim_k, dim_v, conf: FCombConf): + super().__init__(pc, None, None) + pass + + def __call__(self, *args, **kwargs): + raise NotImplementedError() diff --git a/tasks/zmlm/model/mods/vrec/matt.py b/tasks/zmlm/model/mods/vrec/matt.py new file mode 100644 index 0000000..d7374f8 --- /dev/null +++ b/tasks/zmlm/model/mods/vrec/matt.py @@ -0,0 +1,49 @@ +# + +# multihead attention node +# -- include the components of: c{1,2,3}_* +# -- (Q, K, V, accu_attns, rel_dist) => (results_value, attn, attn_info) + +from typing import List, Union, Dict, Iterable +from msp.utils import Constants, Conf, zwarn, Helper +from msp.nn import BK +from msp.nn.layers import BasicNode +from .c1_score import MAttScorerConf, MAttScorerNode +from .c2_norm import MAttNormerConf, MAttNormerNode +from .c3_collect import MAttCollectorConf, MAttCollectorNode + +# ----- + +# conf +class MAttConf(Conf): + def __init__(self): + self.head_count = 8 + self.scorer_conf = MAttScorerConf() + self.normer_conf = MAttNormerConf() + self.collector_conf = MAttCollectorConf() + +# node +class MAttNode(BasicNode): + def __init__(self, pc, dim_q, dim_k, dim_v, conf: MAttConf): + super().__init__(pc, None, None) + self.conf = conf + # ----- + self.scorer = self.add_sub_node("c1", MAttScorerNode(pc, dim_q, dim_k, conf.head_count, conf.scorer_conf)) + self.normer = self.add_sub_node("c2", MAttNormerNode(pc, conf.head_count, conf.normer_conf)) + self.collector = self.add_sub_node("c3", MAttCollectorNode(pc, dim_q, dim_v, conf.head_count, conf.collector_conf)) + + def get_output_dims(self, *input_dims): + return self.collector.get_output_dims(*input_dims) + + # input *[*, len?, Din] + def __call__(self, query, key, value, accu_attn, mask_k=None, mask_qk=None, rel_dist=None, temperature=1., forced_attn=None): + if forced_attn is None: # need to calculate the attns + scores = self.scorer(query, key, accu_attn, mask_k=mask_k, mask_qk=mask_qk, rel_dist=rel_dist) + normer_ret = self.normer(scores, temperature) + else: # skip the structured part + scores = None + normer_ret = (forced_attn, [], [], [], []) + result_value = self.collector(query, value, normer_ret[0], accu_attn) + # return the output of the three compoents: + # [*, len_q, len_k, head], ([*, len_q, len_k, head], ...), [*, len_q, Dv] + return scores, normer_ret, result_value diff --git a/tasks/zmlm/model/mods/vrec/vrec.py b/tasks/zmlm/model/mods/vrec/vrec.py new file mode 100644 index 0000000..1c8e331 --- /dev/null +++ b/tasks/zmlm/model/mods/vrec/vrec.py @@ -0,0 +1,360 @@ +# + +# the vertical recursive/recurrent module + +from typing import List, Union, Dict, Iterable +import numpy as np +import math +from msp.utils import Constants, Conf, zwarn, Helper +from msp.nn import BK +from msp.nn.layers import BasicNode, LstmNode2, Affine, Dropout, NoDropRop, LayerNorm +from msp.zext.process_train import SVConf, ScheduledValue +from .matt import MAttConf, MAttNode +from .fcomb import FCombConf, FCombNode +from .base import AffineCombiner +from ...base import BaseModuleConf, BaseModule, LossHelper +# from ...helper import EntropyHelper + +# ===== +# single node + +# conf +class VRecConf(Conf): + def __init__(self): + # layer norm + self.use_pre_norm = False # combine(ss(LN(x)), x) + self.use_post_norm = False # LN(combine(ss(x), x) + # ----- + # selfatt conf + self.feat_mod = "matt" # matt/fcomb + self.matt_conf = MAttConf() + self.fc_conf = FCombConf() + # what to feed + self.feat_qk_lambda_orig = 0. # for Q/K, this*orig_t + (1-this)*rec_t + self.feat_v_lambda_orig = 0. # for V(k), ... + # ----- + # combiner + self.comb_mode = "affine" # affine/lstm + # affine mode + self.comb_affine_q = False + self.comb_affine_v = False # if d_v != dim_q, should affine this + self.comb_affine_act = 'linear' # maybe linear is fine since at outside there may be LayerNorm + self.comb_affine_drop = 0.1 + # what to feed + self.comb_q_lambda_orig = 0. # for q, ... + # ----- + # further ff? + self.ff_dim = 0 + self.ff_drop = 0.1 + self.ff_act = "relu" + + @property + def attn_count(self): + if self.feat_mod == "matt": + return self.matt_conf.head_count + elif self.feat_mod == "fcomb": + return self.fc_conf.fc_count + else: + raise NotImplementedError() + +# cache +class VRecCache: + def __init__(self): + # values for cur step: "orig" is the very first input, "rec" is the recusively built one + self.orig_t = None # [*, len_q, D] + self.rec_t = None # [*, len_q, D] + self.accu_attn = None # [*, len_q, len_v, head] + self.rec_lstm_c_t = None # optional for lstm mode + # list on steps + self.list_hidden = [] # List[*, len_q, D] + self.list_score = [] # List[*, len_q, len_v, head] + self.list_attn = [] # List[*, len_q, len_v, head] + self.list_accu_attn = [] # List[*, len_q, len_v, head] + self.list_attn_info = [] # Tuple(attn, [prob_valid], [prob_noop], [prob_full]) + +# one node (one step) in the recursion +class VRecNode(BasicNode): + def __init__(self, pc, dim: int, conf: VRecConf): + super().__init__(pc, None, None) + self.conf = conf + self.dim = dim + # ===== + # Feat + if conf.feat_mod == "matt": + self.feat_node = self.add_sub_node("feat", MAttNode(pc, dim, dim, dim, conf.matt_conf)) + self.attn_count = conf.matt_conf.head_count + elif conf.feat_mod == "fcomb": + self.feat_node = self.add_sub_node("feat", FCombNode(pc, dim, dim, dim, conf.fc_conf)) + self.attn_count = conf.fc_conf.fc_count + else: + raise NotImplementedError() + feat_out_dim = self.feat_node.get_output_dims()[0] + # ===== + # Combiner + if conf.comb_mode == "affine": + self.comb_aff = self.add_sub_node("aff", AffineCombiner( + pc, [dim, feat_out_dim], [conf.comb_affine_q, conf.comb_affine_v], dim, + out_act=conf.comb_affine_act, out_drop=conf.comb_affine_drop)) + self.comb_f = lambda q, v, c: (self.comb_aff([q,v]), None) + elif conf.comb_mode == "lstm": + self.comb_lstm = self.add_sub_node("lstm", LstmNode2(pc, feat_out_dim, dim)) + self.comb_f = self._call_lstm + else: + raise NotImplementedError() + # ===== + # ff + if conf.ff_dim > 0: + self.has_ff = True + self.linear1 = self.add_sub_node("l1", Affine(pc, dim, conf.ff_dim, act=conf.ff_act, init_rop=NoDropRop())) + self.dropout1 = self.add_sub_node("d1", Dropout(pc, (conf.ff_dim, ), fix_rate=conf.ff_drop)) + self.linear2 = self.add_sub_node("l2", Affine(pc, conf.ff_dim, dim, act="linear", init_rop=NoDropRop())) + self.dropout2 = self.add_sub_node("d2", Dropout(pc, (dim, ), fix_rate=conf.ff_drop)) + else: + self.has_ff = False + # layer norms + if conf.use_pre_norm: + self.att_pre_norm = self.add_sub_node("aln1", LayerNorm(pc, dim)) + self.ff_pre_norm = self.add_sub_node("fln1", LayerNorm(pc, dim)) + else: + self.att_pre_norm = self.ff_pre_norm = None + if conf.use_post_norm: + self.att_post_norm = self.add_sub_node("aln2", LayerNorm(pc, dim)) + self.ff_post_norm = self.add_sub_node("fln2", LayerNorm(pc, dim)) + else: + self.att_post_norm = self.ff_post_norm = None + + def get_output_dims(self, *input_dims): + return (self.dim, ) + + # ----- + def _call_lstm(self, q, v, c): + prev_shapes = BK.get_shape(q)[:-1] # [*, len_q] + in_reshape = [np.prod(prev_shapes), -1] + in_q, in_v, in_c = [z.view(in_reshape) for z in [q,v,c]] + ret_h, ret_c = self.comb_lstm(in_v, (in_q, in_c), None) # v acts as input, q acts as hidden + out_reshape = prev_shapes + [-1] # [*, len_q, D] + return ret_h.view(out_reshape), ret_c.view(out_reshape) + + # init cache + def init_call(self, src): + # init accumulated attn: all 0. + src_shape = BK.get_shape(src) + attn_shape = src_shape[:-1] + [src_shape[-2], self.attn_count] # [*, len_q, len_k, head] + cache = VRecCache() + cache.orig_t = src + cache.rec_t = src # initially, the same as "orig_t" + cache.rec_lstm_c_t = BK.zeros(BK.get_shape(src)) # initially zero + cache.accu_attn = BK.zeros(attn_shape) + return cache + + # one update step: *[bs, slen, dim] + def update_call(self, cache: VRecCache, src_mask=None, qk_mask=None, + attn_range=None, rel_dist=None, temperature=1., forced_attn=None): + conf = self.conf + # ----- + # first call matt to get v + matt_input_qk = cache.orig_t * conf.feat_qk_lambda_orig + cache.rec_t * (1. - conf.feat_qk_lambda_orig) + if self.att_pre_norm: + matt_input_qk = self.att_pre_norm(matt_input_qk) + matt_input_v = cache.orig_t * conf.feat_v_lambda_orig + cache.rec_t * (1. - conf.feat_v_lambda_orig) + # todo(note): currently no pre-norm for matt_input_v + # put attn_range as mask_qk + if attn_range is not None and attn_range >= 0: # <0 means not effective + cur_slen = BK.get_shape(matt_input_qk, -2) + tmp_arange_t = BK.arange_idx(cur_slen) # [slen] + # less or equal!! + mask_qk = ((tmp_arange_t.unsqueeze(-1) - tmp_arange_t.unsqueeze(0)).abs() <= attn_range).float() + if qk_mask is not None: # further with input masks + mask_qk *= qk_mask + else: + mask_qk = qk_mask + scores, attn_info, result_value = self.feat_node( + matt_input_qk, matt_input_qk, matt_input_v, cache.accu_attn, mask_k=src_mask, mask_qk=mask_qk, + rel_dist=rel_dist, temperature=temperature, forced_attn=forced_attn) # ..., [*, len_q, dv] + # ----- + # then combine q(hidden) and v(input) + comb_input_q = cache.orig_t * conf.comb_q_lambda_orig + cache.rec_t * (1. - conf.comb_q_lambda_orig) + comb_result, comb_c = self.comb_f(comb_input_q, result_value, cache.rec_lstm_c_t) # [*, len_q, dim] + if self.att_post_norm: + comb_result = self.att_post_norm(comb_result) + # ----- + # ff + if self.has_ff: + if self.ff_pre_norm: + ff_input = self.ff_pre_norm(comb_result) + else: + ff_input = comb_result + ff_output = comb_result + self.dropout2(self.linear2(self.dropout1(self.linear1(ff_input)))) + if self.ff_post_norm: + ff_output = self.ff_post_norm(ff_output) + else: # otherwise no ff + ff_output = comb_result + # ----- + # update cache and return output + # cache.orig_t = cache.orig_t # this does not change + cache.rec_t = ff_output + cache.accu_attn = cache.accu_attn + attn_info[0] # accumulating attn + cache.rec_lstm_c_t = comb_c # optional C for lstm + cache.list_hidden.append(ff_output) # all hidden layers + cache.list_score.append(scores) # all un-normed scores + cache.list_attn.append(attn_info[0]) # all normed scores + cache.list_accu_attn.append(cache.accu_attn) # all accumulated attns + cache.list_attn_info.append(attn_info) # all attn infos + return ff_output + +# ===== +# encoder + +class VRecEncoderConf(BaseModuleConf): + def __init__(self): + super().__init__() + # ----- + # repeat of common layers + self.vr_conf = VRecConf() # vrec conf + self.num_layer = 4 + self.share_layers = True # recurrent/recursive if sharing + self.attn_ranges = [] + # scheduled values + self.temperature = SVConf().init_from_kwargs(val=1., which_idx="eidx", mode="none", min_val=0.1) + self.lambda_noop_prob = SVConf().init_from_kwargs(val=0., which_idx="eidx", mode="none", min_val=0.) + self.lambda_entropy = SVConf().init_from_kwargs(val=0., which_idx="eidx", mode="none", min_val=0.) + self.lambda_attn_l1 = SVConf().init_from_kwargs(val=0., which_idx="eidx", mode="none", min_val=0.) + # self.lambda_coverage = SVConf().init_from_kwargs(val=0., which_idx="eidx", mode="none", min_val=0.) + # special noop loss + self.noop_epsilon = 0.01 # if <0, means directly using prob themselves without log + + def do_validate(self): + self.attn_ranges = [int(z) for z in self.attn_ranges] + +class VRecEncoder(BaseModule): + def __init__(self, pc, dim: int, conf: VRecEncoderConf): + super().__init__(pc, conf, output_dims=(dim,)) + # ----- + if conf.share_layers: + node = self.add_sub_node("m", VRecNode(pc, dim, conf.vr_conf)) + self.layers = [node for _ in range(conf.num_layer)] # use the same node!! + else: + self.layers = [self.add_sub_node("m", VRecNode(pc, dim, conf.vr_conf)) for _ in range(conf.num_layer)] + self.attn_ranges = conf.attn_ranges.copy() + self.attn_ranges.extend([None] * (conf.num_layer - len(self.attn_ranges))) # None means no restrictions + # scheduled values + self.temperature = ScheduledValue(f"{self.name}:temperature", conf.temperature) + self.lambda_noop_prob = ScheduledValue(f"{self.name}:lambda_noop_prob", conf.lambda_noop_prob) + self.lambda_entropy = ScheduledValue(f"{self.name}:lambda_entropy", conf.lambda_entropy) + self.lambda_attn_l1 = ScheduledValue(f"{self.name}:lambda_attn_l1", conf.lambda_attn_l1) + # self.lambda_coverage = ScheduledValue(f"{self.name}:lambda_coverage", conf.lambda_coverage) + + @property + def attn_count(self): + if len(self.layers) > 0: + all_counts = [z.attn_count for z in self.layers] + assert all(z==all_counts[0] for z in all_counts), "attn_count disagree!!" + return all_counts[0] + else: + return self.conf.vr_conf.attn_count + + def get_scheduled_values(self): + return super().get_scheduled_values() + [self.temperature, self.lambda_noop_prob, self.lambda_entropy, self.lambda_attn_l1] + + def __call__(self, src, src_mask=None, qk_mask=None, rel_dist=None, forced_attns=None, collect_loss=False): + if src_mask is not None: + src_mask = BK.input_real(src_mask) + if forced_attns is None: + forced_attns = [None] * len(self.layers) + # ----- + # forward + temperature = self.temperature.value + cur_hidden = src + if len(self.layers) > 0: + cache = self.layers[0].init_call(src) + for one_lidx, one_layer in enumerate(self.layers): + cur_hidden = one_layer.update_call(cache, src_mask=src_mask, qk_mask=qk_mask, attn_range=self.attn_ranges[one_lidx], + rel_dist=rel_dist, temperature=temperature, forced_attn=forced_attns[one_lidx]) + else: + cache = VRecCache() # empty one + # ----- + # collect loss + if collect_loss: + loss_item = self._collect_losses(cache) + else: + loss_item = None + return cur_hidden, cache, loss_item + + # ===== + # helpers + def _loss_noop_prob(self, t): + noop_epsilon = self.conf.noop_epsilon + if noop_epsilon >= 0.: + ret_loss = ((1. + noop_epsilon - t).log() - math.log(noop_epsilon)) + else: + ret_loss = (1. - t) + return ret_loss + + def _loss_entropy(self, t_and_dim): + # the entropy + t, dim = t_and_dim + all_ents = - (t * (t + 1e-10).log()).sum(dim) # reduce on specific dim + # all_ents = EntropyHelper.cross_entropy(t, t, dim) + return all_ents + + def _loss_attn_l1(self, t): + # simply L1 reg for all attn (but only for >=0 part) + t_act = t * (t>=0.).float() + return t_act.abs() # same shape as input + + # collect loss from one layer + def _collect_losses(self, cache: VRecCache): + special_losses = [] + list_attn_info = cache.list_attn_info # List[(attn, list_prob_valid, list_prob_noop, list_prob_full, list_dims)] + # ----- + cur_lambda_noop_prob = self.lambda_noop_prob.value + if cur_lambda_noop_prob > 0.: + noop_losses = self.get_losses_from_attn_list( + list_attn_info, lambda x: x[2], self._loss_noop_prob, "PNopS", cur_lambda_noop_prob) + special_losses.extend(noop_losses) + # ----- + cur_lambda_entropy = self.lambda_entropy.value + if cur_lambda_entropy > 0.: + ent_losses = self.get_losses_from_attn_list( + list_attn_info, lambda x: list(zip(x[3],x[4])), self._loss_entropy, "EntRegS", cur_lambda_entropy) + special_losses.extend(ent_losses) + # ----- + cur_lambda_attn_l1 = self.lambda_attn_l1.value + if cur_lambda_attn_l1 > 0.: + attn_l1_losses = self.get_losses_from_attn_list( + list_attn_info, lambda x: [x[0]], self._loss_attn_l1, "AttnL1S", cur_lambda_attn_l1) + special_losses.extend(attn_l1_losses) + # ----- + # here combine them into one + ret_loss_item = self._compile_component_loss("vrec", special_losses) + return ret_loss_item + + # common procedures for obtaining losses from "list_attn_info" + @staticmethod + def get_losses_from_attn_list(list_attn_info: List, ts_f, loss_f, loss_prefix, loss_lambda): + loss_num = None + loss_counts: List[int] = [] + loss_sums: List[List] = [] + rets = [] + # ----- + for one_attn_info in list_attn_info: # each update step + one_ts: List = ts_f(one_attn_info) # get tensor list from attn_info + # get number of losses + if loss_num is None: + loss_num = len(one_ts) + loss_counts = [0] * loss_num + loss_sums = [[] for _ in range(loss_num)] + else: + assert len(one_ts) == loss_num, "mismatched ts length" + # iter them + for one_t_idx, one_t in enumerate(one_ts): # iter on the tensor list + one_loss = loss_f(one_t) + # need it to be in the corresponding shape + loss_counts[one_t_idx] += np.prod(BK.get_shape(one_loss)).item() + loss_sums[one_t_idx].append(one_loss.sum()) + # for different steps + for i, one_loss_count, one_loss_sums in zip(range(len(loss_counts)), loss_counts, loss_sums): + loss_leaf = LossHelper.compile_leaf_info(f"{loss_prefix}{i}", BK.stack(one_loss_sums, 0).sum(), + BK.input_real(one_loss_count), loss_lambda=loss_lambda) + rets.append(loss_leaf) + return rets diff --git a/tasks/zmlm/model/mtl.py b/tasks/zmlm/model/mtl.py new file mode 100644 index 0000000..2f7052a --- /dev/null +++ b/tasks/zmlm/model/mtl.py @@ -0,0 +1,460 @@ +# + +# the multitask model + +from typing import List, Dict +import numpy as np +from msp.utils import zlog, zwarn, Random, Helper, Conf +from msp.data import VocabPackage +from msp.nn import BK +from msp.nn.layers import BasicNode, PosiEmbedding2, RefreshOptions, Dropout +from msp.nn.modules import EncConf, MyEncoder +from msp.nn.modules.berter2 import BertFeaturesWeightLayer +from msp.zext.process_train import SVConf, ScheduledValue +from .base import BaseModelConf, BaseModel, BaseModuleConf, BaseModule, LossHelper +from .helper import RPrepConf, RPrepNode, EntropyHelper +from .mods.embedder import EmbedderNodeConf, EmbedderNode, Inputter, AugWord2Node +from .mods.vrec import VRecEncoderConf, VRecEncoder, VRecConf, VRecNode +from .mods.masklm import MaskLMNodeConf, MaskLMNode +from .mods.plainlm import PlainLMNodeConf, PlainLMNode +from .mods.orderpr import OrderPredNodeConf, OrderPredNode +from .mods.dpar import DparG1DecoderConf, DparG1Decoder +from .mods.seqlab import SeqLabNodeConf, SeqLabNode +from .mods.seqcrf import SeqCrfNodeConf, SeqCrfNode +from ..data.insts import GeneralSentence, NpArrField + +# ===== +# the overall Model + +class MtlMlmModelConf(BaseModelConf): + def __init__(self): + super().__init__() + # components + self.emb_conf = EmbedderNodeConf() + self.mlm_conf = MaskLMNodeConf() + self.plm_conf = PlainLMNodeConf() + self.orp_conf = OrderPredNodeConf() + self.orp_loss_special = False # special loss for orp + # non-pretraining parts + self.dpar_conf = DparG1DecoderConf() + self.upos_conf = SeqLabNodeConf() + self.ner_conf = SeqCrfNodeConf() + self.do_ner = False # an extra flag + self.ner_use_crf = True # whether use crf for ner? + # where pairwise repr comes from + self.default_attn_count = 128 # by default, setting this to what? + self.prepr_choice = "attn_max" # rdist, rdist_abs, attn_sum, attn_avg, attn_max, attn_last + # which encoder to use + self.enc_choice = "vrec" # by default, use the new one + self.venc_conf = VRecEncoderConf() + self.oenc_conf = EncConf().init_from_kwargs(enc_hidden=512, enc_rnn_layer=0) + # another small encoder + self.rprep_conf = RPrepConf() + # agreement module + self.lambda_agree = SVConf().init_from_kwargs(val=0., which_idx="eidx", mode="none") + self.agree_loss_f = "js" + # separate generator seed especially for testing mask + self.testing_rand_gen_seed = 0 + self.testing_get_attns = False + # todo(+N): should make this into another conf + # for training and testing + self.no_build_dict = False + # -- length thresh + self.train_max_length = 80 + self.train_min_length = 5 + self.test_single_length = 80 + # -- batch size + self.train_shuffle = True + self.train_batch_size = 80 + self.train_inst_copy = 1 # how many times to copy the training insts (for different masks) + self.test_batch_size = 32 + # -- maxibatch for BatchArranger + self.train_maxibatch_size = 20 # cache for this number of batches in BatchArranger + self.test_maxibatch_size = -1 # -1 means cache all + # -- batch_size_f + self.train_batch_on_len = False # whether use sent length as budgets rather then sent count + self.test_batch_on_len = False + # ----- + self.train_preload_model = False + self.train_preload_process = False + # interactive testing + self.test_interactive = False # special mode + # load pre-training model? + self.load_pretrain_model_name = "" + # ----- + # special mode with extra word2 + # todo(+N): ugly patch! + self.aug_word2 = False + self.aug_word2_dim = 300 + self.aug_word2_pretrain = "" + self.aug_word2_save_dir = "" + self.aug_word2_aug_encoder = True # special model's special mode, feature-based original encoder and stack another one!! + self.aug_detach_ratio = 0.5 + self.aug_detach_dropout = 0.33 + self.aug_detach_numlayer = 6 # use mixture of how many last layers? + + def do_validate(self): + if self.orp_loss_special: + assert self.orp_conf.disturb_mode == "local_shuffle2" + +class MtlMlmModel(BaseModel): + def __init__(self, conf: MtlMlmModelConf, vpack: VocabPackage): + super().__init__(conf) + # for easier checking + self.word_vocab = vpack.get_voc("word") + # components + self.embedder = self.add_node("emb", EmbedderNode(self.pc, conf.emb_conf, vpack)) + self.inputter = Inputter(self.embedder, vpack) # not a node + self.emb_out_dim = self.embedder.get_output_dims()[0] + self.enc_attn_count = conf.default_attn_count + if conf.enc_choice == "vrec": + self.encoder = self.add_component("enc", VRecEncoder(self.pc, self.emb_out_dim, conf.venc_conf)) + self.enc_attn_count = self.encoder.attn_count + elif conf.enc_choice == "original": + conf.oenc_conf._input_dim = self.emb_out_dim + self.encoder = self.add_node("enc", MyEncoder(self.pc, conf.oenc_conf)) + else: + raise NotImplementedError() + zlog(f"Finished building model's encoder {self.encoder}, all size is {self.encoder.count_allsize_parameters()}") + self.enc_out_dim = self.encoder.get_output_dims()[0] + # -- + conf.rprep_conf._rprep_vr_conf.matt_conf.head_count = self.enc_attn_count # make head-count agree + self.rpreper = self.add_node("rprep", RPrepNode(self.pc, self.enc_out_dim, conf.rprep_conf)) + # -- + self.lambda_agree = self.add_scheduled_value(ScheduledValue(f"agr:lambda", conf.lambda_agree)) + self.agree_loss_f = EntropyHelper.get_method(conf.agree_loss_f) + # -- + self.masklm = self.add_component("mlm", MaskLMNode(self.pc, self.enc_out_dim, conf.mlm_conf, self.inputter)) + self.plainlm = self.add_component("plm", PlainLMNode(self.pc, self.enc_out_dim, conf.plm_conf, self.inputter)) + # todo(note): here we use attn as dim_pair, do not use pair if not using vrec!! + self.orderpr = self.add_component("orp", OrderPredNode( + self.pc, self.enc_out_dim, self.enc_attn_count, conf.orp_conf, self.inputter)) + # ===== + # pre-training pre-load point!! + if conf.load_pretrain_model_name: + zlog(f"At preload_pretrain point: Loading from {conf.load_pretrain_model_name}") + self.pc.load(conf.load_pretrain_model_name, strict=False) + # ===== + self.dpar = self.add_component("dpar", DparG1Decoder( + self.pc, self.enc_out_dim, self.enc_attn_count, conf.dpar_conf, self.inputter)) + self.upos = self.add_component("upos", SeqLabNode( + self.pc, "pos", self.enc_out_dim, self.conf.upos_conf, self.inputter)) + if conf.do_ner: + if conf.ner_use_crf: + self.ner = self.add_component("ner", SeqCrfNode( + self.pc, "ner", self.enc_out_dim, self.conf.ner_conf, self.inputter)) + else: + self.ner = self.add_component("ner", SeqLabNode( + self.pc, "ner", self.enc_out_dim, self.conf.ner_conf, self.inputter)) + else: + self.ner = None + # for pairwise reprs (no trainable params here!) + self.rel_dist_embed = self.add_node("oremb", PosiEmbedding2(self.pc, n_dim=self.enc_attn_count, max_val=100)) + self._prepr_f_attn_sum = lambda cache, rdist: BK.stack(cache.list_attn, 0).sum(0) if (len(cache.list_attn))>0 else None + self._prepr_f_attn_avg = lambda cache, rdist: BK.stack(cache.list_attn, 0).mean(0) if (len(cache.list_attn))>0 else None + self._prepr_f_attn_max = lambda cache, rdist: BK.stack(cache.list_attn, 0).max(0)[0] if (len(cache.list_attn))>0 else None + self._prepr_f_attn_last = lambda cache, rdist: cache.list_attn[-1] if (len(cache.list_attn))>0 else None + self._prepr_f_rdist = lambda cache, rdist: self._get_rel_dist_embed(rdist, False) + self._prepr_f_rdist_abs = lambda cache, rdist: self._get_rel_dist_embed(rdist, True) + self.prepr_f = getattr(self, "_prepr_f_"+conf.prepr_choice) # shortcut + # -- + self.testing_rand_gen = Random.create_sep_generator(conf.testing_rand_gen_seed) # especial gen for testing + # ===== + if conf.orp_loss_special: + self.orderpr.add_node_special(self.masklm) + # ===== + # extra one!! + self.aug_word2 = self.aug_encoder = self.aug_mixturer = None + if conf.aug_word2: + self.aug_word2 = self.add_node("aug2", AugWord2Node(self.pc, conf.emb_conf, vpack, + "word2", conf.aug_word2_dim, self.emb_out_dim)) + if conf.aug_word2_aug_encoder: + assert conf.enc_choice == "vrec" + self.aug_detach_drop = self.add_node("dd", Dropout(self.pc, (self.enc_out_dim,), fix_rate=conf.aug_detach_dropout)) + self.aug_encoder = self.add_component("Aenc", VRecEncoder(self.pc, self.emb_out_dim, conf.venc_conf)) + self.aug_mixturer = self.add_node("Amix", BertFeaturesWeightLayer(self.pc, conf.aug_detach_numlayer)) + + # helper: embed and encode + def _emb_and_enc(self, cur_input_map: Dict, collect_loss: bool, insts=None): + conf = self.conf + # ----- + # special mode + if conf.aug_word2 and conf.aug_word2_aug_encoder: + _rop = RefreshOptions(training=False) # special feature-mode!! + self.embedder.refresh(_rop) + self.encoder.refresh(_rop) + # ----- + emb_t, mask_t = self.embedder(cur_input_map) + rel_dist = cur_input_map.get("rel_dist", None) + if rel_dist is not None: + rel_dist = BK.input_idx(rel_dist) + if conf.enc_choice == "vrec": + enc_t, cache, enc_loss = self.encoder(emb_t, src_mask=mask_t, rel_dist=rel_dist, collect_loss=collect_loss) + elif conf.enc_choice == "original": # todo(note): change back to arr for back compatibility + assert rel_dist is None, "Original encoder does not support rel_dist" + enc_t = self.encoder(emb_t, BK.get_value(mask_t)) + cache, enc_loss = None, None + else: + raise NotImplementedError() + # another encoder based on attn + final_enc_t = self.rpreper(emb_t, enc_t, cache) # [*, slen, D] => final encoder output + if conf.aug_word2: + emb2_t = self.aug_word2(insts) + if conf.aug_word2_aug_encoder: + # simply add them all together, detach orig-enc as features + stack_hidden_t = BK.stack(cache.list_hidden[-conf.aug_detach_numlayer:], -2).detach() + features = self.aug_mixturer(stack_hidden_t) + aug_input = (emb2_t + conf.aug_detach_ratio*self.aug_detach_drop(features)) + final_enc_t, cache, enc_loss = self.aug_encoder(aug_input, src_mask=mask_t, + rel_dist=rel_dist, collect_loss=collect_loss) + else: + final_enc_t = (final_enc_t + emb2_t) # otherwise, simply adding + return emb_t, mask_t, final_enc_t, cache, enc_loss + + # --- + # helper on agreement loss + def _get_agr_loss(self, loss_prefix, cache1, cache2=None, copy_num=1): + # ===== + def _ts_pair_f(_attn_info_pair): + _ainfo1, _ainfo2 = _attn_info_pair + _rets = [] + for _t1, _d1, _t2, _d2 in zip(_ainfo1[3], _ainfo1[4], _ainfo2[3], _ainfo2[4]): + assert _d1 == _d2, "disagree on which dim to collect agr_loss!" + _rets.append((_t1, _t2, _d1)) + return _rets + # -- + def _ts_self_f(_list_attn_info): + _rets = [] + for _t, _d in zip(_list_attn_info[3], _list_attn_info[4]): + # extra dim at idx 0; todo(note): must repeat insts at outmost idx: repeated = insts * copy_num + _t1 = _t.view([copy_num, -1] + BK.get_shape(_t)[1:]) # [copy, bs, ...] + # roll it by 1 + _t2 = BK.concat([_t1[-1].unsqueeze(0), _t1[:-1]], dim=0) + _rets.append((_t1, _t2, _d)) + return _rets + # -- + _arg_loss_f = lambda x: self.agree_loss_f(*x) + # ===== + cur_lambda_agr = self.lambda_agree.value + if cur_lambda_agr > 0.: + if cache2 is None: # roll self + assert copy_num > 1 + rets = VRecEncoder.get_losses_from_attn_list( + cache1.list_attn_info, _ts_self_f, _arg_loss_f, loss_prefix, cur_lambda_agr) + else: + rets = VRecEncoder.get_losses_from_attn_list( + list(zip(cache1.list_attn_info, cache2.list_attn_info)), _ts_pair_f, _arg_loss_f, loss_prefix, cur_lambda_agr) + else: + rets = [] + return rets + + def _get_rel_dist_embed(self, rel_dist, use_abs: bool): + if use_abs: + rel_dist = BK.input_idx(rel_dist).abs() + ret = self.rel_dist_embed(rel_dist) # [bs, len, len, H] + return ret + + def _get_rel_dist(self, len_q: int, len_k: int = None): + if len_k is None: + len_k = len_q + dist_x = BK.arange_idx(0, len_k).unsqueeze(0) # [1, len_k] + dist_y = BK.arange_idx(0, len_q).unsqueeze(1) # [len_q, 1] + distance = dist_x - dist_y # [len_q, len_k] + return distance + + # --- + def fb_on_batch(self, insts: List[GeneralSentence], training=True, loss_factor=1., + rand_gen=None, assign_attns=False, **kwargs): + # ===== + # import torch + # torch.autograd.set_detect_anomaly(True) + # with torch.autograd.detect_anomaly(): + # ===== + conf = self.conf + self.refresh_batch(training) + if len(insts) == 0: + return {"fb": 0, "sent": 0, "tok": 0} + # ----- + # copying instances for training: expand at dim0 + cur_copy = conf.train_inst_copy if training else 1 + copied_insts = insts * cur_copy + all_losses = [] + # ----- + # original input + input_map = self.inputter(copied_insts) + # for the pretraining modules + has_loss_mlm, has_loss_orp = (self.masklm.loss_lambda.value > 0.), (self.orderpr.loss_lambda.value > 0.) + if (not has_loss_orp) and has_loss_mlm: # only for mlm + masked_input_map, input_erase_mask_arr = self.masklm.mask_input(input_map, rand_gen=rand_gen) + emb_t, mask_t, enc_t, cache, enc_loss = self._emb_and_enc(masked_input_map, collect_loss=True) + all_losses.append(enc_loss) + # mlm loss; todo(note): currently only using one layer + mlm_loss = self.masklm.loss(enc_t, input_erase_mask_arr, input_map) + all_losses.append(mlm_loss) + # assign values + if assign_attns: # may repeat and only keep that last one, but does not matter! + self._assign_attns_item(copied_insts, "mask", input_erase_mask_arr=input_erase_mask_arr, cache=cache) + # agreement loss + if cur_copy > 1: + all_losses.extend(self._get_agr_loss("agr_mlm", cache, copy_num=cur_copy)) + if has_loss_orp: + disturbed_input_map = self.orderpr.disturb_input(input_map, rand_gen=rand_gen) + if has_loss_mlm: # further mask some + disturb_keep_arr = disturbed_input_map.get("disturb_keep", None) + assert disturb_keep_arr is not None, "No keep region for mlm!" + # todo(note): in this mode we assume add_root, so here exclude arti-root by [:,1:] + masked_input_map, input_erase_mask_arr = \ + self.masklm.mask_input(input_map, rand_gen=rand_gen, extra_mask_arr=disturb_keep_arr[:,1:]) + disturbed_input_map.update(masked_input_map) # update + emb_t, mask_t, enc_t, cache, enc_loss = self._emb_and_enc(disturbed_input_map, collect_loss=True) + all_losses.append(enc_loss) + # orp loss + if conf.orp_loss_special: + orp_loss = self.orderpr.loss_special(enc_t, mask_t, disturbed_input_map.get("disturb_keep", None), + disturbed_input_map, self.masklm) + else: + orp_input_attn = self.prepr_f(cache, disturbed_input_map.get("rel_dist")) + orp_loss = self.orderpr.loss(enc_t, orp_input_attn, mask_t, disturbed_input_map.get("disturb_keep", None)) + all_losses.append(orp_loss) + # mlm loss + if has_loss_mlm: + mlm_loss = self.masklm.loss(enc_t, input_erase_mask_arr, input_map) + all_losses.append(mlm_loss) + # assign values + if assign_attns: # may repeat and only keep that last one, but does not matter! + self._assign_attns_item(copied_insts, "dist", abs_posi_arr=disturbed_input_map.get("posi"), cache=cache) + # agreement loss + if cur_copy > 1: + all_losses.extend(self._get_agr_loss("agr_orp", cache, copy_num=cur_copy)) + if self.plainlm.loss_lambda.value > 0.: + if conf.enc_choice == "vrec": # special case for blm + emb_t, mask_t = self.embedder(input_map) + rel_dist = input_map.get("rel_dist", None) + if rel_dist is not None: + rel_dist = BK.input_idx(rel_dist) + # two directions + true_rel_dist = self._get_rel_dist(BK.get_shape(mask_t, -1)) # q-k: [len_q, len_k] + enc_t1, cache1, enc_loss1 = self.encoder(emb_t, src_mask=mask_t, qk_mask=(true_rel_dist<=0).float(), + rel_dist=rel_dist, collect_loss=True) + enc_t2, cache2, enc_loss2 = self.encoder(emb_t, src_mask=mask_t, qk_mask=(true_rel_dist>=0).float(), + rel_dist=rel_dist, collect_loss=True) + assert not self.rpreper.active, "TODO: Not supported for this mode" + all_losses.extend([enc_loss1, enc_loss2]) + # plm loss with explict two inputs + plm_loss = self.plainlm.loss([enc_t1, enc_t2], input_map) + all_losses.append(plm_loss) + else: + # here use original input + emb_t, mask_t, enc_t, cache, enc_loss = self._emb_and_enc(input_map, collect_loss=True) + all_losses.append(enc_loss) + # plm loss + plm_loss = self.plainlm.loss(enc_t, input_map) + all_losses.append(plm_loss) + # agreement loss + assert self.lambda_agree.value==0., "Not implemented for this mode" + # ===== + # task loss + dpar_loss_lambda, upos_loss_lambda, ner_loss_lambda = \ + [0. if z is None else z.loss_lambda.value for z in [self.dpar, self.upos, self.ner]] + if any(z>0. for z in [dpar_loss_lambda, upos_loss_lambda, ner_loss_lambda]): + # here use original input + emb_t, mask_t, enc_t, cache, enc_loss = self._emb_and_enc(input_map, collect_loss=True, insts=insts) + all_losses.append(enc_loss) + # parsing loss + if dpar_loss_lambda > 0.: + dpar_input_attn = self.prepr_f(cache, self._get_rel_dist(BK.get_shape(mask_t, -1))) + dpar_loss = self.dpar.loss(copied_insts, enc_t, dpar_input_attn, mask_t) + all_losses.append(dpar_loss) + # pos loss + if upos_loss_lambda > 0.: + upos_loss = self.upos.loss(copied_insts, enc_t, mask_t) + all_losses.append(upos_loss) + # ner loss + if ner_loss_lambda > 0.: + ner_loss = self.ner.loss(copied_insts, enc_t, mask_t) + all_losses.append(ner_loss) + # ----- + info = self.collect_loss_and_backward(all_losses, training, loss_factor) + info.update({"fb": 1, "sent": len(insts), "tok": sum(len(z) for z in insts)}) + return info + + def inference_on_batch(self, insts: List[GeneralSentence], **kwargs): + conf = self.conf + self.refresh_batch(False) + with BK.no_grad_env(): + # special mode + # use: CUDA_VISIBLE_DEVICES=3 PYTHONPATH=../../src/ python3 -m pdb ../../src/tasks/cmd.py zmlm.main.test ${RUN_DIR}/_conf device:0 dict_dir:${RUN_DIR}/ model_load_name:${RUN_DIR}/zmodel.best test:./_en.debug test_interactive:1 + if conf.test_interactive: + iinput_sent = input(">> (Interactive testing) Input sent sep by blanks: ") + iinput_tokens = iinput_sent.split() + if len(iinput_sent) > 0: + iinput_inst = GeneralSentence.create(iinput_tokens) + iinput_inst.word_seq.set_idxes([self.word_vocab.get_else_unk(w) for w in iinput_inst.word_seq.vals]) + iinput_inst.char_seq.build_idxes(self.inputter.vpack.get_voc("char")) + iinput_map = self.inputter([iinput_inst]) + iinput_erase_mask = np.asarray([[z=="Z" for z in iinput_tokens]]).astype(dtype=np.float32) + iinput_masked_map = self.inputter.mask_input(iinput_map, iinput_erase_mask, set("pos")) + emb_t, mask_t, enc_t, cache, enc_loss = self._emb_and_enc(iinput_masked_map, collect_loss=False, insts=[iinput_inst]) + mlm_loss = self.masklm.loss(enc_t, iinput_erase_mask, iinput_map) + dpar_input_attn = self.prepr_f(cache, self._get_rel_dist(BK.get_shape(mask_t, -1))) + self.dpar.predict([iinput_inst], enc_t, dpar_input_attn, mask_t) + self.upos.predict([iinput_inst], enc_t, mask_t) + # print them + import pandas as pd + cur_fields = { + "idxes": list(range(1, len(iinput_inst)+1)), + "word": iinput_inst.word_seq.vals, "pos": iinput_inst.pred_pos_seq.vals, + "head": iinput_inst.pred_dep_tree.heads[1:], "dlab": iinput_inst.pred_dep_tree.labels[1:]} + zlog(f"Result:\n{pd.DataFrame(cur_fields).to_string()}") + return {} # simply return here for interactive mode + # ----- + # test for MLM simply as in training (use special separate rand_gen to keep the masks the same for testing) + # todo(+2): do we need to keep testing/validing during training the same? Currently not! + info = self.fb_on_batch(insts, training=False, rand_gen=self.testing_rand_gen, assign_attns=conf.testing_get_attns) + # ----- + if len(insts) == 0: + return info + # decode for dpar + input_map = self.inputter(insts) + emb_t, mask_t, enc_t, cache, _ = self._emb_and_enc(input_map, collect_loss=False, insts=insts) + dpar_input_attn = self.prepr_f(cache, self._get_rel_dist(BK.get_shape(mask_t, -1))) + self.dpar.predict(insts, enc_t, dpar_input_attn, mask_t) + self.upos.predict(insts, enc_t, mask_t) + if self.ner is not None: + self.ner.predict(insts, enc_t, mask_t) + # ----- + if conf.testing_get_attns: + if conf.enc_choice == "vrec": + self._assign_attns_item(insts, "orig", cache=cache) + elif conf.enc_choice in ["original"]: + pass + else: + raise NotImplementedError() + return info + + def _assign_attns_item(self, insts, prefix, input_erase_mask_arr=None, abs_posi_arr=None, cache=None): + if cache is not None: + attn_names, attn_list = [], [] + for one_sidx, one_attn in enumerate(cache.list_attn): + attn_names.append(f"{prefix}_att{one_sidx}") + attn_list.append(one_attn) + if cache.accu_attn is not None: + attn_names.append(f"{prefix}_att_accu") + attn_list.append(cache.accu_attn) + for one_name, one_attn in zip(attn_names, attn_list): + # (step_idx, ) -> [bs, len_q, len_k, head] + one_attn_arr = BK.get_value(one_attn) + for bidx, inst in enumerate(insts): + save_arr = one_attn_arr[bidx] + inst.add_item(one_name, NpArrField(save_arr, float_decimal=4), assert_non_exist=False) + if abs_posi_arr is not None: + for bidx, inst in enumerate(insts): + inst.add_item(f"{prefix}_abs_posi", + NpArrField(abs_posi_arr[bidx], float_decimal=0), assert_non_exist=False) + if input_erase_mask_arr is not None: + for bidx, inst in enumerate(insts): + inst.add_item(f"{prefix}_erase_mask", + NpArrField(input_erase_mask_arr[bidx], float_decimal=4), assert_non_exist=False) + # ----- + +# b tasks/zmlm/model/mtl:56 diff --git a/tasks/zmlm/run/__init__.py b/tasks/zmlm/run/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tasks/zmlm/run/confs.py b/tasks/zmlm/run/confs.py new file mode 100644 index 0000000..f420db7 --- /dev/null +++ b/tasks/zmlm/run/confs.py @@ -0,0 +1,91 @@ +# + +# common configurations + +import json +from msp import utils, nn +from msp.nn import NIConf +from msp.utils import Conf, Logger, zopen, zfatal, zwarn + +from .vocab import MLMVocabPackageConf + +# todo(note): when testing, unless in the default setting, need to specify: "dict_dir" and "model_load_name" +class DConf(Conf): + def __init__(self): + # vocab + self.vconf = MLMVocabPackageConf() + # data paths + self.train = "" + self.dev = "" + self.test = "" + self.cache_data = True # turn off if large data + self.to_cache_shuffle = False + self.dict_dir = "./" + # cuttings for training (simply for convenience without especially preparing data...) + self.cut_train = -1 # <0 means no cut! + self.cut_dev = -1 + # save name for trainer not here!! + self.model_load_name = "zmodel.best" # load name + self.output_file = "zout" + # format (conllu, plain, json) + self.input_format = "conllu" + self.dev_input_format = "" # special mixing case + self.output_format = "json" + # special processing + self.lower_case = False + self.norm_digit = False # norm digits to 0 + # todo(note): deprecated; do not mix things in this way. Use pos field and word-thresh to control things! + self.backoff_pos_idx = -1 # <0 means nope, when idx>=this, then backoff to pos ones + + def do_validate(self): + if len(self.dev_input_format)==0: + self.dev_input_format = self.input_format + +# the overall conf +class OverallConf(Conf): + def __init__(self, args): + # top-levels + self.conf_output = "" + self.log_file = Logger.MAGIC_CODE + self.msp_seed = 9341 + # + self.niconf = NIConf() # nn-init-conf + self.dconf = DConf() # data-conf + # model + from ..model.mtl import MtlMlmModelConf + self.mconf = MtlMlmModelConf() + # ===== + # + self.update_from_args(args) + self.validate() + +# get model +def build_model(conf, vpack): + mconf = conf.mconf + from ..model.mtl import MtlMlmModel + model = MtlMlmModel(mconf, vpack) + return model + +# the start of everything +def init_everything(args, ConfType=None, quite=False): + # search for basic confs + all_argv, basic_argv = Conf.search_args(args, ["conf_output", "log_file", "msp_seed"], + [str, str, int], [None, Logger.MAGIC_CODE, None]) + # for basic argvs + conf_output = basic_argv["conf_output"] + log_file = basic_argv["log_file"] + msp_seed = basic_argv["msp_seed"] + if conf_output: + with zopen(conf_output, "w") as fd: + for k,v in all_argv.items(): + # todo(note): do not save this one + if k != "conf_output": + fd.write(f"{k}:{v}\n") + utils.init(log_file, msp_seed, quite=quite) + # real init of the conf + if ConfType is None: + conf = OverallConf(args) + else: + conf = ConfType(args) + nn.init(conf.niconf) + return conf diff --git a/tasks/zmlm/run/run.py b/tasks/zmlm/run/run.py new file mode 100644 index 0000000..9604e52 --- /dev/null +++ b/tasks/zmlm/run/run.py @@ -0,0 +1,291 @@ +# + +# some processing tools on the data stream + +from typing import List, Iterable, Dict +from collections import OrderedDict +from copy import deepcopy + +from msp.utils import zlog, zwarn, zopen, GLOBAL_RECORDER, Helper, zcheck +from msp.data import FAdapterStreamer, BatchArranger, InstCacher +from msp.zext.process_train import TrainingRunner, RecordResult +from msp.zext.process_test import TestingRunner, ResultManager +from msp.zext.dpar import ParserEvaler +from msp.data import WordNormer + +from ..data.insts import GeneralSentence +from ..data.io import BaseDataWriter, BaseDataReader, ConlluParseReader, PlainTextReader, ConllNerReader +from .vocab import MLMVocabPackage, UD2_POS_UNK_MAP + +from .seqeval import get_prf_scores + +# ===== +# preparing data + +# pre-processing (lower_case, norm_digit) +class PreprocessStreamer(FAdapterStreamer): + def __init__(self, in_stream, lower_case: bool, norm_digit: bool): + super().__init__(in_stream, self._go_prep, True) + self.normer = WordNormer(lower_case, norm_digit) + + def _go_prep(self, inst: GeneralSentence): + inst.word_seq.reset(self.normer.norm_stream(inst.word_seq.vals)) + +# for indexing instances +class IndexerStreamer(FAdapterStreamer): + def __init__(self, in_stream, vpack: MLMVocabPackage, inst_preparer, backoff_pos_idx: int): + super().__init__(in_stream, self._go_index, False) + # ----- + self.w_vocab = vpack.get_voc("word") + self.w2_vocab = vpack.get_voc("word2") # extra set + self.c_vocab = vpack.get_voc("char") + self.p_vocab = vpack.get_voc("pos") + self.l_vocab = vpack.get_voc("deplabel") + self.n_vocab = vpack.get_voc("ner") + self.inst_preparer = inst_preparer + self.backoff_pos_idx = backoff_pos_idx + + def _go_index(self, inst: GeneralSentence): + # word + # todo(warn): remember to norm word; replace singleton at model's input, not here + if inst.word_seq.has_vals(): + w_voc = self.w_vocab + word_idxes = [w_voc.get_else_unk(w) for w in inst.word_seq.vals] # todo(note): currently unk-idx is large + if self.backoff_pos_idx >= 0: # when using this mode, there must be backoff strings in word vocab + zwarn("The option of 'backoff_pos_idx' is deprecated, do not mix things in this way!") + word_backoff_idxes = [w_voc[UD2_POS_UNK_MAP[z]] for z in inst.pos_seq.vals] + word_idxes = [(zb if z>=self.backoff_pos_idx else z) for z, zb in zip(word_idxes, word_backoff_idxes)] + inst.word_seq.set_idxes(word_idxes) + # ===== + # aug extra word set!! + if self.w2_vocab is not None: + w2_voc = self.w2_vocab + inst.word2_seq = deepcopy(inst.word_seq) + word_idxes = [w2_voc.get_else_unk(w) for w in inst.word2_seq.vals] # todo(note): currently unk-idx is large + inst.word2_seq.set_idxes(word_idxes) + # ===== + # others + char_seq, pos_seq, dep_tree, ner_seq = [getattr(inst, z, None) for z in ["char_seq", "pos_seq", "dep_tree", "ner_seq"]] + if char_seq is not None and char_seq.has_vals(): + char_seq.build_idxes(self.c_vocab) + if pos_seq is not None and pos_seq.has_vals(): + pos_seq.build_idxes(self.p_vocab) + if dep_tree is not None and dep_tree.has_vals(): + dep_tree.build_label_idxes(self.l_vocab) + if ner_seq is not None and ner_seq.has_vals(): + ner_seq.build_idxes(self.n_vocab) + if self.inst_preparer is not None: + inst = self.inst_preparer(inst) + # in fact, inplaced if not wrapping model specific preparer + return inst + +# +def index_stream(in_stream, vpack, cached, cache_shuffle, inst_preparer, backoff_pos_idx): + i_stream = IndexerStreamer(in_stream, vpack, inst_preparer, backoff_pos_idx) + if cached: + return InstCacher(i_stream, shuffle=cache_shuffle) + else: + return i_stream + +# for arrange batches +def batch_stream(in_stream, batch_size, ticonf, training): + MIN_SENT_LEN_FOR_BSIZE = 10 + if training: + batch_size_f = (lambda x: max(len(x), MIN_SENT_LEN_FOR_BSIZE)) if ticonf.train_batch_on_len else None + b_stream = BatchArranger(in_stream, batch_size=batch_size, maxibatch_size=ticonf.train_maxibatch_size, batch_size_f=batch_size_f, + dump_detectors=lambda one: len(one)>=ticonf.train_max_length or len(one)<ticonf.train_min_length, + single_detectors=None, sorting_keyer=len, shuffling=ticonf.train_shuffle) + else: + batch_size_f = (lambda x: max(len(x), MIN_SENT_LEN_FOR_BSIZE)) if ticonf.test_batch_on_len else None + b_stream = BatchArranger(in_stream, batch_size=batch_size, maxibatch_size=ticonf.test_maxibatch_size, batch_size_f=batch_size_f, + dump_detectors=None, single_detectors=lambda one: len(one)>=ticonf.test_single_length, + sorting_keyer=len, shuffling=False) + return b_stream + +# ===== +# data io + +def get_data_reader(file_or_fd, input_format, **kwargs): + if input_format == "conllu": + r = ConlluParseReader(file_or_fd, **kwargs) + elif input_format == "ner": + r = ConllNerReader(file_or_fd, **kwargs) + elif input_format == "json": + r = BaseDataReader(GeneralSentence, file_or_fd) + elif input_format == "plain": + r = PlainTextReader(file_or_fd, **kwargs) + else: + raise NotImplementedError(f"Unknown input_format {input_format}") + return r + +def get_data_writer(file_or_fd, output_format, **kwargs): + if output_format == "json": + return BaseDataWriter(file_or_fd, **kwargs) + else: + raise NotImplementedError(f"Unknown output_format {output_format}") + +# ===== +# runner + +class MltResultMannager(ResultManager): + def __init__(self, vpack, outf, goldf, out_format): + self.insts = [] + # + self.vpack = vpack + self.outf = outf + self.goldf = goldf # todo(note): goldf is only for reporting which file to compare? + self.out_format = out_format + + def add(self, ones: List[GeneralSentence]): + self.insts.extend(ones) + + # write, eval & summary + def end(self): + # sorting by idx of reading + self.insts.sort(key=lambda x: x.sid) + # todo(+1): write other output file + if self.outf is not None: + with zopen(self.outf, "w") as fd: + data_writer = get_data_writer(fd, self.out_format) + data_writer.write_list(self.insts) + # eval for parsing & pos & ner + evaler = ParserEvaler() + ner_golds, ner_preds = [], [] + has_syntax, has_ner = False, False + for one_inst in self.insts: + gold_pos_seq = getattr(one_inst, "pos_seq", None) + pred_pos_seq = getattr(one_inst, "pred_pos_seq", None) + gold_ner_seq = getattr(one_inst, "ner_seq", None) + pred_ner_seq = getattr(one_inst, "pred_ner_seq", None) + gold_tree = getattr(one_inst, "dep_tree", None) + pred_tree = getattr(one_inst, "pred_dep_tree", None) + if gold_pos_seq is not None and gold_tree is not None: + # todo(+N): only ARTI_ROOT in trees + gold_pos, gold_heads, gold_labels = gold_pos_seq.vals, gold_tree.heads[1:], gold_tree.labels[1:] + pred_pos = pred_pos_seq.vals if (pred_pos_seq is not None) else [""]*len(gold_pos) + pred_heads, pred_labels = (pred_tree.heads[1:], pred_tree.labels[1:]) if (pred_tree is not None) \ + else ([-1]*len(gold_heads), [""]*len(gold_labels)) + evaler.eval_one(gold_pos, gold_heads, gold_labels, pred_pos, pred_heads, pred_labels) + has_syntax = True + if gold_ner_seq is not None and pred_ner_seq is not None: + ner_golds.append(gold_ner_seq.vals) + ner_preds.append(pred_ner_seq.vals) + has_ner = True + # ----- + zlog("Results of %s vs. %s" % (self.outf, self.goldf), func="result") + if has_syntax: + report_str, res = evaler.summary() + zlog(report_str, func="result") + else: + res = {} + if has_ner: + for _n, _v in zip("prf", get_prf_scores(ner_golds, ner_preds)): + res[f"ner_{_n}"] = _v + res["gold"] = self.goldf # record which file + zlog("zzzzztest: testing result is " + str(res)) + # note that res['res'] will not be used + if 'res' in res: + del res['res'] + # ----- + return res + +# testing runner +class MltTestingRunner(TestingRunner): + def __init__(self, model, vpack, outf, goldf, out_format): + super().__init__(model) + self.res_manager = MltResultMannager(vpack, outf, goldf, out_format) + + # run and record for one batch + def _run_batch(self, insts): + res = self.model.inference_on_batch(insts) + self.res_manager.add(insts) + return res + + # eval & report + def _run_end(self): + x = self.test_recorder.summary() + res = self.res_manager.end() + x.update(res) + MltDevResult.calc_acc(x) + Helper.printd(x, sep=" || ") + return x + +# +class MltDevResult(RecordResult): + def __init__(self, result_ds: List[Dict]): + self.all_results = result_ds + rr = {} + for idx, vv in enumerate(result_ds): + rr["zres"+str(idx)] = vv # record all the results for printing + # res = 0. if (len(result_ds)==0) else result_ds[0]["res"] # but only use the first one as DEV-res + # todo(note): use the first dev result, and use UAS if res<=0. + if len(result_ds) == 0: + res = 0. + else: + res = result_ds[0]["res"] + super().__init__(rr, res) + + @staticmethod + def calc_acc(x): + # put results for MLM/ORP + addings = OrderedDict() + for comp_name, comp_loss_names in zip(["mlm", "plm", "orp", "dp", "pos", "ner"], + [["word", "pos"], ["l2r", "r2l"], ["d-1", "d-2"], ["head", "label"], + ["slab"], ["slab", "crf"]]): + for c in comp_loss_names: + addings[f"res_{c}.acc"] = x.get(f"loss:{comp_name}.{c}_corr", 0.) / (x.get(f"loss:{comp_name}.{c}_count", 0.) + 1e-5) + addings[f"res_{c}.nll"] = x.get(f"loss:{comp_name}.{c}_sum", 0.) / (x.get(f"loss:{comp_name}.{c}_count", 0.) + 1e-5) + for k,v in addings.items(): + if v>0.: # only care active ones + x[k] = v + # ----- + # set "res" + res_cands = [-addings["res_word.nll"], -addings["res_l2r.nll"], -addings["res_r2l.nll"], + -addings["res_d-1.nll"], -addings["res_d-2.nll"], x.get("tok_las", 0.), + addings["res_slab.acc"], x.get("ner_f", 0.)] + x["res"] = sum(z for z in res_cands if z!=0.) + +# training runner +class MltTrainingRunner(TrainingRunner): + def __init__(self, rconf, model, vpack, dev_outfs, dev_goldfs, dev_out_format): + super().__init__(rconf, model) + self.vpack = vpack + self.dev_out_format = dev_out_format + # + self.dev_goldfs = dev_goldfs + if isinstance(dev_outfs, (list, tuple)): + zcheck(len(dev_outfs) == len(dev_goldfs), "Mismatched number of output and gold!") + self.dev_outfs = dev_outfs + else: + self.dev_outfs = [dev_outfs] * len(dev_goldfs) + + # to be implemented + def _run_fb(self, annotated_insts, loss_factor: float): + res = self.model.fb_on_batch(annotated_insts, loss_factor=loss_factor) + return res + + def _run_train_report(self): + x = self.train_recorder.summary() + y = GLOBAL_RECORDER.summary() + if len(y) > 0: + x.update(y) + MltDevResult.calc_acc(x) + Helper.printd(x, " || ") + return RecordResult(x, score=x.get("res", 0.)) + + def _run_update(self, lrate: float, grad_factor: float): + self.model.update(lrate, grad_factor) + + def _run_validate(self, dev_streams): + if not isinstance(dev_streams, Iterable): + dev_streams = [dev_streams] + dev_results = [] + zcheck(len(dev_streams) == len(self.dev_goldfs), "Mismatch number of streams!") + dev_idx = 0 + for one_stream, one_dev_outf, one_dev_goldf in zip(dev_streams, self.dev_outfs, self.dev_goldfs): + # todo(+2): overwrite previous ones? + rr = MltTestingRunner(self.model, self.vpack, one_dev_outf+".dev"+str(dev_idx), one_dev_goldf, self.dev_out_format) + x = rr.run(one_stream) + dev_results.append(x) + dev_idx += 1 + return MltDevResult(dev_results) diff --git a/tasks/zmlm/run/seqeval.py b/tasks/zmlm/run/seqeval.py new file mode 100644 index 0000000..211f871 --- /dev/null +++ b/tasks/zmlm/run/seqeval.py @@ -0,0 +1,132 @@ +# + +# directly from https://github.com/chakki-works/seqeval/ (on 29 Apr 2019) + +"""Metrics to assess performance on sequence labeling task given prediction +Functions named as ``*_score`` return a scalar value to maximize: the higher +the better +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from collections import defaultdict + +import numpy as np + + +def get_entities(seq, suffix=False): + """Gets entities from sequence. + + Args: + seq (list): sequence of labels. + + Returns: + list: list of (chunk_type, chunk_start, chunk_end). + + Example: + # >>> from seqeval.metrics.sequence_labeling import get_entities + # >>> seq = ['B-PER', 'I-PER', 'O', 'B-LOC'] + # >>> get_entities(seq) + [('PER', 0, 1), ('LOC', 3, 3)] + """ + # for nested list + if any(isinstance(s, list) for s in seq): + seq = [item for sublist in seq for item in sublist + ['O']] + + prev_tag = 'O' + prev_type = '' + begin_offset = 0 + chunks = [] + for i, chunk in enumerate(seq + ['O']): + if suffix: + tag = chunk[-1] + type_ = chunk.split('-')[0] + else: + tag = chunk[0] + type_ = chunk.split('-')[-1] + + if end_of_chunk(prev_tag, tag, prev_type, type_): + chunks.append((prev_type, begin_offset, i-1)) + if start_of_chunk(prev_tag, tag, prev_type, type_): + begin_offset = i + prev_tag = tag + prev_type = type_ + + return chunks + + +def end_of_chunk(prev_tag, tag, prev_type, type_): + """Checks if a chunk ended between the previous and current word. + + Args: + prev_tag: previous chunk tag. + tag: current chunk tag. + prev_type: previous type. + type_: current type. + + Returns: + chunk_end: boolean. + """ + chunk_end = False + + if prev_tag == 'E': chunk_end = True + if prev_tag == 'S': chunk_end = True + + if prev_tag == 'B' and tag == 'B': chunk_end = True + if prev_tag == 'B' and tag == 'S': chunk_end = True + if prev_tag == 'B' and tag == 'O': chunk_end = True + if prev_tag == 'I' and tag == 'B': chunk_end = True + if prev_tag == 'I' and tag == 'S': chunk_end = True + if prev_tag == 'I' and tag == 'O': chunk_end = True + + if prev_tag != 'O' and prev_tag != '.' and prev_type != type_: + chunk_end = True + + return chunk_end + + +def start_of_chunk(prev_tag, tag, prev_type, type_): + """Checks if a chunk started between the previous and current word. + + Args: + prev_tag: previous chunk tag. + tag: current chunk tag. + prev_type: previous type. + type_: current type. + + Returns: + chunk_start: boolean. + """ + chunk_start = False + + if tag == 'B': chunk_start = True + if tag == 'S': chunk_start = True + + if prev_tag == 'E' and tag == 'E': chunk_start = True + if prev_tag == 'E' and tag == 'I': chunk_start = True + if prev_tag == 'S' and tag == 'E': chunk_start = True + if prev_tag == 'S' and tag == 'I': chunk_start = True + if prev_tag == 'O' and tag == 'E': chunk_start = True + if prev_tag == 'O' and tag == 'I': chunk_start = True + + if tag != 'O' and tag != '.' and prev_type != type_: + chunk_start = True + + return chunk_start + +# +def get_prf_scores(y_true, y_pred, suffix=False): + true_entities = set(get_entities(y_true, suffix)) + pred_entities = set(get_entities(y_pred, suffix)) + + nb_correct = len(true_entities & pred_entities) + nb_pred = len(pred_entities) + nb_true = len(true_entities) + + p = nb_correct / nb_pred if nb_pred > 0 else 0 + r = nb_correct / nb_true if nb_true > 0 else 0 + f1 = 2 * p * r / (p + r) if p + r > 0 else 0 + + return p, r, f1 diff --git a/tasks/zmlm/run/vocab.py b/tasks/zmlm/run/vocab.py new file mode 100644 index 0000000..8dbe35a --- /dev/null +++ b/tasks/zmlm/run/vocab.py @@ -0,0 +1,149 @@ +# + +# the vocab package + +from typing import Dict + +from msp.utils import Conf, zlog, zwarn +from msp.data import VocabPackage, VocabBuilder, WordVectors, VocabHelper + +# ----- +# UD2 pre-defined types +UD2_POS_LIST = ["NOUN", "PUNCT", "VERB", "PRON", "ADP", "DET", "PROPN", "ADJ", "AUX", "ADV", "CCONJ", "PART", "NUM", "SCONJ", "X", "INTJ", "SYM"] +UD2_LABEL_LIST = ["punct", "case", "nsubj", "det", "root", "nmod", "advmod", "obj", "obl", "amod", "compound", "aux", "conj", "mark", "cc", "cop", "advcl", "acl", "xcomp", "nummod", "ccomp", "appos", "flat", "parataxis", "discourse", "expl", "fixed", "list", "iobj", "csubj", "goeswith", "vocative", "reparandum", "orphan", "dep", "dislocated", "clf"] +# for POS UNK backoff +UD2_POS_UNK_MAP = {p: VocabHelper.convert_special_pattern("UNK_"+p) for p in UD2_POS_LIST} +# ----- + +# +class MLMVocabPackageConf(Conf): + def __init__(self): + self.add_ud2_prevalues = True # add for UD2 pos and labels + self.add_ud2_pos_backoffs = False + # pretrain + self.pretrain_file = [] + self.read_from_pretrain = False + self.pretrain_scale = 1.0 + self.pretrain_init_nohit = 1.0 + self.pretrain_codes = [] + self.test_extra_pretrain_files = [] # extra embeddings for testing + self.test_extra_pretrain_codes = [] + # thresholds for word + self.word_rthres = 1000000 # rank <= this + self.word_fthres = 1 # freq >= this + self.ignore_thresh_with_pretrain = True # if hit in pretrain, also include and ignore thresh + +# +class MLMVocabPackage(VocabPackage): + def __init__(self, vocabs: Dict, embeds: Dict): + super().__init__(vocabs, embeds) + # ----- + + @staticmethod + def build_by_reading(prefix): + zlog("Load vocabs from files.") + possible_vocabs = ["word", "char", "pos", "deplabel", "ner", "word2"] + one = MLMVocabPackage({n:None for n in possible_vocabs}, {n:None for n in possible_vocabs}) + one.load(prefix=prefix) + return one + + @staticmethod + def build_from_stream(build_conf: MLMVocabPackageConf, stream, extra_stream): + zlog("Build vocabs from streams.") + ret = MLMVocabPackage({}, {}) + # ----- + if build_conf.add_ud2_pos_backoffs: + ud2_pos_pre_list = list(VocabBuilder.DEFAULT_PRE_LIST) + [UD2_POS_UNK_MAP[p] for p in UD2_POS_LIST] + word_builder = VocabBuilder("word", pre_list=ud2_pos_pre_list) + else: + word_builder = VocabBuilder("word") + char_builder = VocabBuilder("char") + pos_builder = VocabBuilder("pos") + deplabel_builder = VocabBuilder("deplabel") + ner_builder = VocabBuilder("ner") + if build_conf.add_ud2_prevalues: + zlog(f"Add pre-defined UD2 values for upos({len(UD2_POS_LIST)}) and ulabel({len(UD2_LABEL_LIST)}).") + pos_builder.feed_stream(UD2_POS_LIST) + deplabel_builder.feed_stream(UD2_LABEL_LIST) + for inst in stream: + word_builder.feed_stream(inst.word_seq.vals) + for w in inst.word_seq.vals: + char_builder.feed_stream(w) + # todo(+N): currently we are assuming that we are using UD pos/deps, and directly go with the default ones + # pos and label can be optional?? + # if inst.poses.has_vals(): + # pos_builder.feed_stream(inst.poses.vals) + # if inst.deplabels.has_vals(): + # deplabel_builder.feed_stream(inst.deplabels.vals) + if hasattr(inst, "ner_seq") and inst.ner_seq.has_vals(): + ner_builder.feed_stream(inst.ner_seq.vals) + # ===== embeddings + w2vec = None + if build_conf.read_from_pretrain: + # todo(warn): for convenience, extra vocab (usually dev&test) is only used for collecting pre-train vecs + extra_word_set = set(w for inst in extra_stream for w in inst.word_seq.vals) + # ----- load (possibly multiple) pretrain embeddings + # must provide build_conf.pretrain_file (there can be multiple pretrain files!) + list_pretrain_file, list_code_pretrain = build_conf.pretrain_file, build_conf.pretrain_codes + list_code_pretrain.extend([""] * len(list_pretrain_file)) # pad default ones + w2vec = WordVectors.load(list_pretrain_file[0], aug_code=list_code_pretrain[0]) + if len(list_pretrain_file) > 1: + w2vec.merge_others([WordVectors.load(list_pretrain_file[i], aug_code=list_code_pretrain[i]) + for i in range(1, len(list_pretrain_file))]) + # ----- + # first filter according to thresholds + word_builder.filter( + lambda ww, rank, val: (val >= build_conf.word_fthres and rank <= build_conf.word_rthres) or + (build_conf.ignore_thresh_with_pretrain and w2vec.has_key(ww))) + # then add extra ones + if build_conf.ignore_thresh_with_pretrain: + for w in extra_word_set: + if w2vec.has_key(w) and (not word_builder.has_key_currently(w)): + word_builder.feed_one(w) + word_vocab = word_builder.finish() + word_embed1 = word_vocab.filter_embed(w2vec, init_nohit=build_conf.pretrain_init_nohit, + scale=build_conf.pretrain_scale) + else: + word_vocab = word_builder.finish_thresh(rthres=build_conf.word_rthres, fthres=build_conf.word_fthres) + word_embed1 = None + # + char_vocab = char_builder.finish() + pos_vocab = pos_builder.finish(sort_by_count=False) + deplabel_vocab = deplabel_builder.finish(sort_by_count=False) + ner_vocab = ner_builder.finish() + # assign + ret.put_voc("word", word_vocab) + ret.put_voc("char", char_vocab) + ret.put_voc("pos", pos_vocab) + ret.put_voc("deplabel", deplabel_vocab) + ret.put_voc("ner", ner_vocab) + ret.put_emb("word", word_embed1) + # + return ret + + # ====== + # special mode extend another word vocab + # todo(+N): ugly patch!! + def aug_word2_vocab(self, stream, extra_stream, extra_embed_file: str): + zlog(f"Aug another word vocab from streams and extra_embed_file={extra_embed_file}") + word_builder = VocabBuilder("word2") + for inst in stream: + word_builder.feed_stream(inst.word_seq.vals) + # embeddings + if len(extra_embed_file) > 0: + extra_word_set = set(w for inst in extra_stream for w in inst.word_seq.vals) + w2vec = WordVectors.load(extra_embed_file) + for w in extra_word_set: + if w2vec.has_key(w) and (not word_builder.has_key_currently(w)): + word_builder.feed_one(w) + word_vocab = word_builder.finish() # no filtering!! + word_embed1 = word_vocab.filter_embed(w2vec, init_nohit=1.0, scale=1.0) + else: + zwarn("WARNING: No pretrain file for aug node!!") + word_vocab = word_builder.finish() # no filtering!! + word_embed1 = None + self.put_voc("word2", word_vocab) + self.put_emb("word2", word_embed1) + # ===== + +# b tasks/zmlm/run/vocab:146