From fbf82f23cecb754beb6df7c1a30f11700774f2b9 Mon Sep 17 00:00:00 2001 From: Michael Graeb Date: Wed, 1 Nov 2023 08:23:10 -0700 Subject: [PATCH 1/2] expose S3 multipart-threshold, and status-code for errors (#517) --- awscrt/s3.py | 41 ++++++++--- source/s3_client.c | 25 ++++--- source/s3_meta_request.c | 3 +- test/test_s3.py | 148 ++++++++++++++++++++++++--------------- 4 files changed, 140 insertions(+), 77 deletions(-) diff --git a/awscrt/s3.py b/awscrt/s3.py index c06f7bf4c..9019a4c64 100644 --- a/awscrt/s3.py +++ b/awscrt/s3.py @@ -137,11 +137,17 @@ class S3Client(NativeResource): for each connection, unless `tls_mode` is :attr:`S3RequestTlsMode.DISABLED` part_size (Optional[int]): Size, in bytes, of parts that files will be downloaded or uploaded in. - Note: for :attr:`S3RequestType.PUT_OBJECT` request, S3 requires the part size greater than 5MB. - (5*1024*1024 by default) + Note: for :attr:`S3RequestType.PUT_OBJECT` request, S3 requires the part size greater than 5 MiB. + (8*1024*1024 by default) - throughput_target_gbps (Optional[float]): Throughput target in Gbps that we are trying to reach. - (5 Gbps by default) + multipart_upload_threshold (Optional[int]): The size threshold in bytes, for when to use multipart uploads. + Uploads over this size will use the multipart upload strategy. + Uploads this size or less will use a single request. + If not set, `part_size` is used as the threshold. + + throughput_target_gbps (Optional[float]): Throughput target in + Gigabits per second (Gbps) that we are trying to reach. + (10.0 Gbps by default) """ __slots__ = ('shutdown_event', '_region') @@ -156,6 +162,7 @@ def __init__( credential_provider=None, tls_connection_options=None, part_size=None, + multipart_upload_threshold=None, throughput_target_gbps=None): assert isinstance(bootstrap, ClientBootstrap) or bootstrap is None assert isinstance(region, str) @@ -193,6 +200,8 @@ def on_shutdown(): tls_mode = 0 if part_size is None: part_size = 0 + if multipart_upload_threshold is None: + multipart_upload_threshold = 0 if throughput_target_gbps is None: throughput_target_gbps = 0 @@ -205,6 +214,7 @@ def on_shutdown(): region, tls_mode, part_size, + multipart_upload_threshold, throughput_target_gbps, s3_client_core) @@ -287,10 +297,16 @@ def make_request( failed because server side sent an unsuccessful response, the headers of the response is provided here. Else None will be returned. - * `error_body` (Optional[Bytes]): If request failed because server + * `error_body` (Optional[bytes]): If request failed because server side sent an unsuccessful response, the body of the response is provided here. Else None will be returned. + * `status_code` (Optional[int]): HTTP response status code (if available). + If request failed because server side sent an unsuccessful response, + this is its status code. If the operation was successful, + this is the final response's status code. If the operation + failed for another reason, None is returned. + * `**kwargs` (dict): Forward-compatibility kwargs. on_progress: Optional callback invoked when part of the transfer is done to report the progress. @@ -461,19 +477,26 @@ def _on_body(self, chunk, offset): def _on_shutdown(self): self._shutdown_event.set() - def _on_finish(self, error_code, error_headers, error_body): + def _on_finish(self, error_code, status_code, error_headers, error_body): + # If C layer gives status_code 0, that means "unknown" + if status_code == 0: + status_code = None + error = None if error_code: error = awscrt.exceptions.from_code(error_code) if error_body: # TODO The error body is XML, will need to parse it to something prettier. - extra_message = ". Body from error request is: " + str(error_body) - error.message = error.message + extra_message + try: + extra_message = ". Body from error request is: " + str(error_body) + error.message = error.message + extra_message + except BaseException: + pass self._finished_future.set_exception(error) else: self._finished_future.set_result(None) if self._on_done_cb: - self._on_done_cb(error=error, error_headers=error_headers, error_body=error_body) + self._on_done_cb(error=error, error_headers=error_headers, error_body=error_body, status_code=status_code) def _on_progress(self, progress): if self._on_progress_cb: diff --git a/source/s3_client.c b/source/s3_client.c index 171c66608..dc3454159 100644 --- a/source/s3_client.c +++ b/source/s3_client.c @@ -98,19 +98,20 @@ PyObject *aws_py_s3_client_new(PyObject *self, PyObject *args) { struct aws_allocator *allocator = aws_py_get_allocator(); - PyObject *bootstrap_py; /* O */ - PyObject *signing_config_py; /* O */ - PyObject *credential_provider_py; /* O */ - PyObject *tls_options_py; /* O */ - PyObject *on_shutdown_py; /* O */ - struct aws_byte_cursor region; /* s# */ - int tls_mode; /* i */ - uint64_t part_size; /* K */ - double throughput_target_gbps; /* d */ - PyObject *py_core; /* O */ + PyObject *bootstrap_py; /* O */ + PyObject *signing_config_py; /* O */ + PyObject *credential_provider_py; /* O */ + PyObject *tls_options_py; /* O */ + PyObject *on_shutdown_py; /* O */ + struct aws_byte_cursor region; /* s# */ + int tls_mode; /* i */ + uint64_t part_size; /* K */ + uint64_t multipart_upload_threshold; /* K */ + double throughput_target_gbps; /* d */ + PyObject *py_core; /* O */ if (!PyArg_ParseTuple( args, - "OOOOOs#iKdO", + "OOOOOs#iKKdO", &bootstrap_py, &signing_config_py, &credential_provider_py, @@ -120,6 +121,7 @@ PyObject *aws_py_s3_client_new(PyObject *self, PyObject *args) { ®ion.len, &tls_mode, &part_size, + &multipart_upload_threshold, &throughput_target_gbps, &py_core)) { return NULL; @@ -185,6 +187,7 @@ PyObject *aws_py_s3_client_new(PyObject *self, PyObject *args) { .tls_mode = tls_mode, .signing_config = signing_config, .part_size = part_size, + .multipart_upload_threshold = multipart_upload_threshold, .tls_connection_options = tls_options, .throughput_target_gbps = throughput_target_gbps, .shutdown_callback = s_s3_client_shutdown, diff --git a/source/s3_meta_request.c b/source/s3_meta_request.c index 5646dda75..aacbd4cfb 100644 --- a/source/s3_meta_request.c +++ b/source/s3_meta_request.c @@ -253,8 +253,9 @@ static void s_s3_request_on_finish( result = PyObject_CallMethod( request_binding->py_core, "_on_finish", - "(iOy#)", + "(iiOy#)", error_code, + meta_request_result->response_status, header_list ? header_list : Py_None, (const char *)(error_body.buffer), (Py_ssize_t)error_body.len); diff --git a/test/test_s3.py b/test/test_s3.py index 0e32f6675..09aa252cf 100644 --- a/test/test_s3.py +++ b/test/test_s3.py @@ -92,9 +92,19 @@ def full_path(self, filename): return os.path.join(self.rootdir, filename) -def s3_client_new(secure, region, part_size=0): - - event_loop_group = EventLoopGroup() +def s3_client_new(secure, region, part_size=0, is_cancel_test=False): + + if is_cancel_test: + # for cancellation tests, make things slow, so it's less likely that + # stuff succeeds on other threads before the cancellation is processed. + num_threads = 1 + throughput_target_gbps = 0.000028 # 28 Kbps beeepdiiingeep beeeeeekskhskshhKKKKchCH + else: + # else use defaults + num_threads = None + throughput_target_gbps = None + + event_loop_group = EventLoopGroup(num_threads) host_resolver = DefaultHostResolver(event_loop_group) bootstrap = ClientBootstrap(event_loop_group, host_resolver) credential_provider = AwsCredentialsProvider.new_default_chain(bootstrap) @@ -110,7 +120,8 @@ def s3_client_new(secure, region, part_size=0): region=region, signing_config=signing_config, tls_connection_options=tls_option, - part_size=part_size) + part_size=part_size, + throughput_target_gbps=throughput_target_gbps) return s3_client @@ -163,7 +174,7 @@ def setUp(self): self.timeout = 100 # seconds self.num_threads = 0 self.special_path = "put_object_test_10MB@$%.txt" - self.non_ascii_file_name = "ÉxÅmple.txt".encode("utf-8") + self.non_ascii_file_name = "ÉxÅmple.txt" self.response_headers = None self.response_status_code = None @@ -171,6 +182,10 @@ def setUp(self): self.transferred_len = 0 self.data_len = 0 self.progress_invoked = 0 + self.done_error = None + self.done_status_code = None + self.done_error_headers = None + self.done_error_body = None self.files = FileCreator() self.temp_put_obj_file_path = self.files.create_file_with_size("temp_put_obj_10mb", 10 * MB) @@ -205,11 +220,20 @@ def _on_request_headers(self, status_code, headers, **kargs): def _on_request_body(self, chunk, offset, **kargs): self.received_body_len = self.received_body_len + len(chunk) + def _on_request_done(self, error, error_headers, error_body, status_code, **kwargs): + self.done_error = error + self.done_error_headers = error_headers + self.done_error_body = error_body + self.done_status_code = status_code + def _on_progress(self, progress): self.transferred_len += progress def _validate_successful_response(self, is_put_object): self.assertEqual(self.response_status_code, 200, "status code is not 200") + self.assertEqual(self.done_status_code, self.response_status_code, + "status-code from on_done doesn't match code from on_headers") + self.assertIsNone(self.done_error) headers = HttpHeaders(self.response_headers) self.assertIsNone(headers.get("Content-Range")) body_length = headers.get("Content-Length") @@ -235,19 +259,22 @@ def _test_s3_put_get_object( type=request_type, on_headers=self._on_request_headers, on_body=self._on_request_body, + on_done=self._on_request_done, **kwargs) - finished_future = s3_request.finished_future - try: - finished_future.result(self.timeout) - except Exception as e: - self.assertEqual(e.name, exception_name) - else: - self._validate_successful_response(request_type is S3RequestType.PUT_OBJECT) + finished_future = s3_request.finished_future shutdown_event = s3_request.shutdown_event s3_request = None self.assertTrue(shutdown_event.wait(self.timeout)) + if exception_name is None: + finished_future.result() + self._validate_successful_response(request_type is S3RequestType.PUT_OBJECT) + else: + e = finished_future.exception() + self.assertEqual(e.name, exception_name) + self.assertEqual(e, self.done_error) + def test_get_object(self): request = self._get_object_request(self.get_test_object_path) self._test_s3_put_get_object(request, S3RequestType.GET_OBJECT) @@ -286,7 +313,8 @@ def test_put_object_multiple_times(self): type=S3RequestType.PUT_OBJECT, send_filepath=tempfile, on_headers=self._on_request_headers, - on_body=self._on_request_body) + on_body=self._on_request_body, + on_done=self._on_request_done) finished_futures.append(s3_request.finished_future) # request keeps connection alive. delete pointer so connection can shut down del s3_request @@ -312,7 +340,8 @@ def test_get_object_filepath(self): type=request_type, recv_filepath=file.name, on_headers=self._on_request_headers, - on_progress=self._on_progress) + on_progress=self._on_progress, + on_done=self._on_request_done) finished_future = s3_request.finished_future # Regression test: Let S3Request get GC'd early. @@ -359,6 +388,7 @@ def test_put_object_filepath_move(self): done_future = Future() def on_done_remove_file(**kwargs): + self._on_request_done(**kwargs) os.remove(tempfile) done_future.set_result(None) @@ -417,7 +447,7 @@ def _on_progress_cancel_after_first_chunk(self, progress): def test_multipart_get_object_cancel(self): # a 5 GB file request = self._get_object_request("/get_object_test_5120MB.txt") - s3_client = s3_client_new(False, self.region, 5 * MB) + s3_client = s3_client_new(False, self.region, 5 * MB, is_cancel_test=True) with tempfile.NamedTemporaryFile(mode="w", delete=False) as file: file.close() self.s3_request = s3_client.make_request( @@ -425,12 +455,11 @@ def test_multipart_get_object_cancel(self): recv_filepath=file.name, type=S3RequestType.GET_OBJECT, on_headers=self._on_request_headers, - on_progress=self._on_progress_cancel_after_first_chunk) + on_progress=self._on_progress_cancel_after_first_chunk, + on_done=self._on_request_done) finished_future = self.s3_request.finished_future - try: - finished_future.result(self.timeout) - except Exception as e: - self.assertEqual(e.name, "AWS_ERROR_S3_CANCELED") + e = finished_future.exception(self.timeout) + self.assertEqual(e.name, "AWS_ERROR_S3_CANCELED") # Result check self.data_len = int(HttpHeaders(self.response_headers).get("Content-Length")) @@ -449,7 +478,7 @@ def test_multipart_get_object_cancel(self): def test_get_object_quick_cancel(self): # a 5 GB file request = self._get_object_request("/get_object_test_5120MB.txt") - s3_client = s3_client_new(False, self.region, 5 * MB) + s3_client = s3_client_new(False, self.region, 5 * MB, is_cancel_test=True) with tempfile.NamedTemporaryFile(mode="w", delete=False) as file: file.close() s3_request = s3_client.make_request( @@ -457,39 +486,37 @@ def test_get_object_quick_cancel(self): recv_filepath=file.name, type=S3RequestType.GET_OBJECT, on_headers=self._on_request_headers, - on_progress=self._on_progress) + on_progress=self._on_progress, + on_done=self._on_request_done) s3_request.cancel() finished_future = s3_request.finished_future - try: - finished_future.result(self.timeout) - except Exception as e: - self.assertEqual(e.name, "AWS_ERROR_S3_CANCELED") + e = finished_future.exception(self.timeout) + self.assertEqual(e.name, "AWS_ERROR_S3_CANCELED") shutdown_event = s3_request.shutdown_event s3_request = None self.assertTrue(shutdown_event.wait(self.timeout)) os.remove(file.name) def _put_object_cancel_helper(self, cancel_after_read): - read_futrue = Future() - put_body_stream = FakeReadStream(read_futrue) + read_future = Future() + put_body_stream = FakeReadStream(read_future) data_len = 10 * GB # some fake length headers = HttpHeaders([("host", self._build_endpoint_string(self.region, self.bucket_name)), ("Content-Type", "text/plain"), ("Content-Length", str(data_len))]) http_request = HttpRequest("PUT", "/cancelled_request", headers, put_body_stream) - s3_client = s3_client_new(False, self.region, 8 * MB) + s3_client = s3_client_new(False, self.region, 8 * MB, is_cancel_test=True) s3_request = s3_client.make_request( request=http_request, type=S3RequestType.PUT_OBJECT, - on_headers=self._on_request_headers) + on_headers=self._on_request_headers, + on_done=self._on_request_done) if cancel_after_read: - read_futrue.result(self.timeout) + read_future.result(self.timeout) s3_request.cancel() finished_future = s3_request.finished_future - try: - finished_future.result(self.timeout) - except Exception as e: - self.assertEqual(e.name, "AWS_ERROR_S3_CANCELED") + e = finished_future.exception(self.timeout) + self.assertEqual(e.name, "AWS_ERROR_S3_CANCELED") shutdown_event = s3_request.shutdown_event s3_request = None @@ -505,19 +532,30 @@ def test_put_object_quick_cancel(self): return self._put_object_cancel_helper(False) def test_multipart_upload_with_invalid_request(self): - put_body_stream = open(self.temp_put_obj_file_path, "r+b") - content_length = os.stat(self.temp_put_obj_file_path).st_size + # send upload with incorrect Content-MD5 + # need to do single-part upload so the Content-MD5 header is sent along as-is. + content_length = 100 + file_path = self.files.create_file_with_size("temp_file", content_length) + put_body_stream = open(file_path, "r+b") request = self._put_object_request(put_body_stream, content_length) request.headers.set("Content-MD5", "something") self._test_s3_put_get_object(request, S3RequestType.PUT_OBJECT, "AWS_ERROR_S3_INVALID_RESPONSE_STATUS") + + # check that data from on_done callback came through correctly + self.assertIsNotNone(self.done_error) + self.assertEqual(self.done_status_code, 400) + self.assertIsNotNone(self.done_error_headers) + self.assertTrue(any(h[0].lower() == 'x-amz-request-id' for h in self.done_error_headers)) + self.assertIsNotNone(self.done_error_body) + self.assertTrue(b"InvalidDigest" in self.done_error_body) + put_body_stream.close() def test_special_filepath_upload(self): # remove the input file when request done - with open(self.special_path, 'wb') as file: - file.write(b"a" * 10 * MB) + content_length = 10 * MB + special_path = self.files.create_file_with_size(self.special_path, content_length) - content_length = os.stat(self.special_path).st_size request = self._put_object_request(None, content_length) s3_client = s3_client_new(False, self.region, 5 * MB) request_type = S3RequestType.PUT_OBJECT @@ -542,10 +580,11 @@ def test_special_filepath_upload(self): s3_request = s3_client.make_request( request=request, type=request_type, - send_filepath=self.special_path, + send_filepath=special_path, signing_config=signing_config, on_headers=self._on_request_headers, - on_progress=self._on_progress) + on_progress=self._on_progress, + on_done=self._on_request_done) finished_future = s3_request.finished_future finished_future.result(self.timeout) @@ -555,14 +594,12 @@ def test_special_filepath_upload(self): self.transferred_len, "the transferred length reported does not match body we sent") self._validate_successful_response(request_type is S3RequestType.PUT_OBJECT) - os.remove(self.special_path) + os.remove(special_path) def test_non_ascii_filepath_upload(self): # remove the input file when request done - with open(self.non_ascii_file_name, 'wb') as file: - file.write(b"a" * 10 * MB) - - content_length = os.stat(self.non_ascii_file_name).st_size + content_length = 10 * MB + non_ascii_file_path = self.files.create_file_with_size(self.non_ascii_file_name, content_length) request = self._put_object_request(None, content_length) s3_client = s3_client_new(False, self.region, 5 * MB) request_type = S3RequestType.PUT_OBJECT @@ -570,9 +607,10 @@ def test_non_ascii_filepath_upload(self): s3_request = s3_client.make_request( request=request, type=request_type, - send_filepath=self.non_ascii_file_name.decode("utf-8"), + send_filepath=non_ascii_file_path, on_headers=self._on_request_headers, - on_progress=self._on_progress) + on_progress=self._on_progress, + on_done=self._on_request_done) finished_future = s3_request.finished_future finished_future.result(self.timeout) @@ -582,26 +620,25 @@ def test_non_ascii_filepath_upload(self): self.transferred_len, "the transferred length reported does not match body we sent") self._validate_successful_response(request_type is S3RequestType.PUT_OBJECT) - os.remove(self.non_ascii_file_name) def test_non_ascii_filepath_download(self): - with open(self.non_ascii_file_name, 'wb') as file: - file.write(b"") + non_ascii_file_path = self.files.create_file_with_size(self.non_ascii_file_name, 0) request = self._get_object_request(self.get_test_object_path) request_type = S3RequestType.GET_OBJECT s3_client = s3_client_new(False, self.region, 5 * MB) s3_request = s3_client.make_request( request=request, type=request_type, - recv_filepath=self.non_ascii_file_name.decode("utf-8"), + recv_filepath=non_ascii_file_path, on_headers=self._on_request_headers, - on_progress=self._on_progress) + on_progress=self._on_progress, + on_done=self._on_request_done) finished_future = s3_request.finished_future finished_future.result(self.timeout) # Result check self.data_len = int(HttpHeaders(self.response_headers).get("Content-Length")) - file_stats = os.stat(self.non_ascii_file_name) + file_stats = os.stat(non_ascii_file_path) file_len = file_stats.st_size self.assertEqual( file_len, @@ -612,7 +649,6 @@ def test_non_ascii_filepath_download(self): self.transferred_len, "the transferred length reported does not match the content-length header") self.assertEqual(self.response_status_code, 200, "status code is not 200") - os.remove(self.non_ascii_file_name) if __name__ == '__main__': From 16e6492413b7ef94c6e7bcd787d6a57859d3d63c Mon Sep 17 00:00:00 2001 From: "Jonathan M. Henson" Date: Tue, 7 Nov 2023 13:01:44 -0800 Subject: [PATCH 2/2] Bind out cross-process lock with unit tests. (#519) Co-authored-by: Nate Prewitt --- awscrt/s3.py | 31 +++++++++++++++ crt/aws-c-common | 2 +- source/module.c | 3 ++ source/s3.h | 4 ++ source/s3_client.c | 99 ++++++++++++++++++++++++++++++++++++++++++++++ test/test_s3.py | 54 +++++++++++++++++++++++++ 6 files changed, 192 insertions(+), 1 deletion(-) diff --git a/awscrt/s3.py b/awscrt/s3.py index 9019a4c64..2278d8856 100644 --- a/awscrt/s3.py +++ b/awscrt/s3.py @@ -19,6 +19,37 @@ from enum import IntEnum +class CrossProcessLock(NativeResource): + """ + Class representing an exclusive cross-process lock, scoped by `lock_scope_name` + + Recommended usage is to either explicitly call acquire() followed by release() when the lock is no longer required, or use this in a 'with' statement. + + acquire() will throw a RuntimeError with AWS_MUTEX_CALLER_NOT_OWNER as the error code, if the lock could not be acquired. + + If the lock has not been explicitly released when the process exits, it will be released by the operating system. + + Keyword Args: + lock_scope_name (str): Unique string identifying the caller holding the lock. + """ + + def __init__(self, lock_scope_name): + super().__init__() + self._binding = _awscrt.s3_cross_process_lock_new(lock_scope_name) + + def acquire(self): + _awscrt.s3_cross_process_lock_acquire(self._binding) + + def __enter__(self): + self.acquire() + + def release(self): + _awscrt.s3_cross_process_lock_release(self._binding) + + def __exit__(self, exc_type, exc_value, exc_tb): + self.release() + + class S3RequestType(IntEnum): """The type of the AWS S3 request""" diff --git a/crt/aws-c-common b/crt/aws-c-common index e381a7bee..fb3182c54 160000 --- a/crt/aws-c-common +++ b/crt/aws-c-common @@ -1 +1 @@ -Subproject commit e381a7beeacb070f1816989dcb0e2c0ae6eccaea +Subproject commit fb3182c5411e4f5da2ee9372e0d66aa3f15a026d diff --git a/source/module.c b/source/module.c index 182b50860..20c6fadfe 100644 --- a/source/module.c +++ b/source/module.c @@ -801,6 +801,9 @@ static PyMethodDef s_module_methods[] = { AWS_PY_METHOD_DEF(s3_meta_request_cancel, METH_VARARGS), AWS_PY_METHOD_DEF(s3_get_ec2_instance_type, METH_NOARGS), AWS_PY_METHOD_DEF(s3_is_crt_s3_optimized_for_system, METH_NOARGS), + AWS_PY_METHOD_DEF(s3_cross_process_lock_new, METH_VARARGS), + AWS_PY_METHOD_DEF(s3_cross_process_lock_acquire, METH_VARARGS), + AWS_PY_METHOD_DEF(s3_cross_process_lock_release, METH_VARARGS), /* WebSocket */ AWS_PY_METHOD_DEF(websocket_client_connect, METH_VARARGS), diff --git a/source/s3.h b/source/s3.h index 85543adc0..48c847ea2 100644 --- a/source/s3.h +++ b/source/s3.h @@ -15,6 +15,10 @@ PyObject *aws_py_s3_client_make_meta_request(PyObject *self, PyObject *args); PyObject *aws_py_s3_meta_request_cancel(PyObject *self, PyObject *args); +PyObject *aws_py_s3_cross_process_lock_new(PyObject *self, PyObject *args); +PyObject *aws_py_s3_cross_process_lock_acquire(PyObject *self, PyObject *args); +PyObject *aws_py_s3_cross_process_lock_release(PyObject *self, PyObject *args); + struct aws_s3_client *aws_py_get_s3_client(PyObject *s3_client); struct aws_s3_meta_request *aws_py_get_s3_meta_request(PyObject *s3_client); diff --git a/source/s3_client.c b/source/s3_client.c index dc3454159..25966d756 100644 --- a/source/s3_client.c +++ b/source/s3_client.c @@ -6,9 +6,11 @@ #include "auth.h" #include "io.h" +#include #include static const char *s_capsule_name_s3_client = "aws_s3_client"; +static const char *s_capsule_name_s3_instance_lock = "aws_cross_process_lock"; PyObject *aws_py_s3_get_ec2_instance_type(PyObject *self, PyObject *args) { (void)self; @@ -37,6 +39,103 @@ PyObject *aws_py_s3_is_crt_s3_optimized_for_system(PyObject *self, PyObject *arg Py_RETURN_FALSE; } +struct cross_process_lock_binding { + struct aws_cross_process_lock *lock; + struct aws_string *name; +}; + +/* Invoked when the python object gets cleaned up */ +static void s_s3_cross_process_lock_destructor(PyObject *capsule) { + struct cross_process_lock_binding *lock_binding = PyCapsule_GetPointer(capsule, s_capsule_name_s3_instance_lock); + + if (lock_binding->lock) { + aws_cross_process_lock_release(lock_binding->lock); + lock_binding->lock = NULL; + } + + if (lock_binding->name) { + aws_string_destroy(lock_binding->name); + } + + aws_mem_release(aws_py_get_allocator(), lock_binding); +} + +PyObject *aws_py_s3_cross_process_lock_new(PyObject *self, PyObject *args) { + (void)self; + + struct aws_allocator *allocator = aws_py_get_allocator(); + + struct aws_byte_cursor lock_name; /* s# */ + + if (!PyArg_ParseTuple(args, "s#", &lock_name.ptr, &lock_name.len)) { + return NULL; + } + + struct cross_process_lock_binding *binding = + aws_mem_calloc(allocator, 1, sizeof(struct cross_process_lock_binding)); + binding->name = aws_string_new_from_cursor(allocator, &lock_name); + + PyObject *capsule = PyCapsule_New(binding, s_capsule_name_s3_instance_lock, s_s3_cross_process_lock_destructor); + if (!capsule) { + aws_string_destroy(binding->name); + aws_mem_release(allocator, binding); + return PyErr_AwsLastError(); + } + + return capsule; +} + +PyObject *aws_py_s3_cross_process_lock_acquire(PyObject *self, PyObject *args) { + (void)self; + + struct aws_allocator *allocator = aws_py_get_allocator(); + + PyObject *lock_capsule; /* O */ + + if (!PyArg_ParseTuple(args, "O", &lock_capsule)) { + return NULL; + } + + struct cross_process_lock_binding *lock_binding = + PyCapsule_GetPointer(lock_capsule, s_capsule_name_s3_instance_lock); + if (!lock_binding) { + return NULL; + } + + if (!lock_binding->lock) { + struct aws_cross_process_lock *lock = + aws_cross_process_lock_try_acquire(allocator, aws_byte_cursor_from_string(lock_binding->name)); + + if (!lock) { + return PyErr_AwsLastError(); + } + lock_binding->lock = lock; + } + + Py_RETURN_NONE; +} + +PyObject *aws_py_s3_cross_process_lock_release(PyObject *self, PyObject *args) { + PyObject *lock_capsule; /* O */ + + if (!PyArg_ParseTuple(args, "O", &lock_capsule)) { + return NULL; + } + + struct cross_process_lock_binding *lock_binding = + PyCapsule_GetPointer(lock_capsule, s_capsule_name_s3_instance_lock); + if (!lock_binding) { + return NULL; + } + + if (lock_binding->lock) { + aws_cross_process_lock_release(lock_binding->lock); + lock_binding->lock = NULL; + } + + Py_RETURN_NONE; +} + struct s3_client_binding { struct aws_s3_client *native; diff --git a/test/test_s3.py b/test/test_s3.py index 09aa252cf..8602a8a14 100644 --- a/test/test_s3.py +++ b/test/test_s3.py @@ -8,8 +8,10 @@ import tempfile import math import shutil +import time from test import NativeResourceTest from concurrent.futures import Future +from multiprocessing import Process from awscrt.http import HttpHeaders, HttpRequest from awscrt.s3 import ( @@ -18,6 +20,7 @@ S3ChecksumLocation, S3Client, S3RequestType, + CrossProcessLock, create_default_s3_signing_config, ) from awscrt.io import ( @@ -41,6 +44,57 @@ MB = 1024 ** 2 GB = 1024 ** 3 +cross_process_lock_name = "instance_lock_test" + + +def cross_proc_task(): + try: + lock = CrossProcessLock(cross_process_lock_name) + lock.acquire() + lock.release() + exit(0) + except RuntimeError as e: + exit(-1) + + +class CrossProcessLockTest(NativeResourceTest): + def setUp(self): + self.nonce = time.time() + super().setUp() + + def test_with_statement(self): + nonce_str = f'lock_a_{self.nonce}' + with CrossProcessLock(nonce_str) as lock: + try: + new_lock = CrossProcessLock(nonce_str) + new_lock.acquire() + self.fail("Acquiring a lock by the same nonce should fail when it's already held") + except RuntimeError as e: + unique_nonce_str = f'lock_b{self.nonce}' + new_lock = CrossProcessLock(unique_nonce_str) + new_lock.acquire() + new_lock.release() + + lock_after_with_same_nonce = CrossProcessLock(nonce_str) + lock_after_with_same_nonce.acquire() + lock_after_with_same_nonce.release() + + def test_cross_proc(self): + with CrossProcessLock(cross_process_lock_name) as lock: + process = Process(target=cross_proc_task) + process.start() + process.join() + # aquiring this lock in a sub-process should fail since we + # already hold the lock in this process. + self.assertNotEqual(0, process.exitcode) + + # now that we've released the lock above, the same sub-process path + # should now succeed. + unlocked_process = Process(target=cross_proc_task) + unlocked_process.start() + unlocked_process.join() + self.assertEqual(0, unlocked_process.exitcode) + class FileCreator(object): def __init__(self):