Created Queries and Execute endpoint
This commit is contained in:
14
app/__init__.py
Normal file
14
app/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
from app.connections import router as connections_router
|
||||
from app.operations import router as router
|
||||
from app.users import router as user_router
|
||||
|
||||
api_router = APIRouter()
|
||||
api_router.include_router(router=user_router, prefix="/users", tags=["Users"])
|
||||
api_router.include_router(
|
||||
router=connections_router, prefix="/connections", tags=["Connections"]
|
||||
)
|
||||
api_router.include_router(
|
||||
router=router, prefix='/operations', tags=["Operations"]
|
||||
)
|
||||
@@ -11,10 +11,10 @@ from data.crud import (
|
||||
)
|
||||
from core.dependencies import get_db, get_current_user, get_admin_user
|
||||
|
||||
connections_router = APIRouter()
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@connections_router.post("/", status_code=status.HTTP_201_CREATED)
|
||||
@router.post("/", status_code=status.HTTP_201_CREATED)
|
||||
async def create_connection_endpoint(
|
||||
connection: ConnectionCreate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
@@ -23,7 +23,7 @@ async def create_connection_endpoint(
|
||||
return await create_connection(db=db, connection=connection, user_id=admin.id)
|
||||
|
||||
|
||||
@connections_router.get(
|
||||
@router.get(
|
||||
"/",
|
||||
response_model=list[Connection],
|
||||
dependencies=[Depends(get_current_user)],
|
||||
@@ -35,7 +35,7 @@ async def read_connections_endpoint(
|
||||
return db_connection
|
||||
|
||||
|
||||
@connections_router.get(
|
||||
@router.get(
|
||||
"/{connection_id}",
|
||||
response_model=Connection,
|
||||
dependencies=[Depends(get_current_user)],
|
||||
@@ -47,7 +47,7 @@ async def read_connection_endpoint(connection_id: int, db: AsyncSession = Depend
|
||||
return db_connection
|
||||
|
||||
|
||||
@connections_router.put(
|
||||
@router.put(
|
||||
"/{connection_id}",
|
||||
response_model=Connection,
|
||||
dependencies=[Depends(get_admin_user)],
|
||||
@@ -63,7 +63,7 @@ async def update_connection_endpoint(
|
||||
return db_connection
|
||||
|
||||
|
||||
@connections_router.delete(
|
||||
@router.delete(
|
||||
"/{connection_id}",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
dependencies=[Depends(get_admin_user)],
|
||||
|
||||
92
app/operations.py
Normal file
92
app/operations.py
Normal file
@@ -0,0 +1,92 @@
|
||||
from fastapi.routing import APIRouter
|
||||
from typing_extensions import Annotated
|
||||
from pydantic import Field
|
||||
from data.schemas import (
|
||||
SelectQueryBase,
|
||||
SelectQueryInDB,
|
||||
SelectQuery,
|
||||
SelectQueryIn,
|
||||
SelectResult,
|
||||
SelectQueryMetaData,
|
||||
SelectQueryInResult,
|
||||
)
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from data.crud import (
|
||||
read_connection,
|
||||
create_select_query,
|
||||
read_all_select_queries,
|
||||
read_select_query,
|
||||
)
|
||||
from core.dependencies import get_db, get_current_user, get_admin_user
|
||||
from core.exceptions import QueryNotFound, ConnectionNotFound, PoolNotFound
|
||||
from utils.sql_creator import build_sql_query_text
|
||||
from dbs import mysql
|
||||
|
||||
router = APIRouter(prefix="/select")
|
||||
|
||||
|
||||
@router.post("/check-query", dependencies=[Depends(get_current_user)])
|
||||
async def check_select_query(query: SelectQueryBase) -> SelectQuery:
|
||||
sql, params = build_sql_query_text(query)
|
||||
q = SelectQuery(**query.model_dump(), params=params, sql=sql)
|
||||
return q
|
||||
|
||||
|
||||
@router.post("/query")
|
||||
async def create_select_query_endpoint(
|
||||
query: SelectQueryBase, db=Depends(get_db), user=Depends(get_current_user)
|
||||
) -> SelectQueryInDB:
|
||||
sql, params = build_sql_query_text(query)
|
||||
query_in = SelectQueryIn(
|
||||
**query.model_dump(), owner_id=user.id, params=params, sql=sql
|
||||
)
|
||||
return await create_select_query(db=db, query=query_in)
|
||||
|
||||
|
||||
@router.get("/query", dependencies=[Depends(get_current_user)])
|
||||
async def get_select_queries_endpoint(db=Depends(get_db)) -> list[SelectQueryInDB]:
|
||||
return await read_all_select_queries(db=db)
|
||||
|
||||
|
||||
@router.get("/query/{query_id}", dependencies=[Depends(get_current_user)])
|
||||
async def get_select_queries_endpoint(
|
||||
query_id: int, db=Depends(get_db)
|
||||
) -> SelectQueryInDB:
|
||||
return await read_select_query(db=db, query_id=query_id)
|
||||
|
||||
|
||||
@router.post("/execute", dependencies=[Depends(get_current_user)])
|
||||
async def execute_select(
|
||||
query_id: int,
|
||||
connection_id: int,
|
||||
page_size: Annotated[int, Field(ge=1, le=100)] = 50,
|
||||
db=Depends(get_db),
|
||||
) -> SelectResult:
|
||||
query = await read_select_query(db=db, query_id=query_id)
|
||||
|
||||
if query is None:
|
||||
raise QueryNotFound
|
||||
connection = await read_connection(db=db, connection_id=connection_id)
|
||||
|
||||
if connection is None:
|
||||
raise ConnectionNotFound
|
||||
pool = mysql.pools.get(connection.id, None)
|
||||
if pool is None:
|
||||
raise PoolNotFound
|
||||
|
||||
raw_result, rowcount = await mysql.execute_select_query(
|
||||
pool=pool, query=query.sql, params=query.params, fetch_num=page_size
|
||||
)
|
||||
|
||||
results = mysql.dict_result_to_list(result=mysql.serializer(raw_result=raw_result))
|
||||
|
||||
meta = SelectQueryMetaData(
|
||||
cursor=None, total_number=rowcount, has_more=len(results.data) != rowcount
|
||||
)
|
||||
|
||||
return SelectResult(
|
||||
meta=meta,
|
||||
query=query,
|
||||
results=results,
|
||||
)
|
||||
16
app/users.py
16
app/users.py
@@ -7,17 +7,17 @@ from data.crud import read_all_users, read_user, create_user, delete_user
|
||||
from core.dependencies import get_db, get_current_user, get_admin_user
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from core.exceptions import ObjectNotFoundInDB, UserNotFound
|
||||
from core.scripts import create_secret
|
||||
from utils.scripts import create_secret
|
||||
|
||||
users_router = APIRouter()
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@users_router.get("/me")
|
||||
@router.get("/me")
|
||||
async def get_me(user=Depends(get_current_user)) -> UserOut:
|
||||
return user
|
||||
|
||||
|
||||
@users_router.get(
|
||||
@router.get(
|
||||
"/",
|
||||
dependencies=[Depends(get_current_user)],
|
||||
)
|
||||
@@ -25,7 +25,7 @@ async def get_all_users_endpoint(db=Depends(get_db)) -> list[UserOut]:
|
||||
return await read_all_users(db=db)
|
||||
|
||||
|
||||
@users_router.post(
|
||||
@router.post(
|
||||
"/", dependencies=[Depends(get_current_user)], status_code=status.HTTP_201_CREATED
|
||||
)
|
||||
async def create_user_endpoint(
|
||||
@@ -42,7 +42,7 @@ async def create_user_endpoint(
|
||||
},
|
||||
)
|
||||
|
||||
@users_router.post('/update-my-api_key/', status_code=status.HTTP_204_NO_CONTENT)
|
||||
@router.post('/update-my-api_key/', status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def update_user_own_api_key(user=Depends(get_current_user), db=Depends(get_db)):
|
||||
if user.role == UserRole.admin:
|
||||
raise HTTPException(status_code=400, detail={
|
||||
@@ -54,7 +54,7 @@ async def update_user_own_api_key(user=Depends(get_current_user), db=Depends(get
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
|
||||
@users_router.post('/update-user-api_key/', status_code=status.HTTP_202_ACCEPTED, dependencies=[Depends(get_admin_user)])
|
||||
@router.post('/update-user-api_key/', status_code=status.HTTP_202_ACCEPTED, dependencies=[Depends(get_admin_user)])
|
||||
async def update_user_own_api_key(user_id:int, db=Depends(get_db)) -> UserInDBBase:
|
||||
user = await read_user(db=db, user_id=user_id)
|
||||
if user is None:
|
||||
@@ -65,7 +65,7 @@ async def update_user_own_api_key(user_id:int, db=Depends(get_db)) -> UserInDBBa
|
||||
await db.refresh(user)
|
||||
return user
|
||||
|
||||
@users_router.delete(
|
||||
@router.delete(
|
||||
"/", dependencies=[Depends(get_admin_user)], status_code=status.HTTP_204_NO_CONTENT
|
||||
)
|
||||
async def delete_user_endpoint(user_id: int, db=Depends(get_db)):
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# dependencies.py
|
||||
from fastapi import Depends, HTTPException, status, Security
|
||||
|
||||
from fastapi.security import APIKeyHeader
|
||||
@@ -6,10 +5,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from data.db import SessionLocal
|
||||
from data.models import User, UserRole
|
||||
from pydantic import BaseModel
|
||||
|
||||
# class UserInDB(User):
|
||||
# hashed_password: str
|
||||
|
||||
async def get_db():
|
||||
async with SessionLocal() as session:
|
||||
@@ -18,7 +13,7 @@ async def get_db():
|
||||
API_KEY_NAME = "Authorization"
|
||||
api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False)
|
||||
|
||||
async def get_current_user(db: AsyncSession = Depends(get_db), api_key:str = Security(api_key_header)):
|
||||
async def get_current_user(db: AsyncSession = Depends(get_db), api_key:str = Security(api_key_header)) -> User:
|
||||
if api_key_header is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail="API key missing"
|
||||
@@ -34,7 +29,7 @@ async def get_current_user(db: AsyncSession = Depends(get_db), api_key:str = Sec
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid credentials")
|
||||
return user
|
||||
|
||||
async def get_admin_user(current_user: User = Depends(get_current_user)):
|
||||
async def get_admin_user(current_user: User = Depends(get_current_user)) -> User:
|
||||
if current_user.role != UserRole.admin:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Not enough permissions")
|
||||
return current_user
|
||||
|
||||
@@ -7,3 +7,21 @@ class ConnectionTypes(str, enum.Enum):
|
||||
class UserRole(enum.Enum):
|
||||
admin = "admin"
|
||||
user = "user"
|
||||
|
||||
|
||||
class FilterOperator(str, enum.Enum):
|
||||
eq = "="
|
||||
neq = "!="
|
||||
gt = ">"
|
||||
lt = "<"
|
||||
gte = ">="
|
||||
lte = "<="
|
||||
like = "LIKE"
|
||||
ilike = "ILIKE"
|
||||
in_ = "IN"
|
||||
is_null = "IS NULL"
|
||||
is_not_null = "IS NOT NULL"
|
||||
|
||||
class SortOrder(str, enum.Enum):
|
||||
asc = "ASC"
|
||||
desc = "DESC"
|
||||
@@ -17,3 +17,33 @@ class UserNotFound(HTTPException):
|
||||
headers=None,
|
||||
):
|
||||
super().__init__(status_code, detail, headers)
|
||||
|
||||
|
||||
class QueryValidationError(ValueError):
|
||||
def __init__(self, loc: list[str], msg: str):
|
||||
self.loc = loc
|
||||
self.msg = msg
|
||||
super().__init__(msg)
|
||||
|
||||
class QueryNotFound(HTTPException):
|
||||
def __init__(self, status_code=404, detail = {
|
||||
'message': "The referenced query was not found.",
|
||||
"code": 'query-not-found'
|
||||
}, headers = None):
|
||||
super().__init__(status_code, detail, headers)
|
||||
|
||||
|
||||
class ConnectionNotFound(HTTPException):
|
||||
def __init__(self, status_code=404, detail = {
|
||||
'message': "The referenced connection was not found.",
|
||||
"code": 'connection-not-found'
|
||||
}, headers = None):
|
||||
super().__init__(status_code, detail, headers)
|
||||
|
||||
|
||||
class PoolNotFound(HTTPException):
|
||||
def __init__(self, status_code=404, detail = {
|
||||
'message': "We didn't find a running Pool for the referenced connection.",
|
||||
"code": 'pool-not-found'
|
||||
}, headers = None):
|
||||
super().__init__(status_code, detail, headers)
|
||||
21
data/crud.py
21
data/crud.py
@@ -1,8 +1,8 @@
|
||||
# crud.py
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from data.models import Connection, User
|
||||
from data.schemas import ConnectionCreate, ConnectionUpdate, UserCreate
|
||||
from data.models import Connection, User, Query
|
||||
from data.schemas import ConnectionCreate, ConnectionUpdate, UserCreate, SelectQueryIn
|
||||
from core.exceptions import ObjectNotFoundInDB
|
||||
|
||||
async def read_user(db:AsyncSession, user_id:int):
|
||||
@@ -14,7 +14,7 @@ async def read_all_users(db:AsyncSession):
|
||||
return result.scalars().all()
|
||||
|
||||
async def create_user(db: AsyncSession, user: UserCreate):
|
||||
from core.scripts import create_secret
|
||||
from utils.scripts import create_secret
|
||||
db_user = User(**user.model_dump(), api_key=create_secret())
|
||||
db.add(db_user)
|
||||
await db.commit()
|
||||
@@ -60,3 +60,18 @@ async def delete_connection(db: AsyncSession, connection_id: int):
|
||||
await db.delete(db_connection)
|
||||
await db.commit()
|
||||
return db_connection
|
||||
|
||||
async def create_select_query(db:AsyncSession, query:SelectQueryIn) -> Query:
|
||||
db_query = Query(**query.model_dump())
|
||||
db.add(db_query)
|
||||
await db.commit()
|
||||
await db.refresh(db_query)
|
||||
return db_query
|
||||
|
||||
async def read_all_select_queries(db:AsyncSession) -> list[Query]:
|
||||
result = await db.execute(select(Query))
|
||||
return result.scalars().all()
|
||||
|
||||
async def read_select_query(db:AsyncSession, query_id:int) -> Query:
|
||||
result = await db.execute(select(Query).filter(Query.id == query_id))
|
||||
return result.scalars().first()
|
||||
@@ -1,5 +1,5 @@
|
||||
# models.py
|
||||
from sqlalchemy import Column, Integer, String, Enum, ForeignKey
|
||||
from sqlalchemy import Column, Integer, String, Enum, ForeignKey, JSON
|
||||
from sqlalchemy.orm import relationship
|
||||
from data.db import Base
|
||||
from core.enums import ConnectionTypes, UserRole
|
||||
@@ -27,3 +27,19 @@ class Connection(Base):
|
||||
owner_id = Column(Integer, ForeignKey("users.id"))
|
||||
|
||||
# owner = relationship("User", back_populates="connections")
|
||||
|
||||
class Query(Base):
|
||||
__tablename__ = "queries"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
name = Column(String, nullable=False)
|
||||
description = Column(String, nullable=True)
|
||||
owner_id = Column(Integer, ForeignKey("users.id"))
|
||||
table_name = Column(String, nullable=False)
|
||||
columns = Column(JSON, nullable=False)
|
||||
filters = Column(JSON, nullable=True)
|
||||
sort_by = Column(JSON, nullable=True)
|
||||
limit = Column(Integer, nullable=True)
|
||||
offset = Column(Integer, nullable=True)
|
||||
sql= Column(String, nullable=False)
|
||||
params = Column(JSON, nullable=False)
|
||||
|
||||
143
data/schemas.py
143
data/schemas.py
@@ -1,7 +1,9 @@
|
||||
# schemas.py
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
from core.enums import ConnectionTypes, UserRole
|
||||
import re
|
||||
from typing import Union, List, Optional, Literal, Any
|
||||
from typing_extensions import Annotated
|
||||
from pydantic import BaseModel, Field, field_validator, ValidationInfo, UUID4
|
||||
from core.enums import ConnectionTypes, UserRole, FilterOperator, SortOrder
|
||||
from core.exceptions import QueryValidationError
|
||||
|
||||
|
||||
class ConnectionBase(BaseModel):
|
||||
@@ -57,3 +59,136 @@ class UserInDBBase(UserBase):
|
||||
|
||||
class User(UserInDBBase):
|
||||
pass
|
||||
|
||||
|
||||
class FilterClause(BaseModel):
|
||||
column: str
|
||||
operator: FilterOperator
|
||||
value: Optional[Union[str, int, float, bool, list]] = None
|
||||
|
||||
@field_validator("column")
|
||||
@classmethod
|
||||
def validate_column(cls, v: str) -> str:
|
||||
if not re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", v):
|
||||
raise QueryValidationError(
|
||||
loc=["filters", "column"], msg="Invalid column name format"
|
||||
)
|
||||
return v
|
||||
|
||||
@field_validator("value")
|
||||
@classmethod
|
||||
def validate_value(
|
||||
cls,
|
||||
v: Optional[Union[str, int, float, bool, list]],
|
||||
values: ValidationInfo,
|
||||
) -> Optional[Union[str, int, float, bool, list]]:
|
||||
operator = values.data.get("operator")
|
||||
|
||||
if operator in [FilterOperator.is_null, FilterOperator.is_not_null]:
|
||||
if v is not None:
|
||||
raise QueryValidationError(
|
||||
loc=["filters", "value"],
|
||||
msg="Value must be null for IS NULL/IS NOT NULL operators",
|
||||
)
|
||||
elif operator == FilterOperator.in_:
|
||||
if not isinstance(v, list):
|
||||
raise QueryValidationError(
|
||||
loc=["filters", "value"],
|
||||
msg="IN operator requires a list of values",
|
||||
)
|
||||
elif v is None:
|
||||
raise QueryValidationError(
|
||||
loc=["filters", "value"], msg="Value required for this operator"
|
||||
)
|
||||
return v
|
||||
|
||||
|
||||
class SortClause(BaseModel):
|
||||
column: str
|
||||
order: SortOrder = SortOrder.asc
|
||||
|
||||
@field_validator("column")
|
||||
@classmethod
|
||||
def validate_column(cls, v: str) -> str:
|
||||
if not re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", v):
|
||||
raise QueryValidationError(
|
||||
loc=["sort_by", "column"], msg="Invalid column name format"
|
||||
)
|
||||
return v
|
||||
|
||||
|
||||
class SelectQueryBase(BaseModel):
|
||||
name: str # a short name for this query.
|
||||
description: str | None = None # describing what does this query do.
|
||||
table_name: str
|
||||
columns: Union[Literal["*"], List[str]] = "*"
|
||||
filters: Optional[List[FilterClause]] = None
|
||||
sort_by: Optional[List[SortClause]] = None
|
||||
limit: Annotated[int, Field(strict=True, gt=0)] = None
|
||||
offset: Annotated[int, Field(strict=True, ge=0)] = None
|
||||
|
||||
@field_validator("table_name")
|
||||
@classmethod
|
||||
def validate_table_name(cls, v: str) -> str:
|
||||
if not re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", v):
|
||||
raise QueryValidationError(
|
||||
loc=["table_name"], msg="Invalid table name format"
|
||||
)
|
||||
return v
|
||||
|
||||
@field_validator("columns")
|
||||
@classmethod
|
||||
def validate_columns(
|
||||
cls, v: Union[Literal["*"], List[str]]
|
||||
) -> Union[Literal["*"], List[str]]:
|
||||
if v == "*":
|
||||
return v
|
||||
|
||||
for col in v:
|
||||
if not re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", col):
|
||||
raise QueryValidationError(
|
||||
loc=["columns"], msg=f"Invalid column name: {col}"
|
||||
)
|
||||
return v
|
||||
|
||||
|
||||
class SelectQuery(SelectQueryBase):
|
||||
sql: str
|
||||
params: list[str | int | float]
|
||||
|
||||
|
||||
class SelectQueryIn(SelectQuery):
|
||||
owner_id: int
|
||||
|
||||
|
||||
class SelectQueryInDB(SelectQueryIn):
|
||||
id: int
|
||||
|
||||
|
||||
class SelectResultData(BaseModel):
|
||||
columns: List[str]
|
||||
data: List[List[Any]]
|
||||
|
||||
|
||||
class SelectQueryInResult(BaseModel):
|
||||
id: int
|
||||
sql: str
|
||||
params: list[str | int | float]
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class SelectQueryMetaData(BaseModel):
|
||||
cursor: Optional[UUID4] = Field(
|
||||
None,
|
||||
description="A UUID4 cursor for pagination. Can be None if no more data is available.",
|
||||
)
|
||||
total_number: int
|
||||
has_more: bool = False
|
||||
|
||||
|
||||
class SelectResult(BaseModel):
|
||||
meta: SelectQueryMetaData
|
||||
query: SelectQueryInResult
|
||||
results: SelectResultData | None
|
||||
|
||||
177
dbs/mysql.py
177
dbs/mysql.py
@@ -1,2 +1,177 @@
|
||||
import aiomysql
|
||||
import aiomysql, decimal, datetime
|
||||
import asyncio
|
||||
from typing import Literal, Any
|
||||
from data.schemas import Connection, SelectResultData
|
||||
|
||||
# Database configuration
|
||||
DB_CONFIG = {
|
||||
"host": "localhost",
|
||||
"user": "me", # Replace with your MySQL username
|
||||
"password": "Passwd3.14", # Replace with your MySQL password
|
||||
"db": "testing", # Replace with your database name
|
||||
"port": 3306, # Default MySQL port
|
||||
}
|
||||
|
||||
pools: None | dict[str, aiomysql.Pool] = {}
|
||||
|
||||
|
||||
async def pool_creator(connection: Connection, minsize=5, maxsize=10):
|
||||
|
||||
return await aiomysql.create_pool(
|
||||
# **DB_CONFIG,
|
||||
host=connection.host,
|
||||
user=connection.username,
|
||||
password=connection.password,
|
||||
db=connection.db_name,
|
||||
port=connection.port,
|
||||
minsize=minsize,
|
||||
maxsize=maxsize,
|
||||
)
|
||||
|
||||
|
||||
async def execute_select_query(pool: aiomysql.Pool, query: str, params: list, fetch_num:int = 100):
|
||||
"""
|
||||
Executes a SELECT query on the MySQL database asynchronously and returns the results.
|
||||
|
||||
Args:
|
||||
pool (aiomysql.Pool): Connection pool to use.
|
||||
query (str): The SELECT query to execute.
|
||||
|
||||
Returns:
|
||||
list: A list of rows (as dictionaries) returned by the query.
|
||||
"""
|
||||
try:
|
||||
async with pool.acquire() as connection:
|
||||
async with connection.cursor(aiomysql.DictCursor) as cursor:
|
||||
await cursor.execute(query, params)
|
||||
result = await cursor.fetchmany(fetch_num)
|
||||
|
||||
return result, cursor.rowcount
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error executing query: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def construct_select_query(
|
||||
table_name,
|
||||
columns: Literal["*"] | list[str] = "*",
|
||||
filters=None,
|
||||
sort_by=None,
|
||||
limit=None,
|
||||
offset=None,
|
||||
):
|
||||
"""
|
||||
Constructs a dynamic SELECT query based on the provided parameters.
|
||||
|
||||
Args:
|
||||
table_name (str): The name of the table to query.
|
||||
columns (str or list): The columns to select. Default is "*" (all columns).
|
||||
filters (dict): A dictionary of filters (e.g., {"column": "value"}).
|
||||
sort_by (str): The column to sort by (e.g., "column_name ASC").
|
||||
limit (int): The maximum number of rows to return.
|
||||
offset (int): The number of rows to skip (for pagination).
|
||||
|
||||
Returns:
|
||||
str: The constructed SELECT query.
|
||||
"""
|
||||
# Handle columns
|
||||
if isinstance(columns, list):
|
||||
columns = ", ".join(columns)
|
||||
elif columns != "*":
|
||||
raise ValueError("Columns must be a list or '*'.")
|
||||
|
||||
# Base query
|
||||
query = f"SELECT {columns}\n FROM {table_name}"
|
||||
|
||||
# Add filters
|
||||
if filters:
|
||||
filter_conditions = [
|
||||
f"{column} = '{value}'" for column, value in filters.items()
|
||||
]
|
||||
query += "\n WHERE " + " AND ".join(filter_conditions)
|
||||
|
||||
# Add sorting
|
||||
if sort_by:
|
||||
query += f"\n ORDER BY {sort_by}"
|
||||
|
||||
# Add pagination
|
||||
if limit is not None:
|
||||
query += f"\n LIMIT {limit}"
|
||||
if offset is not None:
|
||||
query += f"\n OFFSET {offset}"
|
||||
|
||||
return query
|
||||
|
||||
|
||||
async def get_tables_and_datatypes(pool: aiomysql.Pool):
|
||||
"""
|
||||
Retrieves all table names and their column data types from a MySQL database.
|
||||
|
||||
Args:
|
||||
db_config (dict): MySQL database configuration (host, user, password, db, port).
|
||||
|
||||
Returns:
|
||||
dict: A dictionary where keys are table names and values are lists of (column_name, data_type) tuples.
|
||||
"""
|
||||
try:
|
||||
# Connect to the MySQL database
|
||||
# connection = await aiomysql.connect(**db_config)
|
||||
async with pool.acquire() as connection:
|
||||
async with connection.cursor() as cursor:
|
||||
# Get all table names
|
||||
await cursor.execute("SHOW TABLES")
|
||||
tables = await cursor.fetchall()
|
||||
|
||||
result = {}
|
||||
for table in tables:
|
||||
table_name = table[0]
|
||||
result[table_name] = {}
|
||||
|
||||
# Get column names and data types for the current table
|
||||
await cursor.execute(f"DESCRIBE {table_name}")
|
||||
columns = await cursor.fetchall()
|
||||
# Store column names and data types
|
||||
for col in columns:
|
||||
result[table_name][col[0]] = col[1]
|
||||
|
||||
# Close the connection
|
||||
# connection.close()
|
||||
# await connection.wait_closed()
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
return {}
|
||||
|
||||
|
||||
def serializer(raw_result: list[dict[str, Any]]):
|
||||
serialized = []
|
||||
while raw_result:
|
||||
re = raw_result.pop()
|
||||
serialized_re = {}
|
||||
for key, value in re.items():
|
||||
if isinstance(value, str | int | None | float):
|
||||
serialized_re[key] = str(value)
|
||||
elif isinstance(value, decimal.Decimal):
|
||||
serialized_re[key] = float(value)
|
||||
elif isinstance(value, datetime.date):
|
||||
serialized_re[key] = value.strftime("%Y-%m-%d")
|
||||
else:
|
||||
serialized_re[key] = str(value)
|
||||
print(key, type(value))
|
||||
|
||||
serialized.append(serialized_re)
|
||||
return serialized
|
||||
|
||||
|
||||
def dict_result_to_list(result: list[dict[str, Any]]) -> None | SelectResultData:
|
||||
if len(result) == 0:
|
||||
return None
|
||||
|
||||
columns = list(result[0].keys())
|
||||
data = []
|
||||
for row in result:
|
||||
data.append(list(row.values()))
|
||||
|
||||
return SelectResultData(columns=columns, data=data)
|
||||
|
||||
29
main.py
29
main.py
@@ -1,26 +1,19 @@
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI
|
||||
from data.db import engine, Base
|
||||
|
||||
from app.connections import connections_router
|
||||
from app.users import users_router
|
||||
from app import api_router
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
app.include_router(router=users_router, prefix="/users", tags=["Users"])
|
||||
app.include_router(
|
||||
router=connections_router, prefix="/connections", tags=["Connections"]
|
||||
)
|
||||
from utils.scripts import pools_creator, pools_destroy, db_startup
|
||||
|
||||
|
||||
# @app.on_event("startup")
|
||||
async def startup():
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
|
||||
# import asyncio
|
||||
# asyncio.run(startup())
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
await pools_creator()
|
||||
yield
|
||||
await pools_destroy()
|
||||
|
||||
# import uvicorn
|
||||
|
||||
# uvicorn.run(app=app)
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
app.include_router(router=api_router)
|
||||
@@ -1,5 +1,6 @@
|
||||
# add_user.py
|
||||
import asyncio
|
||||
|
||||
import asyncio, logging
|
||||
import secrets
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
@@ -7,6 +8,29 @@ from getpass import getpass
|
||||
from data.db import engine, SessionLocal
|
||||
from data.models import Base, User, UserRole
|
||||
|
||||
|
||||
async def pools_creator():
|
||||
from data.crud import read_all_connections
|
||||
from dbs import mysql
|
||||
|
||||
async with SessionLocal() as db:
|
||||
connections = await read_all_connections(db=db)
|
||||
|
||||
for connection in connections:
|
||||
mysql.pools[connection.id] = await mysql.pool_creator(connection=connection)
|
||||
logging.info(msg='Created Pools')
|
||||
|
||||
async def pools_destroy():
|
||||
from dbs import mysql
|
||||
for connection_id, pool in mysql.pools.items():
|
||||
pool.close()
|
||||
await pool.wait_closed()
|
||||
logging.info(f'Closed pool: {connection_id}')
|
||||
|
||||
async def db_startup():
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
def create_secret():
|
||||
return secrets.token_hex(32)
|
||||
|
||||
78
utils/sql_creator.py
Normal file
78
utils/sql_creator.py
Normal file
@@ -0,0 +1,78 @@
|
||||
from data.schemas import SelectQueryBase
|
||||
|
||||
def build_sql_query_text(query: SelectQueryBase) -> tuple[str, list]:
|
||||
"""
|
||||
Builds a SQL query text and parameters from a SelectQuery schema object.
|
||||
|
||||
Args:
|
||||
query (SelectQuery): The query schema object.
|
||||
|
||||
Returns:
|
||||
tuple[str, list]: The SQL query text and a list of parameters.
|
||||
"""
|
||||
|
||||
# Build SELECT clause
|
||||
if query.columns == "*":
|
||||
select_clause = "SELECT *"
|
||||
else:
|
||||
select_clause = f"SELECT {', '.join(query.columns)}"
|
||||
|
||||
# Build FROM clause
|
||||
from_clause = f"FROM {query.table_name}"
|
||||
|
||||
# Build WHERE clause
|
||||
where_clause = ""
|
||||
params = []
|
||||
if query.filters:
|
||||
conditions = []
|
||||
for filter in query.filters:
|
||||
column = filter.column
|
||||
operator = filter.operator.value
|
||||
value = filter.value
|
||||
|
||||
if operator in ["IS NULL", "IS NOT NULL"]:
|
||||
conditions.append(f"{column} {operator}")
|
||||
elif operator == "IN":
|
||||
placeholders = ", ".join(["%s"] * len(value))
|
||||
conditions.append(f"{column} IN ({placeholders})")
|
||||
params.extend(value)
|
||||
else:
|
||||
# operators like < > == !=
|
||||
conditions.append(f"{column} {operator} %s")
|
||||
params.append(value)
|
||||
|
||||
where_clause = "WHERE " + " AND ".join(conditions)
|
||||
|
||||
# Build ORDER BY clause
|
||||
order_by_clause = ""
|
||||
if query.sort_by:
|
||||
order_by = []
|
||||
for sort in query.sort_by:
|
||||
column = sort.column
|
||||
order = sort.order.value
|
||||
order_by.append(f"{column} {order}")
|
||||
order_by_clause = "ORDER BY " + ", ".join(order_by)
|
||||
|
||||
# Build LIMIT and OFFSET clauses
|
||||
limit_clause = ""
|
||||
if query.limit:
|
||||
limit_clause = f"LIMIT {query.limit}"
|
||||
|
||||
offset_clause = ""
|
||||
if query.offset:
|
||||
offset_clause = f"OFFSET {query.offset}"
|
||||
|
||||
# Combine all clauses
|
||||
sql_query = """
|
||||
""".join(
|
||||
[
|
||||
select_clause,
|
||||
from_clause,
|
||||
where_clause,
|
||||
order_by_clause,
|
||||
limit_clause,
|
||||
offset_clause,
|
||||
]
|
||||
).strip()
|
||||
|
||||
return sql_query, params
|
||||
Reference in New Issue
Block a user