Skip to content

Commit

Permalink
chore: Apply non-functional refactoring and fix typos
Browse files Browse the repository at this point in the history
  • Loading branch information
edgarrmondragon committed Feb 19, 2024
1 parent 14a7377 commit 7f5fa06
Show file tree
Hide file tree
Showing 5 changed files with 128 additions and 142 deletions.
24 changes: 16 additions & 8 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,22 @@ target-version = "py38"

[tool.ruff.lint]
select = [
"F", # Pyflakes
"W", # pycodestyle warnings
"E", # pycodestyle errors
"I", # isort
"N", # pep8-naming
"D", # pydocsyle
"ICN", # flake8-import-conventions
"RUF", # ruff
"F", # Pyflakes
"W", # pycodestyle warnings
"E", # pycodestyle errors
"I", # isort
"N", # pep8-naming
"D", # pydocsyle
"UP", # pyupgrade
"ICN", # flake8-import-conventions
"RET", # flake8-return
"SIM", # flake8-simplify
"TCH", # flake8-type-checking
"ERA", # eradicate
"PGH", # pygrep-hooks
"PL", # Pylint
"PERF", # Perflint
"RUF", # ruff
]

[tool.ruff.lint.flake8-import-conventions]
Expand Down
99 changes: 48 additions & 51 deletions target_postgres/connector.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
"""Handles Postgres interactions."""


from __future__ import annotations

import atexit
import io
import itertools
import signal
import sys
import typing as t
from contextlib import contextmanager
from functools import cached_property
Expand Down Expand Up @@ -92,7 +95,7 @@ def interpret_content_encoding(self) -> bool:
"""
return self.config.get("interpret_content_encoding", False)

def prepare_table( # type: ignore[override]
def prepare_table( # type: ignore[override] # noqa: PLR0913
self,
full_table_name: str,
schema: dict,
Expand All @@ -118,7 +121,7 @@ def prepare_table( # type: ignore[override]
meta = sa.MetaData(schema=schema_name)
table: sa.Table
if not self.table_exists(full_table_name=full_table_name):
table = self.create_empty_table(
return self.create_empty_table(
table_name=table_name,
meta=meta,
schema=schema,
Expand All @@ -127,7 +130,6 @@ def prepare_table( # type: ignore[override]
as_temp_table=as_temp_table,
connection=connection,
)
return table
meta.reflect(connection, only=[table_name])
table = meta.tables[
full_table_name
Expand Down Expand Up @@ -174,19 +176,19 @@ def copy_table_structure(
_, schema_name, table_name = self.parse_full_table_name(full_table_name)
meta = sa.MetaData(schema=schema_name)
new_table: sa.Table
columns = []
if self.table_exists(full_table_name=full_table_name):
raise RuntimeError("Table already exists")
for column in from_table.columns:
columns.append(column._copy())

columns = [column._copy() for column in from_table.columns]

if as_temp_table:
new_table = sa.Table(table_name, meta, *columns, prefixes=["TEMPORARY"])
new_table.create(bind=connection)
return new_table
else:
new_table = sa.Table(table_name, meta, *columns)
new_table.create(bind=connection)
return new_table

new_table = sa.Table(table_name, meta, *columns)
new_table.create(bind=connection)
return new_table

@contextmanager
def _connect(self) -> t.Iterator[sa.engine.Connection]:
Expand All @@ -197,18 +199,17 @@ def drop_table(self, table: sa.Table, connection: sa.engine.Connection):
"""Drop table data."""
table.drop(bind=connection)

def clone_table(
def clone_table( # noqa: PLR0913
self, new_table_name, table, metadata, connection, temp_table
) -> sa.Table:
"""Clone a table."""
new_columns = []
for column in table.columns:
new_columns.append(
sa.Column(
column.name,
column.type,
)
new_columns = [
sa.Column(
column.name,
column.type,
)
for column in table.columns
]
if temp_table is True:
new_table = sa.Table(
new_table_name, metadata, *new_columns, prefixes=["TEMPORARY"]
Expand Down Expand Up @@ -293,9 +294,8 @@ def pick_individual_type(self, jsonschema_type: dict):
):
return HexByteString()
individual_type = th.to_sql_type(jsonschema_type)
if isinstance(individual_type, VARCHAR):
return TEXT()
return individual_type

return TEXT() if isinstance(individual_type, VARCHAR) else individual_type

@staticmethod
def pick_best_sql_type(sql_type_array: list):
Expand Down Expand Up @@ -323,13 +323,12 @@ def pick_best_sql_type(sql_type_array: list):
NOTYPE,
]

for sql_type in precedence_order:
for obj in sql_type_array:
if isinstance(obj, sql_type):
return obj
for sql_type, obj in itertools.product(precedence_order, sql_type_array):
if isinstance(obj, sql_type):
return obj
return TEXT()

def create_empty_table( # type: ignore[override]
def create_empty_table( # type: ignore[override] # noqa: PLR0913
self,
table_name: str,
meta: sa.MetaData,
Expand All @@ -343,7 +342,7 @@ def create_empty_table( # type: ignore[override]
Args:
table_name: the target table name.
meta: the SQLAchemy metadata object.
meta: the SQLAlchemy metadata object.
schema: the JSON schema for the new table.
connection: the database connection.
primary_keys: list of key properties.
Expand Down Expand Up @@ -386,7 +385,7 @@ def create_empty_table( # type: ignore[override]
new_table.create(bind=connection)
return new_table

def prepare_column(
def prepare_column( # noqa: PLR0913
self,
full_table_name: str,
column_name: str,
Expand Down Expand Up @@ -434,7 +433,7 @@ def prepare_column(
column_object=column_object,
)

def _create_empty_column( # type: ignore[override]
def _create_empty_column( # type: ignore[override] # noqa: PLR0913
self,
schema_name: str,
table_name: str,
Expand Down Expand Up @@ -499,7 +498,7 @@ def get_column_add_ddl( # type: ignore[override]
},
)

def _adapt_column_type( # type: ignore[override]
def _adapt_column_type( # type: ignore[override] # noqa: PLR0913
self,
schema_name: str,
table_name: str,
Expand Down Expand Up @@ -542,7 +541,7 @@ def _adapt_column_type( # type: ignore[override]
return

# Not the same type, generic type or compatible types
# calling merge_sql_types for assistnace
# calling merge_sql_types for assistance
compatible_sql_type = self.merge_sql_types([current_type, sql_type])

if str(compatible_sql_type) == str(current_type):
Expand Down Expand Up @@ -612,17 +611,16 @@ def get_sqlalchemy_url(self, config: dict) -> str:
if config.get("sqlalchemy_url"):
return cast(str, config["sqlalchemy_url"])

else:
sqlalchemy_url = URL.create(
drivername=config["dialect+driver"],
username=config["user"],
password=config["password"],
host=config["host"],
port=config["port"],
database=config["database"],
query=self.get_sqlalchemy_query(config),
)
return cast(str, sqlalchemy_url)
sqlalchemy_url = URL.create(
drivername=config["dialect+driver"],
username=config["user"],
password=config["password"],
host=config["host"],
port=config["port"],
database=config["database"],
query=self.get_sqlalchemy_query(config),
)
return cast(str, sqlalchemy_url)

def get_sqlalchemy_query(self, config: dict) -> dict:
"""Get query values to be used for sqlalchemy URL creation.
Expand All @@ -638,7 +636,7 @@ def get_sqlalchemy_query(self, config: dict) -> dict:
# ssl_enable is for verifying the server's identity to the client.
if config["ssl_enable"]:
ssl_mode = config["ssl_mode"]
query.update({"sslmode": ssl_mode})
query["sslmode"] = ssl_mode
query["sslrootcert"] = self.filepath_or_certificate(
value=config["ssl_certificate_authority"],
alternative_name=config["ssl_storage_directory"] + "/root.crt",
Expand Down Expand Up @@ -684,12 +682,11 @@ def filepath_or_certificate(
"""
if path.isfile(value):
return value
else:
with open(alternative_name, "wb") as alternative_file:
alternative_file.write(value.encode("utf-8"))
if restrict_permissions:
chmod(alternative_name, 0o600)
return alternative_name
with open(alternative_name, "wb") as alternative_file:
alternative_file.write(value.encode("utf-8"))
if restrict_permissions:
chmod(alternative_name, 0o600)
return alternative_name

def guess_key_type(self, key_data: str) -> paramiko.PKey:
"""Guess the type of the private key.
Expand All @@ -714,7 +711,7 @@ def guess_key_type(self, key_data: str) -> paramiko.PKey:
):
try:
key = key_class.from_private_key(io.StringIO(key_data)) # type: ignore[attr-defined]
except paramiko.SSHException:
except paramiko.SSHException: # noqa: PERF203
continue
else:
return key
Expand All @@ -734,7 +731,7 @@ def catch_signal(self, signum, frame) -> None:
signum: The signal number
frame: The current stack frame
"""
exit(1) # Calling this to be sure atexit is called, so clean_up gets called
sys.exit(1) # Calling this to be sure atexit is called, so clean_up gets called

def _get_column_type( # type: ignore[override]
self,
Expand Down
53 changes: 21 additions & 32 deletions target_postgres/sinks.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,7 @@ def setup(self) -> None:
This method is called on Sink creation, and creates the required Schema and
Table entities in the target database.
"""
if self.key_properties is None or self.key_properties == []:
self.append_only = True
else:
self.append_only = False
self.append_only = self.key_properties is None or self.key_properties == []
if self.schema_name:
self.connector.prepare_schema(self.schema_name)
with self.connector._connect() as connection, connection.begin():
Expand Down Expand Up @@ -109,14 +106,14 @@ def process_batch(self, context: dict) -> None:

def generate_temp_table_name(self):
"""Uuid temp table name."""
# sa.exc.IdentifierError: Identifier
# sa.exc.IdentifierError: Identifier # noqa: ERA001
# 'temp_test_optional_attributes_388470e9_fbd0_47b7_a52f_d32a2ee3f5f6'
# exceeds maximum length of 63 characters
# Is hit if we have a long table name, there is no limit on Temporary tables
# in postgres, used a guid just in case we are using the same session
return f"{str(uuid.uuid4()).replace('-', '_')}"

def bulk_insert_records( # type: ignore[override]
def bulk_insert_records( # type: ignore[override] # noqa: PLR0913
self,
table: sa.Table,
schema: dict,
Expand Down Expand Up @@ -156,24 +153,24 @@ def bulk_insert_records( # type: ignore[override]
if self.append_only is False:
insert_records: Dict[str, Dict] = {} # pk : record
for record in records:
insert_record = {}
for column in columns:
insert_record[column.name] = record.get(column.name)
insert_record = {
column.name: record.get(column.name) for column in columns
}
# No need to check for a KeyError here because the SDK already
# guaruntees that all key properties exist in the record.
# guarantees that all key properties exist in the record.
primary_key_value = "".join([str(record[key]) for key in primary_keys])
insert_records[primary_key_value] = insert_record
data_to_insert = list(insert_records.values())
else:
for record in records:
insert_record = {}
for column in columns:
insert_record[column.name] = record.get(column.name)
insert_record = {
column.name: record.get(column.name) for column in columns
}
data_to_insert.append(insert_record)
connection.execute(insert, data_to_insert)
return True

def upsert(
def upsert( # noqa: PLR0913
self,
from_table: sa.Table,
to_table: sa.Table,
Expand Down Expand Up @@ -232,7 +229,7 @@ def upsert(
# Update
where_condition = join_condition
update_columns = {}
for column_name in self.schema["properties"].keys():
for column_name in self.schema["properties"]:
from_table_column: sa.Column = from_table.columns[column_name]
to_table_column: sa.Column = to_table.columns[column_name]
update_columns[to_table_column] = from_table_column
Expand All @@ -249,14 +246,13 @@ def column_representation(
schema: dict,
) -> List[sa.Column]:
"""Return a sqlalchemy table representation for the current schema."""
columns: list[sa.Column] = []
for property_name, property_jsonschema in schema["properties"].items():
columns.append(
sa.Column(
property_name,
self.connector.to_sql_type(property_jsonschema),
)
columns: list[sa.Column] = [
sa.Column(
property_name,
self.connector.to_sql_type(property_jsonschema),
)
for property_name, property_jsonschema in schema["properties"].items()
]
return columns

def generate_insert_statement(
Expand Down Expand Up @@ -286,12 +282,12 @@ def schema_name(self) -> Optional[str]:
"""Return the schema name or `None` if using names with no schema part.
Note that after the next SDK release (after 0.14.0) we can remove this
as it's already upstreamed.
as it's already up-streamed.
Returns:
The target schema name.
"""
# Look for a default_target_scheme in the configuraion fle
# Look for a default_target_scheme in the configuration file
default_target_schema: str = self.config.get("default_target_schema", None)
parts = self.stream_name.split("-")

Expand All @@ -302,14 +298,7 @@ def schema_name(self) -> Optional[str]:
if default_target_schema:
return default_target_schema

if len(parts) in {2, 3}:
# Stream name is a two-part or three-part identifier.
# Use the second-to-last part as the schema name.
stream_schema = self.conform_name(parts[-2], "schema")
return stream_schema

# Schema name not detected.
return None
return self.conform_name(parts[-2], "schema") if len(parts) in {2, 3} else None

def activate_version(self, new_version: int) -> None:
"""Bump the active version of the target table.
Expand Down
Loading

0 comments on commit 7f5fa06

Please sign in to comment.