diff --git a/awscrt/crypto.py b/awscrt/crypto.py index cb158c9cc..c216ad48f 100644 --- a/awscrt/crypto.py +++ b/awscrt/crypto.py @@ -119,21 +119,21 @@ def rsa_public_key_from_pem_data(pem_data: Union[str, bytes, bytearray, memoryvi """ return RSA(native_handle=_awscrt.rsa_public_key_from_pem_data(pem_data)) - def encrypt(self, encryption_algorithm: RSAEncryptionAlgorithmType, + def encrypt(self, encryption_algorithm: RSAEncryptionAlgorithm, plaintext: Union[str, bytes, bytearray, memoryview]) -> bytes: """ Encrypts data using a given algorithm. """ return _awscrt.rsa_encrypt(self._rsa, encryption_algorithm, plaintext) - def decrypt(self, encryption_algorithm: RSAEncryptionAlgorithmType, + def decrypt(self, encryption_algorithm: RSAEncryptionAlgorithm, ciphertext: Union[str, bytes, bytearray, memoryview]) -> bytes: """ Decrypts data using a given algorithm. """ return _awscrt.rsa_decrypt(self._rsa, encryption_algorithm, ciphertext) - def sign(self, encryption_algorithm: RSASignatureAlgorithmType, + def sign(self, encryption_algorithm: RSASignatureAlgorithm, digest: Union[str, bytes, bytearray, memoryview]) -> bytes: """ Signs data using a given algorithm. @@ -141,7 +141,7 @@ def sign(self, encryption_algorithm: RSASignatureAlgorithmType, """ return _awscrt.rsa_sign(self._rsa, encryption_algorithm, digest) - def verify(self, encryption_algorithm: RSASignatureAlgorithmType, + def verify(self, encryption_algorithm: RSASignatureAlgorithm, digest: Union[str, bytes, bytearray, memoryview], signature: Union[str, bytes, bytearray, memoryview]) -> bool: """ diff --git a/test/test_crypto.py b/test/test_crypto.py index 984de0f91..606994737 100644 --- a/test/test_crypto.py +++ b/test/test_crypto.py @@ -3,7 +3,7 @@ from test import NativeResourceTest -from awscrt.crypto import Hash, RSA, RSAEncryptionAlgorithmType, RSASignatureAlgorithmType +from awscrt.crypto import Hash, RSA, RSAEncryptionAlgorithm, RSASignatureAlgorithm import unittest RSA_PRIVATE_KEY_PEM = """ @@ -117,9 +117,9 @@ def test_md5_iterated(self): self.assertEqual(expected, digest) def test_rsa_encryption_roundtrip(self): - param_list = [RSAEncryptionAlgorithmType.PKCS1_5, - RSAEncryptionAlgorithmType.OAEP_SHA256, - RSAEncryptionAlgorithmType.OAEP_SHA512] + param_list = [RSAEncryptionAlgorithm.PKCS1_5, + RSAEncryptionAlgorithm.OAEP_SHA256, + RSAEncryptionAlgorithm.OAEP_SHA512] for p in param_list: with self.subTest(msg="RSA Encryption Roundtrip using algo p", p=p): @@ -139,8 +139,8 @@ def test_rsa_signing_roundtrip(self): h.update(b'totally original test string') digest = h.digest() - param_list = [RSASignatureAlgorithmType.PKCS1_5_SHA256, - RSASignatureAlgorithmType.PSS_SHA256] + param_list = [RSASignatureAlgorithm.PKCS1_5_SHA256, + RSASignatureAlgorithm.PSS_SHA256] for p in param_list: with self.subTest(msg="RSA Signing Roundtrip using algo p", p=p): @@ -168,9 +168,9 @@ def test_rsa_signing_verify_fail(self): digest2 = h2.digest() rsa = RSA.rsa_private_key_from_pem_data(RSA_PRIVATE_KEY_PEM) - signature = rsa.sign(RSASignatureAlgorithmType.PKCS1_5_SHA256, digest) - self.assertEqual(rsa.verify(RSASignatureAlgorithmType.PKCS1_5_SHA256, digest2, signature), False) - self.assertEqual(rsa.verify(RSASignatureAlgorithmType.PKCS1_5_SHA256, digest, b'bad signature'), False) + signature = rsa.sign(RSASignatureAlgorithm.PKCS1_5_SHA256, digest) + self.assertEqual(rsa.verify(RSASignatureAlgorithm.PKCS1_5_SHA256, digest2, signature), False) + self.assertEqual(rsa.verify(RSASignatureAlgorithm.PKCS1_5_SHA256, digest, b'bad signature'), False) if __name__ == '__main__':