From 836ce1dc8259894b2159a93752478350cfa0b86e Mon Sep 17 00:00:00 2001 From: abdulhade Date: Mon, 24 Feb 2025 12:15:01 +0300 Subject: [PATCH] Created Queries and Execute endpoint --- app/__init__.py | 14 +++ app/connections.py | 12 +-- app/operations.py | 92 +++++++++++++++++++ app/users.py | 16 ++-- core/dependencies.py | 9 +- core/enums.py | 20 ++++- core/exceptions.py | 30 +++++++ data/crud.py | 21 ++++- data/models.py | 18 +++- data/schemas.py | 143 +++++++++++++++++++++++++++++- dbs/mysql.py | 177 ++++++++++++++++++++++++++++++++++++- main.py | 29 +++--- {core => utils}/scripts.py | 26 +++++- utils/sql_creator.py | 78 ++++++++++++++++ 14 files changed, 635 insertions(+), 50 deletions(-) create mode 100644 app/__init__.py create mode 100644 app/operations.py rename {core => utils}/scripts.py (65%) create mode 100644 utils/sql_creator.py diff --git a/app/__init__.py b/app/__init__.py new file mode 100644 index 0000000..ecce03e --- /dev/null +++ b/app/__init__.py @@ -0,0 +1,14 @@ +from fastapi import APIRouter + +from app.connections import router as connections_router +from app.operations import router as router +from app.users import router as user_router + +api_router = APIRouter() +api_router.include_router(router=user_router, prefix="/users", tags=["Users"]) +api_router.include_router( + router=connections_router, prefix="/connections", tags=["Connections"] +) +api_router.include_router( + router=router, prefix='/operations', tags=["Operations"] +) \ No newline at end of file diff --git a/app/connections.py b/app/connections.py index 3a790b3..c47c20e 100644 --- a/app/connections.py +++ b/app/connections.py @@ -11,10 +11,10 @@ from data.crud import ( ) from core.dependencies import get_db, get_current_user, get_admin_user -connections_router = APIRouter() +router = APIRouter() -@connections_router.post("/", status_code=status.HTTP_201_CREATED) +@router.post("/", status_code=status.HTTP_201_CREATED) async def create_connection_endpoint( connection: ConnectionCreate, db: AsyncSession = Depends(get_db), @@ -23,7 +23,7 @@ async def create_connection_endpoint( return await create_connection(db=db, connection=connection, user_id=admin.id) -@connections_router.get( +@router.get( "/", response_model=list[Connection], dependencies=[Depends(get_current_user)], @@ -35,7 +35,7 @@ async def read_connections_endpoint( return db_connection -@connections_router.get( +@router.get( "/{connection_id}", response_model=Connection, dependencies=[Depends(get_current_user)], @@ -47,7 +47,7 @@ async def read_connection_endpoint(connection_id: int, db: AsyncSession = Depend return db_connection -@connections_router.put( +@router.put( "/{connection_id}", response_model=Connection, dependencies=[Depends(get_admin_user)], @@ -63,7 +63,7 @@ async def update_connection_endpoint( return db_connection -@connections_router.delete( +@router.delete( "/{connection_id}", status_code=status.HTTP_204_NO_CONTENT, dependencies=[Depends(get_admin_user)], diff --git a/app/operations.py b/app/operations.py new file mode 100644 index 0000000..d50d75d --- /dev/null +++ b/app/operations.py @@ -0,0 +1,92 @@ +from fastapi.routing import APIRouter +from typing_extensions import Annotated +from pydantic import Field +from data.schemas import ( + SelectQueryBase, + SelectQueryInDB, + SelectQuery, + SelectQueryIn, + SelectResult, + SelectQueryMetaData, + SelectQueryInResult, +) +from fastapi import Depends, HTTPException, status +from sqlalchemy.ext.asyncio import AsyncSession +from data.crud import ( + read_connection, + create_select_query, + read_all_select_queries, + read_select_query, +) +from core.dependencies import get_db, get_current_user, get_admin_user +from core.exceptions import QueryNotFound, ConnectionNotFound, PoolNotFound +from utils.sql_creator import build_sql_query_text +from dbs import mysql + +router = APIRouter(prefix="/select") + + +@router.post("/check-query", dependencies=[Depends(get_current_user)]) +async def check_select_query(query: SelectQueryBase) -> SelectQuery: + sql, params = build_sql_query_text(query) + q = SelectQuery(**query.model_dump(), params=params, sql=sql) + return q + + +@router.post("/query") +async def create_select_query_endpoint( + query: SelectQueryBase, db=Depends(get_db), user=Depends(get_current_user) +) -> SelectQueryInDB: + sql, params = build_sql_query_text(query) + query_in = SelectQueryIn( + **query.model_dump(), owner_id=user.id, params=params, sql=sql + ) + return await create_select_query(db=db, query=query_in) + + +@router.get("/query", dependencies=[Depends(get_current_user)]) +async def get_select_queries_endpoint(db=Depends(get_db)) -> list[SelectQueryInDB]: + return await read_all_select_queries(db=db) + + +@router.get("/query/{query_id}", dependencies=[Depends(get_current_user)]) +async def get_select_queries_endpoint( + query_id: int, db=Depends(get_db) +) -> SelectQueryInDB: + return await read_select_query(db=db, query_id=query_id) + + +@router.post("/execute", dependencies=[Depends(get_current_user)]) +async def execute_select( + query_id: int, + connection_id: int, + page_size: Annotated[int, Field(ge=1, le=100)] = 50, + db=Depends(get_db), +) -> SelectResult: + query = await read_select_query(db=db, query_id=query_id) + + if query is None: + raise QueryNotFound + connection = await read_connection(db=db, connection_id=connection_id) + + if connection is None: + raise ConnectionNotFound + pool = mysql.pools.get(connection.id, None) + if pool is None: + raise PoolNotFound + + raw_result, rowcount = await mysql.execute_select_query( + pool=pool, query=query.sql, params=query.params, fetch_num=page_size + ) + + results = mysql.dict_result_to_list(result=mysql.serializer(raw_result=raw_result)) + + meta = SelectQueryMetaData( + cursor=None, total_number=rowcount, has_more=len(results.data) != rowcount + ) + + return SelectResult( + meta=meta, + query=query, + results=results, + ) diff --git a/app/users.py b/app/users.py index 1af92b4..8898ea7 100644 --- a/app/users.py +++ b/app/users.py @@ -7,17 +7,17 @@ from data.crud import read_all_users, read_user, create_user, delete_user from core.dependencies import get_db, get_current_user, get_admin_user from sqlalchemy.exc import IntegrityError from core.exceptions import ObjectNotFoundInDB, UserNotFound -from core.scripts import create_secret +from utils.scripts import create_secret -users_router = APIRouter() +router = APIRouter() -@users_router.get("/me") +@router.get("/me") async def get_me(user=Depends(get_current_user)) -> UserOut: return user -@users_router.get( +@router.get( "/", dependencies=[Depends(get_current_user)], ) @@ -25,7 +25,7 @@ async def get_all_users_endpoint(db=Depends(get_db)) -> list[UserOut]: return await read_all_users(db=db) -@users_router.post( +@router.post( "/", dependencies=[Depends(get_current_user)], status_code=status.HTTP_201_CREATED ) async def create_user_endpoint( @@ -42,7 +42,7 @@ async def create_user_endpoint( }, ) -@users_router.post('/update-my-api_key/', status_code=status.HTTP_204_NO_CONTENT) +@router.post('/update-my-api_key/', status_code=status.HTTP_204_NO_CONTENT) async def update_user_own_api_key(user=Depends(get_current_user), db=Depends(get_db)): if user.role == UserRole.admin: raise HTTPException(status_code=400, detail={ @@ -54,7 +54,7 @@ async def update_user_own_api_key(user=Depends(get_current_user), db=Depends(get await db.commit() await db.refresh(user) -@users_router.post('/update-user-api_key/', status_code=status.HTTP_202_ACCEPTED, dependencies=[Depends(get_admin_user)]) +@router.post('/update-user-api_key/', status_code=status.HTTP_202_ACCEPTED, dependencies=[Depends(get_admin_user)]) async def update_user_own_api_key(user_id:int, db=Depends(get_db)) -> UserInDBBase: user = await read_user(db=db, user_id=user_id) if user is None: @@ -65,7 +65,7 @@ async def update_user_own_api_key(user_id:int, db=Depends(get_db)) -> UserInDBBa await db.refresh(user) return user -@users_router.delete( +@router.delete( "/", dependencies=[Depends(get_admin_user)], status_code=status.HTTP_204_NO_CONTENT ) async def delete_user_endpoint(user_id: int, db=Depends(get_db)): diff --git a/core/dependencies.py b/core/dependencies.py index e5f6683..f7abea4 100644 --- a/core/dependencies.py +++ b/core/dependencies.py @@ -1,4 +1,3 @@ -# dependencies.py from fastapi import Depends, HTTPException, status, Security from fastapi.security import APIKeyHeader @@ -6,10 +5,6 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select from data.db import SessionLocal from data.models import User, UserRole -from pydantic import BaseModel - -# class UserInDB(User): -# hashed_password: str async def get_db(): async with SessionLocal() as session: @@ -18,7 +13,7 @@ async def get_db(): API_KEY_NAME = "Authorization" api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False) -async def get_current_user(db: AsyncSession = Depends(get_db), api_key:str = Security(api_key_header)): +async def get_current_user(db: AsyncSession = Depends(get_db), api_key:str = Security(api_key_header)) -> User: if api_key_header is None: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="API key missing" @@ -34,7 +29,7 @@ async def get_current_user(db: AsyncSession = Depends(get_db), api_key:str = Sec raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid credentials") return user -async def get_admin_user(current_user: User = Depends(get_current_user)): +async def get_admin_user(current_user: User = Depends(get_current_user)) -> User: if current_user.role != UserRole.admin: raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Not enough permissions") return current_user diff --git a/core/enums.py b/core/enums.py index ce2ce26..5b90661 100644 --- a/core/enums.py +++ b/core/enums.py @@ -6,4 +6,22 @@ class ConnectionTypes(str, enum.Enum): class UserRole(enum.Enum): admin = "admin" - user = "user" \ No newline at end of file + user = "user" + + +class FilterOperator(str, enum.Enum): + eq = "=" + neq = "!=" + gt = ">" + lt = "<" + gte = ">=" + lte = "<=" + like = "LIKE" + ilike = "ILIKE" + in_ = "IN" + is_null = "IS NULL" + is_not_null = "IS NOT NULL" + +class SortOrder(str, enum.Enum): + asc = "ASC" + desc = "DESC" \ No newline at end of file diff --git a/core/exceptions.py b/core/exceptions.py index ed7c453..1804b96 100644 --- a/core/exceptions.py +++ b/core/exceptions.py @@ -17,3 +17,33 @@ class UserNotFound(HTTPException): headers=None, ): super().__init__(status_code, detail, headers) + + +class QueryValidationError(ValueError): + def __init__(self, loc: list[str], msg: str): + self.loc = loc + self.msg = msg + super().__init__(msg) + +class QueryNotFound(HTTPException): + def __init__(self, status_code=404, detail = { + 'message': "The referenced query was not found.", + "code": 'query-not-found' + }, headers = None): + super().__init__(status_code, detail, headers) + + +class ConnectionNotFound(HTTPException): + def __init__(self, status_code=404, detail = { + 'message': "The referenced connection was not found.", + "code": 'connection-not-found' + }, headers = None): + super().__init__(status_code, detail, headers) + + +class PoolNotFound(HTTPException): + def __init__(self, status_code=404, detail = { + 'message': "We didn't find a running Pool for the referenced connection.", + "code": 'pool-not-found' + }, headers = None): + super().__init__(status_code, detail, headers) \ No newline at end of file diff --git a/data/crud.py b/data/crud.py index 5f2641b..c04e0b0 100644 --- a/data/crud.py +++ b/data/crud.py @@ -1,8 +1,8 @@ # crud.py from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select -from data.models import Connection, User -from data.schemas import ConnectionCreate, ConnectionUpdate, UserCreate +from data.models import Connection, User, Query +from data.schemas import ConnectionCreate, ConnectionUpdate, UserCreate, SelectQueryIn from core.exceptions import ObjectNotFoundInDB async def read_user(db:AsyncSession, user_id:int): @@ -14,7 +14,7 @@ async def read_all_users(db:AsyncSession): return result.scalars().all() async def create_user(db: AsyncSession, user: UserCreate): - from core.scripts import create_secret + from utils.scripts import create_secret db_user = User(**user.model_dump(), api_key=create_secret()) db.add(db_user) await db.commit() @@ -60,3 +60,18 @@ async def delete_connection(db: AsyncSession, connection_id: int): await db.delete(db_connection) await db.commit() return db_connection + +async def create_select_query(db:AsyncSession, query:SelectQueryIn) -> Query: + db_query = Query(**query.model_dump()) + db.add(db_query) + await db.commit() + await db.refresh(db_query) + return db_query + +async def read_all_select_queries(db:AsyncSession) -> list[Query]: + result = await db.execute(select(Query)) + return result.scalars().all() + +async def read_select_query(db:AsyncSession, query_id:int) -> Query: + result = await db.execute(select(Query).filter(Query.id == query_id)) + return result.scalars().first() \ No newline at end of file diff --git a/data/models.py b/data/models.py index a707aaa..54d8c3d 100644 --- a/data/models.py +++ b/data/models.py @@ -1,5 +1,5 @@ # models.py -from sqlalchemy import Column, Integer, String, Enum, ForeignKey +from sqlalchemy import Column, Integer, String, Enum, ForeignKey, JSON from sqlalchemy.orm import relationship from data.db import Base from core.enums import ConnectionTypes, UserRole @@ -27,3 +27,19 @@ class Connection(Base): owner_id = Column(Integer, ForeignKey("users.id")) # owner = relationship("User", back_populates="connections") + +class Query(Base): + __tablename__ = "queries" + + id = Column(Integer, primary_key=True, index=True) + name = Column(String, nullable=False) + description = Column(String, nullable=True) + owner_id = Column(Integer, ForeignKey("users.id")) + table_name = Column(String, nullable=False) + columns = Column(JSON, nullable=False) + filters = Column(JSON, nullable=True) + sort_by = Column(JSON, nullable=True) + limit = Column(Integer, nullable=True) + offset = Column(Integer, nullable=True) + sql= Column(String, nullable=False) + params = Column(JSON, nullable=False) diff --git a/data/schemas.py b/data/schemas.py index a713775..9678df4 100644 --- a/data/schemas.py +++ b/data/schemas.py @@ -1,7 +1,9 @@ -# schemas.py -from pydantic import BaseModel -from typing import Optional -from core.enums import ConnectionTypes, UserRole +import re +from typing import Union, List, Optional, Literal, Any +from typing_extensions import Annotated +from pydantic import BaseModel, Field, field_validator, ValidationInfo, UUID4 +from core.enums import ConnectionTypes, UserRole, FilterOperator, SortOrder +from core.exceptions import QueryValidationError class ConnectionBase(BaseModel): @@ -57,3 +59,136 @@ class UserInDBBase(UserBase): class User(UserInDBBase): pass + + +class FilterClause(BaseModel): + column: str + operator: FilterOperator + value: Optional[Union[str, int, float, bool, list]] = None + + @field_validator("column") + @classmethod + def validate_column(cls, v: str) -> str: + if not re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", v): + raise QueryValidationError( + loc=["filters", "column"], msg="Invalid column name format" + ) + return v + + @field_validator("value") + @classmethod + def validate_value( + cls, + v: Optional[Union[str, int, float, bool, list]], + values: ValidationInfo, + ) -> Optional[Union[str, int, float, bool, list]]: + operator = values.data.get("operator") + + if operator in [FilterOperator.is_null, FilterOperator.is_not_null]: + if v is not None: + raise QueryValidationError( + loc=["filters", "value"], + msg="Value must be null for IS NULL/IS NOT NULL operators", + ) + elif operator == FilterOperator.in_: + if not isinstance(v, list): + raise QueryValidationError( + loc=["filters", "value"], + msg="IN operator requires a list of values", + ) + elif v is None: + raise QueryValidationError( + loc=["filters", "value"], msg="Value required for this operator" + ) + return v + + +class SortClause(BaseModel): + column: str + order: SortOrder = SortOrder.asc + + @field_validator("column") + @classmethod + def validate_column(cls, v: str) -> str: + if not re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", v): + raise QueryValidationError( + loc=["sort_by", "column"], msg="Invalid column name format" + ) + return v + + +class SelectQueryBase(BaseModel): + name: str # a short name for this query. + description: str | None = None # describing what does this query do. + table_name: str + columns: Union[Literal["*"], List[str]] = "*" + filters: Optional[List[FilterClause]] = None + sort_by: Optional[List[SortClause]] = None + limit: Annotated[int, Field(strict=True, gt=0)] = None + offset: Annotated[int, Field(strict=True, ge=0)] = None + + @field_validator("table_name") + @classmethod + def validate_table_name(cls, v: str) -> str: + if not re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", v): + raise QueryValidationError( + loc=["table_name"], msg="Invalid table name format" + ) + return v + + @field_validator("columns") + @classmethod + def validate_columns( + cls, v: Union[Literal["*"], List[str]] + ) -> Union[Literal["*"], List[str]]: + if v == "*": + return v + + for col in v: + if not re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", col): + raise QueryValidationError( + loc=["columns"], msg=f"Invalid column name: {col}" + ) + return v + + +class SelectQuery(SelectQueryBase): + sql: str + params: list[str | int | float] + + +class SelectQueryIn(SelectQuery): + owner_id: int + + +class SelectQueryInDB(SelectQueryIn): + id: int + + +class SelectResultData(BaseModel): + columns: List[str] + data: List[List[Any]] + + +class SelectQueryInResult(BaseModel): + id: int + sql: str + params: list[str | int | float] + + class Config: + from_attributes = True + + +class SelectQueryMetaData(BaseModel): + cursor: Optional[UUID4] = Field( + None, + description="A UUID4 cursor for pagination. Can be None if no more data is available.", + ) + total_number: int + has_more: bool = False + + +class SelectResult(BaseModel): + meta: SelectQueryMetaData + query: SelectQueryInResult + results: SelectResultData | None diff --git a/dbs/mysql.py b/dbs/mysql.py index 83d2f9b..099cbb0 100644 --- a/dbs/mysql.py +++ b/dbs/mysql.py @@ -1,2 +1,177 @@ -import aiomysql +import aiomysql, decimal, datetime +import asyncio +from typing import Literal, Any +from data.schemas import Connection, SelectResultData +# Database configuration +DB_CONFIG = { + "host": "localhost", + "user": "me", # Replace with your MySQL username + "password": "Passwd3.14", # Replace with your MySQL password + "db": "testing", # Replace with your database name + "port": 3306, # Default MySQL port +} + +pools: None | dict[str, aiomysql.Pool] = {} + + +async def pool_creator(connection: Connection, minsize=5, maxsize=10): + + return await aiomysql.create_pool( + # **DB_CONFIG, + host=connection.host, + user=connection.username, + password=connection.password, + db=connection.db_name, + port=connection.port, + minsize=minsize, + maxsize=maxsize, + ) + + +async def execute_select_query(pool: aiomysql.Pool, query: str, params: list, fetch_num:int = 100): + """ + Executes a SELECT query on the MySQL database asynchronously and returns the results. + + Args: + pool (aiomysql.Pool): Connection pool to use. + query (str): The SELECT query to execute. + + Returns: + list: A list of rows (as dictionaries) returned by the query. + """ + try: + async with pool.acquire() as connection: + async with connection.cursor(aiomysql.DictCursor) as cursor: + await cursor.execute(query, params) + result = await cursor.fetchmany(fetch_num) + + return result, cursor.rowcount + + except Exception as e: + print(f"Error executing query: {e}") + return [] + + +def construct_select_query( + table_name, + columns: Literal["*"] | list[str] = "*", + filters=None, + sort_by=None, + limit=None, + offset=None, +): + """ + Constructs a dynamic SELECT query based on the provided parameters. + + Args: + table_name (str): The name of the table to query. + columns (str or list): The columns to select. Default is "*" (all columns). + filters (dict): A dictionary of filters (e.g., {"column": "value"}). + sort_by (str): The column to sort by (e.g., "column_name ASC"). + limit (int): The maximum number of rows to return. + offset (int): The number of rows to skip (for pagination). + + Returns: + str: The constructed SELECT query. + """ + # Handle columns + if isinstance(columns, list): + columns = ", ".join(columns) + elif columns != "*": + raise ValueError("Columns must be a list or '*'.") + + # Base query + query = f"SELECT {columns}\n FROM {table_name}" + + # Add filters + if filters: + filter_conditions = [ + f"{column} = '{value}'" for column, value in filters.items() + ] + query += "\n WHERE " + " AND ".join(filter_conditions) + + # Add sorting + if sort_by: + query += f"\n ORDER BY {sort_by}" + + # Add pagination + if limit is not None: + query += f"\n LIMIT {limit}" + if offset is not None: + query += f"\n OFFSET {offset}" + + return query + + +async def get_tables_and_datatypes(pool: aiomysql.Pool): + """ + Retrieves all table names and their column data types from a MySQL database. + + Args: + db_config (dict): MySQL database configuration (host, user, password, db, port). + + Returns: + dict: A dictionary where keys are table names and values are lists of (column_name, data_type) tuples. + """ + try: + # Connect to the MySQL database + # connection = await aiomysql.connect(**db_config) + async with pool.acquire() as connection: + async with connection.cursor() as cursor: + # Get all table names + await cursor.execute("SHOW TABLES") + tables = await cursor.fetchall() + + result = {} + for table in tables: + table_name = table[0] + result[table_name] = {} + + # Get column names and data types for the current table + await cursor.execute(f"DESCRIBE {table_name}") + columns = await cursor.fetchall() + # Store column names and data types + for col in columns: + result[table_name][col[0]] = col[1] + + # Close the connection + # connection.close() + # await connection.wait_closed() + + return result + except Exception as e: + print(f"Error: {e}") + return {} + + +def serializer(raw_result: list[dict[str, Any]]): + serialized = [] + while raw_result: + re = raw_result.pop() + serialized_re = {} + for key, value in re.items(): + if isinstance(value, str | int | None | float): + serialized_re[key] = str(value) + elif isinstance(value, decimal.Decimal): + serialized_re[key] = float(value) + elif isinstance(value, datetime.date): + serialized_re[key] = value.strftime("%Y-%m-%d") + else: + serialized_re[key] = str(value) + print(key, type(value)) + + serialized.append(serialized_re) + return serialized + + +def dict_result_to_list(result: list[dict[str, Any]]) -> None | SelectResultData: + if len(result) == 0: + return None + + columns = list(result[0].keys()) + data = [] + for row in result: + data.append(list(row.values())) + + return SelectResultData(columns=columns, data=data) diff --git a/main.py b/main.py index 2cc1efc..cb436d7 100644 --- a/main.py +++ b/main.py @@ -1,26 +1,19 @@ +from contextlib import asynccontextmanager + from fastapi import FastAPI -from data.db import engine, Base -from app.connections import connections_router -from app.users import users_router +from app import api_router -app = FastAPI() - -app.include_router(router=users_router, prefix="/users", tags=["Users"]) -app.include_router( - router=connections_router, prefix="/connections", tags=["Connections"] -) +from utils.scripts import pools_creator, pools_destroy, db_startup -# @app.on_event("startup") -async def startup(): - async with engine.begin() as conn: - await conn.run_sync(Base.metadata.create_all) -# import asyncio -# asyncio.run(startup()) +@asynccontextmanager +async def lifespan(app: FastAPI): + await pools_creator() + yield + await pools_destroy() -# import uvicorn - -# uvicorn.run(app=app) +app = FastAPI(lifespan=lifespan) +app.include_router(router=api_router) \ No newline at end of file diff --git a/core/scripts.py b/utils/scripts.py similarity index 65% rename from core/scripts.py rename to utils/scripts.py index 6b6e9a1..dea0887 100644 --- a/core/scripts.py +++ b/utils/scripts.py @@ -1,5 +1,6 @@ # add_user.py -import asyncio + +import asyncio, logging import secrets from sqlalchemy.future import select from sqlalchemy.exc import IntegrityError @@ -7,6 +8,29 @@ from getpass import getpass from data.db import engine, SessionLocal from data.models import Base, User, UserRole + +async def pools_creator(): + from data.crud import read_all_connections + from dbs import mysql + + async with SessionLocal() as db: + connections = await read_all_connections(db=db) + + for connection in connections: + mysql.pools[connection.id] = await mysql.pool_creator(connection=connection) + logging.info(msg='Created Pools') + +async def pools_destroy(): + from dbs import mysql + for connection_id, pool in mysql.pools.items(): + pool.close() + await pool.wait_closed() + logging.info(f'Closed pool: {connection_id}') + +async def db_startup(): + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + def create_secret(): return secrets.token_hex(32) diff --git a/utils/sql_creator.py b/utils/sql_creator.py new file mode 100644 index 0000000..0242e8f --- /dev/null +++ b/utils/sql_creator.py @@ -0,0 +1,78 @@ +from data.schemas import SelectQueryBase + +def build_sql_query_text(query: SelectQueryBase) -> tuple[str, list]: + """ + Builds a SQL query text and parameters from a SelectQuery schema object. + + Args: + query (SelectQuery): The query schema object. + + Returns: + tuple[str, list]: The SQL query text and a list of parameters. + """ + + # Build SELECT clause + if query.columns == "*": + select_clause = "SELECT *" + else: + select_clause = f"SELECT {', '.join(query.columns)}" + + # Build FROM clause + from_clause = f"FROM {query.table_name}" + + # Build WHERE clause + where_clause = "" + params = [] + if query.filters: + conditions = [] + for filter in query.filters: + column = filter.column + operator = filter.operator.value + value = filter.value + + if operator in ["IS NULL", "IS NOT NULL"]: + conditions.append(f"{column} {operator}") + elif operator == "IN": + placeholders = ", ".join(["%s"] * len(value)) + conditions.append(f"{column} IN ({placeholders})") + params.extend(value) + else: + # operators like < > == != + conditions.append(f"{column} {operator} %s") + params.append(value) + + where_clause = "WHERE " + " AND ".join(conditions) + + # Build ORDER BY clause + order_by_clause = "" + if query.sort_by: + order_by = [] + for sort in query.sort_by: + column = sort.column + order = sort.order.value + order_by.append(f"{column} {order}") + order_by_clause = "ORDER BY " + ", ".join(order_by) + + # Build LIMIT and OFFSET clauses + limit_clause = "" + if query.limit: + limit_clause = f"LIMIT {query.limit}" + + offset_clause = "" + if query.offset: + offset_clause = f"OFFSET {query.offset}" + + # Combine all clauses + sql_query = """ + """.join( + [ + select_clause, + from_clause, + where_clause, + order_by_clause, + limit_clause, + offset_clause, + ] + ).strip() + + return sql_query, params \ No newline at end of file