Skip to content

Commit

Permalink
chore: run black and isort on create_assets
Browse files Browse the repository at this point in the history
  • Loading branch information
Ian2012 committed Nov 14, 2024
1 parent 6e0272d commit ea2aa1e
Showing 1 changed file with 51 additions and 35 deletions.
Original file line number Diff line number Diff line change
@@ -1,39 +1,37 @@
"""Import a list of assets from a yaml file and create them in the superset assets folder."""

import os

from superset.app import create_app

app = create_app()
app.app_context().push()

import logging
import uuid
import yaml
from collections import defaultdict
from copy import deepcopy
from pathlib import Path

import yaml
from openedx.create_assets_utils import load_configs_from_directory
from openedx.create_row_level_security import create_rls_filters
from openedx.delete_assets import delete_assets
from openedx.localization import get_translation
from sqlfmt.api import format_string
from sqlfmt.mode import Mode
from collections import defaultdict

from superset import security_manager
from superset.commands.utils import update_tags
from superset.connectors.sqla.models import SqlaTable
from superset.extensions import db
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
from superset.connectors.sqla.models import SqlaTable
from superset.models.embedded_dashboard import EmbeddedDashboard


from superset.tags.models import TagType, get_tag
from superset.commands.utils import update_tags
from superset.tags.models import ObjectType
from openedx.delete_assets import delete_assets
from openedx.create_assets_utils import load_configs_from_directory
from openedx.localization import get_translation
from openedx.create_row_level_security import create_rls_filters

from superset.models.slice import Slice
from superset.tags.models import ObjectType, TagType, get_tag

logger = logging.getLogger("create_assets")
# Supress output from black (sqlfmt) formatting
blib2to3_logger = logging.getLogger('blib2to3.pgen2.driver')
blib2to3_logger = logging.getLogger("blib2to3.pgen2.driver")
blib2to3_logger.setLevel(logging.WARN)

BASE_DIR = "/app/assets/superset"
Expand Down Expand Up @@ -77,7 +75,7 @@ def create_assets():
process_asset(asset, roles, translated_asset_uuids)

# Get extra assets
with open(os.path.join(PYTHONPATH,"assets.yaml"), "r") as file:
with open(os.path.join(PYTHONPATH, "assets.yaml"), "r") as file:
extra_assets = yaml.safe_load_all(file)

if extra_assets:
Expand All @@ -92,11 +90,14 @@ def create_assets():
create_rls_filters()

# Delete unused assets
with open(os.path.join(PYTHONPATH,"aspects_asset_list.yaml"), "r", encoding="utf-8") as file:
with open(
os.path.join(PYTHONPATH, "aspects_asset_list.yaml"), "r", encoding="utf-8"
) as file:
assets = yaml.safe_load(file)
unused_aspect_uuids = assets['unused_uuids']
unused_aspect_uuids = assets["unused_uuids"]
delete_assets(unused_aspect_uuids, translated_asset_uuids)


def process_asset(asset, roles, translated_asset_uuids):
if FILE_NAME_ATTRIBUTE not in asset:
raise Exception(f"Asset {asset} has no {FILE_NAME_ATTRIBUTE}")
Expand All @@ -105,7 +106,9 @@ def process_asset(asset, roles, translated_asset_uuids):
# Find the right folder to create the asset in
for asset_name, folder in ASSET_FOLDER_MAPPING.items():
if asset_name in asset:
write_asset_to_file(asset, asset_name, folder, file_name, roles, translated_asset_uuids)
write_asset_to_file(
asset, asset_name, folder, file_name, roles, translated_asset_uuids
)
return


Expand All @@ -117,16 +120,21 @@ def get_localized_uuid(base_uuid, language):
base_namespace = uuid.uuid5(base_uuid, "superset")
normalized_language = language.lower().replace("-", "_")
return str(uuid.uuid5(base_namespace, normalized_language))

def write_asset_to_file(asset, asset_name, folder, file_name, roles, translated_asset_uuids):


def write_asset_to_file(
asset, asset_name, folder, file_name, roles, translated_asset_uuids
):
"""Write an asset to a file and generated translated assets"""
if folder == "databases":
# Update the sqlalchery_uri from the asset override pre-generated values
asset["sqlalchemy_uri"] = DATABASES.get(asset["database_name"])
if folder in ["charts", "dashboards", "datasets"]:
for locale in DASHBOARD_LOCALES:
if folder == "datasets":
asset["sql"] = format_string(asset["sql"], mode=Mode(dialect_name="clickhouse"))
asset["sql"] = format_string(
asset["sql"], mode=Mode(dialect_name="clickhouse")
)

updated_asset = generate_translated_asset(
asset, asset_name, folder, locale, roles, translated_asset_uuids
Expand All @@ -135,7 +143,9 @@ def write_asset_to_file(asset, asset_name, folder, file_name, roles, translated_
# Clean up old localized dashboards
if folder == "dashboards":
dashboard_slug = updated_asset["slug"]
dashboard = db.session.query(Dashboard).filter_by(slug=dashboard_slug).first()
dashboard = (
db.session.query(Dashboard).filter_by(slug=dashboard_slug).first()
)
if dashboard:
db.session.delete(dashboard)
path = f"{BASE_DIR}/{folder}/{file_name.split('.')[0]}-{locale}.yaml"
Expand All @@ -146,7 +156,7 @@ def write_asset_to_file(asset, asset_name, folder, file_name, roles, translated_
# access the original dashboards.
dashboard_roles = asset.pop("_roles", None)
if dashboard_roles:
roles[asset["uuid"]] = ([security_manager.find_role("Admin")], 'original')
roles[asset["uuid"]] = ([security_manager.find_role("Admin")], "original")

# Clean up old un-localized dashboard
dashboard_slug = asset.get("slug")
Expand All @@ -162,15 +172,17 @@ def write_asset_to_file(asset, asset_name, folder, file_name, roles, translated_
db.session.commit()


def generate_translated_asset(asset, asset_name, folder, language, roles, translated_asset_uuids):
def generate_translated_asset(
asset, asset_name, folder, language, roles, translated_asset_uuids
):
"""Generate a translated asset with their elements updated"""
copy = deepcopy(asset)
parent_uuid = copy['uuid']
parent_uuid = copy["uuid"]
copy["uuid"] = str(get_localized_uuid(copy["uuid"], language))
copy[asset_name] = get_translation(copy[asset_name], language)

# Save parent & translated uuids in yaml file
translated_asset_uuids[parent_uuid].add(copy['uuid'])
translated_asset_uuids[parent_uuid].add(copy["uuid"])
if folder == "dashboards":
copy["slug"] = f"{copy['slug']}-{language}"
copy["description"] = get_translation(copy["description"], language)
Expand All @@ -181,9 +193,10 @@ def generate_translated_asset(asset, asset_name, folder, language, roles, transl
for role in dashboard_roles:
translated_dashboard_roles.append(f"{role} - {language}")

roles[copy["uuid"]] = ([
(security_manager.find_role(role)) for role in translated_dashboard_roles
], language)
roles[copy["uuid"]] = (
[(security_manager.find_role(role)) for role in translated_dashboard_roles],
language,
)

generate_translated_dashboard_elements(copy, language)
generate_translated_dashboard_filters(copy, language)
Expand Down Expand Up @@ -284,6 +297,7 @@ def update_dashboard_roles(roles):

db.session.commit()


def get_owners():
owners_username = {{SUPERSET_OWNERS}}
owners = []
Expand All @@ -293,6 +307,7 @@ def get_owners():
owners.append(user)
return owners


def update_embeddable_uuids():
"""Update the uuids of the embeddable dashboards"""
for dashboard_slug, embeddable_uuid in EMBEDDABLE_DASHBOARDS.items():
Expand All @@ -313,7 +328,9 @@ def create_embeddable_dashboard_by_slug(dashboard_slug, embeddable_uuid):
logger.info(f"WARNING: Dashboard {dashboard_slug} not found")
return

embedded_dashboard = db.session.query(EmbeddedDashboard).filter_by(dashboard_id=dashboard.id).first()
embedded_dashboard = (
db.session.query(EmbeddedDashboard).filter_by(dashboard_id=dashboard.id).first()
)
if embedded_dashboard is None:
embedded_dashboard = EmbeddedDashboard()
embedded_dashboard.dashboard_id = dashboard.id
Expand All @@ -322,13 +339,12 @@ def create_embeddable_dashboard_by_slug(dashboard_slug, embeddable_uuid):
db.session.add(embedded_dashboard)
db.session.commit()


def update_datasets():
"""Update the datasets"""
logger.info("Refreshing datasets")
if {{SUPERSET_REFRESH_DATASETS}}:
datasets = (
db.session.query(SqlaTable).all()
)
datasets = db.session.query(SqlaTable).all()
for dataset in datasets:
logger.info(f"Refreshing dataset {dataset.table_name}")
dataset.fetch_metadata(commit=True)
Expand Down

0 comments on commit ea2aa1e

Please sign in to comment.