Skip to content

Commit

Permalink
Add command to update safe contract logos
Browse files Browse the repository at this point in the history
  • Loading branch information
moisses89 committed Oct 16, 2023
1 parent ec32be0 commit 7fa8576
Show file tree
Hide file tree
Showing 3 changed files with 131 additions and 1 deletion.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,5 @@ psycogreen==1.0.2
psycopg2==2.9.9
redis==5.0.1
requests==2.31.0
safe-eth-py[django]==6.0.0b2
safe-eth-py[django]==6.0.0b3
web3==6.10.0
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from django.core.files import File
from django.core.management import BaseCommand, CommandError

from gnosis.eth import EthereumClientProvider
from gnosis.safe.safe_deployments import safe_deployments

from safe_transaction_service.contracts.models import Contract


def get_deployment_addresses(safe_deployments: dict, chain_id: str) -> list:
"""
Return the deployment addresses of passed dict for a chain id.
:param safe_deployments:
:param chain_id:
:return:
"""
addresses = []
if isinstance(safe_deployments, dict):
for key, value in safe_deployments.items():
if isinstance(value, dict):
addresses.extend(get_deployment_addresses(value, chain_id))
elif key == chain_id:
addresses.append(value)
return addresses


class Command(BaseCommand):
help = "Update safe contract logos by new one"

def add_arguments(self, parser):
parser.add_argument(
"--safe-version", type=str, help="Contract version", required=False
)
parser.add_argument(
"--logo-path", type=str, help="Path of new logo", required=True
)

def handle(self, *args, **options):
"""
Command to add or update safe contract logos if exist.
:param args:
:param options: safe version and logo path
:return:
"""
safe_version = options["safe_version"]
logo_path = options["logo_path"]
ethereum_client = EthereumClientProvider()
chain_id = str(ethereum_client.get_chain_id())

if not safe_version:
addresses = get_deployment_addresses(safe_deployments, chain_id)
elif safe_version in safe_deployments:
addresses = get_deployment_addresses(
safe_deployments[safe_version], chain_id
)
else:
raise CommandError(
f"Wrong Safe version {safe_version}, supported versions {safe_deployments.keys()}"
)

for contract_address in addresses:
try:
contract = Contract.objects.get(address=contract_address)
# Remove previous one if exist
contract.logo.delete(save=True)
contract.logo.save(
f"{contract.address}.png", File(open(logo_path, "rb"))
)
contract.save()
self.stdout.write(
self.style.SUCCESS(
f"Contract {contract_address} successfully updated"
)
)
except Contract.DoesNotExist:
self.stdout.write(
self.style.WARNING(
f"Contract {contract_address} does not exist on database"
)
)
continue
49 changes: 49 additions & 0 deletions safe_transaction_service/contracts/tests/test_commands.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,19 @@
from io import StringIO
from unittest.mock import patch

from django.core.management import call_command
from django.test import TestCase

from gnosis.eth import EthereumClient
from gnosis.safe.safe_deployments import safe_deployments

from config.settings.base import STATIC_ROOT
from safe_transaction_service.contracts.management.commands.update_safe_contracts_logo import (
get_deployment_addresses,
)
from safe_transaction_service.contracts.models import Contract
from safe_transaction_service.contracts.tests.factories import ContractFactory


class TestCommands(TestCase):
def test_index_contracts_with_metadata(self):
Expand All @@ -21,3 +32,41 @@ def test_index_contracts_with_metadata(self):
"Calling `reindex_contracts_without_metadata_task` task", buf.getvalue()
)
self.assertIn("Processing finished", buf.getvalue())

@patch.object(EthereumClient, "get_chain_id", autospec=True, return_value=5)
def test_update_safe_contracts_logo(self, mock_chain_id):
command = "update_safe_contracts_logo"
buf = StringIO()
multisend_address = "0xA238CBeb142c10Ef7Ad8442C6D1f9E89e07e7761"
random_contract = ContractFactory()
previous_random_contract_logo = random_contract.logo.read()
previous_multisend_logo: bytes = ContractFactory(
address=multisend_address
).logo.read()
call_command(command, f"--logo-path={STATIC_ROOT}/safe/logo.png", stdout=buf)
current_multisend_logo: bytes = Contract.objects.get(
address=multisend_address
).logo.read()
self.assertNotEqual(current_multisend_logo, previous_multisend_logo)
# No safe contract logos should keep unchanged
current_no_safe_contract_logo: bytes = Contract.objects.get(
address=random_contract.address
).logo.read()
self.assertEqual(current_no_safe_contract_logo, previous_random_contract_logo)

def test_get_deployment_addresses(self):
expected_result = [
"0x29fcB43b46531BcA003ddC8FCB67FFE91900C762",
"0x38869bf66a61cF6bDB996A6aE40D5853Fd43B526",
"0x9641d764fc13c8B624c04430C7356C1C7C8102e2",
"0x41675C099F32341bf84BFc5382aF534df5C7461a",
"0x3d4BA2E0884aa488718476ca2FB8Efc291A46199",
"0x4e1DCf7AD4e460CfD30791CCC4F9c8a4f820ec67",
"0x9b35Af71d77eaf8d7e40252370304687390A1A52",
"0xfd0732Dc9E303f09fCEf3a7388Ad10A83459Ec99",
"0xd53cd0aB83D845Ac265BE939c57F53AD838012c9",
]

result = get_deployment_addresses(safe_deployments["1.4.1"], "1")

self.assertEqual(expected_result, result)

0 comments on commit 7fa8576

Please sign in to comment.