import aiomysql, decimal, datetime import asyncio from typing import Literal, Any from data.schemas import Connection, SelectResultData # Database configuration DB_CONFIG = { "host": "localhost", "user": "me", # Replace with your MySQL username "password": "Passwd3.14", # Replace with your MySQL password "db": "testing", # Replace with your database name "port": 3306, # Default MySQL port } pools: None | dict[str, aiomysql.Pool] = {} async def pool_creator(connection: Connection, minsize=5, maxsize=10): return await aiomysql.create_pool( # **DB_CONFIG, host=connection.host, user=connection.username, password=connection.password, db=connection.db_name, port=connection.port, minsize=minsize, maxsize=maxsize, ) async def execute_select_query(pool: aiomysql.Pool, query: str, params: list, fetch_num:int = 100): """ Executes a SELECT query on the MySQL database asynchronously and returns the results. Args: pool (aiomysql.Pool): Connection pool to use. query (str): The SELECT query to execute. Returns: list: A list of rows (as dictionaries) returned by the query. """ try: async with pool.acquire() as connection: async with connection.cursor(aiomysql.DictCursor) as cursor: await cursor.execute(query, params) result = await cursor.fetchmany(fetch_num) return result, cursor.rowcount except Exception as e: print(f"Error executing query: {e}") return [] def construct_select_query( table_name, columns: Literal["*"] | list[str] = "*", filters=None, sort_by=None, limit=None, offset=None, ): """ Constructs a dynamic SELECT query based on the provided parameters. Args: table_name (str): The name of the table to query. columns (str or list): The columns to select. Default is "*" (all columns). filters (dict): A dictionary of filters (e.g., {"column": "value"}). sort_by (str): The column to sort by (e.g., "column_name ASC"). limit (int): The maximum number of rows to return. offset (int): The number of rows to skip (for pagination). Returns: str: The constructed SELECT query. """ # Handle columns if isinstance(columns, list): columns = ", ".join(columns) elif columns != "*": raise ValueError("Columns must be a list or '*'.") # Base query query = f"SELECT {columns}\n FROM {table_name}" # Add filters if filters: filter_conditions = [ f"{column} = '{value}'" for column, value in filters.items() ] query += "\n WHERE " + " AND ".join(filter_conditions) # Add sorting if sort_by: query += f"\n ORDER BY {sort_by}" # Add pagination if limit is not None: query += f"\n LIMIT {limit}" if offset is not None: query += f"\n OFFSET {offset}" return query async def get_tables_and_datatypes(pool: aiomysql.Pool): """ Retrieves all table names and their column data types from a MySQL database. Args: db_config (dict): MySQL database configuration (host, user, password, db, port). Returns: dict: A dictionary where keys are table names and values are lists of (column_name, data_type) tuples. """ try: # Connect to the MySQL database # connection = await aiomysql.connect(**db_config) async with pool.acquire() as connection: async with connection.cursor() as cursor: # Get all table names await cursor.execute("SHOW TABLES") tables = await cursor.fetchall() result = {} for table in tables: table_name = table[0] result[table_name] = {} # Get column names and data types for the current table await cursor.execute(f"DESCRIBE {table_name}") columns = await cursor.fetchall() # Store column names and data types for col in columns: result[table_name][col[0]] = col[1] # Close the connection # connection.close() # await connection.wait_closed() return result except Exception as e: print(f"Error: {e}") return {} def serializer(raw_result: list[dict[str, Any]]): serialized = [] while raw_result: re = raw_result.pop() serialized_re = {} for key, value in re.items(): if isinstance(value, str | int | None | float): serialized_re[key] = str(value) elif isinstance(value, decimal.Decimal): serialized_re[key] = float(value) elif isinstance(value, datetime.date): serialized_re[key] = value.strftime("%Y-%m-%d") else: serialized_re[key] = str(value) print(key, type(value)) serialized.append(serialized_re) return serialized def dict_result_to_list(result: list[dict[str, Any]]) -> None | SelectResultData: if len(result) == 0: return None columns = list(result[0].keys()) data = [] for row in result: data.append(list(row.values())) return SelectResultData(columns=columns, data=data)