diff --git a/msp/__init__.py b/msp/__init__.py index 12bae39..54ffea4 100644 --- a/msp/__init__.py +++ b/msp/__init__.py @@ -1,18 +1,28 @@ # -# The Mingle Structured Prediction (v0plus) package +# The Mingled Structured Prediction (v0plus) package # by zzs (from 2018.02 - now) -# dependencies: pytorch, numpy, scipy, gensim, cython +# dependencies: pytorch, numpy, scipy, gensim, cython, pybind11 +# conda install pytorch numpy scipy gensim cython pybind11 VERSION_MAJOR = 0 -VERSION_MINOR = 0 +VERSION_MINOR = 1 VERSION_PATCH = 1 - +VERSION_STATUS = "dev" + +# TODO(!) +# nn optimizer / param groups? +# check nn module (for simplification?) +# new model/training/testing scheme -> make it more modularized +# Conf's init: what types +# dropout setting: use training/testing(with/wo)-mode +# easy-to-use calculations result-representing tools for analysis +# various tools for python as the replacement of direct bash shell +# gru and cnn have problems? def version(): - return (VERSION_MAJOR, VERSION_MINOR, VERSION_PATCH) - + return (VERSION_MAJOR, VERSION_MINOR, VERSION_PATCH, VERSION_STATUS) __version__ = ".".join(str(z) for z in version()) @@ -26,9 +36,9 @@ def version(): # !!: (again-corrected) The goal is not to build a whole framework, but several useful pieces. # conventions -# todo(0)/todo(warn): simply warning -# todo(+N): need what level of efforts -# TODO(!): unfinished, real todo +# todo(0)/todo(warn)/todo(note): simply warning or noting +# todo(+N): need what level of efforts, +N means lots of efforts +# TODO: unfinished, real todo # hierarchically: msp -> scripts / tasks, no reverse ref allowed! diff --git a/msp/cmp/zarray.py b/msp/cmp/zarray.py index b5c3d4b..99f2328 100644 --- a/msp/cmp/zarray.py +++ b/msp/cmp/zarray.py @@ -5,6 +5,8 @@ import numpy as np +# todo(warn): should use pandas! + class ZArray: def __init__(self, arr: np.ndarray, names=None): self.arr = arr diff --git a/msp/cmp/ztab.py b/msp/cmp/ztab.py new file mode 100644 index 0000000..b4aed21 --- /dev/null +++ b/msp/cmp/ztab.py @@ -0,0 +1,34 @@ +# + +# +class ZTab: + @staticmethod + def read(ss, fsep='\t', lsep='\n'): + pieces = [z.split(fsep) for z in ss.split(lsep)] + num_col = len(pieces[0]) - 1 + num_row = len(pieces) - 1 + row_names = [z[0] for z in pieces[1:]] + col_names = pieces[0][1:] + ret = {} + # add entries + for one_row in pieces[1:]: + k = one_row[0] + assert k not in ret + cur_m = {} + assert len(one_row)-1 == num_col + for n,v in zip(col_names, one_row[1:]): + cur_m[n] = v + ret[k] = cur_m + return row_names, col_names, ret + + @staticmethod + def write(row_names, col_names, value_f, fsep='\t', lsep='\n'): + pieces = [] + # add head line + pieces.append([""] + col_names) + # add each rows + for r in row_names: + pieces.append([r] + [value_f(r, c) for c in col_names]) + # finally concat + ss = lsep.join([fsep.join(z) for z in pieces]) + return ss diff --git a/msp/data/__init__.py b/msp/data/__init__.py index 9afae99..9dd37fb 100644 --- a/msp/data/__init__.py +++ b/msp/data/__init__.py @@ -3,4 +3,4 @@ from .data import Instance, TextReader, FdReader, WordNormer from .vocab import Vocab, VocabHelper, VocabBuilder, WordVectors, VocabPackage, MultiHelper from .streamer import Streamer, AdapterStreamer, FAdapterStreamer, FileOrFdStreamer, BatchArranger, InstCacher, \ - MultiCatStreamer, IterStreamer, MultiZipStreamer + MultiCatStreamer, IterStreamer, MultiZipStreamer, FListAdapterStream, MultiJoinStreamer, ShuffleStreamer diff --git a/msp/data/streamer.py b/msp/data/streamer.py index 8361366..70e9a9d 100644 --- a/msp/data/streamer.py +++ b/msp/data/streamer.py @@ -1,6 +1,6 @@ # -from msp.utils import zopen, zcheck, zfatal, ZException +from msp.utils import zopen, zcheck, zfatal from msp.utils import Constants, Random from typing import Iterable, Sequence @@ -147,10 +147,25 @@ def _next(self): else: return z -# f can returns a list (augment or filter) -# TODO(+1) +# f: Inst -> List[Inst] (augment or filter) class FListAdapterStream(AdapterStreamer): - pass + def __init__(self, base_streamer, f): + super().__init__(base_streamer) + self.f = f + self.cache = [] + self.cache_idx = 0 + + def _next(self): + while self.cache_idx >= len(self.cache): + one = self.base_streamer_.next() + if self.base_streamer_.is_eos(one): + return None + self.cache.clear() + self.cache_idx = 0 + self.cache.extend(self.f(one)) + r = self.cache[self.cache_idx] + self.cache_idx += 1 + return r # ===== # Random Streams @@ -202,13 +217,6 @@ def _next(self): if self.rs_.next(): return one -# TODO(+2) -class ShuffleStreamer(AdapterStreamer): - # -1 means read all and shuffle - def __init__(self, base_streamer, batch_size=-1): - super().__init__(base_streamer) - raise NotImplementedError() - # ===== # Multi Streamers @@ -247,6 +255,45 @@ def _restart(self): one.restart() self._mc_set_cur(0) +# +# like MultiCat, but join the streams (horizontally) rather than concat (vertically) +# todo(note): currently mixing them 1 by 1; maybe not efficient enough if there are too many streamers +class MultiJoinStreamer(MultiStreamer): + def __init__(self, base_streamers: Sequence): + super().__init__(base_streamers) + self.cur_pointer = 0 + self.ended = [False] * len(self.base_streamers_) + + def _next(self): + # find the next active streamer + starting_point = self.cur_pointer + num_streamers = self.num_streamers_ + cur_point = starting_point + flag_success = False + one = None + for i in range(num_streamers): # at most travel one round + if not self.ended[cur_point]: + new_streamer = self.base_streamers_[cur_point] + one = new_streamer.next() + if new_streamer.is_eos(one): + self.ended[cur_point] = True + else: + flag_success = True + cur_point = (cur_point + 1) % num_streamers + if flag_success: + break + self.cur_pointer = cur_point + if flag_success: + return one + else: + return None + + def _restart(self): + for one in self.base_streamers_: + one.restart() + self.cur_pointer = 0 + self.ended = [False] * len(self.base_streamers_) + # # zip-like multi-streamer, with various modes, return a list of instances # -- can be useful to combine with LoopStreamer for wrapping-around @@ -301,6 +348,9 @@ def get(self): def reset(self): raise NotImplementedError() + def clear(self): + raise NotImplementedError() + def __len__(self): raise NotImplementedError() @@ -325,6 +375,10 @@ def get(self): self.ptr += 1 return self.c[self.ptr-1] + def clear(self): + self.c.clear() + self.ptr = 0 + def reset(self): self.ptr = 0 if self.shuffle: @@ -347,6 +401,24 @@ def _restart(self): self.cache.put(one) self.cache.reset() +# todo(+1): partial shuffle to avoid read all? +class ShuffleStreamer(AdapterStreamer): + # -1 means read all and shuffle + def __init__(self, src_stream, cache_builder=InplacedCache): + super().__init__(src_stream) + self.src = src_stream + self.cache = cache_builder(shuffle=True) + + def _next(self): + return self.cache.get() + + def _restart(self): + # todo(note): rebuild cache each time (do not need to call super, since for-loop will trigger the src.restart) + self.cache.clear() + for one in self.src: + self.cache.put(one) + self.cache.reset() + # ===== # BatchHelper diff --git a/msp/data/vocab.py b/msp/data/vocab.py index e3675fb..3d08eeb 100644 --- a/msp/data/vocab.py +++ b/msp/data/vocab.py @@ -33,12 +33,12 @@ def __init__(self, name=""): @staticmethod def read(fname): one = Vocab() - JsonRW.load_from_file(one, fname) + JsonRW.from_file(one, fname) zlog("-- Read Dictionary from %s: Finish %s." % (fname, len(one)), func="io") return one def write(self, fname): - JsonRW.save_to_file(self, fname) + JsonRW.to_file(self, fname) zlog("-- Write Dictionary to %s: Finish %s." % (fname, len(self)), func="io") # ------------- @@ -104,8 +104,12 @@ def idx2val(self, idx): # ===== # ranges if used as target vocab - def trg_len(self): - return self.target_length_ + def trg_len(self, idx_orig): + # if idx_orig, return more len including the start ranges to make the idxes the same + if idx_orig: + return self.target_length_ + self.target_range_[0] + else: + return self.target_length_ # target idx to real idx def trg_pred2real(self, idx): @@ -159,7 +163,7 @@ def i2w(dicts, ii, rm_eos=True, factor_split='|'): return tmp # filter inits for embeddings - def filter_embed(self, wv, init_nohit=0., scale=1.0): + def filter_embed(self, wv: 'WordVectors', init_nohit=0., scale=1.0, assert_all_hit=False): if init_nohit <= 0.: get_nohit = lambda s: np.zeros((s,), dtype=np.float32) else: @@ -177,6 +181,9 @@ def filter_embed(self, wv, init_nohit=0., scale=1.0): # value = np.zeros((wv.embed_size,), dtype=np.float32) res["no-hit"] += 1 ret.append(value) + # + if assert_all_hit: + zcheck(res["no-hit"]==0, f"Filter-embed error: assert all-hit but get no-hit of {res['no-hit']}") printing("Filter pre-trained embed: %s, no-hit is inited with %s." % (res, init_nohit)) return np.asarray(ret, dtype=np.float32) * scale @@ -216,11 +223,11 @@ def __init__(self, vname="", pre_list=DEFAULT_PRE_LIST, post_list=DEFAULT_POST_L @staticmethod def _build_check(v): # check specials are included - zcheck(lambda: all(x in v.v for x in v.pre_list), "Not including pre_specials", func="warn") - zcheck(lambda: all(x in v.v for x in v.post_list), "Not including post_specials", func="warn") + zcheck(lambda: all(x in v.v for x in v.pre_list), "Not including pre_specials") + zcheck(lambda: all(x in v.v for x in v.post_list), "Not including post_specials") # check special tokens - zcheck(lambda: all(v.v[x]=len(v.v)-len(v.post_list) for x in v.post_list), "Get unexpected post_special words in plain words!!", func="warn") + zcheck(lambda: all(v.v[x]=len(v.v)-len(v.post_list) for x in v.post_list), "Get unexpected post_special words in plain words!!") @staticmethod def _build_prop(v): @@ -362,18 +369,28 @@ def __init__(self, sep=" "): self.wmap = {} # str -> idx self.vecs = [] # idx -> vec self.sep = sep + # + self.hits = {} # hit keys by any queries with __contains__: norm_until_hit, has_key, get_vec + + # + def __contains__(self, item): + if item in self.wmap: + self.hits[item] = self.hits.get(item, 0) + 1 + return True + else: + return False # return (hit?, norm_name, normed_w) def norm_until_hit(self, w): orig_w = w for norm_name, norm_f in WordVectors.WORD_NORMERS: w = norm_f(w) - if w in self.wmap: + if w in self: return True, norm_name, w return False, "", orig_w def has_key(self, k, norm=True): - if k in self.wmap: + if k in self: return True elif norm: return self.norm_until_hit(k)[0] @@ -381,7 +398,7 @@ def has_key(self, k, norm=True): return False def get_vec(self, k, df=None, norm=True): - if k in self.wmap: + if k in self: return self.vecs[self.wmap[k]] elif norm: hit, _, w = self.norm_until_hit(k) @@ -390,11 +407,19 @@ def get_vec(self, k, df=None, norm=True): return df def save(self, fname): - printing("Saving w2v num_words=%d, embed_size=%d to %s." % (self.num_words, self.embed_size, fname)) + printing(f"Saving w2v num_words={self.num_words:d}, embed_size={self.embed_size:d} to {fname}.") zcheck(self.num_words == len(self.vecs), "Internal error: unmatched number!") with zopen(fname, "w") as fd: WordVectors.save_txt(fd, self.words, self.vecs, self.sep) + def save_hits(self, fname): + num_hits = len(self.hits) + printing(f"Saving hit w2v num_words={num_hits:d}, embed_size={self.embed_size:d} to {fname}.") + with zopen(fname, "w") as fd: + tmp_words = sorted(self.hits.keys(), key=lambda k: self.wmap[k]) # use original ordering + tmp_vecs = [self.vecs[self.wmap[k]] for k in tmp_words] + WordVectors.save_txt(fd, tmp_words, tmp_vecs, self.sep) + @staticmethod def save_txt(fd, words, vecs, sep): num_words = len(words) @@ -419,7 +444,8 @@ def merge_others(self, others): this_all_num = other.num_words this_added_num = 0 for one_w, one_vec in zip(other.words, other.vecs): - if one_w not in self.wmap: + # keep the old one! + if one_w not in self.wmap: # here, does not record as hits! this_added_num += 1 self.wmap[one_w] = len(self.words) self.words.append(one_w) @@ -443,6 +469,7 @@ def load(fname, binary=False, txt_sep=" ", aug_code=""): def _load_txt(fname, sep=" "): printing("Going to load pre-trained (txt) w2v from %s ..." % fname) one = WordVectors(sep=sep) + repeated_count = 0 with zopen(fname) as fd: # first line line = fd.readline() @@ -457,7 +484,14 @@ def _load_txt(fname, sep=" "): line = line.rstrip() fields = line.split(sep) word, vec = fields[0], [float(x) for x in fields[1:]] - zcheck(word not in one.wmap, "Repeated key.") + # zcheck(word not in one.wmap, "Repeated key.") + # keep the old one + if word in one.wmap: + repeated_count += 1 + zwarn(f"Repeat key {word}") + line = fd.readline() + continue + # if one.embed_size is None: one.embed_size = len(vec) else: @@ -467,11 +501,10 @@ def _load_txt(fname, sep=" "): one.words.append(word) line = fd.readline() # final - if one.num_words is None: - one.num_words = len(one.vecs) - printing("Read ok: w2v num_words=%d, embed_size=%d." % (one.num_words, one.embed_size)) - else: - zcheck(one.num_words == len(one.vecs), "Unmatched num of words.") + if one.num_words is not None: + zcheck(one.num_words == len(one.vecs)+repeated_count, "Unmatched num of words.") + one.num_words = len(one.vecs) + printing(f"Read ok: w2v num_words={one.num_words:d}, embed_size={one.embed_size:d}, repeat={repeated_count:d}") return one @staticmethod @@ -480,6 +513,7 @@ def _load_bin(fname): one = WordVectors() # kv = KeyedVectors.load_word2vec_format(fname, binary=True) + # KeyedVectors.save_word2vec_format() one.num_words, one.embed_size = len(kv.vectors), len(kv.vectors[0]) for w, z in kv.vocab.items(): one.vecs.append(kv.vectors[z.index]) @@ -497,11 +531,11 @@ def __init__(self, vocabs: Dict, embeds: Dict): self.vocabs = vocabs self.embeds = embeds - def get_voc(self, name): - return self.vocabs[name] + def get_voc(self, name, df=None): + return self.vocabs.get(name, df) - def get_emb(self, name): - return self.embeds[name] + def get_emb(self, name, df=None): + return self.embeds.get(name, df) def put_voc(self, name, v): self.vocabs[name] = v @@ -520,7 +554,7 @@ def load(self, prefix="./"): for name in self.embeds: fname = prefix+"ve_"+name+".pic" if FileHelper.exists(fname): - self.embeds[name] = PickleRW.read(fname) + self.embeds[name] = PickleRW.from_file(fname) else: self.embeds[name] = None @@ -532,7 +566,7 @@ def save(self, prefix="./"): for name, vv in self.embeds.items(): fname = prefix+"ve_"+name+".pic" if vv is not None: - PickleRW.write(fname, vv) + PickleRW.to_file(vv, fname) # helpers for handling multi-lingual data, simply adding lang-code prefix @@ -564,22 +598,27 @@ def strip_aug_prefix(w): # keep main's pre&post, but put aug's true words before post and make corresponding changes to the arr # -> return (new_vocab, new_arr) @staticmethod - def aug_vocab_and_arr(main_vocab, main_arr, aug_vocab, aug_arr): + def aug_vocab_and_arr(main_vocab, main_arr, aug_vocab, aug_arr, aug_override): # first merge the vocab new_vocab = VocabBuilder.merge_vocabs([main_vocab, aug_vocab], sort_by_count=False, pre_list=main_vocab.pre_list, post_list=main_vocab.post_list) # then find the arrays + # todo(+1): which order to find words + assert aug_override, "To be implemented for other ordering!" + # new_arr = [] main_hit = aug_hit = 0 for idx in range(len(new_vocab)): word = new_vocab.idx2word(idx) - # todo(warn): selecting the embeds according to order - main_orig_idx = main_vocab.get(word) - if main_orig_idx is None: - aug_orig_idx = aug_vocab[word] # must be there! - new_arr.append(aug_arr[aug_orig_idx]) - else: + # todo(warn): selecting the embeds in aug first (make it possible to override original ones!) + aug_orig_idx = aug_vocab.get(word) + if aug_orig_idx is None: + main_orig_idx = main_vocab[word] # must be there! new_arr.append(main_arr[main_orig_idx]) + main_hit += 1 + else: + new_arr.append(aug_arr[aug_orig_idx]) + aug_hit += 1 zlog(f"For the final merged arr, the composition is all={len(new_arr)},main={main_hit},aug={aug_hit}") ret_arr = np.asarray(new_arr) return new_vocab, ret_arr diff --git a/msp/model/model.py b/msp/model/model.py index e9b436c..dfcecee 100644 --- a/msp/model/model.py +++ b/msp/model/model.py @@ -6,12 +6,12 @@ class Model(object): # list(mini-batch) of to-decode instances, # results are written in-place, return info. - def inference_on_batch(self, insts): + def inference_on_batch(self, insts, **kwargs): raise NotImplementedError() # list(mini-batch) of annotated instances # optional results are written in-place? return info. - def fb_on_batch(self, annotated_insts, training=True): + def fb_on_batch(self, annotated_insts, **kwargs): raise NotImplementedError() # called before each mini-batch @@ -19,9 +19,13 @@ def refresh_batch(self, training): raise NotImplementedError() # called for one step of paramter-update - def update(self, lrate): + def update(self, lrate, grad_factor): raise NotImplementedError() + # instance preparer(inst -> inst) None means no specific preparer + def get_inst_preper(self, training, **kwargs): + return None + # get all values need to be schedules, like lrate def get_scheduled_values(self): raise NotImplementedError() diff --git a/msp/nn/__init__.py b/msp/nn/__init__.py index 3838e1f..d133af6 100644 --- a/msp/nn/__init__.py +++ b/msp/nn/__init__.py @@ -17,7 +17,7 @@ from .backends import BK from .backends.common import COMMON_CONFIG, NIConf -from .expr import ExprWrapperPool +from .expr import ExprWrapperPool, ExprWrapper, SliceManager, SlicedExpr from msp.utils import zlog diff --git a/msp/nn/backends/bktr.py b/msp/nn/backends/bktr.py index 149cec8..b41998a 100644 --- a/msp/nn/backends/bktr.py +++ b/msp/nn/backends/bktr.py @@ -6,12 +6,20 @@ from torch.nn import functional as F from torch.nn.utils import clip_grad_norm_ import numpy as np +from typing import List from .common import COMMON_CONFIG, get_unique_name, _my_get_params_init -from msp.utils import zwarn Expr = torch.Tensor -DEFAULT_DEVICE = torch.device("cpu") +CPU_DEVICE = torch.device("cpu") +DEFAULT_DEVICE = CPU_DEVICE + +# types +float32 = torch.float32 +float64 = torch.float64 +int32 = torch.int32 +int64 = torch.int64 +uint8 = torch.uint8 def init(): torch.set_num_threads(COMMON_CONFIG.num_threads) @@ -58,20 +66,90 @@ def get_params_init(shape, init, lookup): elif init == "default" or init == "glorot": nn.init.xavier_uniform_(x) elif init == "ortho": - nn.init.orthogonal_(x) + # todo(note): assume squared matrices + assert len(shape)==2 and (shape[0]%shape[1]==0 or shape[1]%shape[0]==0), "Invalid shape for ortho init" + s0, s1 = shape + if s0>s1: + for i in range(s0//s1): + nn.init.orthogonal_(x[i*s1:(i+1)*s1,:]) + else: + for i in range(s1//s0): + nn.init.orthogonal_(x[:,i*s0:(i+1)*s0]) else: assert False, "Unknown init method for BKTR: "+init return x +# todo(note): similar to Torch.Optimizer, but with extra settings +class Optim: + def __init__(self, optim_type, lrf_sv, oconf, params): + self.params_ = params + self.lrf_sv_ = lrf_sv + if optim_type == "sgd": + opt_ = optim.SGD(params, lr=0., momentum=oconf.sgd_momentum, weight_decay=oconf.weight_decay) + elif optim_type == "adagrad": + opt_ = optim.Adagrad(params, lr=0., weight_decay=oconf.weight_decay) + elif optim_type == "adam": + opt_ = optim.Adam(params, lr=0., betas=oconf.adam_betas, eps=oconf.adam_eps, weight_decay=oconf.weight_decay) + elif optim_type == "adadelta": + opt_ = optim.Adadelta(params, lr=0., rho=oconf.adadelta_rho, weight_decay=oconf.weight_decay) + else: + raise NotImplementedError("Unknown optim %s." % optim_type) + self.opt_ = opt_ + # + self.no_step_lrate0_ = oconf.no_step_lrate0 + self.cached_lrate_ = 0. + self.grad_clip_ = oconf.grad_clip + + def update(self, overall_lrate, grad_factor): + cur_lrate = overall_lrate * float(self.lrf_sv_) + if self.cached_lrate_ != cur_lrate: + # schedule lrate, do as lr_scheduler does + for param_group in self.opt_.param_groups: + param_group['lr'] = cur_lrate + self.cached_lrate_ = cur_lrate + if cur_lrate<=0. and self.no_step_lrate0_: + # no update + self.opt_.zero_grad() + else: + # todo(warn): useful for batch-split, div grad by splits + if grad_factor != 1.: + for p in self.params_: + if p.grad is not None: + p.grad.data.mul_(grad_factor) + if self.grad_clip_ > 0.: + clip_grad_norm_(self.params_, self.grad_clip_) + self.opt_.step() + self.opt_.zero_grad() + # todo(warn): here nn.Module simply used for Param Collection -class ParamCollection(object): - def __init__(self): +class ParamCollection: + def __init__(self, new_name_conv=True): self.model_ = nn.Module() - self.optim_ = None - self.prev_lrate_ = None - self.grad_clip_ = None + self.optims_ = [] + self.paramid2optid_ = {} # id -> list # self.name_dict = {} + self.new_name_conv = new_name_conv + self.new_name_conv_stack = [] + + # ===== + def nnc_push(self, name): + self.new_name_conv_stack.append(name) + + def nnc_pop(self, name): + # todo(note): there can be out-of-order because of wrappers + ridx = len(self.new_name_conv_stack) - 1 - self.new_name_conv_stack[::-1].index(name) + x = self.new_name_conv_stack.pop(ridx) + assert x == name, "Name unmatched!!" + + def nnc_name(self, name, check_stack=True): + if check_stack: + assert name == self.new_name_conv_stack[-1] + if self.new_name_conv: + return "/".join(self.new_name_conv_stack) + else: + return name + # ===== def get_unique_name(self, name): return get_unique_name(self.name_dict, name) @@ -83,6 +161,13 @@ def param_new(self, name, shape, init_weights, lookup=False): self.model_.register_parameter(name, p) return p + # special mode + # todo(WARN): for changing names while stacking modules; not elegant! + def param_rename(self, old_name, new_name): + p = self.model_.__getattr__(old_name) + self.model_.__delattr__(old_name) + self.model_.register_parameter(new_name, p) + # freeze param def param_set_trainable(self, p, trainable): bool_trainable = bool(trainable) @@ -90,73 +175,86 @@ def param_set_trainable(self, p, trainable): p.requires_grad = bool_trainable # tconf should have other properties: momentum, grad_clip, - def optimizer_set(self, optim_type, init_lrate, tconf): - if optim_type == "sgd": - self.optim_ = optim.SGD(self.model_.parameters(), lr=init_lrate, momentum=tconf.sgd_momentum) - elif optim_type == "adagrad": - self.optim_ = optim.Adagrad(self.model_.parameters(), lr=init_lrate) - elif optim_type == "adam": - self.optim_ = optim.Adam(self.model_.parameters(), lr=init_lrate, betas=tconf.adam_betas, eps=tconf.adam_eps) - else: - raise NotImplementedError("Unknown optim %s." % optim_type) - self.prev_lrate_ = init_lrate - self.grad_clip_ = tconf.grad_clip - - def optimizer_update(self, lrate): - if self.prev_lrate_ != lrate: - # schedule lrate, do as lr_scheduler does - for param_group in self.optim_.param_groups: - param_group['lr'] = lrate - self.prev_lrate_ = lrate - if self.grad_clip_ > 0.: - clip_grad_norm_(self.model_.parameters(), self.grad_clip_) - self.optim_.step() - self.model_.zero_grad() + def optimizer_set(self, optim_type, lrf_sv, oconf, params: List = None, check_repeat=True, check_full=False): + if params is None: + params = self.model_.parameters() + optim = Optim(optim_type, lrf_sv, oconf, params) + cur_optid = len(self.optims_) + self.optims_.append(optim) + # track all params + for p in params: + paramid = id(p) + if paramid not in self.paramid2optid_: + self.paramid2optid_[paramid] = [cur_optid] + else: + assert not check_repeat, "Err: repeated params in multiple optimizer!" + self.paramid2optid_[paramid].append(paramid) + if check_full: + for p in self.model_.parameters(): + assert id(p) in self.paramid2optid_, "Err: failed checking full, there is unadded param" + + def optimizer_update(self, overall_lrate, grad_factor): + for optim in self.optims_: + optim.update(overall_lrate, grad_factor) def save(self, path): torch.save(self.model_.state_dict(), path) - def load(self, path): - self.model_.load_state_dict(torch.load(path)) + def load(self, path, strict=True): + model = torch.load(path, map_location=DEFAULT_DEVICE) + self.model_.load_state_dict(model, strict=strict) # ===== the functions # ----- inputs # (inputs: python data type) -> FloatTensor -def input_real(inputs): +def input_real(inputs, device=None): if is_expr(inputs): return inputs - return torch.tensor(inputs, dtype=torch.float32, device=DEFAULT_DEVICE) + return torch.tensor(inputs, dtype=torch.float32, device=(DEFAULT_DEVICE if device is None else device)) # (inputs: python data type of indexes) -> LongTensor -def input_idx(inputs): +def input_idx(inputs, device=None): if is_expr(inputs): return inputs - return torch.tensor(inputs, dtype=torch.long, device=DEFAULT_DEVICE) + return torch.tensor(inputs, dtype=torch.long, device=(DEFAULT_DEVICE if device is None else device)) # (shape: ..., value: float) -> FloatTensor -def constants(shape, value=0.): - return torch.full(shape, value, dtype=torch.float32, device=DEFAULT_DEVICE) +def constants(shape, value=0., dtype=torch.float32, device=None): + return torch.full(shape, value, dtype=dtype, device=(DEFAULT_DEVICE if device is None else device)) # -def constants_idx(shape, value=0): - return torch.full(shape, value, dtype=torch.long, device=DEFAULT_DEVICE) +def constants_idx(shape, value=0, device=None): + return torch.full(shape, value, dtype=torch.long, device=(DEFAULT_DEVICE if device is None else device)) # return 2D eye matrix -def eye(n): - return torch.eye(n, dtype=torch.float32, device=DEFAULT_DEVICE) +def eye(n, device=None): + return torch.eye(n, dtype=torch.float32, device=(DEFAULT_DEVICE if device is None else device)) # -def arange_idx(*args): - return torch.arange(*args, dtype=torch.long, device=DEFAULT_DEVICE) +def arange_idx(*args, device=None): + return torch.arange(*args, dtype=torch.long, device=(DEFAULT_DEVICE if device is None else device)) # (shape: ..., p: rate of 1., mul: multiply) -> Tensor -def random_bernoulli(shape, p, mul): - x = torch.full(shape, p, dtype=torch.float32, device=DEFAULT_DEVICE) +def random_bernoulli(shape, p, mul, device=None): + x = torch.full(shape, p, dtype=torch.float32, device=(DEFAULT_DEVICE if device is None else device)) r = torch.bernoulli(x) * mul return r +# +def rand(shape, dtype=torch.float32, device=None): + return torch.rand(shape, dtype=dtype, device=(DEFAULT_DEVICE if device is None else device)) + +# move to other device +def to_device(x, device=None): + return x.to(device=(DEFAULT_DEVICE if device is None else device)) + +# +def copy(x, device=None): + # return torch.tensor(x, dtype=x.dtype, device=(x.device if device is None else device)) + return x.clone().detach() + # ----- ops # (input_list: list of tensors [bias, weight1, input1, ...]) -> Tensor @@ -186,6 +284,15 @@ def affine2(input_list): prefidx_dimensions.append(last_dim) return base.view(prefidx_dimensions) +# allow >2d input and unsqueezed shapes, simply use matmul +def affine3(input_list): + bias, base_idx = input_list[0], 1 + base = 0 + while base_idx < len(input_list): + base = base + torch.matmul(input_list[base_idx + 1], input_list[base_idx].t()) + base_idx += 2 + return base + bias + # (t: Tensor, size: how many pieces, dim: ...) -> list of Tensors def chunk(t, size, dim=-1): return torch.chunk(t, size, dim) @@ -230,6 +337,7 @@ def select(t, idxes, dim=0): abs = torch.abs avg = torch.mean bilinear = F.bilinear # (N,*,in1_features), (N,*,in2_features), (out_features, in1_features, in2_features), (out_features,) -> (N,*,out_features) +binary_cross_entropy = F.binary_cross_entropy binary_cross_entropy_with_logits = F.binary_cross_entropy_with_logits clamp = torch.clamp diagflat = torch.diagflat @@ -250,10 +358,11 @@ def select(t, idxes, dim=0): softmax = F.softmax squeeze = torch.squeeze split = torch.split +sqrt = torch.sqrt stack = torch.stack sum = torch.sum tanh = torch.tanh -topk = torch.topk +topk = torch.topk # todo(warn): return (val, idx) transpose = torch.transpose unsqueeze = torch.unsqueeze where = torch.where @@ -310,10 +419,38 @@ def multinomial_select(prob, num=1): values = torch.gather(prob, -1, idxes) return idxes, values +# sampling 1 with gumble +# argmax(logprob + -log(-log(Unif[0,1]))), return (val, idx) +# todo(note): keepdim for the output +def category_sample(logprob, dim=-1, keep_dim=True): + G = torch.rand(logprob.shape, dtype=torch.float32, device=DEFAULT_DEVICE) + X = logprob-(-G.log()).log() + V, I = X.max(dim) + if keep_dim: + V, I = V.unsqueeze(-1), I.unsqueeze(-1) + return V, I + def gather(t, idxes, dim=-1): idxes_t = input_idx(idxes) return torch.gather(t, dim, idxes_t) +# special gather for the first several dims +# t: [s1, s2, ..., sn-1, sn, ...]; idx: [s1, s2, ..., sn-1, k] +def gather_first_dims(t, idxes, dim): + t_shape = get_shape(t) + if dim < 0: + dim = len(t_shape) + dim + idxes_t = input_idx(idxes) + idx_shape = get_shape(idxes_t) + assert t_shape[:dim] == idx_shape[:-1] + # flatten and index select + t_shape0, t_shape1 = t_shape[:dim+1], t_shape[dim+1:] + flatten_t = t.view([-1] + t_shape1) # [s1*...*sn, ...] + basis_t = arange_idx(np.prod(t_shape0[:-1])).view(t_shape0[:-1]) * t_shape0[-1] # [s1, ..., sn-1] + basis_t = (basis_t.unsqueeze(-1) + idxes_t).view(-1) # [*] + output_t0 = torch.index_select(flatten_t, dim=0, index=basis_t) # [*, ...] + return output_t0.view(idx_shape + t_shape1) + # ===== # values & backward @@ -327,7 +464,12 @@ def set_value(t, val): # avoid recording grad_fn t.set_(input_real(val)) -def backward(loss): +def get_cpu_tensor(t): + return t.detach().cpu() + +def backward(loss, loss_factor: float): + if loss_factor != 1.: + loss = loss * loss_factor loss.backward() # directly setting param @@ -347,7 +489,7 @@ def has_nan(t): # nn.Conv1d # parameters already added here -class CnnOper(object): +class CnnOper: def __init__(self, node, n_input, n_output, n_window): # todo(warn): still belong to node self.w = node.add_param("w", (n_output, n_input, n_window)) @@ -379,3 +521,30 @@ def conv(self, input_expr): from torch.nn.backends.thnn import backend as thnn_backend gru_oper = thnn_backend.GRUCell lstm_oper = thnn_backend.LSTMCell + +# specific one for batched inv +if int(torch.__version__.split(".")[0])>=1: + # only available torch>=1.0.0 + get_inverse = lambda M, DiagM: M.inverse() +else: + # otherwise use gesv, which is deprecated in some later version + get_inverse = lambda M, DiagM: DiagM.gesv(M)[0] + +# special routines +# todo(note): for mask->idx: 1) topk+sort, 2) padding with extra 1s; currently using 2) +# the inputs should be 1. or 0. (float); [*, L] -> [*, max-count] +def mask2idx(mask_t, padding_idx=0): + mask_shape = get_shape(mask_t) # [*, L] + counts = mask_t.sum(-1).long() # [*] + max_count = counts.max().item() # int, the max expanding + padding_counts = max_count - counts # [*] + max_padding_count = padding_counts.max().item() # int, the max count of padding + pad_t = (arange_idx(max_padding_count) < padding_counts.unsqueeze(-1)).float() # [*, max_pad] + concat_t = concat([mask_t, pad_t], -1) # [*, L+max_pad] + final_shape = mask_shape[:-1] + [max_count] + ret_idxes = concat_t.nonzero()[:, -1].reshape(final_shape) + # get valid mask and set 0 for invalid ones + slen = mask_shape[-1] + valid_mask = (ret_idxes < slen).float() + ret_idxes[ret_idxes >= slen] = padding_idx + return ret_idxes, valid_mask diff --git a/msp/nn/backends/common.py b/msp/nn/backends/common.py index d4c87d3..d81ce89 100644 --- a/msp/nn/backends/common.py +++ b/msp/nn/backends/common.py @@ -53,12 +53,19 @@ def _my_get_params_init(shape, init, lookup): w0 = Random.randn_clip(shape, "winit") return w0*poss_scale elif init == "ortho": - assert len(shape)==2 and shape[0] % shape[1] == 0, "Bad shape %s for ortho_init" % shape - num = shape[0] // shape[1] + # todo(note): always assume init square matrices + assert len(shape)==2 and (shape[0] % shape[1] == 0 or shape[1] % shape[0] == 0), f"Bad shape {shape} for ortho_init!" + orig_num = shape[0] // shape[1] + if orig_num == 0: + num = shape[1] // shape[0] + else: + num = orig_num if num == 1: w0 = Random.ortho_weight(shape[1], "winit") else: w0 = np.concatenate([Random.ortho_weight(shape[1], "winit") for _ in range(num)]) + if orig_num == 0: # reverse it! + w0 = np.transpose(w0) return w0*poss_scale elif init == "zeros": return np.zeros(shape) diff --git a/msp/nn/expr.py b/msp/nn/expr.py index 65cde15..21a5e1e 100644 --- a/msp/nn/expr.py +++ b/msp/nn/expr.py @@ -38,6 +38,13 @@ def __init__(self, val, bsize): def val(self): return ExprWrapperPool.get(self.id) + # bsize == 1 + @staticmethod + def make_unit(val): + ew = ExprWrapper(val, 1) + s = SlicedExpr(ew, 0) + return s + # get raw value def get_val(self, item): return self.val[item] @@ -54,42 +61,47 @@ def __init__(self, ew, slice_idx): self.ew = ew self.slice_idx = slice_idx + def __repr__(self): + return f"[{self.ew.id}]({self.slice_idx}/{self.ew.bsize})" + # ===== Select & Combine ===== class SliceManager(object): - @staticmethod - def _check_full(idxes, length): - return len(idxes) == length and all(i==v for i,v in enumerate(idxes)) + # -- + # todo(warn): deprecated by utils.Helper.check_is_range + # @staticmethod + # def _check_full(idxes, length): + # return len(idxes) == length and all(i==v for i,v in enumerate(idxes)) # return: list of ew, list of list of slice-idx, tracking-idx(original) - # todo(warn): lose original order of slices - @staticmethod - def collect_slices(slices): - ews = [] # list of EW - idxes = [] # list of list of (sidx, ori_idx) - # scan slices and group - tmp_id2idx = {} - for ori_idx, s in enumerate(slices): - one_ew, one_sidx = s.ew, s.slice_idx - ew_id = one_ew.id - # - if ew_id not in tmp_id2idx: - tmp_id2idx[ew_id] = len(ews) - ews.append(one_ew) - idxes.append([]) - # - idx_in_vals = tmp_id2idx[ew_id] - idxes[idx_in_vals].append((one_sidx, ori_idx)) - # sort and split - slice_idxes, origin_idxes = [], [] - for one_idx, one_ew in zip(idxes, ews): - one_idx.sort() # low-idx to high-idx - one_slice_idxes = [x[0] for x in one_idx] - if SliceManager._check_full(one_slice_idxes, one_ew.bsize): - slice_idxes.append(None) - else: - slice_idxes.append(one_slice_idxes) - origin_idxes.append([x[1] for x in one_idx]) - return ews, slice_idxes, origin_idxes + # todo(warn): lose original order of slices; todo(warn): might be too complex + # @staticmethod + # def collect_slices(slices): + # ews = [] # list of EW + # idxes = [] # list of list of (sidx, ori_idx) + # # scan slices and group + # tmp_id2idx = {} + # for ori_idx, s in enumerate(slices): + # one_ew, one_sidx = s.ew, s.slice_idx + # ew_id = one_ew.id + # # + # if ew_id not in tmp_id2idx: + # tmp_id2idx[ew_id] = len(ews) + # ews.append(one_ew) + # idxes.append([]) + # # + # idx_in_vals = tmp_id2idx[ew_id] + # idxes[idx_in_vals].append((one_sidx, ori_idx)) + # # sort and split + # slice_idxes, origin_idxes = [], [] + # for one_idx, one_ew in zip(idxes, ews): + # one_idx.sort() # low-idx to high-idx + # one_slice_idxes = [x[0] for x in one_idx] + # if Helper.check_is_range(one_slice_idxes, one_ew.bsize): + # slice_idxes.append(None) + # else: + # slice_idxes.append(one_slice_idxes) + # origin_idxes.append([x[1] for x in one_idx]) + # return ews, slice_idxes, origin_idxes # return list of ew, list of combined indexes (retaining the original order) @staticmethod @@ -109,7 +121,7 @@ def _arrange_idxes(slices): idx_in_vals = tmp_id2idx[ew_id] bidxes.append(tmp_bidx_bases[idx_in_vals]+one_sidx) # check for perfect match - if SliceManager._check_full(bidxes, tmp_bidx_bases[-1]): + if Helper.check_is_range(bidxes, tmp_bidx_bases[-1]): bidxes = None return values, bidxes @@ -122,11 +134,12 @@ def _combine_recursive(values, bidxes): for name in v0: next_values = [z[name] for z in values] ret[name] = SliceManager._combine_recursive(next_values, bidxes) - elif isinstance(v0, list): + elif isinstance(v0, (list, tuple)): ret = [] for idx in range(len(v0)): next_values = [z[idx] for z in values] ret.append(SliceManager._combine_recursive(next_values, bidxes)) + # todo(+2): need to revert back to tuple if tuple? else: zcheck(BK.is_expr(v0), "Illegal combine value type.") # todo(warn): first concat and then select, may use more memory diff --git a/msp/nn/layers/__init__.py b/msp/nn/layers/__init__.py index a6c6d80..37f6a08 100644 --- a/msp/nn/layers/__init__.py +++ b/msp/nn/layers/__init__.py @@ -2,11 +2,12 @@ from .basic import BasicNode, RefreshOptions, ActivationHelper, Dropout, DropoutLastN from .basic import NoDropRop, NoFixRop, FreezeRop -from .ff import Affine, LayerNorm, MatrixNode, Embedding, PosiEmbedding +from .ff import Affine, LayerNorm, MatrixNode, Embedding, PosiEmbedding, RelPosiEmbedding from .multi import Sequential, Summer, Concater, Joiner, \ NodeWrapper, AddNormWrapper, AddActWrapper, HighWayWrapper, get_mlp from .enc import RnnNode, GruNode, LstmNode, RnnLayer, RnnLayerBatchFirstWrapper, CnnNode, CnnLayer, \ - TransformerEncoderLayer, TransformerEncoder -from .att import AttentionNode, FfAttentionNode, MultiHeadAttention + TransformerEncoderLayer, TransformerEncoder, Transformer2EncoderLayer, Transformer2Encoder +from .att import AttentionNode, FfAttentionNode, MultiHeadAttention, \ + MultiHeadRelationalAttention, MultiHeadSelfDistAttention, AttConf, AttDistHelper from .dec import * from .biaffine import BiAffineScorer diff --git a/msp/nn/layers/att.py b/msp/nn/layers/att.py index 5f52d2d..270b7ec 100644 --- a/msp/nn/layers/att.py +++ b/msp/nn/layers/att.py @@ -6,45 +6,73 @@ import numpy as np import math +from msp.utils import Constants, Conf from ..backends import BK -from .basic import ActivationHelper, BasicNode, Dropout, NoDropRop, NoFixRop, RefreshOptions -from .ff import Affine, Embedding -from msp.utils import Constants +from .basic import BasicNode, Dropout, NoDropRop, NoFixRop +from .ff import Affine, Embedding, PosiEmbedding +# ===== +# Attention Node Configure +# todo(warn): no keep of this conf, since +class AttConf(Conf): + def __init__(self): + self.type = "mh" + self.d_kqv = 64 # also as hidden layer size if needed + self.att_dropout = 0.1 + # mainly for multi-head + self.head_count = 8 + self.use_ranges = False + # + self.clip_dist = 0 + self.use_neg_dist = False + self.use_fix_dist = False + # + self.out_act = 'linear' + # for relational one + self.dim_r = 5 + # should be set specifically + self._fixed_range_val = -1 + +# ===== # (key, query, value) # todo(+2): use cache? class AttentionNode(BasicNode): - def __init__(self, pc, dim_k, dim_q, dim_v, att_dropout, init_rop=None, name=None): + def __init__(self, pc, dim_k, dim_q, dim_v, aconf: AttConf, init_rop=None, name=None): super().__init__(pc, name, init_rop) - # self.dim_k = dim_k self.dim_q = dim_q self.dim_v = dim_v # todo(+3): here no fixed shape rr = NoFixRop() - rr.add_fixed_value("hdrop", att_dropout) + rr.add_fixed_value("hdrop", aconf.att_dropout) self.adrop = self.add_sub_node("adrop", Dropout(pc, (), init_rop=rr)) + def __repr__(self): + return f"# AttNode({self.__class__.__name__}): k={self.dim_k},q={self.dim_q},v={self.dim_v}" + # return re-weighted vs def get_output_dims(self, *input_dims): return (self.dim_v, ) - # # - # def obtain_mask(self, mask_k, mask_q): - # if mask_k is None: - # if mask_q is None: return None - # else: return mask_q.unsqueeze(-1) - # else: - # if mask_q is None: mask_k.unsqueeze(-2) - # else: return mask_q.unsqueeze(-1)*mask_k.unsqueeze(-2) + @staticmethod + def get_att_node(node_type, pc, dim_k, dim_q, dim_v, aconf: AttConf, name=None, init_rop=None): + _ATT_TYPES = {"ff": FfAttentionNode, "mh": MultiHeadAttention, + "mhr": MultiHeadRelationalAttention, "mhrs": MultiHeadSelfDistAttention} + if isinstance(node_type, str): + node_c = _ATT_TYPES[node_type] + else: + node_c = node_type + ret = node_c(pc, dim_k, dim_q, dim_v, aconf, name=name, init_rop=init_rop) + return ret # feed-forward attention class FfAttentionNode(AttentionNode): - def __init__(self, pc, dim_k, dim_q, dim_v, hidden_size, att_dropout, init_rop=None, name=None): - super().__init__(pc, dim_k, dim_q, dim_v, att_dropout, init_rop, name) + def __init__(self, pc, dim_k, dim_q, dim_v, aconf: AttConf, init_rop=None, name=None): + super().__init__(pc, dim_k, dim_q, dim_v, aconf, init_rop, name) # parameters -- split mlp for possible cache-usage - self.affine_k = self.add_sub_node("ak", Affine(pc, dim_k, hidden_size, bias=True, init_rop=NoDropRop(), affine2=True)) - self.affine_q = self.add_sub_node("aq", Affine(pc, dim_q, hidden_size, bias=False, init_rop=NoDropRop(), affine2=True)) + hidden_size = aconf.d_kqv + self.affine_k = self.add_sub_node("ak", Affine(pc, dim_k, hidden_size, bias=True, init_rop=NoDropRop())) + self.affine_q = self.add_sub_node("aq", Affine(pc, dim_q, hidden_size, bias=False, init_rop=NoDropRop())) self.affine_top_w = self.add_param("w", (hidden_size, 1)) # special transpose # [*, len] @@ -57,40 +85,110 @@ def __call__(self, key, value, query, mask_k=None): scores = BK.matmul(att_hidden, self.affine_top_w).squeeze(-1) # [*, len_q, len_k] if mask_k is not None: scores += (1. - mask_k).unsqueeze(-2).unsqueeze(-3) * Constants.REAL_PRAC_MIN - # mask_bytes = (mask_k==0.) - # mask = mask_bytes.unsqueeze(-2).expand_as(scores) # mask: [*, len_q, len_k] - # scores = scores.masked_fill(mask, Constants.REAL_PRAC_MIN) attn = BK.softmax(scores, -1) # [*, len_q, len_k] drop_attn = self.adrop(attn) # [*, len_q, len_k] context = BK.matmul(drop_attn, value) # [*, len_q, dim_v] return context -# multi-head +# ===== +# MultiHeadAttention helpers + +class AttRangeHelper(BasicNode): + def __init__(self, pc, aconf: AttConf): + super().__init__(pc, None, None) + self.head_count = aconf.head_count + # exclude if dist>rr + fixed_range_val = aconf._fixed_range_val + if fixed_range_val >= 0: + self.att_ranges = [fixed_range_val] * self.head_count + else: + self.att_ranges = [2**z - 1 for z in range(self.head_count)] + self._att_ranges_t = None + + def refresh(self, rop=None): + super().refresh(rop) + # refresh reuse-able tensors: [*, query, key] + if self.att_ranges: + self._att_ranges_t = BK.input_real(self.att_ranges).unsqueeze(-1).unsqueeze(-1) + + def __call__(self, query_len, key_len, att_scores_t): + # -- ranged clip + # todo(+3): can have repeated calculations with distances + last_dims = [self.head_count, query_len, key_len] + rr_x = BK.unsqueeze(BK.arange_idx(0, key_len), 0) + rr_y = BK.unsqueeze(BK.arange_idx(0, query_len), 1) + rr_xy = BK.abs(rr_x - rr_y).float().unsqueeze(0).expand(last_dims) # [head, query, key] + scores = att_scores_t.masked_fill(rr_xy > self._att_ranges_t, Constants.REAL_PRAC_MIN) + return scores + +class AttDistHelper(BasicNode): + def __init__(self, pc, aconf: AttConf, dim: int): + super().__init__(pc, None, None) + self.use_neg_dist = aconf.use_neg_dist + self.clip_dist = aconf.clip_dist + if aconf.use_fix_dist: + assert not self.use_neg_dist, "Currently does not support Neg fixed distance" + posi_max_len = self.clip_dist + 1 + self.edge_atts = self.add_sub_node("ek", PosiEmbedding(pc, dim, posi_max_len)) + self.edge_values = self.add_sub_node("ev", PosiEmbedding(pc, dim, posi_max_len)) + else: + self.edge_atts = self.add_sub_node("ek", Embedding(pc, 2 * self.clip_dist + 1, dim, fix_row0=False, + init_rop=NoDropRop())) + self.edge_values = self.add_sub_node("ev", Embedding(pc, 2 * self.clip_dist + 1, dim, fix_row0=False, + init_rop=NoDropRop())) + + def obatin_from_distance(self, distance): + # use different ranges! + if not self.use_neg_dist: + distance = BK.clamp(BK.abs(distance), max=self.clip_dist) + else: + distance = BK.clamp(distance, min=-self.clip_dist, max=self.clip_dist) + self.clip_dist + dist_atts = self.edge_atts(distance) + dist_values = self.edge_values(distance) + return dist_atts, dist_values + + def __call__(self, query_len, key_len): + dist_x = BK.unsqueeze(BK.arange_idx(0, key_len), 0) + dist_y = BK.unsqueeze(BK.arange_idx(0, query_len), 1) + distance = dist_x - dist_y # [query, key] + dist_atts, dist_values = self.obatin_from_distance(distance) + # [lenq, lenk], [lenq, lenk, dim], [lenq, lenk, dim] + return distance, dist_atts, dist_values + +# ===== +# basic multi-head class MultiHeadAttention(AttentionNode): - def __init__(self, pc, dim_k, dim_q, dim_v, d_kqv, head_count, att_dropout, use_ranges, clip_dist, use_neg_dist, init_rop=None, name=None): - super().__init__(pc, dim_k, dim_q, dim_v, att_dropout, init_rop, name) - # - self.head_count = head_count + def __init__(self, pc, dim_k, dim_q, dim_v, aconf: AttConf, init_rop=None, name=None): + super().__init__(pc, dim_k, dim_q, dim_v, aconf, init_rop, name) self.model_dim = dim_v - self.d_kqv = d_kqv - self.clip_dist = clip_dist - self.use_neg_dist = use_neg_dist - self.use_ranges = use_ranges - self.att_ranges = [2**z-1 for z in range(self.head_count)] # exclude if dist>rr + self.head_count = aconf.head_count + self.d_kqv = aconf.d_kqv + self.clip_dist = aconf.clip_dist + self.use_neg_dist = aconf.use_neg_dist + # ===== + # pre-fixed values for later computation + self._att_scale = math.sqrt(self.d_kqv) + if aconf.use_ranges: + self.range_clipper = self.add_sub_node("rc", AttRangeHelper(pc, aconf)) + else: + self.range_clipper = None + # ===== + eff_hidden_size = self.head_count * self.d_kqv # todo(+2): should we use more dropouts here? # no dropouts here - self.affine_k = self.add_sub_node("ak", Affine(pc, dim_k, head_count*d_kqv, init_rop=NoDropRop(), bias=True, affine2=True)) - self.affine_q = self.add_sub_node("aq", Affine(pc, dim_q, head_count*d_kqv, init_rop=NoDropRop(), bias=True, affine2=True)) - self.affine_v = self.add_sub_node("av", Affine(pc, dim_v, head_count*d_kqv, init_rop=NoDropRop(), bias=True, affine2=True)) + self.affine_k = self.add_sub_node("ak", Affine(pc, dim_k, eff_hidden_size, init_rop=NoDropRop())) + self.affine_q = self.add_sub_node("aq", Affine(pc, dim_q, eff_hidden_size, init_rop=NoDropRop())) + self.affine_v = self.add_sub_node("av", Affine(pc, dim_v, eff_hidden_size, init_rop=NoDropRop())) # rel dist keys self.use_distance = (self.clip_dist > 0) if self.use_distance: - self.edge_keys = self.add_sub_node("ek", Embedding(pc, 2*self.clip_dist+1, d_kqv, fix_row0=False, init_rop=NoDropRop())) - self.edge_values = self.add_sub_node("ev", Embedding(pc, 2*self.clip_dist+1, d_kqv, fix_row0=False, init_rop=NoDropRop())) - # todo(+2): with drops(y) & act(n) & bias(y)? - self.final_linear = self.add_sub_node("fl", Affine(pc, head_count*d_kqv, dim_v, bias=True, affine2=True)) + self.dist_helper = self.add_sub_node("dh", AttDistHelper(pc, aconf, self.d_kqv)) + else: + self.dist_helper = None + # todo(+2): with drops(y) & act(?) & bias(y)? + self.final_linear = self.add_sub_node("fl", Affine(pc, eff_hidden_size, dim_v, act=aconf.out_act)) - # + # [*, len, head*dim] <-> [*, head, len, dim] def _shape_project(self, x): x_size = BK.get_shape(x)[:-1] + [self.head_count, -1] return BK.transpose(BK.reshape(x, x_size), -2, -3) @@ -100,32 +198,28 @@ def _unshape_project(self, x): return BK.reshape(BK.transpose(x, -2, -3), x_size[:-3]+[x_size[-2], -1]) # todo(+2): assuming (batch, len, dim) - def __call__(self, key, value, query, mask_k=None): + # prob_qk is outside extra-prob for mixing by prob_mix_rate + def __call__(self, key, value, query, mask_k=None, mask_qk=None, eprob_qk=None, eprob_mix_rate=0., eprob_head_count=0): batch_dim_size = len(BK.get_shape(key))-2 - key_len = BK.get_shape(key, -2) query_len = BK.get_shape(query, -2) + key_len = BK.get_shape(key, -2) # # prepare distances if needed if self.use_distance: - dist_x = BK.unsqueeze(BK.arange_idx(0, key_len), 0) - dist_y = BK.unsqueeze(BK.arange_idx(0, query_len), 1) - distance = dist_x - dist_y # [query, key] - if not self.use_neg_dist: - distance = BK.abs(distance) - distance = BK.clamp(distance, min=-self.clip_dist, max=self.clip_dist) + self.clip_dist + distance, dist_atts, dist_values = self.dist_helper(query_len, key_len) else: - distance = None + distance, dist_atts, dist_values = None, None, None # # 1. project the three key_up = self._shape_project(self.affine_k(key)) # [*, head, len_k, d_k] value_up = self._shape_project(self.affine_v(value)) # [*, head, len_k, d_v] query_up = self._shape_project(self.affine_q(query)) # [*, head, len_q, d_k] # 2. calculate and scale scores - query_up = query_up / math.sqrt(self.d_kqv) + query_up = query_up / self._att_scale scores = BK.matmul(query_up, BK.transpose(key_up, -1, -2)) # [*, head, len_q, len_k] # todo(+2): for convenience, assuming using pytorch here and *==batch_size if self.use_distance: - distance_out = self.edge_keys(distance) # [query, key, d_kqv] + distance_out = dist_atts # [query, key, d_kqv] out = distance_out.view([1]*batch_dim_size + [1, query_len, key_len, self.d_kqv]) # [*, head, len_q, 1, d_k] * [*, 1, query, key, d_kqv] = [*, head, len_q, 1, len_k] # add_term = BK.matmul(query_up.unsqueeze(-2), out.transpose(-1, -2)).squeeze(-2) @@ -135,24 +229,24 @@ def __call__(self, key, value, query, mask_k=None): if mask_k is not None: # todo(warn): mask as [*, len] scores += (1.-mask_k).unsqueeze(-2).unsqueeze(-3) * Constants.REAL_PRAC_MIN - # mask_bytes = (mask_k==0.) - # mask = mask_bytes.unsqueeze(-2).unsqueeze(-3).expand_as(scores) # mask: [*, head, len_q, len_k] - # scores = scores.masked_fill(mask, Constants.REAL_PRAC_MIN) - # -- ranged clip - if self.use_ranges: - last_dims = [self.head_count, query_len, key_len] - rr_x = BK.unsqueeze(BK.arange_idx(0, key_len), 0) - rr_y = BK.unsqueeze(BK.arange_idx(0, query_len), 1) - rr_xy = BK.abs(rr_x - rr_y).float().unsqueeze(0).expand(last_dims) # [head, query, key] - scores = scores.masked_fill( - (rr_xy > BK.input_real(self.att_ranges).unsqueeze(-1).unsqueeze(-1)), Constants.REAL_PRAC_MIN) - # + if mask_qk is not None: + # todo(note): mask as [*, len-q, len-k] + scores += (1.-mask_qk).unsqueeze(-3) * Constants.REAL_PRAC_MIN + if self.range_clipper: + scores = self.range_clipper(query_len, key_len, scores) attn = BK.softmax(scores, -1) # [*, head, len_q, len_k] + # mix eprob for the first several heads? + if eprob_qk is not None and eprob_mix_rate>0. and eprob_head_count>0: + # todo(warn): here assume that [*] == [bs]; cannot do inplaced operation here, otherwise backward error + # attn[:, :eprob_head_count] = eprob_mix_rate*(eprob_qk.unsqueeze(-3)) + (1.-eprob_mix_rate)*(attn[:, :eprob_head_count]) + attn0 = eprob_mix_rate*(eprob_qk.unsqueeze(-3)) + (1.-eprob_mix_rate)*(attn[:, :eprob_head_count]) + attn1 = attn[:, eprob_head_count:] + attn = BK.concat([attn0, attn1], 1) drop_attn = self.adrop(attn) # [*, head, len_q, len_k] context = BK.matmul(drop_attn, value_up) # [*, head, len_q, d_v] if self.use_distance: # specific rel-position values as in https://arxiv.org/pdf/1803.02155.pdf - distance_out2 = self.edge_values(distance) # [query, key, dim] + distance_out2 = dist_values # [query, key, dim] out2 = distance_out2.view([1]*batch_dim_size + [1, query_len, key_len, self.d_kqv]) add_term2 = BK.matmul(drop_attn.unsqueeze(-2), out2).squeeze(-2) # [*, head, len_q, d_v] context += add_term2 @@ -161,67 +255,92 @@ def __call__(self, key, value, query, mask_k=None): output = self.final_linear(context_merge) # [*, len_q, mdim] return output -# special form of multihead attention (fixed attention) -class MultiHeadFixedAttention(AttentionNode): - def __init__(self, pc, dim_k, dim_q, dim_v, d_kqv, head_count, att_dropout, use_ranges, self_weight, init_rop=None, name=None): - super().__init__(pc, dim_k, dim_q, dim_v, att_dropout, init_rop, name) - # - self.head_count = head_count +# ===== +# relational attention: adopt another input for the relational embeddings +# todo(+N): repeated codes with the class of "MultiHeadAttention" +# TODO(+N): still require too much memory +class MultiHeadRelationalAttention(AttentionNode): + def __init__(self, pc, dim_k, dim_q, dim_v, aconf: AttConf, init_rop=None, name=None): + super().__init__(pc, dim_k, dim_q, dim_v, aconf, init_rop, name) self.model_dim = dim_v - self.d_kqv = d_kqv - # todo(+3): is this a good choice, or make it tunable or parameterized? - self.att_scales = [0.1] * self.head_count - self.att_ranges = [2**z-1 for z in range(self.head_count)] # exclude if dist>rr - self.use_ranges = use_ranges - self.self_weight = self_weight # should be negative! + self.head_count = aconf.head_count + self.d_kqv = aconf.d_kqv + self.d_r = aconf.dim_r + # ===== + # pre-fixed values for later computation + self._att_scale = math.sqrt(self.d_kqv) + if aconf.use_ranges: + self.range_clipper = self.add_sub_node("rc", AttRangeHelper(pc, aconf)) + else: + self.range_clipper = None + # ===== + eff_hidden_size = self.head_count * self.d_kqv + # todo(+2): should we use more dropouts here? # no dropouts here - self.affine_v = self.add_sub_node("av", Affine(pc, dim_v, head_count*d_kqv, init_rop=NoDropRop(), bias=True, affine2=True)) - # todo(+2): with drops(y) & act(n) & bias(y)? - self.final_linear = self.add_sub_node("fl", Affine(pc, head_count*d_kqv, dim_v, bias=True, affine2=True)) + self.affine_k = self.add_sub_node("ak", Affine(pc, dim_k, eff_hidden_size, init_rop=NoDropRop())) + # todo(+N): need to add rel-aware values as well? + self.affine_v = self.add_sub_node("av", Affine(pc, dim_v, eff_hidden_size, init_rop=NoDropRop())) + # special relational one for query + self.affine_q = self.add_sub_node("aq", Affine(pc, dim_q, eff_hidden_size*self.d_r, init_rop=NoDropRop())) + # todo(+2): with drops(y) & act(?) & bias(y)? + self.final_linear = self.add_sub_node("fl", Affine(pc, eff_hidden_size, dim_v, act=aconf.out_act)) - # def _shape_project(self, x): x_size = BK.get_shape(x)[:-1] + [self.head_count, -1] return BK.transpose(BK.reshape(x, x_size), -2, -3) def _unshape_project(self, x): x_size = BK.get_shape(x) - return BK.reshape(BK.transpose(x, -2, -3), x_size[:-3] + [x_size[-2], -1]) + return BK.reshape(BK.transpose(x, -2, -3), x_size[:-3]+[x_size[-2], -1]) - # - def __call__(self, key, value, query, mask_k=None): + # todo(+2): kqv=[*, len, dim], rel=[*, len_q, len_k, dim_r] + def __call__(self, key, value, query, rel, mask_k=None): batch_dim_size = len(BK.get_shape(key)) - 2 - value_len = BK.get_shape(value, -2) - # 1. only project the value + query_len = BK.get_shape(query, -2) + key_len = BK.get_shape(key, -2) + # + # 1. project the three + key_up = self._shape_project(self.affine_k(key)) # [*, head, len_k, d_k] value_up = self._shape_project(self.affine_v(value)) # [*, head, len_k, d_v] - # 2. get fixed attention scores (no parameters here) - with BK.no_grad_env(): - last_dims = [self.head_count, value_len, value_len] - # - dist_x = BK.unsqueeze(BK.arange_idx(0, value_len), 0) - dist_y = BK.unsqueeze(BK.arange_idx(0, value_len), 1) - distance = BK.abs(dist_x-dist_y).float().unsqueeze(0).expand(last_dims) # [head, query, key] - # - multi range clip - if self.use_ranges: - clip_distance = distance.masked_fill((distance>BK.input_real(self.att_ranges).unsqueeze(-1).unsqueeze(-1)), - Constants.REAL_PRAC_MAX) - else: - clip_distance = distance - # - settable self-weight (use special weight instead of 0.) - neg_distance = BK.eye(value_len) * self.self_weight - clip_distance - # - multi scales: [head, query, key] - fixed_scores = neg_distance.unsqueeze(0) * BK.input_real(self.att_scales).unsqueeze(-1).unsqueeze(-1) - # [*, head, q-len, k-len] - fixed_scores = fixed_scores.view([1]*batch_dim_size + last_dims).expand(BK.get_shape(key)[:batch_dim_size] + last_dims) - # 3. apply attention + # special for query, select parameters by relation + query_up = self._shape_project(self.affine_q(query)) # [*, head, len_q, d_k*d_r] + # 2. calculate relational att scores + key_up = (key_up / self._att_scale).unsqueeze(-3).unsqueeze(-2) # [*, head, 1, len_k, 1, d_k] + query_up = query_up.view(BK.get_shape(query_up)[:-1]+[1,self.d_kqv,self.d_r]) # [*, head, len_q, 1, d_k, d_r] + scores0 = BK.matmul(key_up, query_up).squeeze(-2) # [*, head, len_q, len_k, d_r] + scores = BK.sum(scores0 * rel.unsqueeze(-4), -1) # [*, head, len_q, len_k] + # 3. attention if mask_k is not None: - # mask is [*, len], thus no need to expand here - fixed_scores += (1.-mask_k).unsqueeze(-2).unsqueeze(-3) * Constants.REAL_PRAC_MIN - # - attn = BK.softmax(fixed_scores, -1) # [*, head, len_q, len_k] + # todo(warn): mask as [*, len] + scores += (1.-mask_k).unsqueeze(-2).unsqueeze(-3) * Constants.REAL_PRAC_MIN + if self.range_clipper: + scores = self.range_clipper(query_len, key_len, scores) + attn = BK.softmax(scores, -1) # [*, head, len_q, len_k] drop_attn = self.adrop(attn) # [*, head, len_q, len_k] context = BK.matmul(drop_attn, value_up) # [*, head, len_q, d_v] + # todo(+N): do the values need to be aware of rel? # 4. final context_merge = self._unshape_project(context) # [*, len_q, d_v*head] output = self.final_linear(context_merge) # [*, len_q, mdim] return output + +# ===== +class MultiHeadSelfDistAttention(AttentionNode): + def __init__(self, pc, dim_k, dim_q, dim_v, aconf: AttConf, init_rop=None, name=None): + super().__init__(pc, dim_k, dim_q, dim_v, aconf, init_rop, name) + self.dim_r = aconf.dim_r + self.rel_att = self.add_sub_node("ra", MultiHeadRelationalAttention(pc, dim_k, dim_q, dim_v, aconf)) + assert (aconf.clip_dist > 0), "Must use distance here!" + self.dist_helper = self.add_sub_node("dh", AttDistHelper(pc, aconf, aconf.dim_r)) + + # todo(+2): assuming (batch, len, dim) + def __call__(self, key, value, query, mask_k=None): + batch_dim_size = len(BK.get_shape(key)) - 2 + key_len = BK.get_shape(key, -2) + query_len = BK.get_shape(query, -2) + # get distance embeddings + distance, dist_atts, dist_values = self.dist_helper(query_len, key_len) # [query, key, d_r] + # todo(+N): make use of dist_values? + # expand batch dimensions + rel_input = dist_atts.view([1] * batch_dim_size + [query_len, key_len, self.dim_r]) + return self.rel_att(key, value, query, rel_input, mask_k=mask_k) diff --git a/msp/nn/layers/basic.py b/msp/nn/layers/basic.py index 447d80b..a487dc2 100644 --- a/msp/nn/layers/basic.py +++ b/msp/nn/layers/basic.py @@ -3,6 +3,8 @@ from ..backends import BK from ..backends.common import get_unique_name +from msp.utils import extract_stack + # with special None-Rule for turning off refresh class RefreshOptions(object): def __init__(self, training=True, hdrop=0., idrop=0., gdrop=0., edrop=0., dropmd=0., fix_drop=False, trainable=True, fix_set=()): @@ -38,7 +40,8 @@ def NoFixRop(): return RefreshOptions(fix_drop=False, fix_set=("fix_drop", )) # helpers class ActivationHelper(object): - ACTIVATIONS = {"tanh": BK.tanh, "softmax": BK.softmax, "relu": BK.relu, "elu": BK.elu, "linear": lambda x:x} + ACTIVATIONS = {"tanh": BK.tanh, "softmax": BK.softmax, "relu": BK.relu, "elu": BK.elu, + "sigmoid": BK.sigmoid, "linear": lambda x:x} # reduction for seq after conv POOLINGS = {"max": lambda x: BK.max(x, -2)[0], "avg": lambda x: BK.avg(x, -2)} @@ -72,6 +75,7 @@ def __init__(self, pc, name, init_rop): self.rop = init_rop # self.parent = None # only one parent who owns this node (controlling refresh)! + self.pc.nnc_push(self.name) def get_unique_name(self, name): return get_unique_name(self.name_dict, name) @@ -105,7 +109,7 @@ def get_output_dims(self, *input_dims): return input_dims # create param from PC - def add_param(self, name, shape, init=None, lookup=False): + def add_param(self, name, shape, init=None, lookup=False, check_stack=True): if init is None: w = BK.get_params_init(shape, "default", lookup) elif isinstance(init, str): @@ -113,13 +117,14 @@ def add_param(self, name, shape, init=None, lookup=False): else: w = init name = self.get_unique_name(name) - combined_name = self.name + "/" + name + combined_name = self.pc.nnc_name(self.name, check_stack) + "/" + name ret = self.pc.param_new(combined_name, shape, w, lookup) self.params[name] = ret return ret # only recording the sub-node, which is built previously def add_sub_node(self, name, node): + self.pc.nnc_pop(node.name) assert isinstance(node, BasicNode), "Subnodes should be a Node!" name = self.get_unique_name(name) if node.parent is None: @@ -129,6 +134,14 @@ def add_sub_node(self, name, node): self.temp_nodes[name] = node return node + # get params recursively (owned subnodes) + def get_parameters(self, recursively=True): + ret = list(self.params.values()) + if recursively: + for node in self.sub_nodes.values(): + ret.extend(node.get_parameters(recursively)) + return ret + # commonly used Nodes class Dropout(BasicNode): def __init__(self, pc, shape, which_drop="hdrop", name=None, init_rop=None): @@ -148,7 +161,8 @@ def refresh(self, rop=None): # r = self.rop drop = self.drop_getter_(r) - if not r.training: # another overall switch + # todo(+3): another overall switch, not quite elegant! + if not r.training: self.f_ = lambda x: x else: self.f_ = Dropout._dropout_f_obtain(r.fix_drop, drop, self.shape) diff --git a/msp/nn/layers/biaffine.py b/msp/nn/layers/biaffine.py index 377e70a..d5bdeb8 100644 --- a/msp/nn/layers/biaffine.py +++ b/msp/nn/layers/biaffine.py @@ -4,8 +4,8 @@ import numpy as np from collections import Iterable from ..backends import BK -from . import Affine, get_mlp, BasicNode, ActivationHelper, Dropout, NoDropRop, NoFixRop -from msp.utils import zcheck, zwarn, zlog, Random, Constants +from . import Affine, Embedding, get_mlp, BasicNode, ActivationHelper, Dropout, NoDropRop, NoFixRop +from msp.utils import zcheck, zfatal, zwarn, zlog, Random, Constants # # biaffine (inputs are already paired) @@ -27,8 +27,8 @@ def __init__(self, pc, n_ins0, n_ins1, n_out, n_hidden, act="linear", bias=True, self.act = act self._act_f = ActivationHelper.get_act(act) # params - self.affine0 = self.add_sub_node("a0", Affine(pc, n_ins0, n_hidden, affine2=False, init_rop=hidden_init_rop)) - self.affine1 = self.add_sub_node("a1", Affine(pc, n_ins1, n_hidden, affine2=False, init_rop=hidden_init_rop)) + self.affine0 = self.add_sub_node("a0", Affine(pc, n_ins0, n_hidden, which_affine=1, init_rop=hidden_init_rop)) + self.affine1 = self.add_sub_node("a1", Affine(pc, n_ins1, n_hidden, which_affine=1, init_rop=hidden_init_rop)) self.W = self.add_param(name="W", shape=(n_out, n_hidden, n_hidden)) if bias: self.b = self.add_param(name="B", shape=(n_out, )) @@ -63,7 +63,7 @@ def get_output_dims(self, *input_dims): # # the final layer scorer (no dropout if used as the final scoring layer) class BiAffineScorer(BasicNode): - def __init__(self, pc, in_size0, in_size1, out_size, ff_hid_size=-1, ff_hid_layer=0, use_bias=True, use_biaffine=True, use_ff=True, no_final_drop=True, name=None, init_rop=None, mask_value=Constants.REAL_PRAC_MIN): + def __init__(self, pc, in_size0, in_size1, out_size, ff_hid_size=-1, ff_hid_layer=0, use_bias=True, use_biaffine=True, biaffine_div: float=None, use_ff=True, use_ff2=False, no_final_drop=True, name=None, init_rop=None, mask_value=Constants.REAL_PRAC_MIN, biaffine_init_ortho=False): super().__init__(pc, name, init_rop) # self.in_size0 = in_size0 @@ -72,21 +72,36 @@ def __init__(self, pc, in_size0, in_size1, out_size, ff_hid_size=-1, ff_hid_laye self.ff_hid_size = ff_hid_size self.use_bias = use_bias self.use_biaffine = use_biaffine + self.biaffine_init_ortho = biaffine_init_ortho + # + # divide the biaffine scores by `biaffine_div' + if biaffine_div is None: + biaffine_div = 1. + elif biaffine_div <= 0.: + biaffine_div = (in_size0*in_size1) ** 0.25 # sqrt(sqrt(in1*in2)) + zlog(f"Adopt biaffine_div of {biaffine_div} for the current BiaffineScorer") + self.biaffine_div = biaffine_div + # self.use_ff = use_ff + self.use_ff2 = use_ff2 self.no_final_drop = no_final_drop self.mask_value = mask_value # function practically as -inf # if self.use_bias: self.B = self.add_param(name="B", shape=(out_size, )) - zcheck(use_ff or use_biaffine, "No real calculations here!") + zcheck(use_ff or use_ff2 or use_biaffine, "No real calculations here!") if self.use_ff: self.A0 = self.add_sub_node("A0", get_mlp(pc, in_size0, out_size, ff_hid_size, n_hidden_layer=ff_hid_layer, final_bias=False, final_init_rop=NoDropRop())) self.A1 = self.add_sub_node("A1", get_mlp(pc, in_size1, out_size, ff_hid_size, n_hidden_layer=ff_hid_layer, final_bias=False, final_init_rop=NoDropRop())) + if self.use_ff2: + # todo(+2): concat previously, not efficient!! + self.AA = self.add_sub_node("AA", get_mlp(pc, in_size0+in_size1, out_size, ff_hid_size, n_hidden_layer=ff_hid_layer, + final_bias=False, final_init_rop=NoDropRop())) if self.use_biaffine: # this is different than BK.bilinear or layers.BiAffine - self.W = self.add_param(name="W", shape=(in_size0, in_size1*out_size)) + self.W = self.add_param(name="W", shape=(in_size0, in_size1*out_size), init=("ortho" if biaffine_init_ortho else "default")) # todo(0): meaningless to use fix-drop if no_final_drop: self.drop_node = lambda x: x @@ -94,17 +109,26 @@ def __init__(self, pc, in_size0, in_size1, out_size, ff_hid_size=-1, ff_hid_laye self.drop_node = self.add_sub_node("drop", Dropout(pc, (self.out_size,), init_rop=NoFixRop())) # plain call - # io: [*, in_size0], [*, in_size1] -> [*, out_size]; masks: [*] - def plain_score(self, input0, input1, mask0=None, mask1=None): + # io: [*, in_size0], [*, in_size1] -> [*, out_size]; masks: [*]; rel1: [*] + def plain_score(self, input0, input1, mask0=None, mask1=None, rel1_t=None): ret = 0 + if rel1_t is not None: + input1 = input1 + rel1_t if self.use_ff: ret = self.A0(input0) + self.A1(input1) + if self.use_ff2: + shape0 = BK.get_shape(input0) + shape1 = BK.get_shape(input1) + cur_shape = [max(a,b) for a,b in zip(shape0, shape1)] + cur_shape[-1] = -1 + cur_input = BK.concat([input0.expand(cur_shape), input1.expand(cur_shape)], dim=-1) + ret += self.AA(cur_input) if self.use_biaffine: # [*, in0] * [in0, in1*out] -> [*, in1*out] -> [*, in1, out] expr0 = BK.matmul(input0, self.W).view(BK.get_shape(input0)[:-1]+[self.in_size1, self.out_size]) # [*, 1, in1] * [*, in1, out] -> [*, 1, out] -> [*, out] expr1 = BK.matmul(input1.unsqueeze(-2), expr0).squeeze(-2) - ret += expr1 + ret += expr1 / self.biaffine_div if self.use_bias: ret += self.B # mask @@ -116,21 +140,31 @@ def plain_score(self, input0, input1, mask0=None, mask1=None): # special call # io: [*, len0, in_size0], [*, len1, in_size1] -> [*, len0, len1, out_size], masks: [*, len0], [*, len1] - def paired_score(self, input0, input1, mask0=None, mask1=None): + def paired_score(self, input0, input1, mask0=None, mask1=None, rel1_t=None): ret = 0 - # [*, 1, len1, in_size1] - expand1 = input1.unsqueeze(-3) + # [*, len0, 1, in_size0] + expand0 = input0.unsqueeze(-2) + # [*, ?, len1, in_size1] + if rel1_t is None: + expand1 = input1.unsqueeze(-3) + else: + expand1 = input1.unsqueeze(-3) + rel1_t if self.use_ff: - # [*, len0, 1, in_size0] - expand0 = input0.unsqueeze(-2) ret = self.A0(expand0) + self.A1(expand1) + if self.use_ff2: + shape0 = BK.get_shape(expand0) + shape1 = BK.get_shape(expand1) + shape0[-2] = shape1[-2] + shape1[-3] = shape0[-3] + cur_input = BK.concat([expand0.expand(shape0), expand1.expand(shape1)], dim=-1) + ret += self.AA(cur_input) if self.use_biaffine: # [*, len0, in0] * [in0, in1*out] -> [*, len0, in1*out] -> [*, len0, in1, out] expr0 = BK.matmul(input0, self.W).view(BK.get_shape(input0)[:-1]+[self.in_size1, self.out_size]) # [*, 1, len1, in1] * [*, len0, in1, out] expr1 = BK.matmul(expand1, expr0) # [*, len0, len1, out] - ret += expr1 + ret += expr1 / self.biaffine_div if self.use_bias: ret += self.B # mask @@ -146,3 +180,49 @@ def __call__(self, *args, **kwargs): def get_output_dims(self, *input_dims): return (self.out_size, ) + + # ===== + # special runnings (pre-computating half of them) + # todo(note): always plain mode here. + + # [*, in_size0], [*] -> (ff: [*, out_size], ff2: [*, ff_hid_size], biaffine: [*, in_size1*out_size]) + def precompute_input0(self, input0): + ff_score0 = None + ff2_hid0 = None + biaff_hid0 = None + if self.use_ff: + ff_score0 = self.A0(input0) + if self.use_ff2: + # todo(+3) + zfatal("Not supported in this mode!") + if self.use_biaffine: + biaff_hid0 = BK.matmul(input0, self.W) + # return (ff_score0, ff2_hid0, biaff_hid0) + return (ff_score0, biaff_hid0) + + # basically following self.plain_score + # input0_package should be the tuple and already index_selected, input1: [*, input1] + def postcompute_input1(self, input0_package, input1, mask0=None, mask1=None, rel1_t=None): + # ff_score0, ff2_hid0, biaff_hid0 = input0_package + ff_score0, biaff_hid0 = input0_package + ret = 0 + if rel1_t is not None: + input1 = input1 + rel1_t + if self.use_ff: + ret = ff_score0 + self.A1(input1) + if self.use_ff2: + zfatal("Not supported in this mode!") + if self.use_biaffine: + # [*, in1, out] + expr0 = biaff_hid0.view(BK.get_shape(biaff_hid0)[:-1]+[self.in_size1, self.out_size]) + # [*, 1, in1] * [*, in1, out] -> [*, 1, out] -> [*, out] + expr1 = BK.matmul(input1.unsqueeze(-2), expr0).squeeze(-2) + ret += expr1 / self.biaffine_div + if self.use_bias: + ret += self.B + # mask + if mask0 is not None: + ret += self.mask_value*(1.-mask0).unsqueeze(-1) + if mask1 is not None: + ret += self.mask_value*(1.-mask1).unsqueeze(-1) + return self.drop_node(ret) diff --git a/msp/nn/layers/enc.py b/msp/nn/layers/enc.py index e5fc973..44b3f2e 100644 --- a/msp/nn/layers/enc.py +++ b/msp/nn/layers/enc.py @@ -2,16 +2,18 @@ # mostly used for encoders +import numpy as np +from typing import List + from ..backends import BK from .basic import BasicNode, Dropout, ActivationHelper -from .att import MultiHeadAttention, MultiHeadFixedAttention +from .att import AttentionNode, AttConf from .ff import LayerNorm -from .multi import AddNormWrapper, AddActWrapper, get_mlp - -import numpy as np +from .multi import AddNormWrapper, AddActWrapper, get_mlp, GatedMixer # # rnn nodes -# [inputs] or input + {"H":hidden, "C":(optional)cell} -> {"H":hidden, "C":(optional)cell} +# OLD: [inputs] or input + {"H":hidden, "C":(optional)cell} -> {"H":hidden, "C":(optional)cell} +# todo(note): change to new format of tuple (hidden, (optional) cell)! # no output drops class RnnNode(BasicNode): def __init__(self, pc, n_input, n_hidden, name=None, init_rop=None): @@ -25,8 +27,7 @@ def __init__(self, pc, n_input, n_hidden, name=None, init_rop=None): def zero_init_hidden(self, bsize): raise NotImplementedError() - def __call__(self, input_exp, hidden_exp, mask): - # todo(warn): return a {} + def __call__(self, input_exp, hidden_exp_tuple, mask): raise NotImplementedError() def __repr__(self): @@ -65,10 +66,11 @@ def __init__(self, pc, n_input, n_hidden, name=None, init_rop=None): self.bh = self.add_param("bh", (self.n_hidden, )) # input: value, hidden: dict of value - def __call__(self, input_exp, hidden_exp, mask): + def __call__(self, input_exp, hidden_exp_tuple, mask): input_exp = self.idrop_node(input_exp) - # todo(warn): only using "H" - hidden_exp_h = self.gdrop_node(hidden_exp["H"]) + # todo(warn): only using [0] + orig_hidden_exp_h = hidden_exp_tuple[0] + hidden_exp_h = self.gdrop_node(orig_hidden_exp_h) # rzt = BK.affine([self.brz, self.x2rz, input_exp, self.h2rz, hidden_exp_h]) rzt = BK.sigmoid(rzt) @@ -77,14 +79,15 @@ def __call__(self, input_exp, hidden_exp, mask): ht = BK.affine([self.bh, self.x2h, input_exp, self.h2h, h_reset]) ht = BK.tanh(ht) # hidden = BK.cmult(zt, hidden_exp["H"]) + BK.cmult((1. - zt), ht) # first one use original hh - hidden = ht + zt * (hidden_exp["H"] - ht) + hidden = ht + zt * (orig_hidden_exp_h - ht) if mask is not None: - hidden = self._apply_mask(mask, hidden, hidden_exp["H"]) - return {"H": hidden} + hidden = self._apply_mask(mask, hidden, orig_hidden_exp_h) + return (hidden, ) def zero_init_hidden(self, bsize): - return {"H": BK.zeros((bsize, self.n_hidden))} + return (BK.zeros((bsize, self.n_hidden)), ) +# todo(+N): get strangely large scaled values from this one! class GruNode2(RnnNode): def __init__(self, pc, n_input, n_hidden, name=None, init_rop=None): super(GruNode2, self).__init__(pc, n_input, n_hidden, name, init_rop) @@ -93,17 +96,18 @@ def __init__(self, pc, n_input, n_hidden, name=None, init_rop=None): self.xb = self.add_param("xb", (3*self.n_hidden, )) self.hb = self.add_param("hb", (3*self.n_hidden, )) - def __call__(self, input_exp, hidden_exp, mask): + def __call__(self, input_exp, hidden_exp_tuple, mask): input_exp = self.idrop_node(input_exp) - hidden_exp_h = self.gdrop_node(hidden_exp["H"]) + orig_hidden_exp_h = hidden_exp_tuple[0] + hidden_exp_h = self.gdrop_node(orig_hidden_exp_h) # - hidden = BK.lstm_oper(input_exp, hidden_exp_h, self.x2h, self.h2h, self.xb, self.hb) + hidden = BK.gru_oper(input_exp, hidden_exp_h, self.x2h, self.h2h, self.xb, self.hb) if mask is not None: - hidden = self._apply_mask(mask, hidden, hidden_exp["H"]) - return {"H": hidden} + hidden = self._apply_mask(mask, hidden, orig_hidden_exp_h) + return (hidden, ) def zero_init_hidden(self, bsize): - return {"H": BK.zeros((bsize, self.n_hidden))} + return (BK.zeros((bsize, self.n_hidden)), ) class LstmNode(RnnNode): def __init__(self, pc, n_input, n_hidden, name=None, init_rop=None): @@ -113,9 +117,10 @@ def __init__(self, pc, n_input, n_hidden, name=None, init_rop=None): self.hw = self.add_param("hw", (4*self.n_hidden, n_hidden), init="ortho") self.b = self.add_param("b", (4*self.n_hidden, )) - def __call__(self, input_exp, hidden_exp, mask): + def __call__(self, input_exp, hidden_exp_tuple, mask): input_exp = self.idrop_node(input_exp) - hidden_exp_h = self.gdrop_node(hidden_exp["H"]) + orig_h, orig_c = hidden_exp_tuple + hidden_exp_h = self.gdrop_node(orig_h) # ifco = BK.affine([self.b, self.xw, input_exp, self.hw, hidden_exp_h]) i_t, f_t, g_t, o_t = BK.chunk(ifco, 4) @@ -123,16 +128,16 @@ def __call__(self, input_exp, hidden_exp, mask): f_t = BK.sigmoid(f_t) g_t = BK.tanh(g_t) o_t = BK.sigmoid(o_t) - c_t = BK.cmult(f_t, hidden_exp["C"]) + BK.cmult(i_t, g_t) + c_t = BK.cmult(f_t, orig_c) + BK.cmult(i_t, g_t) hidden = BK.cmult(o_t, BK.tanh(c_t)) if mask is not None: - hidden = self._apply_mask(mask, hidden, hidden_exp["H"]) - c_t = self._apply_mask(mask, c_t, hidden_exp["C"]) - return {"H": hidden, "C": c_t} + hidden = self._apply_mask(mask, hidden, orig_h) + c_t = self._apply_mask(mask, c_t, orig_c) + return (hidden, c_t) def zero_init_hidden(self, bsize): z0 = BK.zeros((bsize, self.n_hidden)) - return {"H": z0, "C": z0} + return (z0, z0) class LstmNode2(RnnNode): def __init__(self, pc, n_input, n_hidden, name=None, init_rop=None): @@ -143,19 +148,20 @@ def __init__(self, pc, n_input, n_hidden, name=None, init_rop=None): self.xb = self.add_param("xb", (4*self.n_hidden, )) self.hb = self.add_param("hb", (4*self.n_hidden, )) - def __call__(self, input_exp, hidden_exp, mask): + def __call__(self, input_exp, hidden_exp_tuple, mask): input_exp = self.idrop_node(input_exp) - hidden_exp_h = self.gdrop_node(hidden_exp["H"]) + orig_h, orig_c = hidden_exp_tuple + hidden_exp_h = self.gdrop_node(orig_h) # - hidden, c_t = BK.lstm_oper(input_exp, (hidden_exp_h, hidden_exp["C"]), self.xw, self.hw, self.xb, self.hb) + hidden, c_t = BK.lstm_oper(input_exp, (hidden_exp_h, orig_c), self.xw, self.hw, self.xb, self.hb) if mask is not None: - hidden = self._apply_mask(mask, hidden, hidden_exp["H"]) - c_t = self._apply_mask(mask, c_t, hidden_exp["C"]) - return {"H": hidden, "C": c_t} + hidden = self._apply_mask(mask, hidden, orig_h) + c_t = self._apply_mask(mask, c_t, orig_c) + return (hidden, c_t) def zero_init_hidden(self, bsize): z0 = BK.zeros((bsize, self.n_hidden)) - return {"H": z0, "C": z0} + return (z0, z0) # stateless encoder # here only one layer, to join by Seq @@ -201,22 +207,23 @@ def __call__(self, embeds, masks=None): if masks is None: masks = [None for _ in embeds] bsize = BK.get_shape(embeds[0], 0) - init_hidden = BK.zeros((bsize, self.n_hidden)) # broadcast + init_hidden = BK.zeros((bsize, self.n_hidden)) # broadcast + init_states = (init_hidden, init_hidden) # todo(note): inner side of the RNN-units for layer_idx in range(self.n_layers): f_node = self.fnodes[layer_idx] tmp_f = [] # forward - tmp_f_prev = {"H":init_hidden, "C":init_hidden} + tmp_f_prev = init_states for e, m in zip(outputs[-1], masks): one_output = f_node(e, tmp_f_prev, m) - tmp_f.append(one_output["H"]) + tmp_f.append(one_output[0]) # todo(note): inner side of the RNN-units tmp_f_prev = one_output if self.bidirection: b_node = self.bnodes[layer_idx] tmp_b = [] # backward - tmp_b_prev = {"H":init_hidden, "C":init_hidden} + tmp_b_prev = init_states for e, m in zip(reversed(outputs[-1]), reversed(masks)): one_output = b_node(e, tmp_b_prev, m) - tmp_b.append(one_output["H"]) + tmp_b.append(one_output[0]) # todo(note): inner side of the RNN-units tmp_b_prev = one_output # concat ctx = [BK.concat([f,b]) for f,b in zip(tmp_f, reversed(tmp_b))] @@ -236,7 +243,7 @@ def __call__(self, embeds, masks=None): class RnnLayerBatchFirstWrapper(BasicNode): def __init__(self, pc, rnn_node): super().__init__(pc, None, None) - self.rnn_node = self.add_sub_node("z", rnn_node) + self.rnn_node: RnnLayer = self.add_sub_node("z", rnn_node) # mask>0. means valid, mask=0. means padding @staticmethod @@ -277,12 +284,12 @@ def __init__(self, pc, n_input, n_output, n_windows, pooling=None, act="linear", n_windows = [n_windows] self.n_out = n_output * len(n_windows) # simply concat self.conv_opers = [BK.CnnOper(self, n_input, n_output, z) for z in n_windows] # add params here - # pooling? -> act -> dropout + # act -> pooling -> dropout + self._act_f = ActivationHelper.get_act(act) if pooling is not None: self._pool_f = ActivationHelper.get_pool(pooling) else: self._pool_f = lambda x: x - self._act_f = ActivationHelper.get_act(act) self.drop_node = self.add_sub_node("drop", Dropout(pc, (self.n_out,))) # input should be a single Expr @@ -290,10 +297,10 @@ def __init__(self, pc, n_input, n_output, n_windows, pooling=None, act="linear", def __call__(self, input_expr, mask_arr=None): output_exprs = [oper.conv(input_expr) for oper in self.conv_opers] output_expr = BK.concat(output_exprs, -1) - # pooling? -> act -> dropout - expr_p = self._pool_f(output_expr) - expr_a = self._act_f(expr_p) - expr_d = self.drop_node(expr_a) + # act -> pooling -> dropout + expr_a = self._act_f(output_expr) + expr_p = self._pool_f(expr_a) + expr_d = self.drop_node(expr_p) return expr_d def get_output_dims(self, *input_dims): @@ -301,18 +308,29 @@ def get_output_dims(self, *input_dims): CnnLayer = CnnNode +# ===== +# self attentional encoders + +# common helper +def get_selfatt_node(pc, d_model, aconf: AttConf, fix_range_val): + orig_range_val = aconf._fixed_range_val + if fix_range_val is not None: + aconf._fixed_range_val = fix_range_val + node = AttentionNode.get_att_node(aconf.type, pc, d_model, d_model, d_model, aconf) + aconf._fixed_range_val = orig_range_val + return node + +# ===== # transformer (self-attention) # todo(warn): layer-norm at the output rather than input class TransformerEncoderLayer(BasicNode): - def __init__(self, pc, d_model, d_ff, d_kqv=64, head_count=8, att_type="mh", add_wrapper="addnorm", att_dropout=0.1, att_use_ranges=False, attf_selfw=0., clip_dist=0, use_neg_dist=True, name=None, init_rop=None): + def __init__(self, pc, d_model, d_ff, add_wrapper, aconf: AttConf, fix_range_val, name=None, init_rop=None): super().__init__(pc, name, init_rop) # todo(note): residual adding is like preserving 0.5 attention for self wrapper_getf = {"addnorm": AddNormWrapper, "addtanh": lambda *args: AddActWrapper(*args, act="tanh")}[add_wrapper] # - att_getf = {"mh": lambda: MultiHeadAttention(pc, d_model, d_model, d_model, d_kqv, head_count, att_dropout, att_use_ranges, clip_dist, use_neg_dist), "mhf": lambda: MultiHeadFixedAttention(pc, d_model, d_model, d_model, d_kqv, head_count, att_dropout, att_use_ranges, attf_selfw)}[att_type] - # self.d_model = d_model - self.self_att = self.add_sub_node("s", wrapper_getf(att_getf(), [d_model])) + self.self_att = self.add_sub_node("s", wrapper_getf(get_selfatt_node(pc, d_model, aconf, fix_range_val), [d_model])) self.feed_forward = self.add_sub_node("ff", wrapper_getf( get_mlp(pc, d_model, d_model, d_ff, n_hidden_layer=1, hidden_act="relu", final_act="linear"), [d_model])) @@ -327,17 +345,23 @@ def __call__(self, input_expr, mask_arr=None): # transformer (multiple-layers) class TransformerEncoder(BasicNode): - def __init__(self, pc, n_layers, d_model, d_ff, d_kqv=64, head_count=8, att_type="mh", add_wrapper="addnorm", att_dropout=0.1, att_use_ranges=False, attf_selfw=0., clip_dist=0, use_neg_dist=True, name=None, init_rop=None): + def __init__(self, pc, n_layers, d_model, d_ff, add_wrapper, aconf: AttConf, fixed_range_vals, final_act="linear", name=None, init_rop=None): super().__init__(pc, name, init_rop) # self.d_model = d_model self.n_layers = n_layers - # - in_getf = {"addnorm": lambda *args: self.add_sub_node("in", LayerNorm(*args)), + # todo(note): here do not make std back since there can be zero input problems! + in_getf = {"addnorm": lambda *args: self.add_sub_node("in", LayerNorm(*args, std_no_grad=True)), "addtanh": lambda *args: ActivationHelper.get_act("tanh")}[add_wrapper] self.input_f = in_getf(pc, d_model) # - self.layers = [self.add_sub_node("one", TransformerEncoderLayer(pc, d_model, d_ff, d_kqv, head_count, att_type, add_wrapper, att_dropout, att_use_ranges, attf_selfw, clip_dist, use_neg_dist)) for _ in range(n_layers)] + if fixed_range_vals is None: + fixed_range_vals = [None] * self.n_layers + else: + assert len(fixed_range_vals) == self.n_layers, "Unmatched ranges and number of layers" + # todo(+N): not elegant here! + self.layers = [self.add_sub_node("one", TransformerEncoderLayer(pc, d_model, d_ff, add_wrapper, aconf, fix_range_val=fixed_range_vals[i])) for i in range(n_layers)] + self._act_f = ActivationHelper.get_act(final_act) def get_output_dims(self, *input_dims): return (self.d_model, ) @@ -345,6 +369,53 @@ def get_output_dims(self, *input_dims): def __call__(self, input_expr, mask_arr=None): mask_expr = None if mask_arr is None else BK.input_real(mask_arr) x = self.input_f(input_expr) + for layer in self.layers: + x = layer(x, mask_expr) + return self._act_f(x) + +# ====== +# my-transformer: with slight modifications +# -- fat-transformer: combining more horizontally + +class Transformer2EncoderLayer(BasicNode): + def __init__(self, pc, d_model, aconf: AttConf, short_range, long_range, name=None, init_rop=None): + super().__init__(pc, name, init_rop) + self.d_model = d_model + # use range or not depends on the one argument + self.short_att = self.add_sub_node("sa", get_selfatt_node(pc, d_model, aconf, short_range)) + self.long_att = self.add_sub_node("la", get_selfatt_node(pc, d_model, aconf, long_range)) + # final gated combiner (mix input/short-att/long-att) + self.combiner = self.add_sub_node("fc", GatedMixer(pc, d_model, 3)) + + def get_output_dims(self, *input_dims): + return (self.d_model, ) + + def __call__(self, input_expr, mask_arr=None): + mask_expr = None if mask_arr is None else BK.input_real(mask_arr) + short_output = self.short_att(input_expr, input_expr, input_expr, mask_expr) + long_output = self.long_att(input_expr, input_expr, input_expr, mask_expr) + final_output = self.combiner([input_expr, short_output, long_output]) + return final_output + +# my-transformer (multiple-layers) +class Transformer2Encoder(BasicNode): + def __init__(self, pc, n_layers, d_model, aconf: AttConf, short_range:int, long_ranges:List[int], name=None, init_rop=None): + super().__init__(pc, name, init_rop) + self.d_model = d_model + self.n_layers = n_layers + # + if long_ranges is None: + long_ranges = [None] * self.n_layers + else: + assert len(long_ranges) == self.n_layers, "Unmatched long ranges and number of layers" + self.layers = [self.add_sub_node("one", Transformer2EncoderLayer(pc, d_model, aconf, short_range=short_range, long_range=long_ranges[i])) for i in range(n_layers)] + + def get_output_dims(self, *input_dims): + return (self.d_model, ) + + def __call__(self, input_expr, mask_arr=None): + mask_expr = None if mask_arr is None else BK.input_real(mask_arr) + x = input_expr for layer in self.layers: x = layer(x, mask_expr) return x diff --git a/msp/nn/layers/ff.py b/msp/nn/layers/ff.py index 6bd96da..5f01064 100644 --- a/msp/nn/layers/ff.py +++ b/msp/nn/layers/ff.py @@ -15,7 +15,7 @@ # [inputs] or input -> output class Affine(BasicNode): # n_ins: [n_ins0, n_ins1, ...] - def __init__(self, pc, n_ins, n_out, act="linear", bias=True, affine2=True, name=None, init_rop=None): + def __init__(self, pc, n_ins, n_out, act="linear", bias=True, which_affine=2, name=None, init_rop=None): super().__init__(pc, name, init_rop) # list of n_ins and n_outs have different meanings: horizontal and vertical if not isinstance(n_ins, Iterable): @@ -40,7 +40,7 @@ def __init__(self, pc, n_ins, n_out, act="linear", bias=True, affine2=True, name self.drop_node = self.add_sub_node("drop", Dropout(pc, (self.n_out,))) self._input_list = self._init_input_list() # - self._affine_f = BK.affine2 if affine2 else BK.affine + self._affine_f = {1: BK.affine, 2: BK.affine2, 3: BK.affine3}[which_affine] def _init_input_list(self): # called when refresh @@ -77,10 +77,18 @@ def __call__(self, input_exp): def get_output_dims(self, *input_dims): return (self.n_out, ) + # special usage + def zero_params(self): + with BK.no_grad_env(): + for w in self.ws: + w.zero_() + if self.bias: + self.b.zero_() + # ========== # special ones class LayerNorm(BasicNode): - def __init__(self, pc, size, a_init=1., eps=1e-6, name=None): + def __init__(self, pc, size, a_init=1., eps=1e-6, name=None, std_no_grad=False): super().__init__(pc, name, None) # todo(+N): is this all right? a_init_v = np.sqrt(3./size) if a_init is None else a_init @@ -88,12 +96,18 @@ def __init__(self, pc, size, a_init=1., eps=1e-6, name=None): self.a_2 = self.add_param(name="A", shape=(size, ), init=np.full(size, a_init_v, dtype=np.float32)) self.b_2 = self.add_param(name="B", shape=(size, ), init=np.zeros(size, dtype=np.float32)) self.eps = eps + self.std_no_grad = std_no_grad # no droput here def __call__(self, x): mean = x.mean(-1, keepdim=True) std = x.std(-1, keepdim=True) - return self.a_2 * (x - mean) / (std + self.eps) + self.b_2 + if self.std_no_grad: + std = std.detach() + # todo(+N): if std is small, (for example, zero input), then div 1 since otherwise there will be NAN + # currently depend on outside flag (for the very first (input) layer) + to_div = std + self.eps + return self.a_2 * (x - mean) / to_div + self.b_2 # ===== Looking up Nodes # like an embedding, usually used as the score matrix @@ -124,10 +138,13 @@ class Embedding(BasicNode): def __init__(self, pc, n_words, n_dim, fix_row0=True, dropout_wordceil=None, npvec=None, name=None, init_rop=None, freeze=False): super(Embedding, self).__init__(pc, name, init_rop) if npvec is not None: - zcheck(len(npvec.shape) == 2 and npvec.shape[0] == n_words and npvec.shape[1] == n_dim, "Wrong dimension for init embeddings.") - zlog("Add embed W from npvec %s." % (npvec.shape,)) - if freeze: - self.rop.add_fixed_value("trainable", False) + if not (len(npvec.shape) == 2 and npvec.shape[0] == n_words and npvec.shape[1] == n_dim): + zlog(f"Wrong dimension for init embeddings {npvec.shape} instead of ({n_words}, {n_dim}), use random instead!!") + npvec = None + else: + zlog("Add embed W from npvec %s." % (npvec.shape,)) + if freeze: + self.rop.add_fixed_value("trainable", False) else: if freeze: self.rop.add_fixed_value("trainable", False) @@ -149,8 +166,10 @@ def replace_weights(self, npvec): zcheck(num_dim == self.n_dim, "Cannot change embedding dimension!") # replace # todo(+N): simply add one param here, the original one is still around + # todo(WARN): only use this for testing, since the new weights will not be added to the optimizer zlog(f"Replacing the embedding weights from ({self.n_words}, {num_dim}) to ({num_words}, {num_dim})") - self.E = self.add_param("E", (num_words, num_dim), init=npvec, lookup=True) + # here, we are adding params at the outside + self.E = self.add_param("E", (num_words, num_dim), init=npvec, lookup=True, check_stack=False) self.n_words = num_words self.dropout_wordceil = self.dropout_wordceil_hp if self.dropout_wordceil_hp is not None else self.n_words @@ -159,7 +178,7 @@ def _input_f_obtain(self, edrop): return lambda x: x else: # todo(warn): replaced sample, maybe an efficient but approx. impl & only works for 2d - edrop_rands = Random.random_sample((int(self.dropout_wordceil*edrop),), "sample") # [0,1) + edrop_rands = Random.random_sample((int(self.dropout_wordceil*edrop),)) # [0,1) edrop_idxes = [int(self.dropout_wordceil*z) for z in edrop_rands] edrop_set = set(edrop_idxes) return lambda x: [0 if one in edrop_set else one for one in x] # drop to 0 if fall in the set @@ -198,7 +217,7 @@ def __init__(self, pc: BK.ParamCollection, n_dim: int, max_len: int=5000, init_s if init_sincos: pe = np.zeros([max_len, n_dim]) position = np.arange(0, max_len).reshape([-1, 1]) - div_term = np.exp((np.arange(0, n_dim, 2) * -(math.log(10000.0)/n_dim))) + div_term = np.exp((np.arange(0, n_dim, 2) * -(math.log(2*max_len)/n_dim))) div_results = position * div_term pe[:, 0::2] = np.sin(div_results) pe[:, 1::2] = np.cos(div_results) @@ -236,3 +255,28 @@ def __call__(self, input_idxes): h0 = BK.lookup(self.E, clamped_idx_repr) h1 = self.drop_node(h0) return h1 + +# relative positional embeddings: [min, max]; no dropout! +class RelPosiEmbedding(BasicNode): + def __init__(self, pc: BK.ParamCollection, dim: int, max: int, min: int = None): + super().__init__(pc, None, None) + self.dim = dim + assert max>=0 + self.max = max + self.min = -max if min is None else min + assert self.min<=self.max + self.E = self.add_param("E", (self.max-self.min+1, dim), lookup=True) + + def get_output_dims(self, *input_dims): + return (self.dim, ) + + def __repr__(self): + return f"# RelativePositionalEmbedding (dim={self.dim}, range=[{self.min}, {self.max}])" + + def __call__(self, input_idxes): + # input should be a list of ints or int + if isinstance(input_idxes, int): + input_idxes = [input_idxes] + clamped_idx_repr = BK.clamp(BK.input_idx(input_idxes), min=self.min, max=self.max) - self.min + h0 = BK.lookup(self.E, clamped_idx_repr) + return h0 diff --git a/msp/nn/layers/multi.py b/msp/nn/layers/multi.py index 7928640..bcab2fe 100644 --- a/msp/nn/layers/multi.py +++ b/msp/nn/layers/multi.py @@ -40,9 +40,6 @@ class Summer(BasicNode): def __init__(self, pc, name=None, init_rop=None): super().__init__(pc, name, init_rop) - def refresh(self, rop=None): - pass - # NO dropouts!! def __call__(self, input_list): # todo(warn): not that efficient! @@ -61,9 +58,6 @@ class Concater(BasicNode): def __init__(self, pc, name=None, init_rop=None): super().__init__(pc, name, init_rop) - def refresh(self, rop=None): - pass - # NO dropouts!! def __call__(self, input_list): return BK.concat(input_list, -1) @@ -72,8 +66,30 @@ def get_output_dims(self, *input_dims): xs = input_dims[0] return sum(xs) +# sum the inputs with gates +class GatedMixer(BasicNode): + def __init__(self, pc, dim, num_input, name=None, init_rop=None): + super().__init__(pc, name, init_rop) + self.dim = dim + self.num_input = num_input + # + self.ff = self.add_sub_node("g", Affine(pc, dim*num_input, dim*num_input, act="sigmoid")) + self.ff2 = self.add_sub_node("f", Affine(pc, dim*num_input, dim, act="tanh")) + + def __call__(self, input_list): + concat_input_t = BK.concat(input_list, -1) + gates_t = self.ff(concat_input_t) + final_output_t = self.ff2(concat_input_t * gates_t) + # output_shape = BK.get_shape(concat_output_t)[:-1] + [self.num_input, -1] + # # [*, num, D] ->(avg)-> [*, D] + # return BK.avg(concat_output_t.view(output_shape), dim=-2) + return final_output_t + + def get_output_dims(self, *input_dims): + return (self.dim, ) + # input -> various outputs -> join -# -- join mode can be "cat" or "add" +# -- join mode can be "cat", "add" class Joiner(BasicNode): def __init__(self, pc, node_iter, mode="cat", name=None, init_rop=None): super().__init__(pc, name, init_rop) @@ -117,11 +133,11 @@ def get_output_dims(self, *input_dims): # Adding with the first arg class AddNormWrapper(NodeWrapper): - def __init__(self, node, node_last_dims): + def __init__(self, node, node_last_dims, std_no_grad=False): super().__init__(node, node_last_dims) self.size = self.input_ns[0] zcheck(self.size==self.output_ns[0], "AddNormWrapper meets unequal dims.") - self.normer = self.add_sub_node("n", LayerNorm(self.pc, self.size)) + self.normer = self.add_sub_node("n", LayerNorm(self.pc, self.size, std_no_grad=std_no_grad)) def get_output_dims(self, *input_dims): return (self.size, ) @@ -165,11 +181,12 @@ def __call__(self, *args): # shortcuts # shortcur for MLP -def get_mlp(pc, n_ins, n_out, n_hidden, n_hidden_layer=1, hidden_act="tanh", final_act="linear", hidden_bias=True, final_bias=True, hidden_init_rop=None, final_init_rop=None): +def get_mlp(pc, n_ins, n_out, n_hidden, n_hidden_layer=1, hidden_act="tanh", final_act="linear", hidden_bias=True, final_bias=True, hidden_init_rop=None, final_init_rop=None, hidden_which_affine=2): layer_dims = [n_ins] + [n_hidden]*n_hidden_layer nodes = [] for idx in range(n_hidden_layer): - hidden_one = Affine(pc, layer_dims[idx], layer_dims[idx+1], act=hidden_act, bias=hidden_bias, init_rop=hidden_init_rop) + hidden_one = Affine(pc, layer_dims[idx], layer_dims[idx+1], act=hidden_act, bias=hidden_bias, + init_rop=hidden_init_rop, which_affine=hidden_which_affine) nodes.append(hidden_one) nodes.append(Affine(pc, layer_dims[-1], n_out, act=final_act, bias=final_bias, init_rop=final_init_rop)) return Sequential(pc, nodes, name="mlp") diff --git a/msp/nn/modules/__init__.py b/msp/nn/modules/__init__.py index a1a1ce9..5325652 100644 --- a/msp/nn/modules/__init__.py +++ b/msp/nn/modules/__init__.py @@ -4,3 +4,4 @@ from .embedder import EmbedConf, MyEmbedder from .encoder import EncConf, MyEncoder +from .berter import BerterConf, Berter diff --git a/msp/nn/modules/berter.py b/msp/nn/modules/berter.py new file mode 100644 index 0000000..7f3fd6e --- /dev/null +++ b/msp/nn/modules/berter.py @@ -0,0 +1,443 @@ +# + +# extract features from pretrained bert + +try: + from transformers import BertModel, BertTokenizer, BertForMaskedLM +except: + transformers = None + +import numpy as np +from typing import Tuple, Iterable, List +from collections import namedtuple + +from msp.utils import Conf, zcheck, zlog, zwarn, Helper +from msp.nn import BK + +# +class BerterConf(Conf): + def __init__(self): + # basic + self.bert_model = "bert-base-multilingual-cased" # or "bert-base-cased", "bert-large-cased", "bert-base-chinese" + self.bert_lower_case = False + self.bert_layers = [-1] # which layers to extract, use concat for final output + self.bert_trainable = False + self.bert_to_device = True # perform things on default devices + # cache dir for downloading bert models + self.bert_cache_dir = "" + # for feature extracting + self.bert_sent_extend = 0 # use how many context sentence before and after (but overall is constrained within bert's input limit) + self.bert_tok_extend = 0 # similar to the sent one, extra tok constrains + self.bert_extend_sent_bound = True # extend at sent boundary + self.bert_root_mode = 0 # specially for artificial root, 0 means nothing, 1 means alwasy CLS, -1 means previous sub-token + self.bert_previous_rate = 0.75 # if limited budget, what is the rate that we prefer for the previous contexts + # specific + self.bert_batch_size = 1 # forwarding batch size + self.bert_single_len = 256 # if segement length >= this, then ignore bsize and forward one by one + # extra special one, zero padding embedding + self.bert_zero_pademb = False # whether make padding embedding zero vector + +# +SubwordToks = namedtuple('SubwordToks', ['subword_toks', 'subword_ids', 'subword_is_start']) + +# +class Berter(object): + def __init__(self, pc: BK.ParamCollection, bconf: BerterConf): + # super().__init__(pc, None, None) + # TODO(+N): currently use freezed bert features, therefore, also not a BasicNode + self.bconf = bconf + assert not bconf.bert_trainable, "Currently only using this part for feature extractor" + MODEL_NAME = bconf.bert_model + zlog(f"Loading pre-trained bert model of {MODEL_NAME}") + # Load pretrained model/tokenizer + self.tokenizer = BertTokenizer.from_pretrained(MODEL_NAME, do_lower_case=bconf.bert_lower_case, + cache_dir=None if (not bconf.bert_cache_dir) else bconf.bert_cache_dir) + # model = BertModel.from_pretrained(MODEL_NAME, cache_dir="./") + self.model = BertModel.from_pretrained(MODEL_NAME, output_hidden_states=True, + cache_dir=None if (not bconf.bert_cache_dir) else bconf.bert_cache_dir) + self.model.eval() + # ===== + # for word predictions + voc = self.tokenizer.vocab + self.tok_word_list = [None] * len(voc) # idx -> str + for s, i in voc.items(): + self.tok_word_list[i] = s + self.wp_model = BertForMaskedLM.from_pretrained(MODEL_NAME, cache_dir=None if (not bconf.bert_cache_dir) else bconf.bert_cache_dir) + self.wp_model.eval() + # ===== + # zero padding embeddings? + if bconf.bert_zero_pademb: + with BK.no_grad_env(): + # todo(warn): specific!! + zlog(f"Unusual operation: make bert's padding embedding (idx0) zero!!") + self.model.embeddings.word_embeddings.weight[0].fill_(0.) + if bconf.bert_to_device: + zlog(f"Moving bert model to default device {BK.DEFAULT_DEVICE}") + BK.to_device(self.model) + BK.to_device(self.wp_model) + + # if not using anymore, release the resources + def delete_self(self): + self.model = None + self.wp_model = None + + # ===== + # only forward to get features + + # prepare subword-tokens and subword-word correspondences (could be document level) + # todo(note): here we assume that the input does not have special ROOT, otherwise that will be strange to bert! + # todo(note): here does not add CLS or SEP, since later we can concatenate things! + def subword_tokenize(self, tokens: List[str], no_special_root: bool, mask_idx=-1, mask_mode="all", mask_repl=""): + assert no_special_root + subword_toks = [] # all sub-tokens, no CLS or SEP + subword_is_start = [] # whether is the start of orig-token positions + subword_ids = [] + for i, t in enumerate(tokens): + cur_toks = self.tokenizer.tokenize(t) + # in some cases, there can be empty strings -> put the original word + if len(cur_toks) == 0: + cur_toks = [t] + # ===== + # mask mode + if i == mask_idx: + mask_tok = mask_repl if mask_repl else self.tokenizer.mask_token + if mask_mode == "all": + cur_toks = [mask_tok] * (len(cur_toks)) + elif mask_mode == "first": + cur_toks[0] = mask_tok + elif mask_mode == "one": + cur_toks = [mask_tok] + elif mask_mode == "pass": + continue # todo(note): special mode, ignore current tok, but need later post-processing + else: + raise NotImplementedError("UNK mask mode") + # ===== + # todo(warn): use the first BPE piece's vector for the whole word + cur_is_start = [0]*(len(cur_toks)) + cur_is_start[0] = 1 + subword_is_start.extend(cur_is_start) + subword_toks.extend(cur_toks) + subword_ids.extend(self.tokenizer.convert_tokens_to_ids(cur_toks)) + assert len(subword_toks) == len(subword_is_start) + assert len(subword_toks) == len(subword_ids) + return SubwordToks(subword_toks, subword_ids, subword_is_start) + + # forward with a collection of sentences, input is List[(subwords, starts)] + # todo(note): currently all change to np.ndarray and return, + # todo(+N): maybe more efficient to directly return tensor, but how to deal with batch? + def extract_features(self, sents: List[Tuple]): + # ===== + MAX_LEN = 510 # save two for [CLS] and [SEP] + BACK_LEN = 100 # for splitting cases, still remaining some of previous sub-tokens for context + model = self.model + tokenizer = self.tokenizer + CLS, CLS_IDX = tokenizer.cls_token, tokenizer.cls_token_id + SEP, SEP_IDX = tokenizer.sep_token, tokenizer.sep_token_id + # ===== + # prepare the forwarding + bconf = self.bconf + bert_sent_extend = bconf.bert_sent_extend + bert_tok_extend = bconf.bert_tok_extend + bert_extend_sent_bound = bconf.bert_extend_sent_bound + bert_root_mode = bconf.bert_root_mode + # + num_sent = len(sents) + all_ids = [z.subword_ids for z in sents] + # all_subwords = [z.subword_toks for z in sents] + all_prep_sents = [] + for sent_idx in range(num_sent): + cur_sent = sents[sent_idx] + center_subwords, center_ids, center_starts = cur_sent.subword_toks, cur_sent.subword_ids, cur_sent.subword_is_start + # extend extra tokens + # prev_all_subwords, post_all_subwords = self.extend_subwords( + # all_subwords, sent_idx, MAX_LEN-len(center_ids), bert_sent_extend, bert_tok_extend, bert_extend_sent_bound) + prev_all_ids, post_all_ids = self.extend_subwords( + all_ids, sent_idx, MAX_LEN-len(center_ids), bert_sent_extend, bert_tok_extend, bert_extend_sent_bound) + # put them together + cur_all_ids = [CLS_IDX] + prev_all_ids + center_ids + post_all_ids + [SEP_IDX] + cur_all_starts = [0] * (len(prev_all_ids)+1) + if bert_root_mode == -1: + cur_all_starts[-1] = 1 # previous token (can also be CLS if no prev ones) + elif bert_root_mode == 1: + cur_all_starts[0] = 1 # always take CLS + cur_all_starts = cur_all_starts + center_starts + [0] * (len(post_all_ids)+1) + cur_all_len = len(cur_all_ids) + assert cur_all_len == len(cur_all_starts) + all_prep_sents.append((cur_all_ids, cur_all_starts, cur_all_len, sent_idx)) # put more info here + # forward + bert_batch_size = bconf.bert_batch_size + bert_single_len = bconf.bert_single_len + hit_single_thresh = (bert_batch_size==1) # if bsize==1, directly single mode + all_prep_sents.sort(key=lambda x: x[-2]) # sort by sent length + ret_arrs = [None] * num_sent + with BK.no_grad_env(): + cur_ss_idx = 0 # idx for sorted sentence + while cur_ss_idx < len(all_prep_sents): + if all_prep_sents[cur_ss_idx][-2] >= bert_single_len: + hit_single_thresh = True + # if already hit single thresh + if hit_single_thresh: + cur_single_sent = all_prep_sents[cur_ss_idx] + one_feat_arr = self.forward_single(cur_single_sent[0], cur_single_sent[1]) + one_orig_idx = cur_single_sent[-1] + assert ret_arrs[one_orig_idx] is None + ret_arrs[one_orig_idx] = one_feat_arr + cur_ss_idx += 1 + else: + # collect current batch + cur_batch = [] + for _ in range(bert_batch_size): + if cur_ss_idx >= len(all_prep_sents): + break + one_sent = all_prep_sents[cur_ss_idx] + if one_sent[-2] >= bert_single_len: + break + cur_batch.append(one_sent) + cur_ss_idx += 1 + if len(cur_batch) > 0: + cur_batch_arrs = self.forward_batch([z[0] for z in cur_batch], [z[1] for z in cur_batch]) + for one_feat_arr, one_single_sent in zip(cur_batch_arrs, cur_batch): + one_orig_idx = one_single_sent[-1] + assert ret_arrs[one_orig_idx] is None + ret_arrs[one_orig_idx] = one_feat_arr + return ret_arrs + + # simple mode for extracting features, ignoring certain confs + # add one with CLS, but not SEP + def extract_feature_simple_mode(self, sents: List[Tuple]): + # ===== + MAX_LEN = 510 # save two for [CLS] and [SEP] + tokenizer = self.tokenizer + CLS, CLS_IDX = tokenizer.cls_token, tokenizer.cls_token_id + SEP, SEP_IDX = tokenizer.sep_token, tokenizer.sep_token_id + # ===== + ret_arrs = [] + for one_sent in sents: + assert len(one_sent.subword_ids) <= MAX_LEN + cur_all_ids = [CLS_IDX] + one_sent.subword_ids + [SEP_IDX] + cur_all_starts = [1] + one_sent.subword_is_start + [0] + one_feat_arr = self.forward_single(cur_all_ids, cur_all_starts) + ret_arrs.append(one_feat_arr) + return ret_arrs + + # ===== + # helpers + + def forward_features(self, ids_expr, mask_expr, output_layers): + # token_type_ids are by default 0 + final_layer, _, encoded_layers = self.model(ids_expr, attention_mask=mask_expr) + concated_expr = BK.concat([encoded_layers[li] for li in output_layers], -1) # [bsize, slen, DIM*layer] + return concated_expr + + def forward_single(self, cur_ids: List[int], cur_starts: List[int]): + # todo(note): we may need to split long segment in single mode + tokenizer, model = self.tokenizer, self.model + output_layers = self.bconf.bert_layers + # + MAX_LEN = 510 # save two for [CLS] and [SEP] + BACK_LEN = 100 # for splitting cases, still remaining some of previous sub-tokens for context + CLS, CLS_IDX = tokenizer.cls_token, tokenizer.cls_token_id + SEP, SEP_IDX = tokenizer.sep_token, tokenizer.sep_token_id + # + if len(cur_ids) < MAX_LEN+2: + final_outputs = self.forward_features(BK.input_idx(cur_ids).unsqueeze(0), None, output_layers)[0] #[La, *] + ret = final_outputs[BK.input_idx(cur_starts) == 1] # direct masked select + else: + # forwarding (multiple times within the model max-length constrain) + all_outputs = [] + cur_sub_idx = 0 + while cur_sub_idx < len(cur_ids)-1: # minus 1 to ignore ending SEP + cur_slice_start = max(1, cur_sub_idx - BACK_LEN) + cur_slice_end = min(cur_slice_start + MAX_LEN, len(cur_ids)-1) + cur_toks = [CLS_IDX] + cur_ids[cur_slice_start:cur_slice_end] + [SEP_IDX] + features = self.forward_features(BK.input_idx(cur_toks).unsqueeze(0), None, output_layers) # [bs, L, *] + cur_features = features[0] # [L, *] + assert len(cur_features) == len(cur_toks) + # only include CLS in the first run, no SEP included + if cur_sub_idx == 0: + # include CLS, exclude SEP + all_outputs.append(cur_features[:-1]) + else: + # include only new ones, discard BACK ones, exclude CLS, SEP + all_outputs.append(cur_features[cur_sub_idx-cur_slice_start+1:-1]) + zwarn(f"Add multiple-seg range: [{cur_slice_start}, {cur_sub_idx}, {cur_slice_end})] " + f"for all-len={len(cur_ids)}") + cur_sub_idx = cur_slice_end + final_outputs = BK.concat(all_outputs, 0) # [La, *] + ret = final_outputs[BK.input_idx(cur_starts[:-1])==1] # todo(note) SEP is surely not selected + return BK.get_value(ret) # directly return one arr arr(real-len, DIM*layer) + + def forward_batch(self, batched_ids: List[List], batched_starts: List[List]): + PAD_IDX = self.tokenizer.pad_token_id + output_layers = self.bconf.bert_layers + # + bsize = len(batched_ids) + max_len = max(len(z) for z in batched_ids) + input_shape = (bsize, max_len) + # first collect on CPU + input_ids_arr = np.full(input_shape, PAD_IDX, dtype=np.int64) + input_mask_arr = np.full(input_shape, 0, dtype=np.float32) + input_is_start = np.full(input_shape, 0, dtype=np.int64) + for bidx in range(bsize): + cur_ids, cur_starts = batched_ids[bidx], batched_starts[bidx] + cur_len = len(cur_ids) + input_ids_arr[bidx, :cur_len] = cur_ids + input_is_start[bidx, :cur_len] = cur_starts + input_mask_arr[bidx, :cur_len] = 1. + # then real forward + # todo(note): for batched mode, assume things are all within bert's max-len + features = self.forward_features(BK.input_idx(input_ids_arr), BK.input_real(input_mask_arr), output_layers) # [bs, slen, *] + start_idxes, start_masks = BK.mask2idx(BK.input_idx(input_is_start).float()) # [bsize, ?] + start_expr = BK.gather_first_dims(features, start_idxes, 1) # [bsize, ?, DIM*layer] + # get values + start_expr_arr = BK.get_value(start_expr) + start_valid_len_arr = BK.get_value(start_masks.sum(-1).int()) + return [v[:slen] for v, slen in zip(start_expr_arr, start_valid_len_arr)] # List[arr(real-len, DIM*layer)] + + def extend_subwords(self, sents, center_sid, budget, bert_sent_extend, bert_tok_extend, bert_extend_sent_bound): + # todo(note): we prefer previous contexts + budget_prev = max(0, budget * self.bconf.bert_previous_rate) + budget_prev = min(bert_tok_extend, budget_prev) + # collect prev ones + included_prev_sents = [] + for step in range(bert_sent_extend): + one_sid = center_sid-1-step + if one_sid>=0: + prev_subwords = sents[one_sid] + if budget_prev >= len(prev_subwords): + included_prev_sents.append(prev_subwords) + budget_prev -= len(prev_subwords) # decrease budget + else: + if not bert_extend_sent_bound: # we can further add sub-sent + included_prev_sents.append(prev_subwords[-budget_prev:]) + budget_prev = 0 + break # budget run out + else: + break # start of doc + prev_all_subwords = [] + for one_prev_subwords in reversed(included_prev_sents): # remember to reverse + prev_all_subwords.extend(one_prev_subwords) + # collect post ones + post_all_subwords = [] + budget_post = max(0, budget - len(prev_all_subwords)) # remaining budget for post + budget_post = min(bert_tok_extend, budget_post) + for step in range(bert_sent_extend): + one_sid = center_sid+step+1 + if one_sid < len(sents): + post_subwords = sents[one_sid] + if budget_post >= len(post_subwords): + post_all_subwords.extend(post_subwords) # driectly extend + budget_post -= len(post_subwords) + else: + if not bert_extend_sent_bound: # we can further add sub-sent + post_all_subwords.extend(post_subwords[:budget_post]) + budget_post = 0 + break # not enough budget + else: + break # end of doc + return prev_all_subwords, post_all_subwords + + # ===== + # using for word prediction + + # def predict_word_simple_mode(self, sents: List[Tuple], original_sent): + # # ===== + # MAX_LEN = 510 # save two for [CLS] and [SEP] + # tokenizer = self.tokenizer + # CLS, CLS_IDX = tokenizer.cls_token, tokenizer.cls_token_id + # SEP, SEP_IDX = tokenizer.sep_token, tokenizer.sep_token_id + # # ===== + # rets = [] # (list of combined-str, list of list of str) + # orig_subword_ids = [CLS_IDX] + original_sent.subword_ids + [SEP_IDX] + # for one_sent in sents: + # cur_subword_ids = one_sent.subword_ids + # cur_subword_is_start = one_sent.subword_is_start + # # calculate + # assert len(cur_subword_ids) <= MAX_LEN + # cur_all_ids = [CLS_IDX] + cur_subword_ids + [SEP_IDX] + # word_scores, = self.wp_model(BK.input_idx(cur_all_ids).unsqueeze(0)) # [1, slen+2, Vocab] + # word_scores2 = BK.log_softmax(word_scores, -1) + # max_scores, max_idxes = word_scores2[0].max(-1) + # # todo(note): length mismatch if in one mode!! + # orig_scores = word_scores2[0].gather(-1, BK.input_idx(orig_subword_ids).unsqueeze(-1)).squeeze(-1) + # # then ignore CLS and SEP and group by starts + # max_strs = [self.tok_word_list[z] for z in BK.get_value(max_idxes)[1:-1]] + # max_scores_arr = BK.get_value(max_scores)[1:-1] + # orig_scores_arr = BK.get_value(orig_scores)[1:-1] + # assert len(max_strs) == len(cur_subword_is_start) + # one_all_strs = [] + # one_all_lists = [] + # one_all_scores = [] + # one_all_orig_scores = [] + # for this_str, this_is_start, this_score, this_orig_score in \ + # zip(max_strs, cur_subword_is_start, max_scores_arr, orig_scores_arr): + # if this_is_start: + # one_all_strs.append(this_str) + # one_all_lists.append([this_str]) + # one_all_scores.append([this_score]) + # one_all_orig_scores.append([this_orig_score]) + # else: + # one_all_strs[-1] = one_all_strs[-1] + this_str.lstrip("##") + # one_all_lists[-1].append(this_str) # todo(note): error if starting with not is-start + # one_all_scores[-1].append(this_score) + # one_all_orig_scores[-1].append(this_orig_score) + # rets.append((one_all_strs, one_all_lists, + # [np.average(z) for z in one_all_scores], [np.average(z) for z in one_all_orig_scores])) + # return rets + + # mask out each word and predict + def predict_each_word(self, tokens: List[str]): + # todo(+N): looking inside the transformers (v2.0.0) class, may break in the future version?? + # and for prediction, we only put one MASK for predicting one subword + tokenizer = self.tokenizer + CLS, CLS_IDX = tokenizer.cls_token, tokenizer.cls_token_id + SEP, SEP_IDX = tokenizer.sep_token, tokenizer.sep_token_id + MASK, MASK_IDX = tokenizer.mask_token, tokenizer.mask_token_id + # two forwards: get sum-log-probs and best one word replacement + # first do tokenization and prepare inputs + # - original subwords + orig_subword_lists = [] # [orig_len] + orig_word_start = [] # [orig_len] + orig_word_sublen = [] # [orig_len] + orig_subword_ids = [CLS_IDX] # [1+extened_len+1] + for i,t in enumerate(tokens): + cur_toks = tokenizer.tokenize(t) + # in some cases, there can be empty strings -> put the original word + if len(cur_toks) == 0: + cur_toks = [t] + orig_subword_lists.append(cur_toks) + cut_ids = tokenizer.convert_tokens_to_ids(cur_toks) + orig_word_start.append(len(orig_subword_ids)) + orig_word_sublen.append(len(cut_ids)) + orig_subword_ids.extend(cut_ids) + orig_subword_ids.append(SEP_IDX) + # forw1 (LM like) + forw1_hiddens = [] # [extend_len] + for one_start, one_sublen in zip(orig_word_start, orig_word_sublen): + for j in range(one_sublen): + cur_inputs = orig_subword_ids[:one_start+j] + [MASK_IDX] * (one_sublen-j) + orig_subword_ids[one_start+one_sublen:] + # todo(note): not API + cur_hidden = self.wp_model.bert(BK.input_idx(cur_inputs).unsqueeze(0))[0][0][one_start+j] # [D] + forw1_hiddens.append(cur_hidden) + # todo(note): not API + forw1_logits = self.wp_model.cls(BK.stack(forw1_hiddens, 0)) # [ELEN, V] + forw1_all_scores = BK.log_softmax(forw1_logits, -1) + forw1_ext_scores = forw1_all_scores.gather(-1, BK.input_idx(orig_subword_ids[1:-1]).unsqueeze(-1)).squeeze(-1) # [ELEN] + forw1_ext_scores_arr = BK.get_value(forw1_ext_scores) # [ELEN] + # todo(note): sum of logprobs, remember to minus one to exclude CLS + forw1_scores_arr = np.asarray([np.sum(forw1_ext_scores_arr[one_start-1:one_start-1+one_sublen]) + for one_start, one_sublen in zip(orig_word_start, orig_word_sublen)]) # [OLEN] + # forw2 (prediction with one word) + forw2_hiddens = [] # [orig_len] + for one_start, one_sublen in zip(orig_word_start, orig_word_sublen): + cur_inputs = orig_subword_ids[:one_start] + [MASK_IDX] + orig_subword_ids[one_start+one_sublen:] + # todo(note): not API + cur_hidden = self.wp_model.bert(BK.input_idx(cur_inputs).unsqueeze(0))[0][0][one_start] # [D] + forw2_hiddens.append(cur_hidden) + # todo(note): not API + forw2_logits = self.wp_model.cls(BK.stack(forw2_hiddens, 0)) # [OLEN, V] + forw2_all_scores = BK.log_softmax(forw2_logits, -1) + forw2_max_scores, forw2_max_idxes = forw2_all_scores.max(-1) + forw2_max_scores_arr, forw2_max_idxes_arr = BK.get_value(forw2_max_scores), BK.get_value(forw2_max_idxes) # [OLEN] + forw2_max_strs = [self.tok_word_list[z] for z in forw2_max_idxes_arr] + return forw1_scores_arr, forw2_max_scores_arr, forw2_max_strs, orig_subword_lists diff --git a/msp/nn/modules/embedder.py b/msp/nn/modules/embedder.py index 04325de..2472bc7 100644 --- a/msp/nn/modules/embedder.py +++ b/msp/nn/modules/embedder.py @@ -6,7 +6,8 @@ from msp.utils import Conf, zcheck from msp.data import VocabPackage from msp.nn import BK -from msp.nn.layers import BasicNode, Embedding, CnnLayer, PosiEmbedding, Affine, LayerNorm, Sequential, DropoutLastN +from msp.nn.layers import BasicNode, Embedding, CnnLayer, PosiEmbedding, Affine, LayerNorm, Sequential, \ + DropoutLastN, Dropout # class EmbedConf(Conf): @@ -30,6 +31,9 @@ def __init__(self): # self.extra_names = ["pos", ] self.dim_extras = [] self.extra_names = [] + # aux direct features, like mbert ... + self.dim_auxes = [] + self.fold_auxes = [] # how many folds per aux-repr, used to perform Elmo's styled linear combination # # -- project the concat before contextual encoders self.emb_proj_dim = 0 @@ -38,6 +42,11 @@ def __init__(self): def do_validate(self): # confirm int for dims self.dim_extras = [int(z) for z in self.dim_extras] + self.dim_auxes = [int(z) for z in self.dim_auxes] + self.fold_auxes = [int(z) for z in self.fold_auxes] + # by default, fold=1 + fold_gap = max(0, len(self.dim_auxes) - len(self.fold_auxes)) + self.fold_auxes += [1] * fold_gap # combining the inputs after projected to embeddings # todo(0): names in vpack should be consistent: word, char, pos, ... @@ -73,8 +82,19 @@ def __init__(self, pc: BK.ParamCollection, econf: EmbedConf, vpack: VocabPackage self.extra_embeds = [] for one_extra_dim, one_name in zip(self.dim_extras, self.extra_names): self.extra_embeds.append(self.add_sub_node( - "ext", Embedding(self.pc, len(vpack.get_voc(one_name)), one_extra_dim, name="extra"))) + "ext", Embedding(self.pc, len(vpack.get_voc(one_name)), one_extra_dim, + npvec=vpack.get_emb(one_name, None), name="extra:"+one_name))) repr_sizes.append(one_extra_dim) + # auxes + self.dim_auxes = econf.dim_auxes + self.fold_auxes = econf.fold_auxes + self.aux_overall_gammas = [] + self.aux_fold_lambdas = [] + for one_aux_dim, one_aux_fold in zip(self.dim_auxes, self.fold_auxes): + repr_sizes.append(one_aux_dim) + # aux gamma and fold trainable lambdas + self.aux_overall_gammas.append(self.add_param("AG", (), 1.)) # scalar + self.aux_fold_lambdas.append(self.add_param("AL", (), [1./one_aux_fold for _ in range(one_aux_fold)])) # [#fold] # ===== # another projection layer? & set final dim zcheck(len(repr_sizes)>0, "No inputs?") @@ -96,6 +116,8 @@ def __init__(self, pc: BK.ParamCollection, econf: EmbedConf, vpack: VocabPackage self.dropmd_word = self.add_sub_node("md", DropoutLastN(pc, lastn=1)) self.dropmd_char = self.add_sub_node("md", DropoutLastN(pc, lastn=1)) self.dropmd_extras = [self.add_sub_node("md", DropoutLastN(pc, lastn=1)) for _ in self.extra_names] + # dropouts for aux + self.drop_auxes = [self.add_sub_node("aux", Dropout(pc, (one_aux_dim,))) for one_aux_dim in self.dim_auxes] # def get_output_dims(self, *input_dims): @@ -104,9 +126,11 @@ def get_output_dims(self, *input_dims): def __repr__(self): return "# MyEmbedder: %s -> %s" % (self.repr_sizes, self.output_dim) - # word_arr: None or [*, seq-len], char_arr: None or [*, seq-len, word-len], extra_arrs: list of [*, seq-len] + # word_arr: None or [*, seq-len], char_arr: None or [*, seq-len, word-len], + # extra_arrs: list of [*, seq-len], aux_arrs: list of [*, seq-len, D] # todo(warn): no masks in this step? - def __call__(self, word_arr:np.ndarray=None, char_arr:np.ndarray=None, extra_arrs:Iterable[np.ndarray]=()): + def __call__(self, word_arr:np.ndarray=None, char_arr:np.ndarray=None, + extra_arrs:Iterable[np.ndarray]=(), aux_arrs: Iterable[np.ndarray]=()): exprs = [] # word/char/extras/posi seq_shape = None @@ -132,6 +156,20 @@ def __call__(self, word_arr:np.ndarray=None, char_arr:np.ndarray=None, extra_arr posi_input0 = BK.unsqueeze(posi_input0, 0) posi_input1 = BK.expand(posi_input0, tuple(seq_shape)+(-1,)) exprs.append(posi_input1) + # + assert len(aux_arrs) == len(self.drop_auxes) + for one_aux_arr, one_aux_dim, one_aux_drop, one_fold, one_gamma, one_lambdas in \ + zip(aux_arrs, self.dim_auxes, self.drop_auxes, self.fold_auxes, self.aux_overall_gammas, self.aux_fold_lambdas): + # fold and apply trainable lambdas + input_aux_repr = BK.input_real(one_aux_arr) + input_shape = BK.get_shape(input_aux_repr) + # todo(note): assume the original concat is [fold/layer, D] + reshaped_aux_repr = input_aux_repr.view(input_shape[:-1]+[one_fold, one_aux_dim]) # [*, slen, fold, D] + lambdas_softmax = BK.softmax(one_gamma, -1).unsqueeze(-1) # [fold, 1] + weighted_aux_repr = (reshaped_aux_repr * lambdas_softmax).sum(-2) * one_gamma # [*, slen, D] + one_aux_expr = one_aux_drop(weighted_aux_repr) + exprs.append(one_aux_expr) + # concated_exprs = BK.concat(exprs, dim=-1) # optional proj if self.has_proj: diff --git a/msp/nn/modules/encoder.py b/msp/nn/modules/encoder.py index 1390a10..39f6951 100644 --- a/msp/nn/modules/encoder.py +++ b/msp/nn/modules/encoder.py @@ -2,31 +2,49 @@ from typing import List # shared encoder for various parsing methods -from msp.utils import Conf, zfatal, zcheck +from msp.utils import Conf, zfatal, zcheck, zwarn, zlog from msp.nn import BK, layers -from msp.nn.layers import BasicNode, RnnLayer, RnnLayerBatchFirstWrapper, CnnLayer, TransformerEncoder +from msp.nn.layers import BasicNode, RnnLayer, RnnLayerBatchFirstWrapper, CnnLayer, TransformerEncoder, \ + Transformer2Encoder, AttConf, Sequential, Dropout # conf class EncConf(Conf): def __init__(self): self._input_dim = -1 # to be filled self.enc_hidden = 400 # concat-dimension - self.enc_ordering = ["rnn","cnn","att"] # default short-range to long-range + self.enc_ordering = ["rnn","cnn","att","att2"] # default short-range to long-range + self.no_final_dropout = False # disable dropout for the final layer of this module # various encoders - self.enc_rnn_type = "lstm" + # rnn + self.enc_rnn_type = "lstm2" self.enc_rnn_layer = 1 self.enc_rnn_bidirect = True + # cnn self.enc_cnn_windows = [3, 5] # split dim by windows self.enc_cnn_layer = 0 - self.enc_att_rel_clip = 0 # relative positional reprs if >0 - self.enc_att_rel_neg = True # use neg for rel-posi - self.enc_att_ff = 512 + # att(basic) self.enc_att_layer = 0 - self.enc_att_dropout = 0.1 # special attention dropout + self.enc_att_conf = AttConf() self.enc_att_add_wrapper = "addnorm" - self.enc_att_type = "mh" # multi-head - self.enc_attf_selfw = -100. - self.enc_att_use_ranges = False + self.enc_att_ff = 512 + self.enc_att_fixed_ranges = [] # should sth like 2,4,8,16,32,64 + self.enc_att_final_act = "linear" + # reuse some of the options in original att + self.enc_att2_layer = 0 + self.enc_att2_conf = AttConf() + self.enc_att2_short_range = 3 + self.enc_att2_long_ranges = [] # similar to enc_att_fixed_ranges + + def do_validate(self): + def _res(rs, checked_length): + if rs is None or len(rs) == 0: + return None + else: + assert len(rs) == checked_length + return [int(x) for x in rs] + # make sure the ranges are ok! + self.enc_att_fixed_ranges = _res(self.enc_att_fixed_ranges, self.enc_att_layer) + self.enc_att2_long_ranges = _res(self.enc_att2_long_ranges, self.enc_att2_layer) # various kinds of encoders class MyEncoder(BasicNode): @@ -47,21 +65,30 @@ def __init__(self, pc: BK.ParamCollection, econf: EncConf): rnn_enc_size = self.enc_hidden//2 if rnn_bidirect else self.enc_hidden rnn_layer = self.add_sub_node("rnn", RnnLayerBatchFirstWrapper(pc, RnnLayer(pc, last_dim, rnn_enc_size, econf.enc_rnn_layer, node_type=econf.enc_rnn_type, bidirection=rnn_bidirect))) self.layers.append(rnn_layer) + # todo(+2): different i/o sizes for cnn and att? elif name == "cnn": if econf.enc_cnn_layer > 0: per_cnn_size = self.enc_hidden // len(econf.enc_cnn_windows) - cnn_layer = self.add_sub_node("cnn", layers.Sequential(pc, [CnnLayer(pc, last_dim, per_cnn_size, econf.enc_cnn_windows, act="relu") for _ in range(econf.enc_cnn_layer)])) + cnn_layer = self.add_sub_node("cnn", Sequential(pc, [CnnLayer(pc, last_dim, per_cnn_size, econf.enc_cnn_windows, act="elu") for _ in range(econf.enc_cnn_layer)])) self.layers.append(cnn_layer) elif name == "att": if econf.enc_att_layer > 0: zcheck(last_dim == self.enc_hidden, "I/O should have same dim for Att-Enc") - att_layer = self.add_sub_node("att", TransformerEncoder(pc, econf.enc_att_layer, last_dim, econf.enc_att_ff, att_type=econf.enc_att_type, add_wrapper=econf.enc_att_add_wrapper, att_dropout=econf.enc_att_dropout, att_use_ranges=econf.enc_att_use_ranges, attf_selfw=econf.enc_attf_selfw, clip_dist=econf.enc_att_rel_clip, use_neg_dist=econf.enc_att_rel_neg)) + att_layer = self.add_sub_node("att", TransformerEncoder(pc, econf.enc_att_layer, last_dim, econf.enc_att_ff, econf.enc_att_add_wrapper, econf.enc_att_conf, final_act=econf.enc_att_final_act, fixed_range_vals=econf.enc_att_fixed_ranges)) self.layers.append(att_layer) + elif name == "att2": + if econf.enc_att2_layer > 0: + zcheck(last_dim == self.enc_hidden, "I/O should have same dim for Att-Enc") + att2_layer = self.add_sub_node("att2", Transformer2Encoder(pc, econf.enc_att2_layer, last_dim, econf.enc_att2_conf, short_range=econf.enc_att2_short_range, long_ranges=econf.enc_att2_long_ranges)) + self.layers.append(att2_layer) else: zfatal("Unknown encoder name: "+name) if len(self.layers) > 0: last_dim = self.layers[-1].get_output_dims()[-1] self.output_dim = last_dim + # + if econf.no_final_dropout: + self.disable_final_dropout() def __repr__(self): return "# MyEncoder: %s -> %s [%s]" % (self.input_dim, self.output_dim, ", ".join([str(z) for z in self.layers])) @@ -75,3 +102,28 @@ def __call__(self, embeds_expr, word_mask_arr): for one_node in self.layers: v = one_node(v, word_mask_arr) return v + + # todo(+2): specific for every type! + def disable_final_dropout(self): + if len(self.layers) < 1: + zwarn("Cannot disable final dropout since this Enc layer is empty!!") + else: + # get the final one from sequential + final_layer = self.layers[-1] + while isinstance(final_layer, Sequential): + final_layer = final_layer.ns_[-1] if len(final_layer.ns_) else None + # get final dropout node + final_drop_node: Dropout = None + if isinstance(final_layer, RnnLayerBatchFirstWrapper): + final_drop_nodes = final_layer.rnn_node.drop_nodes + if final_drop_nodes is not None and len(final_drop_nodes)>0: + final_drop_node = final_drop_nodes[-1] + elif isinstance(final_layer, CnnLayer): + final_drop_node = final_layer.drop_node + elif isinstance(final_layer, TransformerEncoder): + pass # todo(note): final is LayerNorm? + if final_drop_node is None: + zwarn(f"Failed at disabling final enc-layer dropout: type={type(final_layer)}: {final_layer}") + else: + final_drop_node.rop.add_fixed_value("hdrop", 0.) + zlog(f"Ok at disabling final enc-layer dropout: type={type(final_layer)}: {final_layer}") diff --git a/msp/nn/modules/labeler.py b/msp/nn/modules/labeler.py new file mode 100644 index 0000000..77cb42a --- /dev/null +++ b/msp/nn/modules/labeler.py @@ -0,0 +1,24 @@ +# + +# the labeling component, managing loss functions and outputing as middle layer (as in attention) +# this module itself usually does not contain parameters, mostly based on outside scores + +from msp.utils import Conf, zfatal, zcheck +from msp.nn import BK, layers +from msp.nn.layers import BasicNode + +class LabConf(Conf): + def __init__(self): + self._label_dim = -1 # to be filled + self.binary_mode = False # does not compare among last dim, but treat each one as its own against 0 + +# todo(+N): currently only take care of single-class gold, there are also situations where we need multi-labels +class MyLabeler(BasicNode): + def __init__(self, pc: BK.ParamCollection, lconf: LabConf): + super().__init__(pc, None, None) + self.conf = lconf + + # [*, D] or [*] + def __call__(self, raw_scores, gold_idxes, **kwargs): + # TODO(!) + pass diff --git a/msp/search/__init__.py b/msp/search/__init__.py new file mode 100644 index 0000000..a5f7180 --- /dev/null +++ b/msp/search/__init__.py @@ -0,0 +1,4 @@ +# + +# Transition-System (here) / Search-Graph (here) / Computation-Graph (managed by NN-Toolkit) + diff --git a/msp/search/hsg/__init__.py b/msp/search/hsg/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/msp/search/lsg/__init__.py b/msp/search/lsg/__init__.py new file mode 100644 index 0000000..9d6779e --- /dev/null +++ b/msp/search/lsg/__init__.py @@ -0,0 +1,10 @@ +# + +# the linear search graph (lsg) related tools + +# todo(warn): this submodule can be re-written in cython/c++ to speed up + +from .search import Agenda, Searcher, Component +from .search_bfs import BfsAgenda, BfsSearcher, BfsSearcherFactory, \ + BfsExpander, BfsLocalSelector, BfsGlobalArranger, BfsEnder +from .system import State, Action, Candidates, Graph, Oracler, Coster, Signaturer diff --git a/msp/search/lsg/search.py b/msp/search/lsg/search.py new file mode 100644 index 0000000..2e769ba --- /dev/null +++ b/msp/search/lsg/search.py @@ -0,0 +1,24 @@ +# + +from typing import List + +# +class Agenda: + def is_end(self): + raise NotImplementedError() + +# +class Searcher: + def refresh(self, *args, **kwargs): + raise NotImplementedError() + + def go(self, ags: List[Agenda]): + raise NotImplementedError() + +# +class Component: + def refresh(self, *args, **kwargs): + raise NotImplementedError() + + def __repr__(self): + return f"<{self.__class__.__name__}>" diff --git a/msp/search/lsg/search_bfs.py b/msp/search/lsg/search_bfs.py new file mode 100644 index 0000000..f50ad26 --- /dev/null +++ b/msp/search/lsg/search_bfs.py @@ -0,0 +1,328 @@ +# + +# the (batched) BFS-Style Linear Searchers +# the three layers: State, Agenda, Searcher + +from typing import List, Tuple, Dict +from .system import State, Graph, Oracler, Coster, Signaturer +from .search import Agenda, Searcher, Component + +# agenda for one instance +class BfsAgenda(Agenda): + def __init__(self, inst, init_beam: List[State]=None, init_gbeam: List[State]=None): + self.inst = inst + self.beam = [] if init_beam is None else init_beam # plain search beam + self.gbeam = [] if init_gbeam is None else init_gbeam # gold-informed ones + # for loss collection + self.local_golds = [] # gold (oracle) states for local loss + self.special_points = [] # special snapshots of the plain/gold beams (for special update mode in training) + # for results collection + self.ends = [], [] # endings for plain/gold ones + + @staticmethod + def init_agenda(state_creator, inst, require_sg: bool): + sg = None + if require_sg: + sg = Graph() + init_state = state_creator(prev=None, sg=sg, inst=inst) + a = BfsAgenda(inst, init_beam=[init_state]) + return a + + def is_end(self): + return all(x.is_end() for x in self.beam) and all(x.is_end() for x in self.gbeam) + +# a snapshot (special point) of the current status of beam +class BfsSPoint: + def __init__(self, plain_finals, oracle_finals, violation=0.): + self.plain_finals = plain_finals + self.oracle_finals = oracle_finals + self.violation = violation + +# ===== +# scoring helpers + +# todo(WARN): it seems that in previous code here, the margin get doubled? +# But might not be True if cost is different than the simple ones, thus temporally leave it here! + +def state_plain_ranker_getter(margin: float=0.): + return lambda x: x.get_loss_aug_score(margin) + +def state_oracle_ranker(s: State): + return (-s.cost_accu, s.score_accu) + +# ===== +# (batched) main components + +# expanding candidates (with possible scorer calculations) +# this one is mostly system-specific! +class BfsExpander(Component): + def expand(self, ags: List[BfsAgenda]): + raise NotImplementedError() + +# select with multiple modes (plain + oracle) +# some usually-used cases: 1) plain(topk/sample)-search for decoding +class BfsLocalSelector(Component): + def __init__(self, plain_mode, oracle_mode, plain_k_arc, plain_k_label, oracle_k_arc, oracle_k_label, oracler: Oracler): + self.plain_mode = plain_mode + self.oracle_mode = oracle_mode + self.plain_k_arc, self.plain_k_label, self.oracle_k_arc, self.oracle_k_label = \ + plain_k_arc, plain_k_label, oracle_k_arc, oracle_k_label + self.oracler = oracler + + def __repr__(self): + return f"<{self.__class__.__name__}: plain:mode={self.plain_mode},arcK={self.plain_k_arc},labelK={self.plain_k_label}, " \ + f"oracle:mode={self.oracle_mode},arcK={self.oracle_k_arc},labelK={self.oracle_k_label}, oracler={self.oracler}>" + + # ===== + # these two only return flattened results + + def select_plain(self, ags: List[BfsAgenda], candidates, mode, k_arc, k_label) -> List[List]: + raise NotImplementedError() + + def select_oracle(self, ags: List[BfsAgenda], candidates, mode, k_arc, k_label) -> List[List]: + raise NotImplementedError() + + # ===== + # finally assembling final results + + # return List[(plain-list, oracle-list)] + def select(self, ags: List[BfsAgenda], candidates) -> List[Tuple[List, List]]: + plain_results = self.select_plain(ags, candidates, self.plain_mode, self.plain_k_arc, self.plain_k_label) + oracle_results = self.select_oracle(ags, candidates, self.oracle_mode, self.oracle_k_arc, self.oracle_k_label) + # assemble for final results + sidx = 0 + final_results = [] + for ag in ags: + one_cands_plain, one_cands_oracle = [], [] # plain-beam, gold-beam + for plain_state in ag.beam: + cur_plain_nexts, cur_oracle_nexts = plain_results[sidx], oracle_results[sidx] + one_cands_plain.extend(cur_plain_nexts) # plain+plain => plain + one_cands_oracle.extend(cur_oracle_nexts) # plain+oracle => oracle + sidx += 1 + for oracle_state in ag.gbeam: + cur_plain_nexts, cur_oracle_nexts = plain_results[sidx], oracle_results[sidx] + # WHAT.extend(cur_plain_nexts) # oracle+plain => ? todo(note): discard! + one_cands_oracle.extend(cur_oracle_nexts) # oracle+oracle => oracle + sidx += 1 + final_results.append((one_cands_plain, one_cands_oracle)) + assert len(plain_results) == sidx and len(oracle_results) == sidx + return final_results + +class BfsGlobalArranger(Component): + def __init__(self, plain_beam_size: int, gold_beam_size: int, coster: Coster, signaturer: Signaturer): + self.plain_beam_size = plain_beam_size + self.gold_beam_size = gold_beam_size + # components to provide state-info + self.coster = coster + self.signaturer = signaturer + self.margin = 0. + # + self.plain_ranker = state_plain_ranker_getter(self.margin) + self.oracle_ranker = state_oracle_ranker + + def __repr__(self): + return f"<{self.__class__.__name__}: plainK={self.plain_beam_size}, oracleK={self.gold_beam_size}, " \ + f"coster={self.coster}, signaturer={self.signaturer}>" + + def refresh(self, margin=0., **kwargs): + self.margin = margin + self.plain_ranker = state_plain_ranker_getter(self.margin) + + # return whether survived + def _add_or_merged(self, sig_map: Dict, cand: State) -> bool: + cur_sig = cand.sig + if cur_sig is not None: + cur_merger_state = sig_map.get(cur_sig, None) + if cur_merger_state is None: + sig_map[cur_sig] = cand + return True + else: + cand.merge_by(cur_merger_state) + return False + return True + + # rank and select for the beam + def arrange(self, ags: List[BfsAgenda], local_selections: List) -> List[Tuple[List, List]]: + coster, signaturer = self.coster, self.signaturer + plain_beam_size, gold_beam_size = self.plain_beam_size, self.gold_beam_size + # + cur_margin = self.margin + batch_len = len(ags) + global_selections: List[Tuple[List, List]] = [None] * batch_len + assert batch_len == len(local_selections) + # for each instance + for batch_idx in range(batch_len): + pcands, gcands = local_selections[batch_idx] + # plain sort by score+cost*margin, oracle sort by (-cost, score) + # set cost if needed (mostly in training) + if coster is not None: + coster.set_costs(pcands) + pcands.sort(key=self.plain_ranker, reverse=True) + coster.set_costs(gcands) + gcands.sort(key=self.oracle_ranker, reverse=True) # ignore margin for oracle ones + else: + pcands.sort(key=self.plain_ranker, reverse=True) + assert len(gcands) <= 1, "Err: No Coster to rank the oracle states!" + # set signature if needed + if signaturer is not None: + signaturer.set_sigs(pcands) + signaturer.set_sigs(gcands) + # loop them all + id_map, sig_map = {}, {} + plain_finals, oracle_finals = [], [] + for plain_cand in pcands: + if self._add_or_merged(sig_map, plain_cand): + plain_finals.append(plain_cand) + # todo(note): id map is for excluding repeated states between plain and oracle + id_map[(plain_cand.prev, plain_cand.action)] = plain_cand + if len(plain_finals) >= plain_beam_size: + break + for oracle_cand in gcands: + cur_id = (oracle_cand.prev, oracle_cand.action) + if cur_id not in id_map: + if self._add_or_merged(sig_map, oracle_cand): + oracle_finals.append(oracle_cand) + if len(oracle_finals) >= gold_beam_size: + break + global_selections[batch_idx] = (plain_finals, oracle_finals) + return global_selections + +# the preparation for the next step +class BfsEnder(Component): + def __init__(self, ending_mode): + self.ending_mode = ending_mode + self._modify_f = {"plain": self._modify_plain, "eu": self._modify_eu, + "maxv": self._modify_maxv, "bso": self._modify_bso}[ending_mode] + self.margin = 0. + # + self.plain_ranker = state_plain_ranker_getter(self.margin) + self.oracle_ranker = state_oracle_ranker + + def __repr__(self): + return f"<{self.__class__.__name__}: ending_mode={self.ending_mode}>" + + def refresh(self, margin=0., **kwargs): + self.margin = margin + self.plain_ranker = state_plain_ranker_getter(self.margin) + + def end(self, ags: List[BfsAgenda], final_selections: List): + for ag, one_selections in zip(ags, final_selections): + new_beam, new_gbeam = [], [] + if len(one_selections[0]) > 0: + # with special modifications + plain_finals, oracle_finals = self._modify_f(ag, one_selections) + # pick out ended ones for both beams + for which_end, which_beam, which_finals in zip([0,1], [new_beam, new_gbeam], [plain_finals, oracle_finals]): + for s in which_finals: + if self.is_end(s): + s.mark_end() + ag.ends[which_end].append(s) + else: + which_beam.append(s) + ag.beam = new_beam + ag.gbeam = new_gbeam + + def wrapup(self, ags: List[BfsAgenda]): + # whether adding the endings states + ending_mode = self.ending_mode + if ending_mode == "plain" or ending_mode == "bso": + for ag in ags: + ag.special_points.append(BfsSPoint(ag.ends[0], ag.ends[1])) + elif ending_mode == "eu": + for ag in ags: + if len(ag.special_points) == 0: + ag.special_points.append(BfsSPoint(ag.ends[0], ag.ends[1])) + elif ending_mode == "maxv": + pass + else: + raise NotImplementedError() + + # ===== + # special modes for modifying the beam + + def _modify_plain(self, ag: BfsAgenda, final_selections): + plain_finals, oracle_finals = final_selections + return plain_finals, oracle_finals + + def _modify_eu(self, ag: BfsAgenda, final_selections): + # check whether gold drop out + plain_finals, oracle_finals = final_selections + if len(oracle_finals)>0: + best_oracle = max(oracle_finals, key=self.oracle_ranker) + # if oracle dropped out of the beam + if all(best_oracle.cost_accu0: + best_oracle = max(oracle_finals, key=self.oracle_ranker) + # if oracle dropped out of the beam + if all(best_oracle.cost_accu < x.cost_accu for x in plain_finals): + ag.special_points.append(BfsSPoint(plain_finals, oracle_finals)) + best_oracle.mark_restart() + return [best_oracle], [] # add it back!! + return plain_finals, oracle_finals + + def _modify_maxv(self, ag: BfsAgenda, final_selections): + # check violation + plain_finals, oracle_finals = final_selections + best_plain = max(plain_finals, key=self.plain_ranker) + best_oracle = max(plain_finals+oracle_finals, key=self.oracle_ranker) + cur_violation = self.plain_ranker(best_plain) - self.plain_ranker(best_oracle) # score diff + # keep the max-violation one + if len(ag.special_points) == 0: + ag.special_points.append(BfsSPoint(plain_finals, oracle_finals, cur_violation)) + elif cur_violation > ag.special_points[0].violation: + ag.special_points[0] = (BfsSPoint(plain_finals, oracle_finals, cur_violation)) + return plain_finals, oracle_finals + # ===== + + # todo(note): to be implemented for specific system + def is_end(self, state: State): + raise NotImplementedError() + +# ===== +# the (batched) searcher itself + +class BfsSearcher(Searcher): + def __init__(self): + # components + self.expander: BfsExpander = None + self.local_selector: BfsLocalSelector = None + self.global_arranger: BfsGlobalArranger = None + self.ender: BfsEnder = None + + def go(self, ags: List[BfsAgenda]): + while True: + # finish if all ended + if all(a.is_end() for a in ags): + break + # step 1: expand (list of list of candidates) + candidates = self.expander.expand(ags) + # step 2: score (batched all) & select locally + local_selections = self.local_selector.select(ags, candidates) + # step 3: select globally (skip if not necessary) + if self.global_arranger is None: + final_selections = local_selections + else: + final_selections = self.global_arranger.arrange(ags, local_selections) + # step 4: next step (inplaced modifications) + self.ender.end(ags, final_selections) + # final ending + self.ender.wrapup(ags) + # results are stored in ags + # pass + +# ===== +# creator + +# todo(+N): currently specific for specific systems, can the common things be grouped together? + +class BfsSearcherFactory: + # ===== + # general creator: put all components together and get a Searcher + pass diff --git a/msp/search/lsg/system.py b/msp/search/lsg/system.py new file mode 100644 index 0000000..9cb7fcf --- /dev/null +++ b/msp/search/lsg/system.py @@ -0,0 +1,178 @@ +# + +# the core system components (the strctures) + +from typing import List +from msp.utils import Constants, zcheck + +# basic state information, only maintains the properties of basic graph info, other things are in derived classes. +class State: + # _STATUS = "NONE"(unknown), "EXPAND"(survived), "MERGED", "END" + STATUS_NONE, STATUS_EXPAND, STATUS_MERGED, STATUS_END = 0, 1, 2, 3 + + # ==== constructions + def __init__(self, prev: 'State'=None, action: 'Action'=None, score: float=0., sg: 'Graph'=None): + # 1. basic info + self.sg: 'Graph' = sg # related SearchGraph + self.prev: 'State' = prev # parent + self.length: int = 0 # how many steps from start + # graph info + self.id: int = None + self.nexts: List = None # valid next states + self.merge_list: List = None # merge others + self.merger: State = None # merged by other + + # 2. status + self.status: int = State.STATUS_NONE + self.restart: bool = False + + # 3. values + self.action: Action = action + # model scores + self.score: float = score + self.score_accu: float = 0. + # actual cost (unknown currently, to be filled by Oracler) + self.cost: float = None + self.cost_accu: float = None + # signature + self.sig = None # None actually means no sig, but remember None != None for sig + # ===== + # depend on prev state + if prev is not None: + self.sg = prev.sg + self.length = prev.length + 1 + self.score_accu = prev.score_accu + self.score + # todo(note): set the end points for the action + self.action.set_from_to(prev, self) + else: + self.restart = True # root is also a restart + assert action is None, "Err: starting State does not accept Action." + # no need to preserve graph info if not recording it + if self.sg is not None: + self.nexts = [] # valid next states + self.merge_list = [] # merge others + self.id = self.sg.reg(self) + if prev is not None: + prev.status = State.STATUS_EXPAND + prev.nexts.append(self) + + def get_pid(self): + if self.prev is None: + return -1 + else: + return self.prev.id + + def __len__(self): + return self.length + + def __repr__(self): + return f"State[{self.id}<-{self.get_pid()}] len={self.length}, act={self.action}, " \ + f"sc={self.score_accu:.3f}({self.score:.3f})" + + # graph related + def is_end(self): + return self.status == State.STATUS_END + + def is_start(self): + return self.prev is None + + def is_restart(self): + return self.restart + + def mark_restart(self): + self.restart = True + + def mark_end(self): + self.status = State.STATUS_END + if self.sg is not None: + self.sg.add_end(self) + + # get the history path towards root + def get_path(self, forward_order=True, depth=Constants.INT_PRAC_MAX, include_restart=False): + cur = self + ret = [] + count = 0 + # get them until START or meet max depth + while cur is not None and count < depth: + if not include_restart and cur.restart: + break + ret.append(cur) + cur = cur.prev + count += 1 + # prepare the direction in the list + if forward_order: + ret.reverse() + return ret + + # merge: modify both states + def merge_by(self, s: 'State'): + zcheck(s.merger is None, "Err: multiple level of merges!") + self.status = State.STATUS_MERGED + self.merger = s + s.add_merge(self) + + def add_merge(self, s: 'State'): + if self.merge_list: + self.merge_list.append(s) + + # ===== + def get_loss_aug_score(self, margin): + return self.score_accu+margin*self.cost_accu + +# one action from one state to another +class Action: + def __init__(self): + self.state_from: State = None + self.state_to: State = None + + def set_from_to(self, fstate: State, tstate: State): + self.state_from = fstate + self.state_to = tstate + +# one expansion from a state, can lead to real candidate states if further selected by the selector +class Candidates: + pass + +# the linear search graph (lattice) +class Graph(object): + def __init__(self, info=None): + # information about the settings like srcs + self.info = info + # unique ids for states (must be >=1) + self.counts = 0 + # special nodes + self.root = None + self.ends = [] + + # register a new state, return its id in this graph + def reg(self, s: State): + self.counts += 1 + # todo(+N): record more info about the state? + return self.counts + + def set_root(self, s: State): + zcheck(s.sg is self, "SGErr: State does not belong here") + zcheck(self.root is None, "SGErr: Can only have one root") + zcheck(s.is_start(), "SGErr: Only start node can be root") + self.root = s + + def add_end(self, s: State): + self.ends.append(s) + +# +# other component classes for the searching procedure (decoding or training) + +# ===== + +class Oracler: + pass + +class Coster: + def set_costs(self, states: List[State]): + raise NotImplementedError() + +class Signaturer: + def set_sigs(self, states: List[State]): + raise NotImplementedError() + +# ===== diff --git a/msp/utils/__init__.py b/msp/utils/__init__.py index 1b937cc..ae402c8 100644 --- a/msp/utils/__init__.py +++ b/msp/utils/__init__.py @@ -1,14 +1,16 @@ # from .algo import AlgoHelper -from .check import zwarn, zfatal, zcheck, Checker, ZException +from .check import zwarn, zfatal, zcheck, Checker +from .color import wrap_color from .conf import Conf from .log import zopen, zlog, printing, Logger from .math import MathHelper from .random import Random -from .system import system, dir_msp, get_statm, FileHelper +from .seria import JsonRW, PickleRW +from .system import system, dir_msp, get_statm, FileHelper, extract_stack from .task import Timer, StatRecorder -from .utils import Constants, Helper, NumHelper, StrHelper, JsonRW, PickleRW +from .utils import Constants, Helper, NumHelper, StrHelper, ZObject from sys import stderr, argv from platform import uname diff --git a/msp/utils/check.py b/msp/utils/check.py index 51bbe0c..cc9b794 100644 --- a/msp/utils/check.py +++ b/msp/utils/check.py @@ -3,46 +3,37 @@ from .log import zlog from collections import Callable -# Checkings: also including some shortcuts for convenience - -class ZException(RuntimeError): - def __init__(self, s): - super().__init__(s) +# Checkings def zfatal(ss=""): zlog(ss, func="fatal") - raise ZException(ss) + raise RuntimeError(ss) + +def zwarn(ss="", **kwargs): + zlog(ss, func="warn", **kwargs) -def zwarn(ss=""): - zlog(ss, func="warn") +def _get_v(ff): + return ff() if isinstance(ff, Callable) else ff -def zcheck(ff, ss, func="error", forced=False): - if Checker.enabled(func) or forced: - Checker.check(ff, ss, func) +def zcheck(ff, ss, fatal=True, level=3): + if level >= Checker._level: + check_val, check_ss = _get_v(ff), _get_v(ss) + if not check_val: + if fatal: + zfatal(check_ss) + else: + zwarn(check_ss) # should be used when debugging or only fatal ones, comment out if real usage class Checker(object): - _checker_filters = {"warn": True, "error": True, "fatal": True} - _checker_handlers = {"warn": (lambda: 0), "error": (lambda: zfatal("ERROR")), "fatal": (lambda: zfatal("FATAL-ERROR"))} - - @staticmethod - def init(): - # todo(0): this may be enough - pass - - @staticmethod - def _get_v(v): - if isinstance(v, Callable): - return v() - else: - return v + _level = 3 @staticmethod - def check(expr, ss, func): - if not Checker._get_v(expr): - zlog(Checker._get_v(ss), func=func) - Checker._checker_handlers[func]() + def init(level=3): + Checker.set_level(level) @staticmethod - def enabled(func): - return Checker._checker_filters[func] + def set_level(level): + if level != Checker._level: + zlog(f"Change check level from {Checker._level} to {level}") + Checker._level = level diff --git a/msp/utils/color.py b/msp/utils/color.py new file mode 100644 index 0000000..0844c00 --- /dev/null +++ b/msp/utils/color.py @@ -0,0 +1,42 @@ +# + +# colorful printing + +try: + from colorama import colorama_init + colorama_init() + from colorama import Fore, Back, Style + RESET_ALL = Style.RESET_ALL +except: + class Fore: + BLACK = '\033[30m' + RED = '\033[31m' + GREEN = '\033[32m' + YELLOW = '\033[33m' + BLUE = '\033[34m' + MAGENTA = '\033[35m' + CYAN = '\033[36m' + WHITE = '\033[37m' + RESET = '\033[39m' + + class Back: + BLACK = '\033[40m' + RED = '\033[41m' + GREEN = '\033[42m' + YELLOW = '\033[43m' + BLUE = '\033[44m' + MAGENTA = '\033[45m' + CYAN = '\033[46m' + WHITE = '\033[47m' + RESET = '\033[49m' + RESET_ALL = '\033[0m' + +class Special: + BOLD = '\033[1m' + UNDERLINE = '\033[4m' + +def wrap_color(s: str, fcolor: str=None, bcolor: str=None, smode: str=None): + f_prefix = ("" if fcolor is None else getattr(Fore, str.upper(fcolor))) + b_prefix = ("" if bcolor is None else getattr(Back, str.upper(bcolor))) + s_prefix = ("" if smode is None else getattr(Special, str.upper(smode))) + return "".join([f_prefix, b_prefix, s_prefix, s, RESET_ALL]) diff --git a/msp/utils/conf.py b/msp/utils/conf.py index ae13e2f..36ec37a 100644 --- a/msp/utils/conf.py +++ b/msp/utils/conf.py @@ -1,11 +1,13 @@ # from collections import Iterable, OrderedDict +import json from .log import printing, zopen from .check import zcheck, zfatal, zwarn from .utils import Constants +# todo(+N): need to provide more detailed names if there are ambiguity, anyway to set default ones? class Conf(object): NV_SEP = ":" HIERARCHICAL_SEP = "." @@ -54,7 +56,8 @@ def update_from_conf(self, cc): self.__dict__[n] = type(self.__dict__[n])(v) # ===== - # collecting shortcuts + # collecting shortcuts (all k-last suffixes * zero/one previous random one) + # todo(warn): the number will explode if there are too many layers? def collect_all_names(self): def _add_rec(v, cur_conf, path): for n in cur_conf.__dict__: @@ -67,9 +70,16 @@ def _add_rec(v, cur_conf, path): full_path = Conf.HIERARCHICAL_SEP.join(path) for i in range(len(path)): short_name = Conf.HIERARCHICAL_SEP.join(path[i:]) - if short_name not in v: - v[short_name] = [] - v[short_name].append(full_path) + # further combination: (n^2): single previous + for lead_name in [""] + path[:max(0, i-1)]: # i-1 to avoid repeating for continuous one + if lead_name: + final_name = Conf.HIERARCHICAL_SEP.join([lead_name, short_name]) + else: + final_name = short_name + # + if final_name not in v: + v[final_name] = [] + v[final_name].append(full_path) path.pop() # ret = {} # post_name -> full_name set @@ -95,10 +105,15 @@ def _getv(tt, vv): last_name = ks[-1] old_v = sub_conf.__dict__[last_name] item_type = type(old_v) + # todo(+1): can also use python's "eval" function if isinstance(old_v, list): + try: + v_toassign = json.loads(v) + except json.JSONDecodeError: + v_toassign = v.split(Conf.LIST_SEP) if len(v)>0 else [] # todo(warn): for default empty ones, also special treating for bool - ele_type = type(old_v[0]) if len(old_v)>0 else str - sub_conf.__dict__[last_name] = [_getv(ele_type, z) for z in v.split(Conf.LIST_SEP)] + ele_type = type(old_v[0]) if len(old_v) > 0 else str + sub_conf.__dict__[last_name] = [_getv(ele_type, z) for z in v_toassign] else: sub_conf.__dict__[last_name] = _getv(item_type, v) # return (name, old value, new value) @@ -111,14 +126,14 @@ def update_from_dict(self, d): for n, v in d.items(): name_list = name_maps.get(n, None) if name_list is None: - bad_ones.append("Unknown config %s=%s" % (n, v)) + bad_ones.append(f"Unknown config {n}={v}") continue if len(name_list) != 1: - bad_ones.append("Bad(ambiguous or non-exist) config %s=%s, -> %s" % (n, v, name_list)) + bad_ones.append(f"Bad(ambiguous or non-exist) config {n}={v}, -> {name_list}") continue full_name = name_list[0] - _, _, new_v = self._do_update(full_name, v) - good_ones.append("Update config '%s=%s': %s to %s" % (n,v,full_name,new_v)) + _, old_v, new_v = self._do_update(full_name, v) + good_ones.append(f"Update config '{n}={v}': {full_name} = {old_v} -> {new_v}") return good_ones, bad_ones def update_from_kwargs(self, **kwargs): diff --git a/msp/utils/log.py b/msp/utils/log.py index ed79d51..d974e97 100644 --- a/msp/utils/log.py +++ b/msp/utils/log.py @@ -5,31 +5,35 @@ # files and loggings def zopen(filename, mode='r', encoding="utf-8"): + suffix = "" if ('b' in mode) else "t" + if 'b' in mode: + encoding = None if filename.endswith('.gz'): # "t" for text mode of gzip - return gzip.open(filename, mode+"t", encoding=encoding) + return gzip.open(filename, mode+suffix, encoding=encoding) elif filename.endswith('.bz2'): - return bz2.open(filename, mode+"t", encoding=encoding) + return bz2.open(filename, mode+suffix, encoding=encoding) else: return open(filename, mode, encoding=encoding) -def zlog(s, func="plain", end='\n', flush=True, timed=False): - Logger._instance._log(str(s), func, end, flush, timed) +def zlog(s, func="plain", end='\n', flush=True, timed=False, level=3): + Logger._instance._log(str(s), func, end, flush, timed, level) printing = zlog class Logger(object): _instance = None _logger_heads = { - "plain": "-- ", "time":"## ", "io":"== ", "result": ">> ", "report": "** ", - "warn":"!! ", "error": "ER ", "fatal":"KI ", "config": "CC ", - "debug":"DE ", "none":"**INVALID-CODE**", "nope": "", + "plain": "--", "time": "##", "io": "==", "result": ">>", "report": "**", + "warn": "!!", "error": "ER", "fatal": "KI", "config": "CC", + "debug": "DE", "nope": "", } @staticmethod def _get_ch(func): # get code & head if func not in Logger._logger_heads: - func = "none" - return func, Logger._logger_heads[func] + return func, func + else: + return func, Logger._logger_heads[func] MAGIC_CODE = "?THE-NATURE-OF-HUMAN-IS?" # REPEATER? @staticmethod @@ -46,7 +50,7 @@ def init(files): # zlog("START!!", func="plain") # ===== - def __init__(self, file_filters, func_filters): + def __init__(self, file_filters, func_filters, level=3): self.file_filters = file_filters self.func_filters = func_filters self.fds = {} @@ -56,22 +60,28 @@ def __init__(self, file_filters, func_filters): self.fds[f] = zopen(f, mode="w") else: self.fds[f] = f + self.level = level def __del__(self): for f in self.file_filters: if isinstance(f, str): self.fds[f].close() - def _log(self, s, func, end, flush, timed): - func, head = Logger._get_ch(func) - if self.func_filters[func]: - if timed: - ss = head + "[" + time.ctime() + "]" + s - else: - ss = head + s - for f in self.fds: - if self.file_filters[f]: - print(ss, end=end, file=self.fds[f], flush=flush) + def set_level(self, level): + zlog(f"Change log level from {self.level} to {level}") + self.level = level + + def _log(self, s, func, end, flush, timed, level): + if level >= self.level: + func, head = Logger._get_ch(func) + if self.func_filters[func]: + if timed: + ss = f"{head}[{'-'.join(time.ctime().split()[-4:])}, L{level}] {s}" + else: + ss = f"{head}[L{level}] {s}" + for f in self.fds: + if self.file_filters[f]: + print(ss, end=end, file=self.fds[f], flush=flush) # todo(1): register or filter files & codes diff --git a/msp/utils/math.py b/msp/utils/math.py index ce87ca3..6bd216e 100644 --- a/msp/utils/math.py +++ b/msp/utils/math.py @@ -1,6 +1,9 @@ # import numpy as np -from scipy.misc import logsumexp as scipy_logsumexp +try: + from scipy.misc import logsumexp as scipy_logsumexp +except: + from scipy.special import logsumexp as scipy_logsumexp from math import isclose as math_isclose, exp as math_exp class MathHelper(object): diff --git a/msp/utils/random.py b/msp/utils/random.py index 47a4e74..b679b8d 100644 --- a/msp/utils/random.py +++ b/msp/utils/random.py @@ -91,3 +91,17 @@ def random_bool(true_rate, size=None, task=""): @staticmethod def randint(low, high=None, size=None, task=""): return Random.get_generator(task).randint(low, high, size) + + # ===== + # batched forever streamer + @staticmethod + def stream(f, batch_size=1024): + while True: + for one in f(size=batch_size): + yield one + + @staticmethod + def stream_bool(true_rate, batch_size=1024): + while True: + for one in Random.random_bool(true_rate, batch_size): + yield one diff --git a/msp/utils/seria.py b/msp/utils/seria.py new file mode 100644 index 0000000..b4f1970 --- /dev/null +++ b/msp/utils/seria.py @@ -0,0 +1,134 @@ +# serialization + +import json +import pickle +from typing import List +from .log import zopen, zlog + +# todo(note): +# for json serialization: customed to_builtin, has simple __dict__ or self is builtin +# otherwise, use pickle serialization + +# io +# builtin types will not use this! +class _MyJsonEncoder(json.JSONEncoder): + def default(self, one): + if hasattr(one, "to_builtin"): + return one.to_builtin() + else: + return one.__dict__ + +class JsonRW: + # update from v and return one if one is not builtin, directly return otherwise + @staticmethod + def _update_return(one, v): + if hasattr(one, "from_builtin"): + one.from_builtin(v) + # todo(2): is this one OK? + elif hasattr(one, "__dict__"): + one.__dict__.update(v) + else: + return v + return one + + @staticmethod + def from_file(one, fn_or_fd): + if isinstance(fn_or_fd, str): + with zopen(fn_or_fd, 'r') as fd: + return JsonRW.from_file(one, fd) + else: + return JsonRW._update_return(one, json.load(fn_or_fd)) + + @staticmethod + def to_file(one, fn_or_fd): + if isinstance(fn_or_fd, str): + with zopen(fn_or_fd, 'w') as fd: + JsonRW.to_file(one, fd) + else: + json.dump(one, fn_or_fd, cls=_MyJsonEncoder) + + @staticmethod + def from_str(one, s): + v = json.loads(s) + return JsonRW._update_return(one, v) + + @staticmethod + def to_str(one): + return json.dumps(one, cls=_MyJsonEncoder) + + @staticmethod + def save_list(ones: List, fn_or_fd): + if isinstance(fn_or_fd, str): + with zopen(fn_or_fd, 'w') as fd: + JsonRW.save_list(ones, fd) + else: + for one in ones: + fn_or_fd.write(JsonRW.to_str(one) + "\n") + + @staticmethod + def yield_list(fn_or_fd, max_num=-1): + if isinstance(fn_or_fd, str): + with zopen(fn_or_fd) as fd: + for one in JsonRW.yield_list(fd, max_num): + yield one + else: + c = 0 + while c!=max_num: + line = fn_or_fd.readline() + if len(line) == 0: + break + yield JsonRW.from_str(None, line) + c += 1 + + @staticmethod + def load_list(fn_or_fd, max_num=-1): + return list(JsonRW.yield_list(fn_or_fd, max_num)) + +# +class PickleRW: + @staticmethod + def from_file(fn_or_fd): + if isinstance(fn_or_fd, str): + with zopen(fn_or_fd, 'rb') as fd: + return PickleRW.from_file(fd) + else: + return pickle.load(fn_or_fd) + + @staticmethod + def to_file(one, fn_or_fd): + if isinstance(fn_or_fd, str): + with zopen(fn_or_fd, 'wb') as fd: + PickleRW.to_file(one, fd) + else: + pickle.dump(one, fn_or_fd) + + @staticmethod + def save_list(ones: List, fn_or_fd): + if isinstance(fn_or_fd, str): + with zopen(fn_or_fd, 'wb') as fd: + PickleRW.save_list(ones, fd) + else: + for one in ones: + pickle.dump(one, fn_or_fd) + + @staticmethod + def yield_list(fn_or_fd, max_num=-1): + if isinstance(fn_or_fd, str): + with zopen(fn_or_fd, 'rb') as fd: + for one in PickleRW.yield_list(fd, max_num): + yield one + else: + c = 0 + while c!=max_num: + yield pickle.load(fn_or_fd) + c += 1 + + @staticmethod + def load_list(fn_or_fd, max_num=-1): + ret = [] + try: + for one in PickleRW.yield_list(fn_or_fd, max_num): + ret.append(one) + except EOFError: + pass + return ret diff --git a/msp/utils/system.py b/msp/utils/system.py index 026bf2c..b30c7b5 100644 --- a/msp/utils/system.py +++ b/msp/utils/system.py @@ -1,6 +1,6 @@ # about system information -import sys, os, subprocess +import sys, os, subprocess, traceback, glob from .log import zopen, zlog from .check import zcheck @@ -64,3 +64,20 @@ def read_multiline(fd, skip_f=lambda x: len(x.strip()) == 0, ignore_f=lambda x: if len(line) == 0: break return lines + + @staticmethod + def glob(pathname, assert_exist=False, assert_only_one=False): + files = glob.glob(pathname) + if assert_only_one: + assert len(files) == 1 + elif assert_exist: # only_one leads to exists + assert len(files) > 0 + return files + +# get tracebacks +def extract_stack(num=-1): + # exclude this frame + frames = traceback.extract_stack()[1:] + if num > 0: + frames = frames[:num] + return frames diff --git a/msp/utils/task.py b/msp/utils/task.py index e033c8c..bf65835 100644 --- a/msp/utils/task.py +++ b/msp/utils/task.py @@ -130,7 +130,7 @@ class StatRecorder(object): @staticmethod def guess_default_handler(name, v): - if isinstance(name, Iterable): + if isinstance(name, (tuple, list)): name = str(name[0]) fileds = name.split("_") start_name = fileds[0] diff --git a/msp/utils/utils.py b/msp/utils/utils.py index 4d59c24..c729bfb 100644 --- a/msp/utils/utils.py +++ b/msp/utils/utils.py @@ -1,5 +1,5 @@ # misc -import json, sys, pickle, re +import sys, re from collections import Iterable import numpy as np from .log import zopen, zlog @@ -13,10 +13,10 @@ class Constants(object): REAL_PRAC_MAX = float(PRAC_NUM_) REAL_PRAC_MIN = -REAL_PRAC_MAX # - INT_MAX = 2**31-1 - INT_MIN = -2**31+1 - REAL_MAX = sys.float_info.max - REAL_MIN = -REAL_MAX + # INT_MAX = 2**31-1 + # INT_MIN = -2**31+1 + # REAL_MAX = sys.float_info.max + # REAL_MIN = -REAL_MAX class StrHelper(object): NUM_PATTERN = re.compile(r"[0-9]+|[0-9]+\.[0-9]*|[0-9]+[0-9,]+|[0-9]+\\/[0-9]+|[0-9]+/[0-9]+") @@ -231,71 +231,18 @@ def check_end_iter(iter): except StopIteration: pass -# io -# builtin types will not use this! -class _MyJsonEncoder(json.JSONEncoder): - def default(self, one): - if hasattr(one, "to_builtin"): - return one.to_builtin() - else: - return one.__dict__ - -class JsonRW(object): - # update from v and return one if one is not builtin, directly return otherwise - @staticmethod - def _update_return(one, v): - if hasattr(one, "from_builtin"): - one.from_builtin(v) - # todo(+2): is this one OK? - elif hasattr(one, "__dict__"): - one.__dict__.update(v) - else: - return v - return one - - @staticmethod - def load_from_file(one, file_name): - with zopen(file_name, 'r') as fd: - v = json.load(fd) - return JsonRW._update_return(one, v) - - @staticmethod - def save_to_file(one, file_name): - with zopen(file_name, 'w') as fd: - json.dump(one, fd, cls=_MyJsonEncoder) - - @staticmethod - def load_from_str(one, s): - v = json.loads(s) - return JsonRW._update_return(one, v) - - @staticmethod - def save_to_str(one): - return json.dumps(one, cls=_MyJsonEncoder) - # ===== - # some useful tools @staticmethod - def save_list(ones, file_name): - with zopen(file_name, 'w') as fd: - for one in ones: - fd.write(JsonRW.save_to_str(one)+"\n") + def check_is_range(idxes, length): + return len(idxes) == length and all(i == v for i, v in enumerate(idxes)) - @staticmethod - def load_list(file_name): - with zopen(file_name) as fd: - ones = [JsonRW.load_from_str(None, line) for line in fd] - return ones +class ZObject(object): + def __init__(self, m=None): + if m is not None: + self.update(m) -# -class PickleRW(object): - @staticmethod - def read(fname): - with open(fname, "rb") as fd: - x = pickle.load(fd) - return x - - @staticmethod - def write(fname, x): - with open(fname, "wb") as fd: - pickle.dump(x, fd) + def update(self, m): + if isinstance(m, ZObject): + m = m.__dict__ + for k, v in m.items(): + setattr(self, k, v) diff --git a/msp/zext/ana/__init__.py b/msp/zext/ana/__init__.py new file mode 100644 index 0000000..91e5a3c --- /dev/null +++ b/msp/zext/ana/__init__.py @@ -0,0 +1,6 @@ +# + +# some tools for analyzing (and possibly light-weighted annotating) + +from .node2 import ZRecNode, ZNodeVisitor, ZObject +from .analyzer import Analyzer, AnalyzerConf, AnnotationTask diff --git a/msp/zext/ana/analyzer.py b/msp/zext/ana/analyzer.py new file mode 100644 index 0000000..735babb --- /dev/null +++ b/msp/zext/ana/analyzer.py @@ -0,0 +1,316 @@ +# + +# the general analyzer + +import numpy as np +import pandas as pd +import traceback +from typing import List, Iterable +from shlex import split as sh_split +from pdb import set_trace + +from msp.utils import zopen, zlog, Conf, PickleRW, ZObject, Helper + +from .node2 import ZRecNode, ZNodeVisitor + +class AnalyzerConf(Conf): + def __init__(self): + self.auto_save_name = "./_tmp_analysis.pkl" + +# analyzer: need to conform to each commands' local naming protocol +class Analyzer: + def __init__(self, conf: AnalyzerConf): + self.conf = conf + self.history = [] + # naming convention: those start with "_" will not be saved + self.vars = ZObject() # vars mapping + self.traces = {} # var_name -> set-history-idx + # auto save + self.last_save_point = 0 # last save before this idx + # current cmd info + self.cur_cmd_line = None + self.cur_cmd_target = None + self.cur_cmd_args = None + + def __del__(self): + if self.last_save_point < len(self.history): + auto_save_name = self.conf.auto_save_name + if len(auto_save_name) > 0: + self.do_save(self.conf.auto_save_name) + + # ===== + # some helper functions + def set_var(self, target, v, explanation=None, history_idx=None): + if hasattr(self.vars, target): + zlog(f"Overwriting the existing var `{target}`") + if target not in self.traces: + self.traces[target] = [] + # (explanation, history-idx) + history_idx = len(self.history) if history_idx is None else history_idx + self.traces[target].append((explanation, history_idx)) # the current cmd will be recorded into history + setattr(self.vars, target, v) + + def get_var(self, target): + if hasattr(self.vars, target): + return getattr(self.vars, target) + else: + assert False, f"Cannot find var `{target}`" + + def str_cmd(self, one): + target, args = one + return f"{target} = {args}" + + # ===== + # some general commands + def do_history(self, num=-10): + num = int(num) + cur_len = len(self.history) + if num<0: # listing recent mode + num = min(int(-num), cur_len) + zlog(f"Listing histories: all = {cur_len}") + # back to front + for i in range(1, num+1): + real_idx = cur_len - i + zlog(f"[#{i}|{real_idx}] {self.str_cmd(self.history[real_idx])}") + else: + zlog(f"Listing histories idx = {num}: {self.str_cmd(self.history[num])}") + return None + + def do_trace(self, target): + v = self.get_var(target) + for explanation, history_idx in self.traces[target]: + zlog(f"Var `{target}`: ({explanation}, {self.str_cmd(self.history[history_idx]) if history_idx>=0 else ''})") + + # most general runner + def do_eval(self, code): + vs = self.vars # convenient local variable + _ff = compile(code, "", "eval") + ret = eval(_ff) + return ret + + def do_pdb(self): + set_trace() + return None + + def do_load(self, file): + zlog(f"Try loading vars from {file}") + x = PickleRW.from_file(file) + self.vars.update(x) + + def do_save(self, file): + zlog(f"Try saving vars to {file}") + self.last_save_point = len(self.history) + 1 # plus one for the current one + PickleRW.to_file(self.vars, file) + + # ===== + # some useful ones + + # protocol: target: instances, fcode: d for each instances; return instances + def do_filter(self, insts_target: str, fcode: str) -> List: + vs = self.vars + _ff = compile(fcode, "", "eval") + insts = self.get_and_check_type(insts_target, list) + ret = [d for d in insts if eval(_ff)] + zlog(f"Filter by {fcode}: from {len(insts)} to {len(ret)}, {len(ret)/(len(insts)+1e-7)}") + return ret + + # protocol: target of instances, jcode return a list for each of them + def do_join(self, insts_target: str, jcode: str) -> List: + vs = self.vars + _ff = compile(jcode, "", "eval") + insts = self.get_and_check_type(insts_target, list) + ret = [eval(_ff) for d in insts] + ret = Helper.join_list(ret) + zlog(f"Join-list by {jcode}: from {len(insts)} to {len(ret)}") + return ret + + # protocol: target: instances, kcode: (key) d for inst: return sorted list + def do_sort(self, insts_target: str, kcode: str) -> List: + vs = self.vars + _ff = compile(kcode, "", "eval") + insts = self.get_and_check_type(insts_target, list) + tmp_tuples = [(d, eval(_ff)) for d in insts] + tmp_tuples.sort(key=lambda x: x[1]) + ret = [x[0] for x in tmp_tuples] + zlog(f"Sort by key={kcode}: len = {len(ret)}") + return ret + + # protocol: target: instances, gcode: d for each instances; return one node + def do_group(self, insts_target: str, gcode: str, sum_key="count") -> ZRecNode: + return self._do_group(insts_target, gcode, sum_key, None) + + # protocol: target: instances, gcode: d for each instances; return one node + def _do_group(self, insts_target: str, gcode: str, sum_key: str, visitor: ZNodeVisitor) -> ZRecNode: + vs = self.vars + _ff = compile(gcode, "", "eval") + insts = self.get_and_check_type(insts_target, list) + ret = ZRecNode(None, []) + for d in insts: + ret.add_seq(eval(_ff), obj=d) + # visitor + if visitor is not None: + try: + ret.rec_visit(visitor) + except: + zlog(traceback.format_exc()) + zlog("Error of visitor.") + _show_node_f = lambda x: visitor.show(x) + else: + _show_node_f = lambda x: "--" + # some slight summaries here + all_count = len(insts) + all_nodes = ret.get_descendants(key=sum_key) + ss = [] + for z in all_nodes: + all_parents = z.get_antecedents() + if len(all_parents) > 0: + assert all_parents[0].count == all_count + perc_info = ', '.join([f"{z.count/(zp.count+1e-6):.4f}" for zp in all_parents]) + ss.append(['=='*len(z.path), str(z.path), f"{z.count}({perc_info})" f"{_show_node_f(z)}"]) + # sstr = "\n".join(ss) + # sstr = "" + pd.set_option('display.width', 1000) + pd.set_option('display.max_colwidth', 1000) + pdf = pd.DataFrame(ss) + zlog(f"Group {len(insts)} instances by {gcode}, all {len(ss)} nodes:\n{pdf.to_string()}") + return ret + + # similar to group, but using pd instead + def do_get_pd(self, insts_target: str, gcode: str) -> pd.DataFrame: + vs = self.vars + _ff = compile(gcode, "", "eval") + insts = self.get_and_check_type(insts_target, list) + # + fields = [eval(_ff) for d in insts] + ret = pd.DataFrame(fields) + zlog(f"Group {len(insts)} instances by {gcode} to pd.DataFrame shape={ret.shape}.") + return ret + + # calculation and manipulation on pd, also assuming the local name is d + def do_cal_pd(self, inst_pd: str, scode: str): + vs = self.vars + _ff = compile(scode, "", "eval") + d = self.get_and_check_type(inst_pd, pd.DataFrame) + # + ret = eval(_ff) + zlog(f"Calculation on pd.DataFrame by {scode}, and get another one as: {str(ret)}") + return ret + + # ===== + # helpers + def get_and_check_type(self, target, dtype): + ret = self.get_var(target) + assert isinstance(ret, dtype), f"Wrong typed target, {type(ret)} instead of {dtype}" + return ret + + # ===== + # + def process(self, args: List[str]): + assert len(args) > 0, "Empty command" + cmd_name, real_args = args[0], args[1:] + # similar to the cmd package + method_name = "do_" + cmd_name + assert hasattr(self, method_name), f"Unknown command {cmd_name}" + zlog(f"Performing command: {args}") + return getattr(self, method_name)(*real_args) + + # format is "[target-name=]args..." + def get_cmd(self): + target, args = None, [] + line = input(">> ") + # first check assignment target + cmd = line.strip() + tmp_fields = cmd.split("=", 1) + if len(tmp_fields) > 1: + tmp_target, remainings = [x.strip() for x in tmp_fields] + if str.isidentifier(tmp_target): + target = tmp_target + cmd = remainings + # then split into args + args = sh_split(cmd) + return target, args, line + + def loop(self): + target, args = None, [] + stop = False + while True: + # read one command + while True: + try: + target, args, line = self.get_cmd() + except KeyboardInterrupt: + continue + except EOFError: + stop = True + break + if len(args) > 0: # get valid commands + break + if stop: + break + # process it + try: + self.cur_cmd_line = line + self.cur_cmd_target = target + self.cur_cmd_args = args + v = self.process(args) + except AssertionError as e: + zlog(f"Checking Error: " + str(e)) + continue + except: + zlog(f"Command error: " + str(traceback.format_exc())) + continue + cmd_count = len(self.history) + if v is not None: # does not store None + if target is not None: + self.set_var(target, v) + self.history.append((target, args)) + zlog(f"Finish command #{cmd_count}: {self.str_cmd(self.history[-1])}") + + +# requiring that objs have "id" field +class AnnotationTask: + def __init__(self, objs: List): + self.pool = objs + self.focus = 0 + self.length = len(self.pool) + self.remaining = len(self.pool) + self.annotations = {z.id: None for z in objs} # id -> ann-obj + + # to be overridden + def obj_finished(self, obj): + return self.annotations[obj.id] is not None + + # get current focus + def get_obj(self): + idx = self.focus + obj = self.pool[idx] + return obj + + # + def annotate_obj(self): + raise NotImplementedError("to be implemented") + + # setting focus and return the new focus + def set_focus(self, idx=None, offset=0): + if idx is not None: + new_focus = idx + else: + new_focus = self.focus + offset + if new_focus>=0 and new_focus ZRecNode + # + self.objs: List = [] # only added to the end points + self.props = {} # other to-be-set properties + + # ===== + # like a dictionary + def __getitem__(self, item): + return self.elems[item] + + def __contains__(self, item): + return item in self.elems + + def keys(self): + return self.elems.keys() + + def values(self): + return self.elems.values() + + def is_root(self): + return self.parent is None + + def __getattr__(self, item): + if item in self.__dict__: + return self.__dict__[item] + elif item in self.props: + return self.props[item] + else: + raise AttributeError() + + # get nodes of descendents + def get_descendants(self, recursive=True, key="count", preorder=True): + ret = [self] if preorder else [] + if recursive: + for v in sorted(self.elems.values(), key=lambda v: getattr(v, key)): + ret.extend(v.get_descendants(recursive, key, preorder)) + if not preorder: + ret.extend(self) + return ret + + # get parants, grandparents, etc; h2l means high to low + def get_antecedents(self, h2l=True): + ret = [] + n = self + while n.parent is not None: + ret.append(n.parent) + n = n.parent + if h2l: + ret.reverse() + return ret + + # ===== + # recording specific + def add_seq(self, seq, count=1, obj=None): + assert self.parent is None, "Currently only support adding from ROOT" + # make it list + if not isinstance(seq, (list, tuple)): + seq = [seq] + # recursive adding + cur_node = self + cur_path = [] + while 1: + # update for current node + cur_node.count += count + if obj is not None: + cur_node.objs.append(obj) + # next one + if len(seq) <= 0: + cur_node.count_end += count + break + seq0, seq = seq[0], seq[1:] + cur_path.append(seq0) + next_node = cur_node.elems.get(seq0, None) + if next_node is None: + next_node = ZRecNode(cur_node, cur_path) # no need copy, since changed to a new tuple later. + cur_node.elems[seq0] = next_node + cur_node = next_node + + # recursively visit (updating) + def rec_visit(self, visitor: 'ZNodeVisitor'): + # first visit all the children nodes + values = [] + for node in visitor.sort(self.values()): + values.append(node.rec_visit(visitor)) + return visitor.visit(self, values) + +# +class ZNodeVisitor: + def sort(self, nodes: Iterable[ZRecNode]): + return nodes + + # process one Node + def visit(self, node: ZRecNode, values: List): + raise NotImplementedError("To be overridden!") + + # show result for one Node (mainly used for printing) + def show(self, node: ZRecNode): + raise NotImplementedError("To be overridden!") diff --git a/msp/zext/dpar/conllu_eval.py b/msp/zext/dpar/conllu_eval.py index afdbe3c..f7b9551 100644 --- a/msp/zext/dpar/conllu_eval.py +++ b/msp/zext/dpar/conllu_eval.py @@ -4,8 +4,12 @@ from msp.utils import Helper, NumHelper class ParserEvaler: - # only for PTB, CTB + # for UD, PTB, CTB + # DEFAULT_PUNCT_POS_SET = {"PUNCT", "SYM", '.', '``', "''", ':', ',', 'PU'} + # todo(warn): following CoNLL, no excluding punct for UD DEFAULT_PUNCT_POS_SET = {'.', '``', "''", ':', ',', 'PU'} + # FUNCTION relations + punct relation + CLAS_EXCLUDE_LABELS = {'aux', 'case', 'cc', 'clf', 'cop', 'det', 'mark', 'punct'} def __init__(self, ignore_punct=True, punct_set=DEFAULT_PUNCT_POS_SET): self.ignore_punct = ignore_punct @@ -22,32 +26,70 @@ def add_confusion(self, tg, tp): self.confusion[tg][tp] += 1 # evaluate for one sentence - def eval_one(self, gold_pos, gold_heads, gold_labels, pred_heads, pred_labels): + def eval_one(self, gold_pos, gold_heads, gold_labels, pred_pos, pred_heads, pred_labels): curr_stat = {} # add 1 for current stat def _a1cs(_name): Helper.stat_addone(curr_stat, _name) - # + # ===== + # collect children info for UD*** + gold_children, pred_children = [[] for _ in range(len(gold_pos)+1)], [[] for _ in range(len(gold_pos)+1)] + gold_children_labels, pred_children_labels = [[] for _ in range(len(gold_pos)+1)], [[] for _ in range(len(gold_pos)+1)] + for i, m in enumerate(gold_heads): + gold_children[m].append(i) + gold_children_labels[m].append(gold_labels[i]) + for i, m in enumerate(pred_heads): + pred_children[m].append(i) + pred_children_labels[m].append(pred_labels[i]) + gold_children, pred_children = gold_children[1:], pred_children[1:] + gold_children_labels, pred_children_labels = gold_children_labels[1:], pred_children_labels[1:] + # collect labels info for CLAS + gold_label_inc = [(z.split(":")[0] not in ParserEvaler.CLAS_EXCLUDE_LABELS) for z in gold_labels] + pred_label_inc = [(z.split(":")[0] not in ParserEvaler.CLAS_EXCLUDE_LABELS) for z in pred_labels] + # ===== _a1cs("sent") - for idx in range(len(gold_pos)): - cur_pos = gold_pos[idx] - cur_gold_head, cur_gold_label = gold_heads[idx], gold_labels[idx] - cur_pred_head, cur_pred_label = pred_heads[idx], pred_labels[idx] - # + cur_len = len(gold_pos) + if pred_pos is None: + pred_pos = [""] * cur_len + for idx in range(cur_len): + cur_gold_pos, cur_pred_pos = gold_pos[idx], pred_pos[idx] + cur_gold_head, cur_gold_label, cur_gold_children, cur_gold_children_labels = \ + gold_heads[idx], gold_labels[idx], gold_children[idx], gold_children_labels[idx] + cur_pred_head, cur_pred_label, cur_pred_children, cur_pred_children_labels = \ + pred_heads[idx], pred_labels[idx], pred_children[idx], pred_children_labels[idx] + # token count _a1cs("tok") - if cur_pos not in self.punct_set: + if cur_gold_pos not in self.punct_set: the_suffixes = ("", "_np") _a1cs("tok_np") else: the_suffixes = ("", ) - # + # detailed evals for one_suffix in the_suffixes: + labeled_correct = False + # pos + if cur_gold_pos == cur_pred_pos: + _a1cs("pos_corr" + one_suffix) + # plain UAS/LAS if cur_gold_head == cur_pred_head: _a1cs("tok_corrU" + one_suffix) if cur_gold_label == cur_pred_label: + labeled_correct = True _a1cs("tok_corrL" + one_suffix) if one_suffix == "": self.add_confusion(cur_gold_label, cur_pred_label) + # HDUAS/HDLAS + if cur_gold_head == cur_pred_head and cur_gold_children == cur_pred_children: + _a1cs("tok_corrHDU" + one_suffix) + if cur_gold_label == cur_pred_label and cur_gold_children_labels == cur_pred_children_labels: + _a1cs("tok_corrHDL" + one_suffix) + # CLAS + if gold_label_inc[idx]: + _a1cs("tok_clas_all"+one_suffix) + if labeled_correct: + _a1cs("tok_clas_corr"+one_suffix) # must be also included by pred if correct + if pred_label_inc[idx]: + _a1cs("tok_clas_pall"+one_suffix) # ROOT if cur_gold_head == 0: _a1cs("tok_root_all") @@ -63,12 +105,14 @@ def _a1cs(_name): for which_metric in ("U", "L"): if curr_stat.get("tok_corr" + which_metric + one_suffix, 0) == this_tok_num: _a1cs("sent_corr" + which_metric + one_suffix) + # ===== Helper.stat_addv(self.stat, curr_stat) - # return self.summary_one(curr_stat) + return curr_stat # return (report_str, result) def summary_one(self, stat, verbose): _P = lambda x: NumHelper.truncate_float(x, 6) + _DIV = lambda a, b: 0. if b==0 else a/b # final_stat = {} s = "" @@ -87,27 +131,53 @@ def summary_one(self, stat, verbose): sent_corrL = stat.get("sent_corrL"+one_suffix, 0) sent_uas = _P(sent_corrU/sent_all) sent_las = _P(sent_corrL/sent_all) + # token: plain tok_all = stat.get("tok"+one_suffix, 0) tok_corrU = stat.get("tok_corrU"+one_suffix, 0) tok_corrL = stat.get("tok_corrL"+one_suffix, 0) tok_uas = _P(tok_corrU/tok_all) tok_las = _P(tok_corrL/tok_all) + # special: root tok_root_all = stat.get("tok_root_all", 0) tok_root_pall = stat.get("tok_root_pall", 0) - tok_corrR = stat.get("tok_root_corr", 0) - # todo(+1): this is only ROOT recall - tok_root_recall = _P(tok_corrR/tok_root_all) - tok_root_precision = _P(tok_corrR/tok_root_pall if tok_root_pall>0. else 0.) + tok_root_corr = stat.get("tok_root_corr", 0) + tok_root_recall = _P(_DIV(tok_root_corr, tok_root_all)) + tok_root_precision = _P(_DIV(tok_root_corr, tok_root_pall)) + tok_root_f1 = _P(_DIV(2*tok_root_recall*tok_root_precision, tok_root_recall+tok_root_precision)) + # special: hd + tok_corrHDU = stat.get("tok_corrHDU"+one_suffix, 0) + tok_corrHDL = stat.get("tok_corrHDL"+one_suffix, 0) + tok_hduas = _P(tok_corrHDU/tok_all) + tok_hdlas = _P(tok_corrHDL/tok_all) + # special: clas + tok_clas_all = stat.get("tok_clas_all"+one_suffix, 0) + tok_clas_pall = stat.get("tok_clas_pall"+one_suffix, 0) + tok_clas_corr = stat.get("tok_clas_corr"+one_suffix, 0) + tok_clas_recall = _P(_DIV(tok_clas_corr, tok_clas_all)) + tok_clas_precision = _P(_DIV(tok_clas_corr, tok_clas_pall)) + tok_clas_f1 = _P(_DIV(2*tok_clas_recall*tok_clas_precision, tok_clas_recall+tok_clas_precision)) + # special: pos for all (including punct) + tok_pos_corr = stat.get("pos_corr", 0) + tok_pos_all = stat.get("tok", 0) + tok_pos_acc = _P(_DIV(tok_pos_corr, tok_pos_all)) # - final_stat[one_suffix] = {"sent_uas": sent_uas, "sent_las": sent_las, "tok_uas": tok_uas, "tok_las": tok_las, - "tok_root_rec": tok_root_recall, "tok_root_pre": tok_root_precision} + final_stat[one_suffix] = { + "sent_uas": sent_uas, "sent_las": sent_las, "tok_uas": tok_uas, "tok_las": tok_las, + "tok_root_rec": tok_root_recall, "tok_root_pre": tok_root_precision, "tok_root_f1": tok_root_f1, + "tok_hduas": tok_hduas, "tok_hdlas": tok_hdlas, + "tok_clas_recall": tok_clas_recall, "tok_clas_precision": tok_clas_precision, "tok_clas_f1": tok_clas_f1, + "tok_pos_acc": tok_pos_acc, + } s += "== Eval of " + one_suffix + ": " - s += "tok: %d/%d(%.5f)/%d(%.5f) || " % (tok_all, tok_corrU, tok_uas, tok_corrL, tok_las) - s += "sent: %d/%d(%.5f)/%d(%.5f) || " % (sent_all, sent_corrU, sent_uas, sent_corrL, sent_las) - s += "root: %d/%d(%.5f);%d/%d(%.5f)\n" % (tok_root_all, tok_corrR, tok_root_recall, - tok_root_pall, tok_corrR, tok_root_precision) - # todo(warn): should look at LAS - final_stat[one_suffix]["res"] = final_stat[one_suffix]["tok_las"] + s += f"tok: {tok_all:d}/{tok_corrU:d}({tok_uas:.5f})/{tok_corrL:d}({tok_las:.5f}) || " + s += f"pos: {tok_pos_acc:5f} || " + s += f"sent: {sent_all:d}/{sent_corrU:d}({sent_uas:.5f})/{sent_corrL:d}({sent_las:.5f}) || " + s += f"root: {tok_root_all:d}/{tok_root_corr:d}({tok_root_recall:.5f});{tok_root_pall:d}/{tok_root_corr:d}({tok_root_precision:.5f}) [{tok_root_f1:.5f}] || " + s += f"hd: {tok_all:d}/{tok_corrHDU:d}({tok_hduas:.5f})/{tok_corrHDL:d}({tok_hdlas:.5f}) || " + s += f"clas: {tok_clas_all:d}/{tok_clas_corr:d}({tok_clas_recall:.5f});{tok_clas_pall:d}/{tok_clas_corr:d}({tok_clas_precision:.5f}) [{tok_clas_f1:.5f}]\n" + # todo(warn): here we average uas and las? + # final_stat[one_suffix]["res"] = final_stat[one_suffix]["tok_las"] + final_stat[one_suffix]["res"] = (final_stat[one_suffix]["tok_las"]+final_stat[one_suffix]["tok_uas"])/2 s += "== End Eval" # return if self.ignore_punct: @@ -117,4 +187,3 @@ def summary_one(self, stat, verbose): def summary(self, verbose=False): return self.summary_one(self.stat, verbose) - diff --git a/msp/zext/dpar/conllu_reader.py b/msp/zext/dpar/conllu_reader.py index 7d5d425..ad7f1bf 100644 --- a/msp/zext/dpar/conllu_reader.py +++ b/msp/zext/dpar/conllu_reader.py @@ -6,7 +6,7 @@ # one instance of parse class ConlluToken: - def __init__(self, parse, idx, w, up, xp, h, la, enh_hs, enh_ls): + def __init__(self, parse, idx, w, up, xp, h, la, enh_hs, enh_ls, misc=None): self.parse = parse # basic properties (on construction) self.idx = idx @@ -16,6 +16,15 @@ def __init__(self, parse, idx, w, up, xp, h, la, enh_hs, enh_ls): self.head = h self.label = la self.label0 = la.split(":")[0] if la is not None else None + self.misc = {} + if misc is not None and len(misc)>0 and misc != "_": + try: + for s in misc.split("|"): + k, v = s.split("=") + assert k not in self.misc, f"Err: Repeated key: {k}" + self.misc[k] = v + except: + pass # enhanced dependencies self.enh_heads = enh_hs self.enh_labels = enh_ls @@ -23,6 +32,7 @@ def __init__(self, parse, idx, w, up, xp, h, la, enh_hs, enh_ls): # extra helpful ones (build later, todo(warn) but not for ellip nodes!!) # dep-distance and root-distance self.ddist = 0 # mod - head + self.rdist = -1 # distance to root # self.rdist = 0 # left & right & all child (sorted left-to-right) self.lchilds = [] @@ -33,7 +43,17 @@ def __init__(self, parse, idx, w, up, xp, h, la, enh_hs, enh_ls): self.childs_labels = [] def __repr__(self): - return "\t".join(str(z) for z in (self.idx, self.word, self.upos, self.xpos, self.head, self.label0)) + return "\t".join(str(z) for z in (self.idx, self.word, self.upos, self.xpos, self.head, self.label0, self.misc)) + + def __getattr__(self, item): + if item in self.__dict__: + return self.__dict__[item] + else: + misc = self.__dict__.get("misc", None) + if misc is not None and item in misc: + return misc[item] + else: + raise AttributeError() def is_root(self): return self.idx == 0 @@ -55,7 +75,7 @@ def get_node(self, idx): return self.parse.get_tok(idx) def get_head(self): - return None if self.head is None else self.get_node(self.head) + return None if (self.head is None or self.head<0) else self.get_node(self.head) class ConlluParse: ROOT_SYM = "" @@ -64,12 +84,16 @@ class ConlluParse: IGNORE_F = lambda x: x.startswith("#") def __init__(self): + # headlines + self.headlines = [] # ordinary tokens self.tokens = [] # [0] is root token # enhance tokens: {int/tuple[id,n]: *} self.enhance_tokens = {} # crossed? self._crossed = None + # other properties + self.misc = {} @staticmethod def calc_crossed(heads): @@ -104,6 +128,10 @@ def crossed(self): def __repr__(self): return "\n".join([str(z) for z in self.tokens]) + # add comment + def add_headline(self, s): + self.headlines.append(s) + # sequentially add one def add_token(self, k, t: ConlluToken, keep_seq=True): if isinstance(k, int): @@ -135,6 +163,14 @@ def finish_tokens(self): head_tok.childs.append(cur_idx) head_tok.childs_labels.append(cur_label) cur_tok.ddist = cur_ddist + # calculate root distance + # ===== + def _calc_rdist(tok, d): + tok.rdist = d + for ii in tok.childs: + _calc_rdist(self.tokens[ii], d+1) + # ===== + _calc_rdist(self.tokens[0], 0) # get real tokens (exclude root) def get_tokens(self): @@ -201,18 +237,23 @@ def parse_enhance_dep(self, enh_str): return hs, ls def read_one(self, fd): - lines = FileHelper.read_multiline(fd, ConlluParse.SKIP_F, ConlluParse.IGNORE_F) + # lines = FileHelper.read_multiline(fd, ConlluParse.SKIP_F, ConlluParse.IGNORE_F) + lines = FileHelper.read_multiline(fd, ConlluParse.SKIP_F) if lines is None: return None # RR = ConlluParse.ROOT_SYM parse = ConlluParse() - root_token = ConlluToken(parse, 0, RR, RR, RR, -1, None, None, None) + root_token = ConlluToken(parse, 0, RR, RR, RR, -1, None, None, None, None) parse.add_token(0, root_token) # prev_idx = 0 for one_line in lines: - fileds = one_line.strip().split('\t') + if one_line.startswith("#"): + # add headlines + parse.add_headline(one_line[1:].strip()) + continue + fileds = one_line.strip("\n").split('\t') zcheck(len(fileds)==ConlluParse.NUM_FIELDS, "Conllu Error: Unmatched num of fields.") # id id_str = fileds[0] @@ -228,10 +269,10 @@ def read_one(self, fd): else: zcheck(one_idx[0] == prev_idx, "Conllu Error: wrong ELLI line-idx") # add basic ones - word, upos, xpos, head, label = fileds[1], fileds[3], fileds[4], fileds[6], fileds[7] + word, upos, xpos, head, label, misc = fileds[1], fileds[3], fileds[4], fileds[6], fileds[7], fileds[9] head = int(head) if head!="_" else None ehn_hs, enh_ls = self.parse_enhance_dep(fileds[8]) - one_token = ConlluToken(parse, one_idx, word, upos, xpos, head, label, ehn_hs, enh_ls) + one_token = ConlluToken(parse, one_idx, word, upos, xpos, head, label, ehn_hs, enh_ls, misc) parse.add_token(one_idx, one_token) parse.finish_tokens() return parse @@ -244,8 +285,13 @@ def yield_ones(self, fd): yield parse # writer -def write_conllu(fd, word, pos, heads, labels): +def write_conllu(fd, word, pos, heads, labels, miscs=None, headlines=None): length = len(word) + if miscs is None: + miscs = ["_"] * length + if headlines is not None: + for line in headlines: + fd.write(f"# {line}\n") for idx in range(length): fields = ["_"] * 10 fields[0] = str(idx+1) @@ -253,6 +299,8 @@ def write_conllu(fd, word, pos, heads, labels): fields[3] = pos[idx] fields[6] = str(heads[idx]) fields[7] = labels[idx] + if len(miscs[idx])>0: + fields[9] = miscs[idx] s = "\t".join(fields) + "\n" fd.write(s) fd.write("\n") diff --git a/msp/zext/evaler.py b/msp/zext/evaler.py new file mode 100644 index 0000000..a67653b --- /dev/null +++ b/msp/zext/evaler.py @@ -0,0 +1,100 @@ +# + +from typing import Dict +from collections import Counter + +from msp.utils import zlog, Helper + +class DivResult: + def __init__(self, a, b): + self.a = float(a) + self.b = float(b) + if b == 0.: + assert a == 0., "Illegal non-zero/zero" + self.res = 0. + else: + self.res = a/b + + def __repr__(self): + return f"[{self.a}/{self.b}={self.res:.4f}]" + + def __float__(self): + return self.res + +class F1Result: + def __init__(self, p, r): + self.p = p + self.r = r + p, r = float(p), float(r) + self.f1 = float(DivResult(2*p*r, p+r)) + + def __repr__(self): + return f"(p={self.p}, r={self.r}, f={self.f1:.4f})" + + def __float__(self): + return self.f1 + +class LabelF1Evaler: + def __init__(self, name): + # key -> List[labels] + self.name = name + self.golds = {} + self.preds = {} + self.labels = set() + + # ===== + # adding ones + + def _add_group(self, d: Dict, key, label): + if key not in d: + d[key] = [label] + else: + d[key].append(label) + self.labels.add(label) + + def add_gold(self, key, label): + self._add_group(self.golds, key, label) + + def add_pred(self, key, label): + self._add_group(self.preds, key, label) + + # ===== + # results + + # this can be precision(base=pred) or recall(base=gold) + def _calc_result(self, base_d: Dict, query_d: Dict): + counts = {k:0 for k in self.labels} + corrs = {k:0 for k in self.labels} + ucorrs = 0 + for one_k, one_labels in base_d.items(): + other_label_list = query_d.get(one_k, []) + other_label_maps = Counter(other_label_list) # label -> count + ucorrs += min(len(one_labels), len(other_label_list)) + for one_lab in one_labels: + counts[one_lab] += 1 + remaining_budget = other_label_maps.get(one_lab, 0) + if remaining_budget > 0: + corrs[one_lab] += 1 + other_label_maps[one_lab] = remaining_budget - 1 + results = {k: DivResult(corrs[k], counts[k]) for k in self.labels} + all_result_u = DivResult(ucorrs, sum(counts.values())) + all_result_l = DivResult(sum(corrs.values()), sum(counts.values())) + return all_result_u, all_result_l, results + + # evaluate results + def eval(self, quiet=True, breakdown=False): + all_pre_u, all_pre_l, label_pres = self._calc_result(self.preds, self.golds) + all_rec_u, all_rec_l, label_recs = self._calc_result(self.golds, self.preds) + all_f_u = F1Result(all_pre_u, all_rec_u) + all_f_l = F1Result(all_pre_l, all_rec_l) + label_fs = {k: F1Result(label_pres[k], label_recs[k]) for k in self.labels} if breakdown else {} + if not quiet: + zlog(f"Overall f1 score for {self.name}: unlabeled {all_f_u}; labeled {all_f_l}") + zlog("Breakdowns: \n" + Helper.printd_str(label_fs)) + return all_f_u, all_f_l, label_fs + + # analysis + def analyze(self): + pass + +# b msp/zext/evaler:87 diff --git a/msp/zext/ie/__init__.py b/msp/zext/ie/__init__.py new file mode 100644 index 0000000..082b7cd --- /dev/null +++ b/msp/zext/ie/__init__.py @@ -0,0 +1,7 @@ +# + +# todo(note): move here to make it easier to save(serialize) + +from .label import HLabelConf, HLabelVocab, HLabelIdx +from .utils import * + diff --git a/msp/zext/ie/keyword.py b/msp/zext/ie/keyword.py new file mode 100644 index 0000000..786742e --- /dev/null +++ b/msp/zext/ie/keyword.py @@ -0,0 +1,195 @@ +# + +# from gensim.models import TfidfModel +# with TfidfModel as the reference + +import glob +from typing import Dict, List, Tuple +import numpy as np +# from multiprocessing import Process +# from threading import Thread, Lock +from multiprocessing.pool import ThreadPool, Pool +from msp.utils import Conf, zopen, Helper, PickleRW, zlog + +# +class KeyWordConf(Conf): + def __init__(self): + # how to build + self.build_files = [] + self.build_num_core = 4 + self.save_file = "" + self.load_file = "" + +class KeyWordModel: + def __init__(self, conf: KeyWordConf, num_doc=0, w2f=None, w2d=None): + self.conf = conf + # ----- + self.num_doc = num_doc + self.w2f = {} if w2f is None else w2f # word -> freq + self.w2d = {} if w2d is None else w2d # word -> num-doc (dfs) + # ----- + self._calc() + zlog(f"Create(or Load) {self}") + + # + def __repr__(self): + return f"KeyWordModel: #doc={self.num_doc}, #wordtype={len(self.w2f)}" + + # + def get_rank(self, w): + RANK_INF = len(self.w2i) + return self.w2i.get(w, RANK_INF) + + # + def _calc(self): + self.i2w = sorted(self.w2f.keys(), key=lambda x: (-self.w2f[x], x)) # rank by freq + self.w2i = {w:i for i,w in enumerate(self.i2w)} # reversely word -> rank + + # save and load + def save(self, f): + PickleRW.to_file([self.num_doc, self.w2f, self.w2d], f) + zlog(f"Save {self}") + + @staticmethod + def load(f, conf=None): + if conf is None: + conf = KeyWordConf() + num_doc, w2f, w2d = PickleRW.from_file(f) + m = KeyWordModel(conf, num_doc, w2f, w2d) + return m + + @staticmethod + def collect_freqs(file): + zlog(f"Starting dealing with {file}") + MIN_TOK_PER_DOC = 100 # minimum token per doc + MIN_PARA_PER_DOC = 5 # minimum line-num (paragraph) + word2info = {} # str -> [count, doc-count] + num_doc = 0 + with zopen(file) as fd: + docs = fd.read().split("\n\n") + for one_doc in docs: + tokens = one_doc.split() # space or newline + if len(tokens)>=MIN_TOK_PER_DOC and len(one_doc.split("\n"))>=MIN_PARA_PER_DOC: + num_doc += 1 + # first raw counts + for t in tokens: + t = str.lower(t) # todo(note): lowercase!! + if t not in word2info: + word2info[t] = [0, 0] + word2info[t][0] += 1 + # then doc counts (must be there) + for t in set(tokens): + t = str.lower(t) + word2info[t][1] += 1 + return num_doc, word2info + + @staticmethod + def build(conf: KeyWordConf): + # collect all files + all_build_files = [] + for one in conf.build_files: + all_build_files.extend(glob.glob(one)) + # using multiprocessing + pool = Pool(conf.build_num_core) + # pool = ThreadPool(conf.build_num_core) + results = pool.map(KeyWordModel.collect_freqs, all_build_files) + pool.close() + pool.join() + # then merge all together + word2info = {} # str -> [count, doc-count] + num_doc = 0 + for one_num_doc, one_w2i in results: + num_doc += one_num_doc + for k, info in one_w2i.items(): + if k not in word2info: + word2info[k] = [0, 0] + word2info[k][0] += info[0] + word2info[k][1] += info[1] + # build the model + m = KeyWordModel(conf, num_doc=num_doc, w2f={w:z[0] for w,z in word2info.items()}, + w2d={w:z[1] for w,z in word2info.items()}) + return m + + # ===== + # sort and return range ([start, end)) for "the freq of freq" + def _freq2range(self, freqs: List[int]) -> List[Tuple[int, int]]: + sorting_items = [(f,i) for i,f in enumerate(freqs)] + sorting_items.sort(reverse=True) + ranges = [None for _ in freqs] + prev_v, prev_idxes = None, [] + cur_rank = 0 + for cur_v, cur_idx in sorting_items+[(None, None)]: + if cur_v != prev_v: + # close previous ones + one_range = (cur_rank-len(prev_idxes), cur_rank) + for one_idx in prev_idxes: + assert ranges[one_idx] is None + ranges[one_idx] = one_range + prev_idxes.clear() + prev_v = cur_v + prev_idxes.append(cur_idx) + cur_rank += 1 + assert all(x is not None for x in ranges) + return ranges + + def extract(self, tokens: List[str], n_tf='n', n_df='t'): + # get local vocabs + local_w2f = {} + for t in tokens: + t = str.lower(t) # todo(note): lowercase!! + local_w2f[t] = local_w2f.get(t, 0.) + 1 + # get tf/idf + f_tf = { + "n": lambda tf: tf, + "l": lambda tf: 1+np.log(tf)/np.log(2), + "a": lambda tf: 0.5 + (0.5 * tf / tf.max(axis=0)), + "b": lambda tf: tf.astype('bool').astype('int'), + "L": lambda tf: (1 + np.log(tf) / np.log(2)) / (1 + np.log(tf.mean(axis=0) / np.log(2))), + }[n_tf] + totaldocs = self.num_doc + f_idf = { + "n": lambda df: df, + "t": lambda df: np.log(1.0 * totaldocs / df) / np.log(2), + "p": lambda df: np.log((1.0 * totaldocs - df) / df) / np.log(2), + }[n_df] + # tf-idf + local_words = sorted(local_w2f.keys(), key=lambda x: (-local_w2f[x], x)) + cur_tf = f_tf(np.array([local_w2f[w] for w in local_words])) + cur_idf = f_idf(np.array([self.w2d.get(w, 0)+1. for w in local_words])) # +1 smoothing + cur_tf_idf = cur_tf * cur_idf + # comparing the rankings + local_freqs = [local_w2f[x] for x in local_words] + global_freqs = [self.w2f.get(x, 0) for x in local_words] + local_ranges = self._freq2range(local_freqs) + global_ranges = self._freq2range(global_freqs) + # ranking range center jump forward + cur_range_jump = [(r2[0]+r2[1]-r1[0]-r1[1])/2. for r1,r2 in zip(local_ranges, global_ranges)] + # return local_words, global_rank, cur_tf, cur_tf_idf, cur_range_jump + global_ranks = [self.get_rank(x) for x in local_words] + ret = [(w,fr,gr,s1,s2) for w,fr,gr,s1,s2 in zip(local_words, global_ranks, cur_tf, cur_tf_idf, cur_range_jump)] + return ret + +if __name__ == '__main__': + import sys + conf = KeyWordConf() + conf.update_from_args(sys.argv[1:]) + if len(conf.build_files)>0: + m = KeyWordModel.build(conf) + if conf.save_file: + m.save(conf.save_file) + else: + m = KeyWordModel.load(conf.load_file, conf) + # test it + testing_file = "1993.01.tok" + with zopen(testing_file) as fd: + docs = fd.read().split("\n\n") + for one_doc in docs: + tokens = one_doc.split() # space or newline + one_result = m.extract(tokens) + res_sort1 = sorted(one_result, key=lambda x: -x[-2]) + res_sort2 = sorted(one_result, key=lambda x: -x[-1]) + res_sort0 = sorted(one_result, key=lambda x: -x[-3]) + zzzzz = 0 + +# PYTHONPATH=../src/ python3 k2.py build_files:./*.tok build_num_core:10 save_file:./nyt.voc.pic +# PYTHONPATH=../src/ python3 -m pdb k2.py build_files:[] load_file:./nyt.voc.pic diff --git a/msp/zext/ie/label.py b/msp/zext/ie/label.py new file mode 100644 index 0000000..171b2d1 --- /dev/null +++ b/msp/zext/ie/label.py @@ -0,0 +1,260 @@ +# + +# hierachical labeling vocab +# the hierarchical labeling system +# todo(note): the labeling str format is L1.L2.L3..., each layer has only 26-characters + +from typing import List, Tuple, Iterable, Dict +import numpy as np +from collections import defaultdict + +from msp.utils import Conf, Random, zlog +from msp.data import Vocab, WordVectors +from msp.zext.seq_helper import DataPadder + +from .utils import split_camel, LABEL2LEX_MAP + +# +class HLabelIdx: + def __init__(self, types: List, idxes: List = None): + # todo(note): input idxes is readable only! + self.types = [] + for i, z in enumerate(types): + if z is not None: + self.types.append(z) + else: + assert all(zz is None for zz in types[i:]) # once None, must all be None + break + self.idxes: Tuple = (None if idxes is None else tuple(idxes)) + + def is_nil(self): + return len(self.types)==0 + + @staticmethod + def construct_hlidx(label: str, layered: bool): + if label is None: + types = [] + else: + if layered: + types = [z for z in label.split(".") if len(z)>0] # List[str] + else: + types = [label] + return HLabelIdx(types) + + def pad_types(self, num_layer) -> Tuple: + if len(self.types) >= num_layer: + ret = self.types + else: + ret = self.types + [None] * (num_layer - len(self.types)) + return tuple(ret) + + def get_idx(self, layer): + return self.idxes[layer] + + def __eq__(self, other): + assert self.idxes is not None + return self.idxes == other.idxes + + def __hash__(self): + assert self.idxes is not None + return self.idxes + + def __len__(self): + return len(self.types) + + # directly using __str__ to get the original string label + def __repr__(self): + return ".".join(self.types) + +# configuration for HLabelVocab and HLabelModel +class HLabelConf(Conf): + def __init__(self): + # part 1: layering + self.layered = False # whether split layers for the labels + # part 2: associating types (how to get/combine one type's embedding) + self.pool_split_camel = False + self.pool_sharing = "nope" # nope: no sharing, layered: sharing for each layer, shared: fully shared on lexicon + +# vocab: for each layer: 0 as NIL, [1,n] as types, n+1 as UNK +# two parts of label embeddings: +# pools(flattened embed pool) -> layer-label(hierarchical structured one mapping to one or sum of the ones in the pool) +# todo(+N): can the link be soften? like with trainable linking parameters? +# ===== +# todo(note): summary of this complex class's fields +class HLabelVocab: + def __init__(self, name, conf: HLabelConf, keys: Dict[str, int], nil_as_zero=True): + assert nil_as_zero, "Currently assume nil as zero for all layers" + self.conf = conf + self.name = name + # from original vocab + self.orig_counts = {k: v for k,v in keys.items()} + keys = sorted(set(keys.keys())) # for example, can be "vocab.trg_keys()" + max_layer = 0 + # ===== + # part 1: layering + v = {} + keys = [None] + keys # ad None as NIL + for k in keys: + cur_idx = HLabelIdx.construct_hlidx(k, conf.layered) + max_layer = max(max_layer, len(cur_idx)) + v[k] = cur_idx + # collect all the layered types and put idxes + self.max_layer = max_layer + self.layered_v = [{} for _ in range(max_layer)] # key -> int-idx for each layer + self.layered_k = [[] for _ in range(max_layer)] # int-idx -> key for each layer + self.layered_prei = [[] for _ in range(max_layer)] # int-idx -> int-idx: idx of prefix in previous layer + self.layered_hlidx = [[] for _ in range(max_layer)] # int-idx -> hlidx for each layer + for k in keys: + cur_hidx = v[k] + cur_types = cur_hidx.pad_types(max_layer) + assert len(cur_types) == max_layer + cur_idxes = [] # int-idxes for each layer + # assign from 0 to max-layer + for cur_layer_i in range(max_layer): + cur_layered_v, cur_layered_k, cur_layered_prei, cur_layered_hlidx = \ + self.layered_v[cur_layer_i], self.layered_k[cur_layer_i], \ + self.layered_prei[cur_layer_i], self.layered_hlidx[cur_layer_i] + # also put empty classes here + for cur_layer_types in [cur_types[:cur_layer_i]+(None,), cur_types[:cur_layer_i+1]]: + if cur_layer_types not in cur_layered_v: + new_idx = len(cur_layered_k) + cur_layered_v[cur_layer_types] = new_idx + cur_layered_k.append(cur_layer_types) + cur_layered_prei.append(0 if cur_layer_i==0 else cur_idxes[-1]) # previous idx + cur_layered_hlidx.append(HLabelIdx(cur_layer_types, None)) # make a new hlidx, need to fill idxes later + # put the actual idx + cur_idxes.append(cur_layered_v[cur_types[:cur_layer_i+1]]) + cur_hidx.idxes = cur_idxes # put the idxes (actually not useful here) + self.nil_as_zero = nil_as_zero + if nil_as_zero: + assert all(z[0].is_nil() for z in self.layered_hlidx) # make sure each layer's 0 is all-Nil + # put the idxes for layered_hlidx + self.v = {} + for cur_layer_i in range(max_layer): + cur_layered_hlidx = self.layered_hlidx[cur_layer_i] + for one_hlidx in cur_layered_hlidx: + one_types = one_hlidx.pad_types(max_layer) + one_hlidx.idxes = [self.layered_v[i][one_types[:i+1]] for i in range(max_layer)] + self.v[str(one_hlidx)] = one_hlidx # todo(note): further layers will over-written previous ones + self.nil_idx = self.v[""] + # ===== + # (main) part 2: representation + # link each type representation to the overall pool + self.pools_v = {None: 0} # NIL as 0 + self.pools_k = [None] + self.pools_hint_lexicon = [[]] # hit lexicon to look up in pre-trained embeddings: List[pool] of List[str] + # links for each label-embeddings to the pool-embeddings: List(layer) of List(label) of List(idx-in-pool) + self.layered_pool_links = [[] for _ in range(max_layer)] + # masks indicating local-NIL(None) + self.layered_pool_isnil = [[] for _ in range(max_layer)] + for cur_layer_i in range(max_layer): + cur_layered_pool_links = self.layered_pool_links[cur_layer_i] # List(label) of List + cur_layered_k = self.layered_k[cur_layer_i] # List(full-key-tuple) + cur_layered_pool_isnil = self.layered_pool_isnil[cur_layer_i] # List[int] + for one_k in cur_layered_k: + one_k_final = one_k[cur_layer_i] # use the final token at least for lexicon hint + if one_k_final is None: + cur_layered_pool_links.append([0]) # todo(note): None is always zero + cur_layered_pool_isnil.append(1) + continue + cur_layered_pool_isnil.append(0) + # either splitting into pools or splitting for lexicon hint + one_k_final_elems: List[str] = split_camel(one_k_final) + # adding prefix according to the strategy + if conf.pool_sharing == "nope": + one_prefix = f"L{cur_layer_i}-" + ".".join(one_k[:-1]) + "." + elif conf.pool_sharing == "layered": + one_prefix = f"L{cur_layer_i}-" + elif conf.pool_sharing == "shared": + one_prefix = "" + else: + raise NotImplementedError(f"UNK pool-sharing strategy {conf.pool_sharing}!") + # put in the pools (also two types of strategies) + if conf.pool_split_camel: + # possibly multiple mappings to the pool, each pool-elem gets only one hint-lexicon + cur_layered_pool_links.append([]) + for this_pool_key in one_k_final_elems: + this_pool_key1 = one_prefix + this_pool_key + if this_pool_key1 not in self.pools_v: + self.pools_v[this_pool_key1] = len(self.pools_k) + self.pools_k.append(this_pool_key1) + self.pools_hint_lexicon.append([self.get_lexicon_hint(this_pool_key)]) + cur_layered_pool_links[-1].append(self.pools_v[this_pool_key1]) + else: + # only one mapping to the pool, each pool-elem can get multiple hint-lexicons + this_pool_key1 = one_prefix + one_k_final + if this_pool_key1 not in self.pools_v: + self.pools_v[this_pool_key1] = len(self.pools_k) + self.pools_k.append(this_pool_key1) + self.pools_hint_lexicon.append([self.get_lexicon_hint(z) for z in one_k_final_elems]) + cur_layered_pool_links.append([self.pools_v[this_pool_key1]]) + assert self.pools_v[None] == 0, "Internal error!" + # padding and masking for the links (in numpy) + self.layered_pool_links_padded = [] # List[arr(#layered-label, #elem)] + self.layered_pool_links_mask = [] # List[arr(...)] + padder = DataPadder(2, mask_range=2) # separately for each layer + for cur_layer_i in range(max_layer): + cur_arr, cur_mask = padder.pad(self.layered_pool_links[cur_layer_i]) # [each-sublabel, padded-max-elems] + self.layered_pool_links_padded.append(cur_arr) + self.layered_pool_links_mask.append(cur_mask) + self.layered_prei = [np.asarray(z) for z in self.layered_prei] + self.layered_pool_isnil = [np.asarray(z) for z in self.layered_pool_isnil] + self.pool_init_vec = None + # + zlog(f"Build HLabelVocab {name} (max-layer={max_layer} from pools={len(self.pools_k)}): " + "; ".join([f"L{i}={len(self.layered_k[i])}" for i in range(max_layer)])) + + # set npvec to init HLNode + def set_pool_init(self, npvec: np.ndarray): + assert len(npvec) == len(self.pools_k) + self.pool_init_vec = npvec + + # filter embeddings for init pool (similar to Vocab.filter_embeds) + def filter_pembed(self, wv: WordVectors, init_nohit=0., scale=1.0, assert_all_hit=True, set_init=True): + if init_nohit <= 0.: + get_nohit = lambda s: np.zeros((s,), dtype=np.float32) + else: + get_nohit = lambda s: (Random.random_sample((s,)).astype(np.float32)-0.5) * (2*init_nohit) + # + ret = [np.zeros((wv.embed_size,), dtype=np.float32)] # init NIL is zero + record = defaultdict(int) + for ws in self.pools_hint_lexicon[1:]: + res = np.zeros((wv.embed_size,), dtype=np.float32) + for w in ws: + hit, norm_name, norm_w = wv.norm_until_hit(w) + if hit: + value = np.asarray(wv.get_vec(norm_w, norm=False), dtype=np.float32) + record[norm_name] += 1 + else: + value = get_nohit(wv.embed_size) + record["no-hit"] += 1 + res += value + ret.append(res) + # + assert not assert_all_hit or record["no-hit"]==0, f"Filter-embed error: assert all-hit but get no-hit of {record['no-hit']}" + zlog(f"Filter pre-trained Pembed: {record}, no-hit is inited with {init_nohit}.") + ret = np.asarray(ret, dtype=np.float32) * scale + if set_init: + self.set_pool_init(ret) + return ret + + # ===== + # query + def val2idx(self, item: str) -> HLabelIdx: + if item is None: + item = "" + return self.v[item] + + def idx2val(self, idx: HLabelIdx) -> str: + return str(idx) + + def get_hlidx(self, idx: int, eff_max_layer: int): + return self.layered_hlidx[eff_max_layer-1][idx] + + def val2count(self, item: str) -> int: + return self.orig_counts[item] + + # transform single unit into lexicon + @staticmethod + def get_lexicon_hint(key: str): + k0 = key.lower() + return LABEL2LEX_MAP.get(k0, k0) diff --git a/msp/zext/ie/utils.py b/msp/zext/ie/utils.py new file mode 100644 index 0000000..091d574 --- /dev/null +++ b/msp/zext/ie/utils.py @@ -0,0 +1,34 @@ +# ===== +# other helpers + +# split by Camel Capital letter or "-"/".", return list of lowercased parts +def split_camel(ss): + prev_upper = False + all_cs = [] + for c in ss: + if str.isupper(c): + if not prev_upper: # splitting point + all_cs.append("-") + prev_upper = True + elif str.islower(c): + prev_upper = False + else: + prev_upper = False + if c == ".": + c = '-' + assert c=="-", f"Illegal char {c} for label" + all_cs.append(str.lower(c)) # here, first lower all + ret = [z for z in "".join(all_cs).split("-") if len(z)>0] + return ret + +# norm things into camel format without "-" +def rearrange_camel(ss): + parts = split_camel(ss) + return "".join([str.upper(z[0])+z[1:] for z in parts]) + +# todo(note): specific to ACE/ERE style +# mainly for Entity types +LABEL2LEX_MAP = { + "per": "person", "org": "organization", "gpe": "geopolitics", "loc": "location", "fac": "facility", + "veh": "vehicle", "wea": "weapon", +} diff --git a/msp/zext/process_train.py b/msp/zext/process_train.py index 6041c55..0ab0819 100644 --- a/msp/zext/process_train.py +++ b/msp/zext/process_train.py @@ -46,12 +46,17 @@ def __init__(self, patience, anneal_restarts_cap): self.train_records = [] self.dev_records = [] # about dev checks + # all best self.best_dev_record = RecordResult({}) self.best_point = -1 + # deal with anneal with all bests self.bad_counter = 0 self.bad_points = [] self.anneal_restarts_done = 0 self.anneal_restarts_points = [] + # only save best + self.save_best_dev_record = RecordResult({}) + self.save_best_point = -1 # def current_suffix(self): @@ -62,18 +67,27 @@ def current_suffix(self): def current_idxes(self): return {"aidx": self.anneal_restarts_done, "eidx": self.eidx, "iidx": self.iidx, "uidx": self.uidx, "cidx": len(self.chp_names)} - def best_info(self): - return [self.best_point, self.chp_names[self.best_point], str(self.best_dev_record)] + def info_best(self): + if self.best_point < 0: + return [-1, "None", "-inf"] + else: + return [self.best_point, self.chp_names[self.best_point], str(self.best_dev_record)] + + def info_save_best(self): + if self.save_best_point < 0: + return [-1, "None", "-inf"] + else: + return [self.save_best_point, self.chp_names[self.save_best_point], str(self.save_best_dev_record)] # called after one dev, log dev result and restart train-record - def checkpoint(self, train_result: RecordResult, dev_result: RecordResult): + def checkpoint(self, train_result: RecordResult, dev_result: RecordResult, use_save_best: bool): # log down sname = self.current_suffix() self.chp_names.append(sname) self.train_records.append(train_result) self.dev_records.append(dev_result) - # - if_best = if_anneal = False + # anneal is for all the points + if_best = if_save_best = if_anneal = False if float(dev_result) > float(self.best_dev_record): self.bad_counter = 0 self.best_dev_record = dev_result @@ -88,33 +102,46 @@ def checkpoint(self, train_result: RecordResult, dev_result: RecordResult): if self.bad_counter >= self.patience: self.bad_counter = 0 self.anneal_restarts_points.append(cur_info) - if self.anneal_restarts_done < self.anneal_restarts_cap: - self.anneal_restarts_done += 1 - utils.zlog("Anneal++, now %s." % (self.anneal_restarts_done,), func="warn") - if_anneal = True - else: + # there can be chances of continuing when restarts enough (min training settings) + self.anneal_restarts_done += 1 + utils.zlog("Anneal++, now %s." % (self.anneal_restarts_done,), func="warn") + if_anneal = True + if self.anneal_restarts_done >= self.anneal_restarts_cap: utils.zlog("Sorry, Early Stop !!", func="warn") self.estop = True - return if_best, if_anneal + # save is only for certain points + if use_save_best: + if float(dev_result) > float(self.save_best_dev_record): + self.save_best_dev_record = dev_result + self.save_best_point = len(self.chp_names) - 1 + if_save_best = True + return if_best, if_save_best, if_anneal # class RConf(Conf): def __init__(self): # self.skip_batch = 0. # rate of randomly skipping each batch + # todo(+N): the effect of split_batch can influence lr in more advanced optimizer? (since we are div lr) self.split_batch = 1 # num of batch splits, default 1 means no splitting + # (now div loss) self.grad_div_for_sb = 0 # divide grad for split-batch mode rather than div lrate self.flag_debug = False self.flag_verbose = True # mostly based on uidx: number of updates self.report_freq = 1000 - self.valid_freq = Constants.INT_MAX + self.valid_freq = Constants.INT_PRAC_MAX * 100 # this may be enough... self.validate_epoch = True self.validate_first = False + self.min_epochs = 0 + self.min_save_epochs = 0 self.max_epochs = 100 - self.max_updates = 1000000 + self.min_updates = 0 + self.min_save_updates = 0 + self.max_updates = 10000000 # annealing self.anneal_times = 100 self.patience = 3 + self.anneal_restore = False # restore previous best model when annealing # write model self.model_overwrite = True self.model_name = "zmodel" @@ -122,8 +149,26 @@ def __init__(self): self.suffix_best = ".best" # # lrate schedule - self.lrate = SVConf().init_from_kwargs(init_val=0.001, which_idx="aidx", mode="exp", k=0.75) + self.lrate = SVConf().init_from_kwargs(val=0.001, which_idx="aidx", mode="exp", m=0.75, min_val=0.00001) self.lrate_warmup = 0 # linear increasing lrate as warmup for how many steps (minus values means epoch) + self.lrate_anneal_alpha = 0. # similar to ATT-is-ALL-you-Need, sth like -0.5 (after warmup, step^anneal) + + def do_validate(self): + # min_* must be >= min_save_* + self.min_epochs = max(self.min_epochs, self.min_save_epochs) + self.min_updates = max(self.min_updates, self.min_save_updates) + +# optimizer conf +class OptimConf(Conf): + def __init__(self): + self.optim = "adam" + self.sgd_momentum = 0.85 # for "sgd" + self.adam_betas = [0.9, 0.9] # for "adam" + self.adam_eps = 1e-4 # for "adam" + self.adadelta_rho = 0.9 # for "adadelta" + self.grad_clip = 5.0 + self.no_step_lrate0 = True # no step when lrate<=0., even no momentum accumulating + self.weight_decay = 0. # common practice for training class TrainingRunner(object): @@ -148,9 +193,34 @@ def add_scheduled_values(self, sv_list): utils.zlog("Adding scheduled value %s in training." % (one,)) self.scheduled_values.append(one) - def _finished(self): + def _adjust_scheduled_values(self): + # schedule values + ss = self.current_name() + cur_idxes = self._tp.current_idxes() + for one_sv in self.scheduled_values: + one_sv.adjust_at_ckp(ss, cur_idxes) + + # ===== + # ending criteria + + # already reach any of the right point of training (max-epoch/update or max-restart) + def _reach_right_end(self): return self._tp.estop or self._tp.eidx >= self.rconf.max_epochs \ - or self._tp.uidx >= self.rconf.max_updates + or self._tp.uidx >= self.rconf.max_updates + + # already reach all of the left point of training (min-epoch/update) + def _reach_left_end(self): + return self._tp.eidx >= self.rconf.min_epochs and self._tp.uidx >= self.rconf.min_updates + + # reach all of the min save conditions + def _reach_save_end(self): + return self._tp.eidx >= self.rconf.min_save_epochs and self._tp.uidx >= self.rconf.min_save_updates + + # only finish when reaching both ends + def _finished(self): + return self._reach_left_end() and self._reach_right_end() + + # ===== def current_name(self): return self._tp.current_suffix() @@ -158,10 +228,11 @@ def current_name(self): # training for one batch def _fb_batch(self, insts): num_splits = self.rconf.split_batch + loss_factor = 1. / num_splits splitted_insts = Helper.split_list(insts, num_splits) with self.train_recorder.go(): for one_insts in splitted_insts: - res = self._run_fb(one_insts) + res = self._run_fb(one_insts, loss_factor) self.train_recorder.record(res) self._tp.iidx += self.batch_size_f(insts) @@ -182,20 +253,24 @@ def _validate(self, dev_streams): # validate dev_result = self._run_validate(dev_streams) # record - if_best, if_anneal = self._tp.checkpoint(train_result, dev_result) + cur_use_save_best = self._reach_save_end() + if_best, if_save_best, if_anneal = self._tp.checkpoint(train_result, dev_result, cur_use_save_best) # checkpoint - save curr & best self.save(rconf.model_name+rconf.suffix_curr) if not rconf.model_overwrite: self.save(rconf.model_name+ss) if if_best: self.save(rconf.model_name+rconf.suffix_best) - utils.zlog("Curr is best: " + str(self._tp.best_info()), func="result") + utils.zlog("Curr is best: " + str(self._tp.info_best()), func="result") else: - utils.zlog("Curr not best, the best is " + str(self._tp.best_info()), func="result") - # schedule values - cur_idxes = self._tp.current_idxes() - for one_sv in self.scheduled_values: - one_sv.adjust_at_ckp(ss, cur_idxes) + utils.zlog("Curr not best, the best is " + str(self._tp.info_best()), func="result") + if if_save_best: + # todo(+2): here overwrite the previous best point, will this small mismatch damage reloading? + utils.zlog("But Curr is save_best, overwrite the best point!") + self.save(rconf.model_name + rconf.suffix_best) + if if_anneal and self.rconf.anneal_restore: + utils.zlog("Restore from previous best model!!") + self.load(rconf.model_name+rconf.suffix_best, False) utils.zlog("") def run(self, train_stream, dev_streams): @@ -203,7 +278,8 @@ def run(self, train_stream, dev_streams): last_report_uidx, last_dev_uidx = 0, 0 if rconf.validate_first: self._validate(dev_streams) - # for lrate warm + # ===== + # for lrate warmup and annealing if rconf.lrate_warmup < 0: # calculate epochs steps_per_epoch = 0 @@ -216,29 +292,44 @@ def run(self, train_stream, dev_streams): n_steps = rconf.lrate_warmup else: n_steps = 0 - utils.zlog(f"For lrate-warmup, will go with the first {n_steps} steps.") + max_lrate = self.lrate.value + # final_lrate = lrate * anneal_factor * (step)^lrate_anneal_alpha + # final_lrate(n_steps) = max_lrate + lrate_anneal_alpha = rconf.lrate_anneal_alpha + if n_steps > 0: + anneal_factor = 1. / (n_steps**lrate_anneal_alpha) + else: + anneal_factor = 1. self.lrate_warmup_steps = n_steps - # - num_splits = rconf.split_batch + utils.zlog(f"For lrate-warmup, will go with the first {n_steps} steps up to {max_lrate}, " + f"then anneal with lrate*{anneal_factor}*step^{lrate_anneal_alpha}") + # ===== while not self._finished(): + # todo(note): epoch start from 1!! + self._tp.eidx += 1 with Timer(tag="Train-Iter", info="Iter %s" % self._tp.eidx, print_date=True) as et: + self._adjust_scheduled_values() # adjust at the start of each epoch + act_lrate = 0. # for batches for insts in train_stream: # skip this batch - if Random.random_bool(rconf.skip_batch, task="train"): + if Random.random_bool(rconf.skip_batch): continue # train on batch, return a dictionary # possibly split batch to save memory self._fb_batch(insts) self._tp.uidx += 1 # get the effective lrate - act_lrate = self.lrate.value / num_splits # compensate for batch-split + act_lrate = self.lrate.value if self._tp.uidx < self.lrate_warmup_steps: act_lrate *= (self._tp.uidx / self.lrate_warmup_steps) + else: + act_lrate *= anneal_factor * (self._tp.uidx**lrate_anneal_alpha) # - self._run_update(act_lrate) + self._run_update(act_lrate, 1.) # report on training process if rconf.flag_verbose and (self._tp.uidx-last_report_uidx)>=rconf.report_freq: + utils.zlog(f"Current act_lrate is {act_lrate}.") self._run_train_report() last_report_uidx = self._tp.uidx # time for validating @@ -246,37 +337,38 @@ def run(self, train_stream, dev_streams): self._validate(dev_streams) last_dev_uidx = self._tp.uidx last_report_uidx = self._tp.uidx + # todo(+N): do we need to adjust sv at a finer grained? + self._adjust_scheduled_values() # adjust after uidx validation if self._finished(): break - if self._finished(): - break # validate at the end of epoch? + utils.zlog(f"End of epoch: Current act_lrate is {act_lrate}.") if rconf.validate_epoch: self._validate(dev_streams) last_dev_uidx = self._tp.uidx last_report_uidx = self._tp.uidx - self._tp.eidx += 1 - utils.zlog("zzzzzfinal: After training, the best point is: %s." % (str(self._tp.best_info()))) + utils.zlog("") + utils.zlog("zzzzzfinal: After training, the best point is: %s." % (str(self._tp.info_save_best()))) # save & load def save(self, base_name): - JsonRW.save_to_file(self._tp, base_name+".pr.json") + JsonRW.to_file(self._tp, base_name+".pr.json") self.model.save(base_name) utils.zlog("Save TrainRunner to <%s*>." % (base_name,), func="io") def load(self, base_name, load_process): if load_process: - JsonRW.load_from_file(self._tp, base_name+".pr.json") + JsonRW.from_file(self._tp, base_name+".pr.json") self.model.load(base_name) utils.zlog("Load TrainRunner from <%s*> (load-pr=%s)." % (base_name, load_process), func="io") # to be implemented # return one train result - def _run_fb(self, insts): + def _run_fb(self, insts, loss_factor: float): raise NotImplementedError() # return None - def _run_update(self, lrate: float): + def _run_update(self, lrate: float, grad_factor: float): raise NotImplementedError() # print and return train summary @@ -293,7 +385,8 @@ def _run_validate(self, dev_streams) -> RecordResult: class SVConf(Conf): def __init__(self): - self.init_val = 0. # should be set properly + # self.init_val = 0. # should be set properly (deprecated because of confusing name) + self.val = 0. # basic value # how to schedule the value self.which_idx = "eidx" # count steps on which: aidx, eidx, iidx, uidx self.mode = "none" # none, linear, exp, isigm @@ -301,35 +394,57 @@ def __init__(self): self.start_bias = 0 self.scale = 1.0 self.min_val = Constants.REAL_PRAC_MIN + self.max_val = Constants.REAL_PRAC_MAX + self.b = 0. self.k = 1.0 + self.m = 1.0 # specific one class ScheduledValue(object): - def __init__(self, name, sv_conf): + def __init__(self, name, sv_conf, special_ff=None): self.name = name self.sv_conf = sv_conf - # - self.init_val = sv_conf.init_val - self.cur_val = self.init_val + self.val = sv_conf.val + self.cur_val = None # mode = sv_conf.mode k = sv_conf.k + b = sv_conf.b + m = sv_conf.m + # + if special_ff is not None: + assert mode == "none", "Confusing setting for schedule function!" + # if mode == "linear": - self._ff = lambda idx: 1.0-k*idx + self._ff = lambda idx: b+k*m*idx + elif mode == "poly": + self._ff = lambda idx: b+k*(idx**m) elif mode == "exp": - self._ff = lambda idx: math.pow(k, idx) + self._ff = lambda idx: b+k*(math.pow(m, idx)) elif mode == "isigm": - self._ff = lambda idx: k/(k+math.exp(idx/k)) + self._ff = lambda idx: b+k*(m/(m+math.exp(idx/m))) elif mode == "div": - self._ff = lambda idx: 1/(1.+k*idx) + self._ff = lambda idx: b+k*(1/(1.+m*idx)) elif mode == "none": - self._ff = lambda idx: 1.0 + if special_ff is not None: + self._ff = lambda idx: b+k*special_ff(idx) # self-defined schedule: lambda idx: return ... + else: + self._ff = lambda idx: 1.0 else: raise NotImplementedError(mode) + # init setting + self._set(0) + utils.zlog(f"Init scheduled value {self.name} as {self.cur_val}.") @property def value(self): return self.cur_val + def __float__(self): + return float(self.cur_val) + + def __int__(self): + return int(self.cur_val) + def transform_idx(self, idx): x = max(0, idx-self.sv_conf.start_bias) return x/self.sv_conf.scale @@ -337,13 +452,18 @@ def transform_idx(self, idx): def __repr__(self): return "SV-%s=%s" % (self.name, self.cur_val) + def _set(self, the_idx): + new_idx = self.transform_idx(the_idx) + old_val = self.cur_val + # todo(note): self.val as the basis, multiplied by the factor + self.cur_val = max(self.sv_conf.min_val, self.val * self._ff(new_idx)) + self.cur_val = min(self.sv_conf.max_val, self.cur_val) + return old_val, self.cur_val + # adjust at checkpoint def adjust_at_ckp(self, sname, cur_idxes): the_idx = cur_idxes[self.sv_conf.which_idx] - # - new_idx = self.transform_idx(the_idx) - old_val = self.cur_val - self.cur_val = max(self.sv_conf.min_val, self.init_val * self._ff(new_idx)) + old_val, new_val = self._set(the_idx) if self.cur_val != old_val: utils.zlog("Change scheduled value %s at %s: %s => %s." % (self.name, sname, old_val, self.cur_val)) else: diff --git a/tasks/compile.sh b/tasks/compile.sh index 5e46b98..9574d65 100644 --- a/tasks/compile.sh +++ b/tasks/compile.sh @@ -2,5 +2,7 @@ # compile the cython modules +# 1st order algorithms python zdpar/algo/setup.py build_ext - +# high order algorithm +bash zdpar/algo/gp2/compile.sh diff --git a/tasks/zdpar/__init__.py b/tasks/zdpar/__init__.py index bf71605..5cc6833 100644 --- a/tasks/zdpar/__init__.py +++ b/tasks/zdpar/__init__.py @@ -1,3 +1,4 @@ # # package for dependency parsing + diff --git a/tasks/zdpar/algo/__init__.py b/tasks/zdpar/algo/__init__.py index 9ddeea2..dac14b1 100644 --- a/tasks/zdpar/algo/__init__.py +++ b/tasks/zdpar/algo/__init__.py @@ -12,4 +12,5 @@ # Tensor versions (can be simply wrappers) from .nmst import nmst_unproj, nmst_proj, nmst_greedy from .nmst import nmarginal_unproj, nmarginal_proj, nmarginal_greedy - +# +from .hop import hop_decode diff --git a/tasks/zdpar/algo/gp2/AD3-201907.zip b/tasks/zdpar/algo/gp2/AD3-201907.zip new file mode 100644 index 0000000..9617a09 Binary files /dev/null and b/tasks/zdpar/algo/gp2/AD3-201907.zip differ diff --git a/tasks/zdpar/algo/gp2/FactorHead.h b/tasks/zdpar/algo/gp2/FactorHead.h new file mode 100644 index 0000000..ec94fd7 --- /dev/null +++ b/tasks/zdpar/algo/gp2/FactorHead.h @@ -0,0 +1,517 @@ +// + +// Factor for one head (supporting features up to o3gsib) + +#ifndef FACTOR_HEAD_H +#define FACTOR_HEAD_H + +#include +#include "ad3/GenericFactor.h" +#include "utils.h" + +using std::ostream; + +// todo(note): no considering [0,0] other than the special [g,h] +namespace AD3 { + class FactorHead: public GenericFactor{ + public: + FactorHead() {} + virtual ~FactorHead() { ClearActiveSet(); } + + // Print as a string. + void Print(ostream& stream) { + stream << "HEAD_AUTOMATON up to o3gsib (details omitted)"; + Factor::Print(stream); + } + + void info(ostream& stream, const vector* variable_scores, const vector* additional_scores){ + vector self_variable_scores; + vector self_additional_scores; + // get self if nullptr + if(variable_scores == nullptr){ + for(auto* v : binary_variables_){ + self_variable_scores.push_back(v->GetLogPotential()); + } + variable_scores = &self_variable_scores; + } + if(additional_scores == nullptr){ + additional_scores = &additional_log_potentials_; + } + // collect and print info + // first + stream << "FactorHead: head=" << cur_head_ << ", is_right=" << is_right_ << endl; + // gs + if(head0_){ + CHECK(cur_head_==0, ""); + CHECK(length_gs_ == 1, ""); + CHECK(base_gs_ == 0, ""); + stream << "gs: With only Dummy 0" << endl; + } + else{ + stream << "gs: "; + for(int i = 0; i < rindex_incoming_.size(); i++){ + stream << rindex_incoming_[i] << "(" << variable_scores->at(i) << ")" << " "; + } + stream << endl; + CHECK(base_gs_ == rindex_incoming_.size(), ""); + } + // ms + stream << "ms: "; + for(int i = 0; i < rindex_modifiers_.size(); i++){ + stream << rindex_modifiers_[i] << "(" << variable_scores->at(i+base_gs_) << ")" << " "; + } + stream << endl; + CHECK(base_gs_ + length_ms_ == variable_scores->size(), ""); + // addional + int cur_idx_base = 0; + vector cur_ms, cur_hs, cur_ss, cur_gs; + vector cur_scores; + // o2sib + if(has_o2sib_){ + for(int m = 0; m < length_ms_; m++){ + for(int s = 0; s < length_ms_; s++){ + int cur_idx = index_siblings_[m][s]; + if(cur_idx >= 0){ + cur_ms.push_back(rindex_modifiers_[m]); cur_hs.push_back(cur_head_); cur_ss.push_back(rindex_modifiers_[s]); + cur_scores.push_back(additional_scores->at(cur_idx_base)); + cur_idx_base++; + } + } + } + debug_print(stream, &cur_ms, &cur_hs, &cur_ss, nullptr, &cur_scores); + cur_ms.clear(); cur_hs.clear(); cur_ss.clear(); cur_scores.clear(); + } + // o2g + if(has_o2g_){ + for(int m = 0; m < length_ms_; m++){ + for(int g = 0; g < length_gs_; g++){ + int cur_idx = index_grandparents_[g][m]; + if(cur_idx >= 0){ + cur_ms.push_back(rindex_modifiers_[m]); cur_hs.push_back(cur_head_); cur_gs.push_back(rindex_incoming_[g]); + cur_scores.push_back(additional_scores->at(cur_idx_base)); + cur_idx_base++; + } + } + } + debug_print(stream, &cur_ms, &cur_hs, nullptr, &cur_gs, &cur_scores); + cur_ms.clear(); cur_hs.clear(); cur_gs.clear(); cur_scores.clear(); + } + // o3gsib + if(has_o3gsib_){ + for(int m = 0; m < length_ms_; m++){ + for(int s = 0; s < length_ms_; s++){ + for(int g = 0; g < length_gs_; g++){ + int cur_idx = index_grandsiblings_[g][m][s]; + if(cur_idx >= 0){ + cur_ms.push_back(rindex_modifiers_[m]); cur_hs.push_back(cur_head_); cur_ss.push_back(rindex_modifiers_[s]); cur_gs.push_back(rindex_incoming_[g]); + cur_scores.push_back(additional_scores->at(cur_idx_base)); + cur_idx_base++; + } + } + } + } + debug_print(stream, &cur_ms, &cur_hs, &cur_ss, &cur_gs, &cur_scores); + } + // + CHECK(cur_idx_base == additional_scores->size(), ""); + } + + // ===== + // main ones + private: + // get + inline double get_o2sib_score(const vector &additional_log_potentials, int m, int s){ + if(has_o2sib_){ + int idx = index_siblings_[m][s]; + CHECK(idx >= 0, "Illegal index"); + return additional_log_potentials[idx]; + } + else return 0.; + } + + inline double get_o2g_score(const vector &additional_log_potentials, int g, int m){ + if(has_o2g_){ + int idx = index_grandparents_[g][m]; + CHECK(idx >= 0, "Illegal index"); + return additional_log_potentials[idx]; + } + else return 0.; + } + + inline double get_o3gsib_score(const vector &additional_log_potentials, int g, int m, int s){ + if(has_o3gsib_){ + int idx = index_grandsiblings_[g][m][s]; + CHECK(idx >= 0, "Illegal index"); + return additional_log_potentials[idx]; + } + else return 0.; + } + + // set + inline void add_o2sib_weight(vector *additional_posteriors, double weight, int m, int s){ + if(has_o2sib_){ + int idx = index_siblings_[m][s]; + CHECK(idx >= 0, "Illegal index"); + (*additional_posteriors)[idx] += weight; + } + } + + inline void add_o2g_weight(vector *additional_posteriors, double weight, int g, int m){ + if(has_o2g_){ + int idx = index_grandparents_[g][m]; + CHECK(idx >= 0, "Illegal index"); + (*additional_posteriors)[idx] += weight; + } + } + + inline void add_o3gsib_weight(vector *additional_posteriors, double weight, int g, int m, int s){ + if(has_o3gsib_){ + int idx = index_grandsiblings_[g][m][s]; + CHECK(idx >= 0, "Illegal index"); + (*additional_posteriors)[idx] += weight; + } + } + + public: + // Get the combinatory best solution + // variable_log_potentials: (g,h)+(h,m), additional_log_potentials: o2sib+o2g+o3gsib + // -- todo(note): and be careful about [0,0] + void Maximize(const vector &variable_log_potentials, + const vector &additional_log_potentials, + Configuration &configuration, + double *value) { + // Decode maximizing over the grandparents and using the Viterbi algorithm as an inner loop. + int num_g = length_gs_; + int num_m = length_ms_; + // ----- + // overall best + int overall_best_grandparent = -1; + vector overall_best_modifiers; + double overall_best_score = MY_NEGINF; + // Run Viterbi for each possible grandparent. + for(int g = 0; g < num_g; ++g) { + int orig_g = rindex_incoming_[g]; + // gh score should be added at the final step! + double gh_score = (!head0_) ? variable_log_potentials[g] : 0.; // no gh score or simply ignore (0,0) + // 0 as the start, thus the recording idxes are all +1 + // these means up-to-this-step(and selecting current, what is the best history) + vector best_path(num_m + 1, 0); // for each step, the best prev-m (in recording-m) + vector best_scores(num_m + 1, MY_NEGINF); // best scores for each step + // starting best is selecting nothing + int best_ending = 0; + double best_score = 0.; + best_scores[0] = 0.; + // for each mod as step + // todo(note): not-optimized for no-sib situation + for(int m = 0; m < num_m; ++m) { + int orig_m = rindex_modifiers_[m]; + if(orig_g == orig_m) + continue; // g cannot be m! + double o2g_score = get_o2g_score(additional_log_potentials, g, m); + double hm_score = variable_log_potentials[m+base_gs_]; // offset the gh ones + int recording_m = m + 1; // step number, since padding a 0 as start state + // adding current m, find best prev recording-m + // first the special one for the starting mod: [m, m] means this + int cur_best_prev_rm = 0; + double cur_best_score = get_o2sib_score(additional_log_potentials, m, m) + get_o3gsib_score(additional_log_potentials, g, m, m); + // + // for other nonzero ones + for(int s = 0; s < m; s++){ + int orig_s = rindex_modifiers_[s]; + if(orig_g == orig_s) + continue; // g cannot be s! + int mr = s + 1; + double cur_one_score = best_scores[mr]; + cur_one_score += get_o2sib_score(additional_log_potentials, m, s) + get_o3gsib_score(additional_log_potentials, g, m, s); + if(cur_one_score > cur_best_score){ + cur_best_score = cur_one_score; + cur_best_prev_rm = mr; + } + } + // update for one g + cur_best_score += o2g_score + hm_score; + best_path[recording_m] = cur_best_prev_rm; + best_scores[recording_m] = cur_best_score; + if(cur_best_score > best_score){ + best_score = cur_best_score; + best_ending = recording_m; + } + } + // update overall + double best_full_score = best_score + gh_score; + if(best_full_score > overall_best_score){ + overall_best_score = best_full_score; + overall_best_grandparent = g; + // back tracking modifiers + overall_best_modifiers.clear(); + for(int cur_mr = best_ending; cur_mr > 0; cur_mr=best_path[cur_mr]){ + overall_best_modifiers.push_back(cur_mr-1); // remember to get real idx + } + std::reverse(overall_best_modifiers.begin(), overall_best_modifiers.end()); // correct order + } + } + // Now write the configuration. + vector *grandparent_modifiers = static_cast*>(configuration); + grandparent_modifiers->push_back(overall_best_grandparent); + grandparent_modifiers->insert(grandparent_modifiers->end(), overall_best_modifiers.begin(), overall_best_modifiers.end()); + *value = overall_best_score; + } + + // Compute the score of a given assignment. + void Evaluate(const vector &variable_log_potentials, + const vector &additional_log_potentials, + const Configuration configuration, + double *value) { + const vector* grandparent_modifiers = static_cast*>(configuration); + double tmp_value = 0.; + int g = (*grandparent_modifiers)[0]; + if(!head0_) + tmp_value += variable_log_potentials[g]; + int prev = -1; + for(int i = 1; i < grandparent_modifiers->size(); ++i) { + int m = (*grandparent_modifiers)[i]; + int s = (prev < 0) ? m : prev; + tmp_value += variable_log_potentials[base_gs_+m]; + // the three additional ones + tmp_value += get_o2sib_score(additional_log_potentials, m, s); + tmp_value += get_o2g_score(additional_log_potentials, g, m); + tmp_value += get_o3gsib_score(additional_log_potentials, g, m, s); + prev = m; + } + *value = tmp_value; + } + + // Given a configuration with a probability (weight), + // increment the vectors of variable and additional posteriors. + void UpdateMarginalsFromConfiguration( + const Configuration &configuration, + double weight, + vector *variable_posteriors, + vector *additional_posteriors) { + const vector* grandparent_modifiers = static_cast*>(configuration); + int g = (*grandparent_modifiers)[0]; + if(!head0_) + (*variable_posteriors)[g] += weight; + int prev = -1; + for(int i = 1; i < grandparent_modifiers->size(); ++i) { + int m = (*grandparent_modifiers)[i]; + int s = (prev < 0) ? m : prev; + (*variable_posteriors)[base_gs_+m] += weight; + // the three additional ones + add_o2sib_weight(additional_posteriors, weight, m, s); + add_o2g_weight(additional_posteriors, weight, g, m); + add_o3gsib_weight(additional_posteriors, weight, g, m, s); + prev = m; + } + } + + // ===== + // configuration + + // Count how many common values two configurations have. + int CountCommonValues(const Configuration &configuration1, const Configuration &configuration2) { + const vector *values1 = static_cast*>(configuration1); + const vector *values2 = static_cast*>(configuration2); + int count = 0; + if((*values1)[0] == (*values2)[0]) ++count; // Grandparents matched. + int j = 1; + for(int i = 1; i < values1->size(); ++i) { + for(; j < values2->size(); ++j) { + if((*values2)[j] >= (*values1)[i]) break; + } + if(j < values2->size() && (*values2)[j] == (*values1)[i]) { + ++count; + ++j; + } + } + return count; + } + + // Check if two configurations are the same. + bool SameConfiguration( + const Configuration &configuration1, + const Configuration &configuration2) { + const vector *values1 = static_cast*>(configuration1); + const vector *values2 = static_cast*>(configuration2); + if(values1->size() != values2->size()) return false; + for(int i = 0; i < values1->size(); ++i) { + if((*values1)[i] != (*values2)[i]) return false; + } + return true; + } + + // Delete configuration. + void DeleteConfiguration( + Configuration configuration) { + vector *values = static_cast*>(configuration); + delete values; + } + + Configuration CreateConfiguration() { + // The first element is the index of the grandparent. + // The remaining elements are the indices of the modifiers. + vector* grandparent_modifiers = new vector; + return static_cast(grandparent_modifiers); + } + + public: + // ===== + // init + + // h, gs, ms: original idxes (in sentence) for the possible gs and ms + // -- gs is l2r, ms is l2r/r2l according to the direction + // the rest arguments indicate the high-order parts, idxes are the ones that are effective + // gs can include 0, but ms does not has 0! + // todo(note): no checking here, assuming at the outside the processes are correct without including strange ones + void Initialize(const int h, const bool is_right, const vector &gs, const vector &ms, + const O2SibPack* po2sib, const vector& po2sib_idxes, + const O2gPack* po2g, const vector& po2g_idxes, + const O3gsibPack* po3gsib, const vector& po3gsib_idxes) { + // ===== + cur_head_ = h; + is_right_ = is_right; + if(h == 0){ + head0_ = true; + CHECK(is_right, "Err: 0-left"); + CHECK(gs.size() == 0, "Err: head-of-0"); + } + else{ + head0_ = false; + } + // length/idxes are relative to the head position. + length_gs_ = head0_ ? 1 : gs.size(); + length_ms_ = ms.size(); // no stopping (outmost) part! + base_gs_ = gs.size(); + has_o2sib_ = has_o2g_ = has_o3gsib_ = false; + if(po2sib){ + index_siblings_.assign(length_ms_, vector(length_ms_, -1)); + has_o2sib_ = true; + } + if(po2g){ + index_grandparents_.assign(length_gs_, vector(length_ms_, -1)); + has_o2g_ = true; + } + if(po3gsib){ + index_grandsiblings_.assign(length_gs_, vector >(length_ms_, vector(length_ms_, -1))); + has_o3gsib_ = true; + } + need_sib_ = (has_o2sib_ || has_o3gsib_); + need_g_ = (has_o2g_ || has_o3gsib_); + // ===== + // Create a temporary index of modifiers. + // todo(note): here, no special idx0, since o2sib[m][m] means start-mod; expect for [0][m] for root-self-loop + index_modifiers_.clear(); + for(int k = 0; k < ms.size(); ++k) { + int m = ms[k]; + int position = is_right ? (m - h) : (h - m); + index_modifiers_.resize(position + 1, -1); + index_modifiers_[position] = k; + } + // Create a temporary index of grandparents. + // todo(note): special treating for (0,0) + index_incoming_.clear(); + if(cur_head_ == 0){ + // specially put 0 here + index_incoming_.push_back(0); + } + else{ + for(int k = 0; k < gs.size(); ++k) { + int g = gs[k]; + int position = g; + index_incoming_.resize(position + 1, -1); + index_incoming_[position] = k; + } + } + // ===== + int cur_idx_base = 0; + // Construct index of siblings. + for(int r = 0; r < po2sib_idxes.size(); r++){ + int ri = po2sib_idxes[r]; + //int cur_h = po2sib->idxes_h[ri]; + int m = po2sib->idxes_m[ri]; + int s = po2sib->idxes_s[ri]; + // change idx + int position_modifier = is_right ? (m - h) : (h - m); + int position_sibling = is_right ? (s - h) : (h - s); + int index_modifier = index_modifiers_[position_modifier]; + int index_sibling = index_modifiers_[position_sibling]; + // final overall idx + index_siblings_[index_modifier][index_sibling] = cur_idx_base; + cur_idx_base++; + } + // Construct index of grandparents. + for(int r = 0; r < po2g_idxes.size(); r++){ + int ri = po2g_idxes[r]; + //int cur_h = po2g->idxes_h[ri]; + int m = po2g->idxes_m[ri]; + int g = po2g->idxes_g[ri]; + // change idx + int position_modifier = is_right ? (m - h) : (h - m); + int position_grandparent = g; + int index_modifier = index_modifiers_[position_modifier]; + int index_grandparent = index_incoming_[position_grandparent]; + // final idx + index_grandparents_[index_grandparent][index_modifier] = cur_idx_base; + cur_idx_base++; + } + // Construct index of grandsiblings. + for(int r = 0; r < po3gsib_idxes.size(); r++){ + int ri = po3gsib_idxes[r]; + //int cur_h = po3gsib->idxes_h[ri]; + int m = po3gsib->idxes_m[ri]; + int s = po3gsib->idxes_s[ri]; + int g = po3gsib->idxes_g[ri]; + // change idx + int position_modifier = is_right ? (m - h) : (h - m); + int position_sibling = is_right ? (s - h) : (h - s); + int position_grandparent = g; + int index_modifier = index_modifiers_[position_modifier]; + int index_sibling = index_modifiers_[position_sibling]; + int index_grandparent = index_incoming_[position_grandparent]; + // final overall idx + index_grandsiblings_[index_grandparent][index_modifier][index_sibling] = cur_idx_base; + cur_idx_base++; + } + // reverse index + rindex_incoming_.assign(length_gs_, -1); + rindex_modifiers_.assign(length_ms_, -1); + for(unsigned i = 0; i < index_incoming_.size(); i++){ + int x = index_incoming_[i]; + if(x >= 0) rindex_incoming_[x] = i; + } + for(unsigned i = 0; i < index_modifiers_.size(); i++){ + int x = index_modifiers_[i]; + if(x >= 0) rindex_modifiers_[x] = ((is_right) ? (h+i) : (h-i)); + } + } + + private: + // + int cur_head_; + bool head0_; // cur_head is 0 + bool is_right_; + // overall (valid) lengths + int length_gs_; + int length_ms_; + int base_gs_; // basis as the varaible-offset (excluding 0,0) + // valid pieces + bool has_o2sib_; + bool has_o2g_; + bool has_o3gsib_; + bool need_sib_; + bool need_g_; + // idxes from relative-idx to flattened-additional-scores + vector > index_siblings_; // [len_ms_m, len_ms_s] + vector > index_grandparents_; // [len_gs, len_ms] + vector > > index_grandsiblings_; // // [len_gs, len_ms_m, len_ms_s] + // for debugging + // orig_idx -> cur_idx + vector index_incoming_; + vector index_modifiers_; + // reverse: cur_idx -> orig_idx/orig_distance + vector rindex_incoming_; + vector rindex_modifiers_; +}; +} + +#endif // !FACTOR_HEAD_H diff --git a/tasks/zdpar/algo/gp2/FactorTree.h b/tasks/zdpar/algo/gp2/FactorTree.h new file mode 100644 index 0000000..c6dd151 --- /dev/null +++ b/tasks/zdpar/algo/gp2/FactorTree.h @@ -0,0 +1,186 @@ +// + +// Factor for the overall tree structure + +#ifndef FACTOR_TREE_H +#define FACTOR_TREE_H + +#include "utils.h" + +// whether put edge's o1-score as factor-variable (which will be divided to all the links in the algorithm) or additional-score of this TreeFactor +// #define FACTOR_TREE_ADDITIONAL_SCORE + +namespace AD3 { + class FactorTree: public GenericFactor { + public: + FactorTree() {} + virtual ~FactorTree() { ClearActiveSet(); } + + // Print as a string. + void Print(ostream& stream) { + stream << "ARBORESCENCE (details omitted)"; + Factor::Print(stream); + } + + void info(ostream& stream, const vector* variable_scores, const vector* additional_scores){ + vector self_variable_scores; + vector self_additional_scores; + // get self if nullptr + if(variable_scores == nullptr){ + for(auto* v : binary_variables_){ + self_variable_scores.push_back(v->GetLogPotential()); + } + variable_scores = &self_variable_scores; + } + if(additional_scores == nullptr){ + additional_scores = &additional_log_potentials_; + } + // check + CHECK(variable_scores->size() == hs_->size(), "Err: size mismatch"); +#ifdef FACTOR_TREE_ADDITIONAL_SCORE + CHECK(additional_scores->size() == variable_scores->size(), "Err: size mismatch"); +#else + CHECK(additional_scores->size() == 0, "Err: size mismatch"); +#endif + // print + stream << "FactorTree: slen=" << length_ << ", narc=" << variable_scores->size() << ", projective=" << projective_ << endl; + int count = 0; + vector cur_ms; + vector cur_hs; + vector cur_scores; + for(int m = 0; m < length_; m++){ + for(int h = 0; h < length_; h++){ + int cur_i = index_arcs_[h][m]; + if(cur_i >= 0){ + cur_ms.push_back(m); + cur_hs.push_back(h); + cur_scores.push_back(variable_scores->at(cur_i)); + } + } + } + CHECK(variable_scores->size() == cur_ms.size(), "Err: size mismatch"); + // print2 + debug_print(stream, &cur_ms, &cur_hs, nullptr, nullptr, &cur_scores); + } + + // ===== + // main ones + // use additional_log_potentials as extra o1 scores + + // maximize using specific algorithms + void Maximize(const vector &variable_log_potentials, + const vector &additional_log_potentials, + Configuration &configuration, + double *value) { + vector* heads = static_cast*>(configuration); +#ifdef FACTOR_TREE_ADDITIONAL_SCORE + vector cur_scores = variable_log_potentials; + for(int i = 0; i < cur_scores.size(); i++) + cur_scores[i] += additional_log_potentials[i]; + mst_decode(length_, projective_, *hs_, *ms_, cur_scores, heads, value); +#else + mst_decode(length_, projective_, *hs_, *ms_, variable_log_potentials, heads, value); +#endif + } + + // Compute the score of a given assignment. + void Evaluate(const vector &variable_log_potentials, + const vector &additional_log_potentials, + const Configuration configuration, + double *value) { + const vector *heads = static_cast*>(configuration); + // Heads belong to {0,1,2,...} + *value = 0.0; + for(int m = 1; m < heads->size(); ++m) { + int h = (*heads)[m]; + int index = index_arcs_[h][m]; + *value += variable_log_potentials[index]; +#ifdef FACTOR_TREE_ADDITIONAL_SCORE + *value += additional_log_potentials[index]; +#endif + } + } + + // Given a configuration with a probability (weight), + // increment the vectors of variable and additional posteriors. + void UpdateMarginalsFromConfiguration( + const Configuration &configuration, + double weight, + vector *variable_posteriors, + vector *additional_posteriors) { + const vector *heads = static_cast*>(configuration); + for(int m = 1; m < heads->size(); ++m) { + int h = (*heads)[m]; + int index = index_arcs_[h][m]; + (*variable_posteriors)[index] += weight; +#ifdef FACTOR_TREE_ADDITIONAL_SCORE + (*additional_posteriors)[index] += weight; +#endif + } + } + + // Count how many common values two configurations have. + int CountCommonValues(const Configuration &configuration1, + const Configuration &configuration2) { + const vector *heads1 = static_cast*>(configuration1); + const vector *heads2 = static_cast*>(configuration2); + int count = 0; + for(int i = 1; i < heads1->size(); ++i) { + if((*heads1)[i] == (*heads2)[i]) { + ++count; + } + } + return count; + } + + // Check if two configurations are the same. + bool SameConfiguration( + const Configuration &configuration1, + const Configuration &configuration2) { + const vector *heads1 = static_cast*>(configuration1); + const vector *heads2 = static_cast*>(configuration2); + for(int i = 1; i < heads1->size(); ++i) { + if((*heads1)[i] != (*heads2)[i]) return false; + } + return true; + } + + // Delete configuration. + void DeleteConfiguration( + Configuration configuration) { + vector *heads = static_cast*>(configuration); + delete heads; + } + + // Create configuration. + Configuration CreateConfiguration() { + vector* heads = new vector(length_); + return static_cast(heads); + } + + public: + void Initialize(const bool projective, const int length, const vector* hs, const vector* ms) { + projective_ = projective; + length_ = length; + hs_ = hs; + ms_ = ms; + index_arcs_.assign(length, vector(length, -1)); + for(int k = 0; k < hs->size(); ++k) { + int h = (*hs_)[k]; + int m = (*ms_)[k]; + index_arcs_[h][m] = k; + } + } + + private: + bool projective_; // If true, assume projective trees. + int length_; // Sentence length (including root symbol). + const vector* hs_; + const vector* ms_; + vector > index_arcs_; + }; +} + +#endif // !FACTOR_TREE_H + +//TODO: deal with the decoding algo diff --git a/tasks/zdpar/algo/gp2/algo.cpp b/tasks/zdpar/algo/gp2/algo.cpp new file mode 100644 index 0000000..e0c38fe --- /dev/null +++ b/tasks/zdpar/algo/gp2/algo.cpp @@ -0,0 +1,380 @@ +// decoding algorithm + +// adapted from TurboParser::DepedencyDecoder.cpp + +#include +#include "utils.h" + +// -- nonprojective + +// Decoder for the basic model; it finds a maximum weighted arborescence +void RunChuLiuEdmondsIteration( + vector *disabled, + vector > *candidate_heads, + vector > *candidate_scores, + vector *heads, + double *value) { + // Original number of nodes (including the root). + int length = disabled->size(); + + // Pick the best incoming arc for each node. + heads->resize(length); + vector best_scores(length); + for(int m = 1; m < length; ++m) { + if((*disabled)[m]) continue; + int best = -1; + for(int k = 0; k < (*candidate_heads)[m].size(); ++k) { + if(best < 0 || + (*candidate_scores)[m][k] >(*candidate_scores)[m][best]) { + best = k; + } + } + if(best < 0) { + // No spanning tree exists. Assign the parent of this node + // to the root, and give it a minus infinity score. + (*heads)[m] = 0; + best_scores[m] = -std::numeric_limits::infinity(); + } + else { + (*heads)[m] = (*candidate_heads)[m][best]; //best; + best_scores[m] = (*candidate_scores)[m][best]; //best; + } + } + + // Look for cycles. Return after the first cycle is found. + vector cycle; + vector visited(length, 0); + for(int m = 1; m < length; ++m) { + if((*disabled)[m]) continue; + // Examine all the ancestors of m until the root or a cycle is found. + int h = m; + while(h != 0) { + // If already visited, break and check if it is part of a cycle. + // If visited[h] < m, the node was visited earlier and seen not + // to be part of a cycle. + if(visited[h]) break; + visited[h] = m; + h = (*heads)[h]; + } + + // Found a cycle to which h belongs. + // Obtain the full cycle. + if(visited[h] == m) { + m = h; + do { + cycle.push_back(m); + m = (*heads)[m]; + } while(m != h); + break; + } + } + + // If there are no cycles, then this is a well formed tree. + if(cycle.empty()) { + *value = 0.0; + for(int m = 1; m < length; ++m) { + *value += best_scores[m]; + } + return; + } + + // Build a cycle membership vector for constant-time querying and compute the + // score of the cycle. + // Nominate a representative node for the cycle and disable all the others. + double cycle_score = 0.0; + vector in_cycle(length, false); + int representative = cycle[0]; + for(int k = 0; k < cycle.size(); ++k) { + int m = cycle[k]; + in_cycle[m] = true; + cycle_score += best_scores[m]; + if(m != representative) (*disabled)[m] = true; + } + + // Contract the cycle. + // 1) Update the score of each child to the maximum score achieved by a parent + // node in the cycle. + vector best_heads_cycle(length); + for(int m = 1; m < length; ++m) { + if((*disabled)[m] || m == representative) continue; + double best_score; + // If the list of candidate parents of m is shorter than the length of + // the cycle, use that. Otherwise, loop through the cycle. + int best = -1; + for(int k = 0; k < (*candidate_heads)[m].size(); ++k) { + if(!in_cycle[(*candidate_heads)[m][k]]) continue; + if(best < 0 || (*candidate_scores)[m][k] > best_score) { + best = k; + best_score = (*candidate_scores)[m][best]; + } + } + if(best < 0) continue; + best_heads_cycle[m] = (*candidate_heads)[m][best]; + + // Reconstruct the list of candidate heads for this m. + int l = 0; + for(int k = 0; k < (*candidate_heads)[m].size(); ++k) { + int h = (*candidate_heads)[m][k]; + double score = (*candidate_scores)[m][k]; + if(!in_cycle[h]) { + (*candidate_heads)[m][l] = h; + (*candidate_scores)[m][l] = score; + ++l; + } + } + // If h is in the cycle and is not the representative node, + // it will be dropped from the list of candidate heads. + (*candidate_heads)[m][l] = representative; + (*candidate_scores)[m][l] = best_score; + (*candidate_heads)[m].resize(l + 1); + (*candidate_scores)[m].resize(l + 1); + } + + // 2) Update the score of each candidate parent of the cycle supernode. + vector best_modifiers_cycle(length, -1); + vector candidate_heads_representative; + vector candidate_scores_representative; + + vector best_scores_cycle(length); + // Loop through the cycle. + for(int k = 0; k < cycle.size(); ++k) { + int m = cycle[k]; + for(int l = 0; l < (*candidate_heads)[m].size(); ++l) { + // Get heads out of the cycle. + int h = (*candidate_heads)[m][l]; + if(in_cycle[h]) continue; + + double score = (*candidate_scores)[m][l] - best_scores[m]; + if(best_modifiers_cycle[h] < 0 || score > best_scores_cycle[h]) { + best_modifiers_cycle[h] = m; + best_scores_cycle[h] = score; + } + } + } + for(int h = 0; h < length; ++h) { + if(best_modifiers_cycle[h] < 0) continue; + double best_score = best_scores_cycle[h] + cycle_score; + candidate_heads_representative.push_back(h); + candidate_scores_representative.push_back(best_score); + } + + // Reconstruct the list of candidate heads for the representative node. + (*candidate_heads)[representative] = candidate_heads_representative; + (*candidate_scores)[representative] = candidate_scores_representative; + + // Save the current head of the representative node (it will be overwritten). + int head_representative = (*heads)[representative]; + + // Call itself recursively. + RunChuLiuEdmondsIteration(disabled, + candidate_heads, + candidate_scores, + heads, + value); + + // Uncontract the cycle. + int h = (*heads)[representative]; + (*heads)[representative] = head_representative; + (*heads)[best_modifiers_cycle[h]] = h; + + for(int m = 1; m < length; ++m) { + if((*disabled)[m]) continue; + if((*heads)[m] == representative) { + // Get the right parent from within the cycle. + (*heads)[m] = best_heads_cycle[m]; + } + } + for(int k = 0; k < cycle.size(); ++k) { + int m = cycle[k]; + (*disabled)[m] = false; + } +} + +// interface +void RunChuLiuEdmonds(int sentence_length, const vector& hs, const vector& ms, const vector& scores, + vector *heads, double *value){ + vector > candidate_heads(sentence_length); + vector > candidate_scores(sentence_length); + vector disabled(sentence_length, false); + for(int r = 0; r < hs.size(); ++r) { + int h = hs[r]; + int m = ms[r]; + candidate_heads[m].push_back(h); + candidate_scores[m].push_back(scores[r]); + } + heads->assign(sentence_length, -1); + RunChuLiuEdmondsIteration(&disabled, &candidate_heads, &candidate_scores, heads, value); +} + +// ===== +void RunEisnerBacktrack( + const vector &incomplete_backtrack, + const vector > &complete_backtrack, + const vector > &index_arcs, + int h, int m, bool complete, vector *heads) { + if(h == m) return; + if(complete) { + CHECK_GE(h, 0); + CHECK_LT(h, complete_backtrack.size()); + CHECK_GE(m, 0); + CHECK_LT(m, complete_backtrack.size()); + int u = complete_backtrack[h][m]; + //CHECK_GE(u, 0) << h << " " << m; + CHECK_GE(u, 0); + RunEisnerBacktrack(incomplete_backtrack, complete_backtrack, index_arcs, + h, u, false, heads); + RunEisnerBacktrack(incomplete_backtrack, complete_backtrack, index_arcs, + u, m, true, heads); + } + else { + int r = index_arcs[h][m]; + CHECK_GE(r, 0); + CHECK_LT(r, incomplete_backtrack.size()); + CHECK_GE(h, 0); + CHECK_LT(h, heads->size()); + CHECK_GE(m, 0); + CHECK_LT(m, heads->size()); + (*heads)[m] = h; + int u = incomplete_backtrack[r]; + if(h < m) { + RunEisnerBacktrack(incomplete_backtrack, complete_backtrack, index_arcs, + h, u, true, heads); + RunEisnerBacktrack(incomplete_backtrack, complete_backtrack, index_arcs, + m, u + 1, true, heads); + } + else { + RunEisnerBacktrack(incomplete_backtrack, complete_backtrack, index_arcs, + m, u, true, heads); + RunEisnerBacktrack(incomplete_backtrack, complete_backtrack, index_arcs, + h, u + 1, true, heads); + } + } +} + +// Run Eisner's algorithm for finding a maximal weighted projective dependency +// tree. +void RunEisner(int sentence_length, const vector& hs, const vector& ms, const vector& scores, + vector *heads, double *value) { + // + vector > index_arcs(sentence_length, vector(sentence_length, -1)); + int num_arcs = hs.size(); + for(int r = 0; r < num_arcs; ++r) { + int h = hs[r]; + int m = ms[r]; + index_arcs[h][m] = r; + } + + heads->assign(sentence_length, -1); + + // Initialize CKY table. + vector > complete_spans(sentence_length, vector( + sentence_length, 0.0)); + vector > complete_backtrack(sentence_length, + vector(sentence_length, -1)); + vector incomplete_spans(num_arcs); + vector incomplete_backtrack(num_arcs, -1); + + // Loop from smaller items to larger items. + for(int k = 1; k < sentence_length; ++k) { + for(int s = 1; s < sentence_length - k; ++s) { + int t = s + k; + + // First, create incomplete items. + int left_arc_index = index_arcs[t][s]; + int right_arc_index = index_arcs[s][t]; + if(left_arc_index >= 0 || right_arc_index >= 0) { + double best_value = -std::numeric_limits::infinity(); + int best = -1; + for(int u = s; u < t; ++u) { + double val = complete_spans[s][u] + complete_spans[t][u + 1]; + if(best < 0 || val > best_value) { + best = u; + best_value = val; + } + } + if(left_arc_index >= 0) { + incomplete_spans[left_arc_index] = + best_value + scores[left_arc_index]; + incomplete_backtrack[left_arc_index] = best; + } + if(right_arc_index >= 0) { + incomplete_spans[right_arc_index] = + best_value + scores[right_arc_index]; + incomplete_backtrack[right_arc_index] = best; + } + } + + // Second, create complete items. + // 1) Left complete item. + double best_value = -std::numeric_limits::infinity(); + int best = -1; + for(int u = s; u < t; ++u) { + int left_arc_index = index_arcs[t][u]; + if(left_arc_index >= 0) { + double val = complete_spans[u][s] + incomplete_spans[left_arc_index]; + if(best < 0 || val > best_value) { + best = u; + best_value = val; + } + } + } + complete_spans[t][s] = best_value; + complete_backtrack[t][s] = best; + + // 2) Right complete item. + best_value = -std::numeric_limits::infinity(); + best = -1; + for(int u = s + 1; u <= t; ++u) { + int right_arc_index = index_arcs[s][u]; + if(right_arc_index >= 0) { + double val = complete_spans[u][t] + incomplete_spans[right_arc_index]; + if(best < 0 || val > best_value) { + best = u; + best_value = val; + } + } + } + complete_spans[s][t] = best_value; + complete_backtrack[s][t] = best; + } + } + + // Get the optimal (single) root. + double best_value = -std::numeric_limits::infinity(); + int best = -1; + for(int s = 1; s < sentence_length; ++s) { + int arc_index = index_arcs[0][s]; + if(arc_index >= 0) { + double val = complete_spans[s][1] + complete_spans[s][sentence_length - 1] + + scores[arc_index]; + if(best < 0 || val > best_value) { + best = s; + best_value = val; + } + } + } + + *value = best_value; + (*heads)[best] = 0; + + // Backtrack. + RunEisnerBacktrack(incomplete_backtrack, complete_backtrack, index_arcs, best, + 1, true, heads); + RunEisnerBacktrack(incomplete_backtrack, complete_backtrack, index_arcs, best, + sentence_length - 1, true, heads); + + //*value = complete_spans[0][sentence_length-1]; + //RunEisnerBacktrack(incomplete_backtrack, complete_backtrack, index_arcs, 0, + // sentence_length-1, true, heads); +} + +// ==== +// interface + +void mst_decode(int sentence_length, bool projective, const vector& hs, const vector& ms, const vector& scores, + vector *heads, double *value){ + if(projective) + RunEisner(sentence_length, hs, ms, scores, heads, value); + else + RunChuLiuEdmonds(sentence_length, hs, ms, scores, heads, value); +} diff --git a/tasks/zdpar/algo/gp2/compile.sh b/tasks/zdpar/algo/gp2/compile.sh new file mode 100644 index 0000000..258f62d --- /dev/null +++ b/tasks/zdpar/algo/gp2/compile.sh @@ -0,0 +1,23 @@ +#!/bin/bash + +# dir +PREV_DIR=`pwd` +RUNNING_DIR="$( cd "$( dirname ${BASH_SOURCE[0]} )" && pwd )" +cd $RUNNING_DIR + +# get ad3 (specific version) +git clone https://github.com/andre-martins/AD3/ +cd AD3 +git checkout 22131c7457614dd159546500cd1a0fd8cdf2d282 +cd .. +mv AD3/ad3 . +mv AD3/Eigen . +rm -rf AD3 + +# fix mem-leak? +sed -i '437s/if (j == k) continue//' ad3/GenericFactor.cpp + +# compile +c++ -O3 -Wall -Wno-sign-compare -shared -std=c++11 -fPIC `python3 -m pybind11 --includes` -I./ad3/ -I./Eigen -I. ./ad3/*.cpp algo.cpp parser2.cpp -o parser2`python3-config --extension-suffix` +mv parser2`python3-config --extension-suffix` .. +cd $PREV_DIR diff --git a/tasks/zdpar/algo/gp2/data.txt b/tasks/zdpar/algo/gp2/data.txt new file mode 100644 index 0000000..20a00c5 --- /dev/null +++ b/tasks/zdpar/algo/gp2/data.txt @@ -0,0 +1,16 @@ +5 1 1 0 0 0 +25 1 0 0 0 0 1 0 1 0 0 0 0 0 0 1 0 1 1 0 0 0 1 0 1 0 +25 2.0 -1.0 -1.0 -1.0 -1.0 3.0358071 -1.0 1.005926 0.46386775 0.23816329 -1.0 0.124281086 -1.0 0.47994298 2.817563 -1.0 0.94585943 2.6263208 -1.0 0.09201018 -1.0 3.0067449 0.41978186 0.7602935 -1.0 +14 0 0 1 1 1 1 2 3 3 3 3 4 4 4 +14 0 0 0 0 2 2 4 1 1 2 2 1 1 3 +14 0 1 0 1 1 3 2 3 4 1 3 3 4 4 +14 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 +12 0 1 1 2 2 3 3 3 4 4 4 4 +12 0 0 2 4 4 1 1 2 1 1 3 3 +12 0 0 4 1 3 0 2 4 0 2 1 2 +12 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 +20 0 0 1 1 1 1 2 2 3 3 3 3 3 3 4 4 4 4 4 4 +20 0 0 0 0 2 2 4 4 1 1 1 1 2 2 1 1 1 1 3 3 +20 0 1 0 1 1 3 2 2 3 3 4 4 1 3 3 3 4 4 4 4 +20 0 0 0 0 4 4 1 3 0 2 0 2 4 4 0 2 0 2 1 2 +20 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 diff --git a/tasks/zdpar/algo/gp2/parser2.cpp b/tasks/zdpar/algo/gp2/parser2.cpp new file mode 100644 index 0000000..3d07806 --- /dev/null +++ b/tasks/zdpar/algo/gp2/parser2.cpp @@ -0,0 +1,443 @@ +// + +// Modified high-order nonproj parser from AD3's parsing example and TurboParser +// upon https://github.com/andre-martins/AD3/commit/22131c7457614dd159546500cd1a0fd8cdf2d282 +// upon https://github.com/andre-martins/TurboParser/commit/a87b8e45694c18b826bb3c42e8344bd32928007d + +#include "ad3/FactorGraph.h" +#include "ad3/Utils.h" +#include "utils.h" +#include "FactorTree.h" +#include "FactorHead.h" +#include + +// deubg on windows +#if defined(_WIN32) +//#define MY_DEBUG +//#define USE_PSDD +#endif + +#ifndef MY_DEBUG +#include +#include +#include +namespace py = pybind11; +#endif // MY_DEBUG + +// ===== +void parse_with_packs(const int slen, const bool projective, + const O1Pack* po1, const O2SibPack* po2sib, const O2gPack* po2g, const O3gsibPack* po3gsib, + vector* output_heads, double* output_value + ){ + // Start + // Variables of the factor graph. + vector variables; + // Create factor graph. + AD3::FactorGraph *factor_graph = new AD3::FactorGraph; + + // todo(note): po1 must exist for gh! And po1[0,0]=false, and disallow 0 at m or s! + // get basic o1 info: [m, h] + const vector& o1_masks = po1->masks; + const vector& o1_scores = po1->scores; + // sparse repr + vector so1_hs; + vector so1_ms; + vector so1_scores; + + // Build arc basic variables + // [h, m] -> index in the flattened all valid arcs (basic variables) + vector > index_arcs = vector >(slen, vector(slen, -1)); // [h, m] for the algorithm + int arc_idx = 0; + int tmp_cur_idx = 0; // skip mod==0 as root + for(int m = 0; m < slen; m++){ + for(int h = 0; h < slen; h++){ + if(o1_masks[tmp_cur_idx]){ + double s = o1_scores[tmp_cur_idx]; + AD3::BinaryVariable* v = factor_graph->CreateBinaryVariable(); +#ifdef FACTOR_TREE_ADDITIONAL_SCORE + // todo(WARN): no set potential + v->SetLogPotential(0.); +#else + // todo(WARN): set potential here, which will be divided by num-links per var + v->SetLogPotential(s); +#endif + variables.push_back(v); + // idx in all valid arcs (basic variables) + index_arcs[h][m] = arc_idx; + arc_idx++; + // + so1_hs.push_back(h); + so1_ms.push_back(m); + so1_scores.push_back(s); + } + tmp_cur_idx++; + } + } + + // Build tree factor + AD3::FactorTree *tree_factor = new AD3::FactorTree; + tree_factor->Initialize(projective, slen, &so1_hs, &so1_ms); + vector local_variables = variables; // copy +#ifdef FACTOR_TREE_ADDITIONAL_SCORE + // todo(note): extra o1 scores! + tree_factor->SetAdditionalLogPotentials(so1_scores); +#endif + factor_graph->DeclareFactor(tree_factor, local_variables, true); // put variables in + + // Collect high-order parts + // the ordering is [o2sib, o2g, o3gsib] + // the main purpose is to get a compact inputs for each head-automaton (split into different heads) + // First collect indexes into the *Parts + + // o2sib: [h,m,s]: "m==s" means start child + vector > po2sib_left_idxes(slen); + vector > po2sib_right_idxes(slen); + vector > po2sib_left_scores(slen); + vector > po2sib_right_scores(slen); + if(po2sib){ + for(int r = 0; r < po2sib->size(); r++){ + // get one entry + int cur_h = po2sib->idxes_h[r]; + int cur_m = po2sib->idxes_m[r]; + int cur_s = po2sib->idxes_s[r]; + double cur_score = po2sib->scores[r]; + if(cur_m == 0 || cur_s == 0) + continue; + // left or right + if(cur_h > cur_m){ // left + if(cur_s >= cur_m && cur_s < cur_h){ // only get valid ones + po2sib_left_idxes[cur_h].push_back(r); + po2sib_left_scores[cur_h].push_back(cur_score); + } + } + else{ // right + if(cur_s <= cur_m && cur_s > cur_h){ + po2sib_right_idxes[cur_h].push_back(r); + po2sib_right_scores[cur_h].push_back(cur_score); + } + } + } + } + + // o2g: [g,h,m]: when "h==0" => "g==0" + vector > po2g_left_idxes(slen); + vector > po2g_right_idxes(slen); + vector > po2g_left_scores(slen); + vector > po2g_right_scores(slen); + if(po2g){ + for(int r = 0; r < po2g->size(); r++){ + // get one entry + int cur_h = po2g->idxes_h[r]; + int cur_m = po2g->idxes_m[r]; + int cur_g = po2g->idxes_g[r]; + double cur_score = po2g->scores[r]; + if(cur_m == 0 || cur_m == cur_g) + continue; + // left or right + if(cur_h > cur_m){ // left + po2g_left_idxes[cur_h].push_back(r); + po2g_left_scores[cur_h].push_back(cur_score); + } + else{ // right + po2g_right_idxes[cur_h].push_back(r); + po2g_right_scores[cur_h].push_back(cur_score); + } + } + } + + // o3gsib: [g,h,m,s] ... (both) + vector > po3gsib_left_idxes(slen); + vector > po3gsib_right_idxes(slen); + vector > po3gsib_left_scores(slen); + vector > po3gsib_right_scores(slen); + if(po3gsib){ + for(int r = 0; r < po3gsib->size(); r++){ + // get one entry + int cur_h = po3gsib->idxes_h[r]; + int cur_m = po3gsib->idxes_m[r]; + int cur_s = po3gsib->idxes_s[r]; + int cur_g = po3gsib->idxes_g[r]; + double cur_score = po3gsib->scores[r]; + if(cur_m == 0 || cur_s == 0 || cur_m == cur_g || cur_s == cur_g) + continue; + // left or right + if(cur_h > cur_m){ // left + if(cur_s >= cur_m && cur_s < cur_h){ // only get valid ones + po3gsib_left_idxes[cur_h].push_back(r); + po3gsib_left_scores[cur_h].push_back(cur_score); + } + } + else{ // right + if(cur_s <= cur_m && cur_s > cur_h){ + po3gsib_right_idxes[cur_h].push_back(r); + po3gsib_right_scores[cur_h].push_back(cur_score); + } + } + } + } + + // Build Head Automaton (go through each head and creat left/right) + // TODO(+N): [0,0] as both head and mod --> currently seem fine? + for(int cur_head = 0; cur_head < slen; cur_head++){ + // collect gs + vector gs; + vector gs_variables; + for(int g = 0; g < slen; g++){ + int index_in_valids = index_arcs[g][cur_head]; + if(index_in_valids >= 0){ + gs.push_back(g); + gs_variables.push_back(variables[index_in_valids]); + } + } + // collect ms-left and build + if(cur_head > 0){ + // collect + vector ms_left; + vector left_variables = gs_variables; + for(int m = cur_head - 1; m > 0; m--){ + int index_in_valids = index_arcs[cur_head][m]; + if(index_in_valids >= 0){ + ms_left.push_back(m); + left_variables.push_back(variables[index_in_valids]); + } + } + // build + AD3::FactorHead *factor = new AD3::FactorHead; + factor->Initialize(cur_head, false, gs, ms_left, + po2sib, po2sib_left_idxes[cur_head], po2g, po2g_left_idxes[cur_head], po3gsib, po3gsib_left_idxes[cur_head]); + vector additional_log_potentials = po2sib_left_scores[cur_head]; + additional_log_potentials.insert(additional_log_potentials.end(), po2g_left_scores[cur_head].begin(), po2g_left_scores[cur_head].end()); + additional_log_potentials.insert(additional_log_potentials.end(), po3gsib_left_scores[cur_head].begin(), po3gsib_left_scores[cur_head].end()); + factor->SetAdditionalLogPotentials(additional_log_potentials); + factor_graph->DeclareFactor(factor, left_variables, true); + } + // collect ms-right and build + if(true){ + // collect + vector ms_right; + vector right_variables = gs_variables; + for(int m = cur_head + 1; m < slen; m++){ + int index_in_valids = index_arcs[cur_head][m]; + if(index_in_valids >= 0){ + ms_right.push_back(m); + right_variables.push_back(variables[index_in_valids]); + } + } + // build + AD3::FactorHead *factor = new AD3::FactorHead; + factor->Initialize(cur_head, true, gs, ms_right, + po2sib, po2sib_right_idxes[cur_head], po2g, po2g_right_idxes[cur_head], po3gsib, po3gsib_right_idxes[cur_head]); + vector additional_log_potentials = po2sib_right_scores[cur_head]; + additional_log_potentials.insert(additional_log_potentials.end(), po2g_right_scores[cur_head].begin(), po2g_right_scores[cur_head].end()); + additional_log_potentials.insert(additional_log_potentials.end(), po3gsib_right_scores[cur_head].begin(), po3gsib_right_scores[cur_head].end()); + factor->SetAdditionalLogPotentials(additional_log_potentials); + factor_graph->DeclareFactor(factor, right_variables, true); + } + } + +#ifdef MY_DEBUG + /*// debug print info + (static_cast(factor_graph->GetFactor(0)))->info(cerr, nullptr, nullptr); + for(int i = 1; i < factor_graph->GetNumFactors(); i++){ + (static_cast(factor_graph->GetFactor(i)))->info(cerr, nullptr, nullptr); + }*/ + // verboisty + factor_graph->SetVerbosity(3); +#endif // MY_DEBUG + + // ===== + // Run AD3 + vector posteriors; + vector additional_posteriors; + double value_ref; + double *value = &value_ref; + // +#ifdef USE_PSDD + factor_graph->SetEtaPSDD(1.0); + factor_graph->SetMaxIterationsPSDD(500); + factor_graph->SolveLPMAPWithPSDD(&posteriors, &additional_posteriors, value); +#else + factor_graph->SetMaxIterationsAD3(500); + //factor_graph->SetMaxIterationsAD3(200); + factor_graph->SetEtaAD3(0.05); + factor_graph->AdaptEtaAD3(true); + factor_graph->SetResidualThresholdAD3(1e-3); + //factor_graph->SetResidualThresholdAD3(1e-6); + factor_graph->SolveLPMAPWithAD3(&posteriors, &additional_posteriors, value); +#endif + +#ifdef MY_DEBUG + /*// debug print info + (static_cast(factor_graph->GetFactor(0)))->info(cerr, nullptr, nullptr); + for(int i = 1; i < factor_graph->GetNumFactors(); i++){ + (static_cast(factor_graph->GetFactor(i)))->info(cerr, nullptr, nullptr); + }*/ +#endif // MY_DEBUG + + delete factor_graph; + // ===== + // final decode with posteriors + mst_decode(slen, projective, so1_hs, so1_ms, posteriors, output_heads, output_value); + return; +} + +// interface +#ifndef MY_DEBUG +template +vector array2vector_1d(py::array_t* arr){ + py::buffer_info buf = arr->request(); + CHECK(buf.ndim == 1, "Err: not equal dim == 1"); + int size = buf.size; + T* ptr1 = static_cast(buf.ptr); + // copy here. todo(+2): any way to tmply borrow the data? + vector ret(ptr1, ptr1 + size); + return ret; +} + +vector parse2(const int slen, const int projective, + const int use_o1, py::array_t* o1_masks, py::array_t* o1_scores, + const int use_o2sib, py::array_t* o2sib_ms, py::array_t* o2sib_hs, py::array_t* o2sib_ss, py::array_t* o2sib_scores, + const int use_o2g, py::array_t* o2g_ms, py::array_t* o2g_hs, py::array_t* o2g_gs, py::array_t* o2g_scores, + const int use_o3gsib, py::array_t* o3gsib_ms, py::array_t* o3gsib_hs, py::array_t* o3gsib_ss, py::array_t* o3gsib_gs, py::array_t* o3gsib_scores) +{ + O1Pack* ppo1 = nullptr; + O2SibPack* ppo2sib = nullptr; + O2gPack* ppo2g = nullptr; + O3gsibPack* ppo3gsib = nullptr; + // o1 + CHECK(use_o1, "Error: no po1 info provided!"); + vector po1_masks = array2vector_1d(o1_masks); + vector po1_scores = array2vector_1d(o1_scores); + // todo(note): exclude [0,0], treat (g,h) specially + po1_masks[0] = 0; + // + ppo1 = new O1Pack(po1_masks, po1_scores); + // other vectors + vector po2sib_hs, po2sib_ms, po2sib_ss, po2g_hs, po2g_ms, po2g_gs, po3gsib_hs, po3gsib_ms, po3gsib_ss, po3gsib_gs; + vector po2sib_scores, po2g_scores, po3gsib_scores; + // o2sib + if(use_o2sib){ + po2sib_ms = array2vector_1d(o2sib_ms); + po2sib_hs = array2vector_1d(o2sib_hs); + po2sib_ss = array2vector_1d(o2sib_ss); + po2sib_scores = array2vector_1d(o2sib_scores); + ppo2sib = new O2SibPack(po2sib_hs, po2sib_ms, po2sib_ss, po2sib_scores); + } + // o2g + if(use_o2g){ + po2g_ms = array2vector_1d(o2g_ms); + po2g_hs = array2vector_1d(o2g_hs); + po2g_gs = array2vector_1d(o2g_gs); + po2g_scores = array2vector_1d(o2g_scores); + ppo2g = new O2gPack(po2g_hs, po2g_ms, po2g_gs, po2g_scores); + } + // o3gsib + if(use_o3gsib){ + po3gsib_ms = array2vector_1d(o3gsib_ms); + po3gsib_hs = array2vector_1d(o3gsib_hs); + po3gsib_ss = array2vector_1d(o3gsib_ss); + po3gsib_gs = array2vector_1d(o3gsib_gs); + po3gsib_scores = array2vector_1d(o3gsib_scores); + ppo3gsib = new O3gsibPack(po3gsib_hs, po3gsib_ms, po3gsib_ss, po3gsib_gs, po3gsib_scores); + } + // decode + vector ret(slen, -1); + double v; + parse_with_packs(slen, projective!=0, ppo1, ppo2sib, ppo2g, ppo3gsib, &ret, &v); + // delete + delete ppo1; + delete ppo2sib; + delete ppo2g; + delete ppo3gsib; + return ret; +} + +//c++ -O3 -Wall -Wno-sign-compare -shared -std=c++11 -fPIC `python3 -m pybind11 --includes` -I./ad3/ -I./Eigen -I. ./ad3/*.cpp algo.cpp parser2.cpp -o parser2`python3-config --extension-suffix` +// default policy is return_value_policy::take_ownership when returning pointer +PYBIND11_MODULE(parser2, m) { + m.doc() = "The high-order graph parsing algorithm with AD3"; // optional module docstring + m.def("parse2", &parse2, "A function for high-order parsing"); +} +#endif + +// for debug +/* +b parser2.cpp:300 +p slen +p *po1_masks._M_impl._M_start@25 +p po2sib_hs.size() +p po2g_hs.size() +p po3gsib_hs.size() +b parser2.cpp:62 +*/ + +//#define MY_DEBUG +// c++ -g -DMY_DEBUG -Wall -Wno-sign-compare -std=c++11 -I./ad3/ -I./Eigen -I. ./ad3/*.cpp algo.cpp parser2.cpp +#ifdef MY_DEBUG +template +vector read_vector_1d(std::istream& fin){ + // first read size + int size = 0; + fin >> size; + vector ret(size); + for(int i = 0; i < size; i++){ + fin >> ret[i]; + } + return ret; +} + +void my_debug(const char* data_fname){ + auto fin = std::ifstream(data_fname); + int count = 0; + while(fin){ + // read one record + // start + int slen, projective; + int use_o1, use_o2sib, use_o2g, use_o3gsib; + fin >> slen >> projective >> use_o1 >> use_o2sib >> use_o2g >> use_o3gsib; + if(!fin) + break; + // o1 + auto po1_masks = read_vector_1d(fin); + auto po1_scores = read_vector_1d(fin); + po1_masks[0] = 0; + O1Pack* ppo1 = use_o1 ? (new O1Pack(po1_masks, po1_scores)) : nullptr; + // o2sib + auto po2sib_ms = read_vector_1d(fin); + auto po2sib_hs = read_vector_1d(fin); + auto po2sib_ss = read_vector_1d(fin); + auto po2sib_scores = read_vector_1d(fin); + O2SibPack* ppo2sib = use_o2sib ? (new O2SibPack(po2sib_hs, po2sib_ms, po2sib_ss, po2sib_scores)) : nullptr; + // o2g + auto po2g_ms = read_vector_1d(fin); + auto po2g_hs = read_vector_1d(fin); + auto po2g_gs = read_vector_1d(fin); + auto po2g_scores = read_vector_1d(fin); + O2gPack* ppo2g = use_o2g ? (new O2gPack(po2g_hs, po2g_ms, po2g_gs, po2g_scores)) : nullptr; + // o3gsib + auto po3gsib_ms = read_vector_1d(fin); + auto po3gsib_hs = read_vector_1d(fin); + auto po3gsib_ss = read_vector_1d(fin); + auto po3gsib_gs = read_vector_1d(fin); + auto po3gsib_scores = read_vector_1d(fin); + O3gsibPack* ppo3gsib = use_o3gsib ? (new O3gsibPack(po3gsib_hs, po3gsib_ms, po3gsib_ss, po3gsib_gs, po3gsib_scores)) : nullptr; + // + // decode + vector ret(slen, -1); + double v; + parse_with_packs(slen, projective!=0, ppo1, ppo2sib, ppo2g, ppo3gsib, &ret, &v); + // delete + delete ppo1; + delete ppo2sib; + delete ppo2g; + delete ppo3gsib; + count++; + } +} + +int main(int argc, char** argv){ + //my_debug(argv[1]); + my_debug("data.txt"); +} + +#endif // MY_DEBUG diff --git a/tasks/zdpar/algo/gp2/setup.py b/tasks/zdpar/algo/gp2/setup.py new file mode 100644 index 0000000..fc1dbef --- /dev/null +++ b/tasks/zdpar/algo/gp2/setup.py @@ -0,0 +1,26 @@ +import os, sys +import glob +import platform + +from distutils.core import setup, Extension +from distutils import sysconfig + +if platform.system() == "Windows": + cpp_args = ['-DNOMINMAX'] +else: + cpp_args = ['-std=c++11', '-O3', '-fPIC'] + +sfc_module = Extension( + 'parser2', + sources=glob.glob("./ad3/*.cpp")+["algo.cpp", "parser2.cpp"], + include_dirs=['pybind11/include', './ad3/', './Eigen/', '.'], + language='c++', + extra_compile_args=cpp_args, + ) + +setup( + name='parser2', + ext_modules=[sfc_module], +) + +# python setup.py build_ext diff --git a/tasks/zdpar/algo/gp2/utils.h b/tasks/zdpar/algo/gp2/utils.h new file mode 100644 index 0000000..5a5d3d4 --- /dev/null +++ b/tasks/zdpar/algo/gp2/utils.h @@ -0,0 +1,123 @@ +// + +#ifndef P2_UTILS_H +#define P2_UTILS_H + +// helpful utils + +#include +#include +#include +#include + +using std::vector; +using std::string; +using std::cerr; +using std::endl; +using std::ostream; + +// ===== +// Use four types of sub-trees + +// special O1Pack + +// Order 1 is different than others +struct O1Pack{ + const vector& masks; // whether this edge is valid (not pruned) [slen*slen]; should be bool, but make it int for api + const vector& scores; // scores for each edge [slen*slen] + O1Pack(const vector& ma, const vector& ss): masks(ma), scores(ss) {} +}; + +// The rest uses the indexes + +// [h, m, s], when m==s, it means first sib +struct O2SibPack{ + const vector& idxes_h; + const vector& idxes_m; + const vector& idxes_s; + const vector& scores; + O2SibPack(const vector& h, const vector& m, const vector& s, const vector& ss): + idxes_h(h), idxes_m(m), idxes_s(s), scores(ss){} + const unsigned size() const { return scores.size(); } +}; + +// [h, m, g], when h==0, g==0 +struct O2gPack{ + const vector& idxes_h; + const vector& idxes_m; + const vector& idxes_g; + const vector& scores; + O2gPack(const vector& h, const vector& m, const vector& g, const vector& ss): + idxes_h(h), idxes_m(m), idxes_g(g), scores(ss){} + const unsigned size() const { return scores.size(); } +}; + +// the combined one +struct O3gsibPack{ + const vector& idxes_h; + const vector& idxes_m; + const vector& idxes_s; + const vector& idxes_g; + const vector& scores; + O3gsibPack(const vector& h, const vector& m, const vector& s, const vector& g, const vector& ss): + idxes_h(h), idxes_m(m), idxes_s(s), idxes_g(g), scores(ss){} + const unsigned size() const { return scores.size(); } +}; + +// ===== +// helpers + +const double MY_NEGINF = -1e20; + +inline void MY_ERROR(const string& x){ + cerr << "Error: " << x << endl; + throw x; +} + +inline void CHECK(bool v, const string& x){ + if(!v) MY_ERROR(x); +} + +template +inline void CHECK_GE(T1 x, T2 y){ CHECK(x >= y, "Err: CHECK_GE"); } + +template +inline void CHECK_LT(T1 x, T2 y){ CHECK(x < y, "Err: CHECK_LT"); } + +inline void debug_print(ostream& fout, const vector* ms, const vector* hs, const vector* ss, const vector* gs, const vector* scores){ + int size = scores->size(); + if(ms){ + fout << "m "; + for(int x : *ms) + fout << x << " "; + fout << endl; + } + if(hs){ + fout << "h "; + for(int x : *hs) + fout << x << " "; + fout << endl; + } + if(ss){ + fout << "s "; + for(int x : *ss) + fout << x << " "; + fout << endl; + } + if(gs){ + fout << "g "; + for(int x : *gs) + fout << x << " "; + fout << endl; + } + fout << "v "; + for(double x : *scores) + fout << x << " "; + fout << endl; +} + +// ===== +void mst_decode(int sentence_length, bool projective, const vector& hs, const vector& ms, const vector& scores, + vector *heads, double *value); + +#endif // !P2_UTILS_H diff --git a/tasks/zdpar/algo/hop.py b/tasks/zdpar/algo/hop.py new file mode 100644 index 0000000..635c88c --- /dev/null +++ b/tasks/zdpar/algo/hop.py @@ -0,0 +1,50 @@ +# + +# high-order parsing + +from msp.utils import zwarn, zfatal + +# todo(note): which version of TurboParser and AD3? +# Updated: now directly use AD3's example and TurboParser +# https://github.com/andre-martins/AD3/commit/22131c7457614dd159546500cd1a0fd8cdf2d282 +# https://github.com/andre-martins/TurboParser/commit/a87b8e45694c18b826bb3c42e8344bd32928007d +# +import numpy as np +try: + from .parser2 import parse2 +except: + zwarn("Cannot find high-order parsing lib, please compile them if needed") + def parse2(*args, **kwargs): + raise NotImplementedError("Compile the C++ codes for this one!") + +# TODO(WARN): be careful about when there are no trees in the current pruning mask! (especially for proj) +# and here only do unlabeled parsing, since labeled ones will cost more, handling labels at the outside +# high order parsing decode +# for the *_pack, is an list/tuple of related indexes and scores, None means no such features +def hop_decode(slen: int, projective: bool, o1_masks, o1_scores, o2sib_pack, o2g_pack, o3gsib_pack): + # dummy arguments, use None will get argument error + dummy_int_arr = np.array([], dtype=np.int32) + dummy_double_arr = dummy_int_arr.astype(np.double) + # prepare inputs + projective = int(projective) + assert o1_masks is not None, "Must provide masks" + if o1_scores is None: # dummy o1 scores + o1_scores = np.full([slen, slen], 0., dtype=np.double) + use_o2sib = int(o2sib_pack is not None) + use_o2g = int(o2g_pack is not None) + use_o3gsib = int(o3gsib_pack is not None) + # o2sib: m, h, s, scores + if not use_o2sib: + o2sib_pack = [dummy_int_arr, dummy_int_arr, dummy_int_arr, dummy_double_arr] + # o2g: m, h, g, scores + if not use_o2g: + o2g_pack = [dummy_int_arr, dummy_int_arr, dummy_int_arr, dummy_double_arr] + # o3gsib: m, h, s, g, scores + if not use_o3gsib: + o3gsib_pack = [dummy_int_arr, dummy_int_arr, dummy_int_arr, dummy_int_arr, dummy_double_arr] + # ===== + # return is vector which becomes a list in py + ret = parse2(slen, projective, True, o1_masks.reshape(-1), o1_scores.reshape(-1), + use_o2sib, *o2sib_pack, use_o2g, *o2g_pack, use_o3gsib, *o3gsib_pack) + ret[0] = 0 # set ROOT's parent as self + return ret diff --git a/tasks/zdpar/algo/nmst.py b/tasks/zdpar/algo/nmst.py index 7906dfc..dc7f670 100644 --- a/tasks/zdpar/algo/nmst.py +++ b/tasks/zdpar/algo/nmst.py @@ -47,7 +47,6 @@ def _common_nmst(CPU_f, scores_expr, mask_expr, lengths_arr, labeled, ret_arr): def nmst_unproj(scores_expr, mask_expr, lengths_arr, labeled=True, ret_arr=False): return _common_nmst(mst_unproj, scores_expr, mask_expr, lengths_arr, labeled, ret_arr) -# TODO(!): still have bugs in proj! def nmst_proj(scores_expr, mask_expr, lengths_arr, labeled=True, ret_arr=False): return _common_nmst(mst_proj, scores_expr, mask_expr, lengths_arr, labeled, ret_arr) @@ -90,9 +89,11 @@ def _ensure_margins_norm(marginals_expr): full_shape = BK.get_shape(marginals_expr) combined_marginals_expr = marginals_expr.view(full_shape[:-2] + [-1]) # [BS, Len, Len*L] # should be 1., but for no-solution situation there can be small values (+1 for all in this case, norm later) - combined_marginals_expr += (combined_marginals_expr.sum(dim=-1, keepdim=True) < 1e-10).float() + # make 0./0. = 0. + combined_marginals_sum = combined_marginals_expr.sum(dim=-1, keepdim=True) + combined_marginals_sum += (combined_marginals_sum < 1e-5).float() * 1e-5 # then norm - combined_marginals_expr /= combined_marginals_expr.sum(dim=-1, keepdim=True) + combined_marginals_expr /= combined_marginals_sum return combined_marginals_expr.view(full_shape) # [BS, Len, Len, L], [BS, Len] -> [BS, Len, Len, L] @@ -102,6 +103,7 @@ def _ensure_margins_norm(marginals_expr): # todo(warn): also the order of dimension is slightly different than the original algorithm, here we have [m, h] # todo(+1): simple for unlabeled situation # todo(warn): be careful about Numerical Unstability when the matrix is not inversable, which will make it 0/0!! +# todo(note): mask out non-valid values (diag, padding, root-mod), need to be careful about this? def nmarginal_unproj(scores_expr, mask_expr, lengths_arr, labeled=True): assert labeled with BK.no_grad_env(): @@ -111,11 +113,15 @@ def nmarginal_unproj(scores_expr, mask_expr, lengths_arr, labeled=True): diag1_m = BK.eye(maxlen).double() # [m, h] scores_expr_d = scores_expr.double() mask_expr_d = mask_expr.double() + invalid_pad_expr_d = 1.-mask_expr_d + # [*, m, h] + full_invalid_d = (diag1_m + invalid_pad_expr_d.unsqueeze(-1) + invalid_pad_expr_d.unsqueeze(-2)).clamp(0., 1.) + full_invalid_d[:, 0] = 1. # # first make it unlabeled by sum-exp scores_unlabeled = BK.logsumexp(scores_expr_d, dim=-1) # [BS, m, h] - # force small values at diag entries - scores_unlabeled_diag_neg = scores_unlabeled + Constants.REAL_PRAC_MIN * diag1_m + # force small values at diag entries and padded ones + scores_unlabeled_diag_neg = scores_unlabeled + Constants.REAL_PRAC_MIN * full_invalid_d # # minus the MaxElement to make it more stable with larger values, to make it numerically stable. # [BS, m, h] # todo(+N): since one and only one Head is selected, thus minus by Max will make it the same? @@ -156,12 +162,14 @@ def nmarginal_unproj(scores_expr, mask_expr, lengths_arr, labeled=True): # # det and inverse; using LU decomposition to hit two birds with one stone. diag1_m00 = BK.eye(maxlen-1).double() - LM00_inv, LM00_lu = diag1_m00.gesv(LM00) # [BS, m-1, h-1] + # deprecated operation + # LM00_inv, LM00_lu = diag1_m00.gesv(LM00) # [BS, m-1, h-1] # # todo(warn): lacking P here, but the partition should always be non-negative! # LM00_det = BK.abs((LM00_lu*diag1_m00).sum(-1).prod(-1)) # [BS, ] # d(logZ)/d(LM00) = (LM00^-1)^T # # directly inverse (need pytorch >= 1.0) # LM00_inv = LM00.inverse() + LM00_inv = BK.get_inverse(LM00, diag1_m00) LM00_grad = LM00_inv.transpose(-1, -2) # [BS, m-1, h-1] # marginal(m,h) = d(logZ)/d(score(m,h)) = d(logZ)/d(LM00) * d(LM00)/d(score(m,h)) = INV_mm - INV_mh # padding and minus @@ -192,7 +200,8 @@ def nmarginal_unproj(scores_expr, mask_expr, lengths_arr, labeled=True): # assert False, "Problem here" # # back to plain float32 - ret = marginals_labeled.float() + masked_marginals_labeled = marginals_labeled * (1.-full_invalid_d).unsqueeze(-1) + ret = masked_marginals_labeled.float() return _ensure_margins_norm(ret) # b tasks/zdpar/algo/nmst:102 diff --git a/tasks/zdpar/common/confs.py b/tasks/zdpar/common/confs.py index bee2723..5503145 100644 --- a/tasks/zdpar/common/confs.py +++ b/tasks/zdpar/common/confs.py @@ -5,19 +5,30 @@ import json from msp import utils, nn from msp.nn import NIConf -from msp.utils import Conf, Logger, zopen, zfatal +from msp.utils import Conf, Logger, zopen, zfatal, zwarn -from ..graph.parser import GraphParserConf, GraphParser - -# +# todo(note): when testing, unless in the default setting, need to specify: "dict_dir" and "model_load_name" class DConf(Conf): def __init__(self): + # whether allow multi-source mode + self.multi_source = False # split file names for inputs (use multi_reader) # data paths self.train = "" self.dev = "" self.test = "" self.cache_data = True # turn off if large data + self.to_cache_shuffle = False self.dict_dir = "./" + # cuttings for training (simply for convenience without especially preparing data...) + self.cut_train = "" + self.cut_dev = "" + # extra aux inputs, must be aligned with the corresponded data + self.aux_repr_train = "" + self.aux_repr_dev = "" + self.aux_repr_test = "" + self.aux_score_train = "" + self.aux_score_dev = "" + self.aux_score_test = "" # save name for trainer not here!! self.model_load_name = "zmodel.best" # load name self.output_file = "zout" @@ -25,7 +36,7 @@ def __init__(self): self.input_format = "conllu" self.output_format = "conllu" # pretrain - self.pretrain_file = "" + self.pretrain_file = [] self.init_from_pretrain = False self.pretrain_scale = 1.0 self.pretrain_init_nohit = 1.0 @@ -43,12 +54,11 @@ def __init__(self): self.code_train = "" self.code_dev = "" self.code_test = "" - self.code_pretrain = "" + self.code_pretrain = [] # testing mode extra embeddings self.test_extra_pretrain_files = [] self.test_extra_pretrain_codes = [] - # the overall conf class DepParserConf(Conf): def __init__(self, partype, args): @@ -62,9 +72,25 @@ def __init__(self, partype, args): self.dconf = DConf() # data-conf # parser-conf if partype == "graph": + from ..graph.parser import GraphParserConf self.pconf = GraphParserConf() + elif partype == "td": + from ..transition.topdown.parser import TdParserConf + self.pconf = TdParserConf() + elif partype == "ef": + from ..ef.parser import EfParserConf + self.pconf = EfParserConf() + elif partype == "g1": + from ..ef.parser import G1ParserConf + self.pconf = G1ParserConf() + elif partype == "g2": + from ..ef.parser import G2ParserConf + self.pconf = G2ParserConf() + elif partype == "s2": + from ..ef.parser import S2ParserConf + self.pconf = S2ParserConf() else: - zfatal("Unknown parser type: %s, please provide correct type with the option `partype:{graph,}`") + zfatal(f"Unknown parser type: {partype}, please provide correct type with the option.") # ===== # self.update_from_args(args) @@ -75,13 +101,36 @@ def build_model(partype, conf, vpack): pconf = conf.pconf parser = None if partype == "graph": + # original first-order graph with various output constraints + from ..graph.parser import GraphParser parser = GraphParser(pconf, vpack) + elif partype == "td": + # re-implementation of the top-down stack-pointer parser + zwarn("Warning: Current implementation of td-mode is deprecated and outdated.") + from ..transition.topdown.parser import TdParser + parser = TdParser(pconf, vpack) + elif partype == "ef": + # generalized easy-first parser + from ..ef.parser import EfParser + parser = EfParser(pconf, vpack) + elif partype == "g1": + # first-order graph parser + from ..ef.parser import G1Parser + parser = G1Parser(pconf, vpack) + elif partype == "g2": + # higher-order graph parser + from ..ef.parser import G2Parser + parser = G2Parser(pconf, vpack) + elif partype == "s2": + # two-stage parser + from ..ef.parser import S2Parser + parser = S2Parser(pconf, vpack) else: zfatal("Unknown parser type: %s") return parser # the start of everything -def init_everything(args): +def init_everything(args, ConfType=None): # search for basic confs all_argv, basic_argv = Conf.search_args(args, ["partype", "conf_output", "log_file", "msp_seed"], [str, str, str, int], [None, None, Logger.MAGIC_CODE, None]) @@ -93,9 +142,14 @@ def init_everything(args): if conf_output: with zopen(conf_output, "w") as fd: for k,v in all_argv.items(): - fd.write(f"{k}:{v}\n") + # todo(note): do not save this one + if k != "conf_output": + fd.write(f"{k}:{v}\n") utils.init(log_file, msp_seed) # real init of the conf - conf = DepParserConf(parser_type, args) + if ConfType is None: + conf = DepParserConf(parser_type, args) + else: + conf = ConfType(parser_type, args) nn.init(conf.niconf) return conf diff --git a/tasks/zdpar/common/data.py b/tasks/zdpar/common/data.py index ab46735..3eacc4f 100644 --- a/tasks/zdpar/common/data.py +++ b/tasks/zdpar/common/data.py @@ -3,11 +3,12 @@ # data for dependency parsing import json -from typing import Iterable +from typing import Iterable, List, Set, Dict import numpy as np +import pickle -from msp.utils import zfatal, zopen, zwarn -from msp.data import Instance, FileOrFdStreamer, VocabHelper, MultiHelper +from msp.utils import zfatal, zopen, zwarn, Random, Helper, MathHelper, zcheck +from msp.data import Instance, FileOrFdStreamer, VocabHelper, MultiHelper, AdapterStreamer, MultiJoinStreamer from msp.zext.dpar import ConlluReader, write_conllu, ConlluParse from msp.zext.seq_data import InstanceHelper, SeqFactor, InputCharFactor @@ -37,31 +38,127 @@ def __init__(self, words, poses=None, heads=None, labels=None, code=""): heads = [0] + heads if labels is not None: labels = _tmp_root_list + labels - # todo(warn): calculate the unproj situations: 0 proj edge, 1 unproj edge - if heads is not None: - unprojs = [0] + ConlluParse.calc_crossed(heads[1:]) - else: - unprojs = None - # self.poses = SeqFactor(poses) self.heads = SeqFactor(heads) - self.labels = SeqFactor(labels) - # - self.children_mask_arr = None + self.labels = SeqFactor(labels) # todo(warn): for td, no processing, directly use 0 as padding!! + # ===== + # for top-down parsing, but deprecated now + self.children_mask_arr: np.ndarray = None # [N, N] + # self.children_se: List[Set] = None + self.children_list: List[List] = None # list of children + self.descendant_list: List[List] = None # list of all descendants # - self.unprojs = SeqFactor(unprojs) + self.free_dist_alpha: float = None + # ===== + # other helpful info (calculate in need) + self._unprojs = None + self._sibs = None + self._gps = None + # all children should be aranged l2r + self._children_left = None + self._children_right = None + self._children_all = None # predictions (optional prediction probs) self.pred_poses = SeqFactor(None) self.pred_pos_scores = SeqFactor(None) self.pred_heads = SeqFactor(None) self.pred_labels = SeqFactor(None) self.pred_par_scores = SeqFactor(None) + self.pred_miscs = SeqFactor(None) # for real length self.length = InstanceHelper.check_equal_length([self.words, self.chars, self.poses, self.heads, self.labels]) - 1 + # extra inputs, for example, those from mbert + self.extra_features = {"aux_repr": None} + # extra preds info + self.extra_pred_misc = {} def __len__(self): return self.length + # ===== + # helpful functions + @staticmethod + def get_children(heads): + cur_len = len(heads) + children_left, children_right = [[] for _ in range(cur_len)], [[] for _ in range(cur_len)] + for i in range(1, cur_len): + h = heads[i] + if i < h: + children_left[h].append(i) + else: + children_right[h].append(i) + return children_left, children_right + + @staticmethod + def get_sibs(children_left, children_right): + # get sibling list: sided nearest sibling (self if single) + sibs = [-1] * len(children_left) + for vs in children_left: + if len(vs) > 0: + prev_s = vs[-1] + for cur_m in reversed(vs): + sibs[cur_m] = prev_s + prev_s = cur_m + for vs in children_right: + if len(vs) > 0: + prev_s = vs[0] + for cur_m in vs: + sibs[cur_m] = prev_s + prev_s = cur_m + # todo(+2): how about 0? currently set to 0 + sibs[0] = 0 + return sibs + + @staticmethod + def get_gps(heads): + # get grandparent list + # todo(+2): how about 0? currently will be 0 + return [heads[x] for x in heads] + + # ===== + # other properties + @property + def unprojs(self): + # todo(warn): calculate the unproj situations: 0 proj edge, 1 unproj edge + if self._unprojs is None: + self._unprojs = [0] + ConlluParse.calc_crossed(self.heads.vals[1:]) + return self._unprojs + + @property + def sibs(self): + if self._sibs is None: + self._sibs = ParseInstance.get_sibs(self.children_left, self.children_right) + return self._sibs + + @property + def gps(self): + if self._gps is None: + heads = self.heads.vals + self._gps = ParseInstance.get_gps(heads) + return self._gps + + @property + def children_left(self): + if self._children_left is None: + heads = self.heads.vals + self._children_left, self._children_right = ParseInstance.get_children(heads) + return self._children_left + + @property + def children_right(self): + if self._children_right is None: + heads = self.heads.vals + self._children_left, self._children_right = ParseInstance.get_children(heads) + return self._children_right + + @property + def children_all(self): + if self._children_all is None: + self._children_all = [a+b for a,b in zip(self.children_left, self.children_right)] + return self._children_all + + # ===== + # special processing for training # for h-local-loss def get_children_mask_arr(self, add_self_if_leaf=True): if self.children_mask_arr is None: @@ -79,7 +176,88 @@ def get_children_mask_arr(self, add_self_if_leaf=True): self.children_mask_arr = masks return self.children_mask_arr - # + # set once when reading! + def set_children_info(self, oracle_strategy, label_ranking_dict:Dict=None, free_dist_alpha:float=0.): + heads = self.heads.vals + the_len = len(heads) + # self.children_set = [set() for _ in range(the_len)] + self.children_list = [[] for _ in range(the_len)] + tmp_descendant_list = [None for _ in range(the_len)] + # exclude root + for m, h in enumerate(heads[1:], 1): + # self.children_set[h].add(m) + self.children_list[h].append(m) # l2r order + # re-arrange list order (left -> right) + if oracle_strategy == "i2o": + for h in range(the_len): + self.children_list[h].sort(key=lambda x: -x if x1: + Random.shuffle(one_list) + else: + for i, one_list in enumerate(self.children_list): + if len(one_list)>1: + values = [abs(i-z)*alpha for z in one_list] + # TODO(+N): is it correct to use Gumble for ranking + logprobs = np.log(MathHelper.softmax(values)) + G = np.random.random_sample(len(logprobs)) + ranking_values = np.log(-np.log(G)) - logprobs + self.children_list[i] = [one_list[z] for z in np.argsort(ranking_values)] + + INST_RAND = Random.stream(Random.random_sample) + # shuffle for n2f mode + def shuffle_children_n2f(self): + rr = ParseInstance.INST_RAND + for i, one_list in enumerate(self.children_list): + if len(one_list) > 1: + # todo(warn): use small random to break tie + values = [abs(i-z)+next(rr) for z in one_list] + self.children_list[i] = [one_list[z] for z in np.argsort(values)] + # ===== + # todo(warn): exclude artificial root node + def get_real_values_select(self, selections): ret = [] for name in selections: @@ -101,22 +279,65 @@ def get_real_values_all(self): return ret # ===== Data Reader -def get_data_reader(file_or_fd, input_format, aug_code, use_la0): +def get_data_reader(file_or_fd, input_format, aug_code, use_la0, aux_repr_file=None, aux_score_file=None, cut=None): + cut = -1 if (cut is None or len(cut)==0) else int(cut) if input_format == "conllu": - return ParseConlluReader(file_or_fd, aug_code, use_la0=use_la0) + r = ParseConlluReader(file_or_fd, aug_code, use_la0=use_la0, cut=cut) elif input_format == "plain": - return ParseTextReader(file_or_fd, aug_code) + r = ParseTextReader(file_or_fd, aug_code, cut=cut) elif input_format == "json": - return ParseJsonReader(file_or_fd, aug_code, use_la0=use_la0) + r = ParseJsonReader(file_or_fd, aug_code, use_la0=use_la0, cut=cut) else: zfatal("Unknown input_format %s, should select from {conllu,plain,json}" % input_format) + r = None + if aux_repr_file is not None and len(aux_repr_file)>0: + r = AuxDataReader(r, aux_repr_file, "aux_repr") + if aux_score_file is not None and len(aux_score_file)>0: + r = AuxDataReader(r, aux_score_file, "aux_score") + return r + +# pre-computed auxiliary data +class AuxDataReader(AdapterStreamer): + def __init__(self, base_streamer, aux_repr_file, aux_name): + super().__init__(base_streamer) + self.file = aux_repr_file + self.fd = None + self.aux_name = aux_name + + def __del__(self): + if self.fd is not None: + self.fd.close() + + def _restart(self): + self.base_streamer_.restart() + if isinstance(self.file, str): + if self.fd is not None: + self.fd.close() + self.fd = zopen(self.file, mode='rb', encoding=None) + else: + zcheck(self.restart_times_ == 0, "Cannot restart a FdStreamer") + + def _next(self): + one = self.base_streamer_.next() + if self.base_streamer_.is_eos(one): + return None + res = pickle.load(self.fd) + # todo(warn): specific checks + if isinstance(res, (tuple, list)): + assert len(res) == 2 + assert all(len(z)==len(one)+1 for z in res), "Unmatched length for the aux_score arr" + else: + assert len(res) == len(one)+1, "Unmatched length for the aux_repr arr" + one.extra_features[self.aux_name] = res + return one # read from conllu file class ParseConlluReader(FileOrFdStreamer): - def __init__(self, file_or_fd, aug_code, use_xpos=False, use_la0=False): + def __init__(self, file_or_fd, aug_code, use_xpos=False, use_la0=False, cut=-1): super().__init__(file_or_fd) self.aug_code = aug_code self.reader = ConlluReader() + self.cut = cut # if use_xpos: self.pos_f = lambda t: t.xpos @@ -128,6 +349,8 @@ def __init__(self, file_or_fd, aug_code, use_xpos=False, use_la0=False): self.label_f = lambda t: t.label def _next(self): + if self.count_ == self.cut: + return None parse = self.reader.read_one(self.fd) if parse is None: return None @@ -140,13 +363,16 @@ def _next(self): # read raw text from plain file, no annotations (mainly for prediction) class ParseTextReader(FileOrFdStreamer): - def __init__(self, file_or_fd, aug_code, skip_empty_line=False, sep=None): + def __init__(self, file_or_fd, aug_code, skip_empty_line=False, sep=None, cut=-1): super().__init__(file_or_fd) self.aug_code = aug_code self.skip_empty_line = skip_empty_line self.sep = sep + self.cut = cut def _next(self): + if self.count_ == self.cut: + return None while True: line = self.fd.readline() if len(line)==0: @@ -157,16 +383,20 @@ def _next(self): break words = line.split(self.sep) one = ParseInstance(words, code=self.aug_code) + one.init_idx = self.count() return one # read json, one line per instance class ParseJsonReader(FileOrFdStreamer): - def __init__(self, file_or_fd, aug_code, use_la0=False): + def __init__(self, file_or_fd, aug_code, use_la0=False, cut=-1): super().__init__(file_or_fd) self.aug_code = aug_code self.use_la0 = use_la0 + self.cut = cut def _next(self): + if self.count_ == self.cut: + return None line = self.fd.readline() if len(line)==0: return None @@ -177,6 +407,7 @@ def _next(self): the_labels = [z.split(":")[0] for z in the_labels] one = ParseInstance(words=vv.get("word", None), poses=vv.get("pos", None), heads=vv.get("head", None), labels=the_labels, code=self.aug_code) + one.init_idx = self.count() return one # ===== @@ -214,11 +445,17 @@ def write(self, insts): class ParserConlluWriter(ParserWriter): def __init__(self, file_or_fd): super().__init__(file_or_fd) - self.obtain_names = ["words", "poses", "pred_heads", "pred_labels"] + self.obtain_names = ["words", "poses", "pred_heads", "pred_labels", "pred_miscs"] def write_one(self, inst: ParseInstance): values = inst.get_real_values_select(self.obtain_names) - write_conllu(self.fd, *values) + # record ROOT info in headlines + headlines = [] + if inst.pred_miscs.has_vals(): + headlines.append(json.dumps({"ROOT-MISC": inst.pred_miscs.vals[0]})) + if len(inst.extra_pred_misc) > 0: + headlines.append(json.dumps(inst.extra_pred_misc)) + write_conllu(self.fd, *values, headlines=headlines) class ParserPlainWriter(ParserWriter): def __init__(self, file_or_fd): @@ -235,3 +472,27 @@ def __init__(self, file_or_fd): def write_one(self, inst: ParseInstance): values = inst.get_real_values_all() self.fd.write(json.dumps(values)) + +# ===== +# multi-source reader +# split the srings by "," +# todo(note): best if input data are balanced, otherwise mix and shuffle data should be better +def get_multisoure_data_reader(file, input_format, aug_code, use_la0, aux_repr_file="", aux_score_file="", cut="", sep=","): + # ----- + def _get_and_pad(num, s, pad): + results = s.split(sep) if len(s)>0 else [] + if len(results) < num: + results.extend([pad] * (num - len(results))) + return results + # ----- + list_file = file.split(sep) + num_files = len(list_file) + list_aug_code = _get_and_pad(num_files, aug_code, "") + list_aux_repr_file = _get_and_pad(num_files, aux_repr_file, "") + list_aux_score_file = _get_and_pad(num_files, aux_score_file, "") + list_cut = _get_and_pad(num_files, cut, "") + # prepare all streams + list_streams = [get_data_reader(f, input_format, ac, use_la0, arf, asf, cc) + for f, ac, arf, asf, cc in zip(list_file, list_aug_code, list_aux_repr_file, list_aux_score_file, list_cut)] + # join them + return MultiJoinStreamer(list_streams) diff --git a/tasks/zdpar/common/model.py b/tasks/zdpar/common/model.py new file mode 100644 index 0000000..09d4ce6 --- /dev/null +++ b/tasks/zdpar/common/model.py @@ -0,0 +1,513 @@ +# + +# basic modules handling inputs and encodings for the parsers + +from typing import List +import numpy as np + +from msp.utils import Conf, Random, zlog, JsonRW, zfatal, zwarn +from msp.data import VocabPackage, MultiHelper +from msp.model import Model +from msp.nn import BK +from msp.nn import refresh as nn_refresh +from msp.nn.layers import BasicNode, Affine, RefreshOptions, NoDropRop +from msp.nn.modules import EmbedConf, MyEmbedder, EncConf, MyEncoder +# +from msp.zext.seq_helper import DataPadder +from msp.zext.process_train import RConf, SVConf, ScheduledValue, OptimConf + +from .data import ParseInstance + +# ===== +# the input modeling part + +# ----- +# for joint pos prediction + +class JPosConf(Conf): + def __init__(self): + self._input_dim = -1 # to be filled + self.jpos_multitask = False # the overall switch for multitask + self.jpos_lambda = 0. # lambda(switch) for training: loss_parsing + lambda*loss_pos + self.jpos_decode = True # switch for decoding + self.jpos_enc = EncConf().init_from_kwargs(enc_rnn_type="lstm2", enc_hidden=1024, enc_rnn_layer=0) + self.jpos_stacking = True # stacking as inputs (using adding here) + +class JPosModule(BasicNode): + def __init__(self, pc: BK.ParamCollection, jconf: JPosConf, pos_vocab): + super().__init__(pc, None, None) + self.jpos_stacking = jconf.jpos_stacking + self.jpos_multitask = jconf.jpos_multitask + self.jpos_lambda = jconf.jpos_lambda + self.jpos_decode = jconf.jpos_decode + # encoder0 + jconf.jpos_enc._input_dim = jconf._input_dim + self.enc = self.add_sub_node("enc0", MyEncoder(self.pc, jconf.jpos_enc)) + self.enc_output_dim = self.enc.get_output_dims()[0] + # output + # todo(warn): here, include some other things for convenience + num_labels = len(pos_vocab) + self.pred = self.add_sub_node("pred", Affine(self.pc, self.enc_output_dim, num_labels, init_rop=NoDropRop())) + # further stacking (if not, then simply multi-task learning) + if jconf.jpos_stacking: + self.pos_weights = self.add_param("w", (num_labels, self.enc_output_dim)) # [n, dim] to be added + else: + self.pos_weights = None + + def get_output_dims(self, *input_dims): + return (self.enc_output_dim, ) + + # return (out_expr, jpos_pack) + def __call__(self, input_repr, mask_arr, require_loss, require_pred, gold_pos_arr=None): + enc0_expr = self.enc(input_repr, mask_arr) # [*, len, d] + # + enc1_expr = enc0_expr + pos_probs, pos_losses_expr, pos_preds_expr = None, None, None + if self.jpos_multitask: + # get probabilities + pos_logits = self.pred(enc0_expr) # [*, len, nl] + pos_probs = BK.softmax(pos_logits, dim=-1) + # stacking for input -> output + if self.jpos_stacking: + enc1_expr = enc0_expr + BK.matmul(pos_probs, self.pos_weights) + # simple cross entropy loss + if require_loss and self.jpos_lambda>0.: + gold_probs = BK.gather_one_lastdim(pos_probs, gold_pos_arr).squeeze(-1) # [*, len] + # todo(warn): multiplying the factor here, but not maksing here (masking in the final steps) + pos_losses_expr = (-self.jpos_lambda) * gold_probs.log() + # simple argmax for prediction + if require_pred and self.jpos_decode: + pos_preds_expr = pos_probs.max(dim=-1)[1] + return enc1_expr, (pos_probs, pos_losses_expr, pos_preds_expr) + +# ----- +# the real bottom part + +class BTConf(Conf): + def __init__(self): + # embedding and encoding layer + self.emb_conf = EmbedConf().init_from_kwargs(dim_word=300, dim_char=30, dim_extras='[50]', extra_names='["pos"]') + self.enc_conf = EncConf().init_from_kwargs(enc_rnn_type="lstm2", enc_hidden=1024, enc_rnn_layer=3) + # joint pos encoder (layer0) + self.jpos_conf = JPosConf() + # ===== + # other options + # inputs + self.char_max_length = 45 + # dropouts + self.drop_embed = 0.33 + self.dropmd_embed = 0. + self.drop_hidden = 0.33 + self.gdrop_rnn = 0.33 # gdrop (always fixed for recurrent connections) + self.idrop_rnn = 0.33 # idrop for rnn + self.fix_drop = True # fix drop for one run for each dropout + self.singleton_unk = 0.5 # replace singleton words with UNK when training + self.singleton_thr = 2 # only replace singleton if freq(val) <= this (also decay with 1/freq) + + def do_validate(self): + if self.jpos_conf.jpos_multitask: + # assert self.jpos_conf.jpos_lambda > 0. # for training + assert 'pos' not in self.emb_conf.extra_names, "No usage for pos prediction?" + +# bottom part of the parsers, also include some shared procedures +class ParserBT(BasicNode): + def __init__(self, pc: BK.ParamCollection, bconf: BTConf, vpack: VocabPackage): + super().__init__(pc, None, None) + self.bconf = bconf + # ===== Vocab ===== + self.word_vocab = vpack.get_voc("word") + self.char_vocab = vpack.get_voc("char") + self.pos_vocab = vpack.get_voc("pos") + # ===== Model ===== + # embedding + self.emb = self.add_sub_node("emb", MyEmbedder(self.pc, bconf.emb_conf, vpack)) + emb_output_dim = self.emb.get_output_dims()[0] + # encoder0 for jpos + # todo(note): will do nothing if not use_jpos + bconf.jpos_conf._input_dim = emb_output_dim + self.jpos_enc = self.add_sub_node("enc0", JPosModule(self.pc, bconf.jpos_conf, self.pos_vocab)) + enc0_output_dim = self.jpos_enc.get_output_dims()[0] + # encoder + # todo(0): feed compute-on-the-fly hp + bconf.enc_conf._input_dim = enc0_output_dim + self.enc = self.add_sub_node("enc", MyEncoder(self.pc, bconf.enc_conf)) + self.enc_output_dim = self.enc.get_output_dims()[0] + # ===== Input Specification ===== + # inputs (word, char, pos) and vocabulary + self.need_word = self.emb.has_word + self.need_char = self.emb.has_char + # todo(warn): currently only allow extra fields for POS + self.need_pos = False + if len(self.emb.extra_names) > 0: + assert len(self.emb.extra_names) == 1 and self.emb.extra_names[0] == "pos" + self.need_pos = True + # todo(warn): currently only allow one aux field + self.need_aux = False + if len(self.emb.dim_auxes) > 0: + assert len(self.emb.dim_auxes) == 1 + self.need_aux = True + # + self.word_padder = DataPadder(2, pad_vals=self.word_vocab.pad, mask_range=2) + self.char_padder = DataPadder(3, pad_lens=(0, 0, bconf.char_max_length), pad_vals=self.char_vocab.pad) + self.pos_padder = DataPadder(2, pad_vals=self.pos_vocab.pad) + # + self.random_sample_stream = Random.stream(Random.random_sample) + + def get_output_dims(self, *input_dims): + return (self.enc_output_dim, ) + + # + def refresh(self, rop=None): + zfatal("Should call special_refresh instead!") + + # whether enabling joint-pos multitask + def jpos_multitask_enabled(self): + return self.jpos_enc.jpos_multitask + + # ==== + # special routines + + def special_refresh(self, embed_rop, other_rop): + self.emb.refresh(embed_rop) + self.enc.refresh(other_rop) + self.jpos_enc.refresh(other_rop) + + def prepare_training_rop(self): + mconf = self.bconf + embed_rop = RefreshOptions(hdrop=mconf.drop_embed, dropmd=mconf.dropmd_embed, fix_drop=mconf.fix_drop) + other_rop = RefreshOptions(hdrop=mconf.drop_hidden, idrop=mconf.idrop_rnn, gdrop=mconf.gdrop_rnn, + fix_drop=mconf.fix_drop) + return embed_rop, other_rop + + # ===== + # run + def _prepare_input(self, insts, training): + word_arr, char_arr, extra_arrs, aux_arrs = None, None, [], [] + # ===== specially prepare for the words + wv = self.word_vocab + W_UNK = wv.unk + UNK_REP_RATE = self.bconf.singleton_unk + UNK_REP_THR = self.bconf.singleton_thr + word_act_idxes = [] + if training and UNK_REP_RATE>0.: # replace unfreq/singleton words with UNK + for one_inst in insts: + one_act_idxes = [] + for one_idx in one_inst.words.idxes: + one_freq = wv.idx2val(one_idx) + if one_freq is not None and one_freq >= 1 and one_freq <= UNK_REP_THR: + if next(self.random_sample_stream) < (UNK_REP_RATE/one_freq): + one_idx = W_UNK + one_act_idxes.append(one_idx) + word_act_idxes.append(one_act_idxes) + else: + word_act_idxes = [z.words.idxes for z in insts] + # todo(warn): still need the masks + word_arr, mask_arr = self.word_padder.pad(word_act_idxes) + # ===== + if not self.need_word: + word_arr = None + if self.need_char: + chars = [z.chars.idxes for z in insts] + char_arr, _ = self.char_padder.pad(chars) + if self.need_pos or self.jpos_multitask_enabled(): + poses = [z.poses.idxes for z in insts] + pos_arr, _ = self.pos_padder.pad(poses) + if self.need_pos: + extra_arrs.append(pos_arr) + else: + pos_arr = None + if self.need_aux: + aux_arr_list = [z.extra_features["aux_repr"] for z in insts] + # pad + padded_seq_len = int(mask_arr.shape[1]) + final_aux_arr_list = [] + for cur_arr in aux_arr_list: + cur_len = len(cur_arr) + if cur_len > padded_seq_len: + final_aux_arr_list.append(cur_arr[:padded_seq_len]) + else: + final_aux_arr_list.append(np.pad(cur_arr, ((0,padded_seq_len-cur_len),(0,0)), 'constant')) + aux_arrs.append(np.stack(final_aux_arr_list, 0)) + # + input_repr = self.emb(word_arr, char_arr, extra_arrs, aux_arrs) + # [BS, Len, Dim], [BS, Len] + return input_repr, mask_arr, pos_arr + + # todo(warn): for rnn, need to transpose masks, thus need np.array + # return input_repr, enc_repr, mask_arr + def run(self, insts, training): + # ===== calculate + # [BS, Len, Di], [BS, Len], [BS, len] + input_repr, mask_arr, gold_pos_arr = self._prepare_input(insts, training) + # enc0 for joint-pos multitask + input_repr0, jpos_pack = self.jpos_enc( + input_repr, mask_arr, require_loss=training, require_pred=(not training), gold_pos_arr=gold_pos_arr) + # [BS, Len, De] + enc_repr = self.enc(input_repr0, mask_arr) + return input_repr, enc_repr, jpos_pack, mask_arr + + # special routine + def aug_words_and_embs(self, aug_vocab, aug_wv): + orig_vocab = self.word_vocab + if self.emb.has_word: + orig_arr = self.emb.word_embed.E.detach().cpu().numpy() + # todo(+2): find same-spelling words in the original vocab if not-hit in the extra_embed? + # todo(warn): here aug_vocab should be find in aug_wv + aug_arr = aug_vocab.filter_embed(aug_wv, assert_all_hit=True) + new_vocab, new_arr = MultiHelper.aug_vocab_and_arr(orig_vocab, orig_arr, aug_vocab, aug_arr, aug_override=True) + # assign + self.word_vocab = new_vocab + self.emb.word_embed.replace_weights(new_arr) + else: + zwarn("No need to aug vocab since delexicalized model!!") + new_vocab = orig_vocab + return new_vocab + +# ===== +# base parser + +# decoding conf +class BaseInferenceConf(Conf): + def __init__(self): + # overall + self.batch_size = 32 + self.infer_single_length = 100 # single-inst batch if >= this length + +# training conf +class BaseTrainingConf(RConf): + def __init__(self): + super().__init__() + # about files + self.no_build_dict = False + self.load_model = False + self.load_process = False + # batch arranger + self.batch_size = 32 + self.train_skip_length = 120 + self.shuffle_train = True + # optimizer and lrate factor for enc&dec&sl(mid) + self.enc_optim = OptimConf() + self.enc_lrf = SVConf().init_from_kwargs(val=1., which_idx="eidx", mode="none", min_val=0., max_val=1.) + self.dec_optim = OptimConf() + self.dec_lrf = SVConf().init_from_kwargs(val=1., which_idx="eidx", mode="none", min_val=0., max_val=1.) + # todo(note): by default freeze this one! + self.dec2_optim = OptimConf() + self.dec2_lrf = SVConf().init_from_kwargs(val=0., which_idx="eidx", mode="none", min_val=0., max_val=1.) + self.mid_optim = OptimConf() + self.mid_lrf = SVConf().init_from_kwargs(val=1., which_idx="eidx", mode="none", min_val=0., max_val=1.) + # margin + self.margin = SVConf().init_from_kwargs(val=0.0) + # scheduled sampling (0. as always oracle) + self.sched_sampling = SVConf().init_from_kwargs(val=0.0) + # other regularization + self.reg_scores_lambda = 0. # score norm regularization + # === + # overwrite default ones + self.patience = 8 + self.anneal_times = 10 + self.max_epochs = 200 + +# base conf +class BaseParserConf(Conf): + def __init__(self, iconf: BaseInferenceConf, tconf: BaseTrainingConf): + self.iconf = iconf + self.tconf = tconf + # Model + self.new_name_conv = False # set this to false to be compatible with previous models (param names)! + self.bt_conf = BTConf() + # others (in inherited classes) + +# base parser +class BaseParser(Model): + def __init__(self, conf: BaseParserConf, vpack: VocabPackage): + self.conf = conf + self.vpack = vpack + tconf = conf.tconf + # ===== Vocab ===== + self.label_vocab = vpack.get_voc("label") + # ===== Model ===== + self.pc = BK.ParamCollection(conf.new_name_conv) + # bottom-part: input + encoder + self.bter = ParserBT(self.pc, conf.bt_conf, vpack) + self.enc_output_dim = self.bter.get_output_dims()[0] + self.enc_lrf_sv = ScheduledValue("enc_lrf", tconf.enc_lrf) + self.pc.optimizer_set(tconf.enc_optim.optim, self.enc_lrf_sv, tconf.enc_optim, + params=self.bter.get_parameters(), check_repeat=True, check_full=True) + # upper-part: decoder + # todo(+2): very ugly here! + self.scorer = self.build_decoder() + self.dec_lrf_sv = ScheduledValue("dec_lrf", tconf.dec_lrf) + self.dec2_lrf_sv = ScheduledValue("dec2_lrf", tconf.dec2_lrf) + try: + params, params2 = self.scorer.get_split_params() + self.pc.optimizer_set(tconf.dec_optim.optim, self.dec_lrf_sv, tconf.dec_optim, + params=params, check_repeat=True, check_full=False) + self.pc.optimizer_set(tconf.dec2_optim.optim, self.dec2_lrf_sv, tconf.dec2_optim, + params=params2, check_repeat=True, check_full=True) + except: + self.pc.optimizer_set(tconf.dec_optim.optim, self.dec_lrf_sv, tconf.dec_optim, + params=self.scorer.get_parameters(), check_repeat=True, check_full=True) + # middle-part: structured layer at the middle (build later for convenient re-loading) + self.slayer = None + self.mid_lrf_sv = ScheduledValue("mid_lrf", tconf.mid_lrf) + # ===== For training ===== + # schedule values + self.margin = ScheduledValue("margin", tconf.margin) + self.sched_sampling = ScheduledValue("ss", tconf.sched_sampling) + self._scheduled_values = [self.margin, self.sched_sampling, self.enc_lrf_sv, + self.dec_lrf_sv, self.dec2_lrf_sv, self.mid_lrf_sv] + self.reg_scores_lambda = conf.tconf.reg_scores_lambda + # for refreshing dropouts + self.previous_refresh_training = True + + # to be implemented + def build_decoder(self): + raise NotImplementedError() + + def build_slayer(self): + return None # None means no slayer + + # build after possible re-loading; + # zero_extra_params: means making the extra params output 0 (zero params for the final Affine layers) + def add_slayer(self): + tconf = self.conf.tconf + self.slayer = self.build_slayer() + if self.slayer is not None: + self.pc.optimizer_set(tconf.mid_optim.optim, self.mid_lrf_sv, tconf.mid_optim, + params=self.slayer.get_parameters(), check_repeat=True, check_full=True) + + def get_extra_refresh_nodes(self): + if self.slayer is None: + return [self.scorer] + else: + return [self.scorer, self.slayer] + + # called before each mini-batch + def refresh_batch(self, training: bool): + # refresh graph + # todo(warn): make sure to remember clear this one + nn_refresh() + # refresh nodes + if not training: + if not self.previous_refresh_training: + # todo(+1): currently no need to refresh testing mode multiple times + return + self.previous_refresh_training = False + embed_rop = other_rop = RefreshOptions(training=False) # default no dropout + else: + embed_rop, other_rop = self.bter.prepare_training_rop() + # todo(warn): once-bug, don't forget this one!! + self.previous_refresh_training = True + # manually refresh + self.bter.special_refresh(embed_rop, other_rop) + for node in self.get_extra_refresh_nodes(): + node.refresh(other_rop) + + def update(self, lrate, grad_factor): + self.pc.optimizer_update(lrate, grad_factor) + + def add_scheduled_values(self, v): + self._scheduled_values.append(v) + + def get_scheduled_values(self): + return self._scheduled_values + + # == load and save models + # todo(warn): no need to load confs here + def load(self, path, strict=True): + self.pc.load(path, strict) + # self.conf = JsonRW.load_from_file(path+".json") + zlog(f"Load {self.__class__.__name__} model from {path}.", func="io") + + def save(self, path): + self.pc.save(path) + JsonRW.to_file(self.conf, path + ".json") + zlog(f"Save {self.__class__.__name__} model to {path}.", func="io") + + def aug_words_and_embs(self, aug_vocab, aug_wv): + return self.bter.aug_words_and_embs(aug_vocab, aug_wv) + + # ===== + # common procedures + def pred2real_labels(self, preds): + return [self.label_vocab.trg_pred2real(z) for z in preds] + + def real2pred_labels(self, reals): + return [self.label_vocab.trg_real2pred(z) for z in reals] + + # ===== + # decode and training for the JPos Module + + def jpos_decode(self, insts: List[ParseInstance], jpos_pack): + # jpos prediction (directly index, no converting) + jpos_preds_expr = jpos_pack[2] + if jpos_preds_expr is not None: + jpos_preds_arr = BK.get_value(jpos_preds_expr) + for one_idx, one_inst in enumerate(insts): + cur_length = len(one_inst) + 1 # including the artificial ROOT + one_inst.pred_poses.build_vals(jpos_preds_arr[one_idx][:cur_length], self.bter.pos_vocab) + + def jpos_loss(self, jpos_pack, mask_expr): + jpos_losses_expr = jpos_pack[1] + if jpos_losses_expr is not None: + # collect loss with mask, also excluding the first symbol of ROOT + final_losses_masked = (jpos_losses_expr * mask_expr)[:, 1:] + # todo(note): no need to scale lambda since already multiplied previously + final_loss_sum = BK.sum(final_losses_masked) + return final_loss_sum + else: + return None + + def reg_scores_loss(self, *scores): + if self.reg_scores_lambda>0.: + sreg_losses = [(z**2).mean() for z in scores] + if len(sreg_losses)>0: + sreg_loss = BK.stack(sreg_losses).mean() * self.reg_scores_lambda + return sreg_loss + return None + + # ===== + # inference and training: to be implemented + # ... + +# ===== +# general label set +class DepLabelHelper: + @staticmethod + def get_label_list(sname): + sname = str.lower(sname) + if sname == "": + sname = "none" + return { + "none": [], + "punct": ['punct'], + "func": ['aux', 'case', 'cc', 'clf', 'cop', 'det', 'mark', 'punct'], + "all": ['punct', 'case', 'nmod', 'amod', 'det', 'obl', 'nsubj', 'root', 'advmod', 'conj', 'obj', 'cc', 'mark', 'aux', 'acl', 'nummod', 'flat', 'cop', 'advcl', 'xcomp', 'appos', 'compound', 'expl', 'ccomp', 'fixed', 'iobj', 'parataxis', 'dep', 'csubj', 'orphan', 'discourse', 'clf', 'goeswith', 'vocative', 'list', 'dislocated', 'reparandum'], + "my1": ['aux', 'cc', 'clf', 'cop', 'det', 'mark', 'punct'], # keep 'case' + }[sname] + + # return {idx -> True/False} or [idx -> True/False] + @staticmethod + def select_label_idxes(sname_or_list, concerned_label_list: List, return_mask: bool, include_true: bool): + if isinstance(sname_or_list, str): + sname_or_list = DepLabelHelper.get_label_list(sname_or_list) + tmp_set = set(sname_or_list) + ret_set = set() + for idx, lab in enumerate(concerned_label_list): + lab = lab.split(":")[0] # todo(warn): first layer of label + in_list = lab in tmp_set + if include_true: + if in_list: + ret_set.add(idx) + else: + if not in_list: + ret_set.add(idx) + if return_mask: + ret_list = [False] * len(concerned_label_list) + for idx in ret_set: + ret_list[idx] = True + return ret_list + else: + return ret_set + +# b tasks/zdpar/common/model.py:215 diff --git a/tasks/zdpar/common/run.py b/tasks/zdpar/common/run.py index cfbb75a..870c207 100644 --- a/tasks/zdpar/common/run.py +++ b/tasks/zdpar/common/run.py @@ -14,14 +14,15 @@ # for indexing instances class IndexerStreamer(FAdapterStreamer): - def __init__(self, in_stream, vpack: ParserVocabPackage): - super().__init__(in_stream, self._go_index, True) + def __init__(self, in_stream, vpack: ParserVocabPackage, inst_preparer): + super().__init__(in_stream, self._go_index, False) # self.word_normer = vpack.word_normer self.w_vocab = vpack.get_voc("word") self.c_vocab = vpack.get_voc("char") self.p_vocab = vpack.get_voc("pos") self.l_vocab = vpack.get_voc("label") + self.inst_preparer = inst_preparer def _go_index(self, inst: ParseInstance): # word @@ -36,12 +37,16 @@ def _go_index(self, inst: ParseInstance): inst.poses.build_idxes(self.p_vocab) if inst.labels.has_vals(): inst.labels.build_idxes(self.l_vocab) + if self.inst_preparer is not None: + inst = self.inst_preparer(inst) + # in fact, inplaced if not wrapping model specific preparer + return inst # -def index_stream(in_stream, vpack, cached): - i_stream = IndexerStreamer(in_stream, vpack) +def index_stream(in_stream, vpack, cached, cache_shuffle, inst_preparer): + i_stream = IndexerStreamer(in_stream, vpack, inst_preparer) if cached: - return InstCacher(i_stream) + return InstCacher(i_stream, shuffle=cache_shuffle) else: return i_stream @@ -65,7 +70,7 @@ def __init__(self, vpack, outf, goldf, out_format): # self.vpack = vpack self.outf = outf - self.goldf = goldf + self.goldf = goldf # todo(note): goldf is only for reporting which file to compare? self.out_format = out_format def add(self, ones: List[ParseInstance]): @@ -82,16 +87,22 @@ def end(self): data_writer.write(self.insts) # evaler = ParserEvaler() - eval_arg_names = ["poses", "heads", "labels", "pred_heads", "pred_labels"] + # evaler2 = ParserEvaler(ignore_punct=True, punct_set={"PUNCT", "SYM"}) + eval_arg_names = ["poses", "heads", "labels", "pred_poses", "pred_heads", "pred_labels"] for one_inst in self.insts: # todo(warn): exclude the ROOT symbol; the model should assign pred_* real_values = one_inst.get_real_values_select(eval_arg_names) evaler.eval_one(*real_values) + # evaler2.eval_one(*real_values) report_str, res = evaler.summary() + # _, res2 = evaler2.summary() # zlog("Results of %s vs. %s" % (self.outf, self.goldf), func="result") zlog(report_str, func="result") + res["gold"] = self.goldf # record which file + # res2["gold"] = self.goldf # record which file zlog("zzzzztest: testing result is " + str(res)) + # zlog("zzzzztest2: testing result is " + str(res2)) zlog("zzzzzpar: %s" % res["res"], func="result") return res @@ -122,7 +133,17 @@ def __init__(self, result_ds: List[Dict]): rr = {} for idx, vv in enumerate(result_ds): rr["zres"+str(idx)] = vv # record all the results for printing - super().__init__(rr, result_ds[0]["res"]) # but only use the first one as DEV-res + # res = 0. if (len(result_ds)==0) else result_ds[0]["res"] # but only use the first one as DEV-res + # todo(note): use the first dev result, and use UAS if res<=0. + if len(result_ds) == 0: + res = 0. + else: + # todo(note): the situation of totally unlabeled parsing + if result_ds[0]["res"] <= 0.: + res = result_ds[0]["tok_uas"] + else: + res = result_ds[0]["res"] + super().__init__(rr, res) # training runner class ParserTrainingRunner(TrainingRunner): @@ -139,23 +160,23 @@ def __init__(self, rconf, model, vpack, dev_outfs, dev_goldfs, dev_out_format): self.dev_outfs = [dev_outfs] * len(dev_goldfs) # to be implemented - def _run_fb(self, annotated_insts): - res = self.model.fb_on_batch(annotated_insts) + def _run_fb(self, annotated_insts, loss_factor: float): + res = self.model.fb_on_batch(annotated_insts, loss_factor=loss_factor) return res def _run_train_report(self): x = self.train_recorder.summary() y = GLOBAL_RECORDER.summary() # todo(warn): get loss/tok - x["loss_tok"] = x["loss_sum"]/x["tok"] + x["loss_tok"] = x.get("loss_sum", 0.)/x["tok"] if len(y) > 0: x.update(y) - zlog(x, "report") + # zlog(x, "report") + Helper.printd(x, " || ") return RecordResult(x) - def _run_update(self, lrate: float): - # already done in fb - self.model.update(lrate) + def _run_update(self, lrate: float, grad_factor: float): + self.model.update(lrate, grad_factor) def _run_validate(self, dev_streams): if not isinstance(dev_streams, Iterable): diff --git a/tasks/zdpar/common/vocab.py b/tasks/zdpar/common/vocab.py index 631597f..f310ded 100644 --- a/tasks/zdpar/common/vocab.py +++ b/tasks/zdpar/common/vocab.py @@ -47,8 +47,15 @@ def build_from_stream(dconf: DConf, stream, extra_stream): for inst in extra_stream: for w in word_normer.norm_stream(inst.words.vals): extra_word_set.add(w) - # must provide dconf.pretrain_file - w2vec = WordVectors.load(dconf.pretrain_file, aug_code=dconf.code_pretrain) + # ----- load (possibly multiple) pretrain embeddings + # must provide dconf.pretrain_file (there can be multiple pretrain files!) + list_pretrain_file, list_code_pretrain = dconf.pretrain_file, dconf.code_pretrain + list_code_pretrain.extend([""] * len(list_pretrain_file)) # pad default ones + w2vec = WordVectors.load(list_pretrain_file[0], aug_code=list_code_pretrain[0]) + if len(list_pretrain_file) > 1: + w2vec.merge_others([WordVectors.load(list_pretrain_file[i], aug_code=list_code_pretrain[i]) + for i in range(1, len(list_pretrain_file))]) + # ----- # first filter according to thresholds word_builder.filter(lambda ww, rank, val: (val >= dconf.word_fthres and rank <= dconf.word_rthres) or w2vec.has_key(ww)) # then add extra ones diff --git a/tasks/zdpar/ef/__init__.py b/tasks/zdpar/ef/__init__.py new file mode 100644 index 0000000..ecb2ba3 --- /dev/null +++ b/tasks/zdpar/ef/__init__.py @@ -0,0 +1,22 @@ +# + +# ready go! + +""" +TODO(+N): +- currently cost does not count future cost (by acyclic constraint), which may bring troubles? +- currently still not full dynamic oracle +----- +- 1) weight by cost for the final loss, +- 1.5) check other systems (like td/nf/lr) +- 1.6) check speed +- 2) other enders, like early-update/max-violation/BSO +- 2.5) refining the parts about the losses +-- (arc/label: cost(arc_wrong=>label-wrong), oracle_states(what to include), separate arc/label & with diff div) +- 2.6) features (par/chs) +# ===== +- ??) multi-margin? +- 3) high order graph and init with graph trained parser +TODO(+N): + 4) acording to profiling, most of the time are still on State-building, which can be surely speed up +""" diff --git a/tasks/zdpar/ef/analysis/ann.py b/tasks/zdpar/ef/analysis/ann.py new file mode 100644 index 0000000..11e0d48 --- /dev/null +++ b/tasks/zdpar/ef/analysis/ann.py @@ -0,0 +1,133 @@ +# + +# simply compare two system outputs and one gold file and possibly support instance viewing and annotation + +import sys +import json +import traceback +import numpy as np +import pandas as pd +from typing import List, Iterable +from collections import defaultdict, Counter +from msp.zext.dpar.conllu_reader import ConlluReader +from msp.data import Vocab +from msp.utils import zopen, zlog, Random, Conf, PickleRW, Constants, ZObject, wrap_color, Helper + +from msp.zext.ana import AnalyzerConf, Analyzer, ZRecNode, ZNodeVisitor, AnnotationTask + +try: + from .common import * +except: + from common import * + +# truncat/norm float numbers +def n_float(x, digit=4): + mul = 10**digit + return int(x*mul)/mul + +class AnalysisConf(Conf): + def __init__(self, args): + super().__init__() + self.gold = "" # gold parse file + self.fs = [] # list of system files + self.vocab = "" # (optional) vocab file + self.ana = AnalyzerConf() # basic analyzer conf + self.getter = GetterConf() + self.labeled = False + # save and load + self.save_name = "" + self.load_name = "" + # + # ===== + self.update_from_args(args) + self.validate() + +# +class AccVisitor(ZNodeVisitor): + def visit(self, node: ZRecNode, values: List): + # calculate UAS and LAS + zc_count = len(node.objs) + assert zc_count == node.count + zc_uas, zc_las = [], [] + if len(node.objs)>0: + for i in range(node.objs[0].nsys): + one_uas = sum(z.ss[i].ucorr for z in node.objs) / (zc_count + 1e-5) + one_las = sum(z.ss[i].lcorr for z in node.objs) / (zc_count + 1e-5) + zc_uas.append(n_float(one_uas)) + zc_las.append(n_float(one_las)) + node.props.update({"acc": f" u={zc_uas} Du={[n_float(a-b) for a,b in zip(zc_uas[:-1], zc_uas[1:])]} " + f"l={zc_las} Dl={[n_float(a-b) for a,b in zip(zc_las[:-1], zc_las[1:])]}"}) + return None + + def show(self, node: ZRecNode): + return f"{node.props.get('acc', 'NOPE')}" + +# +# general protocol: sents, tokens, vs +class ParsingAnalyzer(Analyzer): + def __init__(self, conf, all_sents: List, all_tokens: List, labeled: bool=False, vocab: Vocab=None): + super().__init__(conf) + # + self.labeled = labeled + self.set_var("sents", all_sents, explanation="init", history_idx=-1) + self.set_var("tokens", all_tokens, explanation="init", history_idx=-1) + self.set_var("voc", vocab, explanation="init", history_idx=-1) + # + self.cur_ann_task: PerrAnnotationTask = None + + # ===== + # commands + + # protocol: target: instances, gcode: d for each instances; return one node + def do_group(self, insts_target: str, gcode: str, sum_key="count") -> ZRecNode: + return self._do_group(insts_target, gcode, sum_key, AccVisitor()) + + # protocol: target: znode, scode: show code + def do_show(self, inst_node: str, scode: str): + # todo(note): maybe using pd is better!? + pass + + # print for an obj + def do_print(self, ocode): + vs = self.vars + _ff = compile(ocode, "", "eval") + obj = eval(_ff) + print_obj(obj) + + # ----- + # looking/annotating at specific instances. protocol: target + + # start new annotation task or TODO(!) recover the previous one + def do_ann_start(self, insts_target: str) -> AnnotationTask: + assert self.cur_cmd_target is not None, "Should assign this to a var to avoid accidental loss!" + vs = self.vars + insts = self.get_and_check_type(insts_target, list) + new_task = PerrAnnotationTask(insts) + # todo(note): here need auto save? + if self.cur_ann_task is not None and self.cur_ann_task.remaining>0: + zlog("Warn: the previous annotation task has not been finished yet!") + self.cur_ann_task = new_task + return new_task + + # create one annotation + def do_ac(self, *ann): + pass + + # jump to and print a new instance, 0 means printing current instance + def do_aj(self, offset=0): + offset = int(offset) + assert self.cur_ann_task is not None, "Now attached ann task yet!" + self.cur_ann_task.set_focus(offset=offset) + self.cur_ann_task.show_current() + +# parsing error annotation +class PerrAnnotationTask(AnnotationTask): + def __init__(self, objs: List): + super().__init__(objs) + assert len(objs)>0 + # todo(warn): specific method for telling sent-mode or tok-mod + self.is_sent = hasattr(objs[0], "toks") + + def show_current(self): + obj = self.get_obj() + print_obj(obj, self.focus, self.length) diff --git a/tasks/zdpar/ef/analysis/ark/ana0806.py b/tasks/zdpar/ef/analysis/ark/ana0806.py new file mode 100644 index 0000000..b8897a5 --- /dev/null +++ b/tasks/zdpar/ef/analysis/ark/ana0806.py @@ -0,0 +1,422 @@ +# + +# this one is the overall analysis, maybe need another one for one-vs-one compare/record/annotation + +# analyze on the conllu files of 1) gold 2) system-outputs, 3) system2-outputs +# -- sentence-wise, token/step-wise + +import sys +import json +import numpy as np +import pandas as pd +from typing import List, Iterable +from collections import defaultdict, Counter +from msp.zext.dpar.conllu_reader import ConlluReader +from msp.utils import zopen, zlog, StatRecorder, Helper, Conf, PickleRW, Constants + +NEG_INF = Constants.REAL_PRAC_MIN + +# ===== +def yield_ones(file): + r = ConlluReader() + with zopen(file) as fd: + for p in r.yield_ones(fd): + # add extra info + for line in p.headlines: + try: + one = json.loads(line) + p.misc.update(one) + except: + continue + # add extra info to tokens + cur_length = len(p) + if "ef_score" in p.misc: + ef_order, ef_score = p.misc["ef_order"], p.misc["ef_score"] + assert len(ef_order) == cur_length and len(ef_score) == cur_length + for cur_step in range(cur_length): + cur_idx, cur_score = ef_order[cur_step], ef_score[cur_step] + assert cur_idx > 0 + cur_t = p.get_tok(cur_idx) + cur_t.misc["efi"] = cur_step + cur_t.misc["efs"] = cur_score + if "g1_score" in p.misc: + g1_score = p.misc["g1_score"][1:] + assert len(g1_score) == cur_length + sorted_idxes = list(reversed(np.argsort(g1_score))) + for cur_step in range(cur_length): + cur_idx = sorted_idxes[cur_step] + 1 # offset by root + cur_score = g1_score[cur_idx-1] + assert cur_idx > 0 + cur_t = p.get_tok(cur_idx) + cur_t.misc["g1i"] = cur_step + cur_t.misc["g1s"] = cur_score + yield p + +# +class AnalysisConf(Conf): + def __init__(self, args): + self.gold = "" # gold parse file + self.f1s = [] # main set of system files + self.f2s = [] # second set of system files for comparison + self.cache_pkl = "" # cache stat to file + self.force_refresh_cache = False # ignore cache + self.output = "" + # analyze commands + self.loop = False # whether start a loop for interactive mode + self.grain = "tokens" # sents, tokens, steps + self.fcode = "True" # filter code + self.mcode = "1" # map code + self.rcode = "" # reduce code + # ===== + self.update_from_args(args) + self.validate() + +def bool2int(x): + return 1 if x else 0 + +def norm_pos(x): + if x in {"NOUN", "PROPN", "PRON"}: + return "N" + else: + return x + +# ===== +# helpers + +class ZObject(object): + def __init__(self, m=None): + if m is not None: + self.update(m) + + def update(self, m): + for k, v in m.items(): + setattr(self, k, v) + +def _get_token_object(gp_token, s1_tokens: List, s2_tokens: List): + # first put basic nodes + res = {"g": gp_token, "s1": s1_tokens, "s2": s2_tokens, "s1a": None, "s2a": None, "cr": None} + for idx, node in enumerate(s1_tokens): + res["s1"+str(idx)] = node + for idx, node in enumerate(s2_tokens): + res["s2"+str(idx)] = node + # next get correctness info + g_head, g_label = gp_token.head, gp_token.label + for prefix in ["s1", "s2"]: + cur_tokens = res[prefix] + # results + cur_heads = [z.head for z in cur_tokens] + cur_labels = [z.label for z in cur_tokens] + # correctness + ucorrs = [int(z==g_head) for z in cur_heads] + lcorrs = [int(u and z==g_label) for u,z in zip(ucorrs, cur_labels)] + # set correctness + for tok, ucorr, lcorr in zip(cur_tokens, ucorrs, lcorrs): + tok.misc["ucorr"] = ucorr + tok.misc["lcorr"] = lcorr + # aggregate: agree only >= n//2+1! + agr_all_num = len(cur_tokens) + agr_target = (agr_all_num//2)+1 + u_counter = Counter(cur_heads) + agr_head, unum = u_counter.most_common(1)[0] + if unum> ").strip() + conf.update_from_args([z.strip() for z in cmd.split("\"") if z.strip()!=""]) + except KeyboardInterrupt: + continue # ignore Ctrl-C + else: + break + if fout is not None: + fout.close() + +# ===== +# helper functions + +# helper for div +def hdiv(a, b): + if b == 0: + return (a, b, 0.) + else: + return (a, b, a/b) + +# helper for plain counter & distribution +def hdistr(cs, key="count", ret_format="", min_count=0): + r = Counter(cs) + c = len(cs) + get_keys_f = { + "count": lambda r: sorted(r.keys(), key=(lambda x: -r[x])), + "name": lambda r: sorted(r.keys()), + }[key] + thems = [(k, r[k], r[k]/c) for k in get_keys_f(r)] + # + ret_f = { + "": lambda x: x, + "s": lambda x: "\n".join([str(z) for z in x]), + }[ret_format] + thems = [z for z in thems if z[1]>=min_count] + return ret_f(thems) + +# helper for multi-level counter & distribution +def hmdistr(cs, fks=(), fss=(), ret_format="", min_count=0, show_level=Constants.INT_PRAC_MAX): + # ===== + # collect + max_level = 0 + record = [0, None, {}] # str -> (count, value, next) + for one in cs: + max_level = max(len(one), max_level) + cur_r = record + cur_r[0] += 1 # overall count + for i, x in enumerate(one): + if x not in cur_r[-1]: + cur_r[-1][x] = [0, None, {}] + cur_r = cur_r[-1][x] + cur_r[0] += 1 + # ===== + # sorting methods: Dict[Record] -> List[str] + if not isinstance(fks, (list, tuple)): + fks = [fks] * (max_level+1) + elif len(fks) < max_level+1: + fks = list(fks) + [""] * (max_level+1) + MAP_F_KEYS = { + "count": lambda r: sorted(r.keys(), key=(lambda x: -r[x][0])), # sort by count + "value": lambda r: sorted(r.keys(), key=(lambda x: -r[x][1])), # sort by (summary) value + "name": lambda r: sorted(r.keys()), + } + MAP_F_KEYS[""] = MAP_F_KEYS["count"] + f_keys = [MAP_F_KEYS[k] for k in fks] + # summary methods: Record(count, _, Dict[Record]) -> value + if not isinstance(fss, (list, tuple)): + fss = [fss] * (max_level+1) + elif len(fss) < max_level+1: + fss = list(fss) + [""] * (max_level+1) + MAP_F_SUMS = { + "zero": lambda r: 0., + "avg": lambda r: sum(k*v[0]/r[0] for k,v in r[-1].items()), + } + MAP_F_SUMS[""] = MAP_F_SUMS["zero"] + f_sums = [MAP_F_SUMS[k] for k in fss] + # ===== + # recursively visit + stack_counts = [] + stack_keys = [] + def _visit(cur_r, cur_key_fs=None, cur_sum_fs=None, outputs=None): + level = len(stack_counts) + cur_count, _, cur_next = cur_r + stack_counts.append(cur_count) + visiting_keys = cur_next.keys() if (cur_key_fs is None) else (cur_key_fs[level](cur_next)) + for k in visiting_keys: + stack_keys.append(k) + # stat + next_r = cur_next[k] + _visit(next_r, cur_key_fs, cur_sum_fs, outputs) + next_count = next_r[0] + next_val = next_r[1] + if outputs is not None and level<=show_level: + outputs.append((level, stack_keys.copy(), next_count, next_val, [next_count/z for z in stack_counts])) + stack_keys.pop() + # summarize + if cur_sum_fs is not None: + cur_r[1] = cur_sum_fs[level](cur_r) + stack_counts.pop() + _visit(record, cur_sum_fs=f_sums) + thems = [] + _visit(record, cur_key_fs=f_keys, outputs=thems) + # + ret_f = { + "": lambda x: x, + "s": lambda x: "\n".join([str(z) for z in x]), + }[ret_format] + thems = [z for z in thems if z[2] >= min_count] + return ret_f(thems) + +# +def hgroup(k, length, N, is_step=True): + step = k + all_steps = (2*length-2) if is_step else length-1 + g = int(step/all_steps*N) + return g + +# +def hbin(k, length, N): + return int(k/length*N) + +# +def hcmp_err(a, b): + return {(0,0):"both_good", (1,1):"both_bad", (0,1):"sys_good", (1,0):"sys_bad"}[(a,b)] + +def hclip(x, a, b): + return min(max((x,a)),b) + +def halb0(x): + return x.split(":")[0] + +# ===== +# helper functions version 2: dumpers + +# for mcode +def hdm_err_break(key, z): + # unlabeled + s1_corr = 1-z.s1p_uerr + s2_corr = 1-z.s2p_uerr + both_corr = s1_corr * s2_corr + s1_err = z.s1p_uerr + s1_first_err = z.s1p_real_uerr + # labeled + s1_corr_L = 1-z.s1p_lerr + s2_corr_L = 1-z.s2p_lerr + both_corr_L = s1_corr_L*s2_corr_L + # + s12_diff, s12_diff_L = s1_corr-s2_corr, s1_corr_L-s2_corr_L + return (key, [both_corr, s1_corr, s2_corr, s12_diff, both_corr_L, s1_corr_L, s2_corr_L, s12_diff_L, s1_err, s1_first_err]) + +# for rcode +def hdr_err_break(cs): + r = {} + for key, arr in cs: + if key not in r: + r[key] = [] + r[key].append(arr) + ret = {k: np.array(arrs).mean(axis=0).tolist() + [len(arrs)+0., len(arrs)/len(cs)] for k, arrs in r.items()} + return json.dumps(ret) + +# formatting for printing +def hdformat_err_break(line): + import json + ret = json.loads(line) + for key, res in ret.items(): + both_corr, s1_corr, s2_corr, s12_diff, both_corr_L, s1_corr_L, s2_corr_L, s12_diff_L, s1_err, s1_first_err, *_ = res + print(f"{key} & {s1_corr*100:.2f}/{s1_corr_L*100:.2f} & {s2_corr*100:.2f}/{s2_corr_L*100:.2f} & {s12_diff*100:.2f}/{s12_diff_L*100:.2f} \\\\") + +# questions (interactive tryings) +""" +PYTHONPATH=../src/ python3 ana.py gold:en_dev.conllu f1s:out.ef f2s:out.g1 loop:1 +# +# 1. UAS/LAS/Error-rate? +"mcode:z.s1a.ucorr_avg" "rcode:hdiv(sum(cs), len(cs))" +"mcode:z.s1a.lcorr_avg" "rcode:hdiv(sum(cs), len(cs))" +# 2. which steps are first? +# step-bin vs. UAS +"mcode:(hbin(z.s10.efi, z.sobj.len, 10), z.s10.ucorr)" "rcode:hmdistr(cs, fks=['name'], ret_format='s', min_count=10)" +# step-bin vs. label +"mcode:(hbin(z.s10.efi, z.sobj.len, 10), halb0(z.s10.label))" "rcode:hmdistr(cs, fks=['name'], ret_format='s', min_count=10)" +# step-bin vs. ddist +"mcode:(hbin(z.s10.efi, z.sobj.len, 10), z.s10.ddist)" "rcode:hmdistr(cs, fks=['name'], ret_format='s', min_count=10)" +# avg percentage of label +"mcode:(halb0(z.s10.label), z.s10.efi/z.sobj.len)" "rcode:hmdistr(cs, fks=['value'], fss=['', 'avg'], ret_format='s', min_count=10, show_level=0)" +# avg percentage of ddist +"mcode:(hclip(z.s10.ddist, -10, 10), z.s10.efi/z.sobj.len)" "rcode:hmdistr(cs, fks=['value'], fss=['', 'avg'], ret_format='s', min_count=10, show_level=0)" +# 3. cmp +# +# 4. error breakdown +""" + +if __name__ == '__main__': + main(sys.argv[1:]) diff --git a/tasks/zdpar/ef/analysis/ark/see0827.py b/tasks/zdpar/ef/analysis/ark/see0827.py new file mode 100644 index 0000000..e4b8615 --- /dev/null +++ b/tasks/zdpar/ef/analysis/ark/see0827.py @@ -0,0 +1,24 @@ +# + +try: + from tasks.zdpar.ef.analysis.common import * +except: + from common import * + +import sys + +# +def main(args): + pass + +if __name__ == '__main__': + main(*sys.argv[1:]) + +# todo(note): error fixing with multiple passes of fixing templates +# 0. ignore punct, 1. mwe/complex-np errors (usually local), 2. attachment error (PP, NP, CONJ, ...) +# looking at the relations of the errored nodes in the original tree (according to the frame-structure of the original tree) +# bottom-up or top-down fixing? +# what is the actual relation between the wrongly predicted edge (mod, pred-head, correct-head)? +# -> the "important" errors are the one that creates cycles if added to the original tree +# --> attachment error vs. head error (phrase-internal error and phrase-external error) +# ===== diff --git a/tasks/zdpar/ef/analysis/common.py b/tasks/zdpar/ef/analysis/common.py new file mode 100644 index 0000000..b07c6bb --- /dev/null +++ b/tasks/zdpar/ef/analysis/common.py @@ -0,0 +1,251 @@ +# + +# common procedures for analysis + +import json +import numpy as np +import pandas as pd +from typing import List, Iterable +from collections import defaultdict + +from msp.zext.dpar.conllu_reader import ConlluReader +from msp.utils import zopen, zlog, Random, Conf, PickleRW, Constants, ZObject, wrap_color, Conf + +try: + from .fixer import * +except: + from fixer import * + +# ===== +# reading +def yield_ones(file): + r = ConlluReader() + with zopen(file) as fd: + for p in r.yield_ones(fd): + # add extra info + for line in p.headlines: + try: + one = json.loads(line) + p.misc.update(one) + except: + continue + # add extra info to tokens + cur_length = len(p) + if "ef_score" in p.misc: + ef_order, ef_score = p.misc["ef_order"], p.misc["ef_score"] + assert len(ef_order) == cur_length and len(ef_score) == cur_length + for cur_step in range(cur_length): + cur_idx, cur_score = ef_order[cur_step], ef_score[cur_step] + assert cur_idx > 0 + cur_t = p.get_tok(cur_idx) + cur_t.misc["efi"] = cur_step + cur_t.misc["efs"] = cur_score + if "g1_marg" in p.misc: + g1_marg = p.misc["g1_marg"][1:] + assert len(g1_marg) == cur_length + sorted_idxes = list(reversed(np.argsort(g1_marg))) + for cur_step in range(cur_length): + cur_idx = sorted_idxes[cur_step] + 1 # offset by root + cur_score = g1_marg[cur_idx - 1] + assert cur_idx > 0 + cur_t = p.get_tok(cur_idx) + cur_t.misc["gmi"] = cur_step + cur_t.misc["gms"] = cur_score + if "g1_score" in p.misc: + g1_score = p.misc["g1_score"][1:] + assert len(g1_score) == cur_length + sorted_idxes = list(reversed(np.argsort(g1_score))) + for cur_step in range(cur_length): + cur_idx = sorted_idxes[cur_step] + 1 # offset by root + cur_score = g1_score[cur_idx-1] + assert cur_idx > 0 + cur_t = p.get_tok(cur_idx) + cur_t.misc["g1i"] = cur_step + cur_t.misc["g1s"] = cur_score + yield p + +# ===== +# object and printing + +class GetterConf(Conf): + def __init__(self): + self.build_trees = True + self.build_back_edges = True + self.build_fixes = True + + def do_validate(self): + if self.build_back_edges or self.build_fixes: + assert self.build_trees + +def get_token_object(gp_token, sys_tokens: List, tid: int, conf: GetterConf): + # ===== + def _calc_agr(objs: List): + x = objs[0] + return all(x==z for z in objs) + # ===== + res = {"g": gp_token, "ss": sys_tokens, "id": tid, "nsys": len(sys_tokens)} + # correctness info + g_head, g_label = gp_token.head, gp_token.label + all_heads, all_labels = [g_head], [g_label] + for cur_idx, cur_token in enumerate(sys_tokens, 1): + prefix = "s" + str(cur_idx) # s1, s2, ... + res[prefix] = cur_token + cur_head, cur_label = cur_token.head, cur_token.label + all_heads.append(cur_head) + all_labels.append(cur_label) + ucorr = (cur_head == g_head) + lcorr = (ucorr and (cur_label == g_label)) + cur_token.misc["ucorr"] = int(ucorr) + cur_token.misc["lcorr"] = int(lcorr) + # agreement info + uagr2, uagr3 = _calc_agr(all_heads[1:]), _calc_agr(all_heads) + lagr2, lagr3 = _calc_agr(all_labels[1:]) and uagr2, _calc_agr(all_labels) and uagr3 + agr_info = {"uagr2": int(uagr2), "uagr3": int(uagr3), "lagr2": int(lagr2), "lagr3": int(lagr3)} + res.update(agr_info) + # + ret = ZObject(res) + return ret + +def get_sent_object(cur_token_objs, sid: int, conf: GetterConf): + res = { + "id": sid, + "len": len(cur_token_objs), + "toks": [None] + cur_token_objs, # padding for the artificial root to adjust the idxes + "rtoks": cur_token_objs, # real tokens + } + # build tree? + if conf.build_trees: + nsys = cur_token_objs[0].nsys + gold_heads = [0] + [z.g.head for z in cur_token_objs] + gold_labels = ["_"] + [z.g.label0 for z in cur_token_objs] + all_pred_heads = [[0] + [z.ss[i].head for z in cur_token_objs] for i in range(nsys)] + # todo(note): use level0 labels + all_pred_labels = [["_"] + [z.ss[i].label0 for z in cur_token_objs] for i in range(nsys)] + gold_tree = DepTree(gold_heads, gold_labels, True) + all_pred_trees = [DepTree(a, b, False) for a,b in zip(all_pred_heads, all_pred_labels)] + res["gtree"] = gold_tree + res["ptrees"] = all_pred_trees + # build back_edge? + if conf.build_back_edges: + for i in range(nsys): + edges = all_pred_trees[i].characterize_edges(gold_tree)[1:] + for t, e in zip(cur_token_objs, edges): + assert e is not None + t.ss[i].misc["edge"] = e + # build fixes? + fixer = GoldRefFixer(None) + if conf.build_fixes: + fix_cates = [] # List[set of categories] + for i in range(nsys): + fixer.fix(gold_tree, all_pred_trees[i]) + fix_cates.append(set(z.category for z in all_pred_trees[i].fixes)) + res["fix_cates"] = fix_cates + # build reversed links for token nodes + sobj = ZObject(res) + for t in cur_token_objs: # additional link + t.sent = sobj + t.sid = sid + return sobj + +# +def get_objs(gold_parses: List, sys_parses: List[List], getter: GetterConf): + assert all(len(gold_parses) == len(z) for z in sys_parses) + all_sents, all_tokens = [], [] + for sent_idx in range(len(gold_parses)): + gp_tokens = gold_parses[sent_idx].tokens + sys_tokens = [] + for one_sys_parse in sys_parses: + sys_tokens.append(one_sys_parse[sent_idx].tokens) + assert [z.word for z in gp_tokens] == [z.word for z in sys_tokens[-1]] + # ===== + # pre-computation + cur_tokens = [] + # for each valid token + for token_idx in range(1, len(gp_tokens)): + now_gp_token = gp_tokens[token_idx] + now_sys_tokens = [z[token_idx] for z in sys_tokens] + cur_t_obj = get_token_object(now_gp_token, now_sys_tokens, token_idx, getter) + cur_tokens.append(cur_t_obj) + # for the whole sentence + cur_sent = get_sent_object(cur_tokens, sent_idx, getter) + all_sents.append(cur_sent) + all_tokens.extend(cur_tokens) + return all_sents, all_tokens + +def print_obj(obj: ZObject, ann_focus: int=-1, ann_length: int=-1) -> pd.DataFrame: + if hasattr(obj, "toks"): + sent_obj = obj + highlight_tid = -1 + else: + sent_obj = obj.sent + highlight_tid = obj.id + # collect sentence info + cur_len = sent_obj.len + cur_tidx = 0 + all_fields = [] + err_head, err_label = defaultdict(list), defaultdict(list) + nerr_head, nerr_label = defaultdict(int), defaultdict(int) + for cur_tok in sent_obj.toks[1:]: + if cur_tidx == highlight_tid: + CUR_ERR_BACK = "RED" + else: + CUR_ERR_BACK = "BLUE" + # + cur_tidx += 1 + cur_fields = [] + cur_g = cur_tok.g + head, label, word, upos = cur_g.head, cur_g.label, cur_g.word, cur_g.upos + noerr_head, noerr_label = True, True + for sys_idx in range(1, cur_tok.nsys + 1): + prefix = "s" + str(sys_idx) + cur_s = getattr(cur_tok, prefix) + shead, slabel = cur_s.head, cur_s.label + # get ranking and confidence + efi = getattr(cur_s, "efi", None) + gmi = getattr(cur_s, "gmi", None) + if efi is not None: + confi_descr = f"{efi}/{cur_len}" + elif gmi is not None: + gms = getattr(cur_s, "gms") + confi_descr = f"{gmi}/{gms:.4f}" + else: + confi_descr = "--" + # + if shead != head: + err_head[prefix].append(cur_tidx) + nerr_head[prefix] += 1 + noerr_head = False + if slabel != label: + err_label[prefix].append(cur_tidx) + nerr_label[prefix] += 1 + noerr_label = False + cur_fields.append(wrap_color(str(shead), bcolor=("black" if shead == head else CUR_ERR_BACK))) + cur_fields.append(wrap_color(slabel, bcolor=("black" if slabel == label else CUR_ERR_BACK))) + cur_fields.append(confi_descr) + cur_fields = [wrap_color(str(head), bcolor=("black" if noerr_head else CUR_ERR_BACK)), + wrap_color(label, bcolor=("black" if noerr_label else CUR_ERR_BACK))] + cur_fields + cur_fields.extend([word, upos]) + all_fields.append(cur_fields) + x = pd.DataFrame(all_fields, index=[f"({z})" for z in range(1, cur_len + 1)]) + # print the table and summarize + zlog(f"Current instance\n{x.to_string()}\nSid={sent_obj.id}({ann_focus}/{ann_length}), Wid={highlight_tid}," + f" ErrHead={nerr_head};{err_head}, ErrLabel={nerr_label};{err_label}") + # print fix info if accessible + ptrees = getattr(sent_obj, "ptrees", None) + if ptrees is not None: + for sys_idx, one_ptree in enumerate(ptrees, 1): + zlog(f"Print fixes for system {sys_idx}") + for fid, one_fix in enumerate(one_ptree.fixes): + zlog(f"#{fid}: {str(one_fix)}") + return x + + +# ===== +# helper functions +def val2bin(x, bins): + i = 0 + for b in bins: + if x<=b: + return i + i += 1 + return i diff --git a/tasks/zdpar/ef/analysis/fixer.py b/tasks/zdpar/ef/analysis/fixer.py new file mode 100644 index 0000000..397c065 --- /dev/null +++ b/tasks/zdpar/ef/analysis/fixer.py @@ -0,0 +1,400 @@ +# + +# fixing: transform from pred to gold (multiple rounds) +# bottom up: heading & attaching & labeling +## heading: gold as the backbone +## attaching/labeling: individually for the remainings + +from typing import List, Set, Tuple +from collections import namedtuple +from msp.utils import Conf + +# ===== +# types + +ulabel_types = { +"Nivre17": { + "FUN": ["aux", "case", "cc", "clf", "cop", "det", "mark"], + "MWE": ["compound", "fixed", "flat", "goeswith"], + "CORE": ["ccomp", "csubj", "iobj", "nsubj", "obj", "xcomp"], + "NON-CORE": ["acl", "advcl", "advmod", "amod", "appos", "conj", "dep", "discourse", "dislocated", "expl", "list", "nmod", "nummod", "obl", "orphan", "parataxis", "reparandum", "root", "vocative"], + "PUNCT": ["punct"], +}, +"UDDOC": { + # first put the others (some can be non-core, but too rare) + "Others": ["vocative", "expl", "dislocated", "discourse", "list", "parataxis", "orphan", "reparandum", "dep"], + # then special ones + "Root": ["root"], + "Punct": ["punct"], + "Mwe": ["compound", "fixed", "flat", "goeswith"], + "Conj": ["conj", "cc"], + # function + "Fun": ["aux", "case", "clf", "cop", "det", "mark"], + # Nom/Clause/Mod(other) -> Core/Ncore/Nom, need more details to distinguish N/PP + ("N", "Core"): ["nsubj", "obj", "iobj"], + # ("N", "Ncore"): ["obl"], + # ("N", "Nom"): ["nmod", "appos"], + ("N", "Other"): ["obl", "nmod", "appos"], + ("C", "Core"): ["csubj", "ccomp", "xcomp"], + # ("C", "Ncore"): ["advcl"], + # ("C", "Nom"): ["acl"], + ("C", "Other"): ["advcl", "acl"], + ("Mod", ): ["advmod", "amod", "nummod"], +}, +# ['Others', 'Root', 'Punct', 'Mwe', 'Conj', 'Fun', 'N.Core', 'N.Other', 'C.Core', 'C.Other', 'Mod'] +} + +ulabel2type = {} # type-system -> type-name -> List[edge-types] +for tsys, td in ulabel_types.items(): + ulabel2type[tsys] = {} + for k, vs in td.items(): + if not isinstance(k, (List, Tuple)): + k = [k] + k = tuple(k) + for v in vs: + ulabel2type[tsys][v] = k + +def get_coarse_type(v: str, tsys_name: str="UDDOC", layer=2): + type_tuple = ulabel2type[tsys_name].get(v.split(":")[0], ("UNK",)) + return ".".join(type_tuple[:layer]) +# ===== + +# +class FixerConf(Conf): + def __init__(self): + pass + +# one edge +class DepEdge: + def __init__(self, m, h, label, back_gap=-1): + self.m = m + self.h = h + self.label = label + self.back_gap = back_gap # h is actually the decedent of m by how many gaps, <0 means no such relation + +# one fix step for one edge +class FixChange: + def __init__(self, old_edge: DepEdge, new_h, new_label=None): + self.old_edge = old_edge + self.new_h = new_h + self.new_label = new_label + + @property + def m(self): return self.old_edge.m + + @property + def old_label(self): return self.old_edge.label + + @property + def old_h(self): return self.old_edge.h + + def __repr__(self): + return f"{self.m}<{self.old_label}>{self.old_h} => <{self.new_label}>{self.new_h}" + +# group of fixing changes in one step +class FixOperation: + def __init__(self, changes: List[FixChange], cur_type: str, cur_round: int, **kwargs): + assert len(changes) > 0 + self.changes = changes + self.main_change = self.changes[0] + self.type = cur_type + self.round = cur_round + # + self.cate = None + self.category = None + self.corrections: List = None # list of m-idxes that are corrected (final fix) by this fix + # set extra attribute + for k, v in kwargs.items(): + assert not hasattr(self, k) + setattr(self, k, v) + + def __str__(self): + changes_str = ', '.join([str(z) for z in self.changes]) + x = f"{self.type}-R{self.round}, {self.category}, [{changes_str}], <{len(self.corrections)}>{self.corrections}" + return x + + # category for this fixing according to template + def set_category(self, gold_tree: 'DepTree') -> Tuple: + m, old_h, old_label, new_h = self.main_change.m, self.main_change.old_h, \ + self.main_change.old_label, self.main_change.new_h + gold_h, gold_label = gold_tree.heads[m], gold_tree.labels[m] + if self.type == "attaching": + assert gold_h == new_h + cate = (gold_label, old_label) + # elif self.type == "heading1": + # cate = (old_label, gold_label, gold_tree.labels[old_h]) + elif self.type == "heading": + assert self.back_gap == 1, "No implementation for larger gap" + cate = (gold_tree.labels[old_h], old_label) + elif self.type == "labeling": + assert old_h == new_h and gold_h == new_h + cate = (gold_label, old_label) + else: + raise NotImplementedError(f"No implementation of checking category for {self.type}") + self.cate = cate + self.category = self.match_cetegory(cate) + + def match_cetegory(self, cate: Tuple): + return tuple(get_coarse_type(z) for z in cate) + +# general data structure +class DepTree: + # todo(note): assume already padding 0 for artificial root + def __init__(self, heads: List[int], labels: List[str], is_gold): + # ===== + # these fields will be changed + self.length = len(heads) + self.heads = heads + self.labels = labels + # build the children set + self.children = [set() for _ in range(self.length)] + for m, h in enumerate(heads): + if m>0: + self.children[h].add(m) + # ===== + # only used for gold tree + # build the antecedent info, might waste some space/time, but which is ok + self.is_gold = is_gold + if is_gold: + self.antecedent_maps = self.get_antecedent_maps() # antecedent -> gap + self.depths = self.get_depths() + self.fixes = None + self.fixing_sources = None + else: + self.antecedent_maps = None + self.depths = None + self.fixes = [] # list of FixOperations + # list[modifier] of related fixing operations + self.fixing_sources: List[List[FixOperation]] = [[] for _ in range(self.length)] + + def get_antecedent_maps(self): + ret = [{} for _ in range(self.length)] + def _get_ante_map(cur_idx: int, cur_path: List): + assert len(ret[cur_idx]) == 0 + ret[cur_idx] = {x: i+1 for i,x in enumerate(reversed(cur_path))} # gap along the spine + cur_path.append(cur_idx) + for new_idx in self.children[cur_idx]: + _get_ante_map(new_idx, cur_path) + cur_path.pop() + tmp_path = [] + _get_ante_map(0, tmp_path) + return ret + + def get_depths(self): + ret = [0] * self.length + def _get_depths(cur_idx: int, cur_depth: int): + ret[cur_idx] = cur_depth + for new_idx in self.children[cur_idx]: + _get_depths(new_idx, cur_depth+1) + _get_depths(0, 0) + return ret + + # for example: get_nodes("depths", True) for bottom up ones + def get_nodes(self, sort_by, reverse): + sorting_keys = getattr(self, sort_by) + # todo(note): here, there is an extra backoff idx sorting criterion at the last + sorting_items = [(k, idx) for idx, k in enumerate(sorting_keys)] + sorting_items.sort(reverse=reverse) + return [z[-1] for z in sorting_items] + + # apply a fix operation + def apply_fix(self, fix: FixOperation): + assert not self.is_gold + for one_change in fix.changes: + m, orig_h, new_h, new_lab = one_change.m, one_change.old_h, one_change.new_h, one_change.new_label + assert self.heads[m] == orig_h + self.heads[m] = one_change.new_h + if new_lab is not None: + self.labels[m] = new_lab + self.children[orig_h].remove(m) + self.children[new_h].add(m) + self.fixing_sources[m].append(fix) + self.fixes.append(fix) + + # finish the fixing, group each error edges to the final fix + def finish_fix(self, gold_tree: 'DepTree'): + # classify each fixes + assert not self.is_gold + for one_fix in self.fixes: + one_fix.set_category(gold_tree) + one_fix.corrections = [] # open this one! + # group edge-errors into their final fixes + assert len(self.fixing_sources[0]) == 0 + for m_idx, one_fix_list in enumerate(self.fixing_sources): + assert self.heads[m_idx] == gold_tree.heads[m_idx] + if len(one_fix_list) > 0: + one_fix_list[-1].corrections.append(m_idx) + pass + + # get edge with back_edge characters + def characterize_edges(self, gold_tree: 'DepTree') -> List[DepEdge]: + assert self.length == gold_tree.length + ret = [None] * self.length + for idx in range(0, gold_tree.length): + gold_ante_map = gold_tree.antecedent_maps[idx] + pred_children_set: Set[int] = self.children[idx] # set of children in current pred + for one in pred_children_set: + one_gap = gold_ante_map.get(one) + one_pred_label = self.labels[one] + if one_gap is not None: + one_edge = DepEdge(one, idx, one_pred_label, one_gap) + else: + one_edge = DepEdge(one, idx, one_pred_label) + assert ret[one] is None + ret[one] = one_edge + return ret + +# ===== +# error fixer using the gold as the reference +class GoldRefFixer: + def __init__(self, conf: FixerConf): + pass + + def fix_heading(self, gold_tree: DepTree, pred_tree: DepTree, cur_round: int, applying: bool, max_gap: int): + # ===== + # pass: fix heading errors with bottom-up on gold (fix from the aspect of h->m) + # actually this is fixing from head to mod, instead of mod attaching as in the attaching error + # todo(+N): should we use pred-based, should it be dfs-post_order, or alternative strategies? + all_fos = [] + gold_depths = gold_tree.depths + for to_fix_m in gold_tree.get_nodes("depths", True): + gold_ante_map = gold_tree.antecedent_maps[to_fix_m] + pred_children_set: Set[int] = pred_tree.children[to_fix_m] # set of children in current pred + # check if there are gold-antecedents + pred_children_antecedents: List[int] = [] + pred_children_others: List[int] = [] + for h in pred_children_set: + if h in gold_ante_map: + pred_children_antecedents.append(h) + # # todo(+N): do we filter by max-gap here or later? + # if max_gap<=0 or gold_ante_map[h]<=max_gap: # max_gap not activate if <=0 + # pred_children_antecedents.append(h) + else: + pred_children_others.append(h) + # sort by depth and only fix the lowest one + pred_children_antecedents.sort(key=lambda x: -gold_depths[x]) + pred_children_antecedents_gaps = [gold_depths[to_fix_m]-gold_depths[x] for x in pred_children_antecedents] + # prune by max_gap (allowing continous gaps with each one <=max_gap) + if max_gap > 0: + prev_gap = 0 + cur_idx = 0 + while cur_idx to_fix_m_upper) + to_fix_m_upper = to_fix_m + while True: + # cannot be always true, otherwise there will be loop to lowest_child_antecedent + assert to_fix_m_upper > 0 + to_fix_m_upper_head = pred_tree.heads[to_fix_m_upper] + if to_fix_m_upper_head == gold_tree.heads[to_fix_m_upper]: + to_fix_m_upper = to_fix_m_upper_head + else: + break + all_changes = [] + plabs = pred_tree.labels + # todo(note): back edge as the first fix! + # 1) the lowest antecedent is attached to the first incorrect pred-antecedent (also change label here) + all_changes.append(FixChange(DepEdge(lowest_child_antecedent, to_fix_m, + plabs[lowest_child_antecedent], pred_children_antecedents_gaps[0]), + to_fix_m_upper_head, pred_tree.labels[to_fix_m_upper])) + # 2) link to_fix_m's first incorrect pred-antecedent's child to lowest antecedent + all_changes.append(FixChange(DepEdge(to_fix_m_upper, to_fix_m_upper_head, plabs[to_fix_m_upper]), + lowest_child_antecedent)) + # 3) make all upper antecedents to the lower one (still back-edge to be fixed later) + all_changes.extend([FixChange(DepEdge(z, to_fix_m, g), lowest_child_antecedent) + for z, g in zip(pred_children_antecedents[1:], pred_children_antecedents_gaps[1:])]) + # 4) the all-antecedents' children are given to the lowest antecedent + all_antecedents_children = set() + for ante_idx in pred_children_antecedents: + all_antecedents_children.update(gold_tree.children[ante_idx]) + all_changes.extend([FixChange(DepEdge(z, to_fix_m, plabs[z]), lowest_child_antecedent) + for z in pred_children_others if z in all_antecedents_children]) + fo = FixOperation(all_changes, "heading", cur_round, back_gap=pred_children_antecedents_gaps[0]) + all_fos.append(fo) + if applying: + pred_tree.apply_fix(fo) # directly apply the fix + return all_fos + + def fix_attaching(self, gold_tree: DepTree, pred_tree: DepTree, cur_round: int, applying: bool): + # attach each errors independently (fix from the aspect of m) and also fix label here! + all_fos = [] + for to_fix_m in gold_tree.get_nodes("depths", True): + if to_fix_m>0: + pred_h, pred_lab = pred_tree.heads[to_fix_m], pred_tree.labels[to_fix_m] + gold_h, gold_lab = gold_tree.heads[to_fix_m], gold_tree.labels[to_fix_m] + if pred_h != gold_h: + plabs = pred_tree.labels + fo = FixOperation([FixChange(DepEdge(to_fix_m, pred_h, plabs[to_fix_m]), gold_h, + new_label=(None if (pred_lab==gold_lab) else gold_lab))], "attaching", cur_round) + all_fos.append(fo) + if applying: + pred_tree.apply_fix(fo) # directly apply the fix + return all_fos + + def fix_labeling(self, gold_tree: DepTree, pred_tree: DepTree, cur_round: int, applying: bool): + # only fix labeling errors for correct attachments + all_fos = [] + for to_fix_m in range(gold_tree.length): + pred_h, pred_lab = pred_tree.heads[to_fix_m], pred_tree.labels[to_fix_m] + gold_h, gold_lab = gold_tree.heads[to_fix_m], gold_tree.labels[to_fix_m] + if pred_h==gold_h and pred_lab!=gold_lab: + fo = FixOperation([FixChange(DepEdge(to_fix_m, pred_h, pred_lab), gold_h, gold_lab)], "labeling", cur_round) + all_fos.append(fo) + if applying: + pred_tree.apply_fix(fo) # directly apply the fix + return all_fos + + # this is actually fix_heading(gap=1) + def fix_direct_reverse(self, gold_tree: DepTree, pred_tree: DepTree, cur_round: int, applying: bool): + # fix with reversed edges (fix from the aspect of m) + all_fos = [] + for to_fix_m in gold_tree.get_nodes("depths", True): + if to_fix_m > 0: + # can be iterative + while True: + pred_h = pred_tree.heads[to_fix_m] + if gold_tree.heads[pred_h] == to_fix_m: + assert pred_h != gold_tree.heads[to_fix_m] + plabs = pred_tree.labels + all_changes = [] + # 1) m is attached to pred_h's pred_h + upper_h = pred_tree.heads[pred_h] # pred_h's pred_h + upper_label = pred_tree.labels[pred_h] + all_changes.append(FixChange(DepEdge(to_fix_m, pred_h, plabs[to_fix_m], 1), upper_h, upper_label)) + # 2) pred_h is attached to m + all_changes.append(FixChange(DepEdge(pred_h, upper_h, plabs[pred_h]), to_fix_m)) + # 3) pull some children to m + gold_children = gold_tree.children[to_fix_m] + all_changes.extend([FixChange(DepEdge(z, pred_h, plabs[z]), to_fix_m) + for z in pred_tree.children[pred_h] if z!=to_fix_m and z in gold_children]) + fo = FixOperation(all_changes, "heading", cur_round, back_gap=1) + all_fos.append(fo) + if applying: + pred_tree.apply_fix(fo) # directly apply the fix + else: + break + else: + break + return all_fos + + # ===== + # specific fixing schemes + + # todo(WARN): there will be some labeling errors counting missing: thus we majorly focus on unlabeled errors + def fix(self, gold_tree: DepTree, pred_tree: DepTree): + self.fix_labeling(gold_tree, pred_tree, 0, True) + for i in range(0, gold_tree.length): + fos = self.fix_direct_reverse(gold_tree, pred_tree, i, True) + # fos = self.fix_heading(gold_tree, pred_tree, i, True, 1) # still max-gap==1 + if len(fos) == 0: + break + self.fix_attaching(gold_tree, pred_tree, 0, True) + pred_tree.finish_fix(gold_tree) diff --git a/tasks/zdpar/ef/analysis/run0901.py b/tasks/zdpar/ef/analysis/run0901.py new file mode 100644 index 0000000..98e3ab3 --- /dev/null +++ b/tasks/zdpar/ef/analysis/run0901.py @@ -0,0 +1,78 @@ +# + +import sys +from msp.zext.ana import AnalyzerConf, Analyzer, ZRecNode, AnnotationTask + +try: + from .ann import * +except: + from ann import * + +def main(args): + conf = AnalysisConf(args) + # ===== + if conf.load_name == "": + # recalculate them + # read them + zlog("Read them all ...") + gold_parses = list(yield_ones(conf.gold)) + sys_parses = [list(yield_ones(z)) for z in conf.fs] + # use vocab? + voc = Vocab.read(conf.vocab) if len(conf.vocab)>0 else None + # ===== + # stat them + zlog("Stat them all ...") + all_sents, all_tokens = get_objs(gold_parses, sys_parses, conf.getter) + analyzer = ParsingAnalyzer(conf.ana, all_sents, all_tokens, conf.labeled, vocab=voc) + if conf.save_name != "": + analyzer.do_save(conf.save_name) + else: + analyzer = ParsingAnalyzer(conf.ana, None, None, conf.labeled) + analyzer.do_load(conf.load_name) + # ===== + # analyze them + zlog("Analyze them all ...") + analyzer.loop() + return analyzer + +if __name__ == '__main__': + main(sys.argv[1:]) + +""" +# example (pdb can use readline) +PYTHONPATH=../src/ python3 -m pdb run.py gold:en_dev.conllu fs:zout.g1,zout.ef vocab: +x1 = filter sents "sum(z.s1.ucorr for z in d.rtoks)0 else 0, len(d.g.childs))" +""" diff --git a/tasks/zdpar/ef/analysis/see0904.py b/tasks/zdpar/ef/analysis/see0904.py new file mode 100644 index 0000000..49508db --- /dev/null +++ b/tasks/zdpar/ef/analysis/see0904.py @@ -0,0 +1,150 @@ +# + +# see the fixes + +import sys +import json +try: + from .ann import * +except: + from ann import * + +# write results to std-out +def output_one(d): + s = json.dumps(d) + print(s) + +# ===== +# compare one pair of parses +def main(args): + conf = AnalysisConf(args) + # ===== + if conf.load_name == "": + # recalculate them + # read them + zlog("Read them all ...") + gold_parses = list(yield_ones(conf.gold)) + sys_parses = [list(yield_ones(z)) for z in conf.fs] + # ===== + # stat them + zlog("Stat them all ...") + all_sents, all_tokens = get_objs(gold_parses, sys_parses, conf.getter) + analyzer = ParsingAnalyzer(conf.ana, all_sents, all_tokens, conf.labeled) + if conf.save_name != "": + analyzer.do_save(conf.save_name) + else: + analyzer = ParsingAnalyzer(conf.ana, None, None, conf.labeled) + analyzer.do_load(conf.load_name) + all_sents, all_tokens = analyzer.vars.sents, analyzer.vars.tokens + # ===== + # ===== + # collect the tokens according to the fixes + # one token (and its edge) can be classified into 4/5 classes: reverse-fix, co-fix, attach-fix, label-fix / no-fix? + # ----- + # basic: analyze for both mono- and cross- parsing for multiple languages, any relations between? + # question 0: 1) how many errors there? how many fixes are there, especially co-fix? 2) if reverse-fix get fixed, will the co-fixes be also fixed? (s1 as baseline, s2 as better model) + # question 1: what is s2 better at? reverse or attach? any specific type? + # question 1.5: extra surprise: confidence? + # question 2: will extra link stats help parsing and get similar error reduction (for example, for simpler patterns like N-attach)? + # + # the layering is: sent -> fix -> token + num_sys = len(conf.fs) + # analyze pairs + assert len(conf.fs) == 2, "Here we only analyzing pairwise info!" + q1_stats = [defaultdict(int), defaultdict(int)] # stats for sys1 & sys2 + for one_sent in all_sents: + # first stat for each system + cur_len = one_sent.len + toks, rtoks = one_sent.toks, one_sent.rtoks + for sys_idx in range(num_sys): + other_sys_idx = 1-sys_idx # the other system idx + _cur_stat = q1_stats[sys_idx] + # basic info + _cur_stat["sent"] += 1 + _cur_stat["token"] += cur_len + _cur_stat["uerr"] += len([z for z in rtoks if not z.ss[sys_idx].ucorr]) + _cur_stat["lerr"] += len([z for z in rtoks if not z.ss[sys_idx].lcorr]) + # for each fixes + cur_fixes = one_sent.ptrees[sys_idx].fixes + for one_fix in cur_fixes: + one_fix_type = one_fix.type + _cur_stat[f"fix_{one_fix_type}"] += 1 + _cur_stat[f"fix_{one_fix_type}_corrs"] += len(one_fix.corrections) + if one_fix_type == "heading": + # then especially for co-fixes (unlabeled errors) + # the fix is mainly correct the gold edge of (gold_m, gold_h) + # gold_m is sure to be right, but gold_h can still be not-fixed + gold_m, gold_h = one_fix.changes[1].m, one_fix.changes[0].m + assert one_fix.changes[0].old_h == gold_m + assert not toks[gold_m].ss[sys_idx].ucorr and not toks[gold_h].ss[sys_idx].ucorr + # + hfix, hfix_cofix, hfix_cofix2 = 1, 0, 0 + hit_center_m_flag = False + for z in one_fix.corrections: + if z == gold_m: + hit_center_m_flag = True + else: + hfix_cofix += 1 + if z != gold_h: + hfix_cofix2 += 1 + assert hit_center_m_flag + # + _cur_stat[f"hfix"] += hfix + _cur_stat[f"hfix_cofix"] += hfix_cofix + _cur_stat[f"hfix_cofix2"] += hfix_cofix2 + if toks[gold_m].ss[other_sys_idx].ucorr: # if in the other system, center-m is correct + _cur_stat[f"hfix_other"] += hfix + _cur_stat[f"hfix_other_cofix"] += hfix_cofix + _cur_stat[f"hfix_other_cofix2"] += hfix_cofix2 + count_other_cofix_right = 0 + count_other_cofix_right2 = 0 + for z in one_fix.corrections: + if toks[z].ss[other_sys_idx].ucorr: + if z != gold_m: + count_other_cofix_right += 1 + if z != gold_h: + count_other_cofix_right2 += 1 + _cur_stat[f"hfix_other_cofix_right"] += count_other_cofix_right + _cur_stat[f"hfix_other_cofix_right2"] += count_other_cofix_right2 + # ===== + # after collection, do some calculations + q1_res = [] + for _cur_stat in q1_stats: + sent, token = _cur_stat["sent"], _cur_stat["token"] + _cur_res = {"sent": sent, "token": token, "uas": 1.-_cur_stat["uerr"]/token, "las": 1.-_cur_stat["lerr"]/token, + "uerr": _cur_stat["uerr"]/token, "lerr": _cur_stat["lerr"]/token, + # number of fix per sent + "fps_head": _cur_stat["fix_heading"]/sent, + "fps_attach": _cur_stat["fix_attaching"]/sent, + # number of corrections per sentence + "cps_head": _cur_stat["fix_heading_corrs"]/sent, + "cps_attach": _cur_stat["fix_attaching_corrs"]/sent, + # number of corrections per fix + "cpf_head": _cur_stat["fix_heading_corrs"]/_cur_stat["fix_heading"], + "cpf_attach": _cur_stat["fix_attaching_corrs"]/_cur_stat["fix_attaching"], + # cofix correct rate + "cofix_rate": _cur_stat[f"hfix_other_cofix_right"]/(_cur_stat[f"hfix_other_cofix"]+1e-3), + } + q1_res.append(_cur_res) + Helper.printd(_cur_res, " || ") + _cmp_res = {k: (q1_res[0][k]-q1_res[1][k])/q1_res[0][k] for k in q1_res[0]} + Helper.printd(_cmp_res, " || ") + + # todo(note): hypothesis, bert features help fix attachment errors more? + # <- near and co-occur words have higher chance to have grammatical relationships + # <- frequent specific patterns have higher chance to ... + # +1 bag of in-between word as weak context? + # both overall-attachment analysis and heading/attachment analysis + # bert knows the co-location of links but not the head? + + # back-edge #num-gap for cross-lingual parsing + # three classes: reverse-edge, fix-by-reverse attachment, remaining attachment + + # calculate stats and breakdown and print examples-ids + +if __name__ == '__main__': + main(sys.argv[1:]) + +# example seeing +# for zcl in en ar eu zh 'fi' he hi it ja ko ru sv tr; do echo ${zcl}; PYTHONPATH=../src/ python3 see.py gold:../data/UD_RUN/ud24/${zcl}_test.conllu fs:../finals24/mg1_b0_${zcl}_1/output.test.txt,../finals24/mg1_b1_${zcl}_1/output.test.txt; done +# for zcl in en ar eu zh 'fi' he hi it ja ko ru sv tr; do echo ${zcl}; PYTHONPATH=../src/ python3 see.py gold:../data/UD_RUN/ud24/${zcl}_test.conllu fs:../cl_runs_from_tc/final_g1/cg1_b0_en_1/crout_${zcl}.out,../cl_runs_from_tc/final_g1/cg1_b1_en_1/crout_${zcl}.out; done diff --git a/tasks/zdpar/ef/analysis/see0911_fix.py b/tasks/zdpar/ef/analysis/see0911_fix.py new file mode 100644 index 0000000..9c4b9b5 --- /dev/null +++ b/tasks/zdpar/ef/analysis/see0911_fix.py @@ -0,0 +1,246 @@ +# + +# manually build analysis: see and stat and group all the fixes + +try: + from .ann import * +except: + from ann import * +import sys +import json +import numpy as np +import pdb + +# ===== +LANGS = "en ar eu zh fi he hi it ja ko ru sv tr".split() +SYSTEMS = ["b0", "b1"] +RUNS = [1, 2, 3, 4, 5] + +# ===== +# supervised results +class FileGetterMono: + def get_gold(self, cur_lang, cur_sys, cur_run): + return f"../data/UD_RUN/ud24/{cur_lang}_test.conllu" + + def get_pred(self, cur_lang, cur_sys, cur_run): + return f"../finals24/mg1_{cur_sys}_{cur_lang}_{cur_run}/output.test.txt" + +# crossed results from English as source +class FileGetterCross: + def get_gold(self, cur_lang, cur_sys, cur_run): + return f"../data/UD_RUN/ud24/{cur_lang}_test.conllu" + + def get_pred(self, cur_lang, cur_sys, cur_run): + return f"../cl_runs_from_tc/final_g1/cg1_{cur_sys}_en_{cur_run}/crout_{cur_lang}.out" + +# ===== +# stat for one pair of results +def get_stats(all_sents, all_tokens): + _stats = [defaultdict(int), defaultdict(int)] # sys1, sys2 + for one_sent in all_sents: + # first stat for each system + cur_len = one_sent.len + toks, rtoks = one_sent.toks, one_sent.rtoks + for sys_idx in range(len(SYSTEMS)): + other_sys_idx = 1 - sys_idx # the other system idx + _cur_stat = _stats[sys_idx] + # basic info + _cur_stat["sent"] += 1 + _cur_stat["token"] += cur_len + _cur_stat["ucorr"] += len([z for z in rtoks if z.ss[sys_idx].ucorr]) + _cur_stat["lcorr"] += len([z for z in rtoks if z.ss[sys_idx].lcorr]) + _cur_stat["uerr"] += len([z for z in rtoks if not z.ss[sys_idx].ucorr]) + _cur_stat["lerr"] += len([z for z in rtoks if not z.ss[sys_idx].lcorr]) + # for each fixes + cur_fixes = one_sent.ptrees[sys_idx].fixes + for one_fix in cur_fixes: + one_fix_type = one_fix.type + _cur_stat[f"fix_{one_fix_type}"] += 1 + _cur_stat[(one_fix_type, one_fix.category[0])] += 1 # specific fix type+category! + _cur_stat[f"fix_{one_fix_type}_corrs"] += len(one_fix.corrections) + if one_fix_type == "heading": + # then especially for co-fixes (unlabeled errors) + # the fix is mainly correct the gold edge of (gold_m, gold_h) + # gold_m is sure to be right, but gold_h can still be not-fixed + gold_m, gold_h = one_fix.changes[1].m, one_fix.changes[0].m + assert one_fix.changes[0].old_h == gold_m + assert not toks[gold_m].ss[sys_idx].ucorr and not toks[gold_h].ss[sys_idx].ucorr + # + hfix, hfix_cofix, hfix_cofix2 = 1, 0, 0 + hit_center_m_flag = False + for z in one_fix.corrections: + if z == gold_m: + hit_center_m_flag = True + else: + hfix_cofix += 1 + if z != gold_h: + hfix_cofix2 += 1 + assert hit_center_m_flag + # + _cur_stat[f"hfix"] += hfix + _cur_stat[f"hfix_cofix"] += hfix_cofix + _cur_stat[f"hfix_cofix2"] += hfix_cofix2 + if toks[gold_m].ss[other_sys_idx].ucorr: # if in the other system, center-m is correct + _cur_stat[f"hfix_other"] += hfix + _cur_stat[f"hfix_other_cofix"] += hfix_cofix + _cur_stat[f"hfix_other_cofix2"] += hfix_cofix2 + count_other_cofix_right = 0 + count_other_cofix_right2 = 0 + for z in one_fix.corrections: + if toks[z].ss[other_sys_idx].ucorr: + if z != gold_m: + count_other_cofix_right += 1 + if z != gold_h: + count_other_cofix_right2 += 1 + _cur_stat[f"hfix_other_cofix_right"] += count_other_cofix_right + _cur_stat[f"hfix_other_cofix_right2"] += count_other_cofix_right2 + # ===== + # then pre-calculate sth + for _cur_stat in _stats: + sent, token = _cur_stat["sent"], _cur_stat["token"] + _cur_stat.update({ + # uas/las + "uas": _cur_stat["ucorr"]/token, "las": _cur_stat["lcorr"]/token, + # number of corrections per fix + "cpf_head": _cur_stat["fix_heading_corrs"] / _cur_stat["fix_heading"], + "cpf_attach": _cur_stat["fix_attaching_corrs"] / _cur_stat["fix_attaching"], + # cofix correct rate + "cofix_rate": _cur_stat[f"hfix_other_cofix_right"] / (_cur_stat[f"hfix_other_cofix"] + 1e-3),}) + return _stats + +# +class AveragedValue: + def __init__(self, values): + self.mean = np.mean(values) + self.dev = np.std(values) + self.values = values + + def __float__(self): + return float(self.mean) + + def __repr__(self): + return f"{self.mean:.4f}({self.dev:.4f})" + +# +def show_results(final_results, keys=None): + if keys is None: + # default keys + keys = ["uas", "las", "fix_attaching", "fix_heading", "cpf_head", "cofix_rate"] + # keys = [("heading", k) for k in ['Others', 'Root', 'Punct', 'Mwe', 'Conj', 'Fun', 'N.Core', 'N.Other', 'C.Core', 'C.Other', 'Mod']] + # keys = [("attaching", k) for k in ['Others', 'Root', 'Punct', 'Mwe', 'Conj', 'Fun', 'N.Core', 'N.Other', 'C.Core', 'C.Other', 'Mod']] + lines = [] + res0 = AveragedValue([0.]) + for cur_lang, cur_res in final_results.items(): + lines.append([cur_lang, 0] + [str(cur_res[0].get(k, "--")) for k in keys]) + lines.append([cur_lang, 1] + [str(cur_res[1].get(k, "--")) for k in keys]) + lines.append([cur_lang, -1]) + for k in keys: + r1, r2 = cur_res[0].get(k, res0), cur_res[1].get(k, res0) + lines[-1].append(f"{(float(r2)-float(r1)) / (float(r1)+1e-3):.4f}") + x = pd.DataFrame(lines, columns=["lang", "sys"] + keys) + return x + +# detailed results +# keys = [("heading", k) for k in ['Mwe', 'Conj', 'Fun', 'N.Core', 'N.Other', 'C.Core', 'C.Other', 'Mod']] +# keys = [("attaching", k) for k in ['Others', 'Root', 'Punct', 'Mwe', 'Conj', 'Fun', 'N.Core', 'N.Other', 'C.Core', 'C.Other', 'Mod']] +# todo(note): again, assume comparing two systems +def detailed_result(final_results, extra_keys=()): + # ===== + # pretty-print: number / number-per-sentence / percentage + def _pp(num, divs=()): + num = float(num) + ss = f"{num:.4f}" + for one_div in divs: + one_div = float(one_div) + ss += f"/{num/one_div:.4f}" + return ss + # ===== + # for actual-number / number-per-sent: + # uas las Num-UErr Num-FAtt(perc) Num-FHead FHead-CPF Num-FHead*CPF(perc) cofix-rate + # heading: for each: Num Perc/Rank xCPF-Perc/Rank; attaching for each: Num Perc/Rank + lines = [] + res0 = AveragedValue([0.]) + keys = ["uas", "las", "uerr", "Attach", "Head", "CPF", "Head*CPF", "Cofix"] + list(extra_keys) + for cur_lang, cur_res in final_results.items(): + results = [] + for i in range(len(cur_res)): + one_res = cur_res[i] + sent = one_res["sent"] + uas, las, uerr = one_res['uas'], one_res['las'], one_res['uerr'] + # + n_fix_attaching = one_res['fix_attaching'] + n_fix_heading, n_cpf_head = one_res['fix_heading'], one_res['cpf_head'] + mul_head_cpf = n_float(float(n_fix_heading) * float(n_cpf_head)) + cofix_rate = one_res['cofix_rate'] + # + results.append([uas, las, uerr, n_fix_attaching, n_fix_heading, n_cpf_head, mul_head_cpf, cofix_rate]) + lines.append([cur_lang, i] + [_pp(uas), _pp(las), _pp(uerr), _pp(n_fix_attaching, [sent, uerr]), _pp(n_fix_heading, [sent]), _pp(n_cpf_head), _pp(mul_head_cpf, [sent, uerr]), _pp(cofix_rate)]) + # extra keys + for one_extra_key in extra_keys: + cur_div = {"heading": n_fix_heading, "attaching": n_fix_attaching}[one_extra_key[0]] + cur_fix = one_res.get(one_extra_key, res0) + results[-1].append(cur_fix) + # lines[-1].append(_pp(cur_fix, [sent, cur_div])) + lines[-1].append(_pp(cur_fix, [cur_div])) + # diff comparing results[-1] on results[-2] + cur_diff = [cur_lang, -1] + for a, b in zip(results[-2], results[-1]): + try: + delta = (float(b)-float(a)) / (float(a)+1e-3) + delta = f"{delta*100:.2f}%" + except: + delta = "--" + cur_diff.append(delta) + lines.append(cur_diff) + x = pd.DataFrame(lines, columns=["lang", "sys"] + keys) + return x + +# ===== +def main(): + fgetter = FileGetterMono() + # fgetter = FileGetterCross() + # + final_results = {} + for cur_lang in LANGS: + gold_parses = list(yield_ones(fgetter.get_gold(cur_lang, None, None))) + tmp_lang_results = [] + all_keys = set() + for cur_run in RUNS: + zlog(f"Starting on {cur_lang} {cur_run}") + # read them + assert len(SYSTEMS) == 2, "Currently doing pairwise comparison!" + sys_parses = [list(yield_ones(fgetter.get_pred(cur_lang, z, cur_run))) for z in SYSTEMS] + # stat them + all_sents, all_tokens = get_objs(gold_parses, sys_parses, GetterConf()) + # + one_stats = get_stats(all_sents, all_tokens) + for x in one_stats: + all_keys.update(x.keys()) + tmp_lang_results.append(one_stats) + # get the averaged results + avg_lang_results = [] + for idx in range(len(SYSTEMS)): + _one_sys_avg_result = {k:AveragedValue([tmp_lang_results[r][idx][k] for r in range(len(RUNS))]) for k in all_keys} + avg_lang_results.append(_one_sys_avg_result) + final_results[cur_lang] = avg_lang_results + # ===== + # print the averaged things + PickleRW.to_file(final_results, "_avg_res.pkl") + show_results(final_results) + pdb.set_trace() + +if __name__ == '__main__': + main() + +""" +PYTHONPATH=../src/ python3 ~ +from s2 import *; import pickle; x = pickle.load(open("_avg_mono.0919.pkl", 'rb')); detailed_result(x) +keys1 = [("heading", k) for k in ['Mwe', 'Conj', 'Fun', 'N.Core', 'N.Other', 'C.Core', 'C.Other', 'Mod']] +keys2 = [("attaching", k) for k in ['Others', 'Root', 'Punct', 'Mwe', 'Conj', 'Fun', 'N.Core', 'N.Other', 'C.Core', 'C.Other', 'Mod']] +z1=detailed_result(x, keys1) +z2=detailed_result(x, keys2) +z1.loc[:, ['lang'] + keys1] +z2.loc[:, ['lang'] + keys2] +z1.iloc[:, [0]+list(range(-8,0))] +z2.iloc[:, [0]+list(range(-8,0))] +""" diff --git a/tasks/zdpar/ef/parser/__init__.py b/tasks/zdpar/ef/parser/__init__.py new file mode 100644 index 0000000..0b14245 --- /dev/null +++ b/tasks/zdpar/ef/parser/__init__.py @@ -0,0 +1,6 @@ +# + +from .efp import EfParserConf, EfParser +from .g1p import G1ParserConf, G1Parser +from .g2p import G2ParserConf, G2Parser +from .s2p import S2ParserConf, S2Parser diff --git a/tasks/zdpar/ef/parser/efp.py b/tasks/zdpar/ef/parser/efp.py new file mode 100644 index 0000000..7d55838 --- /dev/null +++ b/tasks/zdpar/ef/parser/efp.py @@ -0,0 +1,418 @@ +# + +# the ef parser + +from typing import List + +from msp.utils import FileHelper, zlog +from msp.data import VocabPackage +from msp.nn import BK +from msp.zext.process_train import SVConf, ScheduledValue + +from ...common.data import ParseInstance +from ...common.model import BaseParserConf, BaseInferenceConf, BaseTrainingConf, BaseParser, DepLabelHelper +from ..systems import EfSearcher, EfSearchConf, EfState, EfAction, ScorerHelper +from ..scorer import Scorer, ScorerConf, SL0Layer, SL0Conf + +from .g1p import G1Parser, PreG1Conf + +# ===== +# confs + +# decoding conf +class EfInferenceConf(BaseInferenceConf): + def __init__(self): + super().__init__() + # by default beam search (no training-related settings) + # self.search_conf = EfSearchConf().init_from_kwargs(plain_k_arc=5, oracle_mode="", plain_beam_size=5, cost_type="") + self.search_conf = EfSearchConf().init_from_kwargs(plain_k_arc=1, oracle_mode="", plain_beam_size=1, cost_type="") + +# training conf +class EfTrainingConf(BaseTrainingConf): + def __init__(self): + super().__init__() + # by default beam+oracle search + self.search_conf = EfSearchConf() + self.loss_function = "hinge" + self.loss_div_step = True # loss /= steps or sents? + self.loss_div_weights = False # a further mode in loss_div_step mode to div by actual weights + self.loss_div_fullbatch = True # whether also count in no-error sentences when div + # special weight for loss0? + self.cost0_weight = SVConf().init_from_kwargs(val=1.0) + # include the action even if it is hit in the predicted ones (does not treat differently with cost0_weight) + self.include_hit_oracle = False + # no blame for label if the arc is wrong + self.noblame_label_on_arc_err = False + # + # feature dropout (to be convenient, only in final forward/backward run) + self.fdrop_chs = 0. + self.fdrop_par = 0. + +# parser conf +class EfParserConf(BaseParserConf): + def __init__(self): + super().__init__(EfInferenceConf(), EfTrainingConf()) + self.sc_conf = ScorerConf() + self.sl_conf = SL0Conf() + self.pre_g1_conf = PreG1Conf() + self.system_labeled = True + # ignore certain children + self.ef_ignore_chs = "" # none, punct, func # todo(+2): ignore parent labels? + # adjustable arc beam size (<1 means not adjustable!!) + self.aabs = SVConf().init_from_kwargs(val=0., max_val=5., mode="linear") + + def do_validate(self): + # specific setting + self.iconf.search_conf._system_labeled = self.system_labeled + self.tconf.search_conf._system_labeled = self.system_labeled + if not self.system_labeled: + # do not use label-related features or strategies + self.ef_ignore_chs = "" + self.sl_conf.use_label_feat = False + self.iconf.search_conf.cost_weight_label = 0. + self.tconf.search_conf.cost_weight_label = 0. + +# ===== +# model + +# the model +class EfParser(BaseParser): + def __init__(self, conf: EfParserConf, vpack: VocabPackage): + super().__init__(conf, vpack) + # ===== basic G1 Parser's loading (also possibly load g1's params) + self.g1parser = G1Parser.pre_g1_init(self, conf.pre_g1_conf) + self.lambda_g1_arc_training = conf.pre_g1_conf.lambda_g1_arc_training + self.lambda_g1_arc_testing = conf.pre_g1_conf.lambda_g1_arc_testing + self.lambda_g1_lab_training = conf.pre_g1_conf.lambda_g1_lab_training + self.lambda_g1_lab_testing = conf.pre_g1_conf.lambda_g1_lab_testing + # + self.add_slayer() + # True if ignored + ignore_chs_label_mask = DepLabelHelper.select_label_idxes(conf.ef_ignore_chs, self.label_vocab.keys(), True, True) + # ===== For decoding ===== + self.inferencer = EfInferencer(self.scorer, self.slayer, conf.iconf, ignore_chs_label_mask) + # ===== For training ===== + self.cost0_weight = ScheduledValue("c0w", conf.tconf.cost0_weight) + self.add_scheduled_values(self.cost0_weight) + self.losser = EfLosser(self.scorer, self.slayer, conf.tconf, self.margin, self.cost0_weight, ignore_chs_label_mask) + # ===== adjustable beam size + self.arc_abs = ScheduledValue("aabs", conf.aabs) + self.add_scheduled_values(self.arc_abs) + # + self.num_label = self.label_vocab.trg_len(True) # todo(WARN): use the original idx + + # todo(+3): ugly changing beam size, here! + def refresh_batch(self, training: bool): + super().refresh_batch(training) + cur_arc_abs = int(self.arc_abs) + if cur_arc_abs >= 1: # only set if >=1 + self.inferencer.searcher.set_plain_arc_beam_size(cur_arc_abs) + self.losser.searcher.set_plain_arc_beam_size(cur_arc_abs) + + def build_decoder(self): + conf = self.conf + # ===== Decoding Scorer ===== + conf.sc_conf._input_dim = self.enc_output_dim + conf.sc_conf._num_label = self.label_vocab.trg_len(True) # todo(WARN): use the original idx + return Scorer(self.pc, conf.sc_conf) + + def build_slayer(self): + conf = self.conf + # ===== Structured Layer ===== + conf.sl_conf._input_dim = self.enc_output_dim + conf.sl_conf._num_label = self.label_vocab.trg_len(True) # todo(WARN): use the original idx + return SL0Layer(self.pc, conf.sl_conf) + + # ===== + # main procedures + + # get g1 scores as extra ones + def _get_g1_pack(self, insts: List[ParseInstance], score_arc_lambda: float, score_lab_lambda: float): + if score_arc_lambda <= 0. and score_lab_lambda <= 0.: + return None + else: + # only need scores + if self.g1parser is None or self.g1parser.g1_use_aux_scores: + arc_score, lab_score, _ = G1Parser.collect_aux_scores(insts, self.num_label) + a, b = (arc_score.suqeeze(-1), lab_score) + else: + a, b = self.g1parser.score_on_batch(insts) + # multiply by lambda here! + a *= score_arc_lambda + b *= score_lab_lambda + return a, b + + # decoding + def inference_on_batch(self, insts: List[ParseInstance], **kwargs): + with BK.no_grad_env(): + self.refresh_batch(False) + # encode + input_repr, enc_repr, jpos_pack, mask_arr = self.bter.run(insts, False) + # g1 score + g1_pack = self._get_g1_pack(insts, self.lambda_g1_arc_testing, self.lambda_g1_lab_testing) + # decode for parsing + self.inferencer.decode(insts, enc_repr, mask_arr, g1_pack, self.label_vocab) + # put jpos result (possibly) + self.jpos_decode(insts, jpos_pack) + # ----- + info = {"sent": len(insts), "tok": sum(map(len, insts))} + return info + + # training + def fb_on_batch(self, annotated_insts: List[ParseInstance], training=True, loss_factor=1., **kwargs): + self.refresh_batch(training) + # encode + input_repr, enc_repr, jpos_pack, mask_arr = self.bter.run(annotated_insts, training) + mask_expr = BK.input_real(mask_arr) + # g1 score + g1_pack = self._get_g1_pack(annotated_insts, self.lambda_g1_arc_training, self.lambda_g1_lab_training) + # the parsing loss + parsing_loss, parsing_scores, info = self.losser.loss(annotated_insts, enc_repr, mask_arr, g1_pack) + # whether add jpos loss? + jpos_loss = self.jpos_loss(jpos_pack, mask_expr) + # + no_loss = True + final_loss = 0. + if parsing_loss is None: + info["loss_parse"] = 0. + else: + final_loss = final_loss + parsing_loss + info["loss_parse"] = BK.get_value(parsing_loss).item() + no_loss = False + if jpos_loss is None: + info["loss_jpos"] = 0. + else: + final_loss = final_loss + jpos_loss + info["loss_jpos"] = BK.get_value(jpos_loss).item() + no_loss = False + if parsing_scores is not None: + arc_scores, lab_scores = parsing_scores + reg_loss = self.reg_scores_loss(arc_scores, lab_scores) + if reg_loss is not None: + final_loss = final_loss + reg_loss + info["fb"] = 1 + if training and not no_loss: + info["fb_back"] = 1 + BK.backward(final_loss, loss_factor) + return info + +# ===== +# other components: decoder and trainer + +class EfInferencer: + def __init__(self, scorer: Scorer, slayer: SL0Layer, iconf: EfInferenceConf, ignore_chs_label_mask): + self.searcher = EfSearcher.build(iconf.search_conf, scorer, slayer) + self.hm_feature_getter0 = ScorerHelper.HmFeatureGetter(ignore_chs_label_mask, 0., 0.) # nodrop for search + + # inplaced writing results + def decode(self, insts: List[ParseInstance], enc_repr, mask_arr, g1_pack, label_vocab): + # ===== + def _set_ef_extra(inst: ParseInstance, end_state: EfState): + path_states = end_state.get_path() + mods = [x.action.mod for x in path_states] + inst.extra_pred_misc["ef_order"] = mods + scores = [x.score for x in path_states] + inst.extra_pred_misc["ef_score"] = scores + # ===== + # todo(warn): may need sg if require k-best + ags = self.searcher.start(insts, self.hm_feature_getter0, enc_repr, mask_arr, g1_pack, margin=0.) + self.searcher.go(ags) + # put the results inplaced + for one_inst, one_ag in zip(insts, ags): + best_state = max(one_ag.ends[0], key=lambda s: s.score_accu) + one_inst.pred_heads.set_vals(best_state.list_arc) # directly int-val for heads + # todo(warn): already the correct labels, no need to transform + one_inst.pred_labels.build_vals(best_state.list_label, label_vocab) + # auxiliary info for ef action order + _set_ef_extra(one_inst, best_state) + return ags + +class EfLosser: + def __init__(self, scorer: Scorer, slayer: SL0Layer, tconf: EfTrainingConf, margin: ScheduledValue, + cost0_weight: ScheduledValue, ignore_chs_label_mask): + self.searcher = EfSearcher.build(tconf.search_conf, scorer, slayer) + self.scorer = scorer + self.slayer = slayer + self.margin = margin + self.cost0_weight = cost0_weight + assert tconf.loss_function == "hinge", "Currently only support max-margin" + self.system_labeled = self.searcher.system_labeled + # + self.loss_div_fullbatch = tconf.loss_div_fullbatch + self.loss_div_step = tconf.loss_div_step + self.loss_div_weights = tconf.loss_div_weights + self.include_hit_oracle = tconf.include_hit_oracle + self.noblame_label_on_arc_err = tconf.noblame_label_on_arc_err + # + self.hm_feature_getter0 = ScorerHelper.HmFeatureGetter(ignore_chs_label_mask, 0., 0.) # nodrop for search + self.hm_feature_getter = ScorerHelper.HmFeatureGetter(ignore_chs_label_mask, tconf.fdrop_chs, tconf.fdrop_par) + + # helper function for adding the lists, mainly for adjusting cost-aware weights, collect for one piece + # todo(note): in one piece, there cannot be actions with the same mod + def _add_lists(self, plain_states: List[EfState], oracle_states: List[EfState], action_list: List, + arc_weight_list: List, label_weight_list: List, bidxes_list: List, bidx, cost0_weight): + cur_size = len(plain_states) + plain_actions, oracle_actions = [x.action for x in plain_states], [x.action for x in oracle_states] + action_list.extend(plain_actions) + action_list.extend(oracle_actions) + bidxes_list.extend([bidx] * (cur_size * 2)) + include_hit_oracle = self.include_hit_oracle + noblame_label_on_arc_err = self.noblame_label_on_arc_err + # assign weights for the loss function + if cost0_weight == 1.: + cur_weight_list = [1.] * cur_size + [-1.] * cur_size + arc_weight_list.extend(cur_weight_list) + label_weight_list.extend(cur_weight_list) + arc_valid_weights, label_valid_weights = cur_size, cur_size + else: + arc_valid_weights, label_valid_weights = 0, 0 # accumulated weights + arc_cost0_mod_set, label_cost0_mod_set = set(), set() # cost==0 mods' set + # always based on the predicted ones + for cur_s in plain_states: + cur_mod = cur_s.action.mod + cur_wrong_arc, cur_wrong_label = cur_s.wrong_al + cur_w_arc = cur_w_label = 1. + if cur_wrong_arc == 0: + cur_w_arc = cost0_weight + arc_cost0_mod_set.add(cur_mod) + if cur_wrong_label == 0: + cur_w_label = cost0_weight + label_cost0_mod_set.add(cur_mod) + elif noblame_label_on_arc_err: + # in this mode, mainly blame the arc error + cur_w_label = cost0_weight + label_cost0_mod_set.add(cur_mod) + # record them + arc_valid_weights += cur_w_arc + label_valid_weights += cur_w_label + arc_weight_list.append(cur_w_arc) + label_weight_list.append(cur_w_label) + # for the oracles, depending on the strategy + if include_hit_oracle: + # average over all oracles + oracle_w_arc, oracle_w_label = -arc_valid_weights/cur_size, -label_valid_weights/cur_size + arc_weight_list.extend([oracle_w_arc]*cur_size) + label_weight_list.extend([oracle_w_label]*cur_size) + else: + # special treatment for the hits + # arc + arc_cost0_mods = [int(a.mod in arc_cost0_mod_set) for a in oracle_actions] + arc_cost0_count = sum(arc_cost0_mods) + arc_hascost_count = cur_size - arc_cost0_count + arc_avg_w = -(arc_valid_weights-cost0_weight*arc_cost0_count)/arc_hascost_count \ + if arc_hascost_count>0 else 0. + arc_weight_list.extend([-cost0_weight if z>0 else arc_avg_w for z in arc_cost0_mods]) + # label + label_cost0_mods = [int(a.mod in label_cost0_mod_set) for a in oracle_actions] + label_cost0_count = sum(label_cost0_mods) + label_hascost_count = cur_size - label_cost0_count + label_avg_w = -(label_valid_weights-cost0_weight*label_cost0_count)/label_hascost_count \ + if label_hascost_count>0 else 0. + label_weight_list.extend([-cost0_weight if z>0 else label_avg_w for z in label_cost0_mods]) + return arc_valid_weights, label_valid_weights + + # todo(warn): first do a cost-augmented search and then do forward-backward + # todo(+N): still a slight diff in search and fb because of non-fix dropout of Att-module! + def loss(self, insts: List[ParseInstance], enc_repr, mask_arr, g1_pack): + # todo(WARN): may need sg if using other loss functions + # first-round search + cur_margin = self.margin.value + cur_cost0_weight = self.cost0_weight.value + with BK.no_grad_env(): + ags = self.searcher.start(insts, self.hm_feature_getter0, enc_repr, mask_arr, g1_pack, margin=cur_margin) + self.searcher.go(ags) + # then forward and backward + # collect only loss-related actions + toks_all, sent_all = 0, len(insts) + pieces_all, pieces_no_cost, pieces_serr, pieces_valid = 0, 0, 0, 0 + toks_valid = 0 + arc_valid_weights, label_valid_weights = 0., 0. + action_list, arc_weight_list, label_weight_list, bidxes_list = [], [], [], [] + bidx = 0 + # ===== + score_getter = self.searcher.ender.plain_ranker + oracler_ranker = self.searcher.ender.oracle_ranker + # ===== + for one_inst, one_ag in zip(insts, ags): + cur_size = len(one_inst) # excluding ROOT + toks_all += cur_size + # for all the pieces + for sp in one_ag.special_points: + plain_finals, oracle_finals = sp.plain_finals, sp.oracle_finals + best_plain = max(plain_finals, key=score_getter) + best_oracle = max(plain_finals+oracle_finals, key=oracler_ranker) + cost_plain, cost_oracle = best_plain.cost_accu, best_oracle.cost_accu + score_plain, score_oracle = score_getter(best_plain), score_getter(best_oracle) + # if cost_oracle > 0.: # the gold one cannot be searched? + # sent_oracle_has_cost += 1 + # add them + pieces_all += 1 + if cost_plain <= cost_oracle: + pieces_no_cost += 1 + elif score_plain < score_oracle: + pieces_serr += 1 # search error + else: + pieces_valid += 1 + plain_states, oracle_states = best_plain.get_path(), best_oracle.get_path() + toks_valid += len(plain_states) + # Loss = score(best_plain) - score(best_oracle) + cur_aw, cur_lw = self._add_lists(plain_states, oracle_states, action_list, + arc_weight_list, label_weight_list, bidxes_list, bidx, cur_cost0_weight) + arc_valid_weights += cur_aw + label_valid_weights += cur_lw + bidx += 1 + # collect the losses + info = {"sent": sent_all, "tok": toks_all, "tok_valid": toks_valid, "vw_arc": arc_valid_weights, "vw_lab": label_valid_weights, + "pieces_all": pieces_all, "pieces_no_cost": pieces_no_cost, "pieces_serr": pieces_serr, "pieces_valid": pieces_valid} + if toks_valid == 0: + return None, None, info + final_arc_loss_sum, final_label_loss_sum, arc_scores, label_scores = \ + self._loss(enc_repr, action_list, arc_weight_list, label_weight_list, bidxes_list) + # todo(+1): other indicators? + info["loss_sum_arc"] = BK.get_value(final_arc_loss_sum).item() + info["loss_sum_lab"] = BK.get_value(final_label_loss_sum).item() + # how to div + if self.loss_div_step: + if self.loss_div_weights: + cur_div_arc = max(arc_valid_weights, 1.) + cur_div_lab = max(label_valid_weights, 1.) + else: + cur_div_arc = cur_div_lab = (toks_all if self.loss_div_fullbatch else toks_valid) + else: + # todo(warn): here use pieces rather than sentences + cur_div_arc = cur_div_lab = (pieces_all if self.loss_div_fullbatch else pieces_valid) + final_loss = final_arc_loss_sum/cur_div_arc + final_label_loss_sum/cur_div_lab + return final_loss, (arc_scores, label_scores), info + + # the weighted sum + def _loss(self, enc_repr, action_list: List[EfAction], arc_weight_list: List[float], label_weight_list: List[float], + bidxes_list: List[int]): + # 1. collect (batched) features; todo(note): use prev state for scoring + hm_features = self.hm_feature_getter.get_hm_features(action_list, [a.state_from for a in action_list]) + # 2. get new sreprs + scorer = self.scorer + s_enc = self.slayer + bsize_range_t = BK.input_idx(bidxes_list) + node_h_idxes_t, node_h_srepr = ScorerHelper.calc_repr(s_enc, hm_features[0], enc_repr, bsize_range_t) + node_m_idxes_t, node_m_srepr = ScorerHelper.calc_repr(s_enc, hm_features[1], enc_repr, bsize_range_t) + # label loss + if self.system_labeled: + node_lh_expr, _ = scorer.transform_space_label(node_h_srepr, True, False) + _, node_lm_pack = scorer.transform_space_label(node_m_srepr, False, True) + label_scores_full = scorer.score_label(node_lm_pack, node_lh_expr) # [*, Lab] + label_scores = BK.gather_one_lastdim(label_scores_full, [a.label for a in action_list]).squeeze(-1) + final_label_loss_sum = (label_scores * BK.input_real(label_weight_list)).sum() + else: + label_scores = final_label_loss_sum = BK.zeros([]) + # arc loss + node_ah_expr, _ = scorer.transform_space_arc(node_h_srepr, True, False) + _, node_am_pack = scorer.transform_space_arc(node_m_srepr, False, True) + arc_scores = scorer.score_arc(node_am_pack, node_ah_expr).squeeze(-1) + final_arc_loss_sum = (arc_scores * BK.input_real(arc_weight_list)).sum() + # score reg + return final_arc_loss_sum, final_label_loss_sum, arc_scores, label_scores + +# todo(WARN): very interesting! it seems that with the decoder fixed, ef mode can be trained good enough (almost similar to g1), but letting it trainable actually cannot reach that high, hyper-parameter problems or sth else? +# --> another point is that with fix-decoder g1 model, direct ef-decoding can get similar results, but now worse. (seems natural, previously we acutally have a weak/general? decoder that is not that strong) + +# b tasks/zdpar/ef/parser/efp.py:154 diff --git a/tasks/zdpar/ef/parser/g1p.py b/tasks/zdpar/ef/parser/g1p.py new file mode 100644 index 0000000..bc26eed --- /dev/null +++ b/tasks/zdpar/ef/parser/g1p.py @@ -0,0 +1,459 @@ +# + +# first-order graph parser: as init & pruner + +from typing import List +import numpy as np +from copy import deepcopy +import traceback + +from msp.utils import Constants, FileHelper, zlog, Conf +from msp.data import VocabPackage +from msp.nn import BK +from msp.zext.seq_helper import DataPadder + +from ...common.data import ParseInstance +from ...common.model import BaseParserConf, BaseInferenceConf, BaseTrainingConf, BaseParser +from ..scorer import Scorer, ScorerConf +from ...algo import nmst_unproj, nmarginal_unproj + +# ===== +# conf + +class PruneG1Conf(Conf): + def __init__(self): + self.pruning_labeled = True # todo(note): for marginal mode, do not change this which may cause decoding errors + # simple topk pruning + self.pruning_use_topk = False + self.pruning_perc = 1.0 + self.pruning_topk = 4 + self.pruning_gap = 20. + # marginal pruning + self.pruning_use_marginal = True + self.pruning_mthresh = 0.02 + self.pruning_mthresh_rel = True + +# decoding conf +class G1InferenceConf(BaseInferenceConf): + def __init__(self): + super().__init__() + # for plus pruning mode + self.use_pruning = False + self.pruning_conf = PruneG1Conf() + # extra outputs + self.output_marginals = True + +# training conf +class G1TraningConf(BaseTrainingConf): + def __init__(self): + super().__init__() + # loss functions + self.loss_function = "hinge" + self.loss_div_tok = True # loss divide by token or by sent? + +# overall parser conf +class G1ParserConf(BaseParserConf): + def __init__(self): + super().__init__(G1InferenceConf(), G1TraningConf()) + self.sc_conf = ScorerConf() + self.debug_use_aux_scores = False # only for pruning/scoring mode + +# for g1 as pre-training +class PreG1Conf(Conf): + def __init__(self): + # for basic model + self.g1_pretrain_path = "" # basic g1 model for prune/init/score + self.g1_pretrain_init = False # whether init g1 model to the current model + # whether add basic scores + self.lambda_g1_arc_training = 0. + self.lambda_g1_arc_testing = 0. + self.lambda_g1_lab_training = 0. + self.lambda_g1_lab_testing = 0. + # use aux scores for g1 (must be attached to the inst); only for pruning/scoring mode + self.g1_use_aux_scores = False + +# ===== +# model + +# the model +class G1Parser(BaseParser): + def __init__(self, conf: G1ParserConf, vpack: VocabPackage): + super().__init__(conf, vpack) + # todo(note): the neural parameters are exactly the same as the EF one + self.scorer_helper = GScorerHelper(self.scorer) + self.predict_padder = DataPadder(2, pad_vals=0) + # + self.g1_use_aux_scores = conf.debug_use_aux_scores # assining here is only for debugging usage, otherwise assigning outside + self.num_label = self.label_vocab.trg_len(True) # todo(WARN): use the original idx + # + self.loss_hinge = (self.conf.tconf.loss_function == "hinge") + if not self.loss_hinge: + assert self.conf.tconf.loss_function == "prob", "This model only supports hinge or prob" + + def build_decoder(self): + conf = self.conf + # ===== Decoding Scorer ===== + conf.sc_conf._input_dim = self.enc_output_dim + conf.sc_conf._num_label = self.label_vocab.trg_len(True) # todo(WARN): use the original idx + return Scorer(self.pc, conf.sc_conf) + + # ===== + # main procedures + + # decoding + def inference_on_batch(self, insts: List[ParseInstance], **kwargs): + iconf = self.conf.iconf + pconf = iconf.pruning_conf + with BK.no_grad_env(): + self.refresh_batch(False) + if iconf.use_pruning: + # todo(note): for the testing of pruning mode, use the scores instead + if self.g1_use_aux_scores: + valid_mask, arc_score, label_score, mask_expr, _ = G1Parser.score_and_prune(insts, self.num_label, pconf) + else: + valid_mask, arc_score, label_score, mask_expr, _ = self.prune_on_batch(insts, pconf) + valid_mask_f = valid_mask.float() # [*, len, len] + mask_value = Constants.REAL_PRAC_MIN + full_score = arc_score.unsqueeze(-1) + label_score + full_score += (mask_value * (1. - valid_mask_f)).unsqueeze(-1) + info_pruning = G1Parser.collect_pruning_info(insts, valid_mask_f) + jpos_pack = [None, None, None] + else: + input_repr, enc_repr, jpos_pack, mask_arr = self.bter.run(insts, False) + mask_expr = BK.input_real(mask_arr) + full_score = self.scorer_helper.score_full(enc_repr) + info_pruning = None + # ===== + self._decode(insts, full_score, mask_expr, "g1") + # put jpos result (possibly) + self.jpos_decode(insts, jpos_pack) + # ----- + info = {"sent": len(insts), "tok": sum(map(len, insts))} + if info_pruning is not None: + info.update(info_pruning) + return info + + # training + def fb_on_batch(self, annotated_insts: List[ParseInstance], training=True, loss_factor=1., **kwargs): + self.refresh_batch(training) + # encode + input_repr, enc_repr, jpos_pack, mask_arr = self.bter.run(annotated_insts, training) + mask_expr = BK.input_real(mask_arr) + # the parsing loss + arc_score = self.scorer_helper.score_arc(enc_repr) + lab_score = self.scorer_helper.score_label(enc_repr) + full_score = arc_score + lab_score + parsing_loss, info = self._loss(annotated_insts, full_score, mask_expr) + # other loss? + jpos_loss = self.jpos_loss(jpos_pack, mask_expr) + reg_loss = self.reg_scores_loss(arc_score, lab_score) + # + info["loss_parse"] = BK.get_value(parsing_loss).item() + final_loss = parsing_loss + if jpos_loss is not None: + info["loss_jpos"] = BK.get_value(jpos_loss).item() + final_loss = parsing_loss + jpos_loss + if reg_loss is not None: + final_loss = final_loss + reg_loss + info["fb"] = 1 + if training: + BK.backward(final_loss, loss_factor) + return info + + # ===== + def _decode(self, insts: List[ParseInstance], full_score, mask_expr, misc_prefix): + # decode + mst_lengths = [len(z) + 1 for z in insts] # +=1 to include ROOT for mst decoding + mst_lengths_arr = np.asarray(mst_lengths, dtype=np.int32) + mst_heads_arr, mst_labels_arr, mst_scores_arr = nmst_unproj(full_score, mask_expr, mst_lengths_arr, + labeled=True, ret_arr=True) + if self.conf.iconf.output_marginals: + # todo(note): here, we care about marginals for arc + # lab_marginals = nmarginal_unproj(full_score, mask_expr, None, labeled=True) + arc_marginals = nmarginal_unproj(full_score, mask_expr, None, labeled=True).sum(-1) + bsize, max_len = BK.get_shape(mask_expr) + idxes_bs_expr = BK.arange_idx(bsize).unsqueeze(-1) + idxes_m_expr = BK.arange_idx(max_len).unsqueeze(0) + output_marg = arc_marginals[idxes_bs_expr, idxes_m_expr, BK.input_idx(mst_heads_arr)] + mst_marg_arr = BK.get_value(output_marg) + else: + mst_marg_arr = None + # ===== assign, todo(warn): here, the labels are directly original idx, no need to change + for one_idx, one_inst in enumerate(insts): + cur_length = mst_lengths[one_idx] + one_inst.pred_heads.set_vals(mst_heads_arr[one_idx][:cur_length]) # directly int-val for heads + one_inst.pred_labels.build_vals(mst_labels_arr[one_idx][:cur_length], self.label_vocab) + one_scores = mst_scores_arr[one_idx][:cur_length] + one_inst.pred_par_scores.set_vals(one_scores) + # extra output + one_inst.extra_pred_misc[misc_prefix+"_score"] = one_scores.tolist() + if mst_marg_arr is not None: + one_inst.extra_pred_misc[misc_prefix+"_marg"] = mst_marg_arr[one_idx][:cur_length].tolist() + + # here, only adopt hinge(max-margin) loss; mostly adopted from previous graph parser + def _loss(self, annotated_insts: List[ParseInstance], full_score_expr, mask_expr, valid_expr=None): + bsize, max_len = BK.get_shape(mask_expr) + # gold heads and labels + gold_heads_arr, _ = self.predict_padder.pad([z.heads.vals for z in annotated_insts]) + # todo(note): here use the original idx of label, no shift! + gold_labels_arr, _ = self.predict_padder.pad([z.labels.idxes for z in annotated_insts]) + gold_heads_expr = BK.input_idx(gold_heads_arr) # [BS, Len] + gold_labels_expr = BK.input_idx(gold_labels_arr) # [BS, Len] + # + idxes_bs_expr = BK.arange_idx(bsize).unsqueeze(-1) + idxes_m_expr = BK.arange_idx(max_len).unsqueeze(0) + # scores for decoding or marginal + margin = self.margin.value + decoding_scores = full_score_expr.clone().detach() + decoding_scores = self.scorer_helper.postprocess_scores(decoding_scores, mask_expr, margin, gold_heads_expr, gold_labels_expr) + if self.loss_hinge: + mst_lengths_arr = np.asarray([len(z)+1 for z in annotated_insts], dtype=np.int32) + pred_heads_expr, pred_labels_expr, _ = nmst_unproj(decoding_scores, mask_expr, mst_lengths_arr, labeled=True, ret_arr=False) + # ===== add margin*cost, [bs, len] + gold_final_scores = full_score_expr[idxes_bs_expr, idxes_m_expr, gold_heads_expr, gold_labels_expr] + pred_final_scores = full_score_expr[idxes_bs_expr, idxes_m_expr, pred_heads_expr, pred_labels_expr] + margin*(gold_heads_expr!=pred_heads_expr).float() + margin*(gold_labels_expr!=pred_labels_expr).float() # plus margin + hinge_losses = pred_final_scores - gold_final_scores + valid_losses = ((hinge_losses * mask_expr)[:, 1:].sum(-1) > 0.).float().unsqueeze(-1) # [*, 1] + final_losses = hinge_losses * valid_losses + else: + lab_marginals = nmarginal_unproj(decoding_scores, mask_expr, None, labeled=True) + lab_marginals[idxes_bs_expr, idxes_m_expr, gold_heads_expr, gold_labels_expr] -= 1. + grads_masked = lab_marginals*mask_expr.unsqueeze(-1).unsqueeze(-1)*mask_expr.unsqueeze(-2).unsqueeze(-1) + final_losses = (full_score_expr * grads_masked).sum(-1).sum(-1) # [bs, m] + # divide loss by what? + num_sent = len(annotated_insts) + num_valid_tok = sum(len(z) for z in annotated_insts) + # exclude non-valid ones: there can be pruning error + if valid_expr is not None: + final_valids = valid_expr[idxes_bs_expr, idxes_m_expr, gold_heads_expr] # [bs, m] of (0. or 1.) + final_losses = final_losses * final_valids + tok_valid = float(BK.get_value(final_valids[:, 1:].sum())) + assert tok_valid <= num_valid_tok + tok_prune_err = num_valid_tok - tok_valid + else: + tok_prune_err = 0 + # collect loss with mask, also excluding the first symbol of ROOT + final_losses_masked = (final_losses * mask_expr)[:, 1:] + final_loss_sum = BK.sum(final_losses_masked) + if self.conf.tconf.loss_div_tok: + final_loss = final_loss_sum / num_valid_tok + else: + final_loss = final_loss_sum / num_sent + final_loss_sum_val = float(BK.get_value(final_loss_sum)) + info = {"sent": num_sent, "tok": num_valid_tok, "tok_prune_err": tok_prune_err, "loss_sum": final_loss_sum_val} + return final_loss, info + + # ===== + # special preloading + @staticmethod + def special_pretrain_load(m: BaseParser, path, strict): + if FileHelper.isfile(path): + try: + zlog(f"Trying to load pretrained model from {path}") + m.load(path, strict) + zlog(f"Finished loading pretrained model from {path}") + return True + except: + zlog(traceback.format_exc()) + zlog("Failed loading, keep the original ones.") + else: + zlog(f"File does not exist for pretraining loading: {path}") + return False + + # init for the pre-trained G1, return g1parser and also modify m's params + @staticmethod + def pre_g1_init(m: BaseParser, pg1_conf: PreG1Conf, strict=True): + # ===== basic G1 Parser's loading + # todo(WARN): construct the g1conf here instead of loading for simplicity, since the scorer architecture should be the same + g1conf = G1ParserConf() + g1conf.bt_conf = deepcopy(m.conf.bt_conf) + g1conf.sc_conf = deepcopy(m.conf.sc_conf) + g1conf.validate() + # todo(note): specific setting + g1conf.g1_use_aux_scores = pg1_conf.g1_use_aux_scores + # + g1parser = G1Parser(g1conf, m.vpack) + if not G1Parser.special_pretrain_load(g1parser, pg1_conf.g1_pretrain_path, strict): + g1parser = None + # current init + if pg1_conf.g1_pretrain_init: + G1Parser.special_pretrain_load(m, pg1_conf.g1_pretrain_path, strict) + return g1parser + + # collect and batch scores + @staticmethod + def collect_aux_scores(insts: List[ParseInstance], output_num_label): + score_tuples = [z.extra_features["aux_score"] for z in insts] + num_label = score_tuples[0][1].shape[-1] + max_len = max(len(z)+1 for z in insts) + mask_value = Constants.REAL_PRAC_MIN + bsize = len(insts) + arc_score_arr = np.full([bsize, max_len, max_len], mask_value, dtype=np.float32) + lab_score_arr = np.full([bsize, max_len, max_len, output_num_label], mask_value, dtype=np.float32) + mask_arr = np.full([bsize, max_len], 0., dtype=np.float32) + for bidx, one_tuple in enumerate(score_tuples): + one_score_arc, one_score_lab = one_tuple + one_len = one_score_arc.shape[1] + arc_score_arr[bidx, :one_len, :one_len] = one_score_arc + lab_score_arr[bidx, :one_len, :one_len, -num_label:] = one_score_lab + mask_arr[bidx, :one_len] = 1. + return BK.input_real(arc_score_arr).unsqueeze(-1), BK.input_real(lab_score_arr), BK.input_real(mask_arr) + + # pruner: [bs, slen, slen], [bs, slen, slen, Lab] + @staticmethod + def prune_with_scores(arc_score, label_score, mask_expr, pconf: PruneG1Conf, arc_marginals=None): + prune_use_topk, prune_use_marginal, prune_labeled, prune_perc, prune_topk, prune_gap, prune_mthresh, prune_mthresh_rel = \ + pconf.pruning_use_topk, pconf.pruning_use_marginal, pconf.pruning_labeled, pconf.pruning_perc, pconf.pruning_topk, \ + pconf.pruning_gap, pconf.pruning_mthresh, pconf.pruning_mthresh_rel + full_score = arc_score + label_score + final_valid_mask = BK.constants(BK.get_shape(arc_score), 0, dtype=BK.uint8).squeeze(-1) + # (put as argument) arc_marginals = None # [*, mlen, hlen] + if prune_use_marginal: + if arc_marginals is None: # does not provided, calculate from scores + if prune_labeled: + # arc_marginals = nmarginal_unproj(full_score, mask_expr, None, labeled=True).max(-1)[0] + # use sum of label marginals instead of max + arc_marginals = nmarginal_unproj(full_score, mask_expr, None, labeled=True).sum(-1) + else: + arc_marginals = nmarginal_unproj(arc_score, mask_expr, None, labeled=True).squeeze(-1) + if prune_mthresh_rel: + # relative value + max_arc_marginals = arc_marginals.max(-1)[0].log().unsqueeze(-1) + m_valid_mask = (arc_marginals.log() - max_arc_marginals) > float(np.log(prune_mthresh)) + else: + # absolute value + m_valid_mask = (arc_marginals > prune_mthresh) # [*, len-m, len-h] + final_valid_mask |= m_valid_mask + if prune_use_topk: + # prune by "in topk" and "gap-to-top less than gap" for each mod + if prune_labeled: # take argmax among label dim + tmp_arc_score, _ = full_score.max(-1) + else: + # todo(note): may be modified inplaced, but does not matter since will finally be masked later + tmp_arc_score = arc_score.squeeze(-1) + # first apply mask + mask_value = Constants.REAL_PRAC_MIN + mask_mul = (mask_value * (1. - mask_expr)) # [*, len] + tmp_arc_score += mask_mul.unsqueeze(-1) + tmp_arc_score += mask_mul.unsqueeze(-2) + maxlen = BK.get_shape(tmp_arc_score, -1) + tmp_arc_score += mask_value * BK.eye(maxlen) + prune_topk = min(prune_topk, int(maxlen * prune_perc + 1), maxlen) + if prune_topk >= maxlen: + topk_arc_score = tmp_arc_score + else: + topk_arc_score, _ = BK.topk(tmp_arc_score, prune_topk, dim=-1, sorted=False) # [*, len, k] + min_topk_arc_score = topk_arc_score.min(-1)[0].unsqueeze(-1) # [*, len, 1] + max_topk_arc_score = topk_arc_score.max(-1)[0].unsqueeze(-1) # [*, len, 1] + arc_score_thresh = BK.max_elem(min_topk_arc_score, max_topk_arc_score - prune_gap) # [*, len, 1] + t_valid_mask = (tmp_arc_score > arc_score_thresh) # [*, len-m, len-h] + final_valid_mask |= t_valid_mask + return final_valid_mask, arc_marginals + + # combining the above two + @staticmethod + def score_and_prune(insts: List[ParseInstance], output_num_label, pconf: PruneG1Conf): + arc_score, lab_score, mask_expr = G1Parser.collect_aux_scores(insts, output_num_label) + valid_mask, arc_marginals = G1Parser.prune_with_scores(arc_score, lab_score, mask_expr, pconf) + return valid_mask, arc_score.squeeze(-1), lab_score, mask_expr, arc_marginals + + # ===== + # special mode: first-order model as scorer/pruner + def score_on_batch(self, insts: List[ParseInstance]): + with BK.no_grad_env(): + self.refresh_batch(False) + input_repr, enc_repr, jpos_pack, mask_arr = self.bter.run(insts, False) + # mask_expr = BK.input_real(mask_arr) + arc_score = self.scorer_helper.score_arc(enc_repr) + label_score = self.scorer_helper.score_label(enc_repr) + return arc_score.squeeze(-1), label_score + + # todo(note): the union of two types of pruner! + def prune_on_batch(self, insts: List[ParseInstance], pconf: PruneG1Conf): + with BK.no_grad_env(): + self.refresh_batch(False) + # encode + input_repr, enc_repr, jpos_pack, mask_arr = self.bter.run(insts, False) + mask_expr = BK.input_real(mask_arr) + arc_score = self.scorer_helper.score_arc(enc_repr) + label_score = self.scorer_helper.score_label(enc_repr) + final_valid_mask, arc_marginals = G1Parser.prune_with_scores(arc_score, label_score, mask_expr, pconf) + return final_valid_mask, arc_score.squeeze(-1), label_score, mask_expr, arc_marginals + + # ===== + @staticmethod + def collect_pruning_info(insts: List[ParseInstance], valid_mask_f): + # two dimensions: coverage and pruning-effect + maxlen = BK.get_shape(valid_mask_f, -1) + # 1. coverage + valid_mask_f_flattened = valid_mask_f.view([-1, maxlen]) # [bs*len, len] + cur_mod_base = 0 + all_mods, all_heads = [], [] + for cur_idx, cur_inst in enumerate(insts): + for m, h in enumerate(cur_inst.heads.vals[1:], 1): + all_mods.append(m+cur_mod_base) + all_heads.append(h) + cur_mod_base += maxlen + cov_count = len(all_mods) + cov_valid = BK.get_value(valid_mask_f_flattened[all_mods, all_heads].sum()).item() + # 2. pruning-rate + # todo(warn): to speed up, these stats are approximate because of including paddings + # edges + pr_edges = int(np.prod(BK.get_shape(valid_mask_f))) + pr_edges_valid = BK.get_value(valid_mask_f.sum()).item() + # valid as structured heads + pr_o2_sib = pr_o2_g = pr_edges + pr_o3_gsib = maxlen*pr_edges + valid_chs_counts, valid_par_counts = valid_mask_f.sum(-2), valid_mask_f.sum(-1) # [*, len] + valid_gsibs = valid_chs_counts*valid_par_counts + pr_o2_sib_valid = BK.get_value(valid_chs_counts.sum()).item() + pr_o2_g_valid = BK.get_value(valid_par_counts.sum()).item() + pr_o3_gsib_valid = BK.get_value(valid_gsibs.sum()).item() + return {"cov_count": cov_count, "cov_valid": cov_valid, "pr_edges": pr_edges, "pr_edges_valid": pr_edges_valid, + "pr_o2_sib": pr_o2_sib, "pr_o2_g": pr_o2_g, "pr_o3_gsib": pr_o3_gsib, + "pr_o2_sib_valid": pr_o2_sib_valid, "pr_o2_g_valid": pr_o2_g_valid, "pr_o3_gsib_valid": pr_o3_gsib_valid} + +# ===== +class GScorerHelper: + def __init__(self, scorer: Scorer): + self.scorer = scorer + + # first order full score: return [*, len-m, len-h, label] + def score_full(self, enc_repr): + arc_score = self.scorer.transform_and_arc_score(enc_repr) + label_score = self.scorer.transform_and_label_score(enc_repr) + # todo(note): apply masks/margins later + return arc_score + label_score + + def score_arc(self, enc_repr): + arc_score = self.scorer.transform_and_arc_score(enc_repr) + # todo(note): apply masks/margins later + return arc_score + + def score_label(self, enc_repr): + label_score = self.scorer.transform_and_label_score(enc_repr) + # todo(note): apply masks/margins later + return label_score + + # apply mask and add margins for the score in training + # todo(note): inplaced process! + def postprocess_scores(self, scores_expr, mask_expr, margin, gold_heads_expr, gold_labels_expr): + final_full_scores = scores_expr + # first apply mask + mask_value = Constants.REAL_PRAC_MIN + mask_mul = (mask_value * (1.-mask_expr)).unsqueeze(-1) # [*, len, 1] + final_full_scores += mask_mul.unsqueeze(-2) + final_full_scores += mask_mul.unsqueeze(-3) + # then margin + if margin > 0.: + full_shape = BK.get_shape(final_full_scores) + # combine the first two dim, and minus margin correspondingly + combined_size = full_shape[0] * full_shape[1] + combiend_score_expr = final_full_scores.view([combined_size] + full_shape[-2:]) + arange_idx_expr = BK.arange_idx(combined_size) + combiend_score_expr[arange_idx_expr, gold_heads_expr.view(-1)] -= margin + combiend_score_expr[arange_idx_expr, gold_heads_expr.view(-1), gold_labels_expr.view(-1)] -= margin + final_full_scores = combiend_score_expr.view(full_shape) + return final_full_scores + +# tasks/zdpar/ef/parser/g1p.py:138 diff --git a/tasks/zdpar/ef/parser/g2p.py b/tasks/zdpar/ef/parser/g2p.py new file mode 100644 index 0000000..3a23a3a --- /dev/null +++ b/tasks/zdpar/ef/parser/g2p.py @@ -0,0 +1,657 @@ +# + +# the graph(v2, order>=2) parser + +from typing import List +import numpy as np + +from msp.utils import Constants, zlog, Helper +from msp.data import VocabPackage +from msp.nn import BK +from msp.zext.seq_helper import DataPadder + +from ...common.data import ParseInstance +from ...common.model import BaseParserConf, BaseInferenceConf, BaseTrainingConf, BaseParser +from ..scorer import Scorer, ScorerConf, SL0Layer, SL0Conf +from ...algo import hop_decode + +from .g1p import G1Parser, PreG1Conf, PruneG1Conf + +# ===== +# conf + +# todo(note): notes about g2p/efp and g1p +# 1) first, we may need g1p as pruner/scorer (g2p requires pruner), which may come from g1model or aux-score-files +# -- for usage for training, should use aux files which comes from leave-one-out styled scoring +# RELATED-OPTIONS: PruneG1Conf, PreG1Conf.{g1_use_aux_scores, lambda_g1_training, lambda_g1_testing} +# 2) then, there is also an option for init the current models with pre-trained go1 model +# RELATED-OPTIONS: PreG1Conf.g1_pretrain_* + +# decoding conf +class G2InferenceConf(BaseInferenceConf): + def __init__(self): + super().__init__() + # prune + self.pruning_conf = PruneG1Conf() + # mini-dec budget + self.mb_dec_lb = 128 # batch decoding length budget + self.mb_dec_sb = Constants.INT_PRAC_MAX # scoring once budget (number of scoring parts) + +# training conf +class G2TraningConf(BaseTrainingConf): + def __init__(self): + super().__init__() + # loss functions + self.loss_div_tok = True # loss divide by token or by sent? + # filtering out pruned parts + self.filter_pruned = True + self.filter_margin = True + +# overall parser conf +class G2ParserConf(BaseParserConf): + def __init__(self): + super().__init__(G2InferenceConf(), G2TraningConf()) + self.sc_conf = ScorerConf() + self.sl_conf = SL0Conf() + self.pre_g1_conf = PreG1Conf() + self.system_labeled = True + # which graph-based model? + self.gm_type = "o3gsib" # o1/o2sib/o2g/o3gsib + self.gm_projective = False + + def do_validate(self): + # specific scorer + self.sl_conf.use_label_feat = False # no label features, otherwise will explode + # currently only support 1 sib + self.sl_conf.chs_num = 1 + self.sl_conf.chs_f = "sum" + +# ===== +# model + +# the model +# todo(note): for simplicity, aggregate the features (sib, grandparent) at the head-node's srepr +class G2Parser(BaseParser): + def __init__(self, conf: G2ParserConf, vpack: VocabPackage): + super().__init__(conf, vpack) + # todo(note): the neural parameters are exactly the same as the EF one + # ===== basic G1 Parser's loading + # todo(note): there can be parameter mismatch (but all of them in non-trained part, thus will be fine) + self.g1parser = G1Parser.pre_g1_init(self, conf.pre_g1_conf) + self.lambda_g1_arc_training = conf.pre_g1_conf.lambda_g1_arc_training + self.lambda_g1_arc_testing = conf.pre_g1_conf.lambda_g1_arc_testing + self.lambda_g1_lab_training = conf.pre_g1_conf.lambda_g1_lab_training + self.lambda_g1_lab_testing = conf.pre_g1_conf.lambda_g1_lab_testing + # + self.add_slayer() + self.dl = G2DL(self.scorer, self.slayer, conf) + # + self.predict_padder = DataPadder(2, pad_vals=0) + self.num_label = self.label_vocab.trg_len(True) # todo(WARN): use the original idx + + def build_decoder(self): + conf = self.conf + # ===== Decoding Scorer ===== + conf.sc_conf._input_dim = self.enc_output_dim + conf.sc_conf._num_label = self.label_vocab.trg_len(True) # todo(WARN): use the original idx + return Scorer(self.pc, conf.sc_conf) + + def build_slayer(self): + conf = self.conf + # ===== Structured Layer ===== + conf.sl_conf._input_dim = self.enc_output_dim + conf.sl_conf._num_label = self.label_vocab.trg_len(True) # todo(WARN): use the original idx + return SL0Layer(self.pc, conf.sl_conf) + + # ===== + # main procedures + + # get g1 prunings and extra-scores + # -> if no g1parser provided, then use aux scores; otherwise, it depends on g1_use_aux_scores + def _get_g1_pack(self, insts: List[ParseInstance], score_arc_lambda: float, score_lab_lambda: float): + pconf = self.conf.iconf.pruning_conf + if self.g1parser is None or self.g1parser.g1_use_aux_scores: + valid_mask, arc_score, lab_score, _, _ = G1Parser.score_and_prune(insts, self.num_label, pconf) + else: + valid_mask, arc_score, lab_score, _, _ = self.g1parser.prune_on_batch(insts, pconf) + if score_arc_lambda <= 0. and score_lab_lambda <= 0.: + go1_pack = None + else: + arc_score *= score_arc_lambda + lab_score *= score_lab_lambda + go1_pack = (arc_score, lab_score) + # [*, slen, slen], ([*, slen, slen], [*, slen, slen, Lab]) + return valid_mask, go1_pack + + # make real valid masks (inplaced): valid(byte): [bs, len-m, len-h]; mask(float): [bs, len] + def _make_final_valid(self, valid_expr, mask_expr): + maxlen = BK.get_shape(mask_expr, -1) + # first apply masks + mask_expr_byte = mask_expr.byte() + valid_expr &= mask_expr_byte.unsqueeze(-1) + valid_expr &= mask_expr_byte.unsqueeze(-2) + # then diag + mask_diag = 1-BK.eye(maxlen).byte() + valid_expr &= mask_diag + # root not as mod + valid_expr[:, 0] = 0 + # only allow root->root (for grandparent feature) + valid_expr[:, 0, 0] = 1 + return valid_expr + + # decoding + def inference_on_batch(self, insts: List[ParseInstance], **kwargs): + # iconf = self.conf.iconf + with BK.no_grad_env(): + self.refresh_batch(False) + # pruning and scores from g1 + valid_mask, go1_pack = self._get_g1_pack(insts, self.lambda_g1_arc_testing, self.lambda_g1_lab_testing) + # encode + input_repr, enc_repr, jpos_pack, mask_arr = self.bter.run(insts, False) + mask_expr = BK.input_real(mask_arr) + # decode + final_valid_expr = self._make_final_valid(valid_mask, mask_expr) + ret_heads, ret_labels, _, _ = self.dl.decode(insts, enc_repr, final_valid_expr, go1_pack, False, 0.) + # collect the results together + all_heads = Helper.join_list(ret_heads) + if ret_labels is None: + # todo(note): simply get labels from the go1-label classifier; must provide g1parser + if go1_pack is None: + _, go1_pack = self._get_g1_pack(insts, 1., 1.) + _, go1_label_max_idxes = go1_pack[1].max(-1) # [bs, slen, slen] + pred_heads_arr, _ = self.predict_padder.pad(all_heads) # [bs, slen] + pred_heads_expr = BK.input_idx(pred_heads_arr) + pred_labels_expr = BK.gather_one_lastdim(go1_label_max_idxes, pred_heads_expr).squeeze(-1) + all_labels = BK.get_value(pred_labels_expr) # [bs, slen] + else: + all_labels = np.concatenate(ret_labels, 0) + # ===== assign, todo(warn): here, the labels are directly original idx, no need to change + for one_idx, one_inst in enumerate(insts): + cur_length = len(one_inst)+1 + one_inst.pred_heads.set_vals(all_heads[one_idx][:cur_length]) # directly int-val for heads + one_inst.pred_labels.build_vals(all_labels[one_idx][:cur_length], self.label_vocab) + # one_inst.pred_par_scores.set_vals(all_scores[one_idx][:cur_length]) + # ===== + # put jpos result (possibly) + self.jpos_decode(insts, jpos_pack) + # ----- + info = {"sent": len(insts), "tok": sum(map(len, insts))} + return info + + # training + def fb_on_batch(self, annotated_insts: List[ParseInstance], training=True, loss_factor=1., **kwargs): + self.refresh_batch(training) + # pruning and scores from g1 + valid_mask, go1_pack = self._get_g1_pack(annotated_insts, self.lambda_g1_arc_training, self.lambda_g1_lab_training) + # encode + input_repr, enc_repr, jpos_pack, mask_arr = self.bter.run(annotated_insts, training) + mask_expr = BK.input_real(mask_arr) + # the parsing loss + final_valid_expr = self._make_final_valid(valid_mask, mask_expr) + parsing_loss, parsing_scores, info = \ + self.dl.loss(annotated_insts, enc_repr, final_valid_expr, go1_pack, True, self.margin.value) + info["loss_parse"] = BK.get_value(parsing_loss).item() + final_loss = parsing_loss + # other loss? + jpos_loss = self.jpos_loss(jpos_pack, mask_expr) + if jpos_loss is not None: + info["loss_jpos"] = BK.get_value(jpos_loss).item() + final_loss = parsing_loss + jpos_loss + if parsing_scores is not None: + reg_loss = self.reg_scores_loss(*parsing_scores) + if reg_loss is not None: + final_loss = final_loss + reg_loss + info["fb"] = 1 + if training: + BK.backward(final_loss, loss_factor) + return info + +# ===== +# helper for different models + +# conventions: +# [m, h, sib, gp]: sib=m means first child; ROOT's parent is ROOT again; and force single root +# warnings: +# use the same scorer as ef, but different semantics + +# decoder and losser +class G2DL: + def __init__(self, scorer: Scorer, slayer: SL0Layer, conf: G2ParserConf): + self.scorer = scorer + self.slayer = slayer + self.gm_type = conf.gm_type + self.projective = conf.gm_projective + self.system_labeled = conf.system_labeled + # system specific + self.helper = {"o1": G2O1Helper(), "o2sib": G2O2sibHelper(), "o2g": G2O2gHelper(), "o3gsib": G2O3gsibHelper()}[self.gm_type] + self.margin_div = {"o1": 1, "o2sib": 2, "o2g": 2, "o3gsib": 3}[self.gm_type] + self.use_sib = "sib" in self.gm_type + self.use_gp = "g" in self.gm_type + # others + tconf: G2TraningConf = conf.tconf + self.filter_pruned = tconf.filter_pruned + self.filter_margin = tconf.filter_margin + self.loss_div_tok = tconf.loss_div_tok + iconf: G2InferenceConf = conf.iconf + self.mb_dec_lb = iconf.mb_dec_lb + self.mb_dec_sb = iconf.mb_dec_sb + + # get parsing loss (perceptron styled) + # todo(+3): assume the same margin for both arc and label + def loss(self, insts: List[ParseInstance], enc_expr, final_valid_expr, go1_pack, training: bool, margin: float): + # first do decoding and related preparation + with BK.no_grad_env(): + _, _, g_packs, p_packs = self.decode(insts, enc_expr, final_valid_expr, go1_pack, training, margin) + # flatten the packs (remember to rebase the indexes) + gold_pack = self._flatten_packs(g_packs) + pred_pack = self._flatten_packs(p_packs) + if self.filter_pruned: + # filter out non-valid (pruned) edges, to avoid prune error + mod_unpruned_mask, gold_mask = self.helper.get_unpruned_mask(final_valid_expr, gold_pack) + pred_mask = mod_unpruned_mask[pred_pack[0], pred_pack[1]] # filter by specific mod + gold_pack = [(None if z is None else z[gold_mask]) for z in gold_pack] + pred_pack = [(None if z is None else z[pred_mask]) for z in pred_pack] + # calculate the scores for loss + gold_b_idxes, gold_m_idxes, gold_h_idxes, gold_sib_idxes, gold_gp_idxes, gold_lab_idxes = gold_pack + pred_b_idxes, pred_m_idxes, pred_h_idxes, pred_sib_idxes, pred_gp_idxes, pred_lab_idxes = pred_pack + gold_arc_score, gold_label_score_all = self._get_basic_score(enc_expr, gold_b_idxes, gold_m_idxes, gold_h_idxes, + gold_sib_idxes, gold_gp_idxes) + pred_arc_score, pred_label_score_all = self._get_basic_score(enc_expr, pred_b_idxes, pred_m_idxes, pred_h_idxes, + pred_sib_idxes, pred_gp_idxes) + # whether have labeled scores + if self.system_labeled: + gold_label_score = BK.gather_one_lastdim(gold_label_score_all, gold_lab_idxes).squeeze(-1) + pred_label_score = BK.gather_one_lastdim(pred_label_score_all, pred_lab_idxes).squeeze(-1) + ret_scores = (gold_arc_score, pred_arc_score, gold_label_score, pred_label_score) + pred_full_scores, gold_full_scores = pred_arc_score+pred_label_score, gold_arc_score+gold_label_score + else: + ret_scores = (gold_arc_score, pred_arc_score) + pred_full_scores, gold_full_scores = pred_arc_score, gold_arc_score + # hinge loss: filter-margin by loss*margin to be aware of search error + if self.filter_margin: + with BK.no_grad_env(): + mat_shape = BK.get_shape(enc_expr)[:2] # [bs, slen] + heads_gold = self._get_tmp_mat(mat_shape, 0, BK.int64, gold_b_idxes, gold_m_idxes, gold_h_idxes) + heads_pred = self._get_tmp_mat(mat_shape, 0, BK.int64, pred_b_idxes, pred_m_idxes, pred_h_idxes) + error_count = (heads_gold != heads_pred).float() + if self.system_labeled: + labels_gold = self._get_tmp_mat(mat_shape, 0, BK.int64, gold_b_idxes, gold_m_idxes, gold_lab_idxes) + labels_pred = self._get_tmp_mat(mat_shape, 0, BK.int64, pred_b_idxes, pred_m_idxes, pred_lab_idxes) + error_count += (labels_gold != labels_pred).float() + scores_gold = self._get_tmp_mat(mat_shape, 0., BK.float32, gold_b_idxes, gold_m_idxes, gold_full_scores) + scores_pred = self._get_tmp_mat(mat_shape, 0., BK.float32, pred_b_idxes, pred_m_idxes, pred_full_scores) + # todo(note): here, a small 0.1 is to exclude zero error: anyway they will get zero gradient + sent_mask = ((scores_gold.sum(-1) - scores_pred.sum(-1)) <= (margin * error_count.sum(-1) + 0.1)).float() + num_valid_sent = float(BK.get_value(sent_mask.sum())) + final_loss_sum = (pred_full_scores*sent_mask[pred_b_idxes] - gold_full_scores*sent_mask[gold_b_idxes]).sum() + else: + num_valid_sent = len(insts) + final_loss_sum = (pred_full_scores - gold_full_scores).sum() + # prepare final loss + # divide loss by what? + num_sent = len(insts) + num_valid_tok = sum(len(z) for z in insts) + if self.loss_div_tok: + final_loss = final_loss_sum / num_valid_tok + else: + final_loss = final_loss_sum / num_sent + final_loss_sum_val = float(BK.get_value(final_loss_sum)) + info = {"sent": num_sent, "sent_valid": num_valid_sent, "tok": num_valid_tok, "loss_sum": final_loss_sum_val} + return final_loss, ret_scores, info + + def _flatten_packs(self, packs): + NUM_RET_PACK = 6 # discard the first mb-size + ret_packs = [[] for _ in range(NUM_RET_PACK)] + cur_base_idx = 0 + for one_pack in packs: + mb_size = one_pack[0] + ret_packs[0].append(one_pack[1]+cur_base_idx) + for i in range(1, NUM_RET_PACK): + ret_packs[i].append(one_pack[i+1]) + cur_base_idx += mb_size + ret = [(None if z[0] is None else BK.concat(z, 0)) for z in ret_packs] + return ret + + # get tmp mat of (bsize, mod) for special calculations + def _get_tmp_mat(self, shape, val, dtype, idx0, idx1, vals): + x = BK.constants(shape, val, dtype=dtype) + x[idx0, idx1] = vals + return x + + # enc: [bs, len, D], valid: [bs, len-m, len-h], mask: [bs, len] + # todo(+N): currently does not support multiple high-order parts + def decode(self, insts: List[ParseInstance], enc_expr, final_valid_expr, go1_pack, training: bool, margin: float): + # ===== + has_go1 = go1_pack is not None + if has_go1: + go1_arc_scores, go1_lab_scores = go1_pack + else: + go1_arc_scores = go1_lab_scores = None + # collect scores (split into mini-batches to save memory) + batch_size, max_length = BK.get_shape(final_valid_expr)[:2] + cur_base_bidx = 0 + ret_heads, ret_labels, ret_gold_packs, ret_pred_packs = [], [], [], [] + # dynamically deciding the mini-batch size based on budget + cur_mb_batch_size = max(int(self.mb_dec_lb / max_length), 1) + while cur_base_bidx < batch_size: + next_base_bidx = min(batch_size, cur_base_bidx+cur_mb_batch_size) + # decode for current mini-batch + mb_insts = insts[cur_base_bidx:next_base_bidx] + mb_enc_expr = enc_expr[cur_base_bidx:next_base_bidx] # [mb, slen, D] + mb_final_valid_expr = final_valid_expr[cur_base_bidx:next_base_bidx] # [mb, slen, slen] + if has_go1: + mb_go1_pack = (go1_arc_scores[cur_base_bidx:next_base_bidx], go1_lab_scores[cur_base_bidx:next_base_bidx]) + else: + mb_go1_pack = None + res_heads, res_labels, gold_pack, pred_pack = \ + self._decode(mb_insts, mb_enc_expr, mb_final_valid_expr, mb_go1_pack, training, margin) + # todo(note): temply put them like this, later will flatten specifically + ret_heads.append(res_heads) # List[List[List[Int]]] + ret_labels.append(res_labels) # List[np.array[mb_size, slen]] + ret_gold_packs.append(gold_pack) # List[idxes-pack] + ret_pred_packs.append(pred_pack) # List[idxes-pack] + cur_base_bidx = next_base_bidx + if not self.system_labeled: + ret_labels = None + return ret_heads, ret_labels, ret_gold_packs, ret_pred_packs + + # decode for one mini-batch, to be implemented + # list[insts], [*, slen, D], [*, slen, slen], go1_pack(arc/label scores) -> List[results/heads] + def _decode(self, mb_insts: List[ParseInstance], mb_enc_expr, mb_valid_expr, mb_go1_pack, training: bool, margin: float): + # ===== + use_sib, use_gp = self.use_sib, self.use_gp + # ===== + mb_size = len(mb_insts) + mat_shape = BK.get_shape(mb_valid_expr) + max_slen = mat_shape[-1] + # step 1: extract the candidate features + batch_idxes, m_idxes, h_idxes, sib_idxes, gp_idxes = self.helper.get_cand_features(mb_valid_expr) + # ===== + # step 2: high order scoring + # step 2.1: basic scoring, [*], [*, Lab] + arc_scores, lab_scores = self._get_basic_score(mb_enc_expr, batch_idxes, m_idxes, h_idxes, sib_idxes, gp_idxes) + cur_system_labeled = (lab_scores is not None) + # step 2.2: margin + # get gold labels, which can be useful for later calculating loss + if training: + gold_b_idxes, gold_m_idxes, gold_h_idxes, gold_sib_idxes, gold_gp_idxes, gold_lab_idxes = \ + [(None if z is None else BK.input_idx(z)) for z in self._get_idxes_from_insts(mb_insts, use_sib, use_gp)] + # add the margins to the scores: (m,h), (m,sib), (m,gp) + cur_margin = margin / self.margin_div + self._add_margin_inplaced(mat_shape, gold_b_idxes, gold_m_idxes, gold_h_idxes, gold_lab_idxes, + batch_idxes, m_idxes, h_idxes, arc_scores, lab_scores, cur_margin, cur_margin) + if use_sib: + self._add_margin_inplaced(mat_shape, gold_b_idxes, gold_m_idxes, gold_sib_idxes, gold_lab_idxes, + batch_idxes, m_idxes, sib_idxes, arc_scores, lab_scores, cur_margin, cur_margin) + if use_gp: + self._add_margin_inplaced(mat_shape, gold_b_idxes, gold_m_idxes, gold_gp_idxes, gold_lab_idxes, + batch_idxes, m_idxes, gp_idxes, arc_scores, lab_scores, cur_margin, cur_margin) + # may be useful for later training + gold_pack = (mb_size, gold_b_idxes, gold_m_idxes, gold_h_idxes, gold_sib_idxes, gold_gp_idxes, gold_lab_idxes) + else: + gold_pack = None + # step 2.3: o1scores + if mb_go1_pack is not None: + go1_arc_scores, go1_lab_scores = mb_go1_pack + # todo(note): go1_arc_scores is not added here, but as the input to the dec-algo + if cur_system_labeled: + lab_scores += go1_lab_scores[batch_idxes, m_idxes, h_idxes] + else: + go1_arc_scores = None + # step 2.4: max out labels; todo(+N): or using logsumexp here? + if cur_system_labeled: + max_lab_scores, max_lab_idxes = lab_scores.max(-1) + final_scores = arc_scores + max_lab_scores # [*], final input arc scores + else: + max_lab_idxes = None + final_scores = arc_scores + # ===== + # step 3: actual decode + res_heads = [] + for sid, inst in enumerate(mb_insts): + slen = len(inst) + 1 # plus one for the art-root + arr_o1_masks = BK.get_value(mb_valid_expr[sid, :slen, :slen].int()) + arr_o1_scores = BK.get_value(go1_arc_scores[sid, :slen, :slen].double()) if (go1_arc_scores is not None) else None + cur_bidx_mask = (batch_idxes == sid) + input_pack = [m_idxes, h_idxes, sib_idxes, gp_idxes, final_scores] + one_heads = self.helper.decode_one(slen, self.projective, arr_o1_masks, arr_o1_scores, input_pack, cur_bidx_mask) + res_heads.append(one_heads) + # ===== + # step 4: get labels back and pred_pack + pred_b_idxes, pred_m_idxes, pred_h_idxes, pred_sib_idxes, pred_gp_idxes, _ = \ + [(None if z is None else BK.input_idx(z)) for z in self._get_idxes_from_preds(res_heads, None, use_sib, use_gp)] + if cur_system_labeled: + # obtain hit components + pred_hit_mask = self._get_hit_mask(mat_shape, pred_b_idxes, pred_m_idxes, pred_h_idxes, batch_idxes, m_idxes, h_idxes) + if use_sib: + pred_hit_mask &= self._get_hit_mask(mat_shape, pred_b_idxes, pred_m_idxes, pred_sib_idxes, batch_idxes, m_idxes, sib_idxes) + if use_gp: + pred_hit_mask &= self._get_hit_mask(mat_shape, pred_b_idxes, pred_m_idxes, pred_gp_idxes, batch_idxes, m_idxes, gp_idxes) + # get pred labels (there should be only one hit per mod!) + pred_labels = BK.constants_idx([mb_size, max_slen], 0) + pred_labels[batch_idxes[pred_hit_mask], m_idxes[pred_hit_mask]] = max_lab_idxes[pred_hit_mask] + res_labels = BK.get_value(pred_labels) + pred_lab_idxes = pred_labels[pred_b_idxes, pred_m_idxes] + else: + res_labels = None + pred_lab_idxes = None + pred_pack = (mb_size, pred_b_idxes, pred_m_idxes, pred_h_idxes, pred_sib_idxes, pred_gp_idxes, pred_lab_idxes) + # return + return res_heads, res_labels, gold_pack, pred_pack + + # ===== + # helper functions + + # get high order scores: [*, slen, D], *[*] + # split into multiple times if necessary + def _get_basic_score(self, mb_enc_expr, batch_idxes, m_idxes, h_idxes, sib_idxes, gp_idxes): + allp_size = BK.get_shape(batch_idxes, 0) + all_arc_scores, all_lab_scores = [], [] + cur_pidx = 0 + while cur_pidx < allp_size: + next_pidx = min(allp_size, cur_pidx + self.mb_dec_sb) + # first calculate srepr + s_enc = self.slayer + cur_batch_idxes = batch_idxes[cur_pidx:next_pidx] + h_expr = mb_enc_expr[cur_batch_idxes, h_idxes[cur_pidx:next_pidx]] + m_expr = mb_enc_expr[cur_batch_idxes, m_idxes[cur_pidx:next_pidx]] + s_expr = mb_enc_expr[cur_batch_idxes, sib_idxes[cur_pidx:next_pidx]].unsqueeze(-2) \ + if (sib_idxes is not None) else None # [*, 1, D] + g_expr = mb_enc_expr[cur_batch_idxes, gp_idxes[cur_pidx:next_pidx]] if (gp_idxes is not None) else None + head_srepr = s_enc.calculate_repr(h_expr, g_expr, None, None, s_expr, None, None, None) + mod_srepr = s_enc.forward_repr(m_expr) + # then get the scores + arc_score = self.scorer.transform_and_arc_score_plain(mod_srepr, head_srepr).squeeze(-1) + all_arc_scores.append(arc_score) + if self.system_labeled: + lab_score = self.scorer.transform_and_label_score_plain(mod_srepr, head_srepr) + all_lab_scores.append(lab_score) + cur_pidx = next_pidx + final_arc_score = BK.concat(all_arc_scores, 0) + final_lab_score = BK.concat(all_lab_scores, 0) if self.system_labeled else None + return final_arc_score, final_lab_score + + # prepare the idxes from gold annotations of instances + def _get_idxes_from_insts(self, insts: List[ParseInstance], use_sib, use_gp): + b_idxes, m_idxes, h_idxes = [], [], [] + sib_idxes, gp_idxes = ([] if use_sib else None), ([] if use_gp else None) + lab_idxes = [] + for b, inst in enumerate(insts): + cur_len = len(inst) # length is len(heads)-1 + b_idxes.extend(b for _ in range(cur_len)) + # exclude the Node 0 as art-root + m_idxes.extend(range(1, cur_len+1)) + h_idxes.extend(inst.heads.vals[1:]) + lab_idxes.extend(inst.labels.idxes[1:]) + if use_sib: + sib_idxes.extend(inst.sibs[1:]) + if use_gp: + gp_idxes.extend(inst.gps[1:]) + return b_idxes, m_idxes, h_idxes, sib_idxes, gp_idxes, lab_idxes + + # prepare the idxes from predicted heads + def _get_idxes_from_preds(self, heads: List[List], labels: List[List], use_sib, use_gp): + has_label = labels is not None + b_idxes, m_idxes, h_idxes = [], [], [] + sib_idxes, gp_idxes = ([] if use_sib else None), ([] if use_gp else None) + lab_idxes = ([] if has_label else None) + b = 0 + for idx in range(len(heads)): + one_heads = heads[idx] + cur_len = len(one_heads)-1 # length is len(heads)-1 + b_idxes.extend(b for _ in range(cur_len)) + # exclude the Node 0 as art-root + m_idxes.extend(range(1, cur_len + 1)) + h_idxes.extend(one_heads[1:]) + if has_label: + lab_idxes.extend(labels[idx][1:]) + if use_sib: + children_left, children_right = ParseInstance.get_children(one_heads) + sib_idxes.extend(ParseInstance.get_sibs(children_left, children_right)[1:]) + if use_gp: + gp_idxes.extend(ParseInstance.get_gps(one_heads)[1:]) + b += 1 + return b_idxes, m_idxes, h_idxes, sib_idxes, gp_idxes, lab_idxes + + # add margin for one group of features, modified inplaced! + def _add_margin_inplaced(self, shape, hit_idxes0, hit_idxes1, hit_idxes2, hit_labels, + query_idxes0, query_idxes1, query_idxes2, arc_scores, lab_scores, + arc_margin: float, lab_margin: float): + # arc + gold_arc_mat = BK.constants(shape, 0.) + gold_arc_mat[hit_idxes0, hit_idxes1, hit_idxes2] = arc_margin + gold_arc_margins = gold_arc_mat[query_idxes0, query_idxes1, query_idxes2] + arc_scores -= gold_arc_margins + if lab_scores is not None: + # label + gold_lab_mat = BK.constants_idx(shape, 0) # 0 means the padding idx + gold_lab_mat[hit_idxes0, hit_idxes1, hit_idxes2] = hit_labels + gold_lab_margin_idxes = gold_lab_mat[query_idxes0, query_idxes1, query_idxes2] + lab_scores[BK.arange_idx(BK.get_shape(gold_lab_margin_idxes, 0)), gold_lab_margin_idxes] -= lab_margin + return + + # get hit mask + def _get_hit_mask(self, shape, hit_idxes0, hit_idxes1, hit_idxes2, query_idxes0, query_idxes1, query_idxes2): + hit_mat = BK.constants(shape, 0, dtype=BK.uint8) + hit_mat[hit_idxes0, hit_idxes1, hit_idxes2] = 1 + hit_mask = hit_mat[query_idxes0, query_idxes1, query_idxes2] + return hit_mask + +# ===== +# for specific systems + +# [m, h] +class G2O1Helper: + def get_cand_features(self, final_valid_expr): + # expand to get features: (m,h) + batch_idxes, m_idxes, h_idxes = [x.squeeze(-1) for x in final_valid_expr.nonzero().split(1, -1)] + vmask = (m_idxes != 0) # m cannot be 0 + return batch_idxes[vmask], m_idxes[vmask], h_idxes[vmask], None, None + + def get_unpruned_mask(self, valid_expr, gold_pack): + batch_idxes, m_idxes, h_idxes, _, _, _ = gold_pack + gold_mask = valid_expr[batch_idxes, m_idxes, h_idxes] + gold_mask = gold_mask.byte() + mod_unpruned_mask = BK.constants(BK.get_shape(valid_expr)[:2], 0, dtype=BK.uint8) + mod_unpruned_mask[batch_idxes[gold_mask], m_idxes[gold_mask]] = 1 + return mod_unpruned_mask, gold_mask + + def decode_one(self, slen: int, projective: bool, arr_o1_masks, arr_o1_scores, input_pack, cur_bidx_mask): + m_idxes, h_idxes, _, _, final_scores = input_pack + if arr_o1_scores is None: + arr_o1_scores = np.full([slen, slen], 0., dtype=np.double) + # direct add to the scores + m_idxes, h_idxes, final_scores = m_idxes[cur_bidx_mask], h_idxes[cur_bidx_mask], final_scores[cur_bidx_mask] + arr_o1_scores[BK.get_value(m_idxes), BK.get_value(h_idxes)] += BK.get_value(final_scores) + return hop_decode(slen, projective, arr_o1_masks, arr_o1_scores, None, None, None) + +# [m, h, sib] +class G2O2sibHelper: + def get_cand_features(self, final_valid_expr): + # expand to get features: (m,h), (h,sib) -> [m, h, sib] + expanded_valid_expr = final_valid_expr.unsqueeze(-1) * final_valid_expr.transpose(-1, -2).unsqueeze(-3) + batch_idxes, m_idxes, h_idxes, sib_idxes = [x.squeeze(-1) for x in expanded_valid_expr.nonzero().split(1, -1)] + del expanded_valid_expr + vmask = (m_idxes != 0) # m cannot be 0 + vmask &= (sib_idxes != 0) # sib cannot be 0 + # sib is inner than m + vmask &= (((m_idxes <= sib_idxes) & (sib_idxes < h_idxes)) | ((h_idxes < sib_idxes) & (sib_idxes <= m_idxes))) + return batch_idxes[vmask], m_idxes[vmask], h_idxes[vmask], sib_idxes[vmask], None + + def get_unpruned_mask(self, valid_expr, gold_pack): + batch_idxes, m_idxes, h_idxes, sib_idxes, _, _ = gold_pack + gold_mask = valid_expr[batch_idxes, m_idxes, h_idxes] + gold_mask *= valid_expr[batch_idxes, sib_idxes, h_idxes] + gold_mask = gold_mask.byte() + mod_unpruned_mask = BK.constants(BK.get_shape(valid_expr)[:2], 0, dtype=BK.uint8) + mod_unpruned_mask[batch_idxes[gold_mask], m_idxes[gold_mask]] = 1 + return mod_unpruned_mask, gold_mask + + def decode_one(self, slen: int, projective: bool, arr_o1_masks, arr_o1_scores, input_pack, cur_bidx_mask): + m_idxes, h_idxes, sib_idxes, _, final_scores = input_pack + o2sib_pack = [m_idxes[cur_bidx_mask].int(), h_idxes[cur_bidx_mask].int(), sib_idxes[cur_bidx_mask].int(), final_scores[cur_bidx_mask].double()] + o2sib_arr_pack = [BK.get_value(z) for z in o2sib_pack] + return hop_decode(slen, projective, arr_o1_masks, arr_o1_scores, o2sib_arr_pack, None, None) + +# [m, h, gp] +class G2O2gHelper: + def get_cand_features(self, final_valid_expr): + # expand to get features: (m,h), (h,gp) -> [m, h, gp] + expanded_valid_expr = final_valid_expr.unsqueeze(-1) * final_valid_expr.unsqueeze(-3) + batch_idxes, m_idxes, h_idxes, gp_idxes = [x.squeeze(-1) for x in expanded_valid_expr.nonzero().split(1, -1)] + del expanded_valid_expr + vmask = (m_idxes != 0) # m cannot be 0 + # sib is inner than m + vmask &= (m_idxes != gp_idxes) # gp cannot be m + return batch_idxes[vmask], m_idxes[vmask], h_idxes[vmask], None, gp_idxes[vmask] + + def get_unpruned_mask(self, valid_expr, gold_pack): + batch_idxes, m_idxes, h_idxes, _, gp_idxes, _ = gold_pack + gold_mask = valid_expr[batch_idxes, m_idxes, h_idxes] + gold_mask *= valid_expr[batch_idxes, h_idxes, gp_idxes] + gold_mask = gold_mask.byte() + mod_unpruned_mask = BK.constants(BK.get_shape(valid_expr)[:2], 0, dtype=BK.uint8) + mod_unpruned_mask[batch_idxes[gold_mask], m_idxes[gold_mask]] = 1 + return mod_unpruned_mask, gold_mask + + def decode_one(self, slen: int, projective: bool, arr_o1_masks, arr_o1_scores, input_pack, cur_bidx_mask): + m_idxes, h_idxes, _, gp_idxes, final_scores = input_pack + o2g_pack = [m_idxes[cur_bidx_mask].int(), h_idxes[cur_bidx_mask].int(), gp_idxes[cur_bidx_mask].int(), final_scores[cur_bidx_mask].double()] + o2g_arr_pack = [BK.get_value(z) for z in o2g_pack] + return hop_decode(slen, projective, arr_o1_masks, arr_o1_scores, None, o2g_arr_pack, None) + +# [m, h, sib, gp] +class G2O3gsibHelper: + def get_cand_features(self, final_valid_expr): + # expand to get features: (m,h), (h,sib), (h,gp) -> [m, h, sib, gp] + expanded_valid_expr = final_valid_expr.unsqueeze(-1).unsqueeze(-1) * \ + final_valid_expr.transpose(-1, -2).unsqueeze(-3).unsqueeze(-1) * \ + final_valid_expr.unsqueeze(-3).unsqueeze(-2) + batch_idxes, m_idxes, h_idxes, sib_idxes, gp_idxes = [x.squeeze(-1) for x in + expanded_valid_expr.nonzero().split(1, -1)] + del expanded_valid_expr + vmask = (m_idxes != 0) # m cannot be 0 + vmask &= (sib_idxes != 0) # sib cannot be 0 + # sib is inner than m + vmask &= (((m_idxes<=sib_idxes) & (sib_idxes if no g1parser provided, then use aux scores; otherwise, it depends on g1_use_aux_scores + def _get_g1_pack(self, insts: List[ParseInstance], mask_expr, score_arc_lambda: float, score_lab_lambda: float): + if self.g1parser is None or self.g1parser.g1_use_aux_scores: + arc_score, lab_score, _ = G1Parser.collect_aux_scores(insts, self.num_label) + else: + arc_score, lab_score = self.g1parser.score_on_batch(insts) + arc_score = arc_score.unsqueeze(-1) + # pruner1 (decoding pruning) + valid_mask_d, arc_marginals = G1Parser.prune_with_scores(arc_score, lab_score, mask_expr, self.conf.dprune) + # pruner2 (structured one) -- no need to recalculate arc_marginals + valid_mask_s, _ = G1Parser.prune_with_scores(arc_score, lab_score, mask_expr, self.conf.sprune, arc_marginals) + # + if score_arc_lambda <= 0. and score_lab_lambda <= 0.: + go1_pack = None + else: + arc_score *= score_arc_lambda + lab_score *= score_lab_lambda + go1_pack = (arc_score, lab_score) + # [*, slen, slen], ([*, slen, slen], [*, slen, slen, Lab]), [*, m, h] + assert arc_marginals is not None + return valid_mask_d, valid_mask_s, go1_pack, arc_marginals + + # make real valid masks (inplaced): valid(byte): [bs, len-m, len-h]; mask(float): [bs, len] + def _make_final_valid(self, valid_expr, mask_expr): + maxlen = BK.get_shape(mask_expr, -1) + # first apply masks + mask_expr_byte = mask_expr.byte() + valid_expr &= mask_expr_byte.unsqueeze(-1) + valid_expr &= mask_expr_byte.unsqueeze(-2) + # then diag + mask_diag = 1-BK.eye(maxlen).byte() + valid_expr &= mask_diag + # root not as mod, todo(note): here no [0,0] since no need + valid_expr[:, 0] = 0 + return valid_expr.float() + + # common calculations for both decoding and training + def _score(self, insts: List[ParseInstance], training: bool, lambda_g1_arc: float, lambda_g1_lab: float): + # encode + input_repr, enc_repr, jpos_pack, mask_arr = self.bter.run(insts, training) + mask_expr = BK.input_real(mask_arr) + # pruning and scores from g1 + valid_mask_d, valid_mask_s, go1_pack, arc_marginals = self._get_g1_pack(insts, mask_expr, lambda_g1_arc, lambda_g1_lab) + # s-encode (using s-mask) + final_valid_expr_s = self._make_final_valid(valid_mask_s, mask_expr) + senc_repr = self.slayer(enc_repr, final_valid_expr_s, arc_marginals) + # decode + arc_score = self.scorer_helper.score_arc(senc_repr) + lab_score = self.scorer_helper.score_label(senc_repr) + full_score = arc_score + lab_score + # add go1 scores and apply pruning (using d-mask) + final_valid_expr_d = valid_mask_d.float() # no need to mask out others here! + mask_value = Constants.REAL_PRAC_MIN + if go1_pack is not None: + go1_arc_score, go1_label_score = go1_pack + full_score += go1_arc_score.unsqueeze(-1) + go1_label_score + full_score += (mask_value * (1. - final_valid_expr_d)).unsqueeze(-1) + # [*, m, h, lab], (original-scores), [*, m], [*, m, h] + return full_score, (arc_score, lab_score), jpos_pack, mask_expr, final_valid_expr_d, final_valid_expr_s + + # decoding + def inference_on_batch(self, insts: List[ParseInstance], **kwargs): + # iconf = self.conf.iconf + with BK.no_grad_env(): + self.refresh_batch(False) + full_score, _, jpos_pack, mask_expr, _, _ = \ + self._score(insts, False, self.lambda_g1_arc_testing, self.lambda_g1_lab_testing) + # collect the results together + # ===== + self._decode(insts, full_score, mask_expr, "s2") + # put jpos result (possibly) + self.jpos_decode(insts, jpos_pack) + # ----- + info = {"sent": len(insts), "tok": sum(map(len, insts))} + return info + + # training + def fb_on_batch(self, annotated_insts: List[ParseInstance], training=True, loss_factor=1., **kwargs): + self.refresh_batch(training) + # todo(note): here always using training lambdas + full_score, original_scores, jpos_pack, mask_expr, valid_mask_d, _ = \ + self._score(annotated_insts, False, self.lambda_g1_arc_training, self.lambda_g1_lab_training) + parsing_loss, info = self._loss(annotated_insts, full_score, mask_expr, valid_mask_d) + # other loss? + jpos_loss = self.jpos_loss(jpos_pack, mask_expr) + reg_loss = self.reg_scores_loss(*original_scores) + # + info["loss_parse"] = BK.get_value(parsing_loss).item() + final_loss = parsing_loss + if jpos_loss is not None: + info["loss_jpos"] = BK.get_value(jpos_loss).item() + final_loss = parsing_loss + jpos_loss + if reg_loss is not None: + final_loss = final_loss + reg_loss + info["fb"] = 1 + if training: + BK.backward(final_loss, loss_factor) + return info + diff --git a/tasks/zdpar/ef/scorer/__init__.py b/tasks/zdpar/ef/scorer/__init__.py new file mode 100644 index 0000000..027577b --- /dev/null +++ b/tasks/zdpar/ef/scorer/__init__.py @@ -0,0 +1,4 @@ +# + +from .scorer import Scorer, ScorerConf +from .slayer import SL0Conf, SL0Layer, SL1Conf, SL1Layer diff --git a/tasks/zdpar/ef/scorer/scorer.py b/tasks/zdpar/ef/scorer/scorer.py new file mode 100644 index 0000000..d544a39 --- /dev/null +++ b/tasks/zdpar/ef/scorer/scorer.py @@ -0,0 +1,166 @@ +# + +from msp.utils import Conf, Helper +from msp.nn import BK +from msp.nn.layers import BasicNode, Affine, BiAffineScorer, AttConf + +# ----- +# scorer +# the scorer only cares about giving (raw) scores for the head-mod pairs +# all the neural model parameters belongs here! + +# ----- +# confs + +# the head-mod pair scorer +class ScorerConf(Conf): + def __init__(self): + self._input_dim = -1 # enc's (input) last dimension + self._num_label = -1 # number of labels + # space transferring + self.arc_space = 512 # arc_space==0 means no_arc_score + self.lab_space = 128 + # ----- + # final biaffine scoring + self.transform_act = "elu" + self.ff_hid_size = 0 + self.ff_hid_layer = 0 + self.use_biaffine = True + self.use_ff = True + self.use_ff2 = False + self.biaffine_div = 1. + self.biaffine_init_ortho = False + # distance clip? + self.arc_dist_clip = -1 + self.arc_use_neg = False + + def get_dist_aconf(self): + return AttConf().init_from_kwargs(clip_dist=self.arc_dist_clip, use_neg_dist=self.arc_use_neg) + +# todo(WARN): about no_arc_score (deprecated and deleted) +""" +Actually, after further thinking, arc+label score is a reasonable choice, since if there are only label-score, then there seem to be no param-sharing things in the final scoring. Actually, on PTB, NoPOS/arc=0 is hard to converge with default hp (starting to be reasonable after Epoch-70+), although +POS it converges. Therefore, when parameterizing NN, think about the intuition and make it easier to train with proper param-sharing. +-> My guess for the non-convergence is that it is hard to learn localness from purely lexicalized input; arc-scorer may be easier to learn localness since most arcs are local, POS can also provide helpful info on this aspect, maybe. +""" + +# ----- +# actual model components + +# scoring for pairs of head-mod on structured enc repr +# orig-enc -> s-enc -> h/m-cache +class Scorer(BasicNode): + def __init__(self, pc: BK.ParamCollection, sconf: ScorerConf): + super().__init__(pc, None, None) + # options + input_dim = sconf._input_dim + arc_space = sconf.arc_space + lab_space = sconf.lab_space + ff_hid_size = sconf.ff_hid_size + ff_hid_layer = sconf.ff_hid_layer + use_biaffine = sconf.use_biaffine + use_ff = sconf.use_ff + use_ff2 = sconf.use_ff2 + biaffine_div = sconf.biaffine_div + biaffine_init_ortho = sconf.biaffine_init_ortho + transform_act = sconf.transform_act + # + self.input_dim = input_dim + self.num_label = sconf._num_label + # attach/arc + self.arc_m = self.add_sub_node("am", Affine(pc, input_dim, arc_space, act=transform_act)) + self.arc_h = self.add_sub_node("ah", Affine(pc, input_dim, arc_space, act=transform_act)) + self.arc_scorer = self.add_sub_node( + "as", BiAffineScorer(pc, arc_space, arc_space, 1, ff_hid_size, ff_hid_layer=ff_hid_layer, + use_biaffine=use_biaffine, use_ff=use_ff, use_ff2=use_ff2, + biaffine_div=biaffine_div, biaffine_init_ortho=biaffine_init_ortho)) + # only add distance for arc + if sconf.arc_dist_clip > 0: + # todo(+N): how to include dist feature? + # self.dist_helper = self.add_sub_node("dh", AttDistHelper(pc, sconf.get_dist_aconf(), arc_space)) + self.dist_helper = None + raise NotImplemented("TODO") + else: + self.dist_helper = None + # labeling + self.lab_m = self.add_sub_node("lm", Affine(pc, input_dim, lab_space, act=transform_act)) + self.lab_h = self.add_sub_node("lh", Affine(pc, input_dim, lab_space, act=transform_act)) + self.lab_scorer = self.add_sub_node( + "ls", BiAffineScorer(pc, lab_space, lab_space, self.num_label, ff_hid_size, ff_hid_layer=ff_hid_layer, + use_biaffine=use_biaffine, use_ff=use_ff, use_ff2=use_ff2, + biaffine_div=biaffine_div, biaffine_init_ortho=biaffine_init_ortho)) + + # get separate params of MLP and scorer + def get_split_params(self): + params0 = Helper.join_list(z.get_parameters() for z in [self.arc_m, self.arc_h, self.lab_m, self.lab_h]) + params1 = Helper.join_list(z.get_parameters() for z in [self.arc_scorer, self.lab_scorer]) + return params0, params1 + + # transform to specific space (+pre-computation for mod) (currently arc/label * m/h) + # [*, input_dim] -> *[*, space_dim] + + def transform_space_arc(self, senc_expr, calc_h: bool=True, calc_m: bool=True): + # for head + if calc_h: + ah_expr = self.arc_h(senc_expr) + else: + ah_expr = None + # for mod + if calc_m: + am_expr = self.arc_m(senc_expr) + am_pack = self.arc_scorer.precompute_input0(am_expr) + else: + am_pack = None + return ah_expr, am_pack + + def transform_space_label(self, senc_expr, calc_h: bool=True, calc_m: bool=True): + # for arc + if calc_h: + lh_expr = self.lab_h(senc_expr) + else: + lh_expr = None + # for mod + if calc_m: + lm_expr = self.lab_m(senc_expr) + lm_pack = self.lab_scorer.precompute_input0(lm_expr) + else: + lm_pack = None + return lh_expr, lm_pack + + # ===== + # separate scores + + def score_arc(self, am_pack, ah_expr, mask_m=None, mask_h=None): + return self.arc_scorer.postcompute_input1(am_pack, ah_expr, mask0=mask_m, mask1=mask_h) + + def score_label(self, lm_pack, lh_expr, mask_m=None, mask_h=None): + return self.lab_scorer.postcompute_input1(lm_pack, lh_expr, mask0=mask_m, mask1=mask_h) + + # ===== + # special mode for first order full score + + def transform_and_arc_score(self, senc_expr, mask_expr=None): + ah_expr = self.arc_h(senc_expr) + am_expr = self.arc_m(senc_expr) + arc_full_score = self.arc_scorer.paired_score(am_expr, ah_expr, mask_expr, mask_expr) + return arc_full_score + + def transform_and_label_score(self, senc_expr, mask_expr=None): + lh_expr = self.lab_h(senc_expr) + lm_expr = self.lab_m(senc_expr) + lab_full_score = self.lab_scorer.paired_score(lm_expr, lh_expr, mask_expr, mask_expr) + return lab_full_score + + # ===== + # plain mode for scoring + + def transform_and_arc_score_plain(self, mod_srepr, head_srepr, mask_expr=None): + ah_expr = self.arc_h(head_srepr) + am_expr = self.arc_m(mod_srepr) + arc_full_score = self.arc_scorer.plain_score(am_expr, ah_expr, mask_expr, mask_expr) + return arc_full_score + + def transform_and_label_score_plain(self, mod_srepr, head_srepr, mask_expr=None): + lh_expr = self.lab_h(head_srepr) + lm_expr = self.lab_m(mod_srepr) + lab_full_score = self.lab_scorer.plain_score(lm_expr, lh_expr, mask_expr, mask_expr) + return lab_full_score diff --git a/tasks/zdpar/ef/scorer/slayer.py b/tasks/zdpar/ef/scorer/slayer.py new file mode 100644 index 0000000..12e4f83 --- /dev/null +++ b/tasks/zdpar/ef/scorer/slayer.py @@ -0,0 +1,258 @@ +# + +# middle-part structured layer + +from typing import List +import numpy as np + +from msp.utils import Conf, zcheck +from msp.nn import BK +from msp.nn.layers import BasicNode, Affine, Embedding, MultiHeadAttention, AttConf +from msp.zext.seq_helper import DataPadder + +# ===== +# first the one for ef and g2 + +# the one that calculates the feature-enhanced repr +class SL0Conf(Conf): + def __init__(self): + self._input_dim = -1 # enc's (input) last dimension + self._num_label = -1 # number of labels + # model dimensions + self.dim_label = 30 + # + self.use_label_feat = True # whether use feats + self.use_chs = True # whether use children feats + self.chs_num = 0 # how many (recent) ch's siblings to consider, 0 means all: [-num:] + self.chs_f = "att" # how to represent the ch set: att=attention, sum=sum, ff=feadfoward (always 0 vector if no ch) + self.chs_att = AttConf().init_from_kwargs(d_kqv=256, att_dropout=0.1, head_count=2) + self.use_par = True # whether use parent feats + # + self.zero_extra_output_params = False # whether makes extra output params zero when init + +# another helper class for children set calculation +class ChsReprer(BasicNode): + def __init__(self, pc: BK.ParamCollection, rconf: SL0Conf): + super().__init__(pc, None, None) + # concat enc and label if use-label + input_dim = rconf._input_dim + # todo(warn): always add label related params + self.dim = (input_dim + rconf.dim_label) + self.is_att, self.is_sum, self.is_ff = [rconf.chs_f == z for z in ["att", "sum", "ff"]] + self.ff_reshape = [-1, self.dim * rconf.chs_num] # only for ff + # todo(+N): not elegant, but add the params to make most things compatible when reloading + self.mha = self.add_sub_node("fn", MultiHeadAttention(pc, self.dim, input_dim, self.dim, rconf.chs_att)) + if self.is_att: + # todo(+N): the only possibly inconsistent drop is the one for attention, may consider disabling it. + self.fnode = self.mha + elif self.is_sum: + self.fnode = None + elif self.is_ff: + zcheck(rconf.chs_num > 0, "Err: Cannot ff with 0 child") + self.fnode = self.add_sub_node("fn", Affine(pc, self.dim * rconf.chs_num, self.dim, act="elu")) + else: + raise NotImplementedError(f"UNK chs method: {rconf.chs_f}") + + def get_output_dims(self, *input_dims): + return (self.dim, ) + + def zeros(self, batch): + return BK.zeros((batch, self.dim)) + + # [*, D(q)], [*, max-ch, D(kv)], [*, max-ch], [*] + def __call__(self, chs_input_state_t, chs_input_mem_t, chs_mask_t, chs_valid_t): + if self.is_att: + # [*, max-ch, size], ~, [*, 1, size], [*, max-ch] -> [*, size] + ret = self.fnode(chs_input_mem_t, chs_input_mem_t, chs_input_state_t.unsqueeze(-2), chs_mask_t).squeeze(-2) + # ignore head for the rest + elif self.is_sum: + # ignore head + if chs_mask_t is None: + ret = chs_input_mem_t.sum(-2) + else: + ret = (chs_input_mem_t*chs_mask_t.unsqueeze(-1)).sum(-2) + elif self.is_ff: + if chs_mask_t is None: + reshaped_input_state_t = chs_input_mem_t.view(self.ff_reshape) + else: + reshaped_input_state_t = (chs_input_mem_t*chs_mask_t.unsqueeze(-1)).view(self.ff_reshape) + ret = self.fnode(reshaped_input_state_t) + else: + ret = None + if chs_valid_t is None: + return ret + else: + # out-most mask: [*, D_DL] * [*, 1] + return ret * (chs_valid_t.unsqueeze(-1)) + +# a helper class to calculate structured repr for one node +# todo(note): currently a simple architecture: cur+par+children_summary +class SL0Layer(BasicNode): + def __init__(self, pc: BK.ParamCollection, rconf: SL0Conf): + super().__init__(pc, None, None) + self.dim = rconf._input_dim # both input/output dim + # padders for child nodes + self.chs_start_posi = -rconf.chs_num + self.ch_idx_padder = DataPadder(2, pad_vals=0, mask_range=2) # [*, num-ch] + self.ch_label_padder = DataPadder(2, pad_vals=0) + # + self.label_embeddings = self.add_sub_node("label", Embedding(pc, rconf._num_label, rconf.dim_label, fix_row0=False)) + self.dim_label = rconf.dim_label + # todo(note): now adopting flatten groupings for basic, and then that is all, no more recurrent features + # group 1: [cur, chs, par] -> head_pre_size + self.use_chs = rconf.use_chs + self.use_par = rconf.use_par + self.use_label_feat = rconf.use_label_feat + # components (add the parameters anyway) + # todo(note): children features: children + (label of mod->children) + self.chs_reprer = self.add_sub_node("chs", ChsReprer(pc, rconf)) + self.chs_ff = self.add_sub_node("chs_ff", Affine(pc, self.chs_reprer.get_output_dims()[0], self.dim, act="tanh")) + # todo(note): parent features: parent + (label of parent->mod) + # todo(warn): always add label related params + par_ff_inputs = [self.dim, rconf.dim_label] + self.par_ff = self.add_sub_node("par_ff", Affine(pc, par_ff_inputs, self.dim, act="tanh")) + # no other groups anymore! + if rconf.zero_extra_output_params: + self.par_ff.zero_params() + self.chs_ff.zero_params() + + # calculating the structured representations, giving raw repr-tensor + # 1) [*, D]; 2) [*, D], [*], [*]; 3) [*, chs-len, D], [*, chs-len], [*, chs-len], [*] + def calculate_repr(self, cur_t, par_t, label_t, par_mask_t, chs_t, chs_label_t, chs_mask_t, chs_valid_mask_t): + ret_t = cur_t # [*, D] + # padding 0 if not using labels + dim_label = self.dim_label + # child features + if self.use_chs and chs_t is not None: + if self.use_label_feat: + chs_label_rt = self.label_embeddings(chs_label_t) # [*, max-chs, dlab] + else: + labels_shape = BK.get_shape(chs_t) + labels_shape[-1] = dim_label + chs_label_rt = BK.zeros(labels_shape) + chs_input_t = BK.concat([chs_t, chs_label_rt], -1) + chs_feat0 = self.chs_reprer(cur_t, chs_input_t, chs_mask_t, chs_valid_mask_t) + chs_feat = self.chs_ff(chs_feat0) + ret_t += chs_feat + # parent features + if self.use_par and par_t is not None: + if self.use_label_feat: + cur_label_t = self.label_embeddings(label_t) # [*, dlab] + else: + labels_shape = BK.get_shape(par_t) + labels_shape[-1] = dim_label + cur_label_t = BK.zeros(labels_shape) + par_feat = self.par_ff([par_t, cur_label_t]) + if par_mask_t is not None: + par_feat *= par_mask_t.unsqueeze(-1) + ret_t += par_feat + return ret_t + + # todo(note): if no other features, then no change for the repr! + def forward_repr(self, cur_t): + return cur_t + + # preparation: padding for chs/par + def pad_chs(self, idxes_list: List[List], labels_list: List[List]): + start_posi = self.chs_start_posi + if start_posi < 0: # truncate + idxes_list = [x[start_posi:] for x in idxes_list] + # overall valid mask + chs_valid = [(0. if len(z)==0 else 1.) for z in idxes_list] + # if any valid children in the batch + if all(x>0 for x in chs_valid): + padded_chs_idxes, padded_chs_mask = self.ch_idx_padder.pad(idxes_list) # [*, max-ch], [*, max-ch] + if self.use_label_feat: + if start_posi < 0: # truncate + labels_list = [x[start_posi:] for x in labels_list] + padded_chs_labels, _ = self.ch_label_padder.pad(labels_list) # [*, max-ch] + chs_label_t = BK.input_idx(padded_chs_labels) + else: + chs_label_t = None + chs_idxes_t, chs_mask_t, chs_valid_mask_t = \ + BK.input_idx(padded_chs_idxes), BK.input_real(padded_chs_mask), BK.input_real(chs_valid) + return chs_idxes_t, chs_label_t, chs_mask_t, chs_valid_mask_t + else: + return None, None, None, None + + def pad_par(self, idxes: List, labels: List): + par_idxes_t = BK.input_idx(idxes) + labels_t = BK.input_idx(labels) + # todo(note): specifically, <0 means non-exist + # todo(note): an interesting bug, the bug is ">=" was wrongly written as "<", in this way, 0 will act as the parent of those who actually do not have parents and are to be attached, therefore maybe patterns of "parent=0" will get much positive scores + # todo(note): ACTUALLY, mainly because of the difference in search and forward-backward!! + par_mask_t = (par_idxes_t>=0).float() + par_idxes_t.clamp_(0) # since -1 will be illegal idx + labels_t.clamp_(0) + return par_idxes_t, labels_t, par_mask_t + +# ===== +# second the specific one for s2 + +class SL1Conf(Conf): + def __init__(self): + self._input_dim = -1 # enc's (input) last dimension + self._num_label = -1 # number of labels + # + self.use_par = True + self.use_chs = True + self.sl_par_att = AttConf().init_from_kwargs(d_kqv=256, att_dropout=0., head_count=2) + self.sl_chs_att = AttConf().init_from_kwargs(d_kqv=256, att_dropout=0., head_count=2) + self.mix_marginals_head_count = 0 # how many heads to mix + self.mix_marginals_rate = 1. # mix into the self attentions + # + self.zero_extra_output_params = False # whether makes extra output params zero when init + +# additional structured-masked layer +class SL1Layer(BasicNode): + def __init__(self, pc: BK.ParamCollection, slconf: SL1Conf): + super().__init__(pc, None, None) + self.dim = slconf._input_dim + self.use_par = slconf.use_par + self.use_chs = slconf.use_chs + # parent and children attentional senc + self.node_par = self.add_sub_node("npar", MultiHeadAttention(pc, self.dim, self.dim, self.dim, slconf.sl_par_att)) + self.node_chs = self.add_sub_node("nchs", MultiHeadAttention(pc, self.dim, self.dim, self.dim, slconf.sl_chs_att)) + self.ff_par = self.add_sub_node("par_ff", Affine(pc, self.dim, self.dim, act="tanh")) + self.ff_chs = self.add_sub_node("chs_ff", Affine(pc, self.dim, self.dim, act="tanh")) + # todo(note): currently simply sum them! + self.mix_marginals_head_count = slconf.mix_marginals_head_count + self.mix_marginals_rate = slconf.mix_marginals_rate + if slconf.zero_extra_output_params: + self.ff_par.zero_params() + self.ff_chs.zero_params() + + def get_output_dims(self, *input_dims): + return (self.dim, ) + + # calculation for one type of node + def _calc_one_node(self, node_att: BasicNode, node_ff: BasicNode, enc_expr, qk_mask, qk_marginals): + # [*, len, D] + hidden = node_att(enc_expr, enc_expr, enc_expr, mask_qk=qk_mask, eprob_qk=qk_marginals, + eprob_mix_rate=self.mix_marginals_rate, eprob_head_count=self.mix_marginals_head_count) + q_mask = (qk_mask.float().sum(-1)>0.).float() # [*, len], at least one for query + hidden = hidden * (q_mask.unsqueeze(-1)) # [*, len, D], zeros for emtpy ones + output = node_ff(hidden) + return output + + # [*, slen, D], [*, len-m, len-h] + def __call__(self, enc_expr, valid_expr, arc_marg_expr): + # ===== avoid NAN + def _sum_marg(m, dim): + s = m.sum(dim).unsqueeze(dim) + s += (s < 1e-5).float() * 1e-5 + return s + # ===== + output = enc_expr + arc_marg_expr = arc_marg_expr * valid_expr # only keep the after-pruning ones + if self.use_par: + m_mask = valid_expr # [*, lem-m, len-h] + m_marg = arc_marg_expr / _sum_marg(arc_marg_expr, -1) + senc_par_expr = self._calc_one_node(self.node_par, self.ff_par, enc_expr, m_mask, m_marg) + output = output + senc_par_expr + if self.use_chs: + h_mask = BK.copy(valid_expr.transpose(-1, -2)) # [*, len-h, len-m] + h_marg = (arc_marg_expr / _sum_marg(arc_marg_expr, -2)).transpose(-1, -2) + senc_chs_expr = self._calc_one_node(self.node_chs, self.ff_chs, enc_expr, h_mask, h_marg) + output = output + senc_chs_expr + return output diff --git a/tasks/zdpar/ef/systems/__init__.py b/tasks/zdpar/ef/systems/__init__.py new file mode 100644 index 0000000..ace48a9 --- /dev/null +++ b/tasks/zdpar/ef/systems/__init__.py @@ -0,0 +1,3 @@ +# + +from .base_search import EfSearcher, EfSearchConf, EfState, EfAction, ScorerHelper diff --git a/tasks/zdpar/ef/systems/base_search.py b/tasks/zdpar/ef/systems/base_search.py new file mode 100644 index 0000000..4b7ddd4 --- /dev/null +++ b/tasks/zdpar/ef/systems/base_search.py @@ -0,0 +1,569 @@ +# + +# base system + +from typing import List + +from msp.utils import Helper, Constants, Conf, zlog, Random +from msp.nn import BK +from msp.search.lsg import BfsAgenda, BfsSearcher, BfsExpander, BfsLocalSelector, BfsGlobalArranger, BfsEnder + +from ..scorer import Scorer, SL0Layer +from .base_system import EfState, EfAction, EfOracler, EfCoster, EfSignaturer, ParseInstance, Graph +from .systems import StateBuilder + +# ===== +# searching +# todo(warn): all the caches are modified inplaced to save memory, +# thus the nn-graph should be reconstructed later for backward, +# therefore, does not support local normalization here! +# Two choices: 1. Three runs for training: search+forward+backward, 2. using Cache for batched management + +# +class ScorerHelper: + # get features from action and state + class HmFeatureGetter: + def __init__(self, ignore_chs_label_mask, fdrop_chs: float, fdrop_par: float): + self.ignore_chs_label_mask = ignore_chs_label_mask + if fdrop_chs>0.: + self.chs_getter_f = self.get_fchs_dropped + self.chs_rand_gen = Random.stream_bool(fdrop_chs) + else: + self.chs_getter_f = self.get_fchs_base + if fdrop_par>0.: + self.par_getter_f = self.get_fpar_dropped + self.par_rand_gen = Random.stream_bool(fdrop_par) + else: + self.par_getter_f = self.get_fpar_base + + def get_hm_features(self, actions: List[EfAction], states: List[EfState]): + hm_features = ([[], [], [], [], []], [[], [], [], [], []]) # head, mod + chs_getter_f, par_getter_f = self.chs_getter_f, self.par_getter_f + for action, state in zip(actions, states): + hm_idxes = action.head, action.mod + for idx, features_group in zip(hm_idxes, hm_features): + features_group[0].append(idx) + par_idx, par_label = par_getter_f(state.list_arc[idx], state.list_label[idx]) + features_group[1].append(par_idx) + features_group[2].append(par_label) + chs_idxes, chs_labels = chs_getter_f(state.idxes_chs[idx], state.labels_chs[idx]) + features_group[3].append(chs_idxes) + features_group[4].append(chs_labels) + return hm_features + + # ===== + def get_fpar_base(self, par_idx, par_label): + # todo(+2): currently no label filtering here! + return par_idx, par_label + + def get_fchs_base(self, chs_idxes, chs_labels): + ignore_chs_label_mask = self.ignore_chs_label_mask + ret_idxes, ret_labels = [], [] + for a, b in zip(chs_idxes, chs_labels): + if not ignore_chs_label_mask[b]: # excluded if True in ignore_mask + ret_idxes.append(a) + ret_labels.append(b) + return ret_idxes, ret_labels + + def get_fpar_dropped(self, par_idx, par_label): + if next(self.par_rand_gen): + return -1, -1 # todo(note): -1 means None for par + else: + return par_idx, par_label + + def get_fchs_dropped(self, chs_idxes, chs_labels): + ignore_chs_label_mask = self.ignore_chs_label_mask + ret_idxes, ret_labels = [], [] + for a,b in zip(chs_idxes, chs_labels): + if (b not in ignore_chs_label_mask) and (not next(self.chs_rand_gen)): + ret_idxes.append(a) + ret_labels.append(b) + return ret_idxes, ret_labels + + # from features -> srepr [*, D] + @staticmethod + def calc_repr(s_enc: SL0Layer, features_group, enc_expr, bidxes_expr): + cur_idxes, par_idxes, labels, chs_idxes, chs_labels = features_group + # get padded idxes: [*] or [*, ?] + cur_idxes_t = BK.input_idx(cur_idxes) + par_idxes_t, label_t, par_mask_t = s_enc.pad_par(par_idxes, labels) + chs_idxes_t, chs_label_t, chs_mask_t, chs_valid_mask_t = s_enc.pad_chs(chs_idxes, chs_labels) + # gather enc-expr: [*, D], [*, D], [*, max-chs, D] + dim1_range_t = bidxes_expr + dim2_range_t = dim1_range_t.unsqueeze(-1) + cur_t = enc_expr[dim1_range_t, cur_idxes_t] + par_t = enc_expr[dim1_range_t, par_idxes_t] + chs_t = None if chs_idxes_t is None else enc_expr[dim2_range_t, chs_idxes_t] + # update reprs: [*, D] + new_srepr = s_enc.calculate_repr(cur_t, par_t, label_t, par_mask_t, chs_t, chs_label_t, chs_mask_t, chs_valid_mask_t) + return cur_idxes_t, new_srepr + +# scoring cache +class EfScoringCacheArc: + def __init__(self, scorer: Scorer, slayer: SL0Layer, system_labeled: bool): + self.scorer = scorer + self.slayer = slayer + self.system_labeled = system_labeled + self.num_label = scorer.num_label + # scoring caches + self.head_arc_cache = None # tmp repr-cache for decoder for as-head + self.head_label_cache = None + self.mod_arc_cache = None # tmp repr-cache for decoder for as-mod + self.mod_label_cache = None + # scores + self.arc_scores = None # paired arc scores (m, h): [*, len-m, len-h] + # extra scores (from g1) + self.g1_arc_scores = self.g1_lab_scores = None + + def init_cache(self, enc_repr, g1_pack): + # dec cache, *[orig_bsize, max_slen, ?] + # todo(note): since at the start there are no structured features + enc_s_repr = self.slayer.forward_repr(enc_repr) + self.head_arc_cache, self.mod_arc_cache = self.scorer.transform_space_arc(enc_s_repr) + if self.system_labeled: + self.head_label_cache, self.mod_label_cache = self.scorer.transform_space_label(enc_s_repr) + # arc score (no mask applied here) + head_inputs = self.head_arc_cache.unsqueeze(-3) # [*, 1, len-h, D] + mod_inputs = [z.unsqueeze(-2) for z in self.mod_arc_cache] # [*, len-m, 1, ?] + self.arc_scores = self.scorer.score_arc(mod_inputs, head_inputs).squeeze(-1) # [*, len-m, len-h] + if g1_pack is not None: + self.g1_arc_scores, self.g1_lab_scores = g1_pack # [*, lem-m, len-h], [*, len-m, len-h, L] + + def arange_cache(self, bidxes_device): + self.head_arc_cache = self.head_arc_cache.index_select(0, bidxes_device) + self.mod_arc_cache = [z.index_select(0, bidxes_device) for z in self.mod_arc_cache] + if self.system_labeled: + self.head_label_cache = self.head_label_cache.index_select(0, bidxes_device) + self.mod_label_cache = [z.index_select(0, bidxes_device) for z in self.mod_label_cache] + self.arc_scores = self.arc_scores.index_select(0, bidxes_device) + if self.g1_arc_scores is not None: + self.g1_arc_scores = self.g1_arc_scores.index_select(0, bidxes_device) + if self.g1_lab_scores is not None: + self.g1_lab_scores = self.g1_lab_scores.index_select(0, bidxes_device) + + # update caches and scores; [*], [*, D] + def update_cache_and_score(self, node_idxes_t, bsize_range_t, node_srepr, update_as_head: bool, update_as_mod: bool): + system_labeled = self.system_labeled + # get caches of [*, D] + node_ah_expr, node_am_pack = self.scorer.transform_space_arc(node_srepr, update_as_head, update_as_mod) + if system_labeled: + node_lh_expr, node_lm_pack = self.scorer.transform_space_label(node_srepr, update_as_head, update_as_mod) + # todo(note): inplaced update, which means not backward on this (only forward for search graph) + dim1_range_t = bsize_range_t + if update_as_head: + self.head_arc_cache[dim1_range_t, node_idxes_t] = node_ah_expr + if system_labeled: + self.head_label_cache[dim1_range_t, node_idxes_t] = node_lh_expr + # [*, L] + scores_as_head = self.scorer.score_arc(self.mod_arc_cache, node_ah_expr.unsqueeze(-2)).squeeze(-1) + self.arc_scores[dim1_range_t, :, node_idxes_t] = scores_as_head + if update_as_mod: + for v_to_be_filled, v_to_fill in zip(self.mod_arc_cache, node_am_pack): + v_to_be_filled[dim1_range_t, node_idxes_t] = v_to_fill + if system_labeled: + for v_to_be_filled, v_to_fill in zip(self.mod_label_cache, node_lm_pack): + v_to_be_filled[dim1_range_t, node_idxes_t] = v_to_fill + # [*, L] + scores_as_mod = self.scorer.score_arc([x.unsqueeze(-2) for x in node_am_pack], self.head_arc_cache).squeeze(-1) + self.arc_scores[dim1_range_t, node_idxes_t] = scores_as_mod + + # label scores: [*, k] + def get_selected_label_scores(self, idxes_m_t, idxes_h_t, bsize_range_t, oracle_mask_t, oracle_label_t, + arc_margin: float, label_margin: float): + # todo(note): in this mode, no repeated arc_margin + dim1_range_t = bsize_range_t + dim2_range_t = dim1_range_t.unsqueeze(-1) + if self.system_labeled: + selected_m_cache = [z[dim2_range_t, idxes_m_t] for z in self.mod_label_cache] + selected_h_repr = self.head_label_cache[dim2_range_t, idxes_h_t] + ret = self.scorer.score_label(selected_m_cache, selected_h_repr) # [*, k, labels] + if label_margin > 0.: + oracle_label_idxes = oracle_label_t[dim2_range_t, idxes_m_t, idxes_h_t].unsqueeze(-1) # [*, k, 1] of int + ret.scatter_add_(-1, oracle_label_idxes, BK.constants(oracle_label_idxes.shape, -label_margin)) + else: + # todo(note): otherwise, simply put zeros (with idx=0 as the slightly best to be consistent) + ret = BK.zeros(BK.get_shape(idxes_m_t) + [self.num_label]) + ret[:, :, 0] += 0.01 + if self.g1_lab_scores is not None: + ret += self.g1_lab_scores[dim2_range_t, idxes_m_t, idxes_h_t] + return ret + + # full arc scores (with margin): [*, m, n] + def get_arc_scores(self, oracle_mask_t, margin: float): + if margin <= 0.: + ret = self.arc_scores + else: + ret = self.arc_scores - margin * oracle_mask_t + if self.g1_arc_scores is not None: + # todo(warn): no adding inplaced to avoid change self scores + ret = ret + self.g1_arc_scores + return ret + +# (batched) running cache (main for repr and scoring) +class EfRunningCache: + def __init__(self, scorer: Scorer, slayer: SL0Layer, hm_feature_getter, max_slen, orig_bsize, + enc_repr, enc_mask_arr, g1_pack, insts, system_labeled): + self.scorer = scorer + self.slayer = slayer + self.hm_feature_getter = hm_feature_getter + # repr + self.enc_repr = None # repr output from encoder (not s_enc) for the current running + # scoring, todo(+2): choose by the flag here, not elegant though! + self.scoring_cache = EfScoringCacheArc(scorer, slayer, system_labeled) + # masks + self.scoring_fixed_mask_ct = None # fixed mask (0 for self_loop, root_mod and sent_mask) + self.scoring_mask_ct = None # mask before scoring (cpu_tensor) + # oracles (as masks) + # todo(note): for the current systems, the oracle masks can be fixed + self.oracle_mask_t = None # mask for oracle + self.oracle_mask_ct = None # same as previous, used for selection-validation + self.oracle_label_t = None # idxes for labels (this can be calculated at init) + # others + self.step = 0 + self.max_slen = max_slen + self.orig_bsize = orig_bsize + self.cur_bsize = -1 + self.bsize_range_t = None + self.update_bsize(orig_bsize) + # ----- + self.init_cache(enc_repr, enc_mask_arr, insts, g1_pack) + + def update_step(self): + self.step += 1 + + def update_bsize(self, new_bsize): + if new_bsize != self.cur_bsize: + self.cur_bsize = new_bsize + self.bsize_range_t = BK.arange_idx(new_bsize) + + # re-arrange the caches to make the first-dim batch corresponded + def arange_cache(self, bidxes): + new_bsize = len(bidxes) + # if the idxes are already fine, then no need to select + if not Helper.check_is_range(bidxes, self.cur_bsize): + # mask is on CPU to make assigning easier + bidxes_ct = BK.input_idx(bidxes, BK.CPU_DEVICE) + self.scoring_fixed_mask_ct = self.scoring_fixed_mask_ct.index_select(0, bidxes_ct) + self.scoring_mask_ct = self.scoring_mask_ct.index_select(0, bidxes_ct) + self.oracle_mask_ct = self.oracle_mask_ct.index_select(0, bidxes_ct) + # other things are all on target-device (possibly GPU) + bidxes_device = BK.to_device(bidxes_ct) + self.enc_repr = self.enc_repr.index_select(0, bidxes_device) + self.scoring_cache.arange_cache(bidxes_device) + # oracles + self.oracle_mask_t = self.oracle_mask_t.index_select(0, bidxes_device) + self.oracle_label_t = self.oracle_label_t.index_select(0, bidxes_device) + # update bsize + self.update_bsize(new_bsize) + + # get init fixed masks + def _init_fixed_mask(self, enc_mask_arr): + tmp_device = BK.CPU_DEVICE + # by token mask + mask_ct = BK.input_real(enc_mask_arr, device=tmp_device) # [*, len] + full_mask_ct = mask_ct.unsqueeze(-1) * mask_ct.unsqueeze(-2) # [*, len-mod, len-head] + # no self loop + full_mask_ct *= (1.-BK.eye(self.max_slen, device=tmp_device)) + # no root as mod; todo(warn): assume it is 3D + full_mask_ct[:, 0, :] = 0. + return full_mask_ct + + # init reprs and first-order scores + def init_cache(self, enc_repr, enc_mask_arr, insts, g1_pack): + # init caches and scores, [orig_bsize, max_slen, D] + self.enc_repr = enc_repr + self.scoring_fixed_mask_ct = self._init_fixed_mask(enc_mask_arr) + # init other masks + self.scoring_mask_ct = BK.copy(self.scoring_fixed_mask_ct) + full_shape = BK.get_shape(self.scoring_mask_ct) + # init oracle masks + oracle_mask_ct = BK.constants(full_shape, value=0., device=BK.CPU_DEVICE) + # label=0 means nothing, but still need it to avoid index error (dummy oracle for wrong/no-oracle states) + oracle_label_ct = BK.constants(full_shape, value=0, dtype=BK.int64, device=BK.CPU_DEVICE) + for i, inst in enumerate(insts): + EfOracler.init_oracle_mask(inst, oracle_mask_ct[i], oracle_label_ct[i]) + self.oracle_mask_t = BK.to_device(oracle_mask_ct) + self.oracle_mask_ct = oracle_mask_ct + self.oracle_label_t = BK.to_device(oracle_label_ct) + # scoring cache + self.scoring_cache.init_cache(enc_repr, g1_pack) + + # update reprs and scores incrementally + def update_cache(self, flattened_states: List[EfState]): + assert len(flattened_states) == self.cur_bsize, "Err: Mismatched batch size!" + # after adding an edge, head obtain an additional child and mod obtains an additional parent + # todo(+3): can these be simplified? + # 1. collect (batched) features; todo(note): use current state for updating scoring + hm_features = self.hm_feature_getter.get_hm_features([x.action for x in flattened_states], flattened_states) + # 2. get new sreprs + # todo(note): no recurrence or recursive here, therefore using the original enc_repr + s_enc = self.slayer + node_h_idxes_t, node_h_srepr = ScorerHelper.calc_repr(s_enc, hm_features[0], self.enc_repr, self.bsize_range_t) + node_m_idxes_t, node_m_srepr = ScorerHelper.calc_repr(s_enc, hm_features[1], self.enc_repr, self.bsize_range_t) + # 3. update cache and score + # todo(note): does not matter for the interactions, since inter-influenced parts are to-be-masked-out + # todo(+3): for some specific mode, can further reduce some calculations + self.scoring_cache.update_cache_and_score(node_h_idxes_t, self.bsize_range_t, node_h_srepr, True, True) + # mod cannot be mod again, thus no need to update (scores will be masked out later) + self.scoring_cache.update_cache_and_score(node_m_idxes_t, self.bsize_range_t, node_m_srepr, True, False) + + # ===== + # todo(warn): here margin is utilized to let oracle ones have less scores + + # label scores: [*, k] + def get_selected_label_scores(self, idxes_m_t, idxes_h_t, arc_margin: float, label_margin: float): + return self.scoring_cache.get_selected_label_scores(idxes_m_t, idxes_h_t, self.bsize_range_t, self.oracle_mask_t, self.oracle_label_t, arc_margin, label_margin) + + # full arc scores (with margin): [*, m, n] + def get_arc_scores(self, arc_margin: float): + return self.scoring_cache.get_arc_scores(self.oracle_mask_t, arc_margin) + +# handling repr/arc-score update +class EfExpander(BfsExpander): + def __init__(self, scorer: Scorer): + # self.scorer = scorer + self.cache: EfRunningCache = None + # self.margin: float = None + # self.cost_weight_arc: float = None + self.mw_arc: float = None + + def refresh(self, cache, margin, cost_weight_arc, cost_weight_label): + self.cache = cache + # self.margin = margin + # self.cost_weight_arc = cost_weight_arc + self.mw_arc = margin * cost_weight_arc + + def expand(self, ags: List[BfsAgenda]): + # flatten things out + flattened_states = [] + for ag in ags: + flattened_states.extend(ag.beam) + flattened_states.extend(ag.gbeam) + # read/write cache to change status + cur_cache = self.cache + # update reprs and scores + EfState.set_running_bidxes(flattened_states) + if cur_cache.step > 0: + bidxes = [s.prev.running_bidx for s in flattened_states] + # extend previous cache to the current bsize (dimension 0 at batch-dim) + cur_cache.arange_cache(bidxes) + # update cahces and scores + # todo(+N): for certain modes, some calculations are not needed + cur_cache.update_cache(flattened_states) + cur_cache.update_step() + # get new masks and final scores + scoring_mask_ct = cur_cache.scoring_mask_ct + for sidx, state in enumerate(flattened_states): + state.update_cands_mask(scoring_mask_ct[sidx]) # inplace mask update + scoring_mask_ct *= cur_cache.scoring_fixed_mask_ct # apply the fixed masks + scoring_mask_device = BK.to_device(scoring_mask_ct) + cur_arc_scores = cur_cache.get_arc_scores(self.mw_arc) + Constants.REAL_PRAC_MIN*(1.-scoring_mask_device) + # todo(+N): possible normalization for the scores + return flattened_states, cur_arc_scores, scoring_mask_ct + +# handling selecting and label-scoring +class EfLocalSelector(BfsLocalSelector): + def __init__(self, scorer, plain_mode, oracle_mode, plain_k_arc, plain_k_label, oracle_k_arc, oracle_k_label, + oracler: EfOracler, system_labeled): + super().__init__(plain_mode, oracle_mode, plain_k_arc, plain_k_label, oracle_k_arc, oracle_k_label, oracler) + # self.scorer = scorer + self.cache: EfRunningCache = None + # + # self.margin = 0. + self.system_labeled = system_labeled + self.mw_arc: float = None + self.mw_label: float = None + + def refresh(self, cache, margin, cost_weight_arc, cost_weight_label): + self.cache = cache + # self.margin = margin + self.mw_arc = margin * cost_weight_arc + self.mw_label = margin * cost_weight_label + + # get new state from the selected edge + # todo(+N): currently not recording the ScoreSlice here! + def _new_states(self, flattened_states: List[EfState], scoring_mask_ct, + topk_arc_scores, topk_m, topk_h, topk_label_scores, topk_label_idxes): + topk_arc_scores, topk_m, topk_h, topk_label_scores, topk_label_idxes = \ + (BK.get_value(z) for z in (topk_arc_scores, topk_m, topk_h, topk_label_scores, topk_label_idxes)) + new_states = [] + # for each batch element + for one_state, one_mask, one_arc_scores, one_ms, one_hs, one_label_scores, one_labels in \ + zip(flattened_states, scoring_mask_ct, topk_arc_scores, topk_m, topk_h, topk_label_scores, topk_label_idxes): + one_new_states = [] + # for each of the k arc selection + for cur_arc_score, cur_m, cur_h, cur_label_scores, cur_labels in \ + zip(one_arc_scores, one_ms, one_hs, one_label_scores, one_labels): + # first need that selection to be valid + cur_arc_score, cur_m, cur_h = cur_arc_score.item(), cur_m.item(), cur_h.item() + if one_mask[cur_m, cur_h].item() > 0.: + # for each of the label + for this_label_score, this_label in zip(cur_label_scores, cur_labels): + this_label_score, this_label = this_label_score.item(), this_label.item() + # todo(note): actually add new state; do not include label score if label does not come from ef + cur_all_score = (cur_arc_score+this_label_score) if self.system_labeled else cur_arc_score + this_new_state = one_state.build_next(action=EfAction(cur_h, cur_m, this_label), score=cur_all_score) + one_new_states.append(this_new_state) + new_states.append(one_new_states) + return new_states + + # get oracle mask from OracleManager + def _get_oracle_mask(self, flattened_states): + # todo(note): fixed oracle masks + cur_cache = self.cache + return cur_cache.oracle_mask_t, cur_cache.oracle_label_t + + # ===== + # these two only return flattened results + + def select_plain(self, ags: List[BfsAgenda], candidates, mode, k_arc, k_label) -> List[List]: + flattened_states, cur_arc_scores, scoring_mask_ct = candidates + cur_cache = self.cache + cur_bsize = len(flattened_states) + cur_slen = cur_cache.max_slen + cur_arc_scores_flattend = cur_arc_scores.view([cur_bsize, -1]) # [bs, Lm*Lh] + if mode == "topk": + # arcs [*, k] + topk_arc_scores, topk_arc_idxes = BK.topk( + cur_arc_scores_flattend, min(k_arc, BK.get_shape(cur_arc_scores_flattend, -1)), dim=-1, sorted=False) + topk_m, topk_h = topk_arc_idxes / cur_slen, topk_arc_idxes % cur_slen # [m, h] + # labels [*, k, k'] + cur_label_scores = cur_cache.get_selected_label_scores(topk_m, topk_h, self.mw_arc, self.mw_label) + topk_label_scores, topk_label_idxes = BK.topk( + cur_label_scores, min(k_label, BK.get_shape(cur_label_scores, -1)), dim=-1, sorted=False) + return self._new_states(flattened_states, scoring_mask_ct, topk_arc_scores, topk_m, topk_h, + topk_label_scores, topk_label_idxes) + elif mode == "": + return [[]] * cur_bsize + # todo(+N): other modes like sampling to be implemented: sample, topk-sample + else: + raise NotImplementedError(mode) + + def select_oracle(self, ags: List[BfsAgenda], candidates, mode, k_arc, k_label) -> List[List]: + flattened_states, cur_arc_scores, scoring_mask_ct = candidates + cur_cache = self.cache + cur_bsize = len(flattened_states) + cur_slen = cur_cache.max_slen + if mode == "topk": + # todo(note): there can be multiple oracles, select topk(usually top1) in this mode. + # get and apply oracle mask + cur_oracle_mask_t, cur_oracle_label_t = self._get_oracle_mask(flattened_states) + # [bs, Lm*Lh] + cur_oracle_arc_scores = (cur_arc_scores + Constants.REAL_PRAC_MIN*(1.-cur_oracle_mask_t)).view([cur_bsize, -1]) + # arcs [*, k] + topk_arc_scores, topk_arc_idxes = BK.topk(cur_oracle_arc_scores, k_arc, dim=-1, sorted=False) + topk_m, topk_h = topk_arc_idxes / cur_slen, topk_arc_idxes % cur_slen # [m, h] + # labels [*, k, 1] + # todo(note): here we gather labels since one arc can only have one oracle label + cur_label_scores = cur_cache.get_selected_label_scores(topk_m, topk_h, 0., 0.) # [*, k, labels] + topk_label_idxes = cur_oracle_label_t[cur_cache.bsize_range_t.unsqueeze(-1), topk_m, topk_h].unsqueeze(-1) # [*, k, 1] + # todo(+N): here is the trick to avoid repeated calculations, maybe not correct when using full dynamic oracle + topk_label_scores = BK.gather(cur_label_scores, topk_label_idxes, -1) - self.mw_label + # todo(+N): here use both masks, which may lead to no oracles! Can we simply drop the oracle_mask? + return self._new_states(flattened_states, scoring_mask_ct*cur_cache.oracle_mask_ct, topk_arc_scores, + topk_m, topk_h, topk_label_scores, topk_label_idxes) + elif mode == "": + return [[]] * cur_bsize + # todo(+N): other modes like sampling to be implemented: sample, topk-sample, gather + else: + raise NotImplementedError(mode) + +# the same as the general one +EfGlobalArranger = BfsGlobalArranger + +# the most basic ender +class EfEnder(BfsEnder): + def __init__(self, ending_mode): + super().__init__(ending_mode) + + def is_end(self, state: EfState): + return state.num_rest <= 0 + +# +class EfSearchConf(Conf): + def __init__(self): + # general system / expander + self.ef_mode = "free" # which ef system + self.nf_dist = 2 # distance for n2f mode + self._system_labeled = True # set by outside!! + # local selector + self.plain_mode = "topk" # topk/sample/... (empty "" means Nope) + self.plain_k_arc = 1 + self.plain_k_label = 1 + self.oracle_mode = "topk" + self.oracle_k_arc = 1 + self.oracle_k_label = 1 + # global selector + self.plain_beam_size = 1 + self.gold_beam_size = 1 + self.cost_type = "full" # "" means no coster + self.cost_weight_arc = 1. # arc-cost for margin + self.cost_weight_label = 1. # label-cost for margin + self.sig_type = "labeled" # labeled/unlabeled/"" + # ender + self.ending_mode = "plain" # plain/eu/bso/maxv + pass + +# todo(warn): currently can only search from scratch +# build(searcher); foreach-run { start; go; } +class EfSearcher(BfsSearcher): + def __init__(self, sconf: EfSearchConf, scorer: Scorer, slayer: SL0Layer): + super().__init__() + self.scorer: Scorer = scorer + self.slayer: SL0Layer = slayer + self.state_builder: StateBuilder = None + # + self.system_labeled = sconf._system_labeled + self.cost_weight_arc, self.cost_weight_label = sconf.cost_weight_arc, sconf.cost_weight_label + + def __repr__(self): + return f"" + + # todo(+3): ugly!! here set both local/global ones (to the same value of arc) + def set_plain_arc_beam_size(self, k): + assert k>=1 + self.local_selector.plain_k_arc = k + self.global_arranger.plain_beam_size = k + + # init the states and agendas + build cache + refresh components + def start(self, insts: List[ParseInstance], hm_feature_getter, enc_repr, enc_mask_arr, g1_pack, margin=0., require_sg=False): + # build cache and refresh + max_slen = enc_mask_arr.shape[-1] # padded max sent length + cache = EfRunningCache(self.scorer, self.slayer, hm_feature_getter, max_slen, len(insts), + enc_repr, enc_mask_arr, g1_pack, insts, self.system_labeled) + self.expander.refresh(cache, margin, self.cost_weight_arc, self.cost_weight_label) + self.local_selector.refresh(cache, margin, self.cost_weight_arc, self.cost_weight_label) + self.global_arranger.refresh(margin) + self.ender.refresh(margin) + # build init states and agendas + agendas = [] + for orig_bidx, inst in enumerate(insts): + sg = Graph() if require_sg else None + init_state = self.state_builder.build(sg=sg, inst=inst, orig_bidx=orig_bidx, max_slen=max_slen) + agendas.append(BfsAgenda(inst, init_beam=[init_state])) + return agendas + + # ===== + # the building/creating of the Searcher + @staticmethod + def build(sconf: EfSearchConf, scorer: Scorer, slayer: SL0Layer): + s = EfSearcher(sconf, scorer, slayer) + oracler = EfOracler() + if sconf.cost_type == "": + coster = None + else: + coster = EfCoster(sconf.cost_weight_arc, sconf.cost_weight_label) + if sconf.sig_type == "": + signaturer = None + else: + signaturer = EfSignaturer(sconf.sig_type == "labeled") + s.expander = EfExpander(scorer) + s.local_selector = EfLocalSelector(scorer, sconf.plain_mode, sconf.oracle_mode, sconf.plain_k_arc, sconf.plain_k_label, sconf.oracle_k_arc, sconf.oracle_k_label, oracler, sconf._system_labeled) + s.global_arranger = EfGlobalArranger(sconf.plain_beam_size, sconf.gold_beam_size, coster, signaturer) + s.ender = EfEnder(sconf.ending_mode) + s.state_builder = StateBuilder(sconf.ef_mode, sconf.nf_dist) + # + zlog("Finish building the searcher: " + str(s)) + return s + +# tasks/zdpar/ef/systems/base_search:397 diff --git a/tasks/zdpar/ef/systems/base_system.py b/tasks/zdpar/ef/systems/base_system.py new file mode 100644 index 0000000..1bc5bae --- /dev/null +++ b/tasks/zdpar/ef/systems/base_system.py @@ -0,0 +1,206 @@ +# + +# base State + +from typing import List +from array import array + +from msp.nn import BK +from msp.search.lsg import State, Action, Graph, Oracler, Coster, Signaturer + +from ...common.data import ParseInstance + +# ===== +# basic units + +class EfState(State): + def __init__(self, prev: 'EfState'=None, action: 'EfAction'=None, score=0., sg: Graph=None, + inst: ParseInstance=None, max_slen=-1, orig_bidx=-1): + super().__init__(prev, action, score, sg) + # todo(+N): can it be more efficient to maintain (not-copy) the structure info? + # record the basic info + if prev is None: + self.inst: ParseInstance = inst + self.num_tok = len(inst) + 1 # num of tokens plus artificial root + self.num_rest = self.num_tok - 1 # num of arcs remained to attach + self.list_arc = [-1] * self.num_tok # attached heads + self.list_label = [-1] * self.num_tok # attached labels + self.idxes_chs = [[] for _ in range(self.num_tok)] # child token idx list (by adding order) + self.labels_chs = [[] for _ in range(self.num_tok)] # should be the same size and corresponds to "ch_nodes" + else: + self.inst = prev.inst + self.num_tok = prev.num_tok + self.num_rest = prev.num_rest - 1 # currently always Attach actions + self.list_arc = prev.list_arc.copy() + self.list_label = prev.list_label.copy() + # self.idxes_chs = [x.copy() for x in prev.idxes_chs] + # self.labels_chs = [x.copy() for x in prev.labels_chs] + # todo(warn): only shallow copy here, later copy on write! + self.idxes_chs = prev.idxes_chs.copy() + self.labels_chs = prev.labels_chs.copy() + # only attach actions + cur_head, cur_mod, cur_label = action.head, action.mod, action.label + assert self.list_arc[cur_mod]<0 + self.list_arc[cur_mod] = cur_head + self.list_label[cur_mod] = cur_label + # self.idxes_chs[cur_head].append(cur_mod) + # self.labels_chs[cur_head].append(cur_label) + # copy on write + self.idxes_chs[cur_head] = self.idxes_chs[cur_head] + [cur_mod] + self.labels_chs[cur_head] = self.labels_chs[cur_head] + [cur_label] + # ===== + # other calculate as needed values + # useful for batching + self.running_bidx: int = None # running batch-idx (position in running batches) + # original batch-idx (instances) todo(warn): to be set at the start + if prev is not None: + self.orig_bidx = prev.orig_bidx + else: + self.orig_bidx = orig_bidx + # max length in the current batch; todo(+N): super non-elegant + if prev is not None: + self.max_slen = prev.max_slen + else: + self.max_slen = max_slen + # related with cost/oracle (whether current arc/label is correct) + self.wrong_al = None + + # ===== + # helpers + + @staticmethod + def set_running_bidxes(flattened_states: List['EfState']): + for idx, s in enumerate(flattened_states): + s.running_bidx = idx + + def set_orig_bidx(self, bidx: int): + self.orig_bidx = bidx + + # ===== + # specific procedures (default as EfFree System) + + # extend from self to another state + def build_next(self, action: 'EfAction', score: float): + raise NotImplementedError() + + # todo(note): prev_mask is prev's cands, and the illegal ones are already excluded at the very start + def update_cands_mask(self, prev_mask): + raise NotImplementedError() + + def update_oracle_mask(self, prev_arc_mask, prev_label): + raise NotImplementedError() + + # self.cost/self.cost_accu + # todo(+N): systems other than the top-down one are not arc-decomp and have future-loss because of cycle constraint! + # Therefore, this cost is less than the actual (plus-future) cost + # but maybe the (partial) oracle is still the correct oracle? + def set_cost(self, weight_arc: float, weight_label: float): + if self.cost_accu is not None: + return self.cost_accu + # for this system, simply consider only the last action will be fine + if self.prev is None: + self.cost_accu = self.cost = 0. + else: + prev_cost_accu = self.prev.set_cost(weight_arc, weight_label) + this_head, this_mod, this_label = self.action._key + gold_heads, gold_labels = self.inst.heads.vals, self.inst.labels.idxes + wrong_arc, wrong_label = int(gold_heads[this_mod]!=this_head), int(gold_labels[this_mod]!=this_label) + # todo(note): wrong_arc always means wrong_label! + wrong_label = min(wrong_arc+wrong_label, 1) + self.wrong_al = (wrong_arc, wrong_label) + self.cost = weight_arc * wrong_arc + weight_label * wrong_label + self.cost_accu = prev_cost_accu + self.cost + return self.cost_accu + + # todo(+2): are there more efficient ways? + # self.sig + def set_sig(self, labeled: bool): + if self.sig is not None: + return self.sig + atype = "b" if (self.num_tok<127) else "h" # todo(note): hope most sentences will be shorter than 127 + sig_list = (self.list_arc + self.list_label) if labeled else self.list_arc + sig = array(atype, sig_list).tobytes() + self.sig = sig + return sig + + # todo(WARN): currently in this system, costs are already added in Selector, but it still ignores further loss! + def get_loss_aug_score(self, margin): + # todo(+N): return self.score_accu + margin*self.future_loss + return self.score_accu + +# only Attach action +class EfAction(Action): + def __init__(self, head, mod, label): + super().__init__() + # info + self.head = head + self.mod = mod + self.label = label + # + self._key = (head, mod, label) + + def __repr__(self): + return f"{self.head}->{self.mod}[{self.label}]" + + def __hash__(self): + return hash(self._key) + + def __eq__(self, other): + return self._key == other._key + +# ===== +# + +# todo(warn): not used here! +# # group of candidates from one State, wait for Selector +# class EfCandidates(Candidates): +# def __init__(self, state: EfState, cands_mask: np.ndarray): +# self.state = state +# self.cands_mask = cands_mask + +# ===== +# other components + +class EfOracler(Oracler): + def set_oracle_masks(self, states: List[EfState], prev_oracle_mask_ct, prev_oracle_label_ct): + for idx, s in enumerate(states): + s.update_oracle_mask(prev_oracle_mask_ct[idx], prev_oracle_label_ct[idx]) + + # most of the time, only need to init once and the oracle masks need not changes + # todo(note): this is assigned at CPU + @staticmethod + def init_oracle_mask(inst: ParseInstance, prev_arc_mask, prev_label): + gold_heads = inst.heads.vals[1:] + gold_labels = inst.labels.idxes[1:] + gold_idxes = [i + 1 for i in range(len(gold_heads))] + prev_arc_mask[gold_idxes, gold_heads] = 1. + prev_label[gold_idxes, gold_heads] = BK.input_idx(gold_labels, BK.CPU_DEVICE) + return prev_arc_mask, prev_label + + def __repr__(self): + return "" + +class EfCoster(Coster): + def __init__(self, weight_arc: float, weight_label: float): + self.weight_arc = weight_arc + self.weight_label = weight_label + + def set_costs(self, states: List[EfState]): + weight_arc, weight_label = self.weight_arc, self.weight_label + for s in states: + s.set_cost(weight_arc, weight_label) + + def __repr__(self): + return f"" + +class EfSignaturer(Signaturer): + def __init__(self, labeled: bool): + self.labeled = labeled + + def set_sigs(self, states: List[EfState]): + labeled = self.labeled + for s in states: + s.set_sig(labeled) + + def __repr__(self): + return f"" diff --git a/tasks/zdpar/ef/systems/systems.py b/tasks/zdpar/ef/systems/systems.py new file mode 100644 index 0000000..9fced3e --- /dev/null +++ b/tasks/zdpar/ef/systems/systems.py @@ -0,0 +1,273 @@ +# + +# specific ef systems, currently only contain simple "reduce-free" ones +# todo(note): therefore, currently in all systems, wrong attachments will not include more "fated future cost" + +from array import array +import numpy as np + +from msp.utils import zlog +from msp.nn import BK +from .base_system import EfState, EfAction, Graph, ParseInstance + +# ----- +# helper to ensure no cycle + +# used mainly for excluding cycles +# todo(note): consistent status: 1) cache_uppermost works for all, 2) cache_descendants works for all unattached ones + +# ----- +# todo(warn): the mask based one is deprecated! +# class NoCycleCache00: +# def __init__(self, num_tok=None): +# if num_tok: +# self.cache_uppermost = BK.arange_idx(num_tok, device=BK.CPU_DEVICE) +# # descendants (including self); actually mask of [h,m], todo(note): only valid for unattached ones +# self.cache_descendants = BK.diagflat(BK.constants((num_tok,), 1, dtype=BK.uint8, device=BK.CPU_DEVICE)) +# else: +# self.cache_uppermost = self.cache_descendants = None +# +# def clone(self): +# x = NoCycleCache00() +# x.cache_uppermost = self.cache_uppermost.clone() +# x.cache_descendants = self.cache_descendants.clone() +# return x +# +# def update(self, mod, head): +# cache_uppermost, cache_descendants_masks = self.cache_uppermost, self.cache_descendants +# cur_upm_node = cache_uppermost[head] +# cur_descendants_mask = cache_descendants_masks[mod] +# # update +# cache_uppermost[cur_descendants_mask] = cur_upm_node +# # todo(note): only care about the uppermost one since others have been masked or no need (by single-rule) +# cache_descendants_masks[cur_upm_node, cur_descendants_mask] = 1 + +# the list based one +class NoCycleCache: + def __init__(self, num_tok=None): + if num_tok: + self.cache_uppermost = np.arange(num_tok, dtype=np.int32) + # descendants (including self); actually mask of [h,m], todo(note): only valid for unattached ones + self.cache_descendants = [[i] for i in range(num_tok)] + else: + self.cache_uppermost = self.cache_descendants = None + + def clone(self): + x = NoCycleCache() + x.cache_uppermost = self.cache_uppermost.copy() + x.cache_descendants = self.cache_descendants.copy() + return x + + def update(self, mod, head): + cache_uppermost, cache_descendants = self.cache_uppermost, self.cache_descendants + cur_upm_node = cache_uppermost[head] + cur_descendants = cache_descendants[mod] + # update + cache_uppermost[cur_descendants] = cur_upm_node + # todo(note): only care about the uppermost one since others have been masked or no need (by single-rule) + # copy on write!! + cache_descendants[cur_upm_node] = cache_descendants[cur_upm_node] + cur_descendants + +# ===== +# specific system: free-style + +class EfFreeState(EfState): + def __init__(self, prev: 'EfFreeState'=None, action: EfAction=None, score=0., sg: Graph=None, + inst: ParseInstance=None, max_slen=-1, orig_bidx=-1): + super().__init__(prev, action, score, sg, inst, max_slen, orig_bidx) + if prev is None: + self.nc_cache = NoCycleCache(self.max_slen) + else: + self.nc_cache = prev.nc_cache.clone() + self.nc_cache.update(action.mod, action.head) + + def build_next(self, action: EfAction, score: float): + return EfFreeState(self, action, score) + + # incrementally minus + def update_cands_mask(self, prev_mask): + # todo(note): for this mode, never add new cands but gradually eliminate illegal ones + if self.prev is not None: + # single head + recent_mod, recent_head = self.action.mod, self.action.head + prev_mask[recent_mod] = 0. + # no cycle + cache_uppermost, cache_descendants = self.nc_cache.cache_uppermost, self.nc_cache.cache_descendants + cur_upm_node = cache_uppermost[recent_head] + cur_descendants = cache_descendants[recent_mod] + prev_mask[cur_upm_node, cur_descendants] = 0. # these are enough + return prev_mask + +# ===== +# specific system: top-down + +class EfTdState(EfState): + def __init__(self, prev: 'EfTdState'=None, action: EfAction=None, score=0., sg: Graph=None, + inst: ParseInstance=None, max_slen=-1, orig_bidx=-1): + super().__init__(prev, action, score, sg, inst, max_slen, orig_bidx) + # + if prev is None: + self.attached_cache = [] # nodes that are already attached + else: + self.attached_cache = prev.attached_cache + [action.mod] + + def build_next(self, action: EfAction, score: float): + return EfTdState(self, action, score) + + # incrementally adding + def update_cands_mask(self, prev_mask): + if self.prev is None: + prev_mask.zero_() # clear all + prev_mask[:, 0] = 1. # only ROOT to others at the very first + else: + # allow one more head but disallow the new as mod + recent_mod = self.action.mod + prev_mask[:, recent_mod] = 1. + prev_mask[self.attached_cache, recent_mod] = 0. + prev_mask[recent_mod] = 0. + # todo(note): no need to exclude cycles since for top-down, cycles go to ROOT and are already excluded + return prev_mask + +# ===== +# specific system: directional: left-right/right-left + +class EfDirState(EfState): + def __init__(self, is_l2r, prev: 'EfDirState'=None, action: EfAction=None, score=0., sg: Graph=None, + inst: ParseInstance=None, max_slen=-1, orig_bidx=-1): + super().__init__(prev, action, score, sg, inst, max_slen, orig_bidx) + # + self.is_l2r = is_l2r + if is_l2r: + self.dir_next_idx = self.length + 1 + self.dir_step = 1 + else: + self.dir_next_idx = self.num_tok - 1 - self.length + self.dir_step = -1 + # + if prev is None: + self.nc_cache = NoCycleCache(self.max_slen) + else: + self.nc_cache = prev.nc_cache.clone() + self.nc_cache.update(action.mod, action.head) + + def build_next(self, action: EfAction, score: float): + return EfDirState(self.is_l2r, self, action, score) + + # build new ones everytime + def update_cands_mask(self, prev_mask): + if self.prev is None: + prev_mask.zero_() + cur_next_idx = self.dir_next_idx + prev_mask[cur_next_idx] = 1. + # clear last step + if cur_next_idx-self.dir_step < len(prev_mask): + prev_mask[cur_next_idx-self.dir_step] = 0. # Once-A-Bug: check for r2l mode + # eliminate cycle + cur_descendants = self.nc_cache.cache_descendants[cur_next_idx] + if len(cur_descendants) > 0: + prev_mask[cur_next_idx, cur_descendants] = 0. + return prev_mask + + # todo(note): in this system, no paths can be merged since path determines structures + def set_sig(self, labeled: bool): + return None + +# ===== +# specific system: bottom-up(distance) -> actually near-to-far(near-first) +# -- distance according to the unattached nodes list +# todo(+N): currently does not fused with reduction, that is, not strictly bottom-up; still the same oracle? +# but the problem is that when there are no oracles for some non-projective ones + +class EfNfState(EfState): + def __init__(self, dist, prev: 'EfNfState'=None, action: EfAction=None, score=0., sg: Graph=None, + inst: ParseInstance=None, max_slen=-1, orig_bidx=-1): + super().__init__(prev, action, score, sg, inst, max_slen, orig_bidx) + # + self.dist = dist # dist allowed for attaching + # non-cycle cache & left/right un-attach neighbours + max_len = self.num_tok + if prev is None: + self.nc_cache = NoCycleCache(self.max_slen) + # + self.left_link = list(range(-1, max_len-1)) + self.right_link = list(range(1, max_len+1)) + self.right_link[-1] = -1 # -1 means NULL + else: + self.nc_cache = prev.nc_cache.clone() + self.nc_cache.update(action.mod, action.head) + # + self.left_link = prev.left_link.copy() + self.right_link = prev.right_link.copy() + mod = self.action.mod + left, right = self.left_link[mod], self.right_link[mod] + self.right_link[left] = right # 0 cannot be mod + if right>0: # only if right is valid + self.left_link[right] = left + + def build_next(self, action: EfAction, score: float): + return EfNfState(self.dist, self, action, score) + + # incrementally adding: update for newly introduced "near nodes" + def update_cands_mask(self, prev_mask): + D = self.dist + max_len = self.num_tok + if self.prev is None: + prev_mask.zero_() + for i in range(1, max_len): + prev_mask[i, max(0,i-D):min(max_len,i+D+1)] = 1. # [i-D, i+D] + else: + recent_mod, recent_head = self.action.mod, self.action.head + cache_uppermost, cache_descendants = self.nc_cache.cache_uppermost, self.nc_cache.cache_descendants + # add new ones + # -- first collect the unattached neighbours in range + left_neighbours, right_neighbours = [recent_mod], [recent_mod] + tmp_left_idx = tmp_right_idx = recent_mod + for i in range(D): + if tmp_left_idx >= 0: + tmp_left_idx = self.left_link[tmp_left_idx] + left_neighbours.append(tmp_left_idx) + if tmp_right_idx >= 0: + tmp_right_idx = self.right_link[tmp_right_idx] + right_neighbours.append(tmp_right_idx) + # -- make both left2right + left_neighbours.reverse() + # -- traverse the neighbours and add the new range (but remember to remove cycles) + for i in range(D): + # [..., left0, left1, ...] <=> [..., right1, right0, ...] + left0, left1 = left_neighbours[i], left_neighbours[i+1] + right1, right0 = right_neighbours[i], right_neighbours[i+1] + # expand to new ones + # add left span to right (discard -1 since 0 is always the sentinel) + if left0>=0 and right0>=0: + prev_mask[right0, left0:left1] = 1. + prev_mask[right0, cache_descendants[right0]] = 0. + # add right span to the left (remember to consider the rest of the sentence) + if right1>=0 and left0>0: # 0 cannot have parent + right0 = max_len if right0<0 else right0 + prev_mask[left0, right1:right0+1] = 1. + prev_mask[left0, cache_descendants[left0]] = 0. + # similar to the Free-mode: single-head and non-cycle constraint for general purpose + prev_mask[recent_mod] = 0. + cur_upm_node = cache_uppermost[recent_head] + cur_descendants = cache_descendants[recent_mod] + prev_mask[cur_upm_node, cur_descendants] = 0. # these are enough + return prev_mask + +# ====== +# system factory + +class StateBuilder: + def __init__(self, mode, nf_dist): + self.mode = mode + self.nf_dist = nf_dist + self.build = { + "free": EfFreeState, + "t2d": EfTdState, + "l2r": lambda **kwargs: EfDirState(True, **kwargs), + "r2l": lambda **kwargs: EfDirState(False, **kwargs), + "n2f": lambda **kwargs: EfNfState(nf_dist, **kwargs) + }[mode] + + def __repr__(self): + mode, nf_dist = self.mode, self.nf_dist + return "" if mode=="n2f" else ">") diff --git a/tasks/zdpar/graph/parser.py b/tasks/zdpar/graph/parser.py index 4b25942..9ac994f 100644 --- a/tasks/zdpar/graph/parser.py +++ b/tasks/zdpar/graph/parser.py @@ -3,247 +3,82 @@ from typing import List import numpy as np -from msp.utils import Conf, zlog, JsonRW, Random, zfatal, Constants -from msp.data import VocabPackage, MultiHelper -from msp.model import Model +from msp.utils import zfatal, Constants +from msp.data import VocabPackage from msp.nn import BK -from msp.nn import refresh as nn_refresh -from msp.nn.layers import RefreshOptions -from msp.nn.modules import EmbedConf, MyEmbedder, EncConf, MyEncoder -from msp.zext.process_train import RConf, ScheduledValue, SVConf from msp.zext.seq_helper import DataPadder -from ..common.data import ParseInstance from .scorer import GraphScorerConf, GraphScorer +from ..common.data import ParseInstance +from ..common.model import BaseParserConf, BaseInferenceConf, BaseTrainingConf, BaseParser from ..algo import nmst_unproj, nmst_proj, nmst_greedy, nmarginal_unproj, nmarginal_proj -# overall parser conf -class GraphParserConf(Conf): - def __init__(self): - # - self.iconf = InferenceConf() - self.tconf = TraningConf() - # Model - self.emb_conf = EmbedConf().init_from_kwargs(dim_word=300, dim_char=30, dim_extras="50", extra_names="pos") - self.enc_conf = EncConf().init_from_kwargs(enc_rnn_type="lstm2", enc_hidden=1024, enc_rnn_layer=3) - self.sc_conf = GraphScorerConf() - # ===== - # other options - # inputs - self.char_max_length = 45 - # dropouts - self.drop_embed = 0.33 - self.dropmd_embed = 0. - self.drop_hidden = 0.33 - self.gdrop_rnn = 0.33 # gdrop (always fixed for recurrent connections) - self.idrop_rnn = 0.33 # idrop for rnn - self.fix_drop = True # fix drop for one run for each dropout - self.singleton_unk = 0.5 # replace singleton words with UNK when training - self.singleton_thr = 2 # only replace singleton if freq(val) <= this (also decay with 1/freq) - # ===== - # output modeling (local/global/single, proj/unproj, prob/hinge) - self.output_normalizing = "local" # local/global/single/hlocal - # self.iconf.dec_algorithm = "?" - # self.tconf.loss_function = "?" +# ===== +# confs # decoding conf -class InferenceConf(Conf): +class InferenceConf(BaseInferenceConf): def __init__(self): - # overall - self.batch_size = 32 - self.infer_single_length = 100 # single-inst batch if >= this length + super().__init__() # method self.dec_algorithm = "unproj" # proj/unproj/greedy self.dec_single_neg = False # also consider neg links for single-norm (but this might make it unstable for labels?) # training conf -class TraningConf(RConf): +class TraningConf(BaseTrainingConf): def __init__(self): super().__init__() - # about files - self.no_build_dict = False - self.load_model = False - self.load_process = False - # batch arranger - self.batch_size = 32 - self.train_skip_length = 120 - self.shuffle_train = True - # optimizer - self.optim = "adam" - self.sgd_momentum = 0.85 # for "sgd" - self.adam_betas = [0.9, 0.9] # for "adam" - self.adam_eps = 1e-4 # for "adam" - self.grad_clip = 5.0 - # overwrite default ones - self.patience = 8 - self.anneal_times = 10 - self.max_epochs = 200 # loss functions & general - self.margin = SVConf().init_from_kwargs(init_val=0.0) self.loss_div_tok = True # loss divide by token or by sent? self.loss_function = "prob" # prob/hinge/mr: probability or hinge/perceptron based or min-risk(partially supported) self.loss_single_sample = 2.0 # sampling negative ones (<1: rate, >=2: number) to balance for single-norm +# overall parser conf +class GraphParserConf(BaseParserConf): + def __init__(self): + super().__init__(InferenceConf(), TraningConf()) + # decoding part + self.sc_conf = GraphScorerConf() + # ===== + # output modeling (local/global/single, proj/unproj, prob/hinge) + self.output_normalizing = "local" # local/global/single/hlocal + # self.iconf.dec_algorithm = "?" + # self.tconf.loss_function = "?" + +# ===== +# model + # the model (handle inference & fber by itself, since usually does not involve complex search problem) -class GraphParser(Model): - def __init__(self, gconf: GraphParserConf, vpack: VocabPackage): - self.gconf = gconf - self.vpack = vpack - # ===== Vocab ===== - self.word_vocab = vpack.get_voc("word") - self.char_vocab = vpack.get_voc("char") - self.pos_vocab = vpack.get_voc("pos") - self.label_vocab = vpack.get_voc("label") - # ===== Model ===== - self.pc = BK.ParamCollection() - # embedding - self.emb = MyEmbedder(self.pc, gconf.emb_conf, vpack) - emb_output_dim = self.emb.get_output_dims()[0] - # encoder - # todo(0): feed compute-on-the-fly hp - gconf.enc_conf._input_dim = emb_output_dim - self.enc = MyEncoder(self.pc, gconf.enc_conf) - enc_output_dim = self.enc.get_output_dims()[0] - # decoder - gconf.sc_conf._input_dim = enc_output_dim - gconf.sc_conf._num_label = self.label_vocab.trg_len() - self.scorer = GraphScorer(self.pc, gconf.sc_conf) +class GraphParser(BaseParser): + def __init__(self, conf: GraphParserConf, vpack: VocabPackage): + super().__init__(conf, vpack) # ===== Input Specification ===== - # inputs (word, char, pos) and vocabulary - self.need_word = self.emb.has_word - self.need_char = self.emb.has_char - # todo(warn): currently only allow extra fields for POS - self.need_pos = False - if len(self.emb.extra_names) > 0: - assert len(self.emb.extra_names) == 1 and self.emb.extra_names[0]=="pos" - self.need_pos = True - # - self.word_padder = DataPadder(2, pad_vals=self.word_vocab.pad, mask_range=2) - self.char_padder = DataPadder(3, pad_lens=(0, 0, gconf.char_max_length), pad_vals=self.char_vocab.pad) - self.pos_padder = DataPadder(2, pad_vals=self.pos_vocab.pad) # both head/label padding with 0 (does not matter what, since will be masked) self.predict_padder = DataPadder(2, pad_vals=0) self.hlocal_padder = DataPadder(3, pad_vals=0.) # - # ===== For training ===== - # schedule values - tconf = gconf.tconf - self.margin = ScheduledValue("margin", tconf.margin) - self._scheduled_values = [self.margin] - # set optimizer (the init lr here does not matter!) - self.scorer.pc.optimizer_set(tconf.optim, 0., tconf) - # - # others - self.previous_refresh_training = True - # - # todo(warn): hlocal has problems intuitively, maybe not suitable for graph-parser + # todo(warn): adding-styled hlocal has problems intuitively, maybe not suitable for graph-parser self.norm_single, self.norm_local, self.norm_global, self.norm_hlocal = \ - [gconf.output_normalizing==z for z in ["single", "local", "global", "hlocal"]] - self.loss_prob, self.loss_hinge, self.loss_mr = [gconf.tconf.loss_function==z for z in ["prob", "hinge", "mr"]] - self.alg_proj, self.alg_unproj, self.alg_greedy = [gconf.iconf.dec_algorithm==z for z in ["proj", "unproj", "greedy"]] + [conf.output_normalizing==z for z in ["single", "local", "global", "hlocal"]] + self.loss_prob, self.loss_hinge, self.loss_mr = [conf.tconf.loss_function==z for z in ["prob", "hinge", "mr"]] + self.alg_proj, self.alg_unproj, self.alg_greedy = [conf.iconf.dec_algorithm==z for z in ["proj", "unproj", "greedy"]] # - # called before each mini-batch - def refresh_batch(self, training): - # refresh graph - # todo(warn): make sure to remember clear this one - nn_refresh() - # refresh node rop - mconf = self.gconf - if not training: - if not self.previous_refresh_training: - # todo(+1): currently no need to refresh testing mode multiple times - return - self.previous_refresh_training = False - embed_rop = other_rop = RefreshOptions(training=False) # default no dropout - else: - embed_rop = RefreshOptions(hdrop=mconf.drop_embed, dropmd=mconf.dropmd_embed, fix_drop=mconf.fix_drop) - other_rop = RefreshOptions(hdrop=mconf.drop_embed, idrop=mconf.idrop_rnn, gdrop=mconf.gdrop_rnn, fix_drop=mconf.fix_drop) - # todo(warn): once-bug, don't forget this one!! - self.previous_refresh_training = True - # - self.emb.refresh(embed_rop) - self.enc.refresh(other_rop) - self.scorer.refresh(other_rop) - - def update(self, lrate): - self.pc.optimizer_update(lrate) - - def get_scheduled_values(self): - return self._scheduled_values - - # load and save models - # todo(warn): no need to load confs here - def load(self, path): - self.scorer.pc.load(path) - # self.gconf = JsonRW.load_from_file(path+".json") - zlog("Load GraphParser model from %s." % path, func="io") - - def save(self, path): - self.scorer.pc.save(path) - JsonRW.save_to_file(self.gconf, path+".json") - zlog("Save GraphParser model to %s." % path, func="io") - - # ===== - # common procedures - def _prepare_input(self, insts, training): - word_arr, char_arr, extra_arrs = None, None, [] - # ===== specially prepare for the words - wv = self.word_vocab - W_UNK = wv.unk - UNK_REP_RATE = self.gconf.singleton_unk - UNK_REP_THR = self.gconf.singleton_thr - word_act_idxes = [] - if training and UNK_REP_RATE>0.: # replace unfreq/singleton words with UNK - for one_inst in insts: - one_act_idxes = [] - for one_idx in one_inst.words.idxes: - one_freq = wv.idx2val(one_idx) - if one_freq is not None and one_freq >= 1 and one_freq <= UNK_REP_THR: - if Random.random_bool(UNK_REP_RATE/one_freq): - one_idx = W_UNK - one_act_idxes.append(one_idx) - word_act_idxes.append(one_act_idxes) - else: - word_act_idxes = [z.words.idxes for z in insts] - # todo(warn): still need the masks - word_arr, mask_arr = self.word_padder.pad(word_act_idxes) - # ===== - if not self.need_word: - word_arr = None - if self.need_char: - chars = [z.chars.idxes for z in insts] - char_arr, _ = self.char_padder.pad(chars) - if self.need_pos: - poses = [z.poses.idxes for z in insts] - pos_arr, _ = self.pos_padder.pad(poses) - extra_arrs.append(pos_arr) - # - input_repr = self.emb(word_arr, char_arr, extra_arrs) - # [BS, Len, Dim], [BS, Len] - return input_repr, mask_arr - - # - def pred2real_labels(self, preds): - return [self.label_vocab.trg_pred2real(z) for z in preds] + def build_decoder(self): + conf = self.conf + conf.sc_conf._input_dim = self.enc_output_dim + conf.sc_conf._num_label = self.label_vocab.trg_len(False) + return GraphScorer(self.pc, conf.sc_conf) - def real2pred_labels(self, reals): - return [self.label_vocab.trg_real2pred(z) for z in reals] - - # shared calculations before final scoring + # shared calculations for final scoring # -> the scores are masked with PRAC_MIN (by the scorer) for the paddings, but not handling diag here! def _prepare_score(self, insts, training): - # ===== calculate - # [BS, Len, Di], [BS, Len] - input_repr, mask_arr = self._prepare_input(insts, training) - # [BS, Len, De] - enc_repr = self.enc(input_repr, mask_arr) - # + input_repr, enc_repr, jpos_pack, mask_arr = self.bter.run(insts, training) mask_expr = BK.input_real(mask_arr) # am_expr, ah_expr, lm_expr, lh_expr = self.scorer.transform_space(enc_repr) scoring_expr_pack = self.scorer.transform_space(enc_repr) - return scoring_expr_pack, mask_expr + return scoring_expr_pack, mask_expr, jpos_pack # scoring procedures def _score_arc_full(self, scoring_expr_pack, mask_expr, training, margin, gold_heads_expr=None): @@ -289,7 +124,12 @@ def _score_label_selected(self, scoring_expr_pack, mask_expr, training, margin, # for global-norm + hinge(perceptron-like)-loss # [*, m, h, L], [*, m] - def _losses_global_hinge(self, full_score_expr, gold_heads_expr, gold_labels_expr, pred_heads_expr, pred_labels_expr): + def _losses_global_hinge(self, full_score_expr, gold_heads_expr, gold_labels_expr, pred_heads_expr, pred_labels_expr, mask_expr): + return GraphParser.get_losses_global_hinge(full_score_expr, gold_heads_expr, gold_labels_expr, pred_heads_expr, pred_labels_expr, mask_expr) + + # ===== export + @staticmethod + def get_losses_global_hinge(full_score_expr, gold_heads_expr, gold_labels_expr, pred_heads_expr, pred_labels_expr, mask_expr, clamping=True): # combine the last two dimension full_shape = BK.get_shape(full_score_expr) # [*, m, h*L] @@ -302,8 +142,15 @@ def _losses_global_hinge(self, full_score_expr, gold_heads_expr, gold_labels_exp gold_scores = BK.gather_one_lastdim(combiend_score_expr, gold_combined_idx_expr).squeeze(-1) pred_scores = BK.gather_one_lastdim(combiend_score_expr, pred_combined_idx_expr).squeeze(-1) # todo(warn): be aware of search error! - hinge_losses = BK.clamp(pred_scores - gold_scores, min=0.) - return hinge_losses + # hinge_losses = BK.clamp(pred_scores - gold_scores, min=0.) # this is previous version + hinge_losses = pred_scores - gold_scores # [*, len] + if clamping: + valid_losses = ((hinge_losses*mask_expr)[:, 1:].sum(-1) > 0.).float().unsqueeze(-1) # [*, 1] + return hinge_losses * valid_losses + else: + # for this mode, will there be problems of search error? Maybe rare. + return hinge_losses + # ===== # for global-norm + prob-loss def _losses_global_prob(self, full_score_expr, gold_heads_expr, gold_labels_expr, marginals_expr, mask_expr): @@ -377,7 +224,7 @@ def _decode(self, full_score_expr, maske_expr, lengths_arr): elif self.alg_greedy: return nmst_greedy(full_score_expr, maske_expr, lengths_arr, labeled=True, ret_arr=True) else: - zfatal("Unknown decoding algorithm " + self.gconf.iconf.dec_algorithm) + zfatal("Unknown decoding algorithm " + self.conf.iconf.dec_algorithm) return None # expr[BS, m, h, L], arr[BS] -> expr[BS, m, h, L] @@ -387,17 +234,26 @@ def _marginal(self, full_score_expr, maske_expr, lengths_arr): elif self.alg_proj: marginals_expr = nmarginal_proj(full_score_expr, maske_expr, lengths_arr, labeled=True) else: - zfatal("Unsupported marginal-calculation for the decoding algorithm of " + self.gconf.iconf.dec_algorithm) + zfatal("Unsupported marginal-calculation for the decoding algorithm of " + self.conf.iconf.dec_algorithm) marginals_expr = None return marginals_expr # ===== main methods: training and decoding + # only getting scores + def score_on_batch(self, insts: List[ParseInstance]): + with BK.no_grad_env(): + self.refresh_batch(False) + scoring_expr_pack, mask_expr, jpos_pack = self._prepare_score(insts, False) + full_arc_score = self._score_arc_full(scoring_expr_pack, mask_expr, False, 0.) + full_label_score = self._score_label_full(scoring_expr_pack, mask_expr, False, 0.) + return full_arc_score, full_label_score + # full score and inference - def inference_on_batch(self, insts: List[ParseInstance]): + def inference_on_batch(self, insts: List[ParseInstance], **kwargs): with BK.no_grad_env(): self.refresh_batch(False) # ===== calculate - scoring_expr_pack, mask_expr = self._prepare_score(insts, False) + scoring_expr_pack, mask_expr, jpos_pack = self._prepare_score(insts, False) full_arc_score = self._score_arc_full(scoring_expr_pack, mask_expr, False, 0.) full_label_score = self._score_label_full(scoring_expr_pack, mask_expr, False, 0.) # normalizing scores @@ -410,7 +266,7 @@ def inference_on_batch(self, insts: List[ParseInstance]): # normalize at m dimension, ignore each nodes's self-finish step. full_score = BK.log_softmax(full_arc_score, -2).unsqueeze(-1) + BK.log_softmax(full_label_score, -1) elif self.norm_single and self.loss_prob: - if self.gconf.iconf.dec_single_neg: + if self.conf.iconf.dec_single_neg: # todo(+2): add all-neg for prob explanation full_arc_probs = BK.sigmoid(full_arc_score) full_label_probs = BK.sigmoid(full_label_score) @@ -427,6 +283,10 @@ def inference_on_batch(self, insts: List[ParseInstance]): mst_heads_arr, mst_labels_arr, mst_scores_arr = self._decode(full_score, mask_expr, np.asarray(mst_lengths, dtype=np.int32)) if final_exp_score: mst_scores_arr = np.exp(mst_scores_arr) + # jpos prediction (directly index, no converting as in parsing) + jpos_preds_expr = jpos_pack[2] + has_jpos_pred = jpos_preds_expr is not None + jpos_preds_arr = BK.get_value(jpos_preds_expr) if has_jpos_pred else None # ===== assign info = {"sent": len(insts), "tok": sum(mst_lengths)-len(insts)} mst_real_labels = self.pred2real_labels(mst_labels_arr) @@ -435,11 +295,13 @@ def inference_on_batch(self, insts: List[ParseInstance]): one_inst.pred_heads.set_vals(mst_heads_arr[one_idx][:cur_length]) # directly int-val for heads one_inst.pred_labels.build_vals(mst_real_labels[one_idx][:cur_length], self.label_vocab) one_inst.pred_par_scores.set_vals(mst_scores_arr[one_idx][:cur_length]) + if has_jpos_pred: + one_inst.pred_poses.build_vals(jpos_preds_arr[one_idx][:cur_length], self.bter.pos_vocab) return info # list(mini-batch) of annotated instances # optional results are written in-place? return info. - def fb_on_batch(self, annotated_insts, training=True): + def fb_on_batch(self, annotated_insts, training=True, loss_factor=1, **kwargs): self.refresh_batch(training) margin = self.margin.value # gold heads and labels @@ -448,7 +310,7 @@ def fb_on_batch(self, annotated_insts, training=True): gold_heads_expr = BK.input_idx(gold_heads_arr) # [BS, Len] gold_labels_expr = BK.input_idx(gold_labels_arr) # [BS, Len] # ===== calculate - scoring_expr_pack, mask_expr = self._prepare_score(annotated_insts, training) + scoring_expr_pack, mask_expr, jpos_pack = self._prepare_score(annotated_insts, training) full_arc_score = self._score_arc_full(scoring_expr_pack, mask_expr, training, margin, gold_heads_expr) # final_losses = None @@ -462,7 +324,7 @@ def fb_on_batch(self, annotated_insts, training=True): losses_heads = BK.loss_nll(full_arc_score, gold_heads_expr) losses_labels = BK.loss_nll(select_label_score, gold_labels_expr) elif self.norm_single: - single_sample = self.gconf.tconf.loss_single_sample + single_sample = self.conf.tconf.loss_single_sample losses_heads = self._losses_single(full_arc_score, gold_heads_expr, single_sample, is_hinge=False) losses_labels = self._losses_single(select_label_score, gold_labels_expr, single_sample, is_hinge=False) # simply adding @@ -472,7 +334,7 @@ def fb_on_batch(self, annotated_insts, training=True): losses_heads = BK.loss_hinge(full_arc_score, gold_heads_expr) losses_labels = BK.loss_hinge(select_label_score, gold_labels_expr) elif self.norm_single: - single_sample = self.gconf.tconf.loss_single_sample + single_sample = self.conf.tconf.loss_single_sample losses_heads = self._losses_single(full_arc_score, gold_heads_expr, single_sample, is_hinge=True, margin=margin) losses_labels = self._losses_single(select_label_score, gold_labels_expr, single_sample, is_hinge=True, margin=margin) # simply adding @@ -502,7 +364,7 @@ def fb_on_batch(self, annotated_insts, training=True): if self.alg_proj: # todo(+N): deal with search-error-like problem, discard unproj neg losses (score>weighted-avg), # but this might be too loose, although the unproj edges are few? - gold_unproj_arr, _ = self.predict_padder.pad([z.unprojs.vals for z in annotated_insts]) + gold_unproj_arr, _ = self.predict_padder.pad([z.unprojs for z in annotated_insts]) gold_unproj_expr = BK.input_real(gold_unproj_arr) # [BS, Len] comparing_expr = Constants.REAL_PRAC_MIN * (1. - gold_unproj_expr) final_losses = BK.max_elem(final_losses, comparing_expr) @@ -512,7 +374,7 @@ def fb_on_batch(self, annotated_insts, training=True): pred_labels_expr = BK.input_idx(pred_labels_arr) # [BS, Len] # final_losses = self._losses_global_hinge(full_score, gold_heads_expr, gold_labels_expr, - pred_heads_expr, pred_labels_expr) + pred_heads_expr, pred_labels_expr, mask_expr) elif self.loss_mr: # todo(+N): Loss = -Reward = \sum marginals, which requires gradients on marginal-one-edge, or marginal-two-edges raise NotImplementedError("Not implemented for global-loss + mr.") @@ -532,13 +394,17 @@ def fb_on_batch(self, annotated_insts, training=True): losses_arc[:, 1] += losses_arc[:, 0] final_losses = losses_arc + losses_labels # + # jpos loss? (the same mask as parsing) + jpos_losses_expr = jpos_pack[1] + if jpos_losses_expr is not None: + final_losses += jpos_losses_expr # collect loss with mask, also excluding the first symbol of ROOT final_losses_masked = (final_losses * mask_expr)[:, 1:] final_loss_sum = BK.sum(final_losses_masked) # divide loss by what? num_sent = len(annotated_insts) num_valid_tok = sum(len(z) for z in annotated_insts) - if self.gconf.tconf.loss_div_tok: + if self.conf.tconf.loss_div_tok: final_loss = final_loss_sum / num_valid_tok else: final_loss = final_loss_sum / num_sent @@ -546,21 +412,8 @@ def fb_on_batch(self, annotated_insts, training=True): final_loss_sum_val = float(BK.get_value(final_loss_sum)) info = {"sent": num_sent, "tok": num_valid_tok, "loss_sum": final_loss_sum_val} if training: - BK.backward(final_loss) + BK.backward(final_loss, loss_factor) return info - # ===== - # special routine - def aug_words_and_embs(self, aug_vocab, aug_wv): - orig_vocab = self.word_vocab - orig_arr = self.emb.word_embed.E.detach().cpu().numpy() - # todo(+2): find same-spelling words in the original vocab if not-hit in the extra_embed? - aug_arr = aug_vocab.filter_embed(aug_wv, init_nohit=0.) - new_vocab, new_arr = MultiHelper.aug_vocab_and_arr(orig_vocab, orig_arr, aug_vocab, aug_arr) - # assign - self.word_vocab = new_vocab - self.emb.word_embed.replace_weights(new_arr) - return new_vocab - # pdb: # b tasks/zdpar/graph/parser:232 diff --git a/tasks/zdpar/graph/scorer.py b/tasks/zdpar/graph/scorer.py index e8baade..768123b 100644 --- a/tasks/zdpar/graph/scorer.py +++ b/tasks/zdpar/graph/scorer.py @@ -4,7 +4,7 @@ from msp.utils import Conf, zcheck from msp.nn import BK -from msp.nn.layers import BasicNode, Affine, BiAffineScorer +from msp.nn.layers import BasicNode, Affine, BiAffineScorer, AttConf, AttDistHelper # class GraphScorerConf(Conf): @@ -19,6 +19,17 @@ def __init__(self): self.ff_hid_layer = 0 self.use_biaffine = True self.use_ff = True + self.use_ff2 = False + self.biaffine_div = 1. + # + self.transform_act = "elu" + self.biaffine_init_ortho = False + # distance clip? + self.arc_dist_clip = -1 + self.arc_use_neg = False + + def get_dist_aconf(self): + return AttConf().init_from_kwargs(clip_dist=self.arc_dist_clip, use_neg_dist=self.arc_use_neg) # [*, len, D] -> attach-score [*, m-len, h-len], label-score [*, m-len, h-len, L] class GraphScorer(BasicNode): @@ -32,15 +43,25 @@ def __init__(self, pc: BK.ParamCollection, sconf: GraphScorerConf): ff_hid_layer = sconf.ff_hid_layer use_biaffine = sconf.use_biaffine use_ff = sconf.use_ff + use_ff2 = sconf.use_ff2 + self.arc_dist_clip = sconf.arc_dist_clip # scorer + biaffine_div = sconf.biaffine_div + biaffine_init_ortho = sconf.biaffine_init_ortho + transform_act = sconf.transform_act # attach/arc - self.arc_m = self.add_sub_node("am", Affine(pc, input_dim, arc_space, act="elu")) - self.arc_h = self.add_sub_node("ah", Affine(pc, input_dim, arc_space, act="elu")) - self.arc_scorer = self.add_sub_node("as", BiAffineScorer(pc, arc_space, arc_space, 1, ff_hid_size, ff_hid_layer=ff_hid_layer, use_biaffine=use_biaffine, use_ff=use_ff)) + self.arc_m = self.add_sub_node("am", Affine(pc, input_dim, arc_space, act=transform_act)) + self.arc_h = self.add_sub_node("ah", Affine(pc, input_dim, arc_space, act=transform_act)) + self.arc_scorer = self.add_sub_node("as", BiAffineScorer(pc, arc_space, arc_space, 1, ff_hid_size, ff_hid_layer=ff_hid_layer, use_biaffine=use_biaffine, use_ff=use_ff, use_ff2=use_ff2, biaffine_div=biaffine_div, biaffine_init_ortho=biaffine_init_ortho)) + # only add distance for arc + if sconf.arc_dist_clip > 0: + self.dist_helper = self.add_sub_node("dh", AttDistHelper(pc, sconf.get_dist_aconf(), arc_space)) + else: + self.dist_helper = None # labeling - self.lab_m = self.add_sub_node("lm", Affine(pc, input_dim, lab_space, act="elu")) - self.lab_h = self.add_sub_node("lh", Affine(pc, input_dim, lab_space, act="elu")) - self.lab_scorer = self.add_sub_node("ls", BiAffineScorer(pc, lab_space, lab_space, sconf._num_label, ff_hid_size, ff_hid_layer=ff_hid_layer, use_biaffine=use_biaffine, use_ff=use_ff)) + self.lab_m = self.add_sub_node("lm", Affine(pc, input_dim, lab_space, act=transform_act)) + self.lab_h = self.add_sub_node("lh", Affine(pc, input_dim, lab_space, act=transform_act)) + self.lab_scorer = self.add_sub_node("ls", BiAffineScorer(pc, lab_space, lab_space, sconf._num_label, ff_hid_size, ff_hid_layer=ff_hid_layer, use_biaffine=use_biaffine, use_ff=use_ff, use_ff2=use_ff2, biaffine_div=biaffine_div, biaffine_init_ortho=biaffine_init_ortho)) # transform to specific space (currently arc/label * m/h) # [*, len, input_dim] -> *[*, len, space_dim] @@ -54,7 +75,12 @@ def transform_space(self, enc_expr): # ===== scorers # score all pairs: [*, len1, D1], [*, len2, D2], [*, len1], [*, len2] -> [*, len1, len2] def score_arc_all(self, am_expr, ah_expr, m_mask_expr, h_mask_expr): - arc_scores = self.arc_scorer.paired_score(am_expr, ah_expr, m_mask_expr, h_mask_expr) + if self.dist_helper: + lenm, lenh = BK.get_shape(am_expr, -2), BK.get_shape(ah_expr, -2) + ah_rel1, _ = self.dist_helper(lenm, lenh) + else: + ah_rel1 = None + arc_scores = self.arc_scorer.paired_score(am_expr, ah_expr, m_mask_expr, h_mask_expr, rel1_t=ah_rel1) ret = arc_scores.squeeze(-1) # squeeze the last one return ret @@ -63,11 +89,11 @@ def score_label_all(self, lm_expr, lh_expr, m_mask_expr, h_mask_expr): lab_scores = self.lab_scorer.paired_score(lm_expr, lh_expr, m_mask_expr, h_mask_expr) return lab_scores - # score selected paired inputs: [*, len, D1], [*, len, D2], [*, len] -> [*, len] - def score_arc_select(self, am_expr_sel, ah_expr_sel, mask): - arc_scores = self.arc_scorer.plain_score(am_expr_sel, ah_expr_sel, mask, None) - ret = arc_scores.squeeze(-1) # squeeze the last one - return ret + # # score selected paired inputs: [*, len, D1], [*, len, D2], [*, len] -> [*, len] + # def score_arc_select(self, am_expr_sel, ah_expr_sel, mask): + # arc_scores = self.arc_scorer.plain_score(am_expr_sel, ah_expr_sel, mask, None) + # ret = arc_scores.squeeze(-1) # squeeze the last one + # return ret # score selected paris' all labels: ... -> [*, len, N] def score_label_select(self, am_expr_sel, ah_expr_sel, mask): diff --git a/tasks/zdpar/graph/stest.py b/tasks/zdpar/graph/stest.py new file mode 100644 index 0000000..56597c4 --- /dev/null +++ b/tasks/zdpar/graph/stest.py @@ -0,0 +1,6 @@ +# + +# special testing for graph parsers +# comparing the scores and outputs of different output normalizations + +# todo(+N) how to compare different models? diff --git a/tasks/zdpar/main/analyze.py b/tasks/zdpar/main/analyze.py new file mode 100644 index 0000000..aac8162 --- /dev/null +++ b/tasks/zdpar/main/analyze.py @@ -0,0 +1,110 @@ +# + +# especially analyzing on graph-like models: g1p/s2p (with-model analysis) +# like decoding, but mainly for some analysis: like pruning, marginals, ... + +import numpy as np +import pandas as pd + +from msp.utils import StatRecorder, Helper, zlog + +from ..ef.parser import G1Parser, G1ParserConf +from ..ef.parser.g1p import PruneG1Conf, ParseInstance +from ..common.confs import DepParserConf +from .test import prepare_test + +class AnalyzeConf(DepParserConf): + def __init__(self, parser_type, args): + # extra ones + self.zprune = PruneG1Conf() + # + super().__init__(parser_type, args) + +class ZObject: + pass + +def main(args): + conf, model, vpack, test_iter = prepare_test(args, AnalyzeConf) + # make sure the model is order 1 graph model, otherwise cannot run through + assert isinstance(model, G1Parser) and isinstance(conf.pconf, G1ParserConf) + # ===== + # helpers + all_stater = StatRecorder(False) + def _stat(k, v): + all_stater.record_kv(k, v) + # check agreement + def _agree2(a, b, name): + agreement = (np.asarray(a) == np.asarray(b)) + num_agree = int(agreement.sum()) + _stat(name, num_agree) + # do not care about efficiency here! + step2_pack = [] + for cur_insts in test_iter: + # score and prune + valid_mask, arc_score, label_score, mask_expr, marginals = model.prune_on_batch(cur_insts, conf.zprune) + # greedy on raw scores + greedy_label_scores, greedy_label_mat_idxes = label_score.max(-1) # [*, m, h] + greedy_all_scores, greedy_arc_idxes = (arc_score+greedy_label_scores).max(-1) # [*, m] + greedy_label_idxes = greedy_label_mat_idxes.gather(-1, greedy_arc_idxes.unsqueeze(-1)).squeeze(-1) # [*, m] + # greedy on marginals (arc only) + greedy_marg_arc_scores, greedy_marg_arc_idxes = marginals.max(-1) # [*, m] + entropy_marg = - (marginals * (marginals + 1e-10 * (marginals==0.).float()).log()).sum(-1) # [*, m] + # decode + model.inference_on_batch(cur_insts) + # ===== + z = ZObject() + keys = list(locals().keys()) + for k in keys: + v = locals()[k] + try: + setattr(z, k, v.cpu().detach().numpy()) + except: + pass + # ===== + for idx in range(len(cur_insts)): + one_inst: ParseInstance = cur_insts[idx] + one_len = len(one_inst) + 1 # [1, len) + _stat("all_edges", one_len-1) + arc_gold = one_inst.heads.vals[1:] + arc_mst = one_inst.pred_heads.vals[1:] + arc_gma = z.greedy_marg_arc_idxes[idx][1:one_len] + # step 1: decoding agreement, how many edges agree: gold, mst-decode, greedy-marginal + arcs = {"gold": arc_gold, "mst": arc_mst, "gma": arc_gma} + cmp_keys = sorted(arcs.keys()) + for i in range(len(cmp_keys)): + for j in range(i+1, len(cmp_keys)): + n1, n2 = cmp_keys[i], cmp_keys[j] + _agree2(arcs[n1], arcs[n2], f"{n1}_{n2}") + # step 2: confidence + arc_agree = (np.asarray(arc_gold) == np.asarray(arc_mst)) + arc_marginals_mst = z.marginals[idx][range(1, one_len), arc_mst] + arc_marginals_gold = z.marginals[idx][range(1, one_len), arc_gold] + arc_entropy = z.entropy_marg[idx][1:one_len] + for tidx in range(one_len-1): + step2_pack.append([int(arc_agree[tidx]), min(1., float(arc_marginals_mst[tidx])), + min(1., float(arc_marginals_gold[tidx])), float(arc_entropy[tidx])]) + # step 2: bucket by marginals + if True: + NUM_BUCKET = 10 + df = pd.DataFrame(step2_pack, columns=['agree', 'm_mst', 'm_gold', 'entropy']) + z = df.sort_values(by='m_mst', ascending=False) + z.to_csv('res.csv') + for cur_b in range(NUM_BUCKET): + interval = 1. / NUM_BUCKET + r0, r1 = cur_b*interval, (cur_b+1)*interval + cur_v = df[(df.m_mst>=r0) & ((df.m_mst dconf.pretrain_init_nohit, dconf.pretrain_file, ... +2. testing +// (currently 19.03.07, prefer using the Option 3 above since no need to use special data) +For testing, can plus augment embeddings and vocabs. +=> dconf.test_extra_pretrain_files, dconf.code_*, ... +""" diff --git a/tasks/zdpar/main/train.py b/tasks/zdpar/main/train.py index 8e355bf..9fb797f 100644 --- a/tasks/zdpar/main/train.py +++ b/tasks/zdpar/main/train.py @@ -5,7 +5,7 @@ from msp.data import MultiCatStreamer, InstCacher from ..common.confs import DepParserConf, init_everything, build_model -from ..common.data import get_data_reader +from ..common.data import get_data_reader, get_multisoure_data_reader from ..common.vocab import DConf, ParserVocabPackage from ..common.run import index_stream, batch_stream, ParserTrainingRunner @@ -15,28 +15,51 @@ def main(args): dconf, pconf = conf.dconf, conf.pconf tconf = pconf.tconf iconf = pconf.iconf + # dev/test can be non-existing! + if not dconf.dev and dconf.test: + utils.zwarn("No dev but give test, actually use test as dev (for early stopping)!!") + dt_golds, dt_codes, dt_aux_reprs, dt_aux_scores, dt_cuts = [], [], [], [], [] + for file, code, aux_repr, aux_score, one_cut in \ + [(dconf.dev, dconf.code_dev, dconf.aux_repr_dev, dconf.aux_score_dev, dconf.cut_dev), + (dconf.test, dconf.code_test, dconf.aux_repr_test, dconf.aux_score_test, "")]: # no cut for test! + if len(file)>0: + utils.zlog(f"Add file `{file}(code={code}, aux_repr={aux_repr}, aux_scpre={aux_score}, cut={one_cut})'" + f" as dt-file #{len(dt_golds)}.") + dt_golds.append(file) + dt_codes.append(code) + dt_aux_reprs.append(aux_repr) + dt_aux_scores.append(aux_score) + dt_cuts.append(one_cut) # data - train_streamer = get_data_reader(dconf.train, dconf.input_format, dconf.code_train, dconf.use_label0) - dev_streamer = get_data_reader(dconf.dev, dconf.input_format, dconf.code_dev, dconf.use_label0) - test_streamer = get_data_reader(dconf.test, dconf.input_format, dconf.code_test, dconf.use_label0) + if dconf.multi_source: + _reader_getter = get_multisoure_data_reader + else: + _reader_getter = get_data_reader + train_streamer = _reader_getter(dconf.train, dconf.input_format, dconf.code_train, dconf.use_label0, + dconf.aux_repr_train, dconf.aux_score_train, cut=dconf.cut_train) + dt_streamers = [_reader_getter(f, dconf.input_format, c, dconf.use_label0, aux_r, aux_s, cut=one_cut) + for f, c, aux_r, aux_s, one_cut in zip(dt_golds, dt_codes, dt_aux_reprs, dt_aux_scores, dt_cuts)] # vocab if tconf.no_build_dict: vpack = ParserVocabPackage.build_by_reading(dconf) else: # include dev/test only for convenience of including words hit in pre-trained embeddings - vpack = ParserVocabPackage.build_from_stream(dconf, train_streamer, MultiCatStreamer([dev_streamer, test_streamer])) + vpack = ParserVocabPackage.build_from_stream(dconf, train_streamer, MultiCatStreamer(dt_streamers)) vpack.save(dconf.dict_dir) - # index the data - to_cache = dconf.cache_data - train_iter = batch_stream(index_stream(train_streamer, vpack, to_cache), tconf, True) - dev_iter = batch_stream(index_stream(dev_streamer, vpack, to_cache), iconf, False) - test_iter = batch_stream(index_stream(test_streamer, vpack, to_cache), iconf, False) # model model = build_model(conf.partype, conf, vpack) + # index the data + train_inst_preparer = model.get_inst_preper(True) + test_inst_preparer = model.get_inst_preper(False) + to_cache = dconf.cache_data + to_cache_shuffle = dconf.to_cache_shuffle + # todo(note): make sure to cache both train and dev to save time for cached computation + train_iter = batch_stream(index_stream(train_streamer, vpack, to_cache, to_cache_shuffle, train_inst_preparer), tconf, True) + dt_iters = [batch_stream(index_stream(z, vpack, to_cache, to_cache_shuffle, test_inst_preparer), iconf, False) for z in dt_streamers] # training runner - tr = ParserTrainingRunner(tconf, model, vpack, dev_outfs=dconf.output_file, dev_goldfs=[dconf.dev, dconf.test], dev_out_format=dconf.output_format) + tr = ParserTrainingRunner(tconf, model, vpack, dev_outfs=dconf.output_file, dev_goldfs=dt_golds, dev_out_format=dconf.output_format) if tconf.load_model: tr.load(dconf.model_load_name, tconf.load_process) # go - tr.run(train_iter, [dev_iter, test_iter]) + tr.run(train_iter, dt_iters) utils.zlog("The end of Training.") diff --git a/tasks/zdpar/stat/collect.py b/tasks/zdpar/stat/collect.py new file mode 100644 index 0000000..843dbd1 --- /dev/null +++ b/tasks/zdpar/stat/collect.py @@ -0,0 +1,108 @@ +# + +from msp.utils import zlog, Conf, JsonRW, PickleRW, zopen +import json, glob + +from .svocab import StatVocab +from .scorpus import DataReader +from .smodel import StatModel, StatApplyConf + +# +class StatConf(Conf): + def __init__(self, args): + self.main_prefix = "" # main_dir/main_name + # step 1: build the overall vocab + self.input_files = [] # can be glob + self.lc = True + self.thr_minlen = 5 + self.thr_maxlen = 80 + self.word_minc = 5 # word-count thresh + self.word_softcut = 10000 # general vocab size (todo(note): should also decide on corpus size) + self.vocab_prefix = "vocab" + self.no_build_vocab = False # directly load and skip building + # step 2: build the model (feature and feature counts) + self.feat_compact = True # whether compact the feat to save space? + self.win_size = 10 # max distance of center & context tokens + self.bins = [1, 2, 4, 10] # how to bin the distances + self.neg_dist = True # use negative distance + self.binary_nb = 1 # whether in the mode of binary nb + self.fc_max = 0 # cutting for the nearby word for Center (0 as off) + self.ft_max = 0 # cutting for the nearby word for conText (0 as off) + self.lex_min = 1 # only consider those >= this for tok (center) + self.lex_min2 = 1 # only consider those >= this for tok2 (context) + self.meet_punct_max = 1 # how many puncts as features (<=0 as off) + self.meet_lex_thresh = 0 # only lex min StatModel: + import pickle + # return PickleRW.from_file(file) + with zopen(file, 'rb') as fd: + return pickle.load(fd) + +def main(args): + conf = StatConf(args) + input_files = conf.input_files + main_prefix = conf.main_prefix + # step 1: get vocab + zlog("Step 1: vocab") + vocab_prefix = main_prefix + conf.vocab_prefix + if conf.no_build_vocab: + vocab = StatVocab() + JsonRW.from_file(vocab, vocab_prefix+".json") + else: + reader0 = DataReader(conf) + vocab_orig = StatVocab(name="word") + for tokens in reader0.yield_data(input_files): + vocab_orig.add_all(tokens) + # cut and write vocab + # first sort and write overall one + vocab_orig.sort_and_cut(sort=True) + vocab_orig.write_txt((main_prefix + vocab_prefix + "_orig.txt")) + JsonRW.to_file(vocab_orig, (main_prefix + vocab_prefix + "_orig.json")) + # also write msp.Vocab for analysis usage + JsonRW.to_file(vocab_orig.to_zvoc(), (main_prefix + vocab_prefix + "_orig.voc")) + # + vocab_cut = vocab_orig.copy() + vocab_cut.sort_and_cut(mincount=conf.word_minc, soft_cut=conf.word_softcut, sort=True) + vocab_cut.write_txt((main_prefix + vocab_prefix + ".txt")) + JsonRW.to_file(vocab_cut, (main_prefix + vocab_prefix + ".json")) + vocab = vocab_cut + # step 2: build corpus and raw counts + # todo(note): next: + # look at which words (freq) are good/bad for cross-lingual results; + # collect features one pass from left to right: count(word), count(word, featured_word), count(f_word) + # how to normalize? and how to penalize according to features, in train or test? + zlog("Step 2: model") + model_prefix = main_prefix + conf.model_prefix + if conf.no_build_model: + model = PickleRW.from_file(model_prefix+".pic") + else: + reader1 = DataReader(conf) + model = StatModel(conf, vocab) + for tokens in reader1.yield_data(input_files): + model.add_sent(tokens) + # + PickleRW.to_file(model, model_prefix+".pic") + # step 3: apply model + # todo(note): at using time ... + zlog("Finish") + return model + +# PYTHONPATH=../src/ python3 ../src/tasks/cmd.py zdpar.stat.collect input_files:? diff --git a/tasks/zdpar/stat/decode.py b/tasks/zdpar/stat/decode.py new file mode 100644 index 0000000..3c25f95 --- /dev/null +++ b/tasks/zdpar/stat/decode.py @@ -0,0 +1,165 @@ +# + +# decoding with the extra help from outside sources +# special with g1 models + +import numpy as np +import pandas as pd + +from msp.utils import StatRecorder, Helper, zlog, zopen +from msp.nn import BK + +from ..ef.parser import G1Parser, G1ParserConf +from ..ef.parser.g1p import PruneG1Conf, ParseInstance +from ..common.confs import DepParserConf +from ..main.test import prepare_test + +# todo(note): directly use the specific algo +# from ..algo.mst import mst_unproj +from ..algo.nmst import mst_unproj + +from msp.zext.dpar import ParserEvaler +from ..common.data import ParseInstance, get_data_writer + +from .collect import load_model, StatConf, StatVocab +from .smodel import StatModel, StatApplyConf + +# speical decoding +class SDConf(DepParserConf): + def __init__(self, parser_type, args): + self.zprune = PruneG1Conf() + self.smodel = "" + self.aconf = StatApplyConf() + # other options + self.apply_pruning = True # apply pruning for the sd scores with model pruning masks + self.combine_marginals = True # combine with raw scores or marginals? + # + super().__init__(parser_type, args) + +def main(args): + conf, model, vpack, test_iter = prepare_test(args, SDConf) + # make sure the model is order 1 graph model, otherwise cannot run through + assert isinstance(model, G1Parser) and isinstance(conf.pconf, G1ParserConf) + # ===== + # helpers + all_stater = StatRecorder(False) + def _stat(k, v): + all_stater.record_kv(k, v) + # ===== + # explicitly doing decoding here + if conf.smodel: + zlog(f"Load StatModel from {conf.smodel}") + smodel: StatModel = load_model(conf.smodel) + else: + zlog(f"Blank model for debug") + dummy_vocab = StatVocab() + dummy_vocab.sort_and_cut() + smodel: StatModel = StatModel(StatConf([]), dummy_vocab) + aconf = conf.aconf + # other options + apply_pruning = conf.apply_pruning + combine_marginals = conf.combine_marginals + all_insts = [] + for cur_insts in test_iter: + # score and prune + valid_mask, arc_score, label_score, mask_expr, marginals = model.prune_on_batch(cur_insts, conf.zprune) + # only modifying arc score! + valid_mask_arr = BK.get_value(valid_mask) # [bs, slen, slen] + arc_score_arr = BK.get_value(arc_score) # [bs, slen, slen] + label_score_arr = BK.get_value(label_score) # [bs, slen, slen, L] + marginals_arr = BK.get_value(marginals) # [bs, slen, slen] + # for each inst + for one_idx, one_inst in enumerate(cur_insts): + tokens = one_inst.words.vals + if smodel.lc: + tokens = [str.lower(z) for z in tokens] + cur_len = len(tokens) + cur_arange = np.arange(cur_len) + # get current things: [slen, slen] + one_valid_mask_arr = valid_mask_arr[one_idx, :cur_len, :cur_len] + one_arc_score_arr = arc_score_arr[one_idx, :cur_len, :cur_len] + one_label_score_arr = label_score_arr[one_idx, :cur_len, :cur_len] + one_marginals_arr = marginals_arr[one_idx, :cur_len, :cur_len] + # get scores from smodel + one_sd_scores = smodel.apply_sent(tokens, aconf) # [slen, slen] + if apply_pruning: + one_sd_scores *= one_valid_mask_arr # TODO(WARN): 0 or -inf? + orig_arc_score = (one_marginals_arr if combine_marginals else one_arc_score_arr) + final_arc_score = one_sd_scores + orig_arc_score + # first decoding with arc scores + mst_heads_arr, _, _ = mst_unproj(np.expand_dims(final_arc_score, axis=0), + np.array([cur_len], dtype=np.int32), labeled=False) + mst_heads_arr = mst_heads_arr[0] # [slen] + # then get each one's argmax label + argmax_label_arr = one_label_score_arr[cur_arange, mst_heads_arr].argmax(-1) # [slen] + # put in the results + one_inst.pred_heads.set_vals(mst_heads_arr) # directly int-val for heads + one_inst.pred_labels.build_vals(argmax_label_arr, model.label_vocab) + # extra output + one_inst.extra_pred_misc["orig_score"] = orig_arc_score[cur_arange, mst_heads_arr].tolist() + one_inst.extra_pred_misc["sd_score"] = one_sd_scores[cur_arange, mst_heads_arr].tolist() + # ===== + # special analyzing with the results and the gold (only for analyzing) + gold_heads = one_inst.heads.vals + _stat("num_sent", 1) + _stat("num_token", (cur_len-1)) + _stat("num_pairs", (cur_len-1)*(cur_len-1)) + _stat("num_pairs_valid", one_valid_mask_arr.sum()) # remaining ones after the pruning (pruning rate) + _stat("num_gold_valid", one_valid_mask_arr[cur_arange, gold_heads][1:].sum()) # pruning coverage + # about the sd scores + _stat("num_sd_nonzero", (one_sd_scores>0.).sum()) + _stat("num_sd_correct", (one_sd_scores>0.)[cur_arange, gold_heads][1:].sum()) + # ===== + all_insts.extend(cur_insts) + # ===== + # write and eval + # sorting by idx of reading + all_insts.sort(key=lambda x: x.inst_idx) + # write + dconf = conf.dconf + if dconf.output_file: + with zopen(dconf.output_file, "w") as fd: + data_writer = get_data_writer(fd, dconf.output_format) + data_writer.write(all_insts) + # eval + evaler = ParserEvaler() + eval_arg_names = ["poses", "heads", "labels", "pred_poses", "pred_heads", "pred_labels"] + for one_inst in all_insts: + # todo(warn): exclude the ROOT symbol; the model should assign pred_* + real_values = one_inst.get_real_values_select(eval_arg_names) + evaler.eval_one(*real_values) + report_str, res = evaler.summary() + zlog(report_str, func="result") + zlog("zzzzztest: testing result is " + str(res)) + # ===== + d = all_stater.summary(get_v=True, get_str=False) + d["z_prune_rate"] = d["num_pairs_valid"] / d["num_pairs"] + d["z_prune_coverage"] = d["num_gold_valid"] / d["num_token"] + d["z_sd_precision"] = d["num_sd_correct"] / d["num_sd_nonzero"] + d["z_sd_recall"] = d["num_sd_correct"] / d["num_token"] + Helper.printd(d, "\n\n") + +""" +SRC_DIR="../src/" +DATA_DIR="../data/UD_RUN/ud24/" +DATA_DIR="../data/ud24/" +CUR_LANG=en +RGPU=-1 +DEVICE=-1 +MODEL_DIR=./ru10k/ + +for combine_marginals in 0 1; do +for final_lambda in 0.1 0.25 0.5 0.75 1.; do +for feat_log in 0 1; do +for feat_alpha in 1. 0.75 0.5 0.25; do +for word_beta in 1. 0.5 0.; do +for dist_decay_v in 1. 0.9 0.5; do +for final_norm in 1 0; do +CUDA_VISIBLE_DEVICES=${RGPU} PYTHONPATH=${SRC_DIR}:. python3 ${SRC_DIR}/tasks/cmd.py zdpar.stat.decode ${MODEL_DIR}/_conf device:${DEVICE} dict_dir:${MODEL_DIR} model_load_name:${MODEL_DIR}/zmodel.best partype:g1 log_file: test:${DATA_DIR}/${CUR_LANG}_dev.ppos.conllu test_extra_pretrain_files:${DATA_DIR}/wiki.multi.${CUR_LANG}.filtered.vec smodel:./model.pic zprune.pruning_mthresh:0.1 combine_marginals:${combine_marginals} final_lambda:${final_lambda} feat_log:${feat_log} feat_alpha:${feat_alpha} word_beta:${word_beta} dist_decay_v:${dist_decay_v} final_norm:${final_norm} +done; done; done; done; done; done; done |& tee _log1002 + +# out of the loop +CUDA_VISIBLE_DEVICES=${RGPU} PYTHONPATH=${SRC_DIR}:. python3 ${SRC_DIR}/tasks/cmd.py zdpar.stat.decode ${MODEL_DIR}/_conf device:${DEVICE} dict_dir:${MODEL_DIR} model_load_name:${MODEL_DIR}/zmodel.best partype:g1 log_file: test:${DATA_DIR}/${CUR_LANG}_dev.ppos.conllu test_extra_pretrain_files:${DATA_DIR}/wiki.multi.${CUR_LANG}.filtered.vec smodel:./model.pic zprune.pruning_mthresh:0.1 + +# b tasks/zdpar/stat/decode:78, cur_len>10 +""" diff --git a/tasks/zdpar/stat/scorpus.py b/tasks/zdpar/stat/scorpus.py new file mode 100644 index 0000000..386c54e --- /dev/null +++ b/tasks/zdpar/stat/scorpus.py @@ -0,0 +1,48 @@ +# + +from msp.utils import zlog, zopen, Helper + +class DataReader: + def __init__(self, conf): + self.lc = conf.lc + self.thr_minlen = conf.thr_minlen + self.thr_maxlen = conf.thr_maxlen + # + self.stats = {"orig_sent": 0, "orig_tok": 0, "sent": 0, "tok": 0} + self.report_freq = 5 + + def yield_data(self, files): + # + if not isinstance(files, (list, tuple)): + files = [files] + # + cur_num = 0 + for f in files: + cur_num += 1 + zlog("-----\nDataReader: [#%d] Start reading file %s." % (cur_num, f)) + with zopen(f) as fd: + for z in self._yield_tokens(fd): + yield z + if cur_num % self.report_freq == 0: + zlog("** DataReader: [#%d] Summary till now:" % cur_num) + Helper.printd(self.stats) + zlog("=====\nDataReader: End reading ALL (#%d) ==> Summary ALL:" % cur_num) + Helper.printd(self.stats) + + def _yield_tokens(self, fin): + ss = self.stats + for line in fin: + line = line.strip() + if len(line) > 0: + # todo(note): sep by seq of spaces! + tokens = line.split() + tok_num = len(tokens) + ss["orig_sent"] += 1 + ss["orig_tok"] += tok_num + if tok_num <= self.thr_maxlen and tok_num >= self.thr_minlen: + ss["sent"] += 1 + ss["tok"] += tok_num + # todo(note): current only processing is lowercase + if self.lc: + tokens = [c.lower() for c in tokens] + yield tokens diff --git a/tasks/zdpar/stat/sdec.py b/tasks/zdpar/stat/sdec.py new file mode 100644 index 0000000..4085a65 --- /dev/null +++ b/tasks/zdpar/stat/sdec.py @@ -0,0 +1,46 @@ +# + +# sth about decoding + +import numpy as np + +# undirected mst algorithm: Kruskal +def mst_kruskal(scores): + cur_len = len(scores) # scores should be [cur_len, cur_len], and only use the i0= win_size + self.num_bin = len(bins) * 2 # directional + self.dist2bin = [-1] * (2*win_size_p1) # [0, win_size_p1) + [win_size_p1, 2*win_size_p1) + prev_i = -1 + for bin_idx, bin_ceil in enumerate(bins): + for x in range(prev_i+1, bin_ceil+1): + self.dist2bin[x] = bin_idx + self.dist2bin[x+win_size_p1] = bin_idx+len(bins) + prev_i = bin_ceil + assert all(z>=0 for z in self.dist2bin) + dist2bin_str = ", ".join([f"{s}:{v}" for s,v in enumerate(self.dist2bin)]) + zlog(f"Build dist-bins with win_size={win_size}, bins=(#{len(bins)}){bins}, dist2bin={dist2bin_str}") + # ===== + # prepare models for each word + self.binary_nb = conf.binary_nb # whether 0/1 for each feature and word as in binary nb + self.num_words = len(vocab) + self.word_counts = [0] * self.num_words + self.feat_counts = [Counter() for i in range(self.num_words)] + # feat storing: compact or not? + self.feat_compact = conf.feat_compact + + def add_sent(self, tokens: List[str]): + # ===== + # prepare + conf = self.conf + # nearby features + fc_max = conf.fc_max # near Center + ft_max = conf.ft_max # near conText + lex_min = conf.lex_min + lex_min2 = conf.lex_min2 + win_size = conf.win_size + win_size_p1 = win_size+1 + meet_punct_max = conf.meet_punct_max + meet_lex_thresh = conf.meet_lex_thresh + meet_lex_freq_max = conf.meet_lex_freq_max + vocab = self.vocab + dist2bin = self.dist2bin + neg_dist_adding = win_size_p1 if conf.neg_dist else 0 + # first change to idx, 0 means UNK; also collect related info + cur_len = len(tokens) + tok_idxes = [vocab.get(t, 0) for t in tokens] # the token itself + is_punct = [vocab.is_punct[i] for i in tok_idxes] # whether is_punct + # ===== + # start collecting: center ... -> ... context // context ... <- ... center + collections = [[] for _ in range(cur_len)] + for posi, tok in enumerate(tok_idxes): + if tok >= lex_min and not is_punct[posi]: # only consider lexicons that are not the most freq (and also exclude unk=-1) + # the meet tokens in between + cur_meet_punct = 0 + cur_meet_lex_min = meet_lex_thresh+1 # idx of the most freq word met + cur_meet_freq_num = 0 # number of how many freq words met + for posi2 in range(posi+1, min(cur_len, posi+win_size_p1)): + cur_meet_punct = min(cur_meet_punct + is_punct[posi2], meet_punct_max) + tok2 = tok_idxes[posi2] + if tok2 >= lex_min2 and not is_punct[posi2]: # similarly for context token + cur_distance = posi2 - posi + cur_fleft, cur_fright = tok_idxes[posi+1], tok_idxes[posi2-1] + # add features: (context-tok, meet-punct, fc-tok, ft-tok, dist-bin) + # tok -> tok2 + feat1 = (tok2, cur_meet_punct, min(fc_max, cur_fleft), min(ft_max, cur_fright), + dist2bin[cur_distance], cur_meet_lex_min, cur_meet_freq_num) + feat1 = self._compact(feat1) + collections[posi].append(feat1) + # tok <- tok2 + feat2 = (tok, cur_meet_punct, min(fc_max, cur_fright), min(ft_max, cur_fleft), + dist2bin[neg_dist_adding + cur_distance], cur_meet_lex_min, cur_meet_freq_num) + feat2 = self._compact(feat2) + collections[posi2].append(feat2) + if tok2 > 0 and tok2 <= meet_lex_thresh: # not UNK + cur_meet_lex_min = min(cur_meet_lex_min, tok2) + cur_meet_freq_num = min(cur_meet_freq_num+1, meet_lex_freq_max) + # add to the repo as specific model + for one_tok, one_collections in zip(tok_idxes, collections): + if one_tok >= lex_min: + if self.binary_nb: + wc, vs = 1, set(one_collections) # in binary nb, only count 0/1 for each feat + else: + wc, vs = len(one_collections), one_collections + self.word_counts[one_tok] += wc + d = self.feat_counts[one_tok] + for z in vs: + d[z] += 1 + + # todo(note): the input may contain a artificial root at the start, this does not matter + # since it will be UNK and we use relative distance + def apply_sent(self, tokens: List[str], aconf: StatApplyConf): + # ===== + # prepare + conf = self.conf + # nearby features + fc_max = conf.fc_max # near Center + ft_max = conf.ft_max # near conText + lex_min = conf.lex_min + lex_min2 = conf.lex_min2 + win_size = conf.win_size + win_size_p1 = win_size + 1 + meet_punct_max = conf.meet_punct_max + meet_lex_thresh = conf.meet_lex_thresh + meet_lex_freq_max = conf.meet_lex_freq_max + vocab = self.vocab + dist2bin = self.dist2bin + neg_dist_adding = win_size_p1 if conf.neg_dist else 0 + # first change to idx, 0 means UNK; also collect related info + cur_len = len(tokens) + tok_idxes = [vocab.get(t, 0) for t in tokens] # the token itself + is_punct = [vocab.is_punct[i] for i in tok_idxes] # whether is_punct + # ===== + # how to score + scores = np.zeros([cur_len, cur_len], dtype=np.float32) + for posi, tok in enumerate(tok_idxes): + if tok>=lex_min and not is_punct[posi]: # only consider lexicons that are not the most freq (and also exclude unk=-1) + cur_meet_punct = 0 + cur_meet_lex_min = meet_lex_thresh+1 # idx of the most freq word met + cur_meet_freq_num = 0 # number of how many freq words met + for posi2 in range(posi+1, min(cur_len, posi+win_size_p1)): + cur_meet_punct = min(cur_meet_punct + is_punct[posi2], meet_punct_max) + tok2 = tok_idxes[posi2] + if tok2>=lex_min2 and not is_punct[posi2]: # similarly for context token + cur_distance = posi2 - posi + cur_fleft, cur_fright = tok_idxes[posi+1], tok_idxes[posi2-1] + # add features: (context-tok, meet-punct, fc-tok, ft-tok, dist-bin, lex-min, lex-freq) + # tok -> tok2 + feat1 = (tok2, cur_meet_punct, min(fc_max, cur_fleft), min(ft_max, cur_fright), + dist2bin[cur_distance], cur_meet_lex_min, cur_meet_freq_num) + feat1 = self._compact(feat1) + score1 = self._score(tok, feat1, cur_distance, aconf) + # tok <- tok2 + feat2 = (tok, cur_meet_punct, min(fc_max, cur_fright), min(ft_max, cur_fleft), + dist2bin[neg_dist_adding+cur_distance], cur_meet_lex_min, cur_meet_freq_num) + feat2 = self._compact(feat2) + score2 = self._score(tok2, feat2, cur_distance, aconf) # still using orig distance + # average and symmetric for the link score + final_score = (score1 + score2)/2. + scores[posi, posi2] = final_score + scores[posi2, posi] = final_score + if tok2 > 0 and tok2 <= meet_lex_thresh: # not UNK + cur_meet_lex_min = min(cur_meet_lex_min, tok2) + cur_meet_freq_num = min(cur_meet_freq_num+1, meet_lex_freq_max) + # + if aconf.final_norm: + if aconf.final_norm_exp: + scores2 = aconf.final_norm_v ** scores + scores2 = np.where(scores>0., scores2, 0.) # 0. is still zero + else: + scores2 = scores ** aconf.final_norm_v + # normalize and again average symmetrically + scores = scores2 / (scores2.sum(-1, keepdims=True)+1e-5) + scores2 / (scores2.sum(0, keepdims=True)+1e-5) + scores /= 2 + return scores * aconf.final_lambda + + # + def _score(self, tok, feat, distance, aconf: StatApplyConf): + count_feat = self.feat_counts[tok][feat] + aconf.feat_add + if aconf.feat_log: + div0 = aconf.feat_alpha * np.log(count_feat) + else: + div0 = count_feat ** aconf.feat_alpha + if aconf.word_beta > 0.: + count_word = self.word_counts[tok] + div1 = count_word ** aconf.word_beta + else: + div1 = 1. + if aconf.dist_decay_exp: + dd = aconf.dist_decay_v ** distance + else: + dd = distance ** aconf.dist_decay_v + if div1 == 0.: + return 0. + return dd * div0 / div1 + + # storing feat compactly or not + # currently: (context-tok, meet-punct, fc-tok, ft-tok, dist-bin, meet-lex-thresh, meet-lex-freq) + def _compact(self, feat_tuple): + if self.feat_compact: + r = 0 + for x, bits in zip(feat_tuple, (18, 2, 10, 10, 5, 10, 3)): # make the sum <64 + # for x, bits in zip(feat_tuple, (18, 2, 0, 0, 5, 10, 3)): # not using fc/ft + r = (r< +class StatVocab: + UNK_str = "[]" + + def __init__(self, name=None, pre_words=None, no_adding=False): + self.name = str(name) + self.count_all = 0 + self.no_adding = no_adding + self.w2c = {} # word -> count + self.i2w = [] # idx -> word (only the words that survive cutting) + self.w2i = None # word -> idx + self.i2c = None # idx -> count + self.is_punct = None # [idx -> int(bool)] + if pre_words is not None: + self.i2w = list(pre_words) + self.w2c = {w:0 for w in self.i2w} + self._calc() + elif no_adding: + zlog("Warn: no adding mode for an empty Vocab!!") + + def _calc(self): + # todo(note): put UNK as 0 + unk_str = StatVocab.UNK_str + if unk_str in self.w2c: + # loading, not need to add + assert self.w2c[unk_str] == 0 + assert self.i2w[0] == unk_str + else: + self.w2c[unk_str] = 0 + self.i2w = [unk_str] + self.i2w + self.w2i = {w:i for i,w in enumerate(self.i2w)} + self.i2c = [self.w2c[w] for w in self.i2w] + self.is_punct = [0] * len(self.i2w) + PUNCT_LIST = """[]!"#$%&'()*+,./:;<=>?@[\\^_`{|}~„“”«»²–、。・()「」『』:;!?〜,《》·""" + PUNCT_SET = set(PUNCT_LIST) + for widx, word in enumerate(self.i2w): + if all(c in PUNCT_SET for c in word): # if all characters are punct for this word + self.is_punct[widx] = 1 + + def __len__(self): + return len(self.i2w) + + def __repr__(self): + return f"Vocab {self.name}: size={len(self)}({len(self.w2c)})/count={self.count_all}." + + def __contains__(self, item): + return item in self.w2i + + def __getitem__(self, item): + return self.w2i[item] + + def get(self, item, d): + return self.w2i.get(item, d) + + def copy(self): + n = StatVocab(self.name + "_Copy") + n.count_all = self.count_all + n.w2c = self.w2c.copy() + n.i2w = self.i2w.copy() + n._calc() + return n + + def add_all(self, ws, cc=1, ensure_exist=False): + for w in ws: + self.add(w, cc, ensure_exist) + + def add(self, w, cc=1, ensure_exist=False): + orig_cc = self.w2c.get(w) + exist = (orig_cc is not None) + if self.no_adding or ensure_exist: + assert exist, "Non exist key %s" % (w,) + if exist: + self.w2c[w] = orig_cc + cc + else: + self.i2w.append(w) + self.w2c[w] = cc + self.count_all += cc + + # cur ones that are =mincount] + if sort: + final_i2w.sort(key=lambda x: (-self.w2c[x], x)) # sort by (count, self) to make it fixed + if soft_cut and len(final_i2w)>soft_cut: + new_minc = self.w2c[final_i2w[soft_cut-1]] # boundary counting value + cur_idx = soft_cut + while cur_idx=new_minc: + cur_idx += 1 + final_i2w = final_i2w[:cur_idx] # soft cutting by ranking & keep boundary values + # + self.i2w = final_i2w + self.count_all = sum(self.w2c[w] for w in final_i2w) + self.w2c = {w:self.w2c[w] for w in final_i2w} + self._calc() + zlog("Post-cut Vocab-stat: " + str(self)) + + # + def yield_infos(self): + accu_count = 0 + for i, w in enumerate(self.i2w): + count = self.w2c[w] + accu_count += count + perc = count / self.count_all * 100 + accu_perc = accu_count / self.count_all * 100 + yield (i, w, count, perc, accu_count, accu_perc) + + def write_txt(self, fname): + with zopen(fname, "w") as fd: + for pack in self.yield_infos(): + i, w, count, perc, accu_count, accu_perc = pack + ss = f"{i} {w} {count}({perc:.3f}) {accu_count}({accu_perc:.3f})\n" + fd.write(ss) + zlog("Write (txt) to %s: %s" % (fname, str(self))) + + # other i/o + def to_zvoc(self) -> Vocab: + builder = VocabBuilder("word", default_val=0) + for w,c in self.w2c.items(): + builder.feed_one(w, c) + voc: Vocab = builder.finish(sort_by_count=True) + return voc + + def to_builtin(self): + return (self.name, self.count_all, self.no_adding, self.w2c, self.i2w) + + def from_builtin(self, v): + self.name, self.count_all, self.no_adding, self.w2c, self.i2w = v + self._calc() diff --git a/tasks/zdpar/stat2/__init__.py b/tasks/zdpar/stat2/__init__.py new file mode 100644 index 0000000..b860811 --- /dev/null +++ b/tasks/zdpar/stat2/__init__.py @@ -0,0 +1,11 @@ +# + +# todo-list +# topical influence (use actual bert's predicted output to delete special semantics?) -> change words: first fix non-changed, then find repeated topic words, then change topic and hard words one per segment -> (191103: change too much may hurt) +# predict-leaf (use first half of order as leaf, modify the scores according to this first half?) -> (191104: still not good) +# other reduce criterion? -> (191105: not too much diff, do not get it too complex...) +# vocab based reduce? simply <= thresh? -> (191106: ok, help a little, but not our target) +# cky + decide direction later +# only change topical words? +# direct parse subword seq? (group and take max/avg-score) +# cluster (hierachical bag of words, direction?, influence range, grouped influence) diff --git a/tasks/zdpar/stat2/decode2.py b/tasks/zdpar/stat2/decode2.py new file mode 100644 index 0000000..2501157 --- /dev/null +++ b/tasks/zdpar/stat2/decode2.py @@ -0,0 +1,303 @@ +# + +# look at what structures can be found from representations + +from typing import List, Dict +import numpy as np +import pickle +from collections import OrderedDict, Counter + +from scipy.stats import spearmanr, pearsonr +import pandas as pd + +from msp.utils import StatRecorder, Helper, zlog, zopen, Conf +from msp.nn import BK, NIConf +from msp.nn import init as nn_init +from msp.nn.modules.berter import BerterConf, Berter + +from tasks.zdpar.common.confs import DepParserConf +from tasks.zdpar.common.data import ParseInstance, get_data_writer +from tasks.zdpar.main.test import prepare_test, get_data_reader, init_everything +from tasks.zdpar.ef.parser import G1Parser + +from .helper_basic import SentProcessor, main_loop, show_heatmap, SDBasicConf +from .helper_feature import FeaturerConf, Featurer +from .helper_decode import get_decoder + +# speical decoding2 +class SD2Conf(SDBasicConf): + def __init__(self, args): + super().__init__() + # ===== + # show results + self.show_heatmap = False + self.show_result = False + # the processing + self.no_diag = True # make sure to zero out the self-influence scores + self.no_punct = False # ignore punctuations (PUNCT and SYM) for decoding + # analysis 1: corr + self.corr_min_surf_dist = 1 # analyze the ones >=this + # analysis 2: local min + self.local_min_thresh = 0. + # analysis 3: directional + self.dir_nonzero_thresh = 0.001 + # analysis 3.5: overall influence + self.no_nearby = False # exclude direct nearby words (since they usually have larger influence) + # decoding + self.dec_method = "m2" + self.dec_rmh_fun = True # remove the scores for function words (by POS tags) + self.dec_proj = True + self.dec_Olambda = 1. # lambda for origin one + self.dec_Tlambda = 1. # lambda for transpose + self.dec_Rlambda_inf = 0. # lambda for root influence + self.dec_Rlambda_cls = 0. # lambda for root cls + # ===== + # + self.update_from_args(args) + self.validate() + +# ===== +class SentProcessor2(SentProcessor): + def __init__(self, conf: SD2Conf): + super().__init__() + self.conf = conf + self.decoder = get_decoder(conf.dec_method) + # self.content_pos_set = {"NOUN", "VERB", "ADJ", "PROPN"} + self.content_pos_set = {"NOUN", "VERB", "PROPN"} + + # ===== + # todo(note): no arti-ROOT in inputs + def test_one_sent(self, one_inst: ParseInstance, **kwargs): + conf = self.conf + which_fold = conf.which_fold + sent = one_inst.words.vals[1:] + uposes = one_inst.poses.vals[1:] + dep_heads = one_inst.heads.vals[1:] + dep_labels = one_inst.labels.vals[1:] + # + slen = len(sent) + slen_arange = np.arange(slen) + distances = one_inst.extra_features["sd2_scores"][:, :, which_fold] # [s, s+1] + # ===== + # step 0: pre-processing + orig_distances = np.copy(distances) # [s, s+1], no 0 mask for puncts or diag + pairwise_masks = np.ones([slen, slen], dtype=np.bool) # [s, s] + # ----- + if conf.no_diag: + arange_t = slen_arange.astype(np.int32) + distances[:, 1:][arange_t, arange_t] = 0. + pairwise_masks[arange_t, arange_t] = 0 + punct_masks = np.array([(z=="PUNCT" or z=="SYM") for z in uposes], dtype=np.bool) + punct_idxes = punct_masks.nonzero()[0] # starting from 0, no root + nonpunct_masks = (~punct_masks) + nonpunct_idxes = nonpunct_masks.nonzero()[0] + content_masks = np.array([(z in self.content_pos_set) for z in uposes], dtype=np.bool) + fun_masks = (~content_masks) + if conf.no_punct: + distances[:, 1:][:, punct_idxes] = 0. + distances[:, 0:][punct_idxes, :] = 0. + pairwise_masks[:, punct_idxes] = 0 + pairwise_masks[punct_idxes, :] = 0 + # ===== + # step 0.9: prepare various distances for comparing + surface_distances = np.abs(slen_arange[:, np.newaxis] - slen_arange[np.newaxis, :]) # [s,s] + syntax_distances, syntax_paths, syntax_depths = self._parse_heads(dep_heads) # [s,s], [s,], [s] + # ===== + ret_info = {"orig": (sent, uposes, dep_heads, dep_labels, syntax_distances, syntax_depths)} + # step 1: correlation analysis + # ----- + def _add_corrs(a, b, name): + # ret_info[name+"_rsp"] = spearmanr(a, b)[0] + # ret_info[name+"_rp"] = pearsonr(a, b)[0] + # todo(note): not stable if len not enough or syn-distances(a) are all the same + ret_info["corr_" + name] = 0.5 if (len(a)<=5 or np.std(a)==0) else spearmanr(a, b)[0] + # ----- + corr_min_surf_dist = self.conf.corr_min_surf_dist + all_nonzeros = pairwise_masks.nonzero() + upper_mask = (all_nonzeros[0] < all_nonzeros[1]-corr_min_surf_dist) + lower_mask = (all_nonzeros[0] > all_nonzeros[1]+corr_min_surf_dist) + upper_idxes0, upper_idxes1 = [z[upper_mask] for z in all_nonzeros] + lower_idxes0, lower_idxes1 = [z[lower_mask] for z in all_nonzeros] + # ----- + nifs = -distances[:, 1:] # negative influence value + nifs_both = -(distances[:, 1:]+distances[:, 1:].T)/2. # average for both directions + # syntax vs. surface (upper triu) + _add_corrs(syntax_distances[upper_idxes0, upper_idxes1], surface_distances[upper_idxes0, upper_idxes1], "syn_surf") + # + for target_distance, target_name in zip([syntax_distances, surface_distances], ["syn", "surf"]): + # upper: syntax/surface vs. neg-influence + _add_corrs(target_distance[upper_idxes0, upper_idxes1], nifs[upper_idxes0, upper_idxes1], target_name+"_nif_upper") + # lower: syntax/surface vs. neg-influence + _add_corrs(target_distance[lower_idxes0, lower_idxes1], nifs[lower_idxes0, lower_idxes1], target_name+"_nif_lower") + # averaged: syntax/surface vs. neg-influence (upper triu) + _add_corrs(target_distance[upper_idxes0, upper_idxes1], nifs_both[upper_idxes0, upper_idxes1], target_name+"_nif_avg") + # ===== + # step 2: long range dep + orig_scores = orig_distances[:, 1:] + padded_orig_scores = np.pad(orig_scores, ((1,1),(1,1)), 'constant', constant_values=0.) # min score is 0. + # [s,s], minimum of the diff to four neighbours + local_min_diff = np.stack([orig_scores-padded_orig_scores[1:-1, 2:], orig_scores-padded_orig_scores[1:-1, :-2], + orig_scores-padded_orig_scores[2:, 1:-1], orig_scores-padded_orig_scores[:-2, 1:-1]], + -1).min(-1) + interest_points_masks = ((local_min_diff>conf.local_min_thresh) & pairwise_masks) # excluding non-valids + interest_points_masks = (interest_points_masks & interest_points_masks.T) # bidirectional mask + interest_points_idx0, interest_points_idx1 = interest_points_masks.nonzero() + interest_points_idx_mask = (interest_points_idx0 < interest_points_idx1) # no direction, make idx0=0) & (~punct_masks)] + head_idxes = head_idxes[(head_idxes>=0) & (~punct_masks)] + # (without-head, mod) - (without-mod, head) + hm_score_diff1 = influence_scores[head_idxes, mod_idxes] - influence_scores[mod_idxes, head_idxes] + # ===== + # step 3.5: direction by influencing others & depth + overall_influence_mask = np.copy(pairwise_masks) # no self, no punct + # delete nearby? + if conf.no_nearby: + tmp_range_idxes = np.arange(slen) + overall_influence_mask[tmp_range_idxes[1:], tmp_range_idxes[1:]-1] = 0. + overall_influence_mask[tmp_range_idxes[:-1], tmp_range_idxes[:-1]+1] = 0. + average_influence_score = (influence_scores * overall_influence_mask).sum(-1) / (overall_influence_mask.sum(-1) + 1e-5) + hm_score_diff2 = average_influence_score[head_idxes] - average_influence_score[mod_idxes] + # correlation of depth and influence + _add_corrs(np.array(syntax_depths)[mod_idxes], average_influence_score[mod_idxes], "inf_depth") + # ----- + for prefix, hm_score_diff in zip(["dir_", "dir2_"], [hm_score_diff1, hm_score_diff2]): + ret_info[prefix+"num"] = len(hm_score_diff) + ret_info[prefix+"diff_arr"] = hm_score_diff + # ret_info[prefix+"depth"] = np.array(syntax_depths)[mod_idxes] + # _add_corrs(ret_info["dir_depth"], ret_info[prefix+"diff_arr"], prefix+"diff_depth") + dir_mask_posi, dir_mask_neg = (hm_score_diff>conf.dir_nonzero_thresh), (hm_score_diff<-conf.dir_nonzero_thresh) + ret_info[prefix+"positive"] = dir_mask_posi.sum().item() + ret_info[prefix+"negative"] = dir_mask_neg.sum().item() + ret_info[prefix+"nearzero"] = (~(dir_mask_posi | dir_mask_neg)).sum().item() + ret_info[prefix+"avg_pos"] = hm_score_diff[dir_mask_posi] + ret_info[prefix+"avg_neg"] = hm_score_diff[dir_mask_neg] + # step 3.6: best influence and root? + ret_info["depth_best_inf"] = syntax_depths[np.argmax(average_influence_score).item()] + ret_info["depth_best_cls"] = syntax_depths[np.argmax(distances[:,0]).item()] + # ----- + # step 4: inference + # score pre-processing + original_scores = influence_scores + processed_scores = conf.dec_Olambda * original_scores + conf.dec_Tlambda * (original_scores.T) + root_scores = conf.dec_Rlambda_cls * distances[:,0] + conf.dec_Rlambda_inf * average_influence_score + # delete fun word as h + if conf.dec_rmh_fun: + fun_mask = np.array([z not in self.content_pos_set for z in uposes]).astype(np.bool) + processed_scores[fun_mask] = 0. + root_scores[fun_mask] = 0. + # exclude punct for decoding? + if conf.no_punct: + original_scores = original_scores[nonpunct_idxes, :][:, nonpunct_idxes] + processed_scores = processed_scores[nonpunct_idxes, :][:, nonpunct_idxes] + root_scores = root_scores[nonpunct_idxes] + if len(processed_scores)>0: + # decode + output_heads, output_root = self.decoder(processed_scores, original_scores, root_scores, conf.dec_proj) + # sink fun? + # restore punct? + if conf.no_punct: + assert len(output_heads) == len(nonpunct_idxes) + output_root = nonpunct_idxes[output_root] + real_heads = [output_root+1] * slen # put puncts to the real root + # todo(note): be careful about the root offset (only in heads) + for m,h in enumerate(output_heads): + if h==0: + real_heads[nonpunct_idxes[m]] = 0 + else: + real_heads[nonpunct_idxes[m]] = nonpunct_idxes[h-1]+1 + output_heads = real_heads + else: + # only punctuations or empty sentence? + output_root = 0 + output_heads = [0] * slen + ret_info["output"] = (output_heads, output_root) + # ===== eval + ret_info.update(self._eval(dep_heads, output_heads, nonpunct_masks, content_masks, fun_masks)) + if conf.show_result: + self._show(one_inst, output_heads, ret_info, conf.show_heatmap, distances, True) + # + self.infos.append(ret_info) + return ret_info + + def summary(self): + # get averaged(stddev) scores + ret = OrderedDict() + ret["#Sent"] = len(self.infos) + ret["#Token"] = sum(len(z["orig"][0]) for z in self.infos) + count_np = 0 + for z in self.infos: + count_np += sum(p not in ["PUNCT", "SYM"] for p in z["orig"][1]) + ret["#TokenNP"] = count_np + # step 1: correlation analysis + for n in ["syn_surf"] + [a+b for a in ["syn", "surf"] for b in ["_nif_upper", "_nif_lower", "_nif_avg"]]: + values = [z["corr_" + n] for z in self.infos] + ret["S1_" + n] = (np.mean(values), np.std(values)) + # step 2: long range dep + # skip this one since not much patterns here + # step 3: direction for dep pairs + corr_values = [z["corr_" + "inf_depth"] for z in self.infos] + ret["S3_corr_inf_depth"] = (np.mean(corr_values), np.std(corr_values)) + for prefix in ["dir_", "dir2_"]: + for name in ["positive", "negative", "nearzero"]: + values = [z[prefix+name] / z[prefix+"num"] for z in self.infos if z[prefix+"num"]>0] + ret["S3_" + prefix + name] = (np.mean(values), np.std(values)) + for name in ["avg_pos", "avg_neg"]: + values = np.concatenate([z[prefix+name] for z in self.infos]) + ret["S3_" + prefix + name] = (np.mean(values), np.std(values)) + # best inf vs. depth + ret["S3_depth_best_inf"] = Counter([z["depth_best_inf"] for z in self.infos]) + ret["S3_depth_best_cls"] = Counter([z["depth_best_cls"] for z in self.infos]) + # ===== + ret.update(self._sum_eval()) + return ret + +# +def main(args): + conf = SD2Conf(args) + sp = SentProcessor2(conf) + main_loop(conf, sp) + +# +""" +SRC_DIR="../src/" +DATA_DIR="../data/UD_RUN/ud24/" +#DATA_DIR="../data/ud24/" +CUR_LANG=en +RGPU=2 +DEVICE=0 +MODEL_DIR=./ru10k/ +# bert +CUDA_VISIBLE_DEVICES=${RGPU} PYTHONPATH=${SRC_DIR} python3 ${SRC_DIR}/tasks/cmd.py zdpar.stat2.decode2 device:${DEVICE} input_file:${DATA_DIR}/${CUR_LANG}_dev.ppos.conllu output_pic:_tmp.pic +## input_file:en_cut.ppos.conllu +# bert with precomputed pic +PYTHONPATH=../src/ python3 ../src/tasks/cmd.py zdpar.stat2.decode2 input_file:_en_dev.ppos.pic already_pre_computed:1 +# another example +PYTHONPATH=../src/ python3 -m pdb ../src/tasks/cmd.py zdpar.stat2.decode2 input_file:_en_dev.ppos.pic already_pre_computed:1 dec_Rlambda_inf:0 dec_Rlambda_cls:0 dec_method:m2 dec_Tlambda:0 dec_proj:1 output_file:_tmp.conllu2 +# model +#CUDA_VISIBLE_DEVICES=${RGPU} PYTHONPATH=${SRC_DIR} python3 ${SRC_DIR}/tasks/cmd.py zdpar.stat2.decode2 ${MODEL_DIR}/_conf device:${DEVICE} dict_dir:${MODEL_DIR} model_load_name:${MODEL_DIR}/zmodel.best partype:g1 log_file: test:${DATA_DIR}/${CUR_LANG}_dev.ppos.conllu test_extra_pretrain_files:${DATA_DIR}/wiki.multi.${CUR_LANG}.filtered.vec +# +b tasks/zdpar/stat2/decode2:348 +""" diff --git a/tasks/zdpar/stat2/decode3.py b/tasks/zdpar/stat2/decode3.py new file mode 100644 index 0000000..9caaa2e --- /dev/null +++ b/tasks/zdpar/stat2/decode3.py @@ -0,0 +1,198 @@ +# + +# special decoding 3: in a hierarchical way + +from typing import List, Dict +import numpy as np +from collections import OrderedDict, Counter + +from msp.utils import Helper, zlog +from msp.data.vocab import Vocab +from tasks.zdpar.common.data import ParseInstance +from .helper_basic import SentProcessor, main_loop, show_heatmap, SDBasicConf +from .helper_decode import get_decoder, get_fdecoder, IRConf, IterReducer + +# +class SD3Conf(SDBasicConf): + def __init__(self, args): + super().__init__() + # + self.show_heatmap = False + self.show_result = False + self.no_diag = True + # ===== + # reducer + self.ir_conf = IRConf() + # ===== + # leaf layer decoding + self.use_fdec = False + self.fdec_what = "reduce" # reduce based or pos based or vocab based + self.fdec_vocab_file = "" + self.fdec_vr_thresh = 100 # vocab reduce thresh, <= this + self.fdec_ir_perc = 0.5 # how many percentage put in fdec for ir mode + self.fdec_method = "max" # right/left/max + self.fdec_oper = "add" # operation between O/T + self.fdec_Olambda = 1. # if use score, lambda for origin one + self.fdec_Tlambda = 1. # if use score, lambda for transpose one + # main layer decoding + self.mdec_aggregate_scores = False # aggregate score from fdec + self.mdec_method = "m2" # upper layer decoding method: l2r/r2l/rand/m1/m2/m3/m3s/m4(recursive-reduce) + self.mdec_proj = True + self.mdec_oper = "add" # operation between O/T + self.mdec_Olambda = 1. # if use score, lambda for origin one + self.mdec_Tlambda = 1. # if use score, lambda for transpose one + # + self.update_from_args(args) + self.validate() + +# ===== +class SentProcessor3(SentProcessor): + def __init__(self, conf: SD3Conf): + super().__init__() + self.conf = conf + self.mdecoder = get_decoder(conf.mdec_method) + self.fdecoder = get_fdecoder(conf.fdec_method) + # + # self.content_pos_set = {"NOUN", "VERB", "ADJ", "PROPN"} + self.content_pos_set = {"NOUN", "VERB", "PROPN"} + # + oper_dict = {"add": np.add, "max": np.maximum, "min": np.minimum} + self.fdec_oper = oper_dict[conf.fdec_oper] + self.mdec_oper = oper_dict[conf.mdec_oper] + # + self.ir_reducer = IterReducer(conf.ir_conf) + # + if conf.fdec_vocab_file: + self.r_vocab = Vocab.read(conf.fdec_vocab_file) + else: + self.r_vocab = None + + # ===== + # todo(note): no arti-ROOT in inputs + def test_one_sent(self, one_inst: ParseInstance): + conf = self.conf + which_fold = conf.which_fold + sent = one_inst.words.vals[1:] + uposes = one_inst.poses.vals[1:] + dep_heads = one_inst.heads.vals[1:] + dep_labels = [z.split(":")[0] for z in one_inst.labels.vals[1:]] # todo(note): only first level + # + ret_info = {"orig": (sent, uposes, dep_heads, dep_labels)} + slen = len(sent) + slen_arange = np.arange(slen) + orig_distances = one_inst.extra_features["sd2_scores"][:, :, which_fold] # [s, s+1] + # ===== + # step 0: pre-processing + aug_influence_scores = np.copy(orig_distances) # [s,s] + if conf.no_diag: + arange_t = slen_arange.astype(np.int32) + aug_influence_scores[:,1:][arange_t, arange_t] = 0. + influence_scores = aug_influence_scores[:,1:] + # ===== + # step 0.5: separate the layers + punct_masks = np.array([(z == "PUNCT" or z == "SYM") for z in uposes], dtype=np.bool) + nonpunct_masks = (~punct_masks) + content_masks = np.array([(z in self.content_pos_set) for z in uposes], dtype=np.bool) + content_idxes = content_masks.nonzero()[0] + fun_masks = (~content_masks) + fun_idxes = fun_masks.nonzero()[0] + # ===== + # step 1: leaf layer + if conf.use_fdec: + if conf.fdec_what == "pos": + fdec_masks = np.copy(fun_masks) + elif conf.fdec_what == "reduce": + _, ir_reduce_idxes = self.ir_reducer.reduce(influence_scores, orig_distances[:, 0]) + fdec_masks = np.zeros(slen, dtype=np.bool) + fdec_masks[ir_reduce_idxes[:int(slen*conf.fdec_ir_perc)]] = True + elif conf.fdec_what == "vocab": + vthr = conf.fdec_vr_thresh + fdec_masks = np.asarray([self.r_vocab.get(z, vthr+1)<=vthr for z in sent], dtype=np.bool) + else: + raise NotImplementedError() + # prepare fun scores + original_scores = influence_scores + fdec_scores = self.fdec_oper(conf.fdec_Olambda * original_scores, conf.fdec_Tlambda * (original_scores.T)) + fdec_scores[fdec_masks] = 0. # no inner interactions + # [slen], content words all have 0 as head + fdec_preds = self.fdecoder(fdec_masks, fun_scores=fdec_scores, orig_scores=original_scores) + group_idxes = [[i] for i in range(slen)] # [slen] of list of children-group (include self) + for m, h in enumerate(fdec_preds): + if fdec_masks[m] and h>0: + group_idxes[h-1].append(m) + mdec_idxes = (~fdec_masks).nonzero()[0] + else: + fdec_masks = np.zeros(slen, dtype=np.bool) + fdec_preds = [0] * slen + group_idxes = [[i] for i in range(slen)] + mdec_idxes = list(range(slen)) + # ===== + # step 2: main layer + root_scores = np.zeros(slen) # TODO(!) + if conf.mdec_aggregate_scores: + pass # TODO(!) + if conf.mdec_method == "m3s": + # special decoding starting from a middle step of m3 + pass # TODO(!) + else: + # decoding the contracted sent (only mdec_idxes) + original_scores = influence_scores + processed_scores = self.mdec_oper(conf.mdec_Olambda * original_scores, conf.mdec_Tlambda * (original_scores.T)) + original_scores = original_scores[mdec_idxes, :][:, mdec_idxes] + processed_scores = processed_scores[mdec_idxes, :][:, mdec_idxes] + root_scores = root_scores[mdec_idxes] + if len(processed_scores)>0: + output_heads, output_root = self.mdecoder(processed_scores, original_scores, root_scores, conf.mdec_proj) + # restore all + assert len(output_heads) == len(mdec_idxes) + output_root = mdec_idxes[output_root] + real_heads = fdec_preds.copy() # copy from fun preds + for m,h in enumerate(output_heads): + if h==0: + real_heads[mdec_idxes[m]] = 0 + else: + real_heads[mdec_idxes[m]] = mdec_idxes[h-1]+1 + output_heads = real_heads + else: + # only punctuations or empty sentence? + output_root = 0 + output_heads = [0] * slen + # ===== eval + ret_info["output"] = (output_heads, output_root) + # eval for UAS + ret_info.update(self._eval(dep_heads, output_heads, nonpunct_masks, content_masks, fun_masks, fdec_masks)) + # eval for reducing + ret_info.update(self._eval_order(dep_heads, content_masks, "POS", sent, dep_labels)) + ir_reduce_order, _ = self.ir_reducer.reduce(influence_scores, orig_distances[:, 0]) + ret_info.update(self._eval_order(dep_heads, ir_reduce_order, "IR", sent, dep_labels)) + # + if conf.show_result: + self.eval_order_verbose(dep_heads, ir_reduce_order, sent, dep_labels) + self._show(one_inst, output_heads, ret_info, conf.show_heatmap, aug_influence_scores, True) + # + self.infos.append(ret_info) + return ret_info + + def summary(self): + ret = {} + ret.update(self._sum_eval()) + ret.update(self._sum_eval_order("POS")) + ret.update(self._sum_eval_order("IR")) + # # tmp look + # from .helper_basic import UD2_DEP_LABELS + # for lab in ["<", ">"] + UD2_DEP_LABELS: + # zlog(f"# =====\nLooking at Label={lab}") + # Helper.printd(self._sum_eval_order("IR"+lab)) + return ret + +# +def main(args): + conf = SD3Conf(args) + sp = SentProcessor3(conf) + main_loop(conf, sp) + +# PYTHONPATH=../src/ python3 -m pdb ../src/tasks/cmd.py zdpar.stat2.decode3 input_file:en_cut.ppos.conllu +# PYTHONPATH=../src/ python3 -m pdb ../src/tasks/cmd.py zdpar.stat2.decode3 input_file:_en_dev.ppos.pic already_pre_computed:1 +# PYTHONPATH=../src/ python3 -m pdb ../src/tasks/cmd.py zdpar.stat2.decode3 input_file:_en_dev.ppos.pic already_pre_computed:1 show_result:1 show_heatmap:1 min_len:10 max_len:20 rand_input:1 rand_seed:100 + +# b tasks/zdpar/stat2/decode3:138 diff --git a/tasks/zdpar/stat2/helper_basic.py b/tasks/zdpar/stat2/helper_basic.py new file mode 100644 index 0000000..043c3c8 --- /dev/null +++ b/tasks/zdpar/stat2/helper_basic.py @@ -0,0 +1,317 @@ +# + +from typing import List, Dict +import numpy as np +import pickle +from collections import OrderedDict, Counter, defaultdict + +from scipy.stats import spearmanr, pearsonr +import pandas as pd + +from msp.utils import StatRecorder, Helper, zlog, zopen, Conf +from msp.nn import BK, NIConf +from msp.nn import init as nn_init +from msp.nn.modules.berter import BerterConf, Berter +from msp.data.vocab import Vocab + +from tasks.zdpar.common.confs import DepParserConf +from tasks.zdpar.common.data import ParseInstance, get_data_writer +from tasks.zdpar.main.test import prepare_test, get_data_reader, init_everything +from tasks.zdpar.ef.parser import G1Parser + +from .helper_feature import FeaturerConf, Featurer +from .helper_decode import get_decoder + +# speical decoding +class SDBasicConf(Conf): + def __init__(self): + # ===== + # io + self.input_file = "" + self.output_file = "" + self.output_pic = "" + self.rand_input = False + self.rand_seed = 12345 + # filter + self.min_len = 0 + self.max_len = 100000 + # feature extractor + self.already_pre_computed = False + self.fake_scores = False # create all 0. distance scores for debugging + self.fconf = FeaturerConf() + self.which_fold = -1 + # msp + self.niconf = NIConf() # nn-init-conf + # todo(+N): how to deal with strange influences of low-freq words? (topic influence?) + # replace by bert + self.sent_repl_times = 0 + # replace unk with vocab + self.vocab_file = "" + self.unk_repl_token = "[UNK]" + self.unk_repl_upos = [] + self.unk_repl_thresh = -1 # replace unk if <=this + self.unk_repl_split_thresh = 100 # replace unk if a word is split larger than this + # ===== + self.debug_no_gold = False + self.processing = True # no processing maybe means precomputing things + # ===== + +# ===== +# helpers + +# showing with heatmap +def show_heatmap(sent, distances, col_add_cls): + import matplotlib.pyplot as plt + from .helper_draw import heatmap, annotate_heatmap + fig, ax = plt.subplots() + im, cbar = heatmap(distances, sent, (["[CLS]"]+sent) if col_add_cls else sent, ax=ax, + cmap="YlGn", cbarlabel="distance") + texts = annotate_heatmap(im, valfmt="{x:.2f}") + fig.tight_layout() + plt.show() + +# +UD2_DEP_LABELS = ['acl', 'advcl', 'advmod', 'amod', 'appos', 'aux', 'case', 'cc', 'ccomp', 'clf', 'compound', 'conj', 'cop', 'csubj', 'dep', 'det', 'discourse', 'dislocated', 'expl', 'fixed', 'flat', 'goeswith', 'iobj', 'list', 'mark', 'nmod', 'nsubj', 'nummod', 'obj', 'obl', 'orphan', 'parataxis', 'punct', 'reparandum', 'root', 'vocative', 'xcomp'] + +class SentProcessor: + def __init__(self): + self.infos = [] + + def test_one_sent(self, one_inst: ParseInstance): + raise NotImplementedError("Need to specify in sub-classes!!") + + def summary(self): + raise NotImplementedError("Need to specify in sub-classes!!") + + # uas eval + def _eval(self, gold_heads, pred_heads, nonpunct_masks, content_masks, fun_masks, fdec_masks): + slen = len(gold_heads) + ret = defaultdict(int) + all_masks = np.ones(slen) + for midx in range(slen): + h_gold = gold_heads[midx] - 1 + h_pred = pred_heads[midx] - 1 + h_corr_d = int(h_gold == h_pred) + h_corr_u = int((h_gold == h_pred) or (h_pred >= 0 and gold_heads[h_pred] == midx + 1)) # for UUAS + # + for one_name, one_masks in zip(["all", "np", "ct", "fun", "fdec"], + [all_masks, nonpunct_masks, content_masks, fun_masks, fdec_masks]): + if one_masks[midx]: + ret[one_name+"_dhit"] += h_corr_d + ret[one_name+"_uhit"] += h_corr_u + ret[one_name+"_count"] += 1 + else: + ret[one_name + "_dhit"] += 0 + ret[one_name + "_uhit"] += 0 + ret[one_name + "_count"] += 0 + ret = dict(ret) + return ret + + # check the reduce orders of gold h/m + def _eval_order(self, gold_heads, reduce_order, prefix_name: str, sent, dep_labels): + ret = defaultdict(int) + for m, h in enumerate(gold_heads): + h -= 1 + if h<0: + continue + r_m, r_h = reduce_order[m], reduce_order[h] + if prefix_name=="IR" and r_m==r_h: + raise RuntimeError() + corr, unk, wrong = int(r_mr_h) + for infix, hit in zip(["", "<", ">", dep_labels[m]], [True, r_m=len(gold_heads)//2, 1]): + if hit: + cur_prefix = prefix_name + infix + ret[cur_prefix + "_all"] += 1 + ret[cur_prefix + "_corr"] += corr + ret[cur_prefix + "_unk"] += unk + ret[cur_prefix + "_wrong"] += wrong + ret = dict(ret) + return ret + + def eval_order_verbose(self, gold_heads, reduce_order, sent, dep_labels): + all_corr = 0 + for m in np.argsort(reduce_order): # sort by reduce order + h = gold_heads[m]-1 + if h<0: + continue + r_m, r_h = reduce_order[m], reduce_order[h] + corr = r_m < r_h + print(f"m={m}({sent[m]}), h={h}({sent[h]}), LAB={dep_labels[m]}, order={r_m}-{r_h}-{corr}") + all_corr += int(corr) + print(f"Final: {all_corr}/{len(sent)}") + + def _sum_eval(self): + r = {} + for cc in "du": + for one_name in ["all", "np", "ct", "fun", "fdec"]: + hit, count = sum(z.get(f"{one_name}_{cc}hit", 0) for z in self.infos), \ + sum(z.get(f"{one_name}_count", 0) for z in self.infos) + r[f"SEVAL_{cc}_{one_name}_UAS"] = hit / max(1, count) + r[f"SEVAL_{cc}_{one_name}_Detail"] = f"{hit}/{count}={hit/max(1, count)}" + return r + + def _sum_eval_order(self, prefix_name): + r = {} + for one_name in ["corr", "unk", "wrong"]: + hit, count = sum(z.get(f"{prefix_name}_{one_name}",0) for z in self.infos), \ + sum(z.get(f"{prefix_name}_all", 0) for z in self.infos) + r[f"SEVAL_{prefix_name}_{one_name}_%"] = hit / max(1, count) + r[f"SEVAL_{prefix_name}_{one_name}_Detail"] = f"{hit}/{count}={hit/max(1, count)}" + return r + + def _show(self, one_inst, pred_heads, info, show_hmap: bool, distances, set_trace): + sent, uposes, dep_heads = one_inst.words.vals[1:], one_inst.poses.vals[1:], one_inst.heads.vals[1:] + slen = len(sent) + aug_sent = ["R"] + sent + all_tokens = [("R", "--", -1, "--", -1, "--")] + \ + [(sent[i], uposes[i], dep_heads[i], aug_sent[dep_heads[i]], + pred_heads[i], aug_sent[pred_heads[i]]) + for i in range(slen)] + print(pd.DataFrame(all_tokens).to_string()) + print("; ".join([f"{one_name}:{info[one_name+'_dhit']}/{info[one_name+'_uhit']}/{info[one_name+'_count']}" + for one_name in ["all", "np", "ct", "fun"]])) + # ===== + if show_hmap: + print(sent) + print(uposes) + print(one_inst.extra_features.get("feat_seq", None)) + for one_ri, one_seq in enumerate(one_inst.extra_features.get("sd3_repls", [])): + print(f"#R{one_ri}: {one_seq}") + show_heatmap(sent, distances, True) + if set_trace: + import pdb + pdb.set_trace() + + def _parse_heads(self, dep_heads: List[int]): + # todo(note): a naive root-path matching method + slen = len(dep_heads) + up_paths = [None for _ in range(slen)] + for m, h in enumerate(dep_heads): + # clear the up path + cur_m, cur_h = m, h + cur_stack = [] + while cur_m >= 0 and up_paths[cur_m] is None: + cur_stack.append(cur_m) + cur_m = cur_h - 1 + cur_h = dep_heads[cur_h - 1] + # assign them + upper_path = up_paths[cur_m] if cur_m >= 0 else [] # the upper path that is known + for one_i2, one_m in enumerate(cur_stack): + assert up_paths[one_m] is None + up_paths[one_m] = cur_stack[one_i2:] + upper_path + # once we know the paths, we know the depth + depths = [len(z) for z in up_paths] # starting from 1 + # then assign pairwise distance by finding common prefix + # todo(note): not efficient since there are much repeated computations + syntax_distances = np.zeros([slen, slen]) + for i1 in range(slen): + for i2 in range(i1 + 1, slen): + common_count = sum(a == b for a, b in zip(reversed(up_paths[i1]), reversed(up_paths[i2]))) + one_dist = depths[i1] + depths[i2] - 2 * common_count + syntax_distances[i1, i2] = one_dist + syntax_distances[i2, i1] = one_dist + return syntax_distances, up_paths, depths + +# ===== +# data reader +def yield_data(filename): + if filename is None or filename == "": + zlog("Start to read raw sentence from stdin") + while True: + line = input(">> ") + if len(line) == 0: + break + sent = line.split() + cur_len = len(sent) + one = ParseInstance(sent, ["_"] * cur_len, [0] * cur_len, ["_"] * cur_len) + yield one + # todo(note): judged by special ending!! + elif filename.endswith(".pic"): + zlog(f"Start to read pickle from file {filename}") + with zopen(filename, 'rb') as fd: + while True: + try: + one = pickle.load(fd) + yield one + except EOFError: + break + else: + # otherwise read collnu + zlog(f"Start to read conllu from file {filename}") + for one in get_data_reader(filename, "conllu", "", True): + yield one + +# ===== +def main_loop(conf: SDBasicConf, sp: SentProcessor): + np.seterr(all='raise') + nn_init(conf.niconf) + np.random.seed(conf.rand_seed) + # + # will trigger error otherwise, save time of loading model + featurer = None if conf.already_pre_computed else Featurer(conf.fconf) + output_pic_fd = zopen(conf.output_pic, 'wb') if conf.output_pic else None + all_insts = [] + vocab = Vocab.read(conf.vocab_file) if conf.vocab_file else None + unk_repl_upos_set = set(conf.unk_repl_upos) + with BK.no_grad_env(): + input_stream = yield_data(conf.input_file) + if conf.rand_input: + inputs = list(input_stream) + np.random.shuffle(inputs) + input_stream = inputs + for one_inst in input_stream: + # ----- + # make sure the results are the same; to check whether we mistakenly use gold in that jumble of analysis + if conf.debug_no_gold: + one_inst.heads.vals = [0] * len(one_inst.heads.vals) + if len(one_inst.heads.vals) > 2: + one_inst.heads.vals[2] = 1 # avoid err in certain analysis + one_inst.labels.vals = ["_"] * len(one_inst.labels.vals) + # ----- + if len(one_inst)>=conf.min_len and len(one_inst)<=conf.max_len: + folded_distances = one_inst.extra_features.get("sd2_scores") + if folded_distances is None: + if conf.fake_scores: + one_inst.extra_features["sd2_scores"] = np.zeros(featurer.output_shape(len(one_inst.words.vals[1:]))) + else: + # ===== replace certain words? + word_seq = one_inst.words.vals[1:] + upos_seq = one_inst.poses.vals[1:] + if conf.unk_repl_thresh > 0: + word_seq = [(conf.unk_repl_token if (u in unk_repl_upos_set and + vocab.getval(w, 0)<=conf.unk_repl_thresh) else w) + for w,u in zip(word_seq, upos_seq)] + if conf.unk_repl_split_thresh<10: + berter_toker = featurer.berter.tokenizer + word_seq = [conf.unk_repl_token if (u in unk_repl_upos_set and + len(berter_toker.tokenize(w))>conf.unk_repl_split_thresh) else w + for w,u in zip(word_seq, upos_seq)] + # ===== auto repl by bert? + sent_repls = [word_seq] + sent_fixed = [np.zeros(len(word_seq)).astype(np.bool)] + for _ in range(conf.sent_repl_times): + new_sent, new_fixed = featurer.repl_sent(sent_repls[-1], sent_fixed[-1]) + sent_repls.append(new_sent) + sent_fixed.append(new_fixed) # once fixed, always fixed + one_inst.extra_features["sd3_repls"] = sent_repls + one_inst.extra_features["sd3_fixed"] = sent_fixed + # ===== score + folded_distances = featurer.get_scores(sent_repls[-1]) + one_inst.extra_features["sd2_scores"] = folded_distances + one_inst.extra_features["feat_seq"] = word_seq + if output_pic_fd is not None: + pickle.dump(one_inst, output_pic_fd) + if conf.processing: + one_info = sp.test_one_sent(one_inst) + # put prediction + one_inst.pred_heads.set_vals([0] + list(one_info["output"][0])) + one_inst.pred_labels.set_vals(["_"] * len(one_inst.labels.vals)) + all_insts.append(one_inst) + if output_pic_fd is not None: + output_pic_fd.close() + if conf.output_file: + with zopen(conf.output_file, 'w') as wfd: + data_writer = get_data_writer(wfd, "conllu") + data_writer.write(all_insts) + # ----- + Helper.printd(sp.summary()) diff --git a/tasks/zdpar/stat2/helper_cluster.py b/tasks/zdpar/stat2/helper_cluster.py new file mode 100644 index 0000000..98a78a1 --- /dev/null +++ b/tasks/zdpar/stat2/helper_cluster.py @@ -0,0 +1,50 @@ +# + +from typing import List +import numpy as np + +# +class OneCluster: + def __init__(self, nodes: List[int], head: int): + self.nodes = nodes + self.head = head + + def merge(self, other_cluster: 'OneCluster'): + self.nodes.extend(other_cluster.nodes) + # still keep self.head + +# greedy nearby clustering with head +# bottom up decoding +class ClusterManager: + def build_start_clusters(self, start_groups: List[List], start_heads: List[int]): + return [OneCluster(a,b) for a,b in zip(start_groups, start_heads)] + + # slen is overall length, start_* is the current clusters, matrices are all in [h(slen),m(slen)] convention + def cluster(self, slen: int, start_clusters: List[OneCluster], allowed_matrix, link_scores, head_scores): + ret = [0] * slen + procedures = [] # List[(h, m)] + cur_clusters = start_clusters + while True: + # merge candidates between (i, i+1) + merge_candidates = [] + merge_scores = [] + for midx in range(len(cur_clusters)-1): + c1, c2 = cur_clusters[midx], cur_clusters[midx+1] + c1_nodes, c1_head = c1.nodes, c1.head + c2_nodes, c2_head = c2.nodes, c2.head + if allowed_matrix[c1_head, c2_head] or allowed_matrix[c2_head, c1_head]: + merge_candidates.append(midx) + # here using the linking scores + one_score = None # TODO(!) + merge_scores.append(one_score) + # break if nothing left to merge + if len(merge_candidates) == 0: + break + # greedily pick the best score + best_idx = np.argmax(merge_scores).item() + best_c1_idx = merge_candidates[best_idx] + best_c2_idx = best_c1_idx + 1 + # TODO(!) + # determine the direction + # prepare for the next iter + return ret, procedures diff --git a/tasks/zdpar/stat2/helper_decode.py b/tasks/zdpar/stat2/helper_decode.py new file mode 100644 index 0000000..c55413f --- /dev/null +++ b/tasks/zdpar/stat2/helper_decode.py @@ -0,0 +1,249 @@ +# + +# various decoding methods +# todo(note): conventions +# the input scores utilize the lower-corner with pre-processing +# the original scores are also provided +# root scores are extra ones for the root +# for return: heads are offset-ed, but root is not + +import numpy as np + +from msp.utils import Conf +from ..algo.nmst import mst_unproj, mst_proj + +# ===== +# part 1: main decoder + +# ===== +# chain baselines +def chain_l2r(scores, orig_scores, root_scores, is_proj: bool): + slen = len(scores) + return list(range(slen)), 0 + +def chain_r2l(scores, orig_scores, root_scores, is_proj: bool): + slen = len(scores) + return [x+2 for x in range(slen-1)]+[0], slen-1 + +def decode_rand(scores, orig_scores, root_scores, is_proj: bool): + slen = len(scores) + input_scores = np.random.rand(1, slen+1, slen+1).astype(np.float32) + input_length = np.array([slen+1]).astype(np.int32) + if is_proj: + heads, _, _ = mst_proj(input_scores, input_length, False) + else: + heads, _, _ = mst_unproj(input_scores, input_length, False) + ret_heads = heads[0][1:] + ret_roots = [i for i,h in enumerate(ret_heads) if h==0] + return ret_heads, ret_roots[0] # simple the first root + +# ===== +# method 1: undirectional + selecting root + +# undirected mst algorithm: Kruskal +def mst_kruskal(scores): + cur_len = len(scores) # scores should be [cur_len, cur_len], and only use the i0 best_score: + best_score = cur_score + best_root = cur_root + best_heads = cur_heads + return best_heads, best_root + +# ===== +# method 2: directly using directional MST (can be non-proj or proj) + +# [slen, slen], [slen, slen], [slen] +def decode_m2(scores, orig_scores, root_scores, is_proj: bool): + # make scores into [m,h] and augment root lines + slen = len(scores) + pad_scores0 = np.concatenate([root_scores[np.newaxis, :], scores.T], 0) # [slen+1, slen] + pad_scores1 = np.concatenate([np.zeros([slen+1, 1]), pad_scores0], 1) # [slen+1, slen+1] + input_scores = pad_scores1[np.newaxis, :, :].astype(np.float32) + input_length = np.array([slen+1]).astype(np.int32) + if is_proj: + heads, _, _ = mst_proj(input_scores, input_length, False) + else: + heads, _, _ = mst_unproj(input_scores, input_length, False) + ret_heads = heads[0][1:] + ret_roots = [i for i,h in enumerate(ret_heads) if h==0] + return ret_heads, ret_roots[0] # simple the first root + +# ===== +# method 3: easy first styled clustering + +# [slen, slen], [slen, slen], [slen] +def decode_m3(scores, orig_scores, root_scores, is_proj: bool): + assert is_proj + +# ===== +# method extra: sink function words + +# ===== +# ===== + +def get_decoder(method): + return { + "l2r": chain_l2r, "r2l": chain_r2l, "rand": decode_rand, + "m1": decode_m1, "m2": decode_m2, "m3": decode_m3, + }[method] + +# ===== +# part 2: fun decoder + +# attach all fun to the right content word (unless boundary) +def fdec_right(fun_masks, **kwargs): + slen = len(fun_masks) + ret = [0] * slen + prev_content_idx = -1 + prev_fun_list = [] + for widx, is_fun in enumerate(fun_masks): + if is_fun: + prev_fun_list.append(widx) + else: + for w in prev_fun_list: + ret[w] = widx+1 + prev_fun_list.clear() + prev_content_idx = widx + # attach the rest + for w in prev_fun_list: + ret[w] = prev_content_idx+1 + return ret + +# attach all fun to the left content word (unless boundary) +def fdec_left(fun_masks, **kwargs): + slen = len(fun_masks) + ret = [0] * slen + # first deal with the first group + widx = 0 + while widx=slen: + return ret # all fun words + # + prev_content_idx = widx + for w in range(widx): + ret[w] = prev_content_idx+1 + # + while widx0: + cur_scores = orig_scores[remaining_idxes][:,remaining_idxes] + cur_reduce_tidx = np.argmin(self.get_rank_keys(cur_scores, extra_scores[remaining_idxes])) + cur_reduce_idx = remaining_idxes[cur_reduce_tidx] + reduce_idxes.append(cur_reduce_idx) + remaining_idxes.remove(cur_reduce_idx) # must be there + assert len(reduce_idxes) == slen + else: + # todo(note): directly rank all + reduce_idxes = np.argsort(self.get_rank_keys(orig_scores, extra_scores)) + reduce_order = [None] * slen + for one_order, one_idx in enumerate(reduce_idxes): + reduce_order[one_idx] = one_order + return reduce_order, reduce_idxes diff --git a/tasks/zdpar/stat2/helper_draw.py b/tasks/zdpar/stat2/helper_draw.py new file mode 100644 index 0000000..118acbe --- /dev/null +++ b/tasks/zdpar/stat2/helper_draw.py @@ -0,0 +1,134 @@ +# + +# from matplotlib example + +# from __future__ import unicode_literals + +import matplotlib +import matplotlib.pyplot as plt +import numpy as np + +# support unicode +# matplotlib.rc('font', family='Arial') + +# matplotlib.rc('font', **{'sans-serif': 'Arial', 'family': 'sans-serif'}) +# matplotlib.rc('font', family='TakaoPGothic') + +def heatmap(data, row_labels, col_labels, ax=None, + cbar_kw={}, cbarlabel="", **kwargs): + """ + Create a heatmap from a numpy array and two lists of labels. + + Parameters + ---------- + data + A 2D numpy array of shape (N, M). + row_labels + A list or array of length N with the labels for the rows. + col_labels + A list or array of length M with the labels for the columns. + ax + A `matplotlib.axes.Axes` instance to which the heatmap is plotted. If + not provided, use current axes or create a new one. Optional. + cbar_kw + A dictionary with arguments to `matplotlib.Figure.colorbar`. Optional. + cbarlabel + The label for the colorbar. Optional. + **kwargs + All other arguments are forwarded to `imshow`. + """ + + if not ax: + ax = plt.gca() + + # Plot the heatmap + im = ax.imshow(data, **kwargs) + + # Create colorbar + cbar = ax.figure.colorbar(im, ax=ax, **cbar_kw) + cbar.ax.set_ylabel(cbarlabel, rotation=-90, va="bottom") + + # We want to show all ticks... + ax.set_xticks(np.arange(data.shape[1])) + ax.set_yticks(np.arange(data.shape[0])) + # ... and label them with the respective list entries. + ax.set_xticklabels(col_labels) + ax.set_yticklabels(row_labels) + + # Let the horizontal axes labeling appear on top. + ax.tick_params(top=True, bottom=False, + labeltop=True, labelbottom=False) + + # Rotate the tick labels and set their alignment. + plt.setp(ax.get_xticklabels(), rotation=-30, ha="right", + rotation_mode="anchor") + + # Turn spines off and create white grid. + for edge, spine in ax.spines.items(): + spine.set_visible(False) + + ax.set_xticks(np.arange(data.shape[1]+1)-.5, minor=True) + ax.set_yticks(np.arange(data.shape[0]+1)-.5, minor=True) + ax.grid(which="minor", color="w", linestyle='-', linewidth=3) + ax.tick_params(which="minor", bottom=False, left=False) + + return im, cbar + + +def annotate_heatmap(im, data=None, valfmt="{x:.2f}", + textcolors=["black", "white"], + threshold=None, **textkw): + """ + A function to annotate a heatmap. + + Parameters + ---------- + im + The AxesImage to be labeled. + data + Data used to annotate. If None, the image's data is used. Optional. + valfmt + The format of the annotations inside the heatmap. This should either + use the string format method, e.g. "$ {x:.2f}", or be a + `matplotlib.ticker.Formatter`. Optional. + textcolors + A list or array of two color specifications. The first is used for + values below a threshold, the second for those above. Optional. + threshold + Value in data units according to which the colors from textcolors are + applied. If None (the default) uses the middle of the colormap as + separation. Optional. + **kwargs + All other arguments are forwarded to each call to `text` used to create + the text labels. + """ + + if not isinstance(data, (list, np.ndarray)): + data = im.get_array() + + # Normalize the threshold to the images color range. + if threshold is not None: + threshold = im.norm(threshold) + else: + threshold = im.norm(data.max())/2. + + # Set default alignment to center, but allow it to be + # overwritten by textkw. + kw = dict(horizontalalignment="center", + verticalalignment="center") + kw.update(textkw) + + # Get the formatter in case a string is supplied + if isinstance(valfmt, str): + valfmt = matplotlib.ticker.StrMethodFormatter(valfmt) + + # Loop over the data and create a `Text` for each "pixel". + # Change the text's color depending on the data. + texts = [] + for i in range(data.shape[0]): + for j in range(data.shape[1]): + kw.update(color=textcolors[int(im.norm(data[i, j]) > threshold)]) + text = im.axes.text(j, i, valfmt(data[i, j], None), **kw) + texts.append(text) + + return texts diff --git a/tasks/zdpar/stat2/helper_feature.py b/tasks/zdpar/stat2/helper_feature.py new file mode 100644 index 0000000..b1b21d6 --- /dev/null +++ b/tasks/zdpar/stat2/helper_feature.py @@ -0,0 +1,163 @@ +# + +# the featurer extracts the features and calculate the repr distances (influences scores) + +from typing import List, Dict +import numpy as np +import os + +from msp.utils import StatRecorder, Helper, zlog, zopen, Conf +from msp.nn import BK +from msp.nn.modules.berter import BerterConf, Berter +from msp.data.vocab import Vocab +from tasks.zdpar.ef.parser import G1Parser +from tasks.zdpar.common.data import ParseInstance + +# +class FeaturerConf(Conf): + def __init__(self): + # todo(+N): later we can use other encoders, but currently only with bert + self.use_bert = True + self.bconf = BerterConf().init_from_kwargs(bert_sent_extend=0, bert_root_mode=-1, bert_zero_pademb=True, + bert_layers="[0,1,2,3,4,5,6,7,8,9,10,11,12]") + # for bert + self.b_mask_mode = "first" # all/first/one/pass + self.b_mask_repl = "[MASK]" # or other things like [UNK], [PAD], ... + # for g1p + self.g1p_replace_unk = True # otherwise replace with 0 + # distance (influence scores) + self.dist_f = "cos" # cos or mse + # for bert repl + self.br_vocab_file = "" + self.br_rank_thresh = 100 # rank threshold >=this for non-topical words + + def do_validate(self): + assert self.use_bert, "currently only implemented the bert mode" + +# +class Featurer: + def __init__(self, conf: FeaturerConf): + self.conf = conf + self.berter = Berter(None, conf.bconf) + self.fold = len(conf.bconf.bert_layers) + # used for bert-repl: self.repl_sent() + self.br_vocab = Vocab.read(conf.br_vocab_file) if conf.br_vocab_file else None + self.br_rthresh = conf.br_rank_thresh + + # replace (simplify) sentence by replacing with bert predictions + # todo(note): remember that the main purpose is to remove surprising topical words but retain syntax structure + # todo(+N): the process might be specific to en or alphabet based languages + def repl_sent(self, sent: List[str], fixed_arr): + sent_len = len(sent) + orig_scores_arr, pred_scores_arr, pred_strs, orig_subword_lists = self.berter.predict_each_word(sent) + br_vocab, br_rthresh = self.br_vocab, self.br_rthresh + # determined method + # 1. first extend fixed (input+unchanged) ones + assert len(pred_strs) == sent_len + is_fixed = [(fixed_arr[i] or sent[i]==pred_strs[i]) for i in range(sent_len)] + # 2. check for topical words (and change it if possible) + # 2.1 collect all input words (all, not-freq, not-subword) + input_words_set = set() + for sub_words in orig_subword_lists: + w = sub_words[0] # todo(note): only take the first one + if br_vocab.get(w, br_rthresh) >= br_rthresh: + input_words_set.add(w) + # 2.2 collect all output words (changed-only, not-freq, not-subword) + # topical words are those both in output and input + topical_words_count = {} + for one_pred, one_is_fix in zip(pred_strs, is_fixed): + # todo(note): not fix and not subword and not-freq + if not one_is_fix and not one_pred.startswith("##") and br_vocab.get(one_pred, br_rthresh)>=br_rthresh: + topical_words_count[one_pred] = topical_words_count.get(one_pred, 0) + 1 + # only retain those >1 or hit input + topical_words_count = {k:v for k,v in topical_words_count.items() if (v>1 or k in input_words_set)} + # 2.3 adjust scores: 1) encourage changing input topical words, 2) disencourage changing into topical words + # todo(+N): what strategy for the scores? currently mainly NEG orig-logprob + changing_scores = -1 * orig_scores_arr + for one_idx in range(sent_len): + one_orig, one_pred, one_is_fix = sent[one_idx], pred_strs[one_idx], is_fixed[one_idx] + if one_is_fix or (one_pred in topical_words_count) or (one_pred.startswith("##")): + changing_scores[one_idx] = 0 # not changing fix words or into topical words or suffix + elif one_orig in topical_words_count: + # todo(+2): magical number 100, this should be enough + changing_scores[one_idx] += 100*(topical_words_count[one_orig]) # encourage changing from topical words + # in pdb: pp list(zip(is_fixed,sent,pred_strs,changing_scores)) + # 3. change one for each segment + prev_cands = [] # (idx, score) + ret_seq = sent.copy() + for one_idx in range(sent_len+1): + if one_idx>=sent_len or is_fixed[one_idx]: # hit boundary + if len(prev_cands)>0: + prev_cands.sort(key=lambda x: -x[-1]) + togo_idx, togo_score = prev_cands[0] + if togo_score > 0: # only change it if not forbidden + ret_seq[togo_idx] = pred_strs[togo_idx] + prev_cands.clear() + else: + prev_cands.append((one_idx, changing_scores[one_idx])) + # todo(+N): whether fix changed word? currently nope + return ret_seq, np.asarray(is_fixed).astype(np.bool) + + # get influence scores + def get_scores(self, sent: List[str]): + # get features + conf = self.conf + features = encode_bert(self.berter, sent, conf.b_mask_mode, conf.b_mask_repl) # [1+slen, 1+slen, fold*D] + orig_feature_shape = BK.get_shape(features) + folded_shape = orig_feature_shape[:-1] + [self.fold, orig_feature_shape[-1]//self.fold] + all_features = features.view(folded_shape) # [1+slen, 1+slen, fold, D] + # get scores: [slen(without-which), slen+1(for-all-tokens), fold] + if conf.dist_f == "cos": + # normalize and dot + all_features = all_features / (BK.sqrt((all_features ** 2).sum(-1, keepdim=True)) + 1e-7) + scores = 1. - (all_features[1:] * all_features[0].unsqueeze(0)).sum(-1) + elif conf.dist_f == "mse": + scores = BK.sqrt(((all_features[1:] - all_features[0].unsqueeze(0)) ** 2).mean(-1)) + else: + raise NotImplementedError() + return scores.cpu().numpy() + + def output_shape(self, slen): + return [slen, slen+1, len(self.conf.bconf.bert_layers)] + +# ===== +# encoding with different models, no root in input sent + +# +def encode_bert(berter: Berter, sent: List[str], b_mask_mode, b_mask_repl): + assert berter.bconf.bert_sent_extend == 0 + assert berter.bconf.bert_root_mode == -1 + # prepare inputs + all_sentences = [berter.subword_tokenize(sent, True)] + for i in range(len(sent)): + all_sentences.append(berter.subword_tokenize(sent.copy(), True, mask_idx=i, mask_mode=b_mask_mode, mask_repl=b_mask_repl)) + # all_sentences.append(berter.subword_tokenize(sent.copy(), True, mask_idx=-1, mask_mode=b_mask_mode, mask_repl=b_mask_repl)) + # get outputs + all_features = berter.extract_features(all_sentences) # List[arr[1+slen, D]] + # all_features = berter.extract_feature_simple_mode(all_sentences) # List[arr[1+slen, D]] + # ===== + # post-processing for pass mode + if b_mask_mode=="pass": + for i in range(1, len(all_features)): + all_features[i] = np.insert(all_features[i], i, 0., axis=0) # [slen, D] -> [slen+1, D] + return BK.input_real(np.stack(all_features, 0)) # [1(whole)+slen, 1(R)+slen, D] + +# +def encode_model(model: G1Parser, sent: List[str], g1p_replace_unk): + word_vocab = model.vpack.get_voc("word") + REPL = word_vocab.unk if g1p_replace_unk else 0 + ROOT = word_vocab[ParseInstance.ROOT_SYMBOL] # must be there + sent_idxes = [ROOT] + [word_vocab.get_else_unk(z) for z in sent] # 1(R)+slen + # prepare inputs + all_sentences = [sent_idxes] + for i in range(len(sent)): + new_one = sent_idxes.copy() + new_one[i+1] = REPL # here 1 as offset for ROOT + all_sentences.append(new_one) + # get embeddings and encodings + word_arr = np.asarray(all_sentences) # [1+slen, 1+slen], no need to pad + emb_repr = model.bter.emb(word_arr=word_arr) # no other features, [1+slen, 1+slen, D0] + enc_repr = model.bter.enc(emb_repr) # no needfor mask, [1+slen, 1+slen, D1] + return enc_repr # [1(whole)+slen, 1(R)+slen, D] + +# b tasks/zdpar/stat2/helper_feature:99 diff --git a/tasks/zdpar/stat2/run.sh b/tasks/zdpar/stat2/run.sh new file mode 100644 index 0000000..47b6fe4 --- /dev/null +++ b/tasks/zdpar/stat2/run.sh @@ -0,0 +1,33 @@ +# + +# some runnings for stat2 + +# _run1028.sh +SRC_DIR="../src/" +#DATA_DIR="../data/UD_RUN/ud24/" +DATA_DIR="../data/ud24/" +CUR_LANG=en +RGPU= +DEVICE=-1 +for b_mask_mode in all first one pass; do +for b_mask_repl in MASK UNK PAD; do +for dist_f in cos mse; do +pic_name="_${CUR_LANG}_dev.${b_mask_mode}.${b_mask_repl}.${dist_f}.pic" +OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=${RGPU} PYTHONPATH=${SRC_DIR} python3 ${SRC_DIR}/tasks/cmd.py zdpar.stat2.decode3 device:${DEVICE} input_file:${DATA_DIR}/${CUR_LANG}_dev.ppos.conllu output_pic:${pic_name} b_mask_mode:${b_mask_mode} "b_mask_repl:[${b_mask_repl}]" dist_f:${dist_f} +done; done; done |& tee _log1028 + +# _run1029.sh +SRC_DIR="../src/" +#DATA_DIR="../data/UD_RUN/ud24/" +DATA_DIR="../data/ud24/" +CUR_LANG=en +RGPU= +DEVICE=-1 +for fpic in *.pic; do +for which_fold in {1..9} {10..12}; do +for use_fdec in 0 1; do + echo "ZRUN ${fpic} ${which_fold} ${use_fdec}" + CUDA_VISIBLE_DEVICES=${RGPU} PYTHONPATH=${SRC_DIR} python3 ${SRC_DIR}/tasks/cmd.py zdpar.stat2.decode3 device:${DEVICE} input_file:$fpic already_pre_computed:1 which_fold:${which_fold} use_fdec:${use_fdec} +done +done +done |& tee _log1029 diff --git a/tasks/zdpar/transition/topdown/__init__.py b/tasks/zdpar/transition/topdown/__init__.py new file mode 100644 index 0000000..416e35a --- /dev/null +++ b/tasks/zdpar/transition/topdown/__init__.py @@ -0,0 +1,7 @@ +# + +# top-down stack-ptr parser + +# outside implicit graph + inner core search graph + +# todo(WARN): this module is not updated and might be incompatible with new changes! diff --git a/tasks/zdpar/transition/topdown/analysis/ana0.py b/tasks/zdpar/transition/topdown/analysis/ana0.py new file mode 100644 index 0000000..4382a2e --- /dev/null +++ b/tasks/zdpar/transition/topdown/analysis/ana0.py @@ -0,0 +1,462 @@ +# + +# analyze on the conllu files of 1) gold 2) system-output, 3) system2-output +# -- sentence-wise, token-wise, TD step-wise + +import sys +import json +import numpy as np +from collections import defaultdict, Counter +from msp.zext.dpar.conllu_reader import ConlluReader +from msp.utils import zopen, zlog, StatRecorder, Helper, Conf, PickleRW + +# from tasks.zdpar.transition.topdown.decoder import ParseInstance, TdState + +class ZObject(object): + def __init__(self, m=None): + if m is not None: + self.update(m) + + def update(self, m): + for k, v in m.items(): + setattr(self, k, v) + +# ===== +def yield_ones(file): + r = ConlluReader() + with zopen(file) as fd: + for p in r.yield_ones(fd): + # add root info + corder_valid_flag = False + for line in p.headlines: + try: + misc0 = json.loads(line)["ROOT-MISC"] + for s in misc0.split("|"): + k, v = s.split("=") + p.tokens[0].misc[k] = v + corder_valid_flag = True + break + except: + continue + # add children order + if corder_valid_flag: + for t in p.tokens: + corder_str = t.misc.get("CORDER") + if corder_str is not None: + t.corder = [int(x) for x in corder_str.split(",")] + else: + t.corder = [] + yield p + +# +class AnalysisConf(Conf): + def __init__(self, args): + self.gold = "" # gold parse file + self.f1 = "" # main system file + self.f2 = "" # system file 2 for comparison + self.cache_pkl = "" # cache stat to file + self.force_refresh_cache = False # ignore cache + self.output = "" + # analyze commands + self.loop = False # whether start a loop for interactive mode + self.grain = "steps" # sents, tokens, steps + self.fcode = "True" # filter code + self.mcode = "1" # map code + self.rcode = "" # reduce code + # ===== + self.update_from_args(args) + self.validate() + +def bool2int(x): + return 1 if x else 0 + +def norm_pos(x): + if x in {"NOUN", "PROPN", "PRON"}: + return "N" + else: + return x + +# +def main(args): + conf = AnalysisConf(args) + # read them + zlog("Read them all ...") + gold_parses = list(yield_ones(conf.gold)) + s1_parses = list(yield_ones(conf.f1)) + s2_parses = list(yield_ones(conf.f2)) + # analyze them + if conf.force_refresh_cache or conf.cache_pkl=="": + zlog("Re-stat them all ...") + ana_sents, ana_tokens, ana_steps = [], [], [] + assert len(gold_parses) == len(s1_parses) and len(gold_parses) == len(s2_parses) + for gp, s1p, s2p in zip(gold_parses, s1_parses, s2_parses): + gp_tokens = gp.tokens + s1p_tokens = s1p.tokens + s2p_tokens = s2p.tokens + this_length = len(gp_tokens) + assert [z.word for z in gp_tokens] == [z.word for z in s1p_tokens] + assert [z.word for z in gp_tokens] == [z.word for z in s2p_tokens] + # ===== + # token wise + this_tokens = [] + for n_idx in range(this_length): + is_root = n_idx==0 + n_s1p_node = s1p_tokens[n_idx] + n_gp_node = gp_tokens[n_idx] + n_s2p_node = s2p_tokens[n_idx] + # ----- + gp_head, gp_label, gp_children_set = n_gp_node.head, n_gp_node.label0, set(n_gp_node.childs) + s1p_head, s1p_label, s1p_children_set = n_s1p_node.head, n_s1p_node.label0, set(n_s1p_node.childs) + s2p_head, s2p_label, s2p_children_set = n_s2p_node.head, n_s2p_node.label0, set(n_s2p_node.childs) + s12_ueq = s1p_head==s2p_head + s12_leq = s12_ueq and (s1p_label==s2p_label) + props_tok = { + "gp": n_gp_node, "s1p": n_s1p_node, "s2p": n_s2p_node, + # gold label + "pos": norm_pos(n_gp_node.upos), "hpos": norm_pos("R0" if is_root else n_gp_node.get_head().upos), + "head": 0 if is_root else gp_head, "label": "R0" if is_root else gp_label, + # predictions + "s1p_label": "R0" if is_root else s1p_label, "s2p_label": "R0" if is_root else s2p_label, + "s1p_ucorr": bool2int(s1p_head==gp_head), "s1p_lcorr": bool2int(s1p_head==gp_head and s1p_label==gp_label), + "s2p_ucorr": bool2int(s2p_head==gp_head), "s2p_lcorr": bool2int(s2p_head==gp_head and s2p_label==gp_label), + "s12_ueq": bool2int(s12_ueq), "s12_leq": bool2int(s12_leq), + } + this_tokens.append(ZObject(props_tok)) + # ===== + # step level + this_steps = [] + # ===== go recursive + this_reduced = [False] * this_length + this_arc = [-1] * this_length + this_numbers = [0, 0] # (attach-num, reduce-num) + def _visit(n_s1p_node, stack): + n_idx = n_s1p_node.idx + stack.append(n_s1p_node) + res_h = this_tokens[n_idx] + # attach + prev_c_num = 0 + prev_c_uerr = 0 + prev_c_lerr = 0 + h_uerr = 1 - res_h.s1p_ucorr + for c_ord, c in enumerate(n_s1p_node.corder): + res_c = this_tokens[c] + s1p_uerr, s1p_lerr = 1-bool2int(res_c.s1p_ucorr), 1-bool2int(res_c.s1p_lcorr) + s2p_uerr, s2p_lerr = 1-bool2int(res_c.s2p_ucorr), 1-bool2int(res_c.s2p_lcorr) + real_h_reduced = bool2int(this_reduced[res_c.head]) # possible error because of reduced real head + # + if s1p_uerr: + err_type = "err_reduce" if real_h_reduced else "err_attach" + else: + err_type = "err_none" + # + props_step = { + "len": this_length, "step": len(this_steps), "depth": len(stack), "attach_num": this_numbers[0], + "type": "attach", "h": res_h, "c": res_c, "c_ord": c_ord, "h_uerr": h_uerr, + # int (count) + "prev_c_num": prev_c_num, "prev_c_uerr": prev_c_uerr, "prev_c_lerr": prev_c_lerr, + # bool -> int + "s1p_uerr": s1p_uerr, "s1p_lerr": s1p_lerr, "s2p_uerr": s2p_uerr, "s2p_lerr": s2p_lerr, + "real_h_reduced": real_h_reduced, "s1p_real_uerr": s1p_uerr*(1-real_h_reduced), + # which kind of error + "err": err_type, + } + prev_c_num += 1 + prev_c_uerr += s1p_uerr + prev_c_lerr += s1p_lerr + this_arc[c] = n_idx + this_numbers[0] += 1 + this_steps.append(ZObject(props_step)) + # recursive + _visit(s1p_tokens[c], stack) + # reduce non-root + if n_idx>0: + real_c_num = len(res_h.gp.childs) + miss_c_num = real_c_num - (prev_c_num - prev_c_uerr) + s2p_c_num, s2p_c_uerr, s2p_c_lerr, s2p_c_miss = 0, 0, 0, 0 + for s2p_c in res_h.s2p.childs: + s2p_c_num += 1 + res_s2p_c = this_tokens[s2p_c] + s2p_c_uerr += 1-bool2int(res_s2p_c.s2p_ucorr) + s2p_c_lerr += 1-bool2int(res_s2p_c.s2p_lcorr) + s2p_c_miss = real_c_num - (s2p_c_num - s2p_c_uerr) + props_step = { + "len": this_length, "step": len(this_steps), "depth": len(stack), "reduce_num": this_numbers[1], + "type": "reduce", "h": res_h, "h_uerr": h_uerr, + # int (count) + "real_c_num": real_c_num, + "prev_c_num": prev_c_num, "prev_c_uerr": prev_c_uerr, "prev_c_lerr": prev_c_lerr, + "miss_c_num": miss_c_num, "miss_real_c_num": len([1 for zz in res_h.gp.childs if this_arc[zz]<0]), + "s2p_c_num": s2p_c_num, "s2p_c_uerr": s2p_c_uerr, "s2p_c_lerr": s2p_c_lerr, "s2p_c_miss": s2p_c_miss, + } + this_steps.append(ZObject(props_step)) + this_reduced[n_idx] = True + this_numbers[1] += 1 + stack.pop() + # ===== + cur_root = s1p_tokens[0] + cur_stack = [] + _visit(cur_root, cur_stack) + # sentence level + assert len(this_steps) == 2*(len(this_tokens)-1) + assert len([z for z in this_steps if z.type=="attach"]) == len(this_tokens)-1 + this_sent = { + "tokens": this_tokens, "steps": this_steps + } + # + ana_steps.extend(this_steps) + ana_tokens.extend(this_tokens) + ana_sents.append(this_sent) + if conf.cache_pkl: + zlog(f"Write cache to {conf.cache_pkl}") + PickleRW.to_file([ana_sents, ana_tokens, ana_steps], conf.cache_pkl) + else: + zlog(f"Reload cache from {conf.cache_pkl}") + ana_sents, ana_tokens, ana_steps = PickleRW.from_file(conf.cache_pkl) + # ===== + # real analyzing + zlog("Analyzing them all ...") + fout = open(conf.output, "w") if conf.output else None + while True: + if conf.rcode != "": + ana_insts = {"sents": ana_sents, "tokens": ana_tokens, "steps": ana_steps}[conf.grain] + ff_f, ff_m, ff_r = [compile(z, "", "eval") for z in [conf.fcode, conf.mcode, conf.rcode]] + zlog(f"# compute with {conf.grain}: f={conf.fcode}, m={conf.mcode}, r={conf.rcode}") + count_all, count_hit = len(ana_insts), 0 + # todo(warn): use local names: 'z', 'cs' + cs = [] + for z in ana_insts: + if eval(ff_f): + count_hit += 1 + cs.append(eval(ff_m)) + r = eval(ff_r) + zlog(f"Rate = {count_hit} / {count_all} = {count_hit/count_all}") + zlog(f"Res = {r}") + if fout is not None: + fout.write(str(r)+"\n") + # + if conf.loop: + cmd = "" + while len(cmd)==0: + cmd = input(">> ").strip() + conf.update_from_args([z.strip() for z in cmd.split("\"") if z.strip()!=""]) + else: + break + if fout is not None: + fout.close() + +# ===== +# helper functions + +# helper for div +def hdiv(a, b): + if b == 0: + return (a, b, 0.) + else: + return (a, b, a/b) + +# helper for plain counter & distribution +def hdistr(cs, key="count", ret_format="", min_count=0): + r = Counter(cs) + c = len(cs) + get_keys_f = { + "count": lambda r: sorted(r.keys(), key=(lambda x: -r[x])), + "name": lambda r: sorted(r.keys()), + }[key] + thems = [(k, r[k], r[k]/c) for k in get_keys_f(r)] + # + ret_f = { + "": lambda x: x, + "s": lambda x: "\n".join([str(z) for z in x]), + }[ret_format] + thems = [z for z in thems if z[1]>=min_count] + return ret_f(thems) + +# helper for multi-level counter & distribution +def hmdistr(cs, key="count", ret_format="", min_count=0): + record = [0, {}] # str -> (count, next) + # collect + for one in cs: + cur_r = record + cur_r[0] += 1 # overall count + for i, x in enumerate(one): + if x not in cur_r[1]: + cur_r[1][x] = [0, {}] + cur_r = cur_r[1][x] + cur_r[0] += 1 + # recursive visit + get_keys_f = { + "count": lambda r: sorted(r.keys(), key=(lambda x: -r[x][0])), # sort by count + "name": lambda r: sorted(r.keys()), + }[key] + # + stack_counts = [] + stack_keys = [] + thems = [] + def _visit(cur_r): + level = len(stack_counts) + cur_count, cur_next = cur_r + stack_counts.append(cur_count) + for k in get_keys_f(cur_next): + stack_keys.append(k) + # stat + next_r = cur_next[k] + next_count = next_r[0] + thems.append((level, stack_keys.copy(), next_count, [next_count/z for z in stack_counts])) + _visit(next_r) + stack_keys.pop() + stack_counts.pop() + _visit(record) + # + ret_f = { + "": lambda x: x, + "s": lambda x: "\n".join([str(z) for z in x]), + }[ret_format] + thems = [z for z in thems if z[2] >= min_count] + return ret_f(thems) + +# +def hgroup(k, length, N, is_step=True): + step = k + all_steps = (2*length-2) if is_step else length-1 + g = int(step/all_steps*N) + return g + +# +def hcmp_err(a, b): + return {(0,0):"both_good", (1,1):"both_bad", (0,1):"sys_good", (1,0):"sys_bad"}[(a,b)] + +def hclip(x, a, b): + return min(max((x,a)),b) + +# ===== +# helper functions version 2: dumpers + +# for mcode +def hdm_err_break(key, z): + # unlabeled + s1_corr = 1-z.s1p_uerr + s2_corr = 1-z.s2p_uerr + both_corr = s1_corr * s2_corr + s1_err = z.s1p_uerr + s1_first_err = z.s1p_real_uerr + # labeled + s1_corr_L = 1-z.s1p_lerr + s2_corr_L = 1-z.s2p_lerr + both_corr_L = s1_corr_L*s2_corr_L + # + s12_diff, s12_diff_L = s1_corr-s2_corr, s1_corr_L-s2_corr_L + return (key, [both_corr, s1_corr, s2_corr, s12_diff, both_corr_L, s1_corr_L, s2_corr_L, s12_diff_L, s1_err, s1_first_err]) + +# for rcode +def hdr_err_break(cs): + r = {} + for key, arr in cs: + if key not in r: + r[key] = [] + r[key].append(arr) + ret = {k: np.array(arrs).mean(axis=0).tolist() + [len(arrs)+0., len(arrs)/len(cs)] for k, arrs in r.items()} + return json.dumps(ret) + +# formatting for printing +def hdformat_err_break(line): + import json + ret = json.loads(line) + for key, res in ret.items(): + both_corr, s1_corr, s2_corr, s12_diff, both_corr_L, s1_corr_L, s2_corr_L, s12_diff_L, s1_err, s1_first_err, *_ = res + print(f"{key} & {s1_corr*100:.2f}/{s1_corr_L*100:.2f} & {s2_corr*100:.2f}/{s2_corr_L*100:.2f} & {s12_diff*100:.2f}/{s12_diff_L*100:.2f} \\\\") + +# questions (interactive tryings) +""" +PYTHONPATH=../src/ python3 ana0.py gold:../data/ud22/en_test.conllu f1:en.td1.out f2:en.g.out loop:1 +# +# 1. UAS/LAS/Error-rate? +"grain:steps" "fcode:z.type=='attach'" "mcode:z.s1p_uerr" "rcode:hdiv(sum(cs), len(cs))" +"grain:steps" "fcode:z.type=='attach'" "mcode:z.s1p_lerr" "rcode:hdiv(sum(cs), len(cs))" +# real attach/reduce error +"grain:steps" "fcode:z.type=='attach'" "mcode:z.s1p_real_uerr" "rcode:hdiv(sum(cs), len(cs))" +"grain:steps" "fcode:z.type=='reduce'" "mcode:z.miss_real_c_num" "rcode:hdiv(sum(cs), len(cs))" +# +# 2. children order +# 2.1 stat of gold children +"grain:tokens" "fcode:z.hpos!='R0'" "mcode:(z.hpos, z.label)" "rcode:hmdistr(cs, ret_format='s', min_count=10)" +# 2.2 stat of pred chidren +"grain:steps" "fcode:z.type=='attach'" "mcode:(z.h.pos, z.c.s1p_label, z.c_ord)" "rcode:hmdistr(cs, ret_format='s', min_count=10)" +# +# 3. comparison +# 3.1 overall +"grain:steps" "fcode:z.type=='attach'" "mcode:(z.s1p_uerr, z.s2p_uerr, z.s1p_real_uerr, z.step)" "rcode:hmdistr(cs, ret_format='s', min_count=10)" +# 3.2 wrong ones +"grain:steps" "fcode:z.type=='attach' and z.s1p_uerr and not z.s2p_uerr" "mcode:(z.s1p_real_uerr, z.c.s1p.ddist)" "rcode:hmdistr(cs, ret_format='s', min_count=10)" +# +# 4. error breakdown +# what to breakdown? -> z.step, z.c.s1p.rdist, z.c.gp.rdist, z.c.s1p.ddist, z.c.gp.ddist, z.c.label, z.h.pos, z.c.pos, +# hgroup(z.step, z.len, 5), hgroup(z.attach_num, z.len, 5, False), hclip(z.c.s1p.ddist, -20, 20), hclip(z.c.gp.rdist, 0, 7) +# hcmp_err(z.s1p_uerr, z.s2p_uerr), z.s1p_real_uerr +"grain:steps" "fcode:z.type=='attach'" "mcode:(z.c.gp.rdist, hcmp_err(z.s1p_lerr, z.s2p_lerr),)" "rcode:hmdistr(cs, ret_format='s', min_count=0, key='name')" +"grain:steps" "fcode:z.type=='attach'" "mcode:(z.c.gp.rdist, z.s1p_uerr, z.s1p_lerr,)" "rcode:hmdistr(cs, ret_format='s', min_count=0, key='name')" +"grain:steps" "fcode:z.type=='attach'" "mcode:(z.c.gp.rdist, z.s2p_uerr, z.s2p_lerr,)" "rcode:hmdistr(cs, ret_format='s', min_count=0, key='name')" +# +"grain:steps" "fcode:z.type=='attach'" "mcode:hdm_err_break(hgroup(z.step, z.len, 5), z)" "rcode:hdr_err_break(cs)" +""" + +# run in batch mode +""" +# -- +# file: cmd.txt +"grain:steps" "fcode:z.type=='attach'" "mcode:hdm_err_break(hgroup(z.step, z.len, 5), z)" "rcode:hdr_err_break(cs)" +"grain:steps" "fcode:z.type=='attach'" "mcode:hdm_err_break(hgroup(z.attach_num, z.len, 5, False), z)" "rcode:hdr_err_break(cs)" +"grain:steps" "fcode:z.type=='attach'" "mcode:hdm_err_break(hclip(z.c.gp.rdist, 0, 7), z)" "rcode:hdr_err_break(cs)" +"grain:steps" "fcode:z.type=='attach'" "mcode:hdm_err_break(hclip(z.c.s1p.rdist, 0, 7), z)" "rcode:hdr_err_break(cs)" +"grain:steps" "fcode:z.type=='attach'" "mcode:hdm_err_break(hclip(z.c.gp.ddist, -10, 10), z)" "rcode:hdr_err_break(cs)" +"grain:steps" "fcode:z.type=='attach'" "mcode:hdm_err_break(hclip(z.c.s1p.ddist, -10, 10), z)" "rcode:hdr_err_break(cs)" +"grain:steps" "fcode:z.type=='attach'" "mcode:hdm_err_break(z.c.label, z)" "rcode:hdr_err_break(cs)" +"grain:steps" "fcode:z.type=='attach'" "mcode:hdm_err_break(z.c.pos, z)" "rcode:hdr_err_break(cs)" +"grain:steps" "fcode:z.type=='attach'" "mcode:hdm_err_break(z.c.hpos, z)" "rcode:hdr_err_break(cs)" +"grain:steps" "fcode:z.type=='attach'" "mcode:hdm_err_break(z.c.s1p_label, z)" "rcode:hdr_err_break(cs)" +"grain:steps" "fcode:z.type=='attach'" "mcode:hdm_err_break(z.h.pos, z)" "rcode:hdr_err_break(cs)" +# +for wset in test train; do +for cl in ar bg ca hr cs da nl en et "fi" fr de he hi id it ko la lv no pl pt ro ru sk sl es sv uk; do +echo "RUN ${wset} ${cl}" +PYTHONPATH=../src/ python3 ana0.py gold:../data/ud22/${cl}_${wset}.conllu f1:${wset}/${cl}.td.out f2:${wset}/${cl}.g.out loop:1 output:${wset}/${cl}.td.json 0: + one_dict = json.loads(line) + for k,v in one_dict.items(): + if k not in results[idx]: + results[idx][k] = [] + results[idx][k].append(v) + # average + # final_results = {k:{} for k in results} + final_results = [{} for i in range(len(results))] + for k1 in results: + for k2 in results[k1]: + final_results[k1][k2] = np.array(results[k1][k2]).mean(0).tolist() + with open(f"./{MAIN_DIR}/avg.json", "w") as fd: + for one in final_results: + fd.write(json.dumps(one)+"\n") + print(final_results) + # + pp = (lambda x: pprint([(z, x[z]) for z in sorted(x.keys(), key=lambda k: x[k][1]-x[k][2])])) + return final_results + +if __name__ == '__main__': + main(sys.argv[1:]) diff --git a/tasks/zdpar/transition/topdown/decoder.py b/tasks/zdpar/transition/topdown/decoder.py new file mode 100644 index 0000000..b813239 --- /dev/null +++ b/tasks/zdpar/transition/topdown/decoder.py @@ -0,0 +1,1516 @@ +# + +from typing import List +import numpy as np +from array import array +from copy import copy + +from msp.utils import Conf, zcheck, Random, zlog, Constants, GLOBAL_RECORDER +from msp.nn import BK, ExprWrapper, SliceManager, SlicedExpr +from msp.nn.layers import BasicNode, Affine, BiAffineScorer, Embedding, RnnNode, MultiHeadAttention, AttConf, AttDistHelper +from msp.zext.process_train import ScheduledValue +from msp.zext.seq_helper import DataPadder + +from ...common.data import ParseInstance +from .search.lsg import LinearState, LinearGraph, BfsLinearAgenda, BfsLinearSearcher + +# manage single search instance +class TdState(LinearState): + # todo(warn): global variable, not elegant and there are too many dependencies on this one!! + is_bfs: bool = None + + def __init__(self, prev: 'TdState'=None, action=None, score=0., sg: LinearGraph=None, padded=False, inst: ParseInstance=None): + super().__init__(prev, action, score, sg, padded) + # todo(+N): can it be more efficient to maintain (not-copy) the structure info? + # record the state for the top-down transition system + if prev is None: + self.inst = inst + self.depth = 0 # real depth + self.depth_eff = 0 # effective depth = depth if down else depth+1 + self.num_tok = len(inst) + 1 # num of tokens plus artificial root + self.num_rest = self.num_tok - 1 # num of arcs to attach + self.list_arc = [-1] * self.num_tok # attached heads + self.list_label = [-1] * self.num_tok # attached labels + self.reduced = [False] * self.num_tok # whether this node has been reduced + assert action is None + self.state_par = None # parent state node + self.state_last_cur = None # previous cur state (cur = last_cur + one-more-children) + self.idx_cur = 0 # cur token idx + # todo(warn): 0 serves as input ROOT label, and output reduce label, can never co-exist! + self.label_cur = 0 # cur label between self and parent + self.idx_par = 0 # parent token idx (0's parent is still 0) + self.idxes_chs = [] # child token idx list (by adding order) + self.labels_chs = [] # should be the same size and corresponds to "ch_nodes" + # + self.idx_ch_leftmost, self.idx_ch_rightmost = 0, 0 + # todo(+N): should make it another Class + # BFS info + assert TdState.is_bfs is not None + self.is_bfs = TdState.is_bfs + self.bfs_root = self + if self.is_bfs: + self.bfs_buffer = [0, []] # next-idx, next-info list + else: + self.inst = prev.inst + self.num_tok = prev.num_tok + self.num_rest = prev.num_rest + self.list_arc = prev.list_arc.copy() + self.list_label = prev.list_label.copy() + self.reduced = prev.reduced.copy() + # BFS + self.is_bfs = prev.is_bfs + self.bfs_root = prev.bfs_root + # + # action = (head, child or head means up, attach_label or None) + act_head, act_attach, act_label = action + assert act_head == prev.idx_cur + if act_head == act_attach: + if self.is_bfs: + # go next: reduce + bfs_buffer_idx, bfs_buffer_list = prev.bfs_buffer + new_bfs_buffer_list = bfs_buffer_list.copy() + if bfs_buffer_idx < len(new_bfs_buffer_list): + next_par_state, next_cidx, next_label = new_bfs_buffer_list[bfs_buffer_idx] + new_bfs_buffer_list[bfs_buffer_idx] = None + self.bfs_buffer = [bfs_buffer_idx+1, new_bfs_buffer_list] + else: + # go back to root again!! + next_par_state, next_cidx, next_label = self.bfs_root, 0, 0 + self.bfs_buffer = [0, []] + self.state_par = next_par_state # parent state + self.state_last_cur = None # no prev parallel states for this one + self.idx_cur = next_cidx + self.label_cur = next_label + self.idx_par = next_par_state.idx_cur + self.idxes_chs = [] + self.labels_chs = [] + self.reduced[act_head] = True + self.idx_ch_leftmost = self.idx_ch_rightmost = next_cidx + self.depth = self.depth_eff = -1 # todo(warn): not used!! + else: + # go up: reduce + last_cur_state = prev.state_par + self.state_par = last_cur_state.state_par + self.state_last_cur = last_cur_state + self.idx_cur = last_cur_state.idx_cur + self.label_cur = last_cur_state.label_cur + self.idx_par = last_cur_state.idx_par + self.idxes_chs = last_cur_state.idxes_chs.copy() + self.idxes_chs.append(act_head) + self.labels_chs = last_cur_state.labels_chs.copy() + self.labels_chs.append(self.list_label[act_head]) + self.reduced[act_head] = True # reduce + # leftmost and rightmost children or self + self.idx_ch_leftmost, self.idx_ch_rightmost = \ + min(act_head, last_cur_state.idx_ch_leftmost), max(act_head, last_cur_state.idx_ch_rightmost) + # + self.depth = prev.depth - 1 + self.depth_eff = prev.depth + else: + # attach + self.num_rest -= 1 + self.list_arc[act_attach] = act_head + self.list_label[act_attach] = act_label + # + if self.is_bfs: + # attach, but not go down, just record in buffer + self.state_par = prev.state_par + self.state_last_cur = prev + self.idx_cur = prev.idx_cur + self.label_cur = prev.label_cur + self.idx_par = prev.idx_par + # + self.idxes_chs = prev.idxes_chs.copy() + self.idxes_chs.append(act_attach) + self.labels_chs = prev.labels_chs.copy() + self.labels_chs.append(act_label) + # leftmost and rightmost children or self + self.idx_ch_leftmost, self.idx_ch_rightmost = \ + min(act_attach, prev.idx_ch_leftmost), max(act_attach, prev.idx_ch_rightmost) + self.depth = self.depth_eff = -1 # todo(warn): not used!! + # fill buffer for the children node + bfs_buffer_idx, bfs_buffer_list = prev.bfs_buffer + new_bfs_buffer_list = bfs_buffer_list.copy() # todo(warn): copy to make it correct for beam-search + new_bfs_buffer_list.append((self, act_attach, act_label)) + self.bfs_buffer = [bfs_buffer_idx, new_bfs_buffer_list] + else: + # go down: attach + self.state_par = prev + self.state_last_cur = None + self.idx_cur = act_attach + self.label_cur = act_label + self.idx_par = act_head + self.idxes_chs = [] + self.labels_chs = [] + # leftmost and rightmost children or self + self.idx_ch_leftmost, self.idx_ch_rightmost = act_attach, act_attach + # + self.depth = self.depth_eff = prev.depth + 1 + self.not_root = 0. if (self.state_par is None) else 1. # still have parents, not the root node + # ===== + # other calculate as needed values + # + # Scorer: possible rnn hidden layers slices + self.spine_rnn_cache = None + self.trailing_rnn_cache = None + # Scores: score slices + self.arc_score_slice = None + self.label_score_slice = None + # for Oracle: togo child list + self.oracle_ch_cache = None # pair of idx in the list (next to read children idx): (ch-idx, des-idx) + # for loss: loss_cur+loss_future + self.oracle_loss_cache = None # (loss-arc, loss-corr_arc-label, delta-arc, delta-label) + # for signature and merging + self.sig_cache = None # (arc-array, label-array, sig-bytes) + +# ===== +# manage the oracle related stuffs +# TODO(+N): different oracles for bfs decoding! +class TdOracleManager: + def __init__(self, oracle_strategy: str, oracle_projective: bool, free_dist_alpha: float, labeling_order: str): + self.oracle_strategy = oracle_strategy + self.free_dist_alpha = free_dist_alpha + self.is_l2r, self.is_i2o, self.is_free, self.is_label, self.is_n2f = \ + [oracle_strategy==z for z in ["l2r", "i2o", "free", "label", "n2f"]] + zcheck(any([self.is_l2r, self.is_i2o, self.is_free, self.is_label, self.is_n2f]), + f"Err: UNK oracle-strategy {oracle_strategy}") + zcheck(not oracle_projective, "Err: not implemented for projective oracle!") + # + self.label_order = {"core": TdOracleManager.LABELING_RANKINGS_CORE, + "freq": TdOracleManager.LABELING_RANKINGS_FREQ}[labeling_order] + + # todo(+N): to be improved maybe + # by core relation -> non-core + LABELING_RANKINGS_CORE = {v:i for i,v in enumerate([ + # core + "nsubj", "csubj", "obj", "iobj", "ccomp", "xcomp", + # nominal + "nmod", "appos", "nummod", "amod", "acl", + # non core + "obl", "vocative", "expl", "dislocated", "advcl", "advmod", "discourse", + # others + "fixed", "flat", "compound", "list", "parataxis", "orphan", "goeswith", "reparandum", + # functions + "conj", "cc", "aux", "cop", "mark", "det", "clf", "case", "punct", "root", "dep", + ])} + # by english frequency + LABELING_RANKINGS_FREQ = {v:i for i,v in enumerate(['punct', 'case', 'nmod', 'amod', 'det', 'obl', 'nsubj', 'root', 'advmod', 'conj', 'obj', 'cc', 'mark', 'aux', 'acl', 'nummod', 'flat', 'cop', 'advcl', 'xcomp', 'appos', 'compound', 'expl', 'ccomp', 'fixed', 'iobj', 'parataxis', 'dep', 'csubj', 'orphan', 'discourse', 'clf', 'goeswith', 'vocative', 'list', 'dislocated', 'reparandum'])} + + # ===== + # used for init stream + def init_inst(self, inst: ParseInstance): + # inst.get_children_mask_arr(False) + inst.set_children_info(self.oracle_strategy, label_ranking_dict=self.label_order, free_dist_alpha=self.free_dist_alpha) + return inst + + # used for refreshing before each running (currently mainly for shuffling free-mode children) + def refresh_insts(self, insts: List[ParseInstance]): + if self.is_free: + for inst in insts: + inst.shuffle_children_free() + elif self.is_n2f: + for inst in insts: + inst.shuffle_children_n2f() + + # ===== + # get oracles for the input states, this procedure is same for all the modes + # todo(+N): currently only return one oracle (determined or random) + # todo(+N): simple oracle, but is this one optimal for TD? + def get_oracles(self, states: List[TdState]): + ret_nodes = [] + ret_labels = [] + for s in states: + idx_cur = s.idx_cur + inst = s.inst + list_arc = s.list_arc + state_last_cur = s.state_last_cur + cur_ch_list = inst.children_list[idx_cur] + cur_de_list = inst.descendant_list[idx_cur] + len_ch_list = len(cur_ch_list) + len_de_list = len(cur_de_list) + # the start; note that ch_idx is the indirect index + if state_last_cur is None: + ch_idx = 0 + de_idx = 0 + else: + prev_ch_idx, prev_de_idx = state_last_cur.oracle_ch_cache + # todo(note): will not repeat since the attached one will be filtered out + ch_idx = prev_ch_idx + de_idx = prev_de_idx + # find attachable one + while ch_idx=0: + ch_idx += 1 + while de_idx=0: + de_idx += 1 + # assign + s.oracle_ch_cache = (ch_idx, de_idx) + if ch_idx 0: + # attach + oracle_mask_arr[sidx][valid_ch_list] = 1. + elif idx_cur == 0: + # todo(+N): again, let the ROOT collect all unattached ones + unfinished_ones = [z+1 for z, a in enumerate(list_arc[1:]) if a<0] + oracle_mask_arr[sidx][unfinished_ones] = 1. + else: + # reduce + oracle_mask_arr[sidx][idx_cur] = 1. + # fill labels for the chidren + tofill_labels = inst.labels.idxes + oracle_label_arr[sidx][:len(tofill_labels)] = tofill_labels + else: + # specific order, fill the oracle one + ret_nodes, _ = self.get_oracles(states) + oracle_mask_arr[np.arange(len(states)), ret_nodes] = 1. + oracle_mask_arr *= cands_mask_arr + return oracle_mask_arr, oracle_label_arr + + # ===== + # todo(warn): this loss calculation only corresponds to free-mode! + # for global learning or credit assignment, get the current+future loss for a state + def set_losses(self, state: TdState): + prev = state.prev + if prev is None: + state.oracle_loss_cache = (0, 0, 0, 0) + else: + if prev.oracle_loss_cache is None: + self.set_losses(prev) + loss_arc, loss_carc_label, _, _ = prev.oracle_loss_cache + act_head, act_attach, act_label = state.action + inst = state.inst + list_arc = state.list_arc + delta_arc, delta_label = 0, 0 + if act_head == act_attach: # reduce operation which can lead to missing losses + # todo(warn): only currently unattached children!! + # todo(+N): might be not efficient for looping all children + # missing_loss = max(0, len(inst.children_list[act_head]) - len(prev.idxes_chs)) + missing_loss = len([z for z in inst.children_list[act_head] if list_arc[z]<0]) + delta_arc, delta_label = missing_loss, 0 + else: + real_head = inst.heads.vals[act_attach] + if real_head != act_head: + if state.reduced[real_head]: + # not blamed here + delta_arc, delta_label = 0, 0 + else: + # new attach error, but no blaming for label loss + delta_arc, delta_label = 1, 0 + elif inst.labels.idxes[act_attach] != act_label: + delta_arc, delta_label = 0, 1 + loss_arc += delta_arc + loss_carc_label += delta_label + state.oracle_loss_cache = (loss_arc, loss_carc_label, delta_arc, delta_label) + +# ===== +# calculate signatures +class TdSigManager: + def __init__(self, sig_type:str="plain"): + self.counter = 0 + self.positive_bias = 2 # bias to make all the indices positive to fit in unsigned value + self.sig_plain, self.sig_none = [sig_type==z for z in ("plain", "none")] + zcheck(self.sig_plain or self.sig_none, f"UNK sig type: {sig_type}") + + def get_sig(self, state: TdState): + if self.sig_plain: + return self.get_sig_plain(state) + elif self.sig_none: + # no signature, everyone is different! + self.counter += 1 + return self.counter + else: + raise NotImplementedError() + + # ignore the sigs of all reduced children, only care about unattached nodes and concerned nodes in the spine + def get_sig_plain(self, state: TdState): + positive_bias = self.positive_bias + prev = state.prev + if prev is None: + num_tok = state.num_tok + atype = "B" if (num_tok+positive_bias<256) else "H" + # todo(note): plus bias here! + start_arr = array(atype, [-1+positive_bias] * num_tok) + state.sig_cache = (start_arr, start_arr, b'') + else: + if prev.sig_cache is None: + self.get_sig_plain(prev) + prev_arc_array, prev_label_array, _ = prev.sig_cache + new_arc_array, new_label_array = copy(prev_arc_array), copy(prev_label_array) + # tell by the action + act_head, act_attach, act_label = state.action + if act_head == act_attach: + # todo(note): ignore further level of children here, set 0! + for c in state.idxes_chs: + new_arc_array[c] = 0 + new_label_array[c] = 0 + else: + # todo(note): plus bias here! + new_arc_array[act_attach] = act_head + positive_bias + new_label_array[act_attach] = act_label + positive_bias + state.sig_cache = (new_arc_array, new_label_array, new_arc_array.tobytes()+new_label_array.tobytes()) + return state.sig_cache[-1] + +# ===== +# components + +# ----- +# scorer + +class TdScorerConf(Conf): + def __init__(self): + self._input_dim = -1 # enc's last dimension + self._num_label = -1 # number of labels + # + self.output_local_norm = True # whether norm at output + self.apply_mask_in_scores = True # masking out the non-candidates in the scoring phase + self.dim_label = 30 + # todo(warn): not used in Scorer, but in Selector! + self.score_reduce_label = False # whether include label scores for reduce-op (label=0) + self.score_reduce_fix_zero = False # fix reduce arc score to 0 + # space transferring + self.arc_space = 512 + self.lab_space = 128 + # pre-computation for the head side: (cur/chsib_set/par -> spineRNN -> TrailingRNN) + # todo(+N): pre-transformation for each type of nodes? but may be too much params. + self.head_pre_size = 512 # size of repr after pre-computation + self.head_pre_useff = True # ff for the first layer? otherwise simply sum + self.use_label_feat = True # whether using label feature for cur and children nodes + self.use_chsib = True # use children node's siblings + self.chsib_num = 5 # how many (recent) ch's siblings to consider, 0 means all: [-num:] + self.chsib_f = "att" # how to represent the ch set: att=attention, sum=sum, ff=feadfoward (always 0 vector if no ch) + self.chsib_att = AttConf().init_from_kwargs(d_kqv=256, att_dropout=0.1, head_count=2) + self.use_par = True # use parent of current head node (grand parent) + self.use_spine_rnn = False # use spine rnn for the spine-line of ancestor nodes + self.spine_rnn_type = "lstm2" + self.use_trailing_rnn = False # trailing rnn's size should be the same as the head's output size + self.trailing_rnn_type = "lstm2" + # ----- + # final biaffine scoring + self.ff_hid_size = 0 + self.ff_hid_layer = 0 + self.use_biaffine = True + self.use_ff = True + self.use_ff2 = False + self.biaffine_div = 1. + # distance clip? + self.arc_dist_clip = -1 + self.arc_use_neg = False + + def get_dist_aconf(self): + return AttConf().init_from_kwargs(clip_dist=self.arc_dist_clip, use_neg_dist=self.arc_use_neg) + +# another helper class for children set calculation +class TdChSibReprer(BasicNode): + def __init__(self, pc: BK.ParamCollection, sconf: TdScorerConf): + super().__init__(pc, None, None) + # concat enc and label if use-label + self.dim = (sconf._input_dim+sconf.dim_label) if sconf.use_label_feat else sconf._input_dim + self.is_att, self.is_sum, self.is_ff = [sconf.chsib_f==z for z in ["att", "sum", "ff"]] + self.ff_reshape = [-1, self.dim*sconf.chsib_num] # only for ff + if self.is_att: + self.fnode = self.add_sub_node("fn", MultiHeadAttention(pc, self.dim, self.dim, self.dim, sconf.chsib_att)) + elif self.is_sum: + self.fnode = None + elif self.is_ff: + zcheck(sconf.chsib_num>0, "Err: Cannot ff with 0 child") + self.fnode = self.add_sub_node("fn", Affine(pc, self.dim*sconf.chsib_num, self.dim, act="elu")) + else: + raise NotImplementedError(f"UNK chsib method: {sconf.chsib_f}") + + def get_output_dims(self, *input_dims): + return (self.dim, ) + + def zeros(self, batch): + return BK.zeros((batch, self.dim)) + + def __call__(self, chs_input_state_t, chs_input_mem_t, chs_mask_t, chs_valid_t): + if self.is_att: + # [*, max-ch, size], ~, [*, 1, size], [*, max-ch] -> [*, size] + ret = self.fnode(chs_input_mem_t, chs_input_mem_t, chs_input_state_t.unsqueeze(-2), chs_mask_t).squeeze(-2) + # ignore head for the rest + elif self.is_sum: + # ignore head + ret = (chs_input_mem_t*chs_mask_t.unsqueeze(-1)).sum(-2) + elif self.is_ff: + reshaped_input_state_t = (chs_input_mem_t*chs_mask_t.unsqueeze(-1)).view(self.ff_reshape) + ret = self.fnode(reshaped_input_state_t) + else: + ret = None + # out-most mask + # [*, D_DL] * [*, 1] + return ret * (chs_valid_t.unsqueeze(-1)) + +# helper class for head pre-computation +# (group1:cur/chsib_set/par + group2:spineRNN + group3:TrailingRNN) +class TdHeadReprer(BasicNode): + def __init__(self, pc: BK.ParamCollection, sconf: TdScorerConf): + super().__init__(pc, None, None) + self.head_pre_size = sconf.head_pre_size + self.use_label_feat = sconf.use_label_feat + # padders for child nodes + self.input_enc_dim = sconf._input_dim + self.chs_start_posi = -sconf.chsib_num + self.ch_idx_padder = DataPadder(2, pad_vals=0, mask_range=2) # [*, num-ch] + self.ch_label_padder = DataPadder(2, pad_vals=0) + # ===== + # todo(note): here 0 for the ROOT + self.label_embeddings = self.add_sub_node("label", Embedding(pc, sconf._num_label, sconf.dim_label, fix_row0=False)) + # ===== + # todo(note): now adopting flatten groupings for basic, spine-rnn and trailing-rnn + # group 1: [cur_node, chsib, par_node] -> head_pre_size + self.use_chsib = sconf.use_chsib + self.use_par = sconf.use_par + # cur node + hr_input_sizes = [sconf._input_dim+sconf.dim_label] if self.use_label_feat else [sconf._input_dim] + if sconf.use_chsib: + self.chsib_reprer = self.add_sub_node("chsib", TdChSibReprer(pc, sconf)) + hr_input_sizes.append(self.chsib_reprer.get_output_dims()[0]) + if sconf.use_par: + # no par label here! + hr_input_sizes.append(sconf._input_dim) + # group 2: Spine RNN + self.use_spine_rnn = sconf.use_spine_rnn + if self.use_spine_rnn: + self.spine_rnn = self.add_sub_node("spine", RnnNode.get_rnn_node(sconf.spine_rnn_type, pc, hr_input_sizes[0], self.head_pre_size)) + hr_input_sizes.append(self.head_pre_size) + # group 3: Trailing RNN + self.use_trailing_rnn = sconf.use_trailing_rnn + if self.use_trailing_rnn: + self.trailing_rnn = self.add_sub_node("trailing", RnnNode.get_rnn_node(sconf.trailing_rnn_type, pc, hr_input_sizes[0],self.head_pre_size)) + hr_input_sizes.append(self.head_pre_size) + # finally sum + if sconf.head_pre_useff: + self.final_ff = self.add_sub_node("f1", Affine(pc, hr_input_sizes, self.head_pre_size, act="elu")) + else: + # todo(+2) + self.final_ff = None + + # calculating head representations, state_bidx_bases is the flattened_idx_base + def __call__(self, states: List[TdState], state_bidx_bases_expr: BK.Expr, flattened_enc_expr: BK.Expr): + local_bsize = len(states) + head_repr_inputs = [] + # (group 1) [cur_node, chsib, par_node] + # -- cur_node + cur_idxes = [s.idx_cur for s in states] + cur_rel_pos_t = BK.input_idx(cur_idxes) + cur_idxes_real_t = cur_rel_pos_t + state_bidx_bases_expr + cur_expr_node = BK.select(flattened_enc_expr, cur_idxes_real_t) # [*, enc_size] + if self.use_label_feat: + cur_labels = [s.label_cur for s in states] + cur_labels_real_t = BK.input_idx(cur_labels) + cur_expr_label = self.label_embeddings(cur_labels_real_t) # [*, dim_label] + cur_enc_t = BK.concat([cur_expr_node, cur_expr_label], -1) + else: + cur_enc_t = cur_expr_node + head_repr_inputs.append(cur_enc_t) # [*, D+DL] + # -- chsib + if self.use_chsib: + # slicing does not check range + chs_idxes = [s.idxes_chs[self.chs_start_posi:] for s in states] + chs_valid = [(0. if len(z)==0 else 1.) for z in chs_idxes] + # padding + padded_chs_idxes, padded_chs_mask = self.ch_idx_padder.pad(chs_idxes) # [*, max-ch], [*, max-ch] + # todo(warn): if all no children + if padded_chs_idxes.shape[1] == 0: + chs_repr = self.chsib_reprer.zeros(local_bsize) + else: + # select chs reprs + padded_chs_idxes_t = BK.input_idx(padded_chs_idxes) + state_bidx_bases_expr.unsqueeze(-1) + output_shape = [-1, BK.get_shape(padded_chs_idxes_t, -1), self.input_enc_dim] + chs_enc_t = BK.select(flattened_enc_expr, padded_chs_idxes_t.view(-1)).view(output_shape) # [*,max-ch,D] + # other inputs + chs_mask_t = BK.input_real(padded_chs_mask) # [*, max-ch] + chs_valid_t = BK.input_real(chs_valid) # [*,] + # labels + if self.use_label_feat: + chs_labels = [s.labels_chs[self.chs_start_posi:] for s in states] + padded_chs_labels, _ = self.ch_label_padder.pad(chs_labels) # [*, max-ch] + chs_labels_t = self.label_embeddings(padded_chs_labels) # [*, max-ch, DL] + chs_input_mem_t = BK.concat([chs_enc_t, chs_labels_t], -1) # [*, max-ch, D+DL] + else: + chs_input_mem_t = chs_enc_t + # calculate representations + chs_input_state_t = cur_enc_t # [*, D+DL] + chs_repr = self.chsib_reprer(chs_input_state_t, chs_input_mem_t, chs_mask_t, chs_valid_t) # [*, D+DL] + # + head_repr_inputs.append(chs_repr) # [*, D+DL] + # -- par_node + if self.use_par: + # here, no use of parent's label and mask with zero-vector for roots' parents + par_idxes = [s.idx_par for s in states] + par_idxes_real = BK.input_idx(par_idxes) + state_bidx_bases_expr + par_masks = [s.not_root for s in states] + # [*, enc_size] + par_expr = BK.select(flattened_enc_expr, par_idxes_real) * BK.input_real(par_masks).unsqueeze(-1) + head_repr_inputs.append(par_expr) + # ----- + # (group 2): spine-only RNN + if self.use_spine_rnn: + # create zero-init states for RNN + empty_slice = ExprWrapper.make_unit(self.spine_rnn.zero_init_hidden(1)) + # collect previous RNN states from parent states and combine + # prev_hid_slices = [(empty_slice if s.state_par is None else s.state_par.spine_rnn_cache) for s in states] + prev_hid_slices = [(s.state_par.spine_rnn_cache if s.not_root else empty_slice) for s in states] + prev_hid_combined = SliceManager.combine_slices(prev_hid_slices, None) + # RNN step + layer2_hid_state = self.spine_rnn(cur_enc_t, prev_hid_combined, None) + # store + new_slices = ExprWrapper(layer2_hid_state, local_bsize).split() + for bidx, s in enumerate(states): + s.spine_rnn_cache = new_slices[bidx] + # todo(note): inner side of the RNN-units + head_repr_inputs.append(layer2_hid_state[0]) + # (group 3): all-history-trajectory RNN (the processing is similar to layer2) + # create zero-init states for RNN + if self.use_trailing_rnn: + # create zero-init states for RNN + empty_slice = ExprWrapper.make_unit(self.trailing_rnn.zero_init_hidden(1)) + # collect previous RNN states from parent states and combine + prev_hid_slices = [(empty_slice if s.prev is None else s.prev.trailing_rnn_cache) for s in states] + prev_hid_combined = SliceManager.combine_slices(prev_hid_slices, None) + # RNN step + layer3_hid_state = self.trailing_rnn(cur_enc_t, prev_hid_combined, None) + # store + new_slices = ExprWrapper(layer3_hid_state, local_bsize).split() + for bidx, s in enumerate(states): + s.trailing_rnn_cache = new_slices[bidx] + # todo(note): inner side of the RNN-units + head_repr_inputs.append(layer3_hid_state[0]) + # ==== finally + final_repr = self.final_ff(head_repr_inputs) # [*, dim] + return final_repr, cur_rel_pos_t + +# ----- +# the main scorer +# three steps for scoring: enc-prepare, arc-score, label-score +class TdScorer(BasicNode): + def __init__(self, pc: BK.ParamCollection, sconf: TdScorerConf): + super().__init__(pc, None, None) + # options + input_dim = sconf._input_dim + arc_space = sconf.arc_space + lab_space = sconf.lab_space + ff_hid_size = sconf.ff_hid_size + ff_hid_layer = sconf.ff_hid_layer + use_biaffine = sconf.use_biaffine + use_ff = sconf.use_ff + use_ff2 = sconf.use_ff2 + # + head_pre_size = sconf.head_pre_size + self.cands_padder = DataPadder(2, pad_vals=0, mask_range=2) # [*, num-cands] + # + self.input_dim = input_dim + self.num_label = sconf._num_label + self.output_local_norm = sconf.output_local_norm + self.apply_mask_in_scores = sconf.apply_mask_in_scores + self.score_reduce_label = sconf.score_reduce_label + self.score_reduce_fix_zero = sconf.score_reduce_fix_zero + # todo(warn): the biaffine direction is the opposite in graph one, but does not matter + biaffine_div = sconf.biaffine_div + # attach/arc + self.arc_m = self.add_sub_node("am", Affine(pc, input_dim, arc_space, act="elu")) + self.arc_h = self.add_sub_node("ah", Affine(pc, head_pre_size, arc_space, act="elu")) + self.arc_scorer = self.add_sub_node("as", BiAffineScorer(pc, arc_space, arc_space, 1, ff_hid_size, ff_hid_layer=ff_hid_layer, use_biaffine=use_biaffine, use_ff=use_ff, use_ff2=use_ff2, biaffine_div=biaffine_div)) + # only add distance for arc + if sconf.arc_dist_clip > 0: + self.dist_helper = self.add_sub_node("dh", AttDistHelper(pc, sconf.get_dist_aconf(), arc_space)) + else: + self.dist_helper = None + # labeling + self.lab_m = self.add_sub_node("lm", Affine(pc, input_dim, lab_space, act="elu")) + self.lab_h = self.add_sub_node("lh", Affine(pc, head_pre_size, lab_space, act="elu")) + self.lab_scorer = self.add_sub_node("ls", BiAffineScorer(pc, lab_space, lab_space, self.num_label, ff_hid_size, ff_hid_layer=ff_hid_layer, use_biaffine=use_biaffine, use_ff=use_ff, use_ff2=use_ff2, biaffine_div=biaffine_div)) + # head preparation + self.head_preper = self.add_sub_node("h", TdHeadReprer(pc, sconf)) + # + # alpha for local norm + self.local_norm_alpha = 1.0 + + # step 0: pre-compute for the mod candidate nodes + # transform to specific space (for labeling, using stacking-like architecture for the head side) + # [*, len, input_dim] -> *[*, len, space_dim] + def transform_space(self, enc_expr): + # flatten for later usage for head nodes + flattened_enc_expr = enc_expr.view([-1, self.input_dim]) + flattened_enc_stride = BK.get_shape(enc_expr, -2) # sentence length + # calculate for m nodes + am_expr = self.arc_m(enc_expr) + lm_expr = self.lab_m(enc_expr) + am_pack = self.arc_scorer.precompute_input0(am_expr) + lm_pack = self.lab_scorer.precompute_input0(lm_expr) + scoring_expr_pack = (enc_expr, flattened_enc_expr, flattened_enc_stride, am_pack, lm_pack) + return scoring_expr_pack + + # local normalization (with prob-raising alpha), return log-prob + def local_normalize(self, scores, alpha): + if alpha != 1.: + scores = BK.log_softmax(scores, -1) * alpha + return BK.log_softmax(scores, -1) + + def set_local_norm_alpha(self, alpha): + if alpha != self.local_norm_alpha: + self.local_norm_alpha = alpha + zlog(f"Setting scorer's alpha to {self.local_norm_alpha}") + + # step 1: compute for arc + # INPUT: ENC_PACK, List[*], [*], [*, slen] -> OUTPUT: score[*, slen], ARC_PACK + # todo(note): tradeoff use masks instead of idxes to lessen CPU-burden, nearly half GPU computation is masked out + def score_arc(self, cache_pack, states: List[TdState], state_bidxes_expr: BK.Expr, cand_masks_expr: BK.Expr): + bsize = len(states) + enc_expr, flattened_enc_expr, flattened_enc_stride, am_pack, lm_pack = cache_pack + # flattened idx bases + state_bidx_bases_expr = state_bidxes_expr * flattened_enc_stride + # prepare head + head_expr, rel_pos_t = self.head_preper(states, state_bidx_bases_expr, flattened_enc_expr) # [*, head_pre_size], [*] + ah_expr = self.arc_h(head_expr).unsqueeze(-2) # [*, 1, arc_space] + lh_expr = self.lab_h(head_expr).unsqueeze(-2) # [*, 1, lab_space] + # prepare cands + # [*, slen, *space] (todo(note): select at batch-idx instead of flattened-idx) + state_am_pack = [(BK.select(z, state_bidxes_expr) if z is not None else None) for z in am_pack] + # arc distance + if self.dist_helper is not None: + hm_dist = rel_pos_t.unsqueeze(-1) - BK.unsqueeze(BK.arange_idx(0, flattened_enc_stride), 0) + ah_rel1, _ = self.dist_helper.obatin_from_distance(hm_dist) # [*, L, arc_space] + else: + ah_rel1 = None + # [*, slen] + # todo(note): effective for local-norm mode, no effects for global models since later we will exclude things + score_candidate_mask = cand_masks_expr if self.apply_mask_in_scores else None + # full_arc_score = self.arc_scorer.postcompute_input1( + # state_am_pack, ah_expr, mask0=score_candidate_mask, rel1_t=ah_rel1).squeeze(-1) + full_arc_score = self.arc_scorer.postcompute_input1(state_am_pack, ah_expr, rel1_t=ah_rel1).squeeze(-1) + if self.score_reduce_fix_zero: + full_arc_score[BK.arange_idx(bsize), rel_pos_t] = 0. + # todo(warn): mask as the final step + if score_candidate_mask is not None: + full_arc_score += Constants.REAL_PRAC_MIN*(1.-score_candidate_mask) + if self.output_local_norm: + full_arc_score = self.local_normalize(full_arc_score, self.local_norm_alpha) + return full_arc_score, (state_bidx_bases_expr, state_bidxes_expr, cand_masks_expr, lh_expr) + + # step 2: compute for label (full or partial) + # labeling_cand_idxes=None means scoring for all cands, otherwise partially + # INPUT: ~, [*, len_cand]; OUTPUT: [*, ?, num-label] + # todo(+N): if self.score_reduce_fix_zero and self.score_reduce_label? + def score_label(self, cache_pack, cache_arc_pack, labeling_cand_idxes_expr: BK.Expr=None): + enc_expr, flattened_enc_expr, flattened_enc_stride, _, lm_pack = cache_pack + state_bidx_bases_expr, state_bidxes_expr, cand_masks_expr, lh_expr = cache_arc_pack + # + if labeling_cand_idxes_expr is None: + # [*, slen, label] (full score, also select at batch-idx) + state_lm_pack = [(BK.select(z, state_bidxes_expr) if z is not None else None) for z in lm_pack] + score_candidate_mask = cand_masks_expr if self.apply_mask_in_scores else None + full_label_score = self.lab_scorer.postcompute_input1(state_lm_pack, lh_expr, mask0=score_candidate_mask) + if self.output_local_norm: + full_label_score = self.local_normalize(full_label_score, self.local_norm_alpha) + return full_label_score + else: + # [*, ?, label] (partial score) + flattened_labeling_cand_idxes_expr = labeling_cand_idxes_expr + state_bidx_bases_expr.unsqueeze(-1) + output_shape = BK.get_shape(flattened_labeling_cand_idxes_expr) + [-1] + flat_idx = flattened_labeling_cand_idxes_expr.view(-1) # flatten first-two dims + selected_lm_pack = [(BK.select(z.view(-1, BK.get_shape(z, -1)), flat_idx).view(output_shape) + if z is not None else None) for z in lm_pack] # [*, ?, *space] + selected_label_score = self.lab_scorer.postcompute_input1(selected_lm_pack, lh_expr) + if self.output_local_norm: + selected_label_score = self.local_normalize(selected_label_score, self.local_norm_alpha) + return selected_label_score +# ----- + +# --- +# expander, get the candidates of each state for the next step +class TdExpander: + def expand(self, s: TdState): + raise NotImplementedError() + + @staticmethod + def get_expander(strategy, projective): + if projective: + # todo(+N) + raise NotImplementedError("Err: Projective methods have not been implemented yet.") + else: + _MAP = {"i2o": TdExpanderI2oNP, "l2r": TdExpanderL2rNP, "free": TdExpanderFreeNP, + "n2f": TdExpanderN2fNP, "n2f2": TdExpanderN2f2NP} + return _MAP[strategy]() + +# Non-projective inside to outside +class TdExpanderI2oNP(TdExpander): + def expand(self, s: TdState): + num_tok = s.num_tok + idx_cur = s.idx_cur + list_arc = s.list_arc + idxes_chs = s.idxes_chs + last_ch_node = idx_cur if (len(idxes_chs)==0) else idxes_chs[-1] + if last_ch_node>idx_cur: + # right branch + ret = [i for i in range(last_ch_node+1, num_tok) if list_arc[i]<0] + else: + # left branch, but can also jump to right branch + ret = [i for i in range(1, last_ch_node) if list_arc[i]<0] + [i for i in range(idx_cur+1, num_tok) if list_arc[i]<0] + if idx_cur>0 or TdState.is_bfs: + ret.append(idx_cur) + return ret + +# Non-projective left to right +class TdExpanderL2rNP(TdExpander): + def expand(self, s: TdState): + num_tok = s.num_tok + idx_cur = s.idx_cur + list_arc = s.list_arc + idxes_chs = s.idxes_chs + last_ch_node = 0 if (len(idxes_chs) == 0) else idxes_chs[-1] + ret = [i for i in range(last_ch_node+1, num_tok) if list_arc[i]<0] + if idx_cur>0 or TdState.is_bfs: + ret.append(idx_cur) + return ret + +# Non-projective free +class TdExpanderFreeNP(TdExpander): + def expand(self, s: TdState): + num_tok = s.num_tok + idx_cur = s.idx_cur + list_arc = s.list_arc + ret = [i for i in range(1, num_tok) if list_arc[i]<0] + if idx_cur>0 or TdState.is_bfs: + # including a reduce operation (todo(note): since cur_arc[cur_node] must be >=0) + ret.append(idx_cur) + return ret + +# Non-projective n2f: strict mode +class TdExpanderN2fNP(TdExpander): + def expand(self, s: TdState): + num_tok = s.num_tok + idx_cur = s.idx_cur + list_arc = s.list_arc + idxes_chs = s.idxes_chs + last_ch_node = idx_cur if (len(idxes_chs) == 0) else idxes_chs[-1] + last_ch_node_distance = abs(idx_cur-last_ch_node) + ret = [i for i in range(1, num_tok) if (list_arc[i]<0 and abs(idx_cur-i)>=last_ch_node_distance)] + if idx_cur>0 or TdState.is_bfs: + ret.append(idx_cur) + return ret + +# Non-projective n2f2: non-strict mode +class TdExpanderN2f2NP(TdExpander): + def expand(self, s: TdState): + num_tok = s.num_tok + idx_cur = s.idx_cur + list_arc = s.list_arc + ret = [i for i in range(1, s.idx_ch_leftmost) if list_arc[i]<0] + \ + [i for i in range(s.idx_ch_rightmost+1, num_tok) if list_arc[i]<0] + if idx_cur>0 or TdState.is_bfs: + ret.append(idx_cur) + return ret + +# --- +# local selector (scoring + local selector) +class TdLocalSelector: + def __init__(self, scorer: TdScorer, local_arc_beam_size: int, local_label_beam_size: int, oracle_manager=None): + self.scorer: TdScorer = scorer + self.num_label = self.scorer.num_label + self.score_reduce_label = self.scorer.score_reduce_label + # for mode with topk + self.local_arc_beam_size = local_arc_beam_size + self.local_label_beam_size = min(local_label_beam_size, self.scorer.num_label) + # for mode with oracle + self.oracle_manager: TdOracleManager = oracle_manager + # for ss-mode + self.rand_stream = Random.stream(Random.random_sample) + + # return selected states + # todo(warn): in mixed mode, run _score_and_choose in the batch, and ignore useless ones + def select(self, ags: List[BfsLinearAgenda], candidates, cache_pack): + raise NotImplementedError() + + # common scoring and selecting procedure (three modes, can select more for mixing) + # flatten out all the score expressions + # todo(warn): beam_size is only useful for topk, sampling always obtain 1 selection + # todo(warn): no score modification (+margin) in this local step! + # todo(warn): first arc-topk then label, will this lead to search-error? + def _score_and_choose(self, candidates, cache_pack, if_topk, if_tsample, if_sample, if_oracle, + selection_mask_arr=None, log_prob_sum=False): + # bsize * (TdState, orig-batch-idxes, candidate-masks) + flattened_states, state_bidxes, cands_mask_arr = candidates + state_bidxes_expr = BK.input_idx(state_bidxes) + cand_masks_expr = BK.input_real(cands_mask_arr) + # ===== + # First: Common arc score: [*, slen] + full_arc_score, cache_arc_pack = self.scorer.score_arc(cache_pack, flattened_states, state_bidxes_expr, cand_masks_expr) + bsize, slen = BK.get_shape(full_arc_score) + ew_full_arc_score = ExprWrapper(full_arc_score.view(-1), bsize*slen) # [bsize*slen] + nlabel = self.num_label + # todo(note): add masking (not inplaced) here for topk selection, after possible local-norm, + # the one before local-norm is controlled by the flag in scorer + selection_mask_expr = cand_masks_expr if (selection_mask_arr is None) else BK.input_real(selection_mask_arr) + if log_prob_sum: + assert selection_mask_arr is not None + oracle_log_prob_sum = (BK.exp(full_arc_score) * selection_mask_expr).sum(-1).log() # [bsize,] + else: + oracle_log_prob_sum = None + full_arc_score = full_arc_score + Constants.REAL_PRAC_MIN*(1.-selection_mask_expr) + # ===== + # Rest: different selecting method: (arc-K, label-K, scores_arc, idxes_arc, ew_label, scores_label, idxes_label) + select_rets = [None] * 4 + # ----- + # 1/2. topk/topk-sample selection + if if_topk or if_tsample: + arc_topk = min(slen-1, self.local_arc_beam_size) # todo(note): cannot select 0 (root) as child + # [*, arc_topk] + # todo(warn): be careful that there can be log0=-inf, which might lead to NAN? + topk_arc_scores, topk_arc_idxes = BK.topk(full_arc_score, arc_topk, dim=-1, sorted=False) + # ----- + # 1. topk + if if_topk: + # label score: [*, arc_topk, nlabel] + topk_label_score = self.scorer.score_label(cache_pack, cache_arc_pack, topk_arc_idxes) + ew_topk_label_score = ExprWrapper(topk_label_score.view(-1), bsize*arc_topk*nlabel) # [bs*arcK*nl] + label_topk = self.local_label_beam_size + # topk^2 [*, arc_topk, label_topk] + topk2_label_scores, topk2_label_idxes = BK.topk(topk_label_score, label_topk, dim=-1, sorted=False) + select_rets[0] = (arc_topk, label_topk, topk_arc_scores, topk_arc_idxes, + ew_topk_label_score, topk2_label_scores, topk2_label_idxes) + # 2. topk_sample (use gumble) + if if_tsample: + # sampled arc + tsample_arc_scores, tsample_arc_local_idxes = BK.category_sample(topk_arc_scores) + tsample_arc_idxes = BK.gather(topk_arc_idxes, tsample_arc_local_idxes, -1) + # label: [*, 1, nlabel] + tsample_label_score = self.scorer.score_label(cache_pack, cache_arc_pack, tsample_arc_idxes) + ew_tsample_label_score = ExprWrapper(tsample_label_score.view(-1), bsize*nlabel) # [bs*1*nl] + # todo(+2): directly sample here for simplicity: [*, 1, 1] + tsample2_label_scores, tsample2_label_idxes = BK.category_sample(tsample_label_score, -1) + select_rets[1] = (1, 1, tsample_arc_scores, tsample_arc_idxes, + ew_tsample_label_score, tsample2_label_scores, tsample2_label_idxes) + # ----- + # 3. sampling (use gumble) + # todo(+N): current only support 1 sampling instance + if if_sample: + # sampled arc [*, 1] + sample_arc_scores, sample_arc_idxes = BK.category_sample(full_arc_score) + # scoring labels [*, 1, nlabel] + sample_label_score = self.scorer.score_label(cache_pack, cache_arc_pack, sample_arc_idxes) + ew_sample_label_score = ExprWrapper(sample_label_score.view(-1), bsize*nlabel) + # sampled labels [*, 1, 1] + sample2_label_scores, sample2_label_idxes = BK.category_sample(sample_label_score, -1) + select_rets[2] = (1, 1, sample_arc_scores, sample_arc_idxes, + ew_sample_label_score, sample2_label_scores, sample2_label_idxes) + # ----- + # 4. oracle (need external help) + # todo(+N): current only support 1 oracle instance + if if_oracle: + oracle_arc_idxes, oracle_label_idxes = self.oracle_manager.get_oracles(flattened_states) + # [*, 1], [*, 1, 1] + oracle_arc_idxes, oracle_label_idxes = BK.input_idx(oracle_arc_idxes).unsqueeze(-1), \ + BK.input_idx(oracle_label_idxes).unsqueeze(-1).unsqueeze(-1) + # arc scores [*, 1] + oracle_arc_scores = BK.gather(full_arc_score, oracle_arc_idxes) + # label all scores [*, 1, nlabel] + oracle_label_score = self.scorer.score_label(cache_pack, cache_arc_pack, oracle_arc_idxes) + ew_oracle_label_score = ExprWrapper(oracle_label_score.view(-1), bsize*nlabel) + # label scores [*, 1, 1] + oracle2_label_scores = BK.gather(oracle_label_score, oracle_label_idxes) + select_rets[3] = (1, 1, oracle_arc_scores, oracle_arc_idxes, + ew_oracle_label_score, oracle2_label_scores, oracle_label_idxes) + # return them all + return (bsize, slen, nlabel, ew_full_arc_score), select_rets, oracle_log_prob_sum + + # similar to the previous one but with special mode! + def _score_and_sample_oracle(self, candidates, cache_pack, sel_mask_arr, sel_label_arr, log_prob_sum): + # bsize * (TdState, orig-batch-idxes, candidate-masks) + flattened_states, state_bidxes, cands_mask_arr = candidates + state_bidxes_expr = BK.input_idx(state_bidxes) + cand_masks_expr = BK.input_real(cands_mask_arr) + # ===== + # First: Common arc score: [*, slen] + full_arc_score, cache_arc_pack = self.scorer.score_arc(cache_pack, flattened_states, state_bidxes_expr, cand_masks_expr) + bsize, slen = BK.get_shape(full_arc_score) + ew_full_arc_score = ExprWrapper(full_arc_score.view(-1), bsize * slen) # [bsize*slen] + nlabel = self.num_label + # + selection_mask_expr = BK.input_real(sel_mask_arr) + sel_label_expr = BK.input_idx(sel_label_arr) # [bsize, LEN] + if log_prob_sum: + assert sel_mask_arr is not None + oracle_log_prob_sum = (BK.exp(full_arc_score) * selection_mask_expr).sum(-1).log() # [bsize,] + else: + oracle_log_prob_sum = None + full_arc_score = full_arc_score + Constants.REAL_PRAC_MIN * (1. - selection_mask_expr) + # step 1, sample arc + # sampled arc [*, 1] + sample_arc_scores, sample_arc_idxes = BK.category_sample(full_arc_score) + # scoring labels [*, 1, nlabel] + sample_label_score = self.scorer.score_label(cache_pack, cache_arc_pack, sample_arc_idxes) + ew_sample_label_score = ExprWrapper(sample_label_score.view(-1), bsize * nlabel) + # get oracle labels [*, 1, 1] + oracle_label_idxes = BK.gather(sel_label_expr, sample_arc_idxes).unsqueeze(-1) + oracle2_label_scores = BK.gather(sample_label_score, oracle_label_idxes) + select_rets = [(1, 1, sample_arc_scores, sample_arc_idxes, + ew_sample_label_score, oracle2_label_scores, oracle_label_idxes)] + # return them all + return (bsize, slen, nlabel, ew_full_arc_score), select_rets, oracle_log_prob_sum + + # ===== + # generate and fill-in the new candidate states from the local selections + # -- return list[(beam_expands, gbeam_expands)] + + # common procedure for one state + def _add_states(self, cands: List, state: TdState, nlabel, arc_k, label_k, cands_mask_arr_s, arc_scores_arr_s, arc_idxes_arr_s, label_scores_arr_s, label_idxes_arr_s, ew_arc, ew_label, ew_arc_idx_base, ew_label_idx_base, label_all_scores_t): + aidx = 0 + state_cur_idx = state.idx_cur + assert len(arc_scores_arr_s)==arc_k and len(arc_idxes_arr_s)==arc_k, "Err: Wrong arc-k" + for one_arc_score, one_arc_idx in zip(arc_scores_arr_s, arc_idxes_arr_s): + one_arc_score, one_arc_idx = float(one_arc_score), int(one_arc_idx) + # only allow legal cands, also let reduce operations have labels (but not outputed)! + # todo(+N): this means that reduce on ROOT is filtered out! + if cands_mask_arr_s[one_arc_idx]>0.: + arc_score_slice = SlicedExpr(ew_arc, ew_arc_idx_base + one_arc_idx) + cur_ew_label_idx_base = ew_label_idx_base + aidx * nlabel + if one_arc_idx == state_cur_idx: + if self.score_reduce_label: + # todo(warn): reduce operation, only accept padding label! # todo(+N): is this one expensive? + reduce_label_score = (label_all_scores_t[cur_ew_label_idx_base]).item() # L-scores[base+0] + reduce_label_slice = SlicedExpr(ew_label, cur_ew_label_idx_base) + else: + reduce_label_score = 0. + reduce_label_slice = None + new_state = TdState(prev=state, action=(one_arc_idx, one_arc_idx, 0), score=(one_arc_score+reduce_label_score)) + new_state.label_score_slice = reduce_label_slice + new_state.arc_score_slice = arc_score_slice + cands.append(new_state) + else: # loop for all the labels + assert len(label_scores_arr_s[aidx])==label_k and len(label_idxes_arr_s[aidx])==label_k, "Err: wrong label-K" + for one_label_score, one_label_idx in zip(label_scores_arr_s[aidx], label_idxes_arr_s[aidx]): + one_label_score, one_label_idx = float(one_label_score), int(one_label_idx) + new_state = TdState(prev=state, action=(state_cur_idx, one_arc_idx, one_label_idx), + score=(one_arc_score+one_label_score)) + new_state.arc_score_slice = arc_score_slice + new_state.label_score_slice = SlicedExpr(ew_label, cur_ew_label_idx_base + one_label_idx) + cands.append(new_state) + aidx += 1 + + # mode 1: only on plain beam (for topk/sample/oracle) + def _new_states_plain(self, ags: List[BfsLinearAgenda], cands_mask_arr: np.ndarray, general_pack, plain_pack, local_gold_collect): + bsize, slen, nlabel, ew_full_arc_score = general_pack + arc_k, label_k, arc_scores, arc_idxes, ew_label_scores, label_scores, label_idxes = plain_pack + stride_ew_label = arc_k*nlabel + # ----- + # ew_arc: [bsize*slen], arc_scores: [bsize, arc_k], + # ew_label: [bsize*arc_k*nlabel], label_scores: [bsize, ark_k, label_k] + arc_scores_arr, arc_idxes_arr, label_scores_arr, label_idxes_arr = \ + BK.get_value(arc_scores), BK.get_value(arc_idxes), BK.get_value(label_scores), BK.get_value(label_idxes) + label_alls_t = BK.get_cpu_tensor(ew_label_scores.val) + rets = [] + bidx = 0 + for ag in ags: + pcands = [] + # plain beam + for state in ag.beam: + self._add_states(pcands, state, nlabel, arc_k, label_k, cands_mask_arr[bidx], arc_scores_arr[bidx], + arc_idxes_arr[bidx], label_scores_arr[bidx], label_idxes_arr[bidx], + ew_full_arc_score, ew_label_scores, bidx*slen, bidx*stride_ew_label, label_alls_t) + bidx += 1 + # gold beam + assert len(ag.gbeam) == 0, "Wrong mode" + # bidx += len(ag.gbeam) + rets.append((pcands, [])) + # todo(warn): stat for debugging-like purpose + # if self.oracle_manager is not None: + # for cur_cand_sample in pcands: + # self._stat_ss(cur_cand_sample, None) + if local_gold_collect: + ag.local_golds.extend(pcands) + assert bidx == bsize, "Unmatched dim0 size!" + return rets + + # mode 2: scheduled sampling (for sample-oracle) + def _new_states_ss(self, ags: List[BfsLinearAgenda], cands_mask_arr: np.ndarray, general_pack, plain_pack, gold_pack, rate_plain, ss_strict_oracle: bool, ss_include_correct_rate: float): + bsize, slen, nlabel, ew_full_arc_score = general_pack + arc_k1, label_k1, arc_scores1, arc_idxes1, ew_label_scores1, label_scores1, label_idxes1 = plain_pack + arc_k2, label_k2, arc_scores2, arc_idxes2, ew_label_scores2, label_scores2, label_idxes2 = gold_pack + stride_ew_label1 = arc_k1 * nlabel + stride_ew_label2 = arc_k2 * nlabel + # ----- + # ew_arc: [bsize*slen], arc_scores: [bsize, arc_k], + # ew_label: [bsize*arc_k*nlabel], label_scores: [bsize, ark_k, label_k] + arc_scores_arr1, arc_idxes_arr1, label_scores_arr1, label_idxes_arr1 = \ + BK.get_value(arc_scores1), BK.get_value(arc_idxes1), BK.get_value(label_scores1), BK.get_value(label_idxes1) + arc_scores_arr2, arc_idxes_arr2, label_scores_arr2, label_idxes_arr2 = \ + BK.get_value(arc_scores2), BK.get_value(arc_idxes2), BK.get_value(label_scores2), BK.get_value(label_idxes2) + label_alls_t1 = BK.get_cpu_tensor(ew_label_scores1.val) + label_alls_t2 = BK.get_cpu_tensor(ew_label_scores2.val) + rets = [] + bidx = 0 + for ag in ags: + pcands = [] + # plain beam + for state in ag.beam: + cands1 = [] + cands2 = [] + self._add_states(cands1, state, nlabel, arc_k1, label_k1, cands_mask_arr[bidx], arc_scores_arr1[bidx], + arc_idxes_arr1[bidx], label_scores_arr1[bidx], label_idxes_arr1[bidx], + ew_full_arc_score, ew_label_scores1, bidx*slen, bidx*stride_ew_label1, label_alls_t1) + self._add_states(cands2, state, nlabel, arc_k2, label_k2, cands_mask_arr[bidx], arc_scores_arr2[bidx], + arc_idxes_arr2[bidx], label_scores_arr2[bidx], label_idxes_arr2[bidx], + ew_full_arc_score, ew_label_scores2, bidx*slen, bidx*stride_ew_label2, label_alls_t2) + # ===== + assert len(cands1)==1 and len(cands2)==1, "Too many cands for scheduled sampling mode!" + if ss_strict_oracle: + # for local losses + ag.local_golds.extend(cands2) + # next expanding: sampling with rate + if next(self.rand_stream) < rate_plain: + pcands.extend(cands1) + else: + pcands.extend(cands2) + else: + cur_cand_sample, cur_cand_oracle = cands1[0], cands2[0] + self.oracle_manager.set_losses(cur_cand_sample) + _, _, cur_delta_arc, cur_delta_label = cur_cand_sample.oracle_loss_cache + if (cur_delta_arc+cur_delta_label)==0: + # if sample a good state, always follow it, but whether include it? + if next(self.rand_stream) < ss_include_correct_rate: + ag.local_golds.append(cur_cand_sample) + pcands.append(cur_cand_sample) + else: + ag.local_golds.append(cur_cand_oracle) + pcands.append(cur_cand_sample if (next(self.rand_stream) < rate_plain) else cur_cand_oracle) + # todo(warn): stat for debugging-like purpose + # self._stat_ss(cur_cand_sample, cur_cand_oracle) + # ===== + bidx += 1 + # gold beam + assert len(ag.gbeam) == 0, "Wrong mode" + # bidx += len(ag.gbeam) + rets.append((pcands, [])) + assert bidx == bsize, "Unmatched dim0 size!" + return rets + + def _stat_ss(self, cur_cand_sample, cur_cand_oracle): + GLOBAL_RECORDER.record_kv("ss_step", 1) + self.oracle_manager.set_losses(cur_cand_sample) + cur_loss_arc, cur_loss_label, cur_delta_arc, cur_delta_label = [str(z) for z in cur_cand_sample.oracle_loss_cache] + GLOBAL_RECORDER.record_kv("ss_lossda_" + cur_delta_arc, 1) + GLOBAL_RECORDER.record_kv("ss_lossdl_" + cur_delta_label, 1) + # GLOBAL_RECORDER.record_kv("ss_lossla_"+cur_loss_arc, 1) + # GLOBAL_RECORDER.record_kv("ss_lossll_"+cur_loss_label, 1) + if cur_cand_sample.length >= 2 * cur_cand_sample.num_tok - 2: + GLOBAL_RECORDER.record_kv("ss_end", 1) + GLOBAL_RECORDER.record_kv("ss_endla_" + cur_loss_arc, 1) + GLOBAL_RECORDER.record_kv("ss_endll_" + cur_loss_label, 1) + + # mode 3: mixed plain+gold (for global learning with topk/sample + oracle) + def _new_states_both(self, ags: List[BfsLinearAgenda], cands_mask_arr: np.ndarray, general_pack, plain_pack, gold_pack): + bsize, slen, nlabel, ew_full_arc_score = general_pack + arc_k1, label_k1, arc_scores1, arc_idxes1, ew_label_scores1, label_scores1, label_idxes1 = plain_pack + arc_k2, label_k2, arc_scores2, arc_idxes2, ew_label_scores2, label_scores2, label_idxes2 = gold_pack + stride_ew_label1 = arc_k1 * nlabel + stride_ew_label2 = arc_k2 * nlabel + # ----- + # ew_arc: [bsize*slen], arc_scores: [bsize, arc_k], + # ew_label: [bsize*arc_k*nlabel], label_scores: [bsize, ark_k, label_k] + arc_scores_arr1, arc_idxes_arr1, label_scores_arr1, label_idxes_arr1 = \ + BK.get_value(arc_scores1), BK.get_value(arc_idxes1), BK.get_value(label_scores1), BK.get_value(label_idxes1) + arc_scores_arr2, arc_idxes_arr2, label_scores_arr2, label_idxes_arr2 = \ + BK.get_value(arc_scores2), BK.get_value(arc_idxes2), BK.get_value(label_scores2), BK.get_value(label_idxes2) + label_alls_t1 = BK.get_cpu_tensor(ew_label_scores1.val) + label_alls_t2 = BK.get_cpu_tensor(ew_label_scores2.val) + rets = [] + bidx = 0 + for ag in ags: + pcands = [] + gcands = [] + # plain beam + for state in ag.beam: + # plain expanding + self._add_states(pcands, state, nlabel, arc_k1, label_k1, cands_mask_arr[bidx], arc_scores_arr1[bidx], + arc_idxes_arr1[bidx], label_scores_arr1[bidx], label_idxes_arr1[bidx], + ew_full_arc_score, ew_label_scores1, bidx*slen, bidx*stride_ew_label1, label_alls_t1) + # todo(note): plain+oracle also goes to gbeam: let the GlobalSelector decide. + self._add_states(gcands, state, nlabel, arc_k2, label_k2, cands_mask_arr[bidx], arc_scores_arr2[bidx], + arc_idxes_arr2[bidx], label_scores_arr2[bidx], label_idxes_arr2[bidx], + ew_full_arc_score, ew_label_scores2, bidx*slen, bidx*stride_ew_label2, label_alls_t2) + bidx += 1 + # gold beam + for state in ag.gbeam: + # todo(note): only add gold+oracle + self._add_states(gcands, state, nlabel, arc_k2, label_k2, cands_mask_arr[bidx], arc_scores_arr2[bidx], + arc_idxes_arr2[bidx], label_scores_arr2[bidx], label_idxes_arr2[bidx], + ew_full_arc_score, ew_label_scores2, bidx*slen, bidx*stride_ew_label2, label_alls_t2) + bidx += 1 + rets.append((pcands, gcands)) + assert bidx == bsize, "Unmatched dim0 size!" + return rets + +# pure topk, for decoding +class TdLocalSelectorTopk(TdLocalSelector): + def __init__(self, scorer: TdScorer, local_arc_beam_size: int, local_label_beam_size: int, oracle_manager=None, force_oracle: bool=False): + super().__init__(scorer, local_arc_beam_size, local_label_beam_size, oracle_manager) + self.force_oracle = force_oracle + + def select(self, ags: List[BfsLinearAgenda], candidates, cache_pack): + flattened_states, _, cands_mask_arr = candidates + sel_arr, _ = self.oracle_manager.fill_oracles(flattened_states, cands_mask_arr) if self.force_oracle else (None, None) + general_pack, more_packs, _ = self._score_and_choose(candidates, cache_pack, True, False, False, False, sel_arr) + return self._new_states_plain(ags, cands_mask_arr, general_pack, more_packs[0], False) + +# pure oracle, for basic static learning +class TdLocalSelectorOracle(TdLocalSelector): + def __init__(self, scorer: TdScorer, local_arc_beam_size: int, local_label_beam_size: int, oracle_manager, log_sum_prob): + super().__init__(scorer, local_arc_beam_size, local_label_beam_size, oracle_manager) + self.log_sum_prob = log_sum_prob + + def select(self, ags: List[BfsLinearAgenda], candidates, cache_pack): + lps = self.log_sum_prob + flattened_states, _, cands_mask_arr = candidates + sel_arr, _ = self.oracle_manager.fill_oracles(flattened_states, cands_mask_arr) if lps else (None, None) + general_pack, more_packs, oracle_log_prob_sum = \ + self._score_and_choose(candidates, cache_pack, False, False, False, True, sel_arr, lps) + rets = self._new_states_plain(ags, cands_mask_arr, general_pack, more_packs[3], True) + # todo(+2): repeated codes + # todo(warn): replace scores with log-prob-sum, states returned are exactly gold-states in this mode + if lps: + bsize, = BK.get_shape(oracle_log_prob_sum) + lps_full_arc_score = ExprWrapper(oracle_log_prob_sum, bsize) + bidx = 0 + for pcands, _ in rets: + for p in pcands: + p.arc_score_slice = SlicedExpr(lps_full_arc_score, bidx) + bidx += 1 + assert bidx == bsize + return rets + +# pure sampler +class TdLocalSelectorSample(TdLocalSelector): + def __init__(self, scorer: TdScorer, local_arc_beam_size: int, local_label_beam_size: int, tsample: bool, oracle_manager=None): + super().__init__(scorer, local_arc_beam_size, local_label_beam_size, oracle_manager) + self.tsample = tsample + + def select(self, ags: List[BfsLinearAgenda], candidates, cache_pack): + cands_mask_arr = candidates[-1] + if self.tsample: + general_pack, more_packs, _ = self._score_and_choose(candidates, cache_pack, False, True, False, False) + return self._new_states_plain(ags, cands_mask_arr, general_pack, more_packs[1], False) + else: + general_pack, more_packs, _ = self._score_and_choose(candidates, cache_pack, False, False, True, False) + return self._new_states_plain(ags, cands_mask_arr, general_pack, more_packs[2], False) + +# oracle sampler +class TdLocalSelectorOracleSample(TdLocalSelector): + def __init__(self, scorer: TdScorer, local_arc_beam_size: int, local_label_beam_size: int, oracle_manager, log_sum_prob): + super().__init__(scorer, local_arc_beam_size, local_label_beam_size, oracle_manager) + self.log_sum_prob = log_sum_prob + + def select(self, ags: List[BfsLinearAgenda], candidates, cache_pack): + flattened_states, _, cands_mask_arr = candidates + sel_arr, sel_labs = self.oracle_manager.fill_oracles(flattened_states, cands_mask_arr) + lps = self.log_sum_prob + general_pack, more_packs, oracle_log_prob_sum = self._score_and_sample_oracle(candidates, cache_pack, sel_arr, sel_labs, lps) + rets = self._new_states_plain(ags, cands_mask_arr, general_pack, more_packs[0], False) + # todo(warn): replace scores with log-prob-sum + if lps: + bsize, = BK.get_shape(oracle_log_prob_sum) + lps_full_arc_score = ExprWrapper(oracle_log_prob_sum, bsize) + bidx = 0 + for pcands, _ in rets: + for p in pcands: + p.arc_score_slice = SlicedExpr(lps_full_arc_score, bidx) + bidx += 1 + assert bidx == bsize + return rets + +# sample (with or without oracle) with 1 instance, for (scheduled) sampling +class TdLocalSelectorSS(TdLocalSelector): + def __init__(self, scorer: TdScorer, local_arc_beam_size: int, local_label_beam_size: int, oracle_manager, ss_rate_plain: ScheduledValue, tsample: bool, ss_strict_oracle: bool, ss_include_correct_rate: float): + super().__init__(scorer, local_arc_beam_size, local_label_beam_size, oracle_manager) + self.tsample = tsample + self.ss_rate_plain = ss_rate_plain + self.ss_strict_oracle = ss_strict_oracle + self.ss_include_correct_rate = ss_include_correct_rate + + def select(self, ags: List[BfsLinearAgenda], candidates, cache_pack): + rate_plain = self.ss_rate_plain.value + sso = self.ss_strict_oracle + sic = self.ss_include_correct_rate + cands_mask_arr = candidates[-1] + if rate_plain<=0.: # pure oracle mode + general_pack, more_packs, _ = self._score_and_choose(candidates, cache_pack, False, False, False, True) + rs = self._new_states_plain(ags, cands_mask_arr, general_pack, more_packs[3], True) + elif self.tsample: + general_pack, more_packs, _ = self._score_and_choose(candidates, cache_pack, False, True, False, True) + rs = self._new_states_ss(ags, cands_mask_arr, general_pack, more_packs[1], more_packs[3], rate_plain, sso, sic) + else: + general_pack, more_packs, _ = self._score_and_choose(candidates, cache_pack, False, False, True, True) + rs = self._new_states_ss(ags, cands_mask_arr, general_pack, more_packs[2], more_packs[3], rate_plain, sso, sic) + return rs + +# TODO(+N) +class TdLocalSelectorMixed(TdLocalSelector): + pass + +# --- +# global selector/arranger (ranking the beam) +class TdGlobalArranger: + def __init__(self, global_beam_size: int): + self.global_beam_size = global_beam_size + + def arrange(self, ags: List[BfsLinearAgenda], local_cands: List): + raise NotImplementedError() + + @staticmethod + def create(global_beam_size: int, merge_sig_type: str, attach_num_beam_size: int): + if merge_sig_type == "none" and attach_num_beam_size>=global_beam_size: + ret = TdGlobalArrangerPlain(global_beam_size) + else: + ret = TdGlobalArrangerMerge(global_beam_size, merge_sig_type, attach_num_beam_size) + zlog(f"Create GlobalArranger: {ret}") + return ret + +# plain mode (no mixing) +class TdGlobalArrangerPlain(TdGlobalArranger): + def __init__(self, global_beam_size: int): + super().__init__(global_beam_size) + + def arrange(self, ags: List[BfsLinearAgenda], local_cands: List): + batch_len = len(ags) + selections = [None] * batch_len + assert batch_len == len(local_cands) + for batch_idx in range(batch_len): + pcands = local_cands[batch_idx][0] # # ignore gcands + pcands.sort(key=lambda x: x.score_accu, reverse=True) + selections[batch_idx] = (pcands[:self.global_beam_size], None) + return selections + +# slightly advanced mode: with second-beam and merging +class TdGlobalArrangerMerge(TdGlobalArranger): + def __init__(self, global_beam_size: int, merge_sig_type: str, attach_num_beam_size: int): + super().__init__(global_beam_size) + # + self.sig_manager = TdSigManager(merge_sig_type) + self.attach_num_beam_size = attach_num_beam_size + + def arrange(self, ags: List[BfsLinearAgenda], local_cands: List): + batch_len = len(ags) + selections = [None] * batch_len + assert batch_len == len(local_cands) + global_beam_size, sig_manager, attach_num_beam_size = \ + self.global_beam_size, self.sig_manager, self.attach_num_beam_size + # go + for batch_idx in range(batch_len): + pcands = local_cands[batch_idx][0] # ignore gcands + pcands.sort(key=lambda x: x.score_accu, reverse=True) + # select with merging + scands = [] + beam_sig, beam_attach_num = {}, {} # bytes->State, int->int + for one_cand in pcands: + # secondary attach_num beam + one_remain_num = one_cand.num_rest + already_count = beam_attach_num.get(one_remain_num, 0) + if already_count >= attach_num_beam_size: + continue + beam_attach_num[one_remain_num] = already_count + 1 + # sig beam + # todo(+N): more secondary beams? + one_sig = sig_manager.get_sig(one_cand) + merger_state = beam_sig.get(one_sig, None) + if merger_state is None: + beam_sig[one_sig] = one_cand + else: + one_cand.merge_by(merger_state) # merge by the higher-scored one + continue + # adding + scands.append(one_cand) + # finally main beam + if len(scands) >= global_beam_size: + break + selections[batch_idx] = (scands, None) + return selections + +# TODO(+N) mixing mode / allow-merge (for advanced decoding and training) + +# ----- +# final step, end once loop and prepare for the next +class TdEnder: + def end(self, ags: List[BfsLinearAgenda], selections): + raise NotImplementedError() + +# plain mode +class TdEnderPlain(TdEnder): + def end(self, ags: List[BfsLinearAgenda], selections): + for ag, one_selections in zip(ags, selections): + new_beam = [] + for s in one_selections[0]: + ending_length = (2*s.num_tok-1) if TdState.is_bfs else (2*s.num_tok-2) + # ending_length = (2*s.num_tok-2) + # todo(note): finish all 2*n-2 steps, can be different than num_rest>0 for global models + if s.length >= ending_length: + s.mark_end() + ag.ends.append(s) + else: + new_beam.append(s) + # replace and record last one + if len(ag.beam) > 0: + ag.last_beam = ag.beam + ag.beam = new_beam + +# ===== +# the overall search helper or actually the searcher +class TdSearcher(BfsLinearSearcher): + def __int__(self): + # components + self.expander: TdExpander = None + self.local_selector: TdLocalSelector = None + self.global_arranger: TdGlobalArranger = None + self.ender: TdEnder = None + # cache for the scorer + self.cache_pack = None + self.max_length = -1 + # TODO(+N): where to add margin? + self.margin = 0. + + # setup for each search + def refresh(self, cache_pack): + self.cache_pack = cache_pack + self.max_length = cache_pack[2] # todo(+1): not elegant + # todo(+N): need to refresh other things? + + # expand candidates for each state to get scores + def expand(self, ags: List[BfsLinearAgenda]): + expander = self.expander + this_bsize = sum(len(z.beam) + len(z.gbeam) for z in ags) # current flattened bsize + # + state_bidxes = np.zeros(this_bsize, dtype=np.int32) + cands_mask_arr = np.zeros([this_bsize, self.max_length], dtype=np.float32) + flattened_states = [] + # + bidx = 0 + for sidx, ag in enumerate(ags): + # add all states + for one_beam in [ag.beam, ag.gbeam]: + flattened_states.extend(one_beam) + for state in one_beam: + one_cands = expander.expand(state) + state_bidxes[bidx] = sidx + cands_mask_arr[bidx][one_cands] = 1. + bidx += 1 + return (flattened_states, state_bidxes, cands_mask_arr) + + # score and select + def select(self, ags: List[BfsLinearAgenda], candidates): + # step 1: scoring and local select + flattened_states, state_bidxes, cands_mask_arr = candidates + # List[(pcands, gcands)] + local_selections = self.local_selector.select(ags, candidates, self.cache_pack) + # step 2: global select with the local candidates + # List[List(cands)] + if self.global_arranger is None: + # for single-instance ones + global_selections = local_selections + else: + global_selections = self.global_arranger.arrange(ags, local_selections) + return global_selections + + # preparing for next loop or ending + def end(self, ags: List[BfsLinearAgenda], selections): + self.ender.end(ags, selections) + + # ===== + # put all components together and get a Searcher + + # 1. pure beam searcher for decoding + @staticmethod + def create_beam_searcher(scorer: TdScorer, oracle_manager, iconf, force_oracle: bool): + searcher = TdSearcher() + searcher.expander = TdExpander.get_expander(iconf.expand_strategy, iconf.expand_projective) + searcher.local_selector = TdLocalSelectorTopk(scorer, iconf.local_arc_beam_size, iconf.local_label_beam_size, oracle_manager, force_oracle) + # todo(+N): merge? + searcher.global_arranger = TdGlobalArranger.create(iconf.global_beam_size, iconf.merge_sig_type, iconf.attach_num_beam_size) + searcher.ender = TdEnderPlain() + return searcher + + # 2. pure oracle-follower for teacher forcing like training + @staticmethod + def create_oracle_follower(scorer: TdScorer, oracle_manager, iconf, log_prob_sum: bool): + searcher = TdSearcher() + searcher.expander = TdExpander.get_expander(iconf.expand_strategy, iconf.expand_projective) + # todo(warn): single instance! + searcher.local_selector = TdLocalSelectorOracle(scorer, iconf.local_arc_beam_size, iconf.local_label_beam_size, oracle_manager, log_prob_sum) + searcher.global_arranger = None + searcher.ender = TdEnderPlain() + return searcher + + # 3. mixing oracle/sample randomly for scheduled sampling training + @staticmethod + def create_scheduled_sampler(scorer: TdScorer, oracle_manager, iconf, ss_rate_sample: ScheduledValue, topk_sample: bool, ss_strict_oracle: bool, ss_include_correct_rate: float): + searcher = TdSearcher() + searcher.expander = TdExpander.get_expander(iconf.expand_strategy, iconf.expand_projective) + # todo(warn): single instance! + searcher.local_selector = TdLocalSelectorSS(scorer, iconf.local_arc_beam_size, iconf.local_label_beam_size, oracle_manager, ss_rate_sample, topk_sample, ss_strict_oracle, ss_include_correct_rate) + searcher.global_arranger = None + searcher.ender = TdEnderPlain() + return searcher + + # 4. pure sampler for RL training + # use oracle_manager only for recording things + @staticmethod + def create_rl_sampler(scorer: TdScorer, oracle_manager, iconf, topk_sample: bool): + searcher = TdSearcher() + searcher.expander = TdExpander.get_expander(iconf.expand_strategy, iconf.expand_projective) + # todo(warn): single instance! + searcher.local_selector = TdLocalSelectorSample(scorer, iconf.local_arc_beam_size, iconf.local_label_beam_size, topk_sample, oracle_manager) + searcher.global_arranger = None + searcher.ender = TdEnderPlain() + return searcher + + # 4.5 pure sampler for oracle-force/sampler training + # use oracle_manager only for recording things + @staticmethod + def create_of_sampler(scorer: TdScorer, oracle_manager, iconf, log_prob_sum: bool): + searcher = TdSearcher() + searcher.expander = TdExpander.get_expander(iconf.expand_strategy, iconf.expand_projective) + # todo(warn): single instance! + searcher.local_selector = TdLocalSelectorOracleSample( + scorer, iconf.local_arc_beam_size, iconf.local_label_beam_size, oracle_manager, log_prob_sum) + searcher.global_arranger = None + searcher.ender = TdEnderPlain() + return searcher + + # 5. multi-sample or beam-search + oracle for beam-as-all RL-styled min-risk training + # TODO(+N) + @staticmethod + def create_min_risker(): + pass + + # 6. TODO(+N) + +# b tasks/zdpar/transition/topdown/decoder:623 diff --git a/tasks/zdpar/transition/topdown/parser.py b/tasks/zdpar/transition/topdown/parser.py new file mode 100644 index 0000000..f3d63d8 --- /dev/null +++ b/tasks/zdpar/transition/topdown/parser.py @@ -0,0 +1,189 @@ +# + +# the top-down transition-based parser + +import numpy as np +from typing import List + +from msp.utils import Conf, zlog, JsonRW, zwarn, zcheck +from msp.data import VocabPackage +from msp.nn import BK +from msp.zext.process_train import SVConf, ScheduledValue + +from ...common.data import ParseInstance +from ...common.model import BaseParserConf, BaseInferenceConf, BaseTrainingConf, BaseParser +from .decoder import TdScorerConf, TdScorer, TdState +from .runner import TdInferencer, TdFber + +# ===== +# confs + +# decoding conf +class InferenceConf(BaseInferenceConf): + def __init__(self): + super().__init__() + # self.dec_algorithm = "TD" + # expand + self.expand_strategy = "free" # constrains for expanding candidates + self.expand_projective = False # projective expanding? + # beam sizes + self.global_beam_size = 5 # beam size (global_beam_size) or sampling size (global-selector size) + self.local_arc_beam_size = 5 + self.local_label_beam_size = 2 + # for sampling + self.topk_sample = False # constraining topk for sampling mode? + # slightly advanced beam + self.merge_sig_type = "none" # none, plain, ... + self.attach_num_beam_size = 5 # secondary beam of how many attaching edges that current state has + # todo(warn): special mode, check search error + self.check_serr = False + +# training conf +class TraningConf(BaseTrainingConf): + def __init__(self): + super().__init__() + # oracle + self.oracle_strategy = "free" # constrains for expanding candidates + self.oracle_projective = False # projective expanding? + self.oracle_free_dist_alpha = 0. # for free oracle, distribution \alpha for prob <= e^(|dist|*alpha) + self.oracle_label_order = "freq" # core/freq + # training strategy + self.train_strategy = "force" # force(teacher-forcing), ss(scheduled-sampling), rl, of(oracle-force/sample) + self.loss_div_step = True # loss /= steps or sents? + self.oracle_log_prob_sum = False # use log sum prob of all oracles for of/force-mode + # loss functions & general + # TODO(+N) + # depth scheduling (100 levels should be enough, set 'mode=linear' and 'init_val=1.' to turn on! + self.sched_depth = SVConf().init_from_kwargs(val=100., which_idx="eidx", mode="none", k=1., b=1., scale=1.) + # for scheduled sampling + # the oracler only returns one oracle, but the current sampling can be anther zero-loss good operation, + # whether strcitly follow oracle or accept this one rather than possible random oracle as local loss + self.ss_strict_oracle = False + # include rate of correct states in loss (effective ss_strict_oracle==False) + self.ss_include_correct_rate = 0.1 + # alpha for local normalization + self.local_norm_alpha = 1.0 + +# overall parser conf +class TdParserConf(BaseParserConf): + def __init__(self): + super().__init__(InferenceConf(), TraningConf()) + self.sc_conf = TdScorerConf() + # + self.is_bfs = False + # output + # self.sc_conf.output_local_norm = ? + # self.iconf.dec_algorithm = "?" + # self.tconf.loss_function = "?" + +# ===== +# model + +# the model +class TdParser(BaseParser): + def __init__(self, conf: TdParserConf, vpack: VocabPackage): + super().__init__(conf, vpack) + # ===== For decoding ===== + self.inferencer = TdInferencer(self.scorer, conf.iconf) + # ===== For training ===== + sched_depth = ScheduledValue("depth", conf.tconf.sched_depth) + self.add_scheduled_values(sched_depth) + self.fber = TdFber(self.scorer, conf.iconf, conf.tconf, self.margin, self.sched_sampling, sched_depth) + # todo(warn): not elegant, global flag! + TdState.is_bfs = conf.is_bfs + # ===== + zcheck(not self.bter.jpos_multitask_enabled(), "Not implemented for joint pos in this mode!!") + zwarn("WARN: This topdown mode is deprecated!!") + + def build_decoder(self): + conf = self.conf + conf.sc_conf._input_dim = self.enc_output_dim + conf.sc_conf._num_label = self.label_vocab.trg_len(True) # todo(warn): plus 0 for reduce-label + return TdScorer(self.pc, conf.sc_conf) + + # called before each mini-batch + def refresh_batch(self, training): + super().refresh_batch(training) + # todo(+N): not elegant to put it here! + if training: + self.scorer.set_local_norm_alpha(self.conf.tconf.local_norm_alpha) + else: + self.scorer.set_local_norm_alpha(1.) + + # obtaining the instance preparer for training + def get_inst_preper(self, training, **kwargs): + if training: + return self.fber.oracle_manager.init_inst # method + else: + return None + + # ===== + # main procedures + def _prepare_score(self, insts, training): + input_repr, enc_repr, jpos_pack, mask_arr = self.bter.run(insts, training) + # assert jpos_pack[0] is None, "not allowing jpos in this mode yet!" + # mask_expr = BK.input_real(mask_arr) # do not need this as in Graph-parser + mask_expr = None + scoring_expr_pack = self.scorer.transform_space(enc_repr) + return scoring_expr_pack, mask_expr + + # get children order info from state + def _get_chs_ordering(self, state: TdState): + assert state.is_end() + # collect all last reduce operations + remain = state.num_tok + ret = [""] * remain + while remain>0 and state is not None: + idx_cur = state.idx_cur + if len(ret[idx_cur]) == 0: + act_head, act_attach, act_label = state.action + if act_head == act_attach: + ret[idx_cur] = f"CORDER={','.join([str(z) for z in state.idxes_chs])}" + remain -= 1 + state = state.prev + return ret + + # decoding + def inference_on_batch(self, insts: List[ParseInstance], **kwargs): + _LEQ_MAP = lambda x, m: [max(m, z) for z in x] # larger or equal map + # + with BK.no_grad_env(): + self.refresh_batch(False) + scoring_prep_f = lambda ones: self._prepare_score(ones, False)[0] + get_best_state_f = lambda ag: sorted(ag.ends, key=lambda s: s.score_accu, reverse=True)[0] + ags, info = self.inferencer.decode(insts, scoring_prep_f, False) + # put the results inplaced + for one_inst, one_ag in zip(insts, ags): + best_state = get_best_state_f(one_ag) + one_inst.pred_heads.set_vals(_LEQ_MAP(best_state.list_arc, 0)) # directly int-val for heads + # todo(warn): already the correct labels, no need to transform + # -- one_inst.pred_labels.build_vals(self.pred2real_labels(best_state.list_label), self.label_vocab) + one_inst.pred_labels.build_vals(_LEQ_MAP(best_state.list_label, 1), self.label_vocab) + # todo(warn): add children ordering in MISC field + one_inst.pred_miscs.set_vals(self._get_chs_ordering(best_state)) + # check search err + if self.conf.iconf.check_serr: + ags_fo, _ = self.fber.force_decode(insts, scoring_prep_f, False) + serr_sent = 0 + serr_tok = 0 + for ag, ag_fo in zip(ags, ags_fo): + best_state = get_best_state_f(ag) + best_fo_state = get_best_state_f(ag_fo) + # for this one, only care about UAS + if best_state.score_accu < best_fo_state.score_accu: + cur_serr_tok = sum((1 if a!=b else 0) for a,b in zip(best_state.list_arc[1:], best_fo_state.list_arc[1:])) + if cur_serr_tok > 0: + serr_sent += 1 + serr_tok += cur_serr_tok + info["serr_sent"] = serr_sent + info["serr_tok"] = serr_tok + return info + + # training (forward/backward) + def fb_on_batch(self, annotated_insts: List[ParseInstance], training=True, loss_factor=1.): + self.refresh_batch(training) + scoring_expr_pack, mask_expr = self._prepare_score(annotated_insts, training) + info = self.fber.fb(annotated_insts, scoring_expr_pack, training, loss_factor) + return info + +# b tasks/zdpar/transition/topdown/parser:150 diff --git a/tasks/zdpar/transition/topdown/runner.py b/tasks/zdpar/transition/topdown/runner.py new file mode 100644 index 0000000..e802d57 --- /dev/null +++ b/tasks/zdpar/transition/topdown/runner.py @@ -0,0 +1,188 @@ +# +import copy + +from msp.nn import BK, SliceManager +from msp.utils import Helper +from msp.zext.process_train import ScheduledValue +from .decoder import TdSearcher, TdOracleManager +from .decoder import BfsLinearAgenda, TdState + +# ===== +# decoder +class TdInferencer: + def __init__(self, scorer, iconf, oracle_manager=None, force_oracle=False): + self.searcher = TdSearcher.create_beam_searcher(scorer, oracle_manager, iconf, force_oracle) + # todo(warn): specify the non-failing backoff one + bo_iconf = copy.copy(iconf) + bo_iconf.expand_strategy = "free" + self.backoff_searcher = TdSearcher.create_beam_searcher(scorer, oracle_manager, bo_iconf, force_oracle) + self.require_sg = False # whether need the recordings of SearchGraph + + # return final states + def decode(self, insts, scoring_prep_f, backoff): + cur_searcher = self.backoff_searcher if backoff else self.searcher + # decoding + scoring_expr_pack = scoring_prep_f(insts) + ags = [BfsLinearAgenda.init_agenda(TdState, z, self.require_sg) for z in insts] + cur_searcher.refresh(scoring_expr_pack) + cur_searcher.go(ags) + # backoff + failed_insts = [] + failed_idxes = [] + idx = 0 + for one_inst, one_ag in zip(insts, ags): + if len(one_ag.ends) == 0: + failed_insts.append(one_inst) + failed_idxes.append(idx) + idx += 1 + # ----- + if len(failed_insts) > 0: + if backoff: + assert len(failed_insts) == 0, "Cannot fail in backoff mode!" + else: + # inplaced output results + bo_ags, _ = self.decode(failed_insts, scoring_prep_f, True) + assert len(bo_ags) == len(failed_idxes) + for bo_idx, bo_ag in zip(failed_idxes, bo_ags): + ags[bo_idx] = bo_ag + # ----- + info = {"sent": len(insts), "tok": sum(map(len, insts)), "failed": len(failed_insts)} + return ags, info + +# ===== +# fber(learner) +class TdFber: + def __init__(self, scorer, iconf, tconf, margin: ScheduledValue, sched_sampling: ScheduledValue, sched_depth: ScheduledValue): + self.tconf = tconf + self.oracle_manager = TdOracleManager(tconf.oracle_strategy, tconf.oracle_projective, tconf.oracle_free_dist_alpha, tconf.oracle_label_order) + self.require_sg = False + # forced decoding + self.oracled_searcher = TdInferencer(scorer, iconf, self.oracle_manager, True) + # training strategy + self.train_force, self.train_ss, self.train_rl, self.train_of = \ + [tconf.train_strategy==z for z in ["force", "ss", "rl", "of"]] + if self.train_force: + self.searcher = TdSearcher.create_oracle_follower(scorer, self.oracle_manager, iconf, tconf.oracle_log_prob_sum) + elif self.train_ss: + self.searcher = TdSearcher.create_scheduled_sampler(scorer, self.oracle_manager, iconf, sched_sampling, iconf.topk_sample, tconf.ss_strict_oracle, tconf.ss_include_correct_rate) + elif self.train_rl: + self.searcher = TdSearcher.create_rl_sampler(scorer, self.oracle_manager, iconf, iconf.topk_sample) + elif self.train_of: + self.searcher = TdSearcher.create_of_sampler(scorer, self.oracle_manager, iconf, tconf.oracle_log_prob_sum) + else: + raise NotImplementedError(f"Err: UNK training strategy {tconf.train_strategy}") + self.sched_depth = sched_depth + # TODO(+N): margin, local/global? + self.margin = margin + + def force_decode(self, insts, scoring_prep_f, backoff): + insts = [self.oracle_manager.init_inst(z) for z in insts] + self.oracle_manager.refresh_insts(insts) + return self.oracled_searcher.decode(insts, scoring_prep_f, backoff) + + def fb(self, annotated_insts, scoring_expr_pack, training: bool, loss_factor: float): + # depth constrain: <= sched_depth + cur_depth_constrain = int(self.sched_depth.value) + # run + ags = [BfsLinearAgenda.init_agenda(TdState, z, self.require_sg) for z in annotated_insts] + self.oracle_manager.refresh_insts(annotated_insts) + self.searcher.refresh(scoring_expr_pack) + self.searcher.go(ags) + # collect local loss: credit assignment + if self.train_force or self.train_ss: + states = [] + for ag in ags: + for final_state in ag.local_golds: + # todo(warn): remember to use depth_eff rather than depth + # todo(warn): deprecated + # if final_state.depth_eff > cur_depth_constrain: + # continue + states.append(final_state) + logprobs_arc = [s.arc_score_slice for s in states] + # no labeling scores for reduce operations + logprobs_label = [s.label_score_slice for s in states if s.label_score_slice is not None] + credits_arc, credits_label = None, None + elif self.train_of: + states = [] + for ag in ags: + for final_state in ag.ends: + for s in final_state.get_path(True): + states.append(s) + logprobs_arc = [s.arc_score_slice for s in states] + # no labeling scores for reduce operations + logprobs_label = [s.label_score_slice for s in states if s.label_score_slice is not None] + credits_arc, credits_label = None, None + elif self.train_rl: + logprobs_arc, logprobs_label, credits_arc, credits_label = [], [], [], [] + for ag in ags: + # todo(+2): need to check search failure? + # todo(+2): ignoring labels when reducing or wrong-arc + for final_state in ag.ends: + # todo(warn): deprecated + # if final_state.depth_eff > cur_depth_constrain: + # continue + one_credits_arc = [] + one_credits_label = [] + self.oracle_manager.set_losses(final_state) + for s in final_state.get_path(True): + _, _, delta_arc, delta_label = s.oracle_loss_cache + logprobs_arc.append(s.arc_score_slice) + if delta_arc > 0: + # only blame arc + one_credits_arc.append(-delta_arc) + else: + one_credits_arc.append(0) + if delta_label > 0: + logprobs_label.append(s.label_score_slice) + one_credits_label.append(-delta_label) + elif s.label_score_slice is not None: + # not bad labeling + logprobs_label.append(s.label_score_slice) + one_credits_label.append(0) + # TODO(+N): minus average may encourage bad moves? + # balance + # avg_arc = sum(one_credits_arc) / len(one_credits_arc) + # avg_label = 0. if len(one_credits_label)==0 else sum(one_credits_label) / len(one_credits_label) + baseline_arc = baseline_label = -0.5 + credits_arc.extend(z-baseline_arc for z in one_credits_arc) + credits_label.extend(z-baseline_label for z in one_credits_label) + else: + raise NotImplementedError("CANNOT get here!") + # sum all local losses + loss_zero = BK.zeros([]) + if len(logprobs_arc) > 0: + batched_logprobs_arc = SliceManager.combine_slices(logprobs_arc, None) + loss_arc = (-BK.sum(batched_logprobs_arc)) if (credits_arc is None) \ + else (-BK.sum(batched_logprobs_arc * BK.input_real(credits_arc))) + else: + loss_arc = loss_zero + if len(logprobs_label) > 0: + batched_logprobs_label = SliceManager.combine_slices(logprobs_label, None) + loss_label = (-BK.sum(batched_logprobs_label)) if (credits_label is None) \ + else (-BK.sum(batched_logprobs_label*BK.input_real(credits_label))) + else: + loss_label = loss_zero + final_loss_sum = loss_arc + loss_label + # divide loss by what? + num_sent = len(annotated_insts) + num_valid_arcs, num_valid_labels = len(logprobs_arc), len(logprobs_label) + # num_valid_steps = len(states) + if self.tconf.loss_div_step: + final_loss = loss_arc/max(1,num_valid_arcs) + loss_label/max(1,num_valid_labels) + else: + final_loss = final_loss_sum/num_sent + # + val_loss_arc = BK.get_value(loss_arc).item() + val_loss_label = BK.get_value(loss_label).item() + val_loss_sum = val_loss_arc + val_loss_label + # + cur_has_loss = 1 if ((num_valid_arcs+num_valid_labels)>0) else 0 + if training and cur_has_loss: + BK.backward(final_loss, loss_factor) + # todo(warn): make tok==steps for dividing in common.run + info = {"sent": num_sent, "tok": num_valid_arcs, "valid_arc": num_valid_arcs, "valid_label": num_valid_labels, + "loss_sum": val_loss_sum, "loss_arc": val_loss_arc, "loss_label": val_loss_label, + "fb_all": 1, "fb_valid": cur_has_loss} + return info + +# b tasks/zdpar/transition/topdown/runner:46 diff --git a/tasks/zdpar/transition/topdown/search/lsg/__init__.py b/tasks/zdpar/transition/topdown/search/lsg/__init__.py new file mode 100644 index 0000000..575e271 --- /dev/null +++ b/tasks/zdpar/transition/topdown/search/lsg/__init__.py @@ -0,0 +1,4 @@ +# + +from .graph import LinearState, LinearGraph +from .searcher import BfsLinearAgenda, BfsLinearSearcher diff --git a/tasks/zdpar/transition/topdown/search/lsg/graph.py b/tasks/zdpar/transition/topdown/search/lsg/graph.py new file mode 100644 index 0000000..3d90c91 --- /dev/null +++ b/tasks/zdpar/transition/topdown/search/lsg/graph.py @@ -0,0 +1,121 @@ +# +from msp.utils import Constants, zcheck + +# ===== +# inner graph representation +# todo(+N): this part can be re-written in cython or c++ + +# this class only maintain the properties of basic graph info, other things are not here! +class LinearState(object): + # _STATUS = "NONE"(unknown), "EXPAND"(survived), "MERGED", "END" + STATUS_NONE, STATUS_EXPAND, STATUS_MERGED, STATUS_END = 0, 1, 2, 3 + + # ==== constructions + def __init__(self, prev: 'LinearState'=None, action=None, score=0., sg: 'LinearGraph'=None, padded=False): + # 1. basic info + self.sg = sg # related SearchGraph + self.prev = prev # parent + self.length = 0 # how many steps from start + # graph info + self.id = None + self.nexts = None # valid next states + self.merge_list = None # merge others + self.merger = None # merged by other + + # 2. status + self.padded = padded + self.status = LinearState.STATUS_NONE + + # 3. values + self.action = action + self.score = score + self.score_accu = 0. + # ===== + # depend on prev state + if prev is not None: + self.sg = prev.sg + self.length = prev.length + 1 + self.score_accu = prev.score_accu + self.score + # no need to preserve graph info if not recording it + if padded: + self.sg = None + if self.sg is not None: + self.nexts = [] # valid next states + self.merge_list = [] # merge others + self.id = self.sg.reg(self) + if prev is not None: + prev.status = LinearState.STATUS_EXPAND + prev.nexts.append(self) + + def get_pid(self): + if self.prev is None: + return -1 + else: + return self.prev.id + + def __len__(self): + return self.length + + def __repr__(self): + return f"State[{self.id}<-{self.get_pid()}] len={self.length}, act={self.action}, sc={self.score_accu:.3f}({self.score:.3f})" + + # graph related + def is_end(self): + return self.status == LinearState.STATUS_END + + def is_start(self): + return self.prev is None + + def mark_end(self): + self.status = LinearState.STATUS_END + if self.sg is not None: + self.sg.add_end(self) + + # get the history path towards root + def get_path(self, forward_order=True, depth=Constants.INT_PRAC_MAX): + cur = self + ret = [] + # todo(warn): excluding the usually artificial root-node + while cur.prev is not None and len(ret) < depth: + ret.append(cur) + cur = cur.prev + if forward_order: + ret.reverse() + return ret + + # merge: modify both states + def merge_by(self, s: 'LinearState'): + zcheck(s.merger is None, "Err: multiple level of merges!") + self.status = LinearState.STATUS_MERGED + self.merger = s + s.add_merge(self) + + def add_merge(self, s: 'LinearState'): + if self.merge_list: + self.merge_list.append(s) + +# the linear search graph (lattice) +class LinearGraph(object): + def __init__(self, info=None): + # information about the settings like srcs + self.info = info + # unique ids for states (must be >=1) + self.counts = 0 + # special nodes + self.root = None + self.ends = [] + + # register a new state, return its id in this graph + def reg(self, s: LinearState): + self.counts += 1 + # todo(+N): record more info about the state? + return self.counts + + def set_root(self, s: LinearState): + zcheck(s.sg is self, "SGErr: State does not belong here") + zcheck(self.root is None, "SGErr: Can only have one root") + zcheck(s.is_start(), "SGErr: Only start node can be root") + self.root = s + + def add_end(self, s: LinearState): + self.ends.append(s) diff --git a/tasks/zdpar/transition/topdown/search/lsg/searcher.py b/tasks/zdpar/transition/topdown/search/lsg/searcher.py new file mode 100644 index 0000000..0983e65 --- /dev/null +++ b/tasks/zdpar/transition/topdown/search/lsg/searcher.py @@ -0,0 +1,68 @@ +# + +# the (batched) Linear Searchers +# the three layers: State, Agenda, Searcher + +from typing import List +from .graph import LinearState, LinearGraph + +# ===== +# main components + +# agenda for one instance +class BfsLinearAgenda: + def __init__(self, inst, init_beam: List[LinearState]=None, init_gbeam: List[LinearState]=None): + self.inst = inst + self.beam = [] if init_beam is None else init_beam # plain search beam + self.gbeam = [] if init_gbeam is None else init_gbeam # gold-informed ones + # for loss collection + self.local_golds = [] # gold (oracle) states for local loss + # for results collection + self.ends = [] + self.last_beam = [] + + @staticmethod + def init_agenda(state_type, inst, require_sg: bool): + sg = None + if require_sg: + sg = LinearGraph() + init_state = state_type(prev=None, sg=sg, inst=inst) + a = BfsLinearAgenda(inst, init_beam=[init_state]) + return a + + # TODO(+N): what to do with ended states? + def is_end(self): + return all(x.is_end() for x in self.beam) and all(x.is_end() for x in self.gbeam) + +# ===== +# batched searcher +class BfsLinearSearcher: + # ===== + def expand(self, ags: List[BfsLinearAgenda]): + raise NotImplementedError() + + def select(self, ags: List[BfsLinearAgenda], candidates): + raise NotImplementedError() + + def end(self, ags: List[BfsLinearAgenda], selections): + raise NotImplementedError() + # ===== + + def refresh(self, *args, **kwargs): + raise NotImplementedError() + + def go(self, ags: List[BfsLinearAgenda]): + while True: + # finish if all ended + if all(a.is_end() for a in ags): + break + # step 1: expand (list of list of candidates) + candidates = self.expand(ags) + # step 2: score (batched all) & select + selections = self.select(ags, candidates) + # step 3: next (inplaced modifications) + self.end(ags, selections) + # results are stored in ags + # pass + +# b tasks/zdpar/transition/topdown/search/lsg/searcher:63