-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
10 changed files
with
691 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
### For the Implicit Event Argument Detection | ||
|
||
Hi, this describes our implementation for our ACL-2020 paper: "A Two-Step Approach for Implicit Event Argument Detection". | ||
|
||
Please refer to the paper for more details: [[paper]](https://www.aclweb.org/anthology/??) [[bib]](https://aclweb.org/anthology/papers/??) | ||
|
||
### Repo | ||
|
||
When we were carrying out our experiments for this work, we used the repo at this commit [`here`](https://github.com/zzsfornlp/zmsp/commit/80a7fa85d9e7c12d50df460d8bc0029d2c6cf40b). In later versions of this repo, there may be slight changes (for example, default hyper-parameter change or hyper-parameter name change). | ||
|
||
### Environment | ||
|
||
As those of the main `msp` package: | ||
|
||
python>=3.6 | ||
dependencies: pytorch>=1.0.0, numpy, scipy, gensim, cython, transformers, stanfordnlp=0.2.0, ... | ||
|
||
Please refer to [`scripts/ie/iarg/conda_env_zie.yml`](../scripts/ie/iarg/conda_env_zie.yml) for an example of specific conda environment and you can specify conda environment using this. | ||
|
||
### Data Preparation | ||
|
||
Please refer to [`scripts/ie/iarg/prepare_data.sh`](../scripts/ie/iarg/prepare_data.sh) for more details. To be noted, for easier processing, we transformed from RAMS's format to our own format. | ||
|
||
### Data Format | ||
|
||
Our own data format is also in json (jsonlines), where each line is a dict for one doc, it roughly contains the following fields: | ||
|
||
## updated json format | ||
{"doc_id": ~, "dataset": ~, "source": str, "lang": str[2], | ||
"sents": List[Dict], | ||
extra fields for these: posi, score=0., extra_info={}, is_aug=False | ||
"fillers": [{id, offset, length, type}], | ||
"entity_mentions": [id, gid, offset, length, type, mtype, (head)], | ||
"relation_mentions": [id, gid, type, rel_arg1{aid, role}, rel_arg2{aid, role}], | ||
"event_mentions": [id, gid, type, trigger{offset, length}, em_arg[{aid, role}]]} | ||
|
||
Please refer to [`tasks/zie/common/data.py`](../tasks/zie/common/data.py) or loading the data in python for more details. We think that the namings of the json fields are relatively straightforward. | ||
|
||
We also provide a script [`scripts/ie/iarg/doc2rams.py`](../scripts/ie/iarg/doc2rams.py) to transform from our format back to RAMS's. | ||
|
||
### Running | ||
|
||
Please refer to [`scripts/ie/iarg/go20.sh`](../scripts/ie/iarg/go20.sh) for an example of training and testing. | ||
|
||
For training, please make a new dir (for example, "run") and run `go20.sh` in it. (Maybe need to fix some paths.) | ||
|
||
For testing, simply run the following after training (`RGPU` is gpu id and `SRC_DIR` is the top path of the source code); here, input/output can be specified with: `test:[input]` and `output_file:[output]`: | ||
|
||
`CUDA_VISIBLE_DEVICES=${RGPU} PYTHONPATH=${SRC_DIR} python3 ${SRC_DIR}/tasks/cmd.py zie.main.test _conf device:0 test:../RAMS_1.0/data/en.rams.dev.json output_file:zout.dev.json` | ||
|
||
### Eval | ||
|
||
Our codes directly do evaluations and print the results (assume annotations are already in the input), you can also convert the outputs back to RAMS format and use the scorer in RAMS's package for evaluations (in our experiments, the results are the same as ours). | ||
|
||
# convert back to RAMS format | ||
PYTHONPATH=${SRC_DIR} python3 ${SRC_DIR}/scripts/ie/iarg/doc2rams.py zout.dev.json zout.dev.rams eng | ||
python3 ../RAMS_1.0/scorer/scorer.py --gold_file ../RAMS_1.0/data/dev.jsonlines --pred_file zout.dev.rams --ontology_file ../RAMS_1.0/scorer/event_role_multiplicities.txt --do_all |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
### For the Some Other Parsers (sop) | ||
|
||
Hi, this describes our implementation for some other (higher-order) parsers, including higher-order graph parsers with approximate decoding algorithms and a generalized easy-first transition based parser. | ||
|
||
### Repo | ||
|
||
When we were carrying out our experiments, we used the repo at the commit named `Add some other parsers.` (for main codes) and its following `Update scripts.` (for scripts). In later versions of this repo, there may be slight changes (for example, default hyper-parameter change or hyper-parameter name change). | ||
|
||
### Environment | ||
|
||
As those of the main `msp` package: | ||
|
||
python>=3.6 | ||
dependencies: pytorch>=0.4.1, numpy, scipy, gensim, cython, pybind11 | ||
|
||
### Data Preparation | ||
|
||
Data preparation is similar to those in [`emp_graph.md`](./emp_graph.md), please refer it for data preparation. | ||
|
||
### Compiling (for certain functionalities) | ||
|
||
To include full functionalities (especially high-order approximate decoding), compiling (with cython) is required. This can be done by simply running [`tasks/compile.sh`](../tasks/compile.sh). | ||
|
||
cd tasks; bash compile.sh; cd .. | ||
|
||
### Running | ||
|
||
Please refer to the scripts in [`scripts/dpar/zrun2/*`](../scripts/dpar/zrun2) for examples of running. Here, `train_ef.sh` is for training easy-first parser, `train_g1.sh` is for training basic first-order graph parser, which is needed as the base of higher-order graph parsers and `train_g2.sh` is for training higher-order graph parsers. | ||
|
||
(Note: some paths might need to be changed for actual running) | ||
|
||
(Note: to train `g2`, we also need to use [`score_go.py`](../tasks/zdpar/main/score_go.py) to pre-compute first-order scores for pruning for higher-order models, check this python script for more details.) | ||
|
||
(Note: the script [`scripts/dpar/zrun2/runm.py`](../scripts/dpar/zrun2/runm.py) also provides a reference for multi-lingual training for some of these parsers.) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
name: zie | ||
channels: | ||
- pytorch | ||
- defaults | ||
dependencies: | ||
- _libgcc_mutex=0.1=main | ||
- asn1crypto=0.24.0=py37_0 | ||
- blas=1.0=mkl | ||
- boto=2.49.0=py37_0 | ||
- boto3=1.9.199=py_0 | ||
- botocore=1.12.199=py_0 | ||
- bz2file=0.98=py37_1 | ||
- ca-certificates=2019.11.27=0 | ||
- certifi=2019.11.28=py37_0 | ||
- cffi=1.12.3=py37h2e261b9_0 | ||
- chardet=3.0.4=py37_1003 | ||
- cryptography=2.7=py37h1ba5d50_0 | ||
- cycler=0.10.0=py37_0 | ||
- dbus=1.13.12=h746ee38_0 | ||
- docutils=0.14=py37_0 | ||
- expat=2.2.6=he6710b0_0 | ||
- fontconfig=2.13.0=h9420a91_0 | ||
- freetype=2.9.1=h8a8886c_1 | ||
- gensim=3.4.0=py37h14c3975_0 | ||
- glib=2.63.1=h5a9c865_0 | ||
- gst-plugins-base=1.14.0=hbbd80ab_1 | ||
- gstreamer=1.14.0=hb453b48_1 | ||
- icu=58.2=h9c2bf20_1 | ||
- idna=2.8=py37_0 | ||
- intel-openmp=2019.5=281 | ||
- jmespath=0.9.4=py_0 | ||
- joblib=0.14.0=py_0 | ||
- jpeg=9b=h024ee3a_2 | ||
- kiwisolver=1.1.0=py37he6710b0_0 | ||
- libedit=3.1.20181209=hc058e9b_0 | ||
- libffi=3.2.1=hd88cf55_4 | ||
- libgcc-ng=9.1.0=hdf63c60_0 | ||
- libgfortran-ng=7.3.0=hdf63c60_0 | ||
- libpng=1.6.37=hbc83047_0 | ||
- libstdcxx-ng=9.1.0=hdf63c60_0 | ||
- libuuid=1.0.3=h1bed415_2 | ||
- libxcb=1.13=h1bed415_1 | ||
- libxml2=2.9.9=hea5a465_1 | ||
- matplotlib=3.1.1=py37h5429711_0 | ||
- mkl=2019.4=243 | ||
- mkl-service=2.3.0=py37he904b0f_0 | ||
- mkl_fft=1.0.14=py37ha843d7b_0 | ||
- mkl_random=1.0.2=py37hd81dba3_0 | ||
- ncurses=6.1=he6710b0_1 | ||
- ninja=1.9.0=py37hfd86e86_0 | ||
- numpy=1.16.5=py37h7e9f1db_0 | ||
- numpy-base=1.16.5=py37hde5b4d6_0 | ||
- openssl=1.1.1d=h7b6447c_3 | ||
- pandas=0.25.2=py37he6710b0_0 | ||
- pcre=8.43=he6710b0_0 | ||
- pip=19.2.2=py37_0 | ||
- pycparser=2.19=py37_0 | ||
- pyopenssl=19.0.0=py37_0 | ||
- pyparsing=2.4.5=py_0 | ||
- pyqt=5.9.2=py37h05f1152_2 | ||
- pysocks=1.7.0=py37_0 | ||
- python=3.7.4=h265db76_1 | ||
- python-dateutil=2.8.0=py37_0 | ||
- pytorch=1.0.0=py3.7_cuda9.0.176_cudnn7.4.1_1 | ||
- pytz=2019.3=py_0 | ||
- qt=5.9.7=h5867ecd_1 | ||
- readline=7.0=h7b6447c_5 | ||
- requests=2.22.0=py37_0 | ||
- s3transfer=0.2.1=py37_0 | ||
- scikit-learn=0.21.3=py37hd81dba3_0 | ||
- scipy=1.3.1=py37h7c811a0_0 | ||
- setuptools=41.0.1=py37_0 | ||
- sip=4.19.8=py37hf484d3e_0 | ||
- six=1.12.0=py37_0 | ||
- smart_open=1.8.4=py_0 | ||
- sqlite=3.29.0=h7b6447c_0 | ||
- tk=8.6.8=hbc83047_0 | ||
- tornado=6.0.3=py37h7b6447c_0 | ||
- urllib3=1.24.2=py37_0 | ||
- wheel=0.33.4=py37_0 | ||
- xlsxwriter=1.2.2=py_0 | ||
- xz=5.2.4=h14c3975_4 | ||
- zlib=1.2.11=h7b6447c_3 | ||
- pip: | ||
- absl-py==0.8.1 | ||
- astor==0.8.0 | ||
- cachetools==3.1.1 | ||
- click==7.0 | ||
- gast==0.2.2 | ||
- google-auth==1.7.1 | ||
- google-auth-oauthlib==0.4.1 | ||
- google-pasta==0.1.8 | ||
- grpcio==1.25.0 | ||
- h5py==2.10.0 | ||
- keras==2.3.1 | ||
- keras-applications==1.0.8 | ||
- keras-preprocessing==1.1.0 | ||
- langdetect==1.0.7 | ||
- markdown==3.1.1 | ||
- oauthlib==3.1.0 | ||
- opt-einsum==3.1.0 | ||
- protobuf==3.9.2 | ||
- pyasn1==0.4.8 | ||
- pyasn1-modules==0.2.7 | ||
- pyyaml==5.1.2 | ||
- regex==2019.8.19 | ||
- requests-oauthlib==1.3.0 | ||
- rsa==4.0 | ||
- sacremoses==0.0.35 | ||
- sentencepiece==0.1.83 | ||
- seqeval==0.0.12 | ||
- stanfordnlp==0.2.0 | ||
- tensorboard==2.0.2 | ||
- tensorflow==2.0.0 | ||
- tensorflow-estimator==2.0.1 | ||
- termcolor==1.1.0 | ||
- tqdm==4.36.1 | ||
- transformers==2.0.0 | ||
- werkzeug==0.16.0 | ||
- wrapt==1.11.2 | ||
prefix: /home/username/miniconda3/envs/zie |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
# | ||
|
||
# convert from doc back to rams format (for scoring) | ||
|
||
import sys, json | ||
from tasks.zie.common.data import DocInstance, MyDocReader | ||
|
||
# | ||
def doc2rams(doc, lang): | ||
ret = {"doc_key": doc.doc_id, "language_id": lang, "source_url": None, "split": "UNK", | ||
"sentences": [z.words.vals[1:] for z in doc.sents], "predictions": []} | ||
# sidwid -> overall-tid | ||
cur_tid = 0 | ||
posi2tid = [] | ||
for sent in ret["sentences"]: | ||
slen = len(sent) | ||
posi2tid.append([z+cur_tid for z in range(slen)]) | ||
cur_tid += slen | ||
# ===== | ||
def mention2tids(mention, sents): | ||
sid, wid, wlen = mention.hard_span.position(headed=False) | ||
# anchor_sent = sents[sid] | ||
# wid2 = wid+wlen-1 | ||
# assert wid>=1 and wid<anchor_sent.length | ||
# assert wid2>=1 and wid2<anchor_sent.length | ||
wid = max(1, wid) # keep in sent bound | ||
tid = posi2tid[sid][wid-1] # todo(note): here, -1 to exclude ROOT | ||
return [tid, tid+wlen-1] | ||
# ===== | ||
# get the args | ||
all_evts = doc.events | ||
sents = doc.sents | ||
# assert len(all_evts) <= 1 # todo(note): this mode is for special decoding mode | ||
for one_evt in all_evts: | ||
one_res = [mention2tids(one_evt.mention, sents)] | ||
for one_arg in one_evt.links: | ||
one_res.append(mention2tids(one_arg.ef.mention, sents) + [one_arg.role, 1.]) | ||
ret["predictions"].append(one_res) | ||
return ret | ||
|
||
# | ||
def main(input_f, output_f, lang): | ||
# fin = sys.stdin if (input_f=="-" or input_f=="") else open(input_f) | ||
fout = sys.stdout if (output_f=="-" or output_f=="") else open(output_f, "w") | ||
# | ||
for doc in MyDocReader(input_f, False, False): | ||
rams_doc = doc2rams(doc, lang) | ||
fout.write(json.dumps(rams_doc) + "\n") | ||
# | ||
# fin.close() | ||
fout.close() | ||
|
||
if __name__ == '__main__': | ||
main(*sys.argv[1:]) | ||
|
||
# PYTHONPATH=../src/ python3 doc2rams.py INPUT OUTPUT eng |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
# | ||
|
||
# train m3a on RAMS | ||
|
||
LANGUAGE="en" | ||
DATA_SET="en.rams" | ||
TEST_NAME="test" | ||
# find data | ||
SRC_DIR="../src/" | ||
DATA_DIR="../RAMS_1.0/data//" | ||
if [[ ! -d ${SRC_DIR} ]]; then | ||
SRC_DIR="../../src/" | ||
DATA_DIR="../../RAMS_1.0/data//" | ||
if [[ ! -d ${SRC_DIR} ]]; then | ||
SRC_DIR="../../../src/" | ||
DATA_DIR="../../../RAMS_1.0/data/" | ||
fi | ||
fi | ||
# ===== | ||
function run_one () { | ||
# check envs | ||
echo "The environment variables: RGPU=${RGPU} CUR_RUN=${CUR_RUN}"; | ||
|
||
# ===== | ||
# basis | ||
|
||
# basic | ||
base_opt="model_type:m3a conf_output:_conf init_from_pretrain:0 word_freeze:1" | ||
|
||
# enc | ||
base_opt+=" bert2_model:bert-base-cased bert_use_special_typeids:1 bert2_output_layers:[-1] bert2_trainable_layers:13" | ||
base_opt+=" m2e_use_basic_plus:0 m2e_use_basic_dep:0 ms_extend_step:2 ms_extend_budget:250" # avoid oom? | ||
base_opt+=" m3_enc_conf.enc_rnn_layer:0 m3_enc_conf.enc_hidden:512" # extra LSTM layer? | ||
|
||
# training and dropout | ||
base_opt+=" tconf.train_msent_based:1 tconf.batch_size:18 split_batch:3 iconf.batch_size:64" | ||
base_opt+=" lrate.val:0.00005 lrate.m:0.5 min_save_epochs:0 max_epochs:20 patience:2 anneal_times:5" | ||
base_opt+=" drop_embed:0.1 dropmd_embed:0.1 drop_hidden:0.2 gdrop_rnn:0.33 idrop_rnn:0.1 fix_drop:0" | ||
base_opt+=" bert2_training_mask_rate:0." | ||
|
||
# bucket | ||
base_opt+=" benc_bucket_range:20 enc_bucket_range:20" | ||
|
||
# ===== | ||
|
||
# overall | ||
base_opt+=" train_skip_noevt_rate:1.0 res_list:argument" # only keep center sentences | ||
|
||
# cand | ||
base_opt+=" lambda_cand.val:0.5" | ||
base_opt+=" c_cand.train_neg_rate:0.02 c_cand.nil_penalty:100." | ||
base_opt+=" c_cand.pred_sent_ratio:0.4 c_cand.pred_sent_ratio_sep:1" | ||
|
||
# arg | ||
base_opt+=" eval_conf.arg_mode:span" # eval on span | ||
base_opt+=" lambda_arg.val:1." # loss lambda | ||
# | ||
base_opt+=" ps_conf.nil_score0:1 arc_space:0 lab_space:0" # scorer architecture | ||
base_opt+=" c_arg.share_adp:1 c_arg.share_scorer:1" # param sharing | ||
base_opt+=" use_borrowed_evt_hlnode:0 use_borrowed_evt_adp:0" # extra param sharing (for evt-emb and evt-adp) | ||
base_opt+=" c_arg.use_cons_frame:0 which_frame_cons:rams" # use constraints (mainly for predicted evt types) | ||
base_opt+=" max_sdist:2 max_pairwise_role:2 max_pairwise_ef:1" # other dec-constraints | ||
base_opt+=" center_train_neg_rate:0.5 outside_train_neg_rate:0.5" # arg neg rate | ||
base_opt+=" c_arg.norm_mode:acand c_arg.use_evt_label:0 c_arg.use_ef_label:0 c_arg.use_sdist:0" # arg mode | ||
|
||
# arg expander | ||
base_opt+=" lambda_span.val:0.5 max_range:7 use_lstm_scorer:0 use_binary_scorer:0" | ||
|
||
# special reg & cut-lrate deterministically | ||
#base_opt+=" biaffine_div:0. biaffine_freeze:1" | ||
base_opt+=" lrate.which_idx:eidx lrate.start_bias:5 lrate.scale:4 lrate.m:0.5 lrate_warmup:0 lrate.min_val:0.000001 min_save_epochs:10" | ||
base_opt+=" lrate.val:0.00005 tconf.batch_size:12 split_batch:3 lrate.m:0.5 lrate_warmup:-2 bert2_training_mask_rate:0.15" | ||
|
||
# final group | ||
#base_opt+=" use_lstm_scorer:1 use_binary_scorer:1" | ||
base_opt+=" use_lstm_scorer:0 use_binary_scorer:0 lrate.val:0.000075 lrate.m:0.5 lrate_warmup:0 arc_space:0 lab_space:256 min_save_epochs:7" | ||
|
||
# ===== | ||
# note: these are for "gold_span" linking experiments | ||
#base_opt+=" c_arg.use_sdist:0 pred_span:0 lambda_span.val:0." | ||
#base_opt+=" iconf.lookup_ef:1 mix_pred_ef_rate:0 bert_use_center_typeids:0 bert_use_special_typeids:0 bert2_output_layers:[9,10,11,12] bert2_trainable_layers:0" | ||
#base_opt+=" lrate.val:0.0002 drop_hidden:0.33 tconf.batch_size:64 split_batch:1" | ||
# ===== | ||
|
||
# training | ||
if [[ -n ${DEBUG} ]]; then | ||
DEBUG_OPTION="-m pdb" | ||
else | ||
DEBUG_OPTION="" | ||
fi | ||
CUDA_VISIBLE_DEVICES=${RGPU} PYTHONPATH=${SRC_DIR} python3 ${DEBUG_OPTION} ${SRC_DIR}/tasks/cmd.py zie.main.train train:${DATA_DIR}/${DATA_SET}.train.json dev:${DATA_DIR}/${DATA_SET}.dev.json test:${DATA_DIR}/${DATA_SET}.${TEST_NAME}.json pretrain_file: device:-1 ${base_opt} "niconf.random_seed:9347${CUR_RUN}" "niconf.random_cuda_seed:9349${CUR_RUN}" ${EXTRA_ARGS} | ||
} | ||
|
||
run_one | ||
|
||
# RGPU=2 EXTRA_ARGS="device:0" DEBUG=1 bash _go.sh | ||
|
||
# to test | ||
#CUDA_VISIBLE_DEVICES=${RGPU} PYTHONPATH=${SRC_DIR} python3 -m pdb ${SRC_DIR}/tasks/cmd.py zie.main.test ${RUN_DIR}/_conf device:-1 dict_dir:${RUN_DIR}/ model_load_name:${RUN_DIR}/zmodel.best |
Oops, something went wrong.