diff --git a/requirements.txt b/requirements.txt index 0e580ec95..d5b3b225c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -32,5 +32,5 @@ psycogreen==1.0.2 psycopg2==2.9.9 redis==5.0.1 requests==2.31.0 -safe-eth-py[django]==6.0.0b2 +safe-eth-py[django]==6.0.0b3 web3==6.10.0 diff --git a/safe_transaction_service/contracts/management/commands/update_safe_contracts_logo.py b/safe_transaction_service/contracts/management/commands/update_safe_contracts_logo.py new file mode 100644 index 000000000..0df34a780 --- /dev/null +++ b/safe_transaction_service/contracts/management/commands/update_safe_contracts_logo.py @@ -0,0 +1,81 @@ +from django.core.files import File +from django.core.management import BaseCommand, CommandError + +from gnosis.eth import EthereumClientProvider +from gnosis.safe.safe_deployments import safe_deployments + +from safe_transaction_service.contracts.models import Contract + + +def get_deployment_addresses(safe_deployments: dict, chain_id: str) -> list: + """ + Return the deployment addresses of passed dict for a chain id. + :param safe_deployments: + :param chain_id: + :return: + """ + addresses = [] + if isinstance(safe_deployments, dict): + for key, value in safe_deployments.items(): + if isinstance(value, dict): + addresses.extend(get_deployment_addresses(value, chain_id)) + elif key == chain_id: + addresses.append(value) + return addresses + + +class Command(BaseCommand): + help = "Update safe contract logos by new one" + + def add_arguments(self, parser): + parser.add_argument( + "--safe-version", type=str, help="Contract version", required=False + ) + parser.add_argument( + "--logo-path", type=str, help="Path of new logo", required=True + ) + + def handle(self, *args, **options): + """ + Command to add or update safe contract logos if exist. + :param args: + :param options: safe version and logo path + :return: + """ + safe_version = options["safe_version"] + logo_path = options["logo_path"] + ethereum_client = EthereumClientProvider() + chain_id = str(ethereum_client.get_chain_id()) + + if not safe_version: + addresses = get_deployment_addresses(safe_deployments, chain_id) + elif safe_version in safe_deployments: + addresses = get_deployment_addresses( + safe_deployments[safe_version], chain_id + ) + else: + raise CommandError( + f"Wrong Safe version {safe_version}, supported versions {safe_deployments.keys()}" + ) + + for contract_address in addresses: + try: + contract = Contract.objects.get(address=contract_address) + # Remove previous one if exist + contract.logo.delete(save=True) + contract.logo.save( + f"{contract.address}.png", File(open(logo_path, "rb")) + ) + contract.save() + self.stdout.write( + self.style.SUCCESS( + f"Contract {contract_address} successfully updated" + ) + ) + except Contract.DoesNotExist: + self.stdout.write( + self.style.WARNING( + f"Contract {contract_address} does not exist on database" + ) + ) + continue diff --git a/safe_transaction_service/contracts/tests/test_commands.py b/safe_transaction_service/contracts/tests/test_commands.py index c65698691..cdb4a49fc 100644 --- a/safe_transaction_service/contracts/tests/test_commands.py +++ b/safe_transaction_service/contracts/tests/test_commands.py @@ -1,8 +1,19 @@ from io import StringIO +from unittest.mock import patch from django.core.management import call_command from django.test import TestCase +from gnosis.eth import EthereumClient +from gnosis.safe.safe_deployments import safe_deployments + +from config.settings.base import STATIC_ROOT +from safe_transaction_service.contracts.management.commands.update_safe_contracts_logo import ( + get_deployment_addresses, +) +from safe_transaction_service.contracts.models import Contract +from safe_transaction_service.contracts.tests.factories import ContractFactory + class TestCommands(TestCase): def test_index_contracts_with_metadata(self): @@ -21,3 +32,41 @@ def test_index_contracts_with_metadata(self): "Calling `reindex_contracts_without_metadata_task` task", buf.getvalue() ) self.assertIn("Processing finished", buf.getvalue()) + + @patch.object(EthereumClient, "get_chain_id", autospec=True, return_value=5) + def test_update_safe_contracts_logo(self, mock_chain_id): + command = "update_safe_contracts_logo" + buf = StringIO() + multisend_address = "0xA238CBeb142c10Ef7Ad8442C6D1f9E89e07e7761" + random_contract = ContractFactory() + previous_random_contract_logo = random_contract.logo.read() + previous_multisend_logo: bytes = ContractFactory( + address=multisend_address + ).logo.read() + call_command(command, f"--logo-path={STATIC_ROOT}/safe/logo.png", stdout=buf) + current_multisend_logo: bytes = Contract.objects.get( + address=multisend_address + ).logo.read() + self.assertNotEqual(current_multisend_logo, previous_multisend_logo) + # No safe contract logos should keep unchanged + current_no_safe_contract_logo: bytes = Contract.objects.get( + address=random_contract.address + ).logo.read() + self.assertEqual(current_no_safe_contract_logo, previous_random_contract_logo) + + def test_get_deployment_addresses(self): + expected_result = [ + "0x29fcB43b46531BcA003ddC8FCB67FFE91900C762", + "0x38869bf66a61cF6bDB996A6aE40D5853Fd43B526", + "0x9641d764fc13c8B624c04430C7356C1C7C8102e2", + "0x41675C099F32341bf84BFc5382aF534df5C7461a", + "0x3d4BA2E0884aa488718476ca2FB8Efc291A46199", + "0x4e1DCf7AD4e460CfD30791CCC4F9c8a4f820ec67", + "0x9b35Af71d77eaf8d7e40252370304687390A1A52", + "0xfd0732Dc9E303f09fCEf3a7388Ad10A83459Ec99", + "0xd53cd0aB83D845Ac265BE939c57F53AD838012c9", + ] + + result = get_deployment_addresses(safe_deployments["1.4.1"], "1") + + self.assertEqual(expected_result, result)