From 842c8b7f9bae80e17399b4e6f5628c551cbeb705 Mon Sep 17 00:00:00 2001 From: antonio Date: Tue, 16 Jun 2020 21:02:12 -0400 Subject: [PATCH] Updates for zfp. --- msp/__init__.py | 5 +- msp/nn/backends/bktr.py | 4 +- msp/utils/color.py | 2 +- msp/zext/evaler.py | 7 +- tasks/zdpar/common/confs.py | 11 +- tasks/zdpar/common/model.py | 1 + tasks/zdpar/common/run.py | 4 +- tasks/zdpar/common/vocab.py | 19 +- tasks/zdpar/ef/analysis/ann.py | 2 +- tasks/zdpar/ef/analysis/run1217.py | 177 +++++++++++++++++ tasks/zdpar/main/build_vocab.py | 23 +++ tasks/zdpar/stat2/__init__.py | 14 +- tasks/zdpar/stat2/decode3.py | 51 ++++- tasks/zdpar/stat2/helper_basic.py | 29 ++- tasks/zdpar/stat2/helper_cluster.py | 33 +++- tasks/zdpar/stat2/helper_decode.py | 287 +++++++++++++++++++++++++++- tasks/zdpar/stat2/helper_feature.py | 62 +++++- tasks/zdpar/stat2/run.sh | 63 +++++- tasks/zdpar/zfp/__init__.py | 20 ++ tasks/zdpar/zfp/dec.py | 234 +++++++++++++++++++++++ tasks/zdpar/zfp/enc.py | 192 +++++++++++++++++++ tasks/zdpar/zfp/fp.py | 245 ++++++++++++++++++++++++ tasks/zdpar/zfp/masklm.py | 81 ++++++++ tasks/zdpar/zfp/run_zfp.sh | 98 ++++++++++ 24 files changed, 1617 insertions(+), 47 deletions(-) create mode 100644 tasks/zdpar/ef/analysis/run1217.py create mode 100644 tasks/zdpar/main/build_vocab.py create mode 100644 tasks/zdpar/zfp/__init__.py create mode 100644 tasks/zdpar/zfp/dec.py create mode 100644 tasks/zdpar/zfp/enc.py create mode 100644 tasks/zdpar/zfp/fp.py create mode 100644 tasks/zdpar/zfp/masklm.py create mode 100644 tasks/zdpar/zfp/run_zfp.sh diff --git a/msp/__init__.py b/msp/__init__.py index 5fd12cf..3e96d7a 100644 --- a/msp/__init__.py +++ b/msp/__init__.py @@ -23,7 +23,9 @@ # gru and cnn have problems? # ---- # -- Next Version principles and goals: -# nlp data types +# nlp data types & unified data formats!! +# model and submodules and ... (composition or inheritance??) !! +# more flexible namings: model/module name, loss name, data-field names, ... # use type hint # checkings and reportings # use eval for Conf @@ -31,6 +33,7 @@ # summarize more common patterns, including those in scripts # everything has (more flexible) conf # more flexible save/load for part of model; (better naming and support dynamic adding and deleting components!!) +# a small one: more flexible path finder, for example, multiple upper layers of ".." def version(): return (VERSION_MAJOR, VERSION_MINOR, VERSION_PATCH, VERSION_STATUS) diff --git a/msp/nn/backends/bktr.py b/msp/nn/backends/bktr.py index 0f59d46..48d024d 100644 --- a/msp/nn/backends/bktr.py +++ b/msp/nn/backends/bktr.py @@ -118,7 +118,9 @@ def update(self, overall_lrate, grad_factor): for param_group in self.opt_.param_groups: param_group['lr'] = cur_lrate self.cached_lrate_ = cur_lrate - if cur_lrate<=0. and self.no_step_lrate0_: + # check if we need update + parameters = list(filter(lambda p: p.grad is not None, self.params_)) + if (cur_lrate<=0. and self.no_step_lrate0_) or (len(parameters) == 0): # no update self.opt_.zero_grad() else: diff --git a/msp/utils/color.py b/msp/utils/color.py index 0844c00..b7e0585 100644 --- a/msp/utils/color.py +++ b/msp/utils/color.py @@ -3,7 +3,7 @@ # colorful printing try: - from colorama import colorama_init + from colorama import init as colorama_init colorama_init() from colorama import Fore, Back, Style RESET_ALL = Style.RESET_ALL diff --git a/msp/zext/evaler.py b/msp/zext/evaler.py index a67653b..e136192 100644 --- a/msp/zext/evaler.py +++ b/msp/zext/evaler.py @@ -35,12 +35,13 @@ def __float__(self): return self.f1 class LabelF1Evaler: - def __init__(self, name): + def __init__(self, name, ignore_none=False): # key -> List[labels] self.name = name self.golds = {} self.preds = {} self.labels = set() + self.ignore_none = ignore_none # ===== # adding ones @@ -53,9 +54,13 @@ def _add_group(self, d: Dict, key, label): self.labels.add(label) def add_gold(self, key, label): + if key is None and self.ignore_none: + return self._add_group(self.golds, key, label) def add_pred(self, key, label): + if key is None and self.ignore_none: + return self._add_group(self.preds, key, label) # ===== diff --git a/tasks/zdpar/common/confs.py b/tasks/zdpar/common/confs.py index 5503145..da53618 100644 --- a/tasks/zdpar/common/confs.py +++ b/tasks/zdpar/common/confs.py @@ -47,7 +47,9 @@ def __init__(self): # special processing self.lower_case = False self.norm_digit = False # norm digits to 0 - self.use_label0 = False # using only first-level label + self.use_label0 = True # using only first-level label # todo(note): change the default behaviour + zwarn("Note: currently we change default value of 'use_label0' to True!") + self.vocab_add_prevalues = True # add pre-defined UDv2 values when building dicts # ===== # for multi-lingual processing (another option is to pre-processing suitable data) # language code (empty str for no effects) @@ -89,6 +91,9 @@ def __init__(self, partype, args): elif partype == "s2": from ..ef.parser import S2ParserConf self.pconf = S2ParserConf() + elif partype == "fp": + from ..zfp.fp import FpParserConf + self.pconf = FpParserConf() else: zfatal(f"Unknown parser type: {partype}, please provide correct type with the option.") # ===== @@ -125,6 +130,10 @@ def build_model(partype, conf, vpack): # two-stage parser from ..ef.parser import S2Parser parser = S2Parser(pconf, vpack) + elif partype == "fp": + # the finale parser + from ..zfp.fp import FpParser + parser = FpParser(pconf, vpack) else: zfatal("Unknown parser type: %s") return parser diff --git a/tasks/zdpar/common/model.py b/tasks/zdpar/common/model.py index 09d4ce6..0a96d2d 100644 --- a/tasks/zdpar/common/model.py +++ b/tasks/zdpar/common/model.py @@ -283,6 +283,7 @@ def __init__(self): self.load_process = False # batch arranger self.batch_size = 32 + self.train_min_length = 0 self.train_skip_length = 120 self.shuffle_train = True # optimizer and lrate factor for enc&dec&sl(mid) diff --git a/tasks/zdpar/common/run.py b/tasks/zdpar/common/run.py index 870c207..ff81877 100644 --- a/tasks/zdpar/common/run.py +++ b/tasks/zdpar/common/run.py @@ -54,8 +54,8 @@ def index_stream(in_stream, vpack, cached, cache_shuffle, inst_preparer): def batch_stream(in_stream, ticonf, training): if training: b_stream = BatchArranger(in_stream, batch_size=ticonf.batch_size, maxibatch_size=20, batch_size_f=None, - dump_detectors=lambda one: len(one)>=ticonf.train_skip_length, single_detectors=None, - sorting_keyer=len, shuffling=ticonf.shuffle_train) + dump_detectors=lambda one: len(one)>=ticonf.train_skip_length or len(one)=ticonf.infer_single_length, diff --git a/tasks/zdpar/common/vocab.py b/tasks/zdpar/common/vocab.py index f310ded..e75ac53 100644 --- a/tasks/zdpar/common/vocab.py +++ b/tasks/zdpar/common/vocab.py @@ -22,6 +22,12 @@ def build_by_reading(dconf): one.load(dconf.dict_dir) return one + # ===== + # pre-values for UDv2 + PRE_VALUES_ULAB = ["punct", "case", "nsubj", "det", "root", "", "nmod", "advmod", "obj", "obl", "amod", "compound", "aux", "conj", "mark", "cc", "cop", "advcl", "acl", "xcomp", "nummod", "ccomp", "appos", "flat", "parataxis", "discourse", "expl", "fixed", "list", "iobj", "csubj", "goeswith", "vocative", "reparandum", "orphan", "dep", "dislocated", "clf"] + PRE_VALUES_UPOS = ["NOUN", "PUNCT", "VERB", "PRON", "ADP", "DET", "PROPN", "", "ADJ", "AUX", "ADV", "CCONJ", "PART", "NUM", "SCONJ", "X", "INTJ", "SYM"] + # ===== + @staticmethod def build_from_stream(dconf: DConf, stream, extra_stream): zlog("Build vocabs from streams.") @@ -32,13 +38,22 @@ def build_from_stream(dconf: DConf, stream, extra_stream): pos_builder = VocabBuilder("pos") label_builder = VocabBuilder("label") word_normer = ret.word_normer + if dconf.vocab_add_prevalues: + zlog(f"Add pre-defined values for upos({len(ParserVocabPackage.PRE_VALUES_UPOS)}) and " + f"ulabel({len(ParserVocabPackage.PRE_VALUES_ULAB)}).") + pos_builder.feed_stream(ParserVocabPackage.PRE_VALUES_UPOS) + label_builder.feed_stream(ParserVocabPackage.PRE_VALUES_ULAB) for inst in stream: # todo(warn): only do special handling for words + # there must be words word_builder.feed_stream(word_normer.norm_stream(inst.words.vals)) for w in inst.words.vals: char_builder.feed_stream(w) - pos_builder.feed_stream(inst.poses.vals) - label_builder.feed_stream(inst.labels.vals) + # pos and label can be optional + if inst.poses.has_vals(): + pos_builder.feed_stream(inst.poses.vals) + if inst.labels.has_vals(): + label_builder.feed_stream(inst.labels.vals) # w2vec = None if dconf.init_from_pretrain: diff --git a/tasks/zdpar/ef/analysis/ann.py b/tasks/zdpar/ef/analysis/ann.py index 11e0d48..2255b07 100644 --- a/tasks/zdpar/ef/analysis/ann.py +++ b/tasks/zdpar/ef/analysis/ann.py @@ -97,7 +97,7 @@ def do_print(self, ocode): # ----- # looking/annotating at specific instances. protocol: target - # start new annotation task or TODO(!) recover the previous one + # start new annotation task or TODO(+N) recover the previous one def do_ann_start(self, insts_target: str) -> AnnotationTask: assert self.cur_cmd_target is not None, "Should assign this to a var to avoid accidental loss!" vs = self.vars diff --git a/tasks/zdpar/ef/analysis/run1217.py b/tasks/zdpar/ef/analysis/run1217.py new file mode 100644 index 0000000..e9e1f12 --- /dev/null +++ b/tasks/zdpar/ef/analysis/run1217.py @@ -0,0 +1,177 @@ +# + +# case study for the ef one +# error breakdown on labels,steps; which ones are "easy"(first-decoded) ones + +# + +import sys +from collections import Counter +from msp.zext.ana import AnalyzerConf, Analyzer, ZRecNode, AnnotationTask + +try: + from .ann import * +except: + from ann import * + +# +class NewAnalysisConf(AnalysisConf): + def __init__(self, args): + super().__init__(args) + # + self.step_div = 5 + self.use_label0 = True + +def main(args): + conf = NewAnalysisConf(args) + # ===== + if conf.load_name == "": + # recalculate them + # read them + zlog("Read them all ...") + gold_parses = list(yield_ones(conf.gold)) + sys_parses = [list(yield_ones(z)) for z in conf.fs] + if conf.use_label0: + # todo(note): force using label0 (language-independent) + zlog("Force label0 ...") + for one_parses in sys_parses + [gold_parses]: + for one_parse in one_parses: + for one_token in one_parse.get_tokens(): + one_token.label = one_token.label0 + # use vocab? + voc = Vocab.read(conf.vocab) if len(conf.vocab)>0 else None + # ===== + # stat them + zlog("Stat them all ...") + all_sents, all_tokens = get_objs(gold_parses, sys_parses, conf.getter) + analyzer = ParsingAnalyzer(conf.ana, all_sents, all_tokens, conf.labeled, vocab=voc) + analyzer.set_var("nsys", len(conf.fs), explanation="init", history_idx=-1) + if conf.save_name != "": + analyzer.do_save(conf.save_name) + else: + analyzer = ParsingAnalyzer(conf.ana, None, None, conf.labeled) + analyzer.do_load(conf.load_name) + # ===== + # special analysis + # ---- + def _num_same_sibs(_node): + # todo(note): here split label again + _lab = _node.label + if conf.use_label0: + _count = sum(z.split(":")[0]==_lab for z in _node.get_head().childs_labels) + else: + _count = sum(z==_lab for z in _node.get_head().childs_labels) + assert _count>=1 + return _count-1 + # ---- + all_sents = analyzer.get_var("sents") + nsys = analyzer.get_var("nsys") + step_div = conf.step_div # how many bins for ef-steps? + breakdown_labels = {} # label -> {gold: {count, numsib, dist}, preds: [{count, numsib, dist, lcorr, stepp}]} + for _lab in ulabel2type["Nivre17"].keys(): + breakdown_labels[_lab] = {"gold": {"count": 0, "numsib": 0, "dist": 0}, + "preds": [{"count": 0, "numsib": 0, "dist": 0, "lcorr": 0, "stepp": 0} for _ in range(nsys)]} + breakdown_steps = {} # stepbin -> {count, dist, counter(label), acc, acc-all} + for _stepbin in range(step_div): + breakdown_steps[_stepbin] = {"count": 0, "dist": 0, "labels": Counter(), "lcorrs": [0]*nsys} + # ----- + # collect + for one_sobj in all_sents: + cur_length = one_sobj.len + for one_tobj in one_sobj.rtoks: # all real toks + # ----- + # get stat + gold_label = one_tobj.g.label + gold_numsib = _num_same_sibs(one_tobj.g) + gold_dist = abs(one_tobj.g.ddist) + # breakdown-label + breakdown_labels[gold_label]["gold"]["count"] += 1 + breakdown_labels[gold_label]["gold"]["numsib"] += gold_numsib + breakdown_labels[gold_label]["gold"]["dist"] += gold_dist + for i, p in enumerate(one_tobj.ss): + pred_label = p.label + if pred_label in ["", ""]: + pred_label = "dep" # todo(note): fix padding prediction + pred_numsib = _num_same_sibs(p) + pred_dist = abs(p.ddist) + pred_lcorr = p.lcorr + pred_stepi = getattr(p, "efi", None) + if pred_stepi is None: + pred_stepi = getattr(p, "gmi", None) + assert pred_stepi is not None + pred_stepbin = int(pred_stepi*step_div/cur_length) + pred_stepp = pred_stepi / cur_length + # breakdown-label + breakdown_labels[pred_label]["preds"][i]["count"] += 1 + breakdown_labels[pred_label]["preds"][i]["numsib"] += pred_numsib + breakdown_labels[pred_label]["preds"][i]["dist"] += pred_dist + breakdown_labels[pred_label]["preds"][i]["lcorr"] += pred_lcorr + breakdown_labels[pred_label]["preds"][i]["stepp"] += pred_stepp + # breakdown-steps + if i==0: # todo(note): only record the first one!! + breakdown_steps[pred_stepbin]["count"] += 1 + breakdown_steps[pred_stepbin]["dist"] += pred_dist + breakdown_steps[pred_stepbin]["labels"][pred_label] += 1 + for i2, p2 in enumerate(one_tobj.ss): # all nodes' correctness for this certain node! + breakdown_steps[pred_stepbin]["lcorrs"][i2] += p2.lcorr + # ----- + # summary + data_labels = [] + for k, dd in breakdown_labels.items(): + gold_count = max(dd["gold"]["count"], 1e-5) + res = {"K": k, "gold_count": gold_count, "numsib": dd["gold"]["numsib"]/gold_count, + "dist": dd["gold"]["dist"]/gold_count} + for pidx, preds in enumerate(dd["preds"]): + pred_count = max(preds["count"], 1e-5) + res[f"pred{pidx}_count"] = pred_count + res[f"pred{pidx}_numsib"] = preds["numsib"]/pred_count + res[f"pred{pidx}_dist"] = preds["dist"]/pred_count + res[f"pred{pidx}_stepp"] = preds["stepp"]/pred_count + P, R = preds["lcorr"]/pred_count, preds["lcorr"]/gold_count + F = 2*P*R/(P+R) if (P+R)>0 else 0. + res.update({f"pred{pidx}_P": P, f"pred{pidx}_R": R, f"pred{pidx}_F": F}) + data_labels.append(res) + data_steps = [] + TOP_LABEL_K = 5 + for k, dd in breakdown_steps.items(): + dd_count = max(dd["count"], 1e-5) + res = {"K": k, "count": dd_count, "dist": dd["dist"]/dd_count} + for common_idx, common_p in enumerate(dd["labels"].most_common(TOP_LABEL_K)): + common_label, common_count = common_p + res[f"common{common_idx}"] = f"{common_label}({common_count/dd['count']:.3f})" + for pidx, pcorr in enumerate(dd["lcorrs"]): + res[f"pred{pidx}_acc"] = pcorr/dd_count + data_steps.append(res) + # ===== + pd_labels = pd.DataFrame({k: [d[k] for d in data_labels] for k in data_labels[0].keys()}) + pd_labels = pd_labels.sort_values(by="gold_count", ascending=False) + selections = ["K", "gold_count", "numsib", "dist", "pred0_numsib", "pred0_stepp", "pred0_F", + "pred1_numsib", "pred1_stepp", "pred1_F"] + pd_labels2 = pd_labels[selections] + pd_steps = pd.DataFrame({k: [d[k] for d in data_steps] for k in data_steps[0].keys()}) + zlog(f"#-----\nLABELS: \n{pd_labels2.to_string()}\n\n") + zlog(f"#-----\nSTEPS: \n{pd_steps.to_string()}\n\n") + # specific table + TABLE_LABEL_K = 10 + num_all_tokens = sum(z["gold_count"] for z in data_labels) + lines = [] + for i in range(TABLE_LABEL_K): + ss = pd_labels.iloc[i] + fields = [ss["K"], f"{ss['gold_count']/num_all_tokens:.2f}", f"{ss['numsib']:.2f}", f"{ss['dist']:.2f}", + f"{ss['pred0_F']*100:.2f}", f"{ss['pred0_numsib']:.2f}", + f"{ss['pred1_F']*100:.2f}", f"{ss['pred1_numsib']:.2f}"] + lines.append(" & ".join(fields)) + table_ss = "\\\\\n".join(lines) + zlog(f"#=====\n{table_ss}") + # ----- + # import pdb + # pdb.set_trace() + return + +if __name__ == '__main__': + main(sys.argv[1:]) + +# runnings +""" +PYTHONPATH=../../src/ python3 -m pdb run.py gold:en_dev.gold fs:en_dev.zef_ru.pred,en_dev.zg1_ru.pred +""" diff --git a/tasks/zdpar/main/build_vocab.py b/tasks/zdpar/main/build_vocab.py new file mode 100644 index 0000000..1e17cfe --- /dev/null +++ b/tasks/zdpar/main/build_vocab.py @@ -0,0 +1,23 @@ +# + +# building vocabs +# the initial part of training + +from ..common.confs import init_everything +from ..common.data import get_data_reader, get_multisoure_data_reader +from ..common.vocab import ParserVocabPackage + +def main(args): + conf = init_everything(args+["partype:fp"]) + dconf = conf.dconf + if dconf.multi_source: + _reader_getter = get_multisoure_data_reader + else: + _reader_getter = get_data_reader + train_streamer = _reader_getter(dconf.train, dconf.input_format, dconf.code_train, dconf.use_label0, cut=dconf.cut_train) + vpack = ParserVocabPackage.build_from_stream(dconf, train_streamer, []) # empty extra_stream + vpack.save(dconf.dict_dir) + +# SRC_DIR="../src/" +# PYTHONPATH=${SRC_DIR}/ python3 ${SRC_DIR}/tasks/cmd.py zdpar.main.build_vocab train:[input] input_format:conllu dict_dir:[output_dir] init_from_pretrain:0 pretrain_file:? +# PYTHONPATH=${SRC_DIR}/ python3 ${SRC_DIR}/tasks/cmd.py zdpar.main.build_vocab train:../data/UD_RUN/ud24/en_all.conllu input_format:conllu dict_dir:./vocab/ init_from_pretrain:1 pretrain_file:../data/UD_RUN/ud24/wiki.multi.en.filtered.vec pretrain_scale:1 pretrain_init_nohit:1 diff --git a/tasks/zdpar/stat2/__init__.py b/tasks/zdpar/stat2/__init__.py index b860811..a6402d0 100644 --- a/tasks/zdpar/stat2/__init__.py +++ b/tasks/zdpar/stat2/__init__.py @@ -1,11 +1,15 @@ # -# todo-list +# trying-list # topical influence (use actual bert's predicted output to delete special semantics?) -> change words: first fix non-changed, then find repeated topic words, then change topic and hard words one per segment -> (191103: change too much may hurt) # predict-leaf (use first half of order as leaf, modify the scores according to this first half?) -> (191104: still not good) # other reduce criterion? -> (191105: not too much diff, do not get it too complex...) # vocab based reduce? simply <= thresh? -> (191106: ok, help a little, but not our target) -# cky + decide direction later -# only change topical words? -# direct parse subword seq? (group and take max/avg-score) -# cluster (hierachical bag of words, direction?, influence range, grouped influence) +# cky + decide direction later -> (191111: still not good, but first pdb-debug; 191112: still not good) +# cky + two-layer by splitting puncts? -> (191113: only slightly helpful) +# stop words (<100 in voacb) as lower ones? -> (191114: worse than pos-rule) +# check wsj and phrase result? -> (191114: f1 around 40) +# only change topical words? -> (191115: change 883/25148, no obvious diff) +# direct parse subword seq? (group and take max/avg-score) -> (191115: max-score seems to be slightly helpful +1) +# cluster (hierachical bag of words, direction?, influence range, grouped influence) -> (191115: similar, but slightly worse than cky) +# -- ok, here is the end, goto next stage ... diff --git a/tasks/zdpar/stat2/decode3.py b/tasks/zdpar/stat2/decode3.py index 9caaa2e..303c27d 100644 --- a/tasks/zdpar/stat2/decode3.py +++ b/tasks/zdpar/stat2/decode3.py @@ -10,7 +10,8 @@ from msp.data.vocab import Vocab from tasks.zdpar.common.data import ParseInstance from .helper_basic import SentProcessor, main_loop, show_heatmap, SDBasicConf -from .helper_decode import get_decoder, get_fdecoder, IRConf, IterReducer +from .helper_decode import get_decoder, get_fdecoder, IRConf, IterReducer, CKYConf, CKYParser +from .helper_cluster import ClusterManager, OneCluster # class SD3Conf(SDBasicConf): @@ -23,6 +24,7 @@ def __init__(self, args): # ===== # reducer self.ir_conf = IRConf() + self.cky_conf = CKYConf() # ===== # leaf layer decoding self.use_fdec = False @@ -36,7 +38,7 @@ def __init__(self, args): self.fdec_Tlambda = 1. # if use score, lambda for transpose one # main layer decoding self.mdec_aggregate_scores = False # aggregate score from fdec - self.mdec_method = "m2" # upper layer decoding method: l2r/r2l/rand/m1/m2/m3/m3s/m4(recursive-reduce) + self.mdec_method = "cky" # upper layer decoding method: l2r/r2l/rand/m1/m2/m3/m3s/cky self.mdec_proj = True self.mdec_oper = "add" # operation between O/T self.mdec_Olambda = 1. # if use score, lambda for origin one @@ -60,12 +62,14 @@ def __init__(self, conf: SD3Conf): self.fdec_oper = oper_dict[conf.fdec_oper] self.mdec_oper = oper_dict[conf.mdec_oper] # - self.ir_reducer = IterReducer(conf.ir_conf) - # if conf.fdec_vocab_file: self.r_vocab = Vocab.read(conf.fdec_vocab_file) else: self.r_vocab = None + # + self.ir_reducer = IterReducer(conf.ir_conf) + self.cky_parser = CKYParser(conf.cky_conf, self.r_vocab) + self.cluster_manager = ClusterManager() # ===== # todo(note): no arti-ROOT in inputs @@ -128,12 +132,19 @@ def test_one_sent(self, one_inst: ParseInstance): mdec_idxes = list(range(slen)) # ===== # step 2: main layer - root_scores = np.zeros(slen) # TODO(!) + res_tree = None + root_scores = np.copy(orig_distances[:, 0]) # todo(+2): other types? if conf.mdec_aggregate_scores: - pass # TODO(!) + pass # TODO(+N) if conf.mdec_method == "m3s": # special decoding starting from a middle step of m3 - pass # TODO(!) + original_scores = influence_scores + processed_scores = self.mdec_oper(conf.mdec_Olambda * original_scores, conf.mdec_Tlambda * (original_scores.T)) + ir_reduce_rank, _ = self.ir_reducer.reduce(influence_scores, orig_distances[:, 0]) # rank as the head score + # + start_clusters = [OneCluster(group_idxes[i], i) for i in mdec_idxes] + output_heads, output_root, cluster_procedure = self.cluster_manager.cluster( + len(sent), start_clusters, None, processed_scores, ir_reduce_rank) else: # decoding the contracted sent (only mdec_idxes) original_scores = influence_scores @@ -142,7 +153,11 @@ def test_one_sent(self, one_inst: ParseInstance): processed_scores = processed_scores[mdec_idxes, :][:, mdec_idxes] root_scores = root_scores[mdec_idxes] if len(processed_scores)>0: - output_heads, output_root = self.mdecoder(processed_scores, original_scores, root_scores, conf.mdec_proj) + if conf.mdec_method == "cky": + res_tree, res_deps = self.cky_parser.parse(sent, uposes, fun_masks, original_scores, root_scores) + output_heads, output_root = res_deps + else: + output_heads, output_root = self.mdecoder(processed_scores, original_scores, root_scores, conf.mdec_proj) # restore all assert len(output_heads) == len(mdec_idxes) output_root = mdec_idxes[output_root] @@ -168,8 +183,12 @@ def test_one_sent(self, one_inst: ParseInstance): # if conf.show_result: self.eval_order_verbose(dep_heads, ir_reduce_order, sent, dep_labels) + if res_tree is not None: + zlog("#-----\n"+res_tree.get_repr(False,False)+"\n"+res_tree.get_repr(True,False)+"\n#-----") self._show(one_inst, output_heads, ret_info, conf.show_heatmap, aug_influence_scores, True) # + if res_tree is not None: + ret_info["phrase_tree"] = res_tree.get_repr(add_head=False, fake_nt=True) self.infos.append(ret_info) return ret_info @@ -194,5 +213,21 @@ def main(args): # PYTHONPATH=../src/ python3 -m pdb ../src/tasks/cmd.py zdpar.stat2.decode3 input_file:en_cut.ppos.conllu # PYTHONPATH=../src/ python3 -m pdb ../src/tasks/cmd.py zdpar.stat2.decode3 input_file:_en_dev.ppos.pic already_pre_computed:1 # PYTHONPATH=../src/ python3 -m pdb ../src/tasks/cmd.py zdpar.stat2.decode3 input_file:_en_dev.ppos.pic already_pre_computed:1 show_result:1 show_heatmap:1 min_len:10 max_len:20 rand_input:1 rand_seed:100 +# PYTHONPATH=../src/ python3 -m pdb ../src/tasks/cmd.py zdpar.stat2.decode3 input_file:./_en_dev.all.MASK.mse.pic already_pre_computed:1 +# ===== +# basic +# PYTHONPATH=../src/ python3 -m pdb ../src/tasks/cmd.py zdpar.stat2.decode3 input_file:./_en_dev.all.MASK.mse.pic already_pre_computed:1 +# wsj +# PYTHONPATH=../src/ python3 ../src/tasks/cmd.py zdpar.stat2.decode3 input_file:wsj23.conllu output_file:w23.out1 output_file_ptree:w23.out2 output_pic:w23.pic +# PYTHONPATH=../src/ python3 ../src/tasks/cmd.py zdpar.stat2.decode3 input_file:./w23.pic already_pre_computed:1 output_file:w23.out1 output_file_ptree:w23.out2 ps_pp:1 +# ----- +# repl +# PYTHONPATH=../src/ python3 ../src/tasks/cmd.py zdpar.stat2.decode3 input_file:en_dev.ppos.conllu output_pic:_en_dev.brepl.pic br_vocab_file:./en.voc sent_repl_times:1 +# ----- +# subword mode +# PYTHONPATH=../src/ python3 ../src/tasks/cmd.py zdpar.stat2.decode3 input_file:en_dev.ppos.conllu output_pic:_en_dev.subword.pic b_use_subword:1 b_subword_aggr:max +# ----- +# decoding method +# PYTHONPATH=../src/ python3 -m pdb ../src/tasks/cmd.py zdpar.stat2.decode3 input_file:./_en_dev.all.MASK.mse.pic already_pre_computed:1 ps_pp:1 mdec_method:m3s # b tasks/zdpar/stat2/decode3:138 diff --git a/tasks/zdpar/stat2/helper_basic.py b/tasks/zdpar/stat2/helper_basic.py index 043c3c8..5496e97 100644 --- a/tasks/zdpar/stat2/helper_basic.py +++ b/tasks/zdpar/stat2/helper_basic.py @@ -29,6 +29,7 @@ def __init__(self): # io self.input_file = "" self.output_file = "" + self.output_file_ptree = "" # output for phrase trees self.output_pic = "" self.rand_input = False self.rand_seed = 12345 @@ -39,7 +40,7 @@ def __init__(self): self.already_pre_computed = False self.fake_scores = False # create all 0. distance scores for debugging self.fconf = FeaturerConf() - self.which_fold = -1 + self.which_fold = 8 # 7~9 generally slightly better # msp self.niconf = NIConf() # nn-init-conf # todo(+N): how to deal with strange influences of low-freq words? (topic influence?) @@ -171,12 +172,12 @@ def _show(self, one_inst, pred_heads, info, show_hmap: bool, distances, set_trac print("; ".join([f"{one_name}:{info[one_name+'_dhit']}/{info[one_name+'_uhit']}/{info[one_name+'_count']}" for one_name in ["all", "np", "ct", "fun"]])) # ===== + print(sent) + print(uposes) + print(one_inst.extra_features.get("feat_seq", None)) + for one_ri, one_seq in enumerate(one_inst.extra_features.get("sd3_repls", [])): + print(f"#R{one_ri}: {one_seq}") if show_hmap: - print(sent) - print(uposes) - print(one_inst.extra_features.get("feat_seq", None)) - for one_ri, one_seq in enumerate(one_inst.extra_features.get("sd3_repls", [])): - print(f"#R{one_ri}: {one_seq}") show_heatmap(sent, distances, True) if set_trace: import pdb @@ -246,6 +247,7 @@ def main_loop(conf: SDBasicConf, sp: SentProcessor): np.seterr(all='raise') nn_init(conf.niconf) np.random.seed(conf.rand_seed) + records = defaultdict(int) # # will trigger error otherwise, save time of loading model featurer = None if conf.already_pre_computed else Featurer(conf.fconf) @@ -297,6 +299,11 @@ def main_loop(conf: SDBasicConf, sp: SentProcessor): one_inst.extra_features["sd3_fixed"] = sent_fixed # ===== score folded_distances = featurer.get_scores(sent_repls[-1]) + assert len(sent_repls[-1]) == len(word_seq) + # --- + records["repl_count"] += len(word_seq) + records["repl_repl"] += sum(a!=b for a,b in zip(sent_repls[-1], word_seq)) + # --- one_inst.extra_features["sd2_scores"] = folded_distances one_inst.extra_features["feat_seq"] = word_seq if output_pic_fd is not None: @@ -306,6 +313,10 @@ def main_loop(conf: SDBasicConf, sp: SentProcessor): # put prediction one_inst.pred_heads.set_vals([0] + list(one_info["output"][0])) one_inst.pred_labels.set_vals(["_"] * len(one_inst.labels.vals)) + # + phrase_tree_string = one_info.get("phrase_tree") + if phrase_tree_string is not None: + one_inst.extra_pred_misc["phrase_tree"] = phrase_tree_string all_insts.append(one_inst) if output_pic_fd is not None: output_pic_fd.close() @@ -313,5 +324,11 @@ def main_loop(conf: SDBasicConf, sp: SentProcessor): with zopen(conf.output_file, 'w') as wfd: data_writer = get_data_writer(wfd, "conllu") data_writer.write(all_insts) + if conf.output_file_ptree: + with zopen(conf.output_file_ptree, 'w') as wfd: + for one_inst in all_insts: + phrase_tree_string = one_inst.extra_pred_misc.get("phrase_tree") + wfd.write(str(phrase_tree_string)+"\n") # ----- + Helper.printd(records) Helper.printd(sp.summary()) diff --git a/tasks/zdpar/stat2/helper_cluster.py b/tasks/zdpar/stat2/helper_cluster.py index 98a78a1..bd7e43b 100644 --- a/tasks/zdpar/stat2/helper_cluster.py +++ b/tasks/zdpar/stat2/helper_cluster.py @@ -13,6 +13,9 @@ def merge(self, other_cluster: 'OneCluster'): self.nodes.extend(other_cluster.nodes) # still keep self.head + def __repr__(self): + return f"{self.head}: {self.nodes}" + # greedy nearby clustering with head # bottom up decoding class ClusterManager: @@ -24,6 +27,8 @@ def cluster(self, slen: int, start_clusters: List[OneCluster], allowed_matrix, l ret = [0] * slen procedures = [] # List[(h, m)] cur_clusters = start_clusters + if allowed_matrix is None: + allowed_matrix = np.ones([slen, slen], dtype=np.bool) while True: # merge candidates between (i, i+1) merge_candidates = [] @@ -34,17 +39,31 @@ def cluster(self, slen: int, start_clusters: List[OneCluster], allowed_matrix, l c2_nodes, c2_head = c2.nodes, c2.head if allowed_matrix[c1_head, c2_head] or allowed_matrix[c2_head, c1_head]: merge_candidates.append(midx) - # here using the linking scores - one_score = None # TODO(!) + # here using the averaged linking scores + one_score = (link_scores[c1_nodes][:,c2_nodes] + link_scores[c2_nodes][:,c1_nodes].T).mean() merge_scores.append(one_score) # break if nothing left to merge if len(merge_candidates) == 0: break # greedily pick the best score - best_idx = np.argmax(merge_scores).item() - best_c1_idx = merge_candidates[best_idx] + best_idx = np.argmax(merge_scores).item() # idx in current candidates list + best_c1_idx = merge_candidates[best_idx] # idx in cur cluster list best_c2_idx = best_c1_idx + 1 - # TODO(!) # determine the direction - # prepare for the next iter - return ret, procedures + togo_c1, togo_c2 = cur_clusters[best_c1_idx], cur_clusters[best_c2_idx] + togo_c1_head, togo_c2_head = togo_c1.head, togo_c2.head + if head_scores[togo_c1_head] <= head_scores[togo_c2_head]: + # c2 as head + cur_clusters = cur_clusters[:best_c1_idx] + cur_clusters[best_c2_idx:] + procedures.append((togo_c2_head, togo_c1_head)) + ret[togo_c1_head] = togo_c2_head + 1 # +1 head offset + else: + # c1 as head + cur_clusters = cur_clusters[:best_c1_idx+1] + cur_clusters[best_c2_idx+1:] + procedures.append((togo_c1_head, togo_c2_head)) + ret[togo_c2_head] = togo_c1_head + 1 + # + ret_root = [i for i,v in enumerate(ret) if v==0][0] + return ret, ret_root, procedures + +# b tasks/zdpar/stat2/helper_cluster:34 diff --git a/tasks/zdpar/stat2/helper_decode.py b/tasks/zdpar/stat2/helper_decode.py index c55413f..6b734d1 100644 --- a/tasks/zdpar/stat2/helper_decode.py +++ b/tasks/zdpar/stat2/helper_decode.py @@ -8,8 +8,9 @@ # for return: heads are offset-ed, but root is not import numpy as np +import string -from msp.utils import Conf +from msp.utils import Conf, zlog from ..algo.nmst import mst_unproj, mst_proj # ===== @@ -137,7 +138,7 @@ def get_decoder(method): return { "l2r": chain_l2r, "r2l": chain_r2l, "rand": decode_rand, "m1": decode_m1, "m2": decode_m2, "m3": decode_m3, - }[method] + }.get(method, None) # ===== # part 2: fun decoder @@ -200,7 +201,7 @@ def get_fdecoder(method): }[method] # ===== -# other helpers +# reducer class IRConf(Conf): def __init__(self): @@ -247,3 +248,283 @@ def reduce(self, orig_scores, extra_scores=None): for one_order, one_idx in enumerate(reduce_idxes): reduce_order[one_idx] = one_order return reduce_order, reduce_idxes + +# ===== +# CKYParser + +# nodes for the phrase tree: [a,b] +class TreeNode: + def __init__(self, node_left: 'TreeNode'=None, node_right: 'TreeNode'=None, sent=None, leaf_idx: int=None): + if leaf_idx is not None: + assert sent is not None + assert node_left is None and node_right is None + self.is_leaf = True + self.sent = sent + self.range_a, self.range_b = leaf_idx, leaf_idx + self.head_idx = leaf_idx # only one node + else: + assert sent is None and leaf_idx is None + self.is_leaf = False + assert node_left.sent is node_right.sent + self.sent = node_right.sent + self.node_left, self.node_right = node_left, node_right + assert node_left.range_b+1 == node_right.range_a + self.range_a, self.range_b = node_left.range_a, node_right.range_b + self.head_idx = None + + def set_head(self, head_idx): + assert self.range_a<=head_idx and head_idx<=self.range_b + self.head_idx = head_idx + + # add_head: plus head printings; fake_nt: fake X for non-terminal + def get_repr(self, add_head, fake_nt): + if self.is_leaf: + cur_word = self.sent[self.range_a] + cur_word = {"(": "-LRB-", ")": "-RRB-"}.get(cur_word, cur_word) # escape + if fake_nt: + return f"(X {cur_word})" + else: + return cur_word + else: + ra, rb = self.node_left.get_repr(add_head, fake_nt), self.node_right.get_repr(add_head, fake_nt) + fake_nt_str = "X " if fake_nt else "" + if self.head_idx is None or not add_head: + return f"({fake_nt_str}{ra} {rb})" + else: + head_str = self.sent[self.head_idx] + return f"({fake_nt_str}[{head_str}] {ra} {rb})" + + def __repr__(self): + return self.get_repr(True, False) + +# +class CKYConf(Conf): + def __init__(self): + # phrase split score: lambda: a[sizeA, sizeB], b[sizeB, sizeA] -> scalar + self.ps_method = "add_avg" + self.ps_pp = False # special two-layer processing with PUNCT as splitter + # head finding method (similar to the Reducer) + self.hf_use_left = False # always make left as head + self.hf_use_right = False # always make right as head + self.hf_use_fun_masks = False # use extra help from masks (fun?@[\\^_`{|}~„“”«»²–、。・()「」『』:;!?〜,《》·]+""") + # selecting the interesting ones + PUNCT_SET = set(r""",.:;!?、,。:;!?""") + + def __init__(self, conf: CKYConf, r_vocab): + self.conf = conf + self.r_vocab = r_vocab + # + self.score_f = { + "add_avg": lambda a,b: (a+b.T).mean(), + "max_avg": lambda a,b: np.maximum(a, b.T).mean(), + }[conf.ps_method] + + # input [N,N] original scores, and [N] extra-root-scores + def parse(self, sent, uposes, fun_masks, orig_scores, extra_scores=None): + slen = len(sent) + # first get all leaf nodes + all_leaf_nodes = [TreeNode(sent=sent, leaf_idx=z) for z in range(slen)] + # cky + if slen>1 and self.conf.ps_pp: + # split by PUNCT + is_punct = [all(z in CKYParser.PUNCT_SET for z in w) for w in sent] # whole string is punct + split_ranges = [] + prev_start = None + for i, is_p in enumerate(is_punct): + if is_p: + if prev_start is not None: + split_ranges.append((prev_start, i)) + prev_start = None + split_ranges.append((i, i+1)) + else: + if prev_start is None: + prev_start = i + if prev_start is not None: + split_ranges.append((prev_start, slen)) # add final range + # decode layer 1 + split_trees = [self._cky(all_leaf_nodes[a:b], orig_scores[a:b,a:b]) for a,b in split_ranges] + # decode layer 2 + # todo(+N): current simply average here! + l2_len = len(split_trees) + l2_scores = np.zeros([l2_len, l2_len]) + for i in range(l2_len): + for j in range(i+1, l2_len): + ia, ib = split_ranges[i] + ja, jb = split_ranges[j] + l2_scores[i,j] = orig_scores[ia:ib, ja:jb].mean() + l2_scores[j,i] = orig_scores[ja:jb, ia:ib].mean() + best_tree = self._cky(split_trees, l2_scores) # layer-2 decode + else: + best_tree = self._cky(all_leaf_nodes, orig_scores) # directly decode + # ===== + # then decide the head + ret_heads, ret_root = self._decide_head(best_tree, orig_scores, extra_scores, uposes, fun_masks) + # + # zlog(best_tree, func="debug") + return best_tree, (ret_heads, ret_root) + + # decode on the current list of leaf nodes (can be subtrees) + def _cky(self, cur_leaf_nodes, cur_scores): + slen = len(cur_leaf_nodes) + assert slen>0 + if slen==1: + return cur_leaf_nodes[0] # directly return for singleton node + # ===== + cky_best_scores = np.zeros([slen, slen]) # [a,b] + cky_best_sp = np.zeros([slen, slen], dtype=np.int) + for step in range(2, slen+1): # do not need to calculate leaf nodes + for a in range(slen+1-step): + b = a + step - 1 # [a,b] + # loop over split points + cur_best_score = -1000. + cur_best_sp = None + for sp in range(a,b): # [a,b] -> [a,sp], [sp+1,b] + cur_edge_score = self.score_f(cur_scores[a:sp+1, sp+1:b+1], cur_scores[sp+1:b+1, a:sp+1]) + cur_score = cky_best_scores[a,sp] + cky_best_scores[sp+1,b] + cur_edge_score + if cur_score > cur_best_score: + cur_best_score = cur_score + cur_best_sp = sp + assert cur_best_sp is not None + cky_best_scores[a,b] = cur_best_score + cky_best_sp[a,b] = cur_best_sp + # traceback + best_tree = self._traceback(cky_best_sp, cur_leaf_nodes, 0, slen-1) + return best_tree + + # return TreeNode + def _traceback(self, cky_best_sp, leaf_nodes, a, b): + assert b>=a + if a==b: + return leaf_nodes[a] + else: + sp = cky_best_sp[a,b] + node_left = self._traceback(cky_best_sp, leaf_nodes, a, sp) + node_right = self._traceback(cky_best_sp, leaf_nodes, sp+1, b) + return TreeNode(node_left=node_left, node_right=node_right) + + # + def _pos_head_score(self, hidx_self, hidx_other, uposes): + # todo(note): CONJ is from UDv1 + POS_RANKS = [["VERB"], ["NOUN", "PROPN"], ["ADJ"], ["ADV"], + ["ADP", "AUX", "CCONJ", "DET", "NUM", "PART", "PRON", "SCONJ", "CONJ"], ["INTJ", "SYM", "X"], ["PUNCT"]] + POS2RANK = {} + for i, ps in enumerate(POS_RANKS): + for p in ps: + POS2RANK[p] = i + upos_self, upos_other = uposes[hidx_self], uposes[hidx_other] + rank_self, rank_other = POS2RANK[upos_self], POS2RANK[upos_other] + if rank_self < rank_other: + return 1 + elif rank_self > rank_other: + return -1 + else: + # TODO(+N): maybe specific to en + # if neighbouring, select the right one, else left + if abs(hidx_self-hidx_other) == 0: + return int(hidx_self>hidx_other) + else: + return int(hidx_selfnode_other.head_idx) else -100. + elif conf.hf_use_pos_rule: + return self._pos_head_score(node_self.head_idx, node_other.head_idx, uposes) + else: + pass + # ===== + # filter ones + if conf.hf_use_fun_masks: + if fun_masks[node_self.head_idx] and not fun_masks[node_other.head_idx]: + return -100. + elif not fun_masks[node_self.head_idx] and fun_masks[node_other.head_idx]: + return 100. + if conf.hf_use_vocab: + word_self, word_other = node_self.sent[node_self.head_idx], node_other.sent[node_other.head_idx] + vthr = conf.hf_vr_thresh + rank_self, rank_other = self.r_vocab.get(word_self.lower(), vthr+1), self.r_vocab.get(word_other.lower(), vthr+1) + if rank_self<=vthr and rank_other>vthr: + return -100. + elif rank_self>vthr and rank_other<=vthr: + return 100. + # ===== + # fallback: score based ones + if conf.hf_use_ir: + self_head_idx = node_self.head_idx + return orig_scores[self_head_idx].sum() * conf.hf_win_row \ + + orig_scores[:,self_head_idx].sum() * conf.hf_win_col \ + + extra_scores[self_head_idx] * conf.hf_w_extra + else: + x0_self, x1_self = (node_self.head_idx, node_self.head_idx+1) if conf.hf_self_headed \ + else (node_self.range_a, node_self.range_b+1) + x0_other, x1_other = (node_other.head_idx, node_other.head_idx+1) if conf.hf_other_headed \ + else (node_other.range_a, node_other.range_b+1) + x0_par, x1_par = node_parent.range_a, node_parent.range_b+1 + # inner score + s_inner = (orig_scores[x0_self:x1_self, x0_other:x1_other]*conf.hf_win_row + + orig_scores[x0_other:x1_other, x0_self:x1_self].T*conf.hf_win_col).mean() + # outer score + s_outer_row = np.concatenate([orig_scores[x0_self:x1_self, :x0_par], + orig_scores[x0_self:x1_self, x1_par:]], -1) * conf.hf_wout_row + s_outer_col = np.concatenate([orig_scores[:x0_par, x0_self:x1_self], + orig_scores[x1_par:, x0_self:x1_self]], 0).T * conf.hf_wout_col + s_outer_sum = s_outer_row + s_outer_col + s_outer = s_outer_sum.mean() if s_outer_sum.size>0 else 0. + # extra score + s_extra = extra_scores[x0_self:x1_self].mean() * conf.hf_w_extra + return s_inner + s_outer + s_extra + # ===== + # recursively decide + def _rdec(cur_tree: TreeNode): + if not cur_tree.is_leaf: + _rdec(cur_tree.node_left) + _rdec(cur_tree.node_right) + score_left = _score(cur_tree.node_left, cur_tree.node_right, cur_tree) + score_right = _score(cur_tree.node_right, cur_tree.node_left, cur_tree) + head_left, head_right = cur_tree.node_left.head_idx, cur_tree.node_right.head_idx + if score_left < score_right: + cur_tree.set_head(head_right) + ret_heads[head_left] = head_right + 1 # +1 offset + else: + cur_tree.set_head(head_left) + ret_heads[head_right] = head_left + 1 + # ===== + _rdec(top_tree) + ret_root = top_tree.head_idx + ret_heads[ret_root] = 0 + return ret_heads, ret_root + +# PYTHONPATH=../src/ python3 -m pdb ../src/tasks/cmd.py zdpar.stat2.decode3 input_file:_en_dev.ppos.pic already_pre_computed:1 mdec_method:cky hf_use_fun_masks:1 hf_use_ir:1 ps_pp:1 +# b tasks/zdpar/stat2/helper_decode:388 diff --git a/tasks/zdpar/stat2/helper_feature.py b/tasks/zdpar/stat2/helper_feature.py index b1b21d6..1649fe6 100644 --- a/tasks/zdpar/stat2/helper_feature.py +++ b/tasks/zdpar/stat2/helper_feature.py @@ -20,16 +20,21 @@ def __init__(self): self.use_bert = True self.bconf = BerterConf().init_from_kwargs(bert_sent_extend=0, bert_root_mode=-1, bert_zero_pademb=True, bert_layers="[0,1,2,3,4,5,6,7,8,9,10,11,12]") + # todo(note): generally 'first' strategy is slightly worse, finally using all+MASK+mse+layer8 # for bert - self.b_mask_mode = "first" # all/first/one/pass + self.b_mask_mode = "all" # all/first/one/pass self.b_mask_repl = "[MASK]" # or other things like [UNK], [PAD], ... + # subword mode for bert + self.b_use_subword = False # use subword mode + self.b_subword_aggr = "max" # how to aggregate distances in subword mode for words # for g1p self.g1p_replace_unk = True # otherwise replace with 0 # distance (influence scores) - self.dist_f = "cos" # cos or mse + self.dist_f = "mse" # cos or mse # for bert repl self.br_vocab_file = "" self.br_rank_thresh = 100 # rank threshold >=this for non-topical words + self.br_only_topical = True # only change topical words def do_validate(self): assert self.use_bert, "currently only implemented the bert mode" @@ -43,6 +48,7 @@ def __init__(self, conf: FeaturerConf): # used for bert-repl: self.repl_sent() self.br_vocab = Vocab.read(conf.br_vocab_file) if conf.br_vocab_file else None self.br_rthresh = conf.br_rank_thresh + self.br_only_topical = conf.br_only_topical # replace (simplify) sentence by replacing with bert predictions # todo(note): remember that the main purpose is to remove surprising topical words but retain syntax structure @@ -89,12 +95,13 @@ def repl_sent(self, sent: List[str], fixed_arr): if one_idx>=sent_len or is_fixed[one_idx]: # hit boundary if len(prev_cands)>0: prev_cands.sort(key=lambda x: -x[-1]) - togo_idx, togo_score = prev_cands[0] - if togo_score > 0: # only change it if not forbidden + togo_idx, togo_word, togo_score = prev_cands[0] + if togo_score > 0 and (not self.br_only_topical or togo_word in topical_words_count): + # only change it if not forbidden and in topical input words if flag set ret_seq[togo_idx] = pred_strs[togo_idx] prev_cands.clear() else: - prev_cands.append((one_idx, changing_scores[one_idx])) + prev_cands.append((one_idx, sent[one_idx], changing_scores[one_idx])) # todo(+N): whether fix changed word? currently nope return ret_seq, np.asarray(is_fixed).astype(np.bool) @@ -102,7 +109,29 @@ def repl_sent(self, sent: List[str], fixed_arr): def get_scores(self, sent: List[str]): # get features conf = self.conf - features = encode_bert(self.berter, sent, conf.b_mask_mode, conf.b_mask_repl) # [1+slen, 1+slen, fold*D] + # ===== + if conf.b_use_subword: + range_base_idx = 0 + input_sent = [] + tokenizer = self.berter.tokenizer + all_subwords = [] + word_ranges = [] + for i, t in enumerate(sent): + cur_toks = tokenizer.tokenize(t) + # in some cases, there can be empty strings -> put the original word + if len(cur_toks) == 0: + cur_toks = [t] + cur_len = len(cur_toks) + all_subwords.append(cur_toks) + word_ranges.append((range_base_idx, range_base_idx+cur_len)) + range_base_idx += cur_len + input_sent.extend(cur_toks) + else: + word_ranges = None + input_sent = sent + # ===== + # normal calculating + features = encode_bert(self.berter, input_sent, conf.b_mask_mode, conf.b_mask_repl) # [1+slen, 1+slen, fold*D] orig_feature_shape = BK.get_shape(features) folded_shape = orig_feature_shape[:-1] + [self.fold, orig_feature_shape[-1]//self.fold] all_features = features.view(folded_shape) # [1+slen, 1+slen, fold, D] @@ -115,7 +144,26 @@ def get_scores(self, sent: List[str]): scores = BK.sqrt(((all_features[1:] - all_features[0].unsqueeze(0)) ** 2).mean(-1)) else: raise NotImplementedError() - return scores.cpu().numpy() + scores_arr = scores.cpu().numpy() + # ===== + # convert back to word mode if needed + aggr_f = {"max": np.max, "avg": np.mean}[conf.b_subword_aggr] + if conf.b_use_subword: + slen = len(sent) + assert slen == len(word_ranges) + # repack scores_arr + new_scores_arr = np.zeros([slen, slen+1, self.fold]) + for idx_i in range(slen): + ri_a, ri_b = word_ranges[idx_i] + # to CLS + new_scores_arr[idx_i, 0] = aggr_f(scores_arr[ri_a:ri_b, 0], axis=0) + # to others + for idx_j in range(slen): + rj_a, rj_b = word_ranges[idx_j] + new_scores_arr[idx_i, idx_j+1] = aggr_f(aggr_f(scores_arr[ri_a:ri_b, rj_a+1:rj_b+1], axis=0), axis=0) + return new_scores_arr + else: + return scores_arr def output_shape(self, slen): return [slen, slen+1, len(self.conf.bconf.bert_layers)] diff --git a/tasks/zdpar/stat2/run.sh b/tasks/zdpar/stat2/run.sh index 47b6fe4..5f8d710 100644 --- a/tasks/zdpar/stat2/run.sh +++ b/tasks/zdpar/stat2/run.sh @@ -16,7 +16,7 @@ pic_name="_${CUR_LANG}_dev.${b_mask_mode}.${b_mask_repl}.${dist_f}.pic" OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=${RGPU} PYTHONPATH=${SRC_DIR} python3 ${SRC_DIR}/tasks/cmd.py zdpar.stat2.decode3 device:${DEVICE} input_file:${DATA_DIR}/${CUR_LANG}_dev.ppos.conllu output_pic:${pic_name} b_mask_mode:${b_mask_mode} "b_mask_repl:[${b_mask_repl}]" dist_f:${dist_f} done; done; done |& tee _log1028 -# _run1029.sh +# _run1029.sh -> fold 7/8/9 slightly better SRC_DIR="../src/" #DATA_DIR="../data/UD_RUN/ud24/" DATA_DIR="../data/ud24/" @@ -31,3 +31,64 @@ for use_fdec in 0 1; do done done done |& tee _log1029 + +# _run1113.sh +# bash _run1113.sh |& tee _log1113 +SRC_DIR="../src/" +#DATA_DIR="../data/UD_RUN/ud24/" +DATA_DIR="../data/ud24/" +CUR_LANG=en +RGPU= +DEVICE=-1 +for hf_use_pos_rule in 0 1; do +for fpic in *.pic; do +for which_fold in 12 9 8 7; do + echo "ZRUN POS_RULE=${hf_use_pos_rule} ${fpic} FOLD=${which_fold}" + CUDA_VISIBLE_DEVICES=${RGPU} PYTHONPATH=${SRC_DIR} python3 ${SRC_DIR}/tasks/cmd.py zdpar.stat2.decode3 device:${DEVICE} input_file:$fpic already_pre_computed:1 mdec_method:cky hf_use_ir:1 ps_pp:1 hf_use_pos_rule:${hf_use_pos_rule} which_fold:${which_fold} +done +done +done + +# _run1114a.sh -> row=1 col=0/-1 +# bash _run1114a.sh |& tee _log1114a +SRC_DIR="../src/" +#DATA_DIR="../data/UD_RUN/ud24/" +DATA_DIR="../data/ud24/" +CUR_LANG=en +RGPU= +DEVICE=-1 +for hf_use_pos_rule in 0 1; do +fpic="_en_dev.all.MASK.mse.pic" +which_fold=8 +for ps_method in "add_avg" "max_avg"; do +for hf_win_row in 0 1; do +for hf_win_col in 1 0 -1; do + echo "ZRUN hf_use_pos_rule:${hf_use_pos_rule} ps_method:${ps_method} hf_win_row:${hf_win_row} hf_win_col:${hf_win_col}" + CUDA_VISIBLE_DEVICES=${RGPU} PYTHONPATH=${SRC_DIR} python3 ${SRC_DIR}/tasks/cmd.py zdpar.stat2.decode3 device:${DEVICE} input_file:$fpic already_pre_computed:1 mdec_method:cky which_fold:${which_fold} hf_use_ir:1 ps_pp:1 hf_use_pos_rule:${hf_use_pos_rule} ps_method:${ps_method} hf_win_row:${hf_win_row} hf_win_col:${hf_win_col} +done +done +done +done + +# _run1114b.sh -> not much differences on the good ones ~31 +# bash _run1114b.sh |& tee _log1114b +SRC_DIR="../src/" +#DATA_DIR="../data/UD_RUN/ud24/" +DATA_DIR="../data/ud24/" +CUR_LANG=en +RGPU= +DEVICE=-1 +# +hf_use_pos_rule=0 +fpic="_en_dev.all.MASK.mse.pic" +which_fold=8 +for ps_method in "add_avg" "max_avg"; do +for hf_win_row in 0 1; do +for hf_win_col in 1 0 -1; do +for hf_wout_row in 0 1 -1; do +for hf_wout_col in 0 1 -1; do +for hf_self_headed in 0 1; do +for hf_other_headed in 0 1; do + echo "ZRUN ps_method:${ps_method} hf_win_row:${hf_win_row} hf_win_col:${hf_win_col} hf_wout_row:${hf_wout_row} hf_wout_col:${hf_wout_col} hf_self_headed:${hf_self_headed} hf_other_headed:${hf_other_headed}" + CUDA_VISIBLE_DEVICES=${RGPU} PYTHONPATH=${SRC_DIR} python3 ${SRC_DIR}/tasks/cmd.py zdpar.stat2.decode3 device:${DEVICE} input_file:$fpic already_pre_computed:1 mdec_method:cky which_fold:${which_fold} hf_use_ir:0 ps_pp:1 ps_method:${ps_method} hf_win_row:${hf_win_row} hf_win_col:${hf_win_col} hf_wout_row:${hf_wout_row} hf_wout_col:${hf_wout_col} hf_self_headed:${hf_self_headed} hf_other_headed:${hf_other_headed} +done; done; done; done; done; done; done diff --git a/tasks/zdpar/zfp/__init__.py b/tasks/zdpar/zfp/__init__.py new file mode 100644 index 0000000..832326f --- /dev/null +++ b/tasks/zdpar/zfp/__init__.py @@ -0,0 +1,20 @@ +# + +# the final parser (hopefully a simple one) + +# +# 0. labeled+recursive, direct embeddings; what object? +# 1. score as confidence for the combination and used for later calculation +# 2. encourage sparsity (with NOP op) +# 3. how to reduce? +# 4. the connections should also be distributed? + +# todo(note): +# ===== +# 20.01 +# look at word pairs for syntax path -> tools/see_pairs.py (co-heads?) +# word drop and word pred loss: joint-training as aux loss -> directly multi-task training is not helpful +# pre-training + fine-tune? +# meaning grouping by picking and cutting the important dep links +# predict combined path as new label? +# slayer: d-self-att diff --git a/tasks/zdpar/zfp/dec.py b/tasks/zdpar/zfp/dec.py new file mode 100644 index 0000000..af4a998 --- /dev/null +++ b/tasks/zdpar/zfp/dec.py @@ -0,0 +1,234 @@ +# + +# the decoders + +from typing import List +import numpy as np +from msp.utils import Conf, Helper, Constants +from msp.nn import BK +from msp.nn.layers import BasicNode, Affine, NoDropRop, PairScorerConf, PairScorer +from msp.zext.seq_helper import DataPadder + +from ..common.data import ParseInstance +from ..algo import nmst_unproj, nmarginal_unproj + +# ===== +# scorers + +class FpScorerConf(Conf): + def __init__(self): + self._input_dim = -1 # enc's (input) last dimension + self._num_label = -1 # number of labels + # space transferring (0 means no this layer) + self.arc_space = 512 + self.lab_space = 128 + self.transform_act = "elu" + # pairwise conf + self.ps_conf = PairScorerConf() + +class FpPairedScorer(BasicNode): + def __init__(self, pc: BK.ParamCollection, sconf: FpScorerConf): + super().__init__(pc, None, None) + # options + input_dim = sconf._input_dim + arc_space = sconf.arc_space + lab_space = sconf.lab_space + transform_act = sconf.transform_act + # + self.input_dim = input_dim + self.num_label = sconf._num_label + # attach/arc + if arc_space>0: + self.arc_m = self.add_sub_node("am", Affine(pc, input_dim, arc_space, act=transform_act)) + self.arc_h = self.add_sub_node("ah", Affine(pc, input_dim, arc_space, act=transform_act)) + else: + self.arc_h = self.arc_m = None + arc_space = input_dim + self.arc_scorer = self.add_sub_node("as", PairScorer(pc, arc_space, arc_space, 1, sconf.ps_conf)) + # labeling + if lab_space>0: + self.lab_m = self.add_sub_node("lm", Affine(pc, input_dim, lab_space, act=transform_act)) + self.lab_h = self.add_sub_node("lh", Affine(pc, input_dim, lab_space, act=transform_act)) + else: + self.lab_h = self.lab_m = None + lab_space = input_dim + self.lab_scorer = self.add_sub_node("ls", PairScorer(pc, lab_space, lab_space, self.num_label, sconf.ps_conf)) + + # ===== + # score + + def transform_and_arc_score(self, senc_expr, mask_expr=None): + ah_expr = self.arc_h(senc_expr) if self.arc_h else senc_expr + am_expr = self.arc_m(senc_expr) if self.arc_m else senc_expr + arc_full_score = self.arc_scorer.paired_score(am_expr, ah_expr, mask_expr, mask_expr) + return arc_full_score # [*, m, h, 1] + + def transform_and_lab_score(self, senc_expr, mask_expr=None): + lh_expr = self.lab_h(senc_expr) if self.lab_h else senc_expr + lm_expr = self.lab_m(senc_expr) if self.lab_m else senc_expr + lab_full_score = self.lab_scorer.paired_score(lm_expr, lh_expr, mask_expr, mask_expr) + return lab_full_score # [*, m, h, Lab] + + def transform_and_arc_score_plain(self, mod_srepr, head_srepr, mask_expr=None): + ah_expr = self.arc_h(head_srepr) if self.arc_h else head_srepr + am_expr = self.arc_m(mod_srepr) if self.arc_m else mod_srepr + arc_full_score = self.arc_scorer.plain_score(am_expr, ah_expr, mask_expr, mask_expr) + return arc_full_score + + def transform_and_lab_score_plain(self, mod_srepr, head_srepr, mask_expr=None): + lh_expr = self.lab_h(head_srepr) if self.lab_h else head_srepr + lm_expr = self.lab_m(mod_srepr) if self.lab_m else mod_srepr + lab_full_score = self.lab_scorer.plain_score(lm_expr, lh_expr, mask_expr, mask_expr) + return lab_full_score + +class FpSingleScorer(BasicNode): + def __init__(self, pc: BK.ParamCollection, sconf: FpScorerConf): + super().__init__(pc, None, None) + # options + input_dim = sconf._input_dim + arc_space = sconf.arc_space + lab_space = sconf.lab_space + transform_act = sconf.transform_act + # + self.input_dim = input_dim + self.num_label = sconf._num_label + self.mask_value = Constants.REAL_PRAC_MIN + # attach/arc + if arc_space>0: + self.arc_f = self.add_sub_node("af", Affine(pc, input_dim, arc_space, act=transform_act)) + else: + self.arc_f = None + arc_space = input_dim + self.arc_scorer = self.add_sub_node("as", Affine(pc, arc_space, 1, init_rop=NoDropRop())) + # labeling + if lab_space>0: + self.lab_f = self.add_sub_node("lf", Affine(pc, input_dim, lab_space, act=transform_act)) + else: + self.lab_f = None + lab_space = input_dim + self.lab_scorer = self.add_sub_node("ls", Affine(pc, lab_space, self.num_label, init_rop=NoDropRop())) + + # ===== + # score + + def transform_and_arc_score(self, senc_expr, mask_expr=None): + a_expr = self.arc_f(senc_expr) if self.arc_f else senc_expr + arc_full_score = self.arc_scorer(a_expr) + if mask_expr is not None: + arc_full_score += self.mask_value * (1. - mask_expr).unsqueeze(-1) + return arc_full_score # [*, 1] + + def transform_and_lab_score(self, senc_expr, mask_expr=None): + l_expr = self.lab_f(senc_expr) if self.lab_f else senc_expr + lab_full_score = self.lab_scorer(l_expr) + if mask_expr is not None: + lab_full_score += self.mask_value * (1. - mask_expr).unsqueeze(-1) + return lab_full_score # [*, Lab] + +# ===== +# decoders + +class FpDecConf(Conf): + def __init__(self): + self.sconf = FpScorerConf() + self.use_ablpair = False # special mode? + # loss weight + self.lambda_arc = 1. + self.lambda_lab = 1. + +class FpDecoder(BasicNode): + def __init__(self, pc, conf: FpDecConf, label_vocab): + super().__init__(pc, None, None) + self.conf = conf + self.label_vocab = label_vocab + self.predict_padder = DataPadder(2, pad_vals=0) + # the scorer + self.use_ablpair = conf.use_ablpair + conf.sconf._input_dim = conf._input_dim + conf.sconf._num_label = conf._num_label + if self.use_ablpair: + self.scorer = self.add_sub_node("s", FpSingleScorer(pc, conf.sconf)) + else: + self.scorer = self.add_sub_node("s", FpPairedScorer(pc, conf.sconf)) + + # ----- + # scoring + def _score(self, enc_expr, mask_expr): + # ----- + def _special_score(one_score): # specially change ablpair scores into [bs,m,h,*] + root_score = one_score[:,:,0].unsqueeze(2) # [bs, rlen, 1, *] + tmp_shape = BK.get_shape(root_score) + tmp_shape[1] = 1 # [bs, 1, 1, *] + padded_root_score = BK.concat([BK.zeros(tmp_shape), root_score], dim=1) # [bs, rlen+1, 1, *] + final_score = BK.concat([padded_root_score, one_score.transpose(1,2)], dim=2) # [bs, rlen+1[m], rlen+1[h], *] + return final_score + # ----- + if self.use_ablpair: + input_mask_expr = (mask_expr.unsqueeze(-1) * mask_expr.unsqueeze(-2))[:, 1:] # [bs, rlen, rlen+1] + arc_score = self.scorer.transform_and_arc_score(enc_expr, input_mask_expr) # [bs, rlen, rlen+1, 1] + lab_score = self.scorer.transform_and_lab_score(enc_expr, input_mask_expr) # [bs, rlen, rlen+1, Lab] + # put root-scores for both directions + arc_score = _special_score(arc_score) + lab_score = _special_score(lab_score) + else: + # todo(+2): for training, we can simply select and lab-score + arc_score = self.scorer.transform_and_arc_score(enc_expr, mask_expr) # [bs, m, h, 1] + lab_score = self.scorer.transform_and_lab_score(enc_expr, mask_expr) # [bs, m, h, Lab] + # mask out diag scores + diag_mask = BK.eye(BK.get_shape(arc_score, 1)) + diag_mask[0,0] = 0. + diag_add = Constants.REAL_PRAC_MIN * (diag_mask.unsqueeze(-1).unsqueeze(0)) # [1, m, h, 1] + arc_score += diag_add + lab_score += diag_add + return arc_score, lab_score + + # loss + # todo(note): no margins here, simply using target-selection cross-entropy + def loss(self, insts: List[ParseInstance], enc_expr, mask_expr, **kwargs): + conf = self.conf + # scoring + arc_score, lab_score = self._score(enc_expr, mask_expr) # [bs, m, h, *] + # loss + bsize, max_len = BK.get_shape(mask_expr) + # gold heads and labels + gold_heads_arr, _ = self.predict_padder.pad([z.heads.vals for z in insts]) + # todo(note): here use the original idx of label, no shift! + gold_labels_arr, _ = self.predict_padder.pad([z.labels.idxes for z in insts]) + gold_heads_expr = BK.input_idx(gold_heads_arr) # [bs, Len] + gold_labels_expr = BK.input_idx(gold_labels_arr) # [bs, Len] + # collect the losses + arange_bs_expr = BK.arange_idx(bsize).unsqueeze(-1) # [bs, 1] + arange_m_expr = BK.arange_idx(max_len).unsqueeze(0) # [1, Len] + # logsoftmax and losses + arc_logsoftmaxs = BK.log_softmax(arc_score.squeeze(-1), -1) # [bs, m, h] + lab_logsoftmaxs = BK.log_softmax(lab_score, -1) # [bs, m, h, Lab] + arc_sel_ls = arc_logsoftmaxs[arange_bs_expr, arange_m_expr, gold_heads_expr] # [bs, Len] + lab_sel_ls = lab_logsoftmaxs[arange_bs_expr, arange_m_expr, gold_heads_expr, gold_labels_expr] # [bs, Len] + # head selection (no root) + arc_loss_sum = (- arc_sel_ls * mask_expr)[:, 1:].sum() + lab_loss_sum = (- lab_sel_ls * mask_expr)[:, 1:].sum() + final_loss = conf.lambda_arc * arc_loss_sum + conf.lambda_lab * lab_loss_sum + final_loss_count = mask_expr[:, 1:].sum() + return [[final_loss, final_loss_count]] + + # decode + def predict(self, insts: List[ParseInstance], enc_expr, mask_expr, **kwargs): + conf = self.conf + # scoring + arc_score, lab_score = self._score(enc_expr, mask_expr) # [bs, m, h, *] + full_score = BK.log_softmax(arc_score, -2) + BK.log_softmax(lab_score, -1) # [bs, m, h, Lab] + # decode + mst_lengths = [len(z) + 1 for z in insts] # +1 to include ROOT for mst decoding + mst_lengths_arr = np.asarray(mst_lengths, dtype=np.int32) + mst_heads_arr, mst_labels_arr, mst_scores_arr = \ + nmst_unproj(full_score, mask_expr, mst_lengths_arr, labeled=True, ret_arr=True) + # ===== assign, todo(warn): here, the labels are directly original idx, no need to change + misc_prefix = "g" + for one_idx, one_inst in enumerate(insts): + cur_length = mst_lengths[one_idx] + one_inst.pred_heads.set_vals(mst_heads_arr[one_idx][:cur_length]) # directly int-val for heads + one_inst.pred_labels.build_vals(mst_labels_arr[one_idx][:cur_length], self.label_vocab) + one_scores = mst_scores_arr[one_idx][:cur_length] + one_inst.pred_par_scores.set_vals(one_scores) + # extra output + one_inst.extra_pred_misc[misc_prefix+"_score"] = one_scores.tolist() diff --git a/tasks/zdpar/zfp/enc.py b/tasks/zdpar/zfp/enc.py new file mode 100644 index 0000000..98754ec --- /dev/null +++ b/tasks/zdpar/zfp/enc.py @@ -0,0 +1,192 @@ +# + +from typing import List +import numpy as np + +from msp.utils import Conf, Random, zlog, JsonRW, zfatal, zwarn +from msp.data import VocabPackage, MultiHelper +from msp.nn import BK +from msp.nn.layers import BasicNode, Affine, RefreshOptions, NoDropRop +from msp.nn.modules import EmbedConf, MyEmbedder, EncConf, MyEncoder, Berter2Conf, Berter2 +from msp.zext.seq_helper import DataPadder + +# todo(note): most parts are adopted from model.py:ParserBT + +# conf +class FpEncConf(Conf): + def __init__(self): + # embedding + self.emb_conf = EmbedConf().init_from_kwargs(dim_word=0) # by default no embedding inputs + # bert + self.bert_conf = Berter2Conf().init_from_kwargs(bert2_retinc_cls=True, bert2_training_mask_rate=0., bert2_output_mode="concat") + # middle layer to reduce dim + self.middle_dim = 0 # 0 means no middle one + # encoder + self.enc_conf = EncConf().init_from_kwargs(enc_rnn_type="lstm2", enc_hidden=1024, enc_rnn_layer=0) + # ===== + # other options + # inputs + self.char_max_length = 45 + # dropouts + self.drop_embed = 0.33 + self.dropmd_embed = 0. + self.drop_hidden = 0.33 + self.gdrop_rnn = 0.33 # gdrop (always fixed for recurrent connections) + self.idrop_rnn = 0.33 # idrop for rnn + self.fix_drop = True # fix drop for one run for each dropout + +# node +class FpEncoder(BasicNode): + def __init__(self, pc: BK.ParamCollection, conf: FpEncConf, vpack: VocabPackage): + super().__init__(pc, None, None) + self.conf = conf + # ===== Vocab ===== + self.word_vocab = vpack.get_voc("word") + self.char_vocab = vpack.get_voc("char") + self.pos_vocab = vpack.get_voc("pos") + # avoid no params error + self._tmp_v = self.add_param("nope", (1,)) + # ===== Model ===== + # embedding + self.emb = self.add_sub_node("emb", MyEmbedder(self.pc, conf.emb_conf, vpack)) + self.emb_output_dim = self.emb.get_output_dims()[0] + # bert + self.bert = self.add_sub_node("bert", Berter2(self.pc, conf.bert_conf)) + self.bert_output_dim = self.bert.get_output_dims()[0] + # make sure there are inputs + assert self.emb_output_dim>0 or self.bert_output_dim>0 + # middle? + if conf.middle_dim > 0: + self.middle_node = self.add_sub_node("mid", Affine(self.pc, self.emb_output_dim + self.bert_output_dim, + conf.middle_dim, act="elu")) + self.enc_input_dim = conf.middle_dim + else: + self.middle_node = None + self.enc_input_dim = self.emb_output_dim + self.bert_output_dim # concat the two parts (if needed) + # encoder? + # todo(note): feed compute-on-the-fly hp + conf.enc_conf._input_dim = self.enc_input_dim + self.enc = self.add_sub_node("enc", MyEncoder(self.pc, conf.enc_conf)) + self.enc_output_dim = self.enc.get_output_dims()[0] + # ===== Input Specification ===== + # inputs (word, char, pos) and vocabulary + self.need_word = self.emb.has_word + self.need_char = self.emb.has_char + # todo(warn): currently only allow extra fields for POS + self.need_pos = False + if len(self.emb.extra_names) > 0: + assert len(self.emb.extra_names) == 1 and self.emb.extra_names[0] == "pos" + self.need_pos = True + # + self.word_padder = DataPadder(2, pad_vals=self.word_vocab.pad, mask_range=2) + self.char_padder = DataPadder(3, pad_lens=(0, 0, conf.char_max_length), pad_vals=self.char_vocab.pad) + self.pos_padder = DataPadder(2, pad_vals=self.pos_vocab.pad) + + def get_output_dims(self, *input_dims): + return (self.enc_output_dim, ) + + # + def refresh(self, rop=None): + zfatal("Should call special_refresh instead!") + + # ==== + # special routines + + def special_refresh(self, embed_rop, other_rop): + self.emb.refresh(embed_rop) + self.enc.refresh(other_rop) + self.bert.refresh(other_rop) + if self.middle_node is not None: + self.middle_node.refresh(other_rop) + + def prepare_training_rop(self): + mconf = self.conf + embed_rop = RefreshOptions(hdrop=mconf.drop_embed, dropmd=mconf.dropmd_embed, fix_drop=mconf.fix_drop) + other_rop = RefreshOptions(hdrop=mconf.drop_hidden, idrop=mconf.idrop_rnn, gdrop=mconf.gdrop_rnn, + fix_drop=mconf.fix_drop) + return embed_rop, other_rop + + def aug_words_and_embs(self, aug_vocab, aug_wv): + orig_vocab = self.word_vocab + if self.emb.has_word: + orig_arr = self.emb.word_embed.E.detach().cpu().numpy() + # todo(+2): find same-spelling words in the original vocab if not-hit in the extra_embed? + # todo(warn): here aug_vocab should be find in aug_wv + aug_arr = aug_vocab.filter_embed(aug_wv, assert_all_hit=True) + new_vocab, new_arr = MultiHelper.aug_vocab_and_arr(orig_vocab, orig_arr, aug_vocab, aug_arr, aug_override=True) + # assign + self.word_vocab = new_vocab + self.emb.word_embed.replace_weights(new_arr) + else: + zwarn("No need to aug vocab since delexicalized model!!") + new_vocab = orig_vocab + return new_vocab + + # ===== + # run + + def prepare_inputs(self, insts, training, input_word_mask_repl=None): + word_arr, char_arr, extra_arrs, aux_arrs = None, None, [], [] + # ===== specially prepare for the words + wv = self.word_vocab + W_UNK = wv.unk + word_act_idxes = [z.words.idxes for z in insts] + # todo(warn): still need the masks + word_arr, mask_arr = self.word_padder.pad(word_act_idxes) + if input_word_mask_repl is not None: + input_word_mask_repl = input_word_mask_repl.astype(np.int) + word_arr = word_arr * (1-input_word_mask_repl) + W_UNK * input_word_mask_repl # replace with UNK + # ===== + if not self.need_word: + word_arr = None + if self.need_char: + chars = [z.chars.idxes for z in insts] + char_arr, _ = self.char_padder.pad(chars) + if self.need_pos: + poses = [z.poses.idxes for z in insts] + pos_arr, _ = self.pos_padder.pad(poses) + extra_arrs.append(pos_arr) + return word_arr, char_arr, extra_arrs, aux_arrs, mask_arr + + # todo(note): for rnn, need to transpose masks, thus need np.array + # return input_repr, enc_repr, mask_arr + def run(self, insts, training, input_word_mask_repl=None): + self._cache_subword_tokens(insts) + # prepare inputs + word_arr, char_arr, extra_arrs, aux_arrs, mask_arr = \ + self.prepare_inputs(insts, training, input_word_mask_repl=input_word_mask_repl) + # layer0: emb + bert + layer0_reprs = [] + if self.emb_output_dim>0: + emb_repr = self.emb(word_arr, char_arr, extra_arrs, aux_arrs) # [BS, Len, Dim] + layer0_reprs.append(emb_repr) + if self.bert_output_dim>0: + # prepare bert inputs + BERT_MASK_ID = self.bert.tokenizer.mask_token_id + batch_subword_ids, batch_subword_is_starts = [], [] + for bidx, one_inst in enumerate(insts): + st = one_inst.extra_features["st"] + if input_word_mask_repl is not None: + cur_subword_ids, cur_subword_is_start, _ = \ + st.mask_and_return(input_word_mask_repl[bidx][1:], BERT_MASK_ID) # todo(note): exclude ROOT for bert tokens + else: + cur_subword_ids, cur_subword_is_start = st.subword_ids, st.subword_is_start + batch_subword_ids.append(cur_subword_ids) + batch_subword_is_starts.append(cur_subword_is_start) + bert_repr, _ = self.bert.forward_batch(batch_subword_ids, batch_subword_is_starts, + batched_typeids=None, training=training) # [BS, Len, D'] + layer0_reprs.append(bert_repr) + # layer1: enc + enc_input_repr = BK.concat(layer0_reprs, -1) # [BS, Len, D+D'] + if self.middle_node is not None: + enc_input_repr = self.middle_node(enc_input_repr) # [BS, Len, D??] + enc_repr = self.enc(enc_input_repr, mask_arr) + mask_repr = BK.input_real(mask_arr) + return enc_repr, mask_repr # [bs, len, *], [bs, len] + + # ===== + # caching + def _cache_subword_tokens(self, insts): + for one_inst in insts: + if "st" not in one_inst.extra_features: + one_inst.extra_features["st"] = self.bert.subword_tokenize2(one_inst.words.vals[1:], True) diff --git a/tasks/zdpar/zfp/fp.py b/tasks/zdpar/zfp/fp.py new file mode 100644 index 0000000..9814313 --- /dev/null +++ b/tasks/zdpar/zfp/fp.py @@ -0,0 +1,245 @@ +# + +# the parser + +from typing import List +import numpy as np + +from msp.utils import Conf, Random, zlog, JsonRW, zfatal, zwarn +from msp.data import VocabPackage, MultiHelper +from msp.model import Model +from msp.nn import BK +from msp.nn import refresh as nn_refresh +from msp.nn.layers import BasicNode, Affine, RefreshOptions, NoDropRop +# +from msp.zext.seq_helper import DataPadder +from msp.zext.process_train import RConf, SVConf, ScheduledValue, OptimConf + +from ..common.data import ParseInstance +from .enc import FpEncConf, FpEncoder +from .dec import FpDecConf, FpDecoder +from .masklm import MaskLMNodeConf, MaskLMNode + +# ===== +# confs + +# decoding conf +class FpInferenceConf(Conf): + def __init__(self): + # overall + self.batch_size = 32 + self.infer_single_length = 80 # single-inst batch if >= this length + +# training conf +class FpTrainingConf(RConf): + def __init__(self): + super().__init__() + # about files + self.no_build_dict = False + self.load_model = False + self.load_process = False + # batch arranger + self.batch_size = 32 + self.train_min_length = 0 + self.train_skip_length = 100 + self.shuffle_train = True + # optimizer and lrate factor for enc&dec&sl(mid) + self.enc_optim = OptimConf() + self.enc_lrf = SVConf().init_from_kwargs(val=1., which_idx="eidx", mode="none", min_val=0., max_val=1.) + self.dec_optim = OptimConf() + self.dec_lrf = SVConf().init_from_kwargs(val=1., which_idx="eidx", mode="none", min_val=0., max_val=1.) + self.mid_optim = OptimConf() + self.mid_lrf = SVConf().init_from_kwargs(val=1., which_idx="eidx", mode="none", min_val=0., max_val=1.) + # margin + self.margin = SVConf().init_from_kwargs(val=0.0) + # === + # overwrite default ones + self.patience = 3 + self.anneal_times = 5 + self.max_epochs = 50 + +# base conf +class FpParserConf(Conf): + def __init__(self, iconf: FpInferenceConf=None, tconf: FpTrainingConf=None): + self.iconf = FpInferenceConf() if iconf is None else iconf + self.tconf = FpTrainingConf() if tconf is None else tconf + # Model + self.encoder_conf = FpEncConf() + self.decoder_conf = FpDecConf() + self.masklm_conf = MaskLMNodeConf() + # lambdas for training + self.lambda_parse = SVConf().init_from_kwargs(val=1.0) + self.lambda_masklm = SVConf().init_from_kwargs(val=0.0) + +# ===== +# model + +class FpParser(Model): + def __init__(self, conf: FpParserConf, vpack: VocabPackage): + self.conf = conf + self.vpack = vpack + tconf = conf.tconf + # ===== Vocab ===== + self.label_vocab = vpack.get_voc("label") + # ===== Model ===== + self.pc = BK.ParamCollection(True) + # bottom-part: input + encoder + self.enc = FpEncoder(self.pc, conf.encoder_conf, vpack) + self.enc_output_dim = self.enc.get_output_dims()[0] + self.enc_lrf_sv = ScheduledValue("enc_lrf", tconf.enc_lrf) + self.pc.optimizer_set(tconf.enc_optim.optim, self.enc_lrf_sv, tconf.enc_optim, + params=self.enc.get_parameters(), check_repeat=True, check_full=True) + # middle-part: structured layer at the middle (build later for convenient re-loading) + self.slayer = self.build_slayer() + self.mid_lrf_sv = ScheduledValue("mid_lrf", tconf.mid_lrf) + if self.slayer is not None: + self.pc.optimizer_set(tconf.mid_optim.optim, self.mid_lrf_sv, tconf.mid_optim, + params=self.slayer.get_parameters(), check_repeat=True, check_full=True) + # upper-part: decoder + self.dec = self.build_decoder() + self.dec_lrf_sv = ScheduledValue("dec_lrf", tconf.dec_lrf) + self.pc.optimizer_set(tconf.dec_optim.optim, self.dec_lrf_sv, tconf.dec_optim, + params=self.dec.get_parameters(), check_repeat=True, check_full=True) + # extra aux loss + conf.masklm_conf._input_dim = self.enc_output_dim + self.masklm = MaskLMNode(self.pc, conf.masklm_conf, vpack) + self.pc.optimizer_set(tconf.dec_optim.optim, self.dec_lrf_sv, tconf.dec_optim, + params=self.masklm.get_parameters(), check_repeat=True, check_full=True) + # ===== For training ===== + # schedule values + self.margin = ScheduledValue("margin", tconf.margin) + self.lambda_parse = ScheduledValue("lambda_parse", conf.lambda_parse) + self.lambda_masklm = ScheduledValue("lambda_masklm", conf.lambda_masklm) + self._scheduled_values = [self.margin, self.enc_lrf_sv, self.mid_lrf_sv, self.dec_lrf_sv, + self.lambda_parse, self.lambda_masklm] + # for refreshing dropouts + self.previous_refresh_training = True + + # to be implemented + def build_decoder(self): + conf = self.conf + # todo(note): might need to change if using slayer + conf.decoder_conf._input_dim = self.enc_output_dim + conf.decoder_conf._num_label = self.label_vocab.trg_len(True) # todo(note): use the original idx + dec = FpDecoder(self.pc, conf.decoder_conf, self.label_vocab) + return dec + + def build_slayer(self): + return None # None means no slayer + + # called before each mini-batch + def refresh_batch(self, training: bool): + # refresh graph + # todo(warn): make sure to remember clear this one + nn_refresh() + # refresh nodes + if not training: + if not self.previous_refresh_training: + # todo(+1): currently no need to refresh testing mode multiple times + return + self.previous_refresh_training = False + embed_rop = other_rop = RefreshOptions(training=False) # default no dropout + else: + embed_rop, other_rop = self.enc.prepare_training_rop() + # todo(warn): once-bug, don't forget this one!! + self.previous_refresh_training = True + # manually refresh + self.enc.special_refresh(embed_rop, other_rop) + for node in [self.dec, self.slayer, self.masklm]: + if node is not None: + node.refresh(other_rop) + + def update(self, lrate, grad_factor): + self.pc.optimizer_update(lrate, grad_factor) + + def add_scheduled_values(self, v): + self._scheduled_values.append(v) + + def get_scheduled_values(self): + return self._scheduled_values + + # == load and save models + # todo(warn): no need to load confs here + def load(self, path, strict=True): + self.pc.load(path, strict) + # self.conf = JsonRW.load_from_file(path+".json") + zlog(f"Load {self.__class__.__name__} model from {path}.", func="io") + + def save(self, path): + self.pc.save(path) + JsonRW.to_file(self.conf, path + ".json") + zlog(f"Save {self.__class__.__name__} model to {path}.", func="io") + + def aug_words_and_embs(self, aug_vocab, aug_wv): + return self.enc.aug_words_and_embs(aug_vocab, aug_wv) + + # common routines + def collect_loss_and_backward(self, loss_names, loss_ts, loss_lambdas, info, training, loss_factor): + final_losses = [] + for one_name, one_losses, one_lambda in zip(loss_names, loss_ts, loss_lambdas): + if one_lambda>0. and len(one_losses)>0: + num_sub_losses = len(one_losses[0]) + coll_sub_losses = [] + for i in range(num_sub_losses): + # todo(note): (loss_sum, loss_count, gold_count[opt]) + this_loss_sum = BK.stack([z[i][0] for z in one_losses]).sum() + this_loss_count = BK.stack([z[i][1] for z in one_losses]).sum() + info[f"loss_sum_{one_name}{i}"] = this_loss_sum.item() + info[f"loss_count_{one_name}{i}"] = this_loss_count.item() + # optional extra count + if len(one_losses[0][i]) >= 3: # has gold count + info[f"loss_count_extra_{one_name}{i}"] = BK.stack([z[i][2] for z in one_losses]).sum().item() + # todo(note): any case that loss-count can be 0? + coll_sub_losses.append(this_loss_sum / (this_loss_count + 1e-5)) + # sub losses are already multiplied by sub-lambdas + weighted_sub_loss = BK.stack(coll_sub_losses).sum() * one_lambda + final_losses.append(weighted_sub_loss) + if len(final_losses)>0: + final_loss = BK.stack(final_losses).sum() + if training and final_loss.requires_grad: + BK.backward(final_loss, loss_factor) + + # ===== + # training + def fb_on_batch(self, annotated_insts: List[ParseInstance], training=True, loss_factor=1., **kwargs): + self.refresh_batch(training) + cur_lambda_parse, cur_lambda_masklm = self.lambda_parse.value, self.lambda_masklm.value + # prepare for masklm todo(note): always prepare for input-mask-dropout-like + input_word_mask_repl_arr, output_pred_mask_repl_arr, ouput_pred_idx_arr = self.masklm.prepare(annotated_insts, training) + # encode + enc_expr, mask_expr = self.enc.run(annotated_insts, training, input_word_mask_repl=input_word_mask_repl_arr) + # get loss + if cur_lambda_parse > 0.: + parsing_loss = [self.dec.loss(annotated_insts, enc_expr, mask_expr)] + else: + parsing_loss = [] + if cur_lambda_masklm > 0.: + masklm_loss = [self.masklm.loss(enc_expr, output_pred_mask_repl_arr, ouput_pred_idx_arr)] + else: + masklm_loss = [] + # ----- + info = {"fb": 1, "sent": len(annotated_insts), "tok": sum(map(len, annotated_insts))} + self.collect_loss_and_backward(["parse", "masklm"], [parsing_loss, masklm_loss], + [cur_lambda_parse, cur_lambda_masklm], info, training, loss_factor) + return info + + # decoding + def inference_on_batch(self, insts: List[ParseInstance], **kwargs): + with BK.no_grad_env(): + self.refresh_batch(False) + # encode + enc_expr, mask_expr = self.enc.run(insts, False, input_word_mask_repl=None) # no masklm in testing + # decode + self.dec.predict(insts, enc_expr, mask_expr) + # ===== + # test for masklm + input_word_mask_repl_arr, output_pred_mask_repl_arr, ouput_pred_idx_arr = self.masklm.prepare(insts, False) + enc_expr2, mask_expr2 = self.enc.run(insts, False, input_word_mask_repl=input_word_mask_repl_arr) + masklm_loss = self.masklm.loss(enc_expr2, output_pred_mask_repl_arr, ouput_pred_idx_arr) + masklm_loss_val, masklm_loss_count, masklm_corr_count = [z.item() for z in masklm_loss[0]] + # ===== + info = {"sent": len(insts), "tok": sum(map(len, insts)), + "masklm_loss_val": masklm_loss_val, "masklm_loss_count": masklm_loss_count, "masklm_corr_count": masklm_corr_count} + return info + +# b tasks/zdpar/zfp/fp:223 diff --git a/tasks/zdpar/zfp/masklm.py b/tasks/zdpar/zfp/masklm.py new file mode 100644 index 0000000..e173969 --- /dev/null +++ b/tasks/zdpar/zfp/masklm.py @@ -0,0 +1,81 @@ +# + +# the module for masked lm + +from typing import List +import numpy as np + +from msp.utils import Conf, Random, zlog, JsonRW, zfatal, zwarn +from msp.data import VocabPackage, MultiHelper +from msp.nn import BK +from msp.nn.layers import BasicNode, Affine, RefreshOptions, NoDropRop +from msp.zext.seq_helper import DataPadder + +from ..common.data import ParseInstance + +class MaskLMNodeConf(Conf): + def __init__(self): + self._input_dim = -1 + # ---- + # hidden layer + self.hid_dim = 300 + self.hid_act = "elu" + # mask/pred options + self.mask_rate = 0. + self.min_mask_rank = 2 + self.max_pred_rank = 2000 + self.init_pred_from_pretrain = False + +class MaskLMNode(BasicNode): + def __init__(self, pc: BK.ParamCollection, conf: MaskLMNodeConf, vpack: VocabPackage): + super().__init__(pc, None, None) + self.conf = conf + # vocab and padder + self.word_vocab = vpack.get_voc("word") + self.padder = DataPadder(2, pad_vals=self.word_vocab.pad, mask_range=2) # todo(note): -id is very large + # models + self.hid_layer = self.add_sub_node("hid", Affine(pc, conf._input_dim, conf.hid_dim, act=conf.hid_act)) + self.pred_layer = self.add_sub_node("pred", Affine(pc, conf.hid_dim, conf.max_pred_rank+1, init_rop=NoDropRop())) + if conf.init_pred_from_pretrain: + npvec = vpack.get_emb("word") + if npvec is None: + zwarn("Pretrained vector not provided, skip init pred embeddings!!") + else: + with BK.no_grad_env(): + self.pred_layer.ws[0].copy_(BK.input_real(npvec[:conf.max_pred_rank+1].T)) + zlog(f"Init pred embeddings from pretrained vectors (size={conf.max_pred_rank+1}).") + + # return (input_word_mask_repl, output_pred_mask_repl, ouput_pred_idx) + def prepare(self, insts: List[ParseInstance], training): + conf = self.conf + word_idxes = [z.words.idxes for z in insts] + word_arr, input_mask = self.padder.pad(word_idxes) # [bsize, slen] + # prepare for the masks + input_word_mask = (Random.random_sample(word_arr.shape) < conf.mask_rate) & (input_mask>0.) + input_word_mask &= (word_arr >= conf.min_mask_rank) + input_word_mask[:, 0] = False # no masking for special ROOT + output_pred_mask = (input_word_mask & (word_arr <= conf.max_pred_rank)) + return input_word_mask.astype(np.float32), output_pred_mask.astype(np.float32), word_arr + + # [bsize, slen, *] + def loss(self, repr_t, pred_mask_repl_arr, pred_idx_arr): + mask_idxes, mask_valids = BK.mask2idx(BK.input_real(pred_mask_repl_arr)) # [bsize, ?] + if BK.get_shape(mask_idxes, -1) == 0: # no loss + zzz = BK.zeros([]) + return [[zzz, zzz, zzz]] + else: + target_reprs = BK.gather_first_dims(repr_t, mask_idxes, 1) # [bsize, ?, *] + target_hids = self.hid_layer(target_reprs) + target_scores = self.pred_layer(target_hids) # [bsize, ?, V] + pred_idx_t = BK.input_idx(pred_idx_arr) # [bsize, slen] + target_idx_t = pred_idx_t.gather(-1, mask_idxes) # [bsize, ?] + target_idx_t[(mask_valids<1.)] = 0 # make sure invalid ones in range + # get loss + pred_losses = BK.loss_nll(target_scores, target_idx_t) # [bsize, ?] + pred_loss_sum = (pred_losses * mask_valids).sum() + pred_loss_count = mask_valids.sum() + # argmax + _, argmax_idxes = target_scores.max(-1) + pred_corrs = (argmax_idxes == target_idx_t).float() * mask_valids + pred_corr_count = pred_corrs.sum() + return [[pred_loss_sum, pred_loss_count, pred_corr_count]] diff --git a/tasks/zdpar/zfp/run_zfp.sh b/tasks/zdpar/zfp/run_zfp.sh new file mode 100644 index 0000000..e499401 --- /dev/null +++ b/tasks/zdpar/zfp/run_zfp.sh @@ -0,0 +1,98 @@ +#!/usr/bin/env bash + +# running for zfp + +# -- +# go: create new dir? +if [[ ${USE_RDIR} != "" ]]; then +echo "Going into a new dir for the running: ${USE_RDIR}" +mkdir -p ${USE_RDIR} +cd ${USE_RDIR} +fi +# -- + +# ===== +# find the dir +SRC_DIR="../src/" +if [[ ! -d ${SRC_DIR} ]]; then +SRC_DIR="../../src/" +if [[ ! -d ${SRC_DIR} ]]; then +SRC_DIR="../../../src/" +fi +fi +DATA_DIR="${SRC_DIR}/../data/" # set data dir + +# ===== +function decode_one () { +if [[ -z ${WSET} ]]; then WSET="test"; fi +for cl in en ar eu zh 'fi' he hi it ja ko ru sv tr; +do +# running for one + CUDA_VISIBLE_DEVICES=${RGPU} PYTHONPATH=${SRC_DIR} python3 ${SRC_DIR}/tasks/cmd.py zdpar.main.test ${MDIR}/_conf conf_output: test:${DATA_DIR}/${cl}_${WSET}.conllu output_file:crout_${cl}${OUTINFFIX}.out model_load_name:${MDIR}/zmodel.best dict_dir:${MDIR} log_file: +done +} +# RGPU=? MDIR=../? OUTINFFIX= WSET= decode_one |& tee log.decode + +# ===== +function run_one () { +# check envs +echo "The environment variables: RGPU=${RGPU} CUR_LANG=${CUR_LANG} CUR_RUN=${CUR_RUN}"; + +# ===== +# basis +base_opt="conf_output:_conf init_from_pretrain:0" +base_opt+=" drop_hidden:0.2 drop_embed:0.2 fix_drop:0 dropmd_embed:0." +base_opt+=" lrate.val:0.00005 lrate.min_val:0.000001 tconf.batch_size:16 lrate.m:0.5 min_save_epochs:10 max_epochs:100 patience:5 anneal_times:5" + +# which system +base_opt+=" partype:fp use_label0:1" + +# fine-tune bert +base_opt+=" bert2_model:bert-base-multilingual-cased bert2_output_layers:[-1] bert2_trainable_layers:13" +# or simply use features +#base_opt+=" bert2_model:bert-base-multilingual-cased bert2_output_layers:[4,6,8]" + +# decoder +base_opt+=" arc_space:512 lab_space:128" +base_opt+=" use_biaffine:1 biaffine_div:0. use_ff1:1 use_ff2:0 ff2_hid_layer:0" + +# encoder +# by default no encoder +if [[ ${USERNN} == 1 ]]; then +base_opt+=" enc_conf.enc_rnn_layer:3 enc_conf.enc_att_layer:0 lrate.val:0.001" +fi +if [[ ${USEATT} == 1 ]]; then +# specific for self-att encoder +base_opt+=" middle_dim:512 enc_conf.enc_hidden:512" +base_opt+=" enc_conf.enc_rnn_layer:0 enc_conf.enc_att_layer:6 enc_conf.enc_att_add_wrapper:addnorm enc_conf.enc_att_conf.clip_dist:10 enc_conf.enc_att_conf.use_neg_dist:0 enc_conf.enc_att_conf.att_dropout:0.1 enc_conf.enc_att_conf.use_ranges:1 enc_conf.enc_att_fixed_ranges:5,10,20,30,50,100 enc_conf.enc_att_final_act:linear" +base_opt+=" max_epochs:100 train_skip_length:80 tconf.batch_size:80 split_batch:2 biaffine_div:0." +base_opt+=" lrate.val:0.0002 lrate_warmup:-2 drop_embed:0.1 dropmd_embed:0. drop_hidden:0.1" +fi + +# run +data_opt="train:${DATA_DIR}/${CUR_LANG}_train.conllu dev:${DATA_DIR}/${CUR_LANG}_dev.conllu test:${DATA_DIR}/${CUR_LANG}_test.conllu" +LOG_PREFIX="_log" +# train +if [[ -n ${DEBUG} ]]; then + CUDA_VISIBLE_DEVICES=${RGPU} PYTHONPATH=${SRC_DIR} python3 -m pdb ${SRC_DIR}/tasks/cmd.py zdpar.main.train device:0 ${data_opt} ${base_opt} "niconf.random_seed:9347${CUR_RUN}" "niconf.random_cuda_seed:9349${CUR_RUN}" ${EXTRA_ARGS} +else + CUDA_VISIBLE_DEVICES=${RGPU} PYTHONPATH=${SRC_DIR} python3 ${SRC_DIR}/tasks/cmd.py zdpar.main.train device:0 ${data_opt} ${base_opt} "niconf.random_seed:9347${CUR_RUN}" "niconf.random_cuda_seed:9349${CUR_RUN}" ${EXTRA_ARGS} |& tee ${LOG_PREFIX}.train +fi +# decode targets +RGPU=${RGPU} MDIR=./ OUTINFFIX= WSET= decode_one |& tee ${LOG_PREFIX}.decode +} + +# actual run +if [[ -z "${CUR_LANG}" ]]; then + CUR_LANG="en" +fi +# -- +run_one +# -- +if [[ ${USE_RDIR} != "" ]]; then +cd $OLDPWD +fi + + +# ===== +# RGPU=1 DEBUG= CUR_LANG=en CUR_RUN=1 USERNN=0 USEATT=0 USE_RDIR= bash _go.sh