import asyncio, aiomysql, decimal, datetime, uuid, logging from typing import Literal, Any from data.schemas import Connection, SelectResultData, SelectQuery from data.app_types import CachedCursor from core.exceptions import PoolNotFound pools: dict[str, aiomysql.Pool] = {} cached_cursors: dict[str, CachedCursor] = {} closed_cached_cursors: dict[str, CachedCursor] = {} cached_cursors_cleaner_task = None async def close_old_cached_cursors(): global cached_cursors, closed_cached_cursors for cursor_id in list(cached_cursors.keys()): cursor = cached_cursors.get(cursor_id, None) if cursor is None: continue print(cursor.close_at, datetime.datetime.now(datetime.UTC).timestamp()) if cursor.close_at > datetime.datetime.now(datetime.UTC).timestamp(): continue try: await cursor.close() cached_cursors.pop(cursor_id, None) closed_cached_cursors[cursor_id] = cursor print(f"Closed cursor {cursor_id}") except Exception as e: print(f"Error closing Cursor {cursor_id} -> {e}") async def remove_old_closed_cached_cursors(): global closed_cached_cursors for cursor_id in set(closed_cached_cursors.keys()): closed_cursor = closed_cached_cursors.get(cursor_id, None) if closed_cursor is None: continue if ( closed_cursor.close_at + closed_cursor.ttl * 5 > datetime.datetime.now(datetime.UTC).timestamp() ): continue del closed_cached_cursors[cursor_id] print(f"Removed cursor {cursor_id}") async def cached_cursors_cleaner(): global cached_cursors, closed_cached_cursors while True: await close_old_cached_cursors() await remove_old_closed_cached_cursors() await asyncio.sleep(10) async def pool_creator(connection: Connection, minsize=5, maxsize=10): return await aiomysql.create_pool( # **DB_CONFIG, host=connection.host, user=connection.username, password=connection.password, db=connection.db_name, port=connection.port, minsize=minsize, maxsize=maxsize, ) async def create_cursor(connection_id: int, query: SelectQuery) -> CachedCursor: pool = pools.get(connection_id, None) if pool is None: raise PoolNotFound connection = await pool.acquire() cursor = await connection.cursor(aiomysql.SSCursor) await cursor.execute(query.sql, query.params) cached_cursor = CachedCursor( id=str(uuid.uuid4()), cursor=cursor, connection=connection, pool=pool, connection_id=connection_id, query=query, ) return cached_cursor async def execute_select_query( pool: aiomysql.Pool, sql_query: str, params: list, fetch_num: int = 100 ): """ Executes a SELECT query on the MySQL database asynchronously and returns the results. Args: pool (aiomysql.Pool): Connection pool to use. query (str): The SELECT query to execute. Returns: list: A list of rows (as dictionaries) returned by the query. """ try: async with pool.acquire() as connection: async with connection.cursor(aiomysql.DictCursor) as cursor: await cursor.execute(sql_query, params) result = await cursor.fetchmany(fetch_num) return result, cursor.rowcount except Exception as e: print(f"Error executing query: {e}") return [] def construct_select_query( table_name, columns: Literal["*"] | list[str] = "*", filters=None, sort_by=None, limit=None, offset=None, ): """ Constructs a dynamic SELECT query based on the provided parameters. Args: table_name (str): The name of the table to query. columns (str or list): The columns to select. Default is "*" (all columns). filters (dict): A dictionary of filters (e.g., {"column": "value"}). sort_by (str): The column to sort by (e.g., "column_name ASC"). limit (int): The maximum number of rows to return. offset (int): The number of rows to skip (for pagination). Returns: str: The constructed SELECT query. """ # Handle columns if isinstance(columns, list): columns = ", ".join(columns) elif columns != "*": raise ValueError("Columns must be a list or '*'.") # Base query query = f"SELECT {columns}\n FROM {table_name}" # Add filters if filters: filter_conditions = [ f"{column} = '{value}'" for column, value in filters.items() ] query += "\n WHERE " + " AND ".join(filter_conditions) # Add sorting if sort_by: query += f"\n ORDER BY {sort_by}" # Add pagination if limit is not None: query += f"\n LIMIT {limit}" if offset is not None: query += f"\n OFFSET {offset}" return query async def get_tables_and_datatypes(pool: aiomysql.Pool): """ Retrieves all table names and their column data types from a MySQL database. Args: db_config (dict): MySQL database configuration (host, user, password, db, port). Returns: dict: A dictionary where keys are table names and values are lists of (column_name, data_type) tuples. """ try: # Connect to the MySQL database # connection = await aiomysql.connect(**db_config) async with pool.acquire() as connection: async with connection.cursor() as cursor: # Get all table names await cursor.execute("SHOW TABLES") tables = await cursor.fetchall() result = {} for table in tables: table_name = table[0] result[table_name] = {} # Get column names and data types for the current table await cursor.execute(f"DESCRIBE {table_name}") columns = await cursor.fetchall() # Store column names and data types for col in columns: result[table_name][col[0]] = col[1] # Close the connection # connection.close() # await connection.wait_closed() return result except Exception as e: print(f"Error: {e}") return {} def serializer(raw_result: list[dict[str, Any]]): serialized = [] while raw_result: re = raw_result.pop() serialized_re = {} for key, value in re.items(): if isinstance(value, str | int | None | float): serialized_re[key] = str(value) elif isinstance(value, decimal.Decimal): serialized_re[key] = float(value) elif isinstance(value, datetime.date): serialized_re[key] = value.strftime("%Y-%m-%d") else: serialized_re[key] = str(value) print(key, type(value)) serialized.append(serialized_re) return serialized def dict_result_to_list(result: list[dict[str, Any]]) -> None | SelectResultData: if len(result) == 0: return None columns = list(result[0].keys()) data = [] for row in result: data.append(list(row.values())) return SelectResultData(columns=columns, data=data)