Files
db-middleware/dbs/mysql.py

178 lines
5.4 KiB
Python

import aiomysql, decimal, datetime
import asyncio
from typing import Literal, Any
from data.schemas import Connection, SelectResultData
# Database configuration
DB_CONFIG = {
"host": "localhost",
"user": "me", # Replace with your MySQL username
"password": "Passwd3.14", # Replace with your MySQL password
"db": "testing", # Replace with your database name
"port": 3306, # Default MySQL port
}
pools: None | dict[str, aiomysql.Pool] = {}
async def pool_creator(connection: Connection, minsize=5, maxsize=10):
return await aiomysql.create_pool(
# **DB_CONFIG,
host=connection.host,
user=connection.username,
password=connection.password,
db=connection.db_name,
port=connection.port,
minsize=minsize,
maxsize=maxsize,
)
async def execute_select_query(pool: aiomysql.Pool, query: str, params: list, fetch_num:int = 100):
"""
Executes a SELECT query on the MySQL database asynchronously and returns the results.
Args:
pool (aiomysql.Pool): Connection pool to use.
query (str): The SELECT query to execute.
Returns:
list: A list of rows (as dictionaries) returned by the query.
"""
try:
async with pool.acquire() as connection:
async with connection.cursor(aiomysql.DictCursor) as cursor:
await cursor.execute(query, params)
result = await cursor.fetchmany(fetch_num)
return result, cursor.rowcount
except Exception as e:
print(f"Error executing query: {e}")
return []
def construct_select_query(
table_name,
columns: Literal["*"] | list[str] = "*",
filters=None,
sort_by=None,
limit=None,
offset=None,
):
"""
Constructs a dynamic SELECT query based on the provided parameters.
Args:
table_name (str): The name of the table to query.
columns (str or list): The columns to select. Default is "*" (all columns).
filters (dict): A dictionary of filters (e.g., {"column": "value"}).
sort_by (str): The column to sort by (e.g., "column_name ASC").
limit (int): The maximum number of rows to return.
offset (int): The number of rows to skip (for pagination).
Returns:
str: The constructed SELECT query.
"""
# Handle columns
if isinstance(columns, list):
columns = ", ".join(columns)
elif columns != "*":
raise ValueError("Columns must be a list or '*'.")
# Base query
query = f"SELECT {columns}\n FROM {table_name}"
# Add filters
if filters:
filter_conditions = [
f"{column} = '{value}'" for column, value in filters.items()
]
query += "\n WHERE " + " AND ".join(filter_conditions)
# Add sorting
if sort_by:
query += f"\n ORDER BY {sort_by}"
# Add pagination
if limit is not None:
query += f"\n LIMIT {limit}"
if offset is not None:
query += f"\n OFFSET {offset}"
return query
async def get_tables_and_datatypes(pool: aiomysql.Pool):
"""
Retrieves all table names and their column data types from a MySQL database.
Args:
db_config (dict): MySQL database configuration (host, user, password, db, port).
Returns:
dict: A dictionary where keys are table names and values are lists of (column_name, data_type) tuples.
"""
try:
# Connect to the MySQL database
# connection = await aiomysql.connect(**db_config)
async with pool.acquire() as connection:
async with connection.cursor() as cursor:
# Get all table names
await cursor.execute("SHOW TABLES")
tables = await cursor.fetchall()
result = {}
for table in tables:
table_name = table[0]
result[table_name] = {}
# Get column names and data types for the current table
await cursor.execute(f"DESCRIBE {table_name}")
columns = await cursor.fetchall()
# Store column names and data types
for col in columns:
result[table_name][col[0]] = col[1]
# Close the connection
# connection.close()
# await connection.wait_closed()
return result
except Exception as e:
print(f"Error: {e}")
return {}
def serializer(raw_result: list[dict[str, Any]]):
serialized = []
while raw_result:
re = raw_result.pop()
serialized_re = {}
for key, value in re.items():
if isinstance(value, str | int | None | float):
serialized_re[key] = str(value)
elif isinstance(value, decimal.Decimal):
serialized_re[key] = float(value)
elif isinstance(value, datetime.date):
serialized_re[key] = value.strftime("%Y-%m-%d")
else:
serialized_re[key] = str(value)
print(key, type(value))
serialized.append(serialized_re)
return serialized
def dict_result_to_list(result: list[dict[str, Any]]) -> None | SelectResultData:
if len(result) == 0:
return None
columns = list(result[0].keys())
data = []
for row in result:
data.append(list(row.values()))
return SelectResultData(columns=columns, data=data)