From 97ebaede149b834c72c0ddf6b44b048a507c7132 Mon Sep 17 00:00:00 2001 From: KiraPC Date: Sun, 30 May 2021 12:32:34 +0200 Subject: [PATCH 1/2] added new sts role configuration parameter to assume aws role dynamically --- pynamodb/connection/base.py | 59 +++++++++++++++++++++++++++++++++-- pynamodb/connection/table.py | 38 ++++++++++++---------- pynamodb/models.py | 38 +++++++++++++++------- tests/test_base_connection.py | 19 +++++++++-- 4 files changed, 120 insertions(+), 34 deletions(-) diff --git a/pynamodb/connection/base.py b/pynamodb/connection/base.py index cca050f7d..1ca3fcfcb 100644 --- a/pynamodb/connection/base.py +++ b/pynamodb/connection/base.py @@ -1,6 +1,8 @@ """ Lowest level connection """ +from datetime import datetime + import json import logging import random @@ -54,6 +56,9 @@ from pynamodb.settings import get_settings_value, OperationSettings from pynamodb.signals import pre_dynamodb_send, post_dynamodb_send from pynamodb.types import HASH, RANGE +import pytz + +utc=pytz.UTC BOTOCORE_EXCEPTIONS = (BotoCoreError, ClientError) RATE_LIMITING_ERROR_CODES = ['ProvisionedThroughputExceededException', 'ThrottlingException'] @@ -233,6 +238,15 @@ def get_exclusive_start_key_map(self, exclusive_start_key): } } +class STSCredentials: + def __init__(self, **kwargs): + self.access_key = kwargs.get('AccessKeyId') + self.secret_key = kwargs.get('SecretAccessKey') + self.token = kwargs.get('SessionToken') + self.token_expiration = kwargs.get('Expiration', datetime.now(tz=utc)) + + def expired(self): + return self.token_expiration <= datetime.now(tz=utc) class Connection(object): """ @@ -247,11 +261,21 @@ def __init__(self, max_retry_attempts: Optional[int] = None, base_backoff_ms: Optional[int] = None, max_pool_connections: Optional[int] = None, - extra_headers: Optional[Mapping[str, str]] = None): + extra_headers: Optional[Mapping[str, str]] = None, + aws_sts_role_arn=None, aws_sts_role_session_name=None, + aws_sts_session_expiration=None): self._tables: Dict[str, MetaTable] = {} self.host = host self._local = local() self._client = None + + if aws_sts_role_arn is not None: + # Initialize empty STS Credentials if STS auth is configured + self.sts_session = STSCredentials() + self.aws_sts_role_arn = aws_sts_role_arn + self.aws_sts_role_session_name = aws_sts_role_session_name + self.aws_sts_session_expiration = aws_sts_session_expiration or 3600 + if region: self.region = region else: @@ -508,6 +532,25 @@ def _handle_binary_attributes(data): _convert_binary(attr) return data + def is_sts_session_required(self): + if not hasattr(self, 'sts_session'): + return False + + if self.sts_session.expired(): + self.assume_role_session() + return True + + return False + + def assume_role_session(self): + sts = self.session.create_client('sts') + sts_response = sts.assume_role( + RoleArn=self.aws_sts_role_arn, + RoleSessionName=self.aws_sts_role_session_name or 'PynamoDB', + DurationSeconds=self.aws_sts_session_expiration + ) + self.sts_session = STSCredentials(**sts_response['Credentials']) + @property def session(self) -> botocore.session.Session: """ @@ -527,13 +570,23 @@ def client(self): # https://github.com/boto/botocore/blob/4d55c9b4142/botocore/credentials.py#L1016-L1021 # if the client does not have credentials, we create a new client # otherwise the client is permanently poisoned in the case of metadata service flakiness when using IAM roles - if not self._client or (self._client._request_signer and not self._client._request_signer._credentials): + if not self._client or (self._client._request_signer and not self._client._request_signer._credentials) or self.is_sts_session_required(): config = botocore.client.Config( parameter_validation=False, # Disable unnecessary validation for performance connect_timeout=self._connect_timeout_seconds, read_timeout=self._read_timeout_seconds, max_pool_connections=self._max_pool_connections) - self._client = self.session.create_client(SERVICE_NAME, self.region, endpoint_url=self.host, config=config) + + credentials = self.sts_session if self.is_sts_session_required() else self.session.get_credentials() + + self._client = self.session.create_client( + SERVICE_NAME, + self.region, + endpoint_url=self.host, + config=config, + aws_access_key_id=credentials.access_key, + aws_secret_access_key=credentials.secret_key, + aws_session_token=credentials.token) return self._client def get_meta_table(self, table_name: str, refresh: bool = False): diff --git a/pynamodb/connection/table.py b/pynamodb/connection/table.py index 183467a9f..ce91c0d46 100644 --- a/pynamodb/connection/table.py +++ b/pynamodb/connection/table.py @@ -16,21 +16,24 @@ class TableConnection: A higher level abstraction over botocore """ - def __init__( - self, - table_name: str, - region: Optional[str] = None, - host: Optional[str] = None, - connect_timeout_seconds: Optional[float] = None, - read_timeout_seconds: Optional[float] = None, - max_retry_attempts: Optional[int] = None, - base_backoff_ms: Optional[int] = None, - max_pool_connections: Optional[int] = None, - extra_headers: Optional[Mapping[str, str]] = None, - aws_access_key_id: Optional[str] = None, - aws_secret_access_key: Optional[str] = None, - aws_session_token: Optional[str] = None, - ) -> None: + def __init__(self, + table_name, + region=None, + host=None, + connect_timeout_seconds=None, + read_timeout_seconds=None, + max_retry_attempts=None, + base_backoff_ms=None, + max_pool_connections=None, + extra_headers=None, + aws_access_key_id=None, + aws_secret_access_key=None, + aws_session_token=None, + aws_sts_role_arn=None, + aws_sts_role_session_name=None, + aws_sts_session_expiration=None): + self._hash_keyname = None + self._range_keyname = None self.table_name = table_name self.connection = Connection(region=region, host=host, @@ -39,7 +42,10 @@ def __init__( max_retry_attempts=max_retry_attempts, base_backoff_ms=base_backoff_ms, max_pool_connections=max_pool_connections, - extra_headers=extra_headers) + extra_headers=extra_headers, + aws_sts_role_arn=aws_sts_role_arn, + aws_sts_role_session_name=aws_sts_role_session_name, + aws_sts_session_expiration=aws_sts_session_expiration) if aws_access_key_id and aws_secret_access_key: self.connection.session.set_credentials(aws_access_key_id, diff --git a/pynamodb/models.py b/pynamodb/models.py index 3ec7952cc..30abf1085 100644 --- a/pynamodb/models.py +++ b/pynamodb/models.py @@ -191,6 +191,9 @@ class MetaProtocol(Protocol): billing_mode: Optional[str] tags: Optional[Dict[str, str]] stream_view_type: Optional[str] + aws_sts_role_arn: Optional[str] + aws_sts_role_session_name: Optional[str] + aws_sts_session_expiration: Optional[int] class MetaModel(AttributeContainerMeta): @@ -253,6 +256,12 @@ def __init__(self, name, bases, namespace, discriminator=None) -> None: setattr(attr_obj, 'aws_secret_access_key', None) if not hasattr(attr_obj, 'aws_session_token'): setattr(attr_obj, 'aws_session_token', None) + if not hasattr(attr_obj, 'aws_sts_role_arn'): + setattr(attr_obj, 'aws_sts_role_arn', None) + if not hasattr(attr_obj, 'aws_sts_role_session_name'): + setattr(attr_obj, 'aws_sts_role_session_name', None) + if not hasattr(attr_obj, 'aws_sts_session_expiration'): + setattr(attr_obj, 'aws_sts_session_expiration', None) elif isinstance(attr_obj, Index): attr_obj.Meta.model = cls if not hasattr(attr_obj.Meta, "index_name"): @@ -1079,18 +1088,23 @@ def _get_connection(cls) -> TableConnection: # For now we just check that the connection exists and (in the case of model inheritance) # points to the same table. In the future we should update the connection if any of the attributes differ. if cls._connection is None or cls._connection.table_name != cls.Meta.table_name: - cls._connection = TableConnection(cls.Meta.table_name, - region=cls.Meta.region, - host=cls.Meta.host, - connect_timeout_seconds=cls.Meta.connect_timeout_seconds, - read_timeout_seconds=cls.Meta.read_timeout_seconds, - max_retry_attempts=cls.Meta.max_retry_attempts, - base_backoff_ms=cls.Meta.base_backoff_ms, - max_pool_connections=cls.Meta.max_pool_connections, - extra_headers=cls.Meta.extra_headers, - aws_access_key_id=cls.Meta.aws_access_key_id, - aws_secret_access_key=cls.Meta.aws_secret_access_key, - aws_session_token=cls.Meta.aws_session_token) + cls._connection = TableConnection( + cls.Meta.table_name, + region=cls.Meta.region, + host=cls.Meta.host, + connect_timeout_seconds=cls.Meta.connect_timeout_seconds, + read_timeout_seconds=cls.Meta.read_timeout_seconds, + max_retry_attempts=cls.Meta.max_retry_attempts, + base_backoff_ms=cls.Meta.base_backoff_ms, + max_pool_connections=cls.Meta.max_pool_connections, + extra_headers=cls.Meta.extra_headers, + aws_access_key_id=cls.Meta.aws_access_key_id, + aws_secret_access_key=cls.Meta.aws_secret_access_key, + aws_session_token=cls.Meta.aws_session_token, + aws_sts_role_arn=cls.Meta.aws_sts_role_arn, + aws_sts_role_session_name=cls.Meta.aws_sts_role_session_name, + aws_sts_session_expiration=cls.Meta.aws_sts_session_expiration + ) return cls._connection @classmethod diff --git a/tests/test_base_connection.py b/tests/test_base_connection.py index 7319f91d2..a4901350d 100644 --- a/tests/test_base_connection.py +++ b/tests/test_base_connection.py @@ -86,8 +86,14 @@ def test_subsequent_client_is_not_cached_when_credentials_none(self): session_mock.create_client.assert_has_calls( [ - mock.call('dynamodb', 'us-east-1', endpoint_url=None, config=mock.ANY), - mock.call('dynamodb', 'us-east-1', endpoint_url=None, config=mock.ANY), + mock.call('dynamodb', 'us-east-1', endpoint_url=None, config=mock.ANY, + aws_access_key_id=mock.ANY, + aws_secret_access_key=mock.ANY, + aws_session_token=mock.ANY), + mock.call('dynamodb', 'us-east-1', endpoint_url=None, config=mock.ANY, + aws_access_key_id=mock.ANY, + aws_secret_access_key=mock.ANY, + aws_session_token=mock.ANY), ], any_order=True ) @@ -102,7 +108,14 @@ def test_subsequent_client_is_cached_when_credentials_truthy(self): self.assertIsNotNone(conn.client) self.assertEqual( - session_mock.create_client.mock_calls.count(mock.call('dynamodb', 'us-east-1', endpoint_url=None, config=mock.ANY)), + session_mock.create_client.mock_calls.count(mock.call( + 'dynamodb', + 'us-east-1', + endpoint_url=None, + config=mock.ANY, + aws_access_key_id=mock.ANY, + aws_secret_access_key=mock.ANY, + aws_session_token=mock.ANY)), 1 ) From 5ec3c0d54c68e79b7e1fa4ae850bfcd328d1f1e8 Mon Sep 17 00:00:00 2001 From: KiraPC Date: Mon, 31 May 2021 10:31:42 +0200 Subject: [PATCH 2/2] added tests and updated doc --- docs/awsaccess.rst | 14 +++++++++ pynamodb/connection/base.py | 7 ++--- tests/test_base_connection.py | 59 +++++++++++++++++++++++++++++++++++ 3 files changed, 76 insertions(+), 4 deletions(-) diff --git a/docs/awsaccess.rst b/docs/awsaccess.rst index 38e26ae2a..12fe61c1c 100644 --- a/docs/awsaccess.rst +++ b/docs/awsaccess.rst @@ -24,5 +24,19 @@ If for some reason you can't use conventional AWS configuration methods, you can aws_secret_access_key = 'my_secret_access_key' aws_session_token = 'my_session_token' # Optional, only for temporary credentials like those received when assuming a role +If you need to access DynamoDB passing for a specific AWS Role and so you need to perform an assume-role operation on your Role ARN you can configure it in the Model Meta class: + +.. code-block:: python + + from pynamodb.models import Model + + class MyModel(Model): + class Meta: + aws_sts_role_arn='arn:aws:iam::1234567:role/my-aws-role' + aws_sts_role_session_name='MySession' # Optional, by default it is PynamoDB + aws_sts_session_expiration=3600 # Optional, it is the session token duration in seconds. Default is 3600 + +Note that your environment variables credentials, or in case, credentials setted in the Model Meta class, need to have grant to perform the assume-role operation. + Finally, see the `AWS CLI documentation `_ for more details on how to pass credentials to botocore. \ No newline at end of file diff --git a/pynamodb/connection/base.py b/pynamodb/connection/base.py index 1ca3fcfcb..67a5d0612 100644 --- a/pynamodb/connection/base.py +++ b/pynamodb/connection/base.py @@ -273,7 +273,7 @@ def __init__(self, # Initialize empty STS Credentials if STS auth is configured self.sts_session = STSCredentials() self.aws_sts_role_arn = aws_sts_role_arn - self.aws_sts_role_session_name = aws_sts_role_session_name + self.aws_sts_role_session_name = aws_sts_role_session_name or 'PynamoDB' self.aws_sts_session_expiration = aws_sts_session_expiration or 3600 if region: @@ -538,15 +538,14 @@ def is_sts_session_required(self): if self.sts_session.expired(): self.assume_role_session() - return True - return False + return True def assume_role_session(self): sts = self.session.create_client('sts') sts_response = sts.assume_role( RoleArn=self.aws_sts_role_arn, - RoleSessionName=self.aws_sts_role_session_name or 'PynamoDB', + RoleSessionName=self.aws_sts_role_session_name, DurationSeconds=self.aws_sts_session_expiration ) self.sts_session = STSCredentials(**sts_response['Credentials']) diff --git a/tests/test_base_connection.py b/tests/test_base_connection.py index a4901350d..db01889c8 100644 --- a/tests/test_base_connection.py +++ b/tests/test_base_connection.py @@ -1,7 +1,9 @@ """ Tests for the base connection class """ +import time import base64 +from datetime import datetime, timedelta import json from unittest import mock, TestCase from unittest.mock import patch @@ -25,6 +27,9 @@ from pynamodb.settings import OperationSettings from .data import DESCRIBE_TABLE_DATA, GET_ITEM_DATA, LIST_TABLE_DATA from .deep_eq import deep_eq +import pytz + +utc=pytz.UTC PATCH_METHOD = 'pynamodb.connection.Connection._make_api_call' @@ -1665,3 +1670,57 @@ def test_update_time_to_live_fail(self): with patch(PATCH_METHOD) as req: req.side_effect = BotoCoreError self.assertRaises(TableError, conn.update_time_to_live, 'test table', 'my_ttl') + + def test_sts_assume_role_long_expiration_one_time_client(self): + with patch('pynamodb.connection.Connection.session') as session_mock: + sts_client = mock.MagicMock() + sts_client.assume_role.return_value = { 'Credentials': {'AccessKeyId': '12345', 'SecretAccessKey': '123456789', 'SessionToken': 'abcd', 'Expiration': datetime.now(tz=utc) + timedelta(hours=1) } } + session_mock.create_client.return_value = sts_client + conn = Connection( + aws_sts_role_arn='arn:aws:iam::1234567:role/my-aws-role' + ) + + self.assertIsNotNone(conn.client) + self.assertIsNotNone(conn.client) + + assert sts_client.assume_role.call_count == 1 + session_mock.create_client.assert_has_calls( + [ + mock.call('dynamodb', 'us-east-1', endpoint_url=None, config=mock.ANY, + aws_access_key_id='12345', + aws_secret_access_key='123456789', + aws_session_token='abcd') + ] + ) + + def test_sts_assume_role_short_expiration_two_time_client(self): + with patch('pynamodb.connection.Connection.session') as session_mock: + sts_client = mock.MagicMock() + sts_client.assume_role.side_effect = [ + { 'Credentials': {'AccessKeyId': '12345', 'SecretAccessKey': '123456789', 'SessionToken': 'abcd', 'Expiration': datetime.now(tz=utc) + timedelta(seconds=10) } }, + { 'Credentials': {'AccessKeyId': '12345', 'SecretAccessKey': '123456789', 'SessionToken': 'abcd', 'Expiration': datetime.now(tz=utc) + timedelta(seconds=30) } } + ] + session_mock.create_client.return_value = sts_client + conn = Connection( + aws_sts_role_arn='arn:aws:iam::1234567:role/my-aws-role' + ) + + # the session token will expire in 10 second, to expect the second call to assume role is made when asking for a client + self.assertIsNotNone(conn.client) + time.sleep(10) + self.assertIsNotNone(conn.client) + + assert sts_client.assume_role.call_count == 2 + session_mock.create_client.assert_has_calls( + [ + mock.call('dynamodb', 'us-east-1', endpoint_url=None, config=mock.ANY, + aws_access_key_id='12345', + aws_secret_access_key='123456789', + aws_session_token='abcd'), + mock.call('dynamodb', 'us-east-1', endpoint_url=None, config=mock.ANY, + aws_access_key_id='12345', + aws_secret_access_key='123456789', + aws_session_token='abcd'), + ], + any_order=True + )