Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Allow different schema for tmp tables created during table materialization #664

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
10 changes: 10 additions & 0 deletions dbt/include/athena/macros/adapters/relation.sql
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,16 @@
{%- endcall %}
{%- endmacro %}

{% macro set_table_relation_schema(relation, schema) %}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you add some comments to this method?

{%- if temp_schema is not none -%}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
{%- if temp_schema is not none -%}
{%- if schema is not none -%}

or could you explain how do you get temp_schema here ?

{%- set relation = relation.incorporate(path={
"schema": schema
}) -%}
{%- do create_schema(relation) -%}
pierrebzl marked this conversation as resolved.
Show resolved Hide resolved
{% endif %}
{{ return(relation) }}
{% endmacro %}

{% macro make_temp_relation(base_relation, suffix='__dbt_tmp', temp_schema=none) %}
{%- set temp_identifier = base_relation.identifier ~ suffix -%}
{%- set temp_relation = base_relation.incorporate(path={"identifier": temp_identifier}) -%}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,14 @@
{%- set lf_grants = config.get('lf_grants') -%}

{%- set table_type = config.get('table_type', default='hive') | lower -%}
{%- set temp_schema = config.get('temp_schema') -%}
nicor88 marked this conversation as resolved.
Show resolved Hide resolved
{%- set old_relation = adapter.get_relation(database=database, schema=schema, identifier=identifier) -%}
{%- set old_tmp_relation = adapter.get_relation(identifier=identifier ~ '__ha',
schema=schema,
database=database) -%}
{%- if temp_schema is not none and old_tmp_relation is not none-%}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't fully get this check, why do you need to include a check for old_tmp_relation not being none here? If old_tmp_relation is none, we still want to create the new tmp table in the tmp schema.

{%- set old_tmp_relation = set_table_relation_schema(relation=old_tmp_relation, schema=temp_schema) -%}
{%- endif -%}
{%- set old_bkp_relation = adapter.get_relation(identifier=identifier ~ '__bkp',
schema=schema,
database=database) -%}
Expand All @@ -31,6 +35,9 @@
database=database,
s3_path_table_part=target_relation.identifier,
type='table') -%}
{%- if temp_schema is not none -%}
{%- set tmp_relation = set_table_relation_schema(relation=tmp_relation, schema=temp_schema) -%}
{%- endif -%}

{%- if (
table_type == 'hive'
Expand Down
6 changes: 3 additions & 3 deletions tests/functional/adapter/test_incremental_tmp_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import yaml
from tests.functional.adapter.utils.parse_dbt_run_output import (
extract_create_statement_table_names,
extract_running_create_statements,
extract_running_ddl_statements,
)

from dbt.contracts.results import RunStatus
Expand Down Expand Up @@ -61,7 +61,7 @@ def test__schema_tmp(self, project, capsys):
assert records_count_first_run == 1

out, _ = capsys.readouterr()
athena_running_create_statements = extract_running_create_statements(out, relation_name)
athena_running_create_statements = extract_running_ddl_statements(out, relation_name, "create table")

assert len(athena_running_create_statements) == 1

Expand Down Expand Up @@ -95,7 +95,7 @@ def test__schema_tmp(self, project, capsys):
assert records_count_incremental_run == 2

out, _ = capsys.readouterr()
athena_running_create_statements = extract_running_create_statements(out, relation_name)
athena_running_create_statements = extract_running_ddl_statements(out, relation_name, "create table")

assert len(athena_running_create_statements) == 1

Expand Down
118 changes: 118 additions & 0 deletions tests/functional/adapter/test_table_temp_schema_name.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import pytest
import yaml
from tests.functional.adapter.utils.parse_dbt_run_output import (
extract_create_statement_table_names,
extract_rename_statement_table_names,
extract_running_ddl_statements,
)

from dbt.contracts.results import RunStatus
from dbt.tests.util import run_dbt

models__iceberg_table = """
{{ config(
materialized='table',
table_type='iceberg',
temp_schema=var('temp_schema_name'),
)
}}

select
{{ var('test_id') }} as id
"""


class TestTableIcebergTableUnique:
@pytest.fixture(scope="class")
def models(self):
return {"models__iceberg_table.sql": models__iceberg_table}

def test__temp_schema_name_iceberg_table(self, project, capsys):
relation_name = "models__iceberg_table"
temp_schema_name = f"{project.test_schema}_tmp"
drop_temp_schema = f"drop schema if exists `{temp_schema_name}` cascade"
model_run_result_row_count_query = f"select count(*) as records from {project.test_schema}.{relation_name}"
model_run_result_test_id_query = f"select id from {project.test_schema}.{relation_name}"

vars_dict = {
"temp_schema_name": temp_schema_name,
"test_id": 1,
}

model_run = run_dbt(
[
"run",
"--select",
relation_name,
"--vars",
yaml.safe_dump(vars_dict),
"--log-level",
"debug",
"--log-format",
"json",
]
)

model_run_result = model_run.results[0]
assert model_run_result.status == RunStatus.Success

out, _ = capsys.readouterr()
athena_running_create_statements = extract_running_ddl_statements(out, relation_name, "create table")
assert len(athena_running_create_statements) == 1

incremental_model_run_result_table_name = extract_create_statement_table_names(
athena_running_create_statements[0]
)[0]
assert temp_schema_name not in incremental_model_run_result_table_name

model_records_count = project.run_sql(model_run_result_row_count_query, fetch="all")[0][0]
assert model_records_count == 1

model_test_id_in_table = project.run_sql(model_run_result_test_id_query, fetch="all")[0][0]
assert model_test_id_in_table == 1

vars_dict["test_id"] = 2

model_run = run_dbt(
[
"run",
"--select",
relation_name,
"--vars",
yaml.safe_dump(vars_dict),
"--log-level",
"debug",
"--log-format",
"json",
]
)

model_run_result = model_run.results[0]
assert model_run_result.status == RunStatus.Success

model_records_count = project.run_sql(model_run_result_row_count_query, fetch="all")[0][0]
assert model_records_count == 1

model_test_id_in_table = project.run_sql(model_run_result_test_id_query, fetch="all")[0][0]
assert model_test_id_in_table == 2

out, _ = capsys.readouterr()
athena_running_create_statements = extract_running_ddl_statements(out, relation_name, "create table")
assert len(athena_running_create_statements) == 1

model_run_2_result_table_name = extract_create_statement_table_names(athena_running_create_statements[0])[0]
assert temp_schema_name in model_run_2_result_table_name

athena_running_alter_statements = extract_running_ddl_statements(out, relation_name, "alter table")
assert len(athena_running_alter_statements) == 1

athena_running_alter_statement_tables = extract_rename_statement_table_names(athena_running_alter_statements[0])
athena_running_alter_statement_origin_table = athena_running_alter_statement_tables.get("alter_table_names")[0]
athena_running_alter_statement_renamed_to_table = athena_running_alter_statement_tables.get(
"rename_to_table_names"
)[0]

assert temp_schema_name in athena_running_alter_statement_origin_table
assert athena_running_alter_statement_renamed_to_table == f"`{project.test_schema}`.`{relation_name}`"

project.run_sql(drop_temp_schema)
8 changes: 4 additions & 4 deletions tests/functional/adapter/test_unique_tmp_table_suffix.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest
from tests.functional.adapter.utils.parse_dbt_run_output import (
extract_create_statement_table_names,
extract_running_create_statements,
extract_running_ddl_statements,
)

from dbt.contracts.results import RunStatus
Expand Down Expand Up @@ -55,7 +55,7 @@ def test__unique_tmp_table_suffix(self, project, capsys):
assert first_model_run_result.status == RunStatus.Success

out, _ = capsys.readouterr()
athena_running_create_statements = extract_running_create_statements(out, relation_name)
athena_running_create_statements = extract_running_ddl_statements(out, relation_name, "create table")

assert len(athena_running_create_statements) == 1

Expand Down Expand Up @@ -87,7 +87,7 @@ def test__unique_tmp_table_suffix(self, project, capsys):
assert incremental_model_run_result.status == RunStatus.Success

out, _ = capsys.readouterr()
athena_running_create_statements = extract_running_create_statements(out, relation_name)
athena_running_create_statements = extract_running_ddl_statements(out, relation_name, "create table")

assert len(athena_running_create_statements) == 1

Expand Down Expand Up @@ -119,7 +119,7 @@ def test__unique_tmp_table_suffix(self, project, capsys):
assert incremental_model_run_result.status == RunStatus.Success

out, _ = capsys.readouterr()
athena_running_create_statements = extract_running_create_statements(out, relation_name)
athena_running_create_statements = extract_running_ddl_statements(out, relation_name, "create table")

incremental_model_run_result_table_name_2 = extract_create_statement_table_names(
athena_running_create_statements[0]
Expand Down
26 changes: 18 additions & 8 deletions tests/functional/adapter/utils/parse_dbt_run_output.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import json
import re
from typing import List
from typing import Dict, List


def extract_running_create_statements(dbt_run_capsys_output: str, relation_name: str) -> List[str]:
sql_create_statements = []
def extract_running_ddl_statements(dbt_run_capsys_output: str, relation_name: str, ddl_type: str) -> List[str]:
sql_ddl_statements = []
# Skipping "Invoking dbt with ['run', '--select', 'unique_tmp_table_suffix'..."
for events_msg in dbt_run_capsys_output.split("\n")[1:]:
base_msg_data = None
Expand All @@ -21,16 +21,26 @@ def extract_running_create_statements(dbt_run_capsys_output: str, relation_name:
if base_msg_data:
base_msg = base_msg_data.get("base_msg")
if "Running Athena query:" in str(base_msg):
if "create table" in base_msg:
sql_create_statements.append(base_msg)
if ddl_type in base_msg:
sql_ddl_statements.append(base_msg)

if base_msg_data.get("conn_name") == f"model.test.{relation_name}" and "sql" in base_msg_data:
if "create table" in base_msg_data.get("sql"):
sql_create_statements.append(base_msg_data.get("sql"))
if ddl_type in base_msg_data.get("sql"):
sql_ddl_statements.append(base_msg_data.get("sql"))

return sql_create_statements
return sql_ddl_statements


def extract_create_statement_table_names(sql_create_statement: str) -> List[str]:
table_names = re.findall(r"(?s)(?<=create table ).*?(?=with)", sql_create_statement)
return [table_name.rstrip() for table_name in table_names]


def extract_rename_statement_table_names(sql_alter_rename_statement: str) -> Dict[str, List[str]]:
alter_table_names = re.findall(r"(?s)(?<=alter table ).*?(?= rename)", sql_alter_rename_statement)
rename_to_table_names = re.findall(r"(?s)(?<=rename to ).*?(?=$)", sql_alter_rename_statement)

return {
"alter_table_names": [table_name.rstrip() for table_name in alter_table_names],
"rename_to_table_names": [table_name.rstrip() for table_name in rename_to_table_names],
}
Loading