Skip to content

Commit

Permalink
Merge pull request #1600 from arc53/postgres-tool
Browse files Browse the repository at this point in the history
feat: postgres tool
  • Loading branch information
dartpain authored Jan 24, 2025
2 parents 92528af + 5a38c09 commit c477a49
Show file tree
Hide file tree
Showing 3 changed files with 193 additions and 0 deletions.
1 change: 1 addition & 0 deletions application/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ prance==23.6.21.0
primp==0.10.0
prompt-toolkit==3.0.48
protobuf==5.29.3
psycopg2-binary==2.9.10
py==1.11.0
pydantic==2.10.4
pydantic-core==2.27.2
Expand Down
163 changes: 163 additions & 0 deletions application/tools/implementations/postgres.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
import psycopg2
from application.tools.base import Tool

class PostgresTool(Tool):
"""
PostgreSQL Database Tool
A tool for connecting to a PostgreSQL database using a connection string,
executing SQL queries, and retrieving schema information.
"""

def __init__(self, config):
self.config = config
self.connection_string = config.get("token", "")

def execute_action(self, action_name, **kwargs):
actions = {
"postgres_execute_sql": self._execute_sql,
"postgres_get_schema": self._get_schema,
}

if action_name in actions:
return actions[action_name](**kwargs)
else:
raise ValueError(f"Unknown action: {action_name}")

def _execute_sql(self, sql_query):
"""
Executes an SQL query against the PostgreSQL database using a connection string.
"""
conn = None # Initialize conn to None for error handling
try:
conn = psycopg2.connect(self.connection_string)
cur = conn.cursor()
cur.execute(sql_query)
conn.commit()

if sql_query.strip().lower().startswith("select"):
column_names = [desc[0] for desc in cur.description] if cur.description else []
results = []
rows = cur.fetchall()
for row in rows:
results.append(dict(zip(column_names, row)))
response_data = {"data": results, "column_names": column_names}
else:
row_count = cur.rowcount
response_data = {"message": f"Query executed successfully, {row_count} rows affected."}

cur.close()
return {
"status_code": 200,
"message": "SQL query executed successfully.",
"response_data": response_data,
}

except psycopg2.Error as e:
error_message = f"Database error: {e}"
print(f"Database error: {e}")
return {
"status_code": 500,
"message": "Failed to execute SQL query.",
"error": error_message,
}
finally:
if conn: # Ensure connection is closed even if errors occur
conn.close()

def _get_schema(self, db_name):
"""
Retrieves the schema of the PostgreSQL database using a connection string.
"""
conn = None # Initialize conn to None for error handling
try:
conn = psycopg2.connect(self.connection_string)
cur = conn.cursor()

cur.execute("""
SELECT
table_name,
column_name,
data_type,
column_default,
is_nullable
FROM
information_schema.columns
WHERE
table_schema = 'public'
ORDER BY
table_name,
ordinal_position;
""")

schema_data = {}
for row in cur.fetchall():
table_name, column_name, data_type, column_default, is_nullable = row
if table_name not in schema_data:
schema_data[table_name] = []
schema_data[table_name].append({
"column_name": column_name,
"data_type": data_type,
"column_default": column_default,
"is_nullable": is_nullable
})

cur.close()
return {
"status_code": 200,
"message": "Database schema retrieved successfully.",
"schema": schema_data,
}

except psycopg2.Error as e:
error_message = f"Database error: {e}"
print(f"Database error: {e}")
return {
"status_code": 500,
"message": "Failed to retrieve database schema.",
"error": error_message,
}
finally:
if conn: # Ensure connection is closed even if errors occur
conn.close()

def get_actions_metadata(self):
return [
{
"name": "postgres_execute_sql",
"description": "Execute an SQL query against the PostgreSQL database and return the results. Use this tool to interact with the database, e.g., retrieve specific data or perform updates. Only SELECT queries will return data, other queries will return execution status.",
"parameters": {
"type": "object",
"properties": {
"sql_query": {
"type": "string",
"description": "The SQL query to execute.",
},
},
"required": ["sql_query"],
"additionalProperties": False,
},
},
{
"name": "postgres_get_schema",
"description": "Retrieve the schema of the PostgreSQL database, including tables and their columns. Use this to understand the database structure before executing queries. db_name is 'default' if not provided.",
"parameters": {
"type": "object",
"properties": {
"db_name": {
"type": "string",
"description": "The name of the database to retrieve the schema for.",
},
},
"required": ["db_name"],
"additionalProperties": False,
},
},
]

def get_config_requirements(self):
return {
"token": {
"type": "string",
"description": "PostgreSQL database connection string (e.g., 'postgresql://user:password@host:port/dbname')",
},
}
29 changes: 29 additions & 0 deletions frontend/public/toolIcons/tool_postgres.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit c477a49

Please sign in to comment.