Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added new sts role configuration parameter to assume aws role dynamic… #944

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions docs/awsaccess.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,19 @@ If for some reason you can't use conventional AWS configuration methods, you can
aws_secret_access_key = 'my_secret_access_key'
aws_session_token = 'my_session_token' # Optional, only for temporary credentials like those received when assuming a role

If you need to access DynamoDB passing for a specific AWS Role and so you need to perform an assume-role operation on your Role ARN you can configure it in the Model Meta class:

.. code-block:: python

from pynamodb.models import Model

class MyModel(Model):
class Meta:
aws_sts_role_arn='arn:aws:iam::1234567:role/my-aws-role'
aws_sts_role_session_name='MySession' # Optional, by default it is PynamoDB
aws_sts_session_expiration=3600 # Optional, it is the session token duration in seconds. Default is 3600

Note that your environment variables credentials, or in case, credentials setted in the Model Meta class, need to have grant to perform the assume-role operation.

Finally, see the `AWS CLI documentation <http://docs.aws.amazon.com/cli/latest/userguide/cli-chap-getting-started.html#cli-installing-credentials>`_
for more details on how to pass credentials to botocore.
58 changes: 55 additions & 3 deletions pynamodb/connection/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""
Lowest level connection
"""
from datetime import datetime

import json
import logging
import random
Expand Down Expand Up @@ -54,6 +56,9 @@
from pynamodb.settings import get_settings_value, OperationSettings
from pynamodb.signals import pre_dynamodb_send, post_dynamodb_send
from pynamodb.types import HASH, RANGE
import pytz

utc=pytz.UTC

BOTOCORE_EXCEPTIONS = (BotoCoreError, ClientError)
RATE_LIMITING_ERROR_CODES = ['ProvisionedThroughputExceededException', 'ThrottlingException']
Expand Down Expand Up @@ -233,6 +238,15 @@ def get_exclusive_start_key_map(self, exclusive_start_key):
}
}

class STSCredentials:
def __init__(self, **kwargs):
self.access_key = kwargs.get('AccessKeyId')
self.secret_key = kwargs.get('SecretAccessKey')
self.token = kwargs.get('SessionToken')
self.token_expiration = kwargs.get('Expiration', datetime.now(tz=utc))

def expired(self):
return self.token_expiration <= datetime.now(tz=utc)

class Connection(object):
"""
Expand All @@ -247,11 +261,21 @@ def __init__(self,
max_retry_attempts: Optional[int] = None,
base_backoff_ms: Optional[int] = None,
max_pool_connections: Optional[int] = None,
extra_headers: Optional[Mapping[str, str]] = None):
extra_headers: Optional[Mapping[str, str]] = None,
aws_sts_role_arn=None, aws_sts_role_session_name=None,
aws_sts_session_expiration=None):
self._tables: Dict[str, MetaTable] = {}
self.host = host
self._local = local()
self._client = None

if aws_sts_role_arn is not None:
# Initialize empty STS Credentials if STS auth is configured
self.sts_session = STSCredentials()
self.aws_sts_role_arn = aws_sts_role_arn
self.aws_sts_role_session_name = aws_sts_role_session_name or 'PynamoDB'
self.aws_sts_session_expiration = aws_sts_session_expiration or 3600

if region:
self.region = region
else:
Expand Down Expand Up @@ -508,6 +532,24 @@ def _handle_binary_attributes(data):
_convert_binary(attr)
return data

def is_sts_session_required(self):
if not hasattr(self, 'sts_session'):
return False

if self.sts_session.expired():
self.assume_role_session()

return True

def assume_role_session(self):
sts = self.session.create_client('sts')
sts_response = sts.assume_role(
RoleArn=self.aws_sts_role_arn,
RoleSessionName=self.aws_sts_role_session_name,
DurationSeconds=self.aws_sts_session_expiration
)
self.sts_session = STSCredentials(**sts_response['Credentials'])

@property
def session(self) -> botocore.session.Session:
"""
Expand All @@ -527,13 +569,23 @@ def client(self):
# https://github.com/boto/botocore/blob/4d55c9b4142/botocore/credentials.py#L1016-L1021
# if the client does not have credentials, we create a new client
# otherwise the client is permanently poisoned in the case of metadata service flakiness when using IAM roles
if not self._client or (self._client._request_signer and not self._client._request_signer._credentials):
if not self._client or (self._client._request_signer and not self._client._request_signer._credentials) or self.is_sts_session_required():
config = botocore.client.Config(
parameter_validation=False, # Disable unnecessary validation for performance
connect_timeout=self._connect_timeout_seconds,
read_timeout=self._read_timeout_seconds,
max_pool_connections=self._max_pool_connections)
self._client = self.session.create_client(SERVICE_NAME, self.region, endpoint_url=self.host, config=config)

credentials = self.sts_session if self.is_sts_session_required() else self.session.get_credentials()

self._client = self.session.create_client(
SERVICE_NAME,
self.region,
endpoint_url=self.host,
config=config,
aws_access_key_id=credentials.access_key,
aws_secret_access_key=credentials.secret_key,
aws_session_token=credentials.token)
return self._client

def get_meta_table(self, table_name: str, refresh: bool = False):
Expand Down
38 changes: 22 additions & 16 deletions pynamodb/connection/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,24 @@ class TableConnection:
A higher level abstraction over botocore
"""

def __init__(
self,
table_name: str,
region: Optional[str] = None,
host: Optional[str] = None,
connect_timeout_seconds: Optional[float] = None,
read_timeout_seconds: Optional[float] = None,
max_retry_attempts: Optional[int] = None,
base_backoff_ms: Optional[int] = None,
max_pool_connections: Optional[int] = None,
extra_headers: Optional[Mapping[str, str]] = None,
aws_access_key_id: Optional[str] = None,
aws_secret_access_key: Optional[str] = None,
aws_session_token: Optional[str] = None,
) -> None:
def __init__(self,
table_name,
region=None,
host=None,
connect_timeout_seconds=None,
read_timeout_seconds=None,
max_retry_attempts=None,
base_backoff_ms=None,
max_pool_connections=None,
extra_headers=None,
aws_access_key_id=None,
aws_secret_access_key=None,
aws_session_token=None,
aws_sts_role_arn=None,
aws_sts_role_session_name=None,
aws_sts_session_expiration=None):
self._hash_keyname = None
self._range_keyname = None
self.table_name = table_name
self.connection = Connection(region=region,
host=host,
Expand All @@ -39,7 +42,10 @@ def __init__(
max_retry_attempts=max_retry_attempts,
base_backoff_ms=base_backoff_ms,
max_pool_connections=max_pool_connections,
extra_headers=extra_headers)
extra_headers=extra_headers,
aws_sts_role_arn=aws_sts_role_arn,
aws_sts_role_session_name=aws_sts_role_session_name,
aws_sts_session_expiration=aws_sts_session_expiration)

if aws_access_key_id and aws_secret_access_key:
self.connection.session.set_credentials(aws_access_key_id,
Expand Down
38 changes: 26 additions & 12 deletions pynamodb/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,9 @@ class MetaProtocol(Protocol):
billing_mode: Optional[str]
tags: Optional[Dict[str, str]]
stream_view_type: Optional[str]
aws_sts_role_arn: Optional[str]
aws_sts_role_session_name: Optional[str]
aws_sts_session_expiration: Optional[int]


class MetaModel(AttributeContainerMeta):
Expand Down Expand Up @@ -253,6 +256,12 @@ def __init__(self, name, bases, namespace, discriminator=None) -> None:
setattr(attr_obj, 'aws_secret_access_key', None)
if not hasattr(attr_obj, 'aws_session_token'):
setattr(attr_obj, 'aws_session_token', None)
if not hasattr(attr_obj, 'aws_sts_role_arn'):
setattr(attr_obj, 'aws_sts_role_arn', None)
if not hasattr(attr_obj, 'aws_sts_role_session_name'):
setattr(attr_obj, 'aws_sts_role_session_name', None)
if not hasattr(attr_obj, 'aws_sts_session_expiration'):
setattr(attr_obj, 'aws_sts_session_expiration', None)
elif isinstance(attr_obj, Index):
attr_obj.Meta.model = cls
if not hasattr(attr_obj.Meta, "index_name"):
Expand Down Expand Up @@ -1079,18 +1088,23 @@ def _get_connection(cls) -> TableConnection:
# For now we just check that the connection exists and (in the case of model inheritance)
# points to the same table. In the future we should update the connection if any of the attributes differ.
if cls._connection is None or cls._connection.table_name != cls.Meta.table_name:
cls._connection = TableConnection(cls.Meta.table_name,
region=cls.Meta.region,
host=cls.Meta.host,
connect_timeout_seconds=cls.Meta.connect_timeout_seconds,
read_timeout_seconds=cls.Meta.read_timeout_seconds,
max_retry_attempts=cls.Meta.max_retry_attempts,
base_backoff_ms=cls.Meta.base_backoff_ms,
max_pool_connections=cls.Meta.max_pool_connections,
extra_headers=cls.Meta.extra_headers,
aws_access_key_id=cls.Meta.aws_access_key_id,
aws_secret_access_key=cls.Meta.aws_secret_access_key,
aws_session_token=cls.Meta.aws_session_token)
cls._connection = TableConnection(
cls.Meta.table_name,
region=cls.Meta.region,
host=cls.Meta.host,
connect_timeout_seconds=cls.Meta.connect_timeout_seconds,
read_timeout_seconds=cls.Meta.read_timeout_seconds,
max_retry_attempts=cls.Meta.max_retry_attempts,
base_backoff_ms=cls.Meta.base_backoff_ms,
max_pool_connections=cls.Meta.max_pool_connections,
extra_headers=cls.Meta.extra_headers,
aws_access_key_id=cls.Meta.aws_access_key_id,
aws_secret_access_key=cls.Meta.aws_secret_access_key,
aws_session_token=cls.Meta.aws_session_token,
aws_sts_role_arn=cls.Meta.aws_sts_role_arn,
aws_sts_role_session_name=cls.Meta.aws_sts_role_session_name,
aws_sts_session_expiration=cls.Meta.aws_sts_session_expiration
)
return cls._connection

@classmethod
Expand Down
78 changes: 75 additions & 3 deletions tests/test_base_connection.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""
Tests for the base connection class
"""
import time
import base64
from datetime import datetime, timedelta
import json
from unittest import mock, TestCase
from unittest.mock import patch
Expand All @@ -25,6 +27,9 @@
from pynamodb.settings import OperationSettings
from .data import DESCRIBE_TABLE_DATA, GET_ITEM_DATA, LIST_TABLE_DATA
from .deep_eq import deep_eq
import pytz

utc=pytz.UTC

PATCH_METHOD = 'pynamodb.connection.Connection._make_api_call'

Expand Down Expand Up @@ -86,8 +91,14 @@ def test_subsequent_client_is_not_cached_when_credentials_none(self):

session_mock.create_client.assert_has_calls(
[
mock.call('dynamodb', 'us-east-1', endpoint_url=None, config=mock.ANY),
mock.call('dynamodb', 'us-east-1', endpoint_url=None, config=mock.ANY),
mock.call('dynamodb', 'us-east-1', endpoint_url=None, config=mock.ANY,
aws_access_key_id=mock.ANY,
aws_secret_access_key=mock.ANY,
aws_session_token=mock.ANY),
mock.call('dynamodb', 'us-east-1', endpoint_url=None, config=mock.ANY,
aws_access_key_id=mock.ANY,
aws_secret_access_key=mock.ANY,
aws_session_token=mock.ANY),
],
any_order=True
)
Expand All @@ -102,7 +113,14 @@ def test_subsequent_client_is_cached_when_credentials_truthy(self):
self.assertIsNotNone(conn.client)

self.assertEqual(
session_mock.create_client.mock_calls.count(mock.call('dynamodb', 'us-east-1', endpoint_url=None, config=mock.ANY)),
session_mock.create_client.mock_calls.count(mock.call(
'dynamodb',
'us-east-1',
endpoint_url=None,
config=mock.ANY,
aws_access_key_id=mock.ANY,
aws_secret_access_key=mock.ANY,
aws_session_token=mock.ANY)),
1
)

Expand Down Expand Up @@ -1652,3 +1670,57 @@ def test_update_time_to_live_fail(self):
with patch(PATCH_METHOD) as req:
req.side_effect = BotoCoreError
self.assertRaises(TableError, conn.update_time_to_live, 'test table', 'my_ttl')

def test_sts_assume_role_long_expiration_one_time_client(self):
with patch('pynamodb.connection.Connection.session') as session_mock:
sts_client = mock.MagicMock()
sts_client.assume_role.return_value = { 'Credentials': {'AccessKeyId': '12345', 'SecretAccessKey': '123456789', 'SessionToken': 'abcd', 'Expiration': datetime.now(tz=utc) + timedelta(hours=1) } }
session_mock.create_client.return_value = sts_client
conn = Connection(
aws_sts_role_arn='arn:aws:iam::1234567:role/my-aws-role'
)

self.assertIsNotNone(conn.client)
self.assertIsNotNone(conn.client)

assert sts_client.assume_role.call_count == 1
session_mock.create_client.assert_has_calls(
[
mock.call('dynamodb', 'us-east-1', endpoint_url=None, config=mock.ANY,
aws_access_key_id='12345',
aws_secret_access_key='123456789',
aws_session_token='abcd')
]
)

def test_sts_assume_role_short_expiration_two_time_client(self):
with patch('pynamodb.connection.Connection.session') as session_mock:
sts_client = mock.MagicMock()
sts_client.assume_role.side_effect = [
{ 'Credentials': {'AccessKeyId': '12345', 'SecretAccessKey': '123456789', 'SessionToken': 'abcd', 'Expiration': datetime.now(tz=utc) + timedelta(seconds=10) } },
{ 'Credentials': {'AccessKeyId': '12345', 'SecretAccessKey': '123456789', 'SessionToken': 'abcd', 'Expiration': datetime.now(tz=utc) + timedelta(seconds=30) } }
]
session_mock.create_client.return_value = sts_client
conn = Connection(
aws_sts_role_arn='arn:aws:iam::1234567:role/my-aws-role'
)

# the session token will expire in 10 second, to expect the second call to assume role is made when asking for a client
self.assertIsNotNone(conn.client)
time.sleep(10)
self.assertIsNotNone(conn.client)

assert sts_client.assume_role.call_count == 2
session_mock.create_client.assert_has_calls(
[
mock.call('dynamodb', 'us-east-1', endpoint_url=None, config=mock.ANY,
aws_access_key_id='12345',
aws_secret_access_key='123456789',
aws_session_token='abcd'),
mock.call('dynamodb', 'us-east-1', endpoint_url=None, config=mock.ANY,
aws_access_key_id='12345',
aws_secret_access_key='123456789',
aws_session_token='abcd'),
],
any_order=True
)