diff --git a/docker/web/celery/worker/run.sh b/docker/web/celery/worker/run.sh index 5e5cffe80..7740338a9 100755 --- a/docker/web/celery/worker/run.sh +++ b/docker/web/celery/worker/run.sh @@ -19,6 +19,9 @@ if [ ${RUN_MIGRATIONS:-0} = 1 ]; then echo "==> $(date +%H:%M:%S) ==> Setting up service... " python manage.py setup_service + + echo "==> $(date +%H:%M:%S) ==> Setting contracts... " + python manage.py update_safe_contracts_logo fi echo "==> $(date +%H:%M:%S) ==> Check RPC connected matches previously used RPC... " diff --git a/requirements.txt b/requirements.txt index a383e5977..022cbb593 100644 --- a/requirements.txt +++ b/requirements.txt @@ -32,5 +32,5 @@ psycogreen==1.0.2 psycopg2==2.9.9 redis==5.0.1 requests==2.31.0 -safe-eth-py[django]==6.0.0b2 +safe-eth-py[django]==6.0.0b3 web3==6.11.1 diff --git a/safe_transaction_service/contracts/management/commands/update_safe_contracts_logo.py b/safe_transaction_service/contracts/management/commands/update_safe_contracts_logo.py new file mode 100644 index 000000000..3212ef37a --- /dev/null +++ b/safe_transaction_service/contracts/management/commands/update_safe_contracts_logo.py @@ -0,0 +1,101 @@ +from django.core.files import File +from django.core.management import BaseCommand, CommandError + +from gnosis.eth import EthereumClientProvider +from gnosis.safe.safe_deployments import safe_deployments + +from config.settings.base import STATICFILES_DIRS +from safe_transaction_service.contracts.models import Contract + + +def generate_safe_contract_display_name(contract_name: str, version: str) -> str: + """ + Generates the display name for Safe contract. + Append Safe at the beginning if the contract name doesn't contain Safe word and append the contract version at the end. + + :param contract_name: + :param version: + :return: display_name + """ + # Remove gnosis word + contract_name = contract_name.replace("Gnosis", "") + if "safe" not in contract_name.lower(): + return f"Safe: {contract_name} {version}" + else: + return f"{contract_name} {version}" + + +class Command(BaseCommand): + help = "Update or create Safe contracts with provided logo" + + def add_arguments(self, parser): + parser.add_argument( + "--safe-version", type=str, help="Contract version", required=False + ) + parser.add_argument( + "--force-update-contract-names", + help="Update all the safe contract names and display names", + action="store_true", + default=False, + ) + parser.add_argument( + "--logo-path", + type=str, + help="Path of new logo", + required=False, + default=f"{STATICFILES_DIRS[0]}/safe/safe_contract_logo.png", + ) + + def handle(self, *args, **options): + """ + Command to create or update Safe contracts with provided logo. + + :param args: + :param options: Safe version and logo path + :return: + """ + safe_version = options["safe_version"] + force_update_contract_names = options["force_update_contract_names"] + logo_path = options["logo_path"] + ethereum_client = EthereumClientProvider() + chain_id = ethereum_client.get_chain_id() + logo_file = File(open(logo_path, "rb")) + if not safe_version: + versions = list(safe_deployments.keys()) + elif safe_version in safe_deployments: + versions = [safe_version] + else: + raise CommandError( + f"Wrong Safe version {safe_version}, supported versions {safe_deployments.keys()}" + ) + + if force_update_contract_names: + # update all safe contract names + queryset = Contract.objects.update_or_create + else: + # only update the contracts with empty values + queryset = Contract.objects.get_or_create + + for version in versions: + for contract_name, addresses in safe_deployments[version].items(): + display_name = generate_safe_contract_display_name( + contract_name, version + ) + contract, created = queryset( + address=addresses[str(chain_id)], + defaults={ + "name": contract_name, + "display_name": display_name, + }, + ) + + if not created: + # Remove previous logo file + contract.logo.delete(save=True) + # update name only for contracts with empty names + if not force_update_contract_names and contract.name == "": + contract.display_name = display_name + contract.name = contract_name + + contract.logo.save(f"{contract.address}.png", logo_file) + contract.save() diff --git a/safe_transaction_service/contracts/tests/test_commands.py b/safe_transaction_service/contracts/tests/test_commands.py index c65698691..6fdce8e77 100644 --- a/safe_transaction_service/contracts/tests/test_commands.py +++ b/safe_transaction_service/contracts/tests/test_commands.py @@ -1,8 +1,14 @@ from io import StringIO +from unittest.mock import patch from django.core.management import call_command from django.test import TestCase +from gnosis.eth import EthereumClient + +from safe_transaction_service.contracts.models import Contract +from safe_transaction_service.contracts.tests.factories import ContractFactory + class TestCommands(TestCase): def test_index_contracts_with_metadata(self): @@ -21,3 +27,58 @@ def test_index_contracts_with_metadata(self): "Calling `reindex_contracts_without_metadata_task` task", buf.getvalue() ) self.assertIn("Processing finished", buf.getvalue()) + + @patch.object(EthereumClient, "get_chain_id", autospec=True, return_value=1) + def test_update_safe_contracts_logo(self, mock_chain_id): + command = "update_safe_contracts_logo" + buf = StringIO() + random_contract = ContractFactory() + previous_random_contract_logo = random_contract.logo.read() + multisend_address = "0xA238CBeb142c10Ef7Ad8442C6D1f9E89e07e7761" + multisend_contract = ContractFactory( + address=multisend_address, name="GnosisMultisend" + ) + multisend_contract_logo = multisend_contract.logo.read() + + call_command(command, stdout=buf) + current_multisend_contract = Contract.objects.get(address=multisend_address) + # Previous created contracts logo should be updated + self.assertNotEqual( + current_multisend_contract.logo.read(), multisend_contract_logo + ) + + # Previous created contracts name and display name should keep unchanged + self.assertEqual(multisend_contract.name, current_multisend_contract.name) + self.assertEqual( + multisend_contract.display_name, current_multisend_contract.display_name + ) + + # No safe contract logos should keep unchanged + current_no_safe_contract_logo: bytes = Contract.objects.get( + address=random_contract.address + ).logo.read() + self.assertEqual(current_no_safe_contract_logo, previous_random_contract_logo) + + # Missing safe addresses should be added + self.assertEqual(Contract.objects.count(), 28) + + # Contract name and display name should be correctly generated + safe_l2_141_address = "0x29fcB43b46531BcA003ddC8FCB67FFE91900C762" + contract = Contract.objects.get(address=safe_l2_141_address) + self.assertEqual(contract.name, "SafeL2") + self.assertEqual(contract.display_name, "SafeL2 1.4.1") + + safe_multisend_141_address = "0x38869bf66a61cF6bDB996A6aE40D5853Fd43B526" + contract = Contract.objects.get(address=safe_multisend_141_address) + self.assertEqual(contract.name, "MultiSend") + self.assertEqual(contract.display_name, "Safe: MultiSend 1.4.1") + + # Force to update contract names should update the name and display name of the contract + call_command( + command, + "--force-update-contract-names", + stdout=buf, + ) + contract = Contract.objects.get(address=multisend_address) + self.assertEqual(contract.name, "MultiSend") + self.assertEqual(contract.display_name, "Safe: MultiSend 1.3.0") diff --git a/safe_transaction_service/static/safe/safe_contract_logo.png b/safe_transaction_service/static/safe/safe_contract_logo.png new file mode 100644 index 000000000..14eb5bb02 Binary files /dev/null and b/safe_transaction_service/static/safe/safe_contract_logo.png differ