-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathutils.py
92 lines (80 loc) · 2.43 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import argparse
import os
import tensorflow as tf
def process_config() -> dict:
"""
Add in any static configuration that is unlikely to change very often
:return: a dictionary of static configuration data
"""
config = {"exp_name": "example_model_train"}
return config
def get_args() -> dict:
"""
Get command line arguments add and remove any needed by your project
:return: Namespace of command arguments
"""
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--train-files",
help="GCS or local paths to training data",
nargs="+",
required=True,
)
parser.add_argument(
"--num-epochs",
help="Maximum number of training data epochs on which to train.",
type=int,
required=True,
)
parser.add_argument(
"--train-batch-size", help="Batch size for training steps", type=int, default=32
)
parser.add_argument(
"--eval-batch-size",
help="Batch size for evaluation steps",
type=int,
default=32,
)
parser.add_argument(
"--export-path",
type=str,
help="Where to export the saved model to locally or on GCP",
)
parser.add_argument(
"--eval-files",
help="GCS or local paths to evaluation data",
nargs="+",
required=True,
)
parser.add_argument(
"--test-files", help="GCS or local paths to test data", nargs="+", required=True
)
# Training arguments
parser.add_argument(
"--learning-rate",
help="Learning rate for the optimizer",
default=0.001,
type=float,
)
parser.add_argument(
"--job-dir",
help="GCS location to write checkpoints and export models",
required=True,
)
parser.add_argument(
"--verbosity",
choices=["DEBUG", "ERROR", "FATAL", "INFO", "WARN"],
default="DEBUG",
help="Set logging verbosity",
)
parser.add_argument(
"--keep-prob", help="Keep probability for dropout", default=0.5, type=int
)
args, unknown = parser.parse_known_args()
# Set python level verbosity
tf.logging.set_verbosity(args.verbosity)
# Set C++ Graph Execution level verbosity
os.environ["TF_CPP_MIN_LOG_LEVEL"] = str(tf.logging.__dict__[args.verbosity] / 10)
if unknown:
tf.logging.warn("Unknown arguments: {}".format(unknown))
return vars(args)