Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gRPC timeout #21

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion run_validation/main_task/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from compiled_protobufs.passage_validator_pb2_grpc import add_PassageValidatorServicer_to_server
from passage_id_db import PassageIDDatabase
from passage_validator_servicer import EXPECTED_ID_COUNT
from main import load_run_file, get_service_stub
from main import load_run_file, get_service_stub, GRPC_DEFAULT_TIMEOUT

# see https://docs.pytest.org/en/latest/example/simple.html#control-skipping-of-tests-according-to-command-line-option
def pytest_addoption(parser):
Expand Down Expand Up @@ -81,6 +81,7 @@ def default_validate_args():
skip_passage_validation=False,
fileroot=test_root,
strict=False,
timeout=GRPC_DEFAULT_TIMEOUT,
)

@pytest.fixture
Expand Down
27 changes: 16 additions & 11 deletions run_validation/main_task/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from compiled_protobufs.run_pb2 import CastRun, Turn
from utils import check_provenance, validate_passages, check_response

GRPC_DEFAULT_TIMEOUT = 3.0

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

Expand Down Expand Up @@ -66,7 +68,7 @@ def load_run_file(run_file_path: str) -> CastRun:

return run

def validate_turn(turn: Turn, turn_lookup_set: dict, service_stub: PassageValidatorStub) -> (int, bool):
def validate_turn(turn: Turn, turn_lookup_set: dict, service_stub: PassageValidatorStub, timeout: float) -> (int, bool):
warning_count, service_errors = 0, 0

# check turns are valid
Expand All @@ -80,7 +82,7 @@ def validate_turn(turn: Turn, turn_lookup_set: dict, service_stub: PassageValida
# will be None if skip_passage_validation was used
if service_stub is not None:
try:
warning_count = validate_passages(service_stub, logger, warning_count, turn)
warning_count = validate_passages(service_stub, logger, warning_count, turn, timeout)
except grpc.RpcError as rpce:
logger.warning(f'A gRPC error occurred when validating passages ({rpce.code().name})')
service_errors += 1
Expand All @@ -107,12 +109,12 @@ def validate_turn(turn: Turn, turn_lookup_set: dict, service_stub: PassageValida

return warning_count, service_errors

def validate_run(run: CastRun, turn_lookup_set: dict, service_stub: PassageValidatorStub, max_warnings: int, strict: bool) -> (int, int, int):
def validate_run(run: CastRun, turn_lookup_set: dict, service_stub: PassageValidatorStub, max_warnings: int, strict: bool, timeout: float) -> (int, int, int):
total_warnings, service_errors = 0, 0
turns_validated = 0

for turn in run.turns:
_warnings, _service_errors = validate_turn(turn, turn_lookup_set, service_stub)
_warnings, _service_errors = validate_turn(turn, turn_lookup_set, service_stub, timeout)
total_warnings += _warnings
service_errors += _service_errors
turns_validated += 1
Expand All @@ -128,7 +130,7 @@ def validate_run(run: CastRun, turn_lookup_set: dict, service_stub: PassageValid
logger.info(f'Validation completed on {turns_validated}/{len(run.turns)} turns with {total_warnings} warnings, {service_errors} service errors')
return turns_validated, service_errors, total_warnings

def validate(run_file_path: str, fileroot: str, max_warnings: int, skip_validation: bool, strict: bool) -> (int, int, int):
def validate(run_file_path: str, fileroot: str, max_warnings: int, skip_validation: bool, strict: bool, timeout: float = GRPC_DEFAULT_TIMEOUT) -> (int, int, int):
run_file_name = PurePath(run_file_path).name
fileHandler = logging.FileHandler(filename=f'{run_file_name}.errlog')
fileHandler.setFormatter(formatter)
Expand All @@ -149,22 +151,25 @@ def validate(run_file_path: str, fileroot: str, max_warnings: int, skip_validati
logger.warning('Loaded run file has 0 turns, not performing any validation!')
return len(run.turns)

turns_validated, service_errors, total_warnings = validate_run(run, turn_lookup_set, validator_stub, max_warnings, strict)
turns_validated, service_errors, total_warnings = validate_run(run, turn_lookup_set, validator_stub, max_warnings, strict, timeout)

return turns_validated, service_errors, total_warnings

if __name__ == '__main__':
ap = argparse.ArgumentParser(description='TREC 2022 CAsT main task validator',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
ap.add_argument('task_name')
ap.add_argument('path_to_run_file')
ap.add_argument('task_name', help='CAST is currently the only supported option')
ap.add_argument('path_to_run_file', help='Path to the JSON-formatted run file to be validated')
ap.add_argument('-f', '--fileroot', help='Location of data files',
default='.')
ap.add_argument('--skip_passage_validation', action='store_true')
ap.add_argument('-m', '--max_warnings', help='Maximum number of warnings to allow',
ap.add_argument('--skip_passage_validation', help='Do NOT validate passage IDs using the validator service',
action='store_true')
ap.add_argument('-m', '--max_warnings', help='Maximum number of validation warnings to allow',
type=int, default=25)
ap.add_argument('-s', '--strict', help='Abort if any passage validation service errors occur',
action='store_true')
ap.add_argument('-t', '--timeout', help='Set the gRPC timeout (secs) for contacting the validation service',
type=float, default=GRPC_DEFAULT_TIMEOUT)
args = ap.parse_args()

validate(args.path_to_run_file, args.fileroot, args.max_warnings, args.skip_passage_validation, args.strict)
validate(args.path_to_run_file, args.fileroot, args.max_warnings, args.skip_passage_validation, args.strict, args.timeout)
24 changes: 12 additions & 12 deletions run_validation/main_task/tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest

from main import get_service_stub, load_turn_lookup_set, load_run_file
from main import validate_turn, validate_run, validate
from main import validate_turn, validate_run, validate, GRPC_DEFAULT_TIMEOUT

def test_get_service_stub(grpc_server_test):
assert(get_service_stub() is not None)
Expand Down Expand Up @@ -53,7 +53,7 @@ def test_validate_missing_run_file():

def test_validate_turn(turns_lookup_path, run_file_path, grpc_stub_test, sample_turn):
turn_lookup_set = load_turn_lookup_set(turns_lookup_path)
warnings, service_errors = validate_turn(sample_turn, turn_lookup_set, grpc_stub_test)
warnings, service_errors = validate_turn(sample_turn, turn_lookup_set, grpc_stub_test, GRPC_DEFAULT_TIMEOUT)
assert(warnings == 4) # due to small database being used
assert(service_errors == 0)

Expand All @@ -71,7 +71,7 @@ def test_validate_run(turns_lookup_path, run_file_path, grpc_stub_test, default_
run.turns.extend(first_turns)
assert(len(run.turns) == 8)

turns_validated, service_errors, total_warnings = validate_run(run, turn_lookup_set, grpc_stub_test, args.max_warnings, args.strict)
turns_validated, service_errors, total_warnings = validate_run(run, turn_lookup_set, grpc_stub_test, args.max_warnings, args.strict, args.timeout)

assert(turns_validated == 8)
assert(service_errors == 0)
Expand All @@ -92,7 +92,7 @@ def test_validate_run_strict(turns_lookup_path, run_file_path, grpc_stub_test, d
run.turns.extend(first_turns)
assert(len(run.turns) == 8)

turns_validated, service_errors, total_warnings = validate_run(run, turn_lookup_set, grpc_stub_test, args.max_warnings, args.strict)
turns_validated, service_errors, total_warnings = validate_run(run, turn_lookup_set, grpc_stub_test, args.max_warnings, args.strict, args.timeout)
assert(turns_validated == 8)
assert(service_errors == 0)
assert(total_warnings == 25)
Expand All @@ -114,7 +114,7 @@ def test_validate_run_strict_invalid(turns_lookup_path, run_file_path, grpc_stub
assert(len(run.turns) == 8)

with pytest.raises(SystemExit) as pytest_exc:
turns_validated, service_errors, total_warnings = validate_run(run, turn_lookup_set, grpc_stub_test_invalid, args.max_warnings, args.strict)
turns_validated, service_errors, total_warnings = validate_run(run, turn_lookup_set, grpc_stub_test_invalid, args.max_warnings, args.strict, args.timeout)

assert pytest_exc.type == SystemExit
assert pytest_exc.value.code == 255
Expand All @@ -125,7 +125,7 @@ def test_validate_run_no_service(turns_lookup_path, run_file_path, default_valid
run = load_run_file(run_file_path)
assert(len(run.turns) == 205)

turns_validated, service_errors, total_warnings = validate_run(run, turn_lookup_set, None, args.max_warnings, args.strict)
turns_validated, service_errors, total_warnings = validate_run(run, turn_lookup_set, None, args.max_warnings, args.strict, args.timeout)
assert(turns_validated == 205)
assert(service_errors == 0)
assert(total_warnings == 0)
Expand All @@ -134,7 +134,7 @@ def test_validate_run_no_service(turns_lookup_path, run_file_path, default_valid
def test_validate(default_validate_args, grpc_server_full):
args = default_validate_args

turns_validated, service_errors, total_warnings = validate(args.path_to_run_file, args.fileroot, args.max_warnings, args.skip_passage_validation, args.strict)
turns_validated, service_errors, total_warnings = validate(args.path_to_run_file, args.fileroot, args.max_warnings, args.skip_passage_validation, args.strict, args.timeout)
assert(turns_validated == 205)
assert(service_errors == 0)
assert(total_warnings == 1) # seems to be 1 invalid ID in the sample_runs.json file?
Expand All @@ -146,7 +146,7 @@ def test_validate_no_service(default_validate_args, grpc_server_full):
# terminate the service
grpc_server_full.stop(None)

turns_validated, service_errors, total_warnings = validate(args.path_to_run_file, args.fileroot, args.max_warnings, args.skip_passage_validation, args.strict)
turns_validated, service_errors, total_warnings = validate(args.path_to_run_file, args.fileroot, args.max_warnings, args.skip_passage_validation, args.strict, args.timeout)
assert(turns_validated == 205)
assert(service_errors == 205)
assert(total_warnings == 0)
Expand All @@ -159,7 +159,7 @@ def test_validate_no_service_skip_validation(default_validate_args, grpc_server_
# terminate the service
grpc_server_full.stop(None)

turns_validated, service_errors, total_warnings = validate(args.path_to_run_file, args.fileroot, args.max_warnings, args.skip_passage_validation, args.strict)
turns_validated, service_errors, total_warnings = validate(args.path_to_run_file, args.fileroot, args.max_warnings, args.skip_passage_validation, args.strict, args.timeout)
assert(turns_validated == 205)
assert(service_errors == 0)
assert(total_warnings == 0)
Expand All @@ -173,7 +173,7 @@ def test_validate_no_service_strict(default_validate_args, grpc_server_full):
grpc_server_full.stop(None)

with pytest.raises(SystemExit) as pytest_exc:
turns_validated, service_errors, total_warnings = validate(args.path_to_run_file, args.fileroot, args.max_warnings, args.skip_passage_validation, args.strict)
turns_validated, service_errors, total_warnings = validate(args.path_to_run_file, args.fileroot, args.max_warnings, args.skip_passage_validation, args.strict, args.timeout)

assert(pytest_exc.type == SystemExit)
assert(pytest_exc.value.code == 255)
Expand All @@ -182,14 +182,14 @@ def test_validate_empty(default_validate_args):
args = default_validate_args
args.path_to_run_file = 'foobar'
with pytest.raises(FileNotFoundError):
_, _, _ = validate(args.path_to_run_file, args.fileroot, args.max_warnings, args.skip_passage_validation, args.strict)
_, _, _ = validate(args.path_to_run_file, args.fileroot, args.max_warnings, args.skip_passage_validation, args.strict, args.timeout)

def test_validate_small(default_validate_args, grpc_server_test):
args = default_validate_args

# this should abort after generating enough warnings, since the smaller database won't match most of the IDs
with pytest.raises(SystemExit) as pytest_exc:
turns_validated, service_errors, total_warnings = validate(args.path_to_run_file, args.fileroot, args.max_warnings, args.skip_passage_validation, args.strict)
turns_validated, service_errors, total_warnings = validate(args.path_to_run_file, args.fileroot, args.max_warnings, args.skip_passage_validation, args.strict, args.timeout)

assert(pytest_exc.type == SystemExit)
assert(pytest_exc.value.code == 255)
3 changes: 2 additions & 1 deletion run_validation/main_task/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from compiled_protobufs.passage_validator_pb2 import PassageValidationRequest
from utils import validate_passages
from main import GRPC_DEFAULT_TIMEOUT

def build_request(ids):
request = PassageValidationRequest()
Expand All @@ -13,7 +14,7 @@ def get_invalid_indices(response):

def test_validate_passages(grpc_stub_test, test_logger, sample_turn):
warning_count = 0
warning_count = validate_passages(grpc_stub_test, test_logger, warning_count, sample_turn)
warning_count = validate_passages(grpc_stub_test, test_logger, warning_count, sample_turn, GRPC_DEFAULT_TIMEOUT)

assert(warning_count == 4)

Expand Down
4 changes: 2 additions & 2 deletions run_validation/main_task/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def check_provenance(previous_score: float, provenance: Provenance, logger: Logg

return previous_score, provenance_count, warning_count

def validate_passages(passage_validation_client: PassageValidatorStub, logger: Logger, warning_count: int, turn: Turn) -> int:
def validate_passages(passage_validation_client: PassageValidatorStub, logger: Logger, warning_count: int, turn: Turn, timeout: float) -> int:
# collect passage ids
passage_validation_request = PassageValidationRequest()

Expand All @@ -48,7 +48,7 @@ def validate_passages(passage_validation_client: PassageValidatorStub, logger: L
passage_validation_request.passage_ids.MergeFrom(passage_ids)

# validate ids
passage_validation_result = passage_validation_client.validate_passages(passage_validation_request)
passage_validation_result = passage_validation_client.validate_passages(passage_validation_request, timeout=timeout)

invalid_indexes = [i for i, passage_validation in enumerate(passage_validation_result.passage_validations) if not passage_validation.is_valid]
for index in invalid_indexes:
Expand Down