Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for pg procedure #143

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/alembic_utils/depends.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def upgrade() -> None:
"""
from alembic_utils.pg_function import PGFunction
from alembic_utils.pg_materialized_view import PGMaterializedView
from alembic_utils.pg_procedure import PGProcedure
from alembic_utils.pg_trigger import PGTrigger
from alembic_utils.pg_view import PGView
from alembic_utils.replaceable_entity import ReplaceableEntity
Expand All @@ -91,6 +92,7 @@ def collect_all_db_entities(sess: Session) -> List[ReplaceableEntity]:

return [
*PGFunction.from_database(sess, "%"),
*PGProcedure.from_database(sess, "%"),
*PGTrigger.from_database(sess, "%"),
*PGView.from_database(sess, "%"),
*PGMaterializedView.from_database(sess, "%"),
Expand Down
154 changes: 154 additions & 0 deletions src/alembic_utils/pg_procedure.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
# pylint: disable=unused-argument,invalid-name,line-too-long
from typing import List

from parse import parse
from sqlalchemy import text as sql_text

from alembic_utils.exceptions import SQLParseFailure
from alembic_utils.replaceable_entity import ReplaceableEntity
from alembic_utils.statement import (
escape_colon_for_plpgsql,
escape_colon_for_sql,
normalize_whitespace,
strip_terminating_semicolon,
)


class PGProcedure(ReplaceableEntity):
"""A PostgreSQL Procedure compatible with `alembic revision --autogenerate`

**Parameters:**

* **schema** - *str*: A SQL schema name
* **signature** - *str*: A SQL procedure's call signature
* **definition** - *str*: The remainig procedure body and identifiers
"""

type_ = "procedure"

def __init__(self, schema: str, signature: str, definition: str):
super().__init__(schema, signature, definition)
# Detect if procedure uses plpgsql and update escaping rules to not escape ":="
is_plpgsql: bool = "language plpgsql" in normalize_whitespace(definition).lower().replace(
"'", ""
)
escaping_callable = escape_colon_for_plpgsql if is_plpgsql else escape_colon_for_sql
# Override definition with correct escaping rules
self.definition: str = escaping_callable(strip_terminating_semicolon(definition))

@classmethod
def from_sql(cls, sql: str) -> "PGProcedure":
"""Create an instance instance from a SQL string"""
template = "create{}procedure{:s}{schema}.{signature_name}({signature_arg}){:s}{definition}"
result = parse(template, sql.strip(), case_sensitive=False)
if result is not None:
raw_signature = f'{result["signature_name"]}({result["signature_arg"]})'
# remove possible quotes from signature
signature = (
"".join(raw_signature.split('"', 2))
if raw_signature.startswith('"')
else raw_signature
)
return cls(
schema=result["schema"],
signature=signature,
definition=result["definition"],
)
raise SQLParseFailure(f'Failed to parse SQL into PGProcedure """{sql}"""')

@property
def literal_signature(self) -> str:
"""Adds quoting around the procedure name when emitting SQL statements

e.g.
'toUpper(text) returns text' -> '"toUpper"(text) text'
"""
# May already be quoted if loading from database or SQL file
name, remainder = self.signature.split("(", 1)
return '"' + name.strip() + '"(' + remainder

def to_sql_statement_create(self):
"""Generates a SQL "create procedure" statement for PGProcedure"""
return sql_text(
f"CREATE PROCEDURE {self.literal_schema}.{self.literal_signature} {self.definition}",
)

def to_sql_statement_drop(self, cascade=False):
"""Generates a SQL "drop procedure" statement for PGProcedure"""
cascade = "cascade" if cascade else ""
template = "{procedure_name}({parameters})"
result = parse(template, self.signature, case_sensitive=False)
try:
procedure_name = result["procedure_name"].strip()
parameters_str = result["parameters"].strip()
except TypeError:
# Did not match, NoneType is not scriptable
result = parse("{procedure_name}()", self.signature, case_sensitive=False)
procedure_name = result["procedure_name"].strip()
parameters_str = ""

# NOTE: Will fail if a text field has a default and that deafult contains a comma...
parameters: list[str] = parameters_str.split(",")
parameters = [x[: len(x.lower().split("default")[0])] for x in parameters]
parameters = [x.strip() for x in parameters]
drop_params = ", ".join(parameters)
return sql_text(
f'DROP PROCEDURE {self.literal_schema}."{procedure_name}"({drop_params}) {cascade}',
)

def to_sql_statement_create_or_replace(self):
"""Generates a SQL "create or replace procedure" statement for PGProcedure"""
yield sql_text(
f"CREATE OR REPLACE PROCEDURE {self.literal_schema}.{self.literal_signature} {self.definition}",
)

@classmethod
def from_database(cls, sess, schema):
"""Get a list of all procedures defined in the db"""

sql = sql_text(
"""
with extension_functions as (
select
objid as extension_function_oid
from
pg_depend
where
-- depends on an extension
deptype='e'
-- is a proc/function
and classid = 'pg_proc'::regclass
)

select
n.nspname as function_schema,
p.proname as procedure_name,
pg_get_function_arguments(p.oid) as function_arguments,
case
when l.lanname = 'internal' then p.prosrc
else pg_get_functiondef(p.oid)
end as create_statement,
t.typname as return_type,
l.lanname as function_language
from
pg_proc p
left join pg_namespace n on p.pronamespace = n.oid
left join pg_language l on p.prolang = l.oid
left join pg_type t on t.oid = p.prorettype
left join extension_functions ef on p.oid = ef.extension_function_oid
where
n.nspname not in ('pg_catalog', 'information_schema')
-- Filter out functions from extensions
and ef.extension_function_oid is null
and n.nspname = :schema
and p.prokind = 'p'
"""
)

rows = sess.execute(sql, {"schema": schema}).fetchall()
db_functions = [cls.from_sql(x[3]) for x in rows]

for func in db_functions:
assert func is not None

return db_functions
Loading