Files
db-middleware/dbs/mysql.py

242 lines
7.3 KiB
Python

import asyncio, aiomysql, decimal, datetime, uuid, logging
from typing import Literal, Any
from data.schemas import Connection, SelectResultData, SelectQuery
from data.app_types import CachedCursor
from core.exceptions import PoolNotFound
pools: dict[str, aiomysql.Pool] = {}
cached_cursors: dict[str, CachedCursor] = {}
closed_cached_cursors: dict[str, CachedCursor] = {}
cached_cursors_cleaner_task = None
async def close_old_cached_cursors():
global cached_cursors, closed_cached_cursors
for cursor_id in list(cached_cursors.keys()):
cursor = cached_cursors.get(cursor_id, None)
if cursor is None:
continue
print(cursor.close_at, datetime.datetime.now(datetime.UTC).timestamp())
if cursor.close_at > datetime.datetime.now(datetime.UTC).timestamp():
continue
try:
await cursor.close()
cached_cursors.pop(cursor_id, None)
closed_cached_cursors[cursor_id] = cursor
print(f"Closed cursor {cursor_id}")
except Exception as e:
print(f"Error closing Cursor {cursor_id} -> {e}")
async def remove_old_closed_cached_cursors():
global closed_cached_cursors
for cursor_id in set(closed_cached_cursors.keys()):
closed_cursor = closed_cached_cursors.get(cursor_id, None)
if closed_cursor is None:
continue
if (
closed_cursor.close_at + closed_cursor.ttl * 5
> datetime.datetime.now(datetime.UTC).timestamp()
):
continue
del closed_cached_cursors[cursor_id]
print(f"Removed cursor {cursor_id}")
async def cached_cursors_cleaner():
global cached_cursors, closed_cached_cursors
while True:
print("hey")
await close_old_cached_cursors()
await remove_old_closed_cached_cursors()
await asyncio.sleep(10)
async def pool_creator(connection: Connection, minsize=5, maxsize=10):
return await aiomysql.create_pool(
# **DB_CONFIG,
host=connection.host,
user=connection.username,
password=connection.password,
db=connection.db_name,
port=connection.port,
minsize=minsize,
maxsize=maxsize,
)
async def create_cursor(connection_id: int, query: SelectQuery) -> CachedCursor:
pool = pools.get(connection_id, None)
if pool is None:
raise PoolNotFound
connection = await pool.acquire()
cursor = await connection.cursor(aiomysql.SSCursor)
await cursor.execute(query.sql, query.params)
cached_cursor = CachedCursor(
id=str(uuid.uuid4()),
cursor=cursor,
connection=connection,
pool=pool,
connection_id=connection_id,
query=query,
)
return cached_cursor
async def execute_select_query(
pool: aiomysql.Pool, sql_query: str, params: list, fetch_num: int = 100
):
"""
Executes a SELECT query on the MySQL database asynchronously and returns the results.
Args:
pool (aiomysql.Pool): Connection pool to use.
query (str): The SELECT query to execute.
Returns:
list: A list of rows (as dictionaries) returned by the query.
"""
try:
async with pool.acquire() as connection:
async with connection.cursor(aiomysql.DictCursor) as cursor:
await cursor.execute(sql_query, params)
result = await cursor.fetchmany(fetch_num)
return result, cursor.rowcount
except Exception as e:
print(f"Error executing query: {e}")
return []
def construct_select_query(
table_name,
columns: Literal["*"] | list[str] = "*",
filters=None,
sort_by=None,
limit=None,
offset=None,
):
"""
Constructs a dynamic SELECT query based on the provided parameters.
Args:
table_name (str): The name of the table to query.
columns (str or list): The columns to select. Default is "*" (all columns).
filters (dict): A dictionary of filters (e.g., {"column": "value"}).
sort_by (str): The column to sort by (e.g., "column_name ASC").
limit (int): The maximum number of rows to return.
offset (int): The number of rows to skip (for pagination).
Returns:
str: The constructed SELECT query.
"""
# Handle columns
if isinstance(columns, list):
columns = ", ".join(columns)
elif columns != "*":
raise ValueError("Columns must be a list or '*'.")
# Base query
query = f"SELECT {columns}\n FROM {table_name}"
# Add filters
if filters:
filter_conditions = [
f"{column} = '{value}'" for column, value in filters.items()
]
query += "\n WHERE " + " AND ".join(filter_conditions)
# Add sorting
if sort_by:
query += f"\n ORDER BY {sort_by}"
# Add pagination
if limit is not None:
query += f"\n LIMIT {limit}"
if offset is not None:
query += f"\n OFFSET {offset}"
return query
async def get_tables_and_datatypes(pool: aiomysql.Pool):
"""
Retrieves all table names and their column data types from a MySQL database.
Args:
db_config (dict): MySQL database configuration (host, user, password, db, port).
Returns:
dict: A dictionary where keys are table names and values are lists of (column_name, data_type) tuples.
"""
try:
# Connect to the MySQL database
# connection = await aiomysql.connect(**db_config)
async with pool.acquire() as connection:
async with connection.cursor() as cursor:
# Get all table names
await cursor.execute("SHOW TABLES")
tables = await cursor.fetchall()
result = {}
for table in tables:
table_name = table[0]
result[table_name] = {}
# Get column names and data types for the current table
await cursor.execute(f"DESCRIBE {table_name}")
columns = await cursor.fetchall()
# Store column names and data types
for col in columns:
result[table_name][col[0]] = col[1]
# Close the connection
# connection.close()
# await connection.wait_closed()
return result
except Exception as e:
print(f"Error: {e}")
return {}
def serializer(raw_result: list[dict[str, Any]]):
serialized = []
while raw_result:
re = raw_result.pop()
serialized_re = {}
for key, value in re.items():
if isinstance(value, str | int | None | float):
serialized_re[key] = str(value)
elif isinstance(value, decimal.Decimal):
serialized_re[key] = float(value)
elif isinstance(value, datetime.date):
serialized_re[key] = value.strftime("%Y-%m-%d")
else:
serialized_re[key] = str(value)
print(key, type(value))
serialized.append(serialized_re)
return serialized
def dict_result_to_list(result: list[dict[str, Any]]) -> None | SelectResultData:
if len(result) == 0:
return None
columns = list(result[0].keys())
data = []
for row in result:
data.append(list(row.values()))
return SelectResultData(columns=columns, data=data)