Skip to content

Commit

Permalink
Add command to update safe contract logos (#1657)
Browse files Browse the repository at this point in the history
Call default command to create missing contracts at service start
  • Loading branch information
moisses89 authored Oct 25, 2023
1 parent 74dbad0 commit c57ff93
Show file tree
Hide file tree
Showing 5 changed files with 166 additions and 1 deletion.
3 changes: 3 additions & 0 deletions docker/web/celery/worker/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ if [ ${RUN_MIGRATIONS:-0} = 1 ]; then

echo "==> $(date +%H:%M:%S) ==> Setting up service... "
python manage.py setup_service

echo "==> $(date +%H:%M:%S) ==> Setting contracts... "
python manage.py update_safe_contracts_logo
fi

echo "==> $(date +%H:%M:%S) ==> Check RPC connected matches previously used RPC... "
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,5 @@ psycogreen==1.0.2
psycopg2==2.9.9
redis==5.0.1
requests==2.31.0
safe-eth-py[django]==6.0.0b2
safe-eth-py[django]==6.0.0b3
web3==6.11.1
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
from django.core.files import File
from django.core.management import BaseCommand, CommandError

from gnosis.eth import EthereumClientProvider
from gnosis.safe.safe_deployments import safe_deployments

from config.settings.base import STATICFILES_DIRS
from safe_transaction_service.contracts.models import Contract


def generate_safe_contract_display_name(contract_name: str, version: str) -> str:
"""
Generates the display name for Safe contract.
Append Safe at the beginning if the contract name doesn't contain Safe word and append the contract version at the end.
:param contract_name:
:param version:
:return: display_name
"""
# Remove gnosis word
contract_name = contract_name.replace("Gnosis", "")
if "safe" not in contract_name.lower():
return f"Safe: {contract_name} {version}"
else:
return f"{contract_name} {version}"


class Command(BaseCommand):
help = "Update or create Safe contracts with provided logo"

def add_arguments(self, parser):
parser.add_argument(
"--safe-version", type=str, help="Contract version", required=False
)
parser.add_argument(
"--force-update-contract-names",
help="Update all the safe contract names and display names",
action="store_true",
default=False,
)
parser.add_argument(
"--logo-path",
type=str,
help="Path of new logo",
required=False,
default=f"{STATICFILES_DIRS[0]}/safe/safe_contract_logo.png",
)

def handle(self, *args, **options):
"""
Command to create or update Safe contracts with provided logo.
:param args:
:param options: Safe version and logo path
:return:
"""
safe_version = options["safe_version"]
force_update_contract_names = options["force_update_contract_names"]
logo_path = options["logo_path"]
ethereum_client = EthereumClientProvider()
chain_id = ethereum_client.get_chain_id()
logo_file = File(open(logo_path, "rb"))
if not safe_version:
versions = list(safe_deployments.keys())
elif safe_version in safe_deployments:
versions = [safe_version]
else:
raise CommandError(
f"Wrong Safe version {safe_version}, supported versions {safe_deployments.keys()}"
)

if force_update_contract_names:
# update all safe contract names
queryset = Contract.objects.update_or_create
else:
# only update the contracts with empty values
queryset = Contract.objects.get_or_create

for version in versions:
for contract_name, addresses in safe_deployments[version].items():
display_name = generate_safe_contract_display_name(
contract_name, version
)
contract, created = queryset(
address=addresses[str(chain_id)],
defaults={
"name": contract_name,
"display_name": display_name,
},
)

if not created:
# Remove previous logo file
contract.logo.delete(save=True)
# update name only for contracts with empty names
if not force_update_contract_names and contract.name == "":
contract.display_name = display_name
contract.name = contract_name

contract.logo.save(f"{contract.address}.png", logo_file)
contract.save()
61 changes: 61 additions & 0 deletions safe_transaction_service/contracts/tests/test_commands.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
from io import StringIO
from unittest.mock import patch

from django.core.management import call_command
from django.test import TestCase

from gnosis.eth import EthereumClient

from safe_transaction_service.contracts.models import Contract
from safe_transaction_service.contracts.tests.factories import ContractFactory


class TestCommands(TestCase):
def test_index_contracts_with_metadata(self):
Expand All @@ -21,3 +27,58 @@ def test_index_contracts_with_metadata(self):
"Calling `reindex_contracts_without_metadata_task` task", buf.getvalue()
)
self.assertIn("Processing finished", buf.getvalue())

@patch.object(EthereumClient, "get_chain_id", autospec=True, return_value=1)
def test_update_safe_contracts_logo(self, mock_chain_id):
command = "update_safe_contracts_logo"
buf = StringIO()
random_contract = ContractFactory()
previous_random_contract_logo = random_contract.logo.read()
multisend_address = "0xA238CBeb142c10Ef7Ad8442C6D1f9E89e07e7761"
multisend_contract = ContractFactory(
address=multisend_address, name="GnosisMultisend"
)
multisend_contract_logo = multisend_contract.logo.read()

call_command(command, stdout=buf)
current_multisend_contract = Contract.objects.get(address=multisend_address)
# Previous created contracts logo should be updated
self.assertNotEqual(
current_multisend_contract.logo.read(), multisend_contract_logo
)

# Previous created contracts name and display name should keep unchanged
self.assertEqual(multisend_contract.name, current_multisend_contract.name)
self.assertEqual(
multisend_contract.display_name, current_multisend_contract.display_name
)

# No safe contract logos should keep unchanged
current_no_safe_contract_logo: bytes = Contract.objects.get(
address=random_contract.address
).logo.read()
self.assertEqual(current_no_safe_contract_logo, previous_random_contract_logo)

# Missing safe addresses should be added
self.assertEqual(Contract.objects.count(), 28)

# Contract name and display name should be correctly generated
safe_l2_141_address = "0x29fcB43b46531BcA003ddC8FCB67FFE91900C762"
contract = Contract.objects.get(address=safe_l2_141_address)
self.assertEqual(contract.name, "SafeL2")
self.assertEqual(contract.display_name, "SafeL2 1.4.1")

safe_multisend_141_address = "0x38869bf66a61cF6bDB996A6aE40D5853Fd43B526"
contract = Contract.objects.get(address=safe_multisend_141_address)
self.assertEqual(contract.name, "MultiSend")
self.assertEqual(contract.display_name, "Safe: MultiSend 1.4.1")

# Force to update contract names should update the name and display name of the contract
call_command(
command,
"--force-update-contract-names",
stdout=buf,
)
contract = Contract.objects.get(address=multisend_address)
self.assertEqual(contract.name, "MultiSend")
self.assertEqual(contract.display_name, "Safe: MultiSend 1.3.0")
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit c57ff93

Please sign in to comment.