242 lines
7.2 KiB
Python
242 lines
7.2 KiB
Python
import asyncio, aiomysql, decimal, datetime, uuid, logging
|
|
|
|
from typing import Literal, Any
|
|
from data.schemas import Connection, SelectResultData, SelectQuery
|
|
from data.app_types import CachedCursor
|
|
from core.exceptions import PoolNotFound
|
|
|
|
pools: dict[str, aiomysql.Pool] = {}
|
|
cached_cursors: dict[str, CachedCursor] = {}
|
|
closed_cached_cursors: dict[str, CachedCursor] = {}
|
|
|
|
cached_cursors_cleaner_task = None
|
|
|
|
|
|
async def close_old_cached_cursors():
|
|
global cached_cursors, closed_cached_cursors
|
|
|
|
for cursor_id in list(cached_cursors.keys()):
|
|
cursor = cached_cursors.get(cursor_id, None)
|
|
if cursor is None:
|
|
continue
|
|
print(cursor.close_at, datetime.datetime.now(datetime.UTC).timestamp())
|
|
if cursor.close_at > datetime.datetime.now(datetime.UTC).timestamp():
|
|
continue
|
|
|
|
try:
|
|
await cursor.close()
|
|
cached_cursors.pop(cursor_id, None)
|
|
closed_cached_cursors[cursor_id] = cursor
|
|
print(f"Closed cursor {cursor_id}")
|
|
except Exception as e:
|
|
print(f"Error closing Cursor {cursor_id} -> {e}")
|
|
|
|
|
|
async def remove_old_closed_cached_cursors():
|
|
global closed_cached_cursors
|
|
for cursor_id in set(closed_cached_cursors.keys()):
|
|
closed_cursor = closed_cached_cursors.get(cursor_id, None)
|
|
if closed_cursor is None:
|
|
continue
|
|
if (
|
|
closed_cursor.close_at + closed_cursor.ttl * 5
|
|
> datetime.datetime.now(datetime.UTC).timestamp()
|
|
):
|
|
continue
|
|
|
|
del closed_cached_cursors[cursor_id]
|
|
print(f"Removed cursor {cursor_id}")
|
|
|
|
|
|
async def cached_cursors_cleaner():
|
|
global cached_cursors, closed_cached_cursors
|
|
while True:
|
|
await close_old_cached_cursors()
|
|
await remove_old_closed_cached_cursors()
|
|
await asyncio.sleep(10)
|
|
|
|
|
|
async def pool_creator(connection: Connection, minsize=5, maxsize=10):
|
|
|
|
return await aiomysql.create_pool(
|
|
host=connection.host,
|
|
user=connection.username,
|
|
password=connection.password,
|
|
db=connection.db_name,
|
|
port=connection.port,
|
|
minsize=minsize,
|
|
maxsize=maxsize,
|
|
)
|
|
|
|
|
|
async def create_cursor(connection_id: int, query: SelectQuery) -> CachedCursor:
|
|
pool = pools.get(connection_id, None)
|
|
|
|
if pool is None:
|
|
raise PoolNotFound
|
|
|
|
connection = await pool.acquire()
|
|
cursor = await connection.cursor(aiomysql.SSCursor)
|
|
await cursor.execute(query.sql, query.params)
|
|
|
|
cached_cursor = CachedCursor(
|
|
id=str(uuid.uuid4()),
|
|
cursor=cursor,
|
|
connection=connection,
|
|
pool=pool,
|
|
connection_id=connection_id,
|
|
query=query,
|
|
)
|
|
|
|
return cached_cursor
|
|
|
|
|
|
async def execute_select_query(
|
|
pool: aiomysql.Pool, sql_query: str, params: list, fetch_num: int = 100
|
|
):
|
|
"""
|
|
Executes a SELECT query on the MySQL database asynchronously and returns the results.
|
|
|
|
Args:
|
|
pool (aiomysql.Pool): Connection pool to use.
|
|
query (str): The SELECT query to execute.
|
|
|
|
Returns:
|
|
list: A list of rows (as dictionaries) returned by the query.
|
|
"""
|
|
try:
|
|
async with pool.acquire() as connection:
|
|
async with connection.cursor(aiomysql.DictCursor) as cursor:
|
|
await cursor.execute(sql_query, params)
|
|
result = await cursor.fetchmany(fetch_num)
|
|
|
|
return result, cursor.rowcount
|
|
|
|
except Exception as e:
|
|
print(f"Error executing query: {e}")
|
|
return []
|
|
|
|
|
|
def construct_select_query(
|
|
table_name,
|
|
columns: Literal["*"] | list[str] = "*",
|
|
filters=None,
|
|
sort_by=None,
|
|
limit=None,
|
|
offset=None,
|
|
):
|
|
"""
|
|
Constructs a dynamic SELECT query based on the provided parameters.
|
|
|
|
Args:
|
|
table_name (str): The name of the table to query.
|
|
columns (str or list): The columns to select. Default is "*" (all columns).
|
|
filters (dict): A dictionary of filters (e.g., {"column": "value"}).
|
|
sort_by (str): The column to sort by (e.g., "column_name ASC").
|
|
limit (int): The maximum number of rows to return.
|
|
offset (int): The number of rows to skip (for pagination).
|
|
|
|
Returns:
|
|
str: The constructed SELECT query.
|
|
"""
|
|
# Handle columns
|
|
if isinstance(columns, list):
|
|
columns = ", ".join(columns)
|
|
elif columns != "*":
|
|
raise ValueError("Columns must be a list or '*'.")
|
|
|
|
# Base query
|
|
query = f"SELECT {columns}\n FROM {table_name}"
|
|
|
|
# Add filters
|
|
if filters:
|
|
filter_conditions = [
|
|
f"{column} = '{value}'" for column, value in filters.items()
|
|
]
|
|
query += "\n WHERE " + " AND ".join(filter_conditions)
|
|
|
|
# Add sorting
|
|
if sort_by:
|
|
query += f"\n ORDER BY {sort_by}"
|
|
|
|
# Add pagination
|
|
if limit is not None:
|
|
query += f"\n LIMIT {limit}"
|
|
if offset is not None:
|
|
query += f"\n OFFSET {offset}"
|
|
|
|
return query
|
|
|
|
|
|
async def get_tables_and_datatypes(pool: aiomysql.Pool):
|
|
"""
|
|
Retrieves all table names and their column data types from a MySQL database.
|
|
|
|
Args:
|
|
db_config (dict): MySQL database configuration (host, user, password, db, port).
|
|
|
|
Returns:
|
|
dict: A dictionary where keys are table names and values are lists of (column_name, data_type) tuples.
|
|
"""
|
|
try:
|
|
# Connect to the MySQL database
|
|
# connection = await aiomysql.connect(**db_config)
|
|
async with pool.acquire() as connection:
|
|
async with connection.cursor() as cursor:
|
|
# Get all table names
|
|
await cursor.execute("SHOW TABLES")
|
|
tables = await cursor.fetchall()
|
|
|
|
result = {}
|
|
for table in tables:
|
|
table_name = table[0]
|
|
result[table_name] = {}
|
|
|
|
# Get column names and data types for the current table
|
|
await cursor.execute(f"DESCRIBE {table_name}")
|
|
columns = await cursor.fetchall()
|
|
# Store column names and data types
|
|
for col in columns:
|
|
result[table_name][col[0]] = col[1]
|
|
|
|
# Close the connection
|
|
# connection.close()
|
|
# await connection.wait_closed()
|
|
|
|
return result
|
|
except Exception as e:
|
|
print(f"Error: {e}")
|
|
return {}
|
|
|
|
|
|
def serializer(raw_result: list[dict[str, Any]]):
|
|
serialized = []
|
|
while raw_result:
|
|
re = raw_result.pop()
|
|
serialized_re = {}
|
|
for key, value in re.items():
|
|
if isinstance(value, str | int | None | float):
|
|
serialized_re[key] = str(value)
|
|
elif isinstance(value, decimal.Decimal):
|
|
serialized_re[key] = float(value)
|
|
elif isinstance(value, datetime.date):
|
|
serialized_re[key] = value.strftime("%Y-%m-%d")
|
|
else:
|
|
serialized_re[key] = str(value)
|
|
print(key, type(value))
|
|
|
|
serialized.append(serialized_re)
|
|
return serialized
|
|
|
|
|
|
def dict_result_to_list(result: list[dict[str, Any]]) -> None | SelectResultData:
|
|
if len(result) == 0:
|
|
return None
|
|
|
|
columns = list(result[0].keys())
|
|
data = []
|
|
for row in result:
|
|
data.append(list(row.values()))
|
|
|
|
return SelectResultData(columns=columns, data=data)
|