Created Queries and Execute endpoint
This commit is contained in:
14
app/__init__.py
Normal file
14
app/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
from fastapi import APIRouter
|
||||||
|
|
||||||
|
from app.connections import router as connections_router
|
||||||
|
from app.operations import router as router
|
||||||
|
from app.users import router as user_router
|
||||||
|
|
||||||
|
api_router = APIRouter()
|
||||||
|
api_router.include_router(router=user_router, prefix="/users", tags=["Users"])
|
||||||
|
api_router.include_router(
|
||||||
|
router=connections_router, prefix="/connections", tags=["Connections"]
|
||||||
|
)
|
||||||
|
api_router.include_router(
|
||||||
|
router=router, prefix='/operations', tags=["Operations"]
|
||||||
|
)
|
||||||
@@ -11,10 +11,10 @@ from data.crud import (
|
|||||||
)
|
)
|
||||||
from core.dependencies import get_db, get_current_user, get_admin_user
|
from core.dependencies import get_db, get_current_user, get_admin_user
|
||||||
|
|
||||||
connections_router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
@connections_router.post("/", status_code=status.HTTP_201_CREATED)
|
@router.post("/", status_code=status.HTTP_201_CREATED)
|
||||||
async def create_connection_endpoint(
|
async def create_connection_endpoint(
|
||||||
connection: ConnectionCreate,
|
connection: ConnectionCreate,
|
||||||
db: AsyncSession = Depends(get_db),
|
db: AsyncSession = Depends(get_db),
|
||||||
@@ -23,7 +23,7 @@ async def create_connection_endpoint(
|
|||||||
return await create_connection(db=db, connection=connection, user_id=admin.id)
|
return await create_connection(db=db, connection=connection, user_id=admin.id)
|
||||||
|
|
||||||
|
|
||||||
@connections_router.get(
|
@router.get(
|
||||||
"/",
|
"/",
|
||||||
response_model=list[Connection],
|
response_model=list[Connection],
|
||||||
dependencies=[Depends(get_current_user)],
|
dependencies=[Depends(get_current_user)],
|
||||||
@@ -35,7 +35,7 @@ async def read_connections_endpoint(
|
|||||||
return db_connection
|
return db_connection
|
||||||
|
|
||||||
|
|
||||||
@connections_router.get(
|
@router.get(
|
||||||
"/{connection_id}",
|
"/{connection_id}",
|
||||||
response_model=Connection,
|
response_model=Connection,
|
||||||
dependencies=[Depends(get_current_user)],
|
dependencies=[Depends(get_current_user)],
|
||||||
@@ -47,7 +47,7 @@ async def read_connection_endpoint(connection_id: int, db: AsyncSession = Depend
|
|||||||
return db_connection
|
return db_connection
|
||||||
|
|
||||||
|
|
||||||
@connections_router.put(
|
@router.put(
|
||||||
"/{connection_id}",
|
"/{connection_id}",
|
||||||
response_model=Connection,
|
response_model=Connection,
|
||||||
dependencies=[Depends(get_admin_user)],
|
dependencies=[Depends(get_admin_user)],
|
||||||
@@ -63,7 +63,7 @@ async def update_connection_endpoint(
|
|||||||
return db_connection
|
return db_connection
|
||||||
|
|
||||||
|
|
||||||
@connections_router.delete(
|
@router.delete(
|
||||||
"/{connection_id}",
|
"/{connection_id}",
|
||||||
status_code=status.HTTP_204_NO_CONTENT,
|
status_code=status.HTTP_204_NO_CONTENT,
|
||||||
dependencies=[Depends(get_admin_user)],
|
dependencies=[Depends(get_admin_user)],
|
||||||
|
|||||||
92
app/operations.py
Normal file
92
app/operations.py
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
from fastapi.routing import APIRouter
|
||||||
|
from typing_extensions import Annotated
|
||||||
|
from pydantic import Field
|
||||||
|
from data.schemas import (
|
||||||
|
SelectQueryBase,
|
||||||
|
SelectQueryInDB,
|
||||||
|
SelectQuery,
|
||||||
|
SelectQueryIn,
|
||||||
|
SelectResult,
|
||||||
|
SelectQueryMetaData,
|
||||||
|
SelectQueryInResult,
|
||||||
|
)
|
||||||
|
from fastapi import Depends, HTTPException, status
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from data.crud import (
|
||||||
|
read_connection,
|
||||||
|
create_select_query,
|
||||||
|
read_all_select_queries,
|
||||||
|
read_select_query,
|
||||||
|
)
|
||||||
|
from core.dependencies import get_db, get_current_user, get_admin_user
|
||||||
|
from core.exceptions import QueryNotFound, ConnectionNotFound, PoolNotFound
|
||||||
|
from utils.sql_creator import build_sql_query_text
|
||||||
|
from dbs import mysql
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/select")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/check-query", dependencies=[Depends(get_current_user)])
|
||||||
|
async def check_select_query(query: SelectQueryBase) -> SelectQuery:
|
||||||
|
sql, params = build_sql_query_text(query)
|
||||||
|
q = SelectQuery(**query.model_dump(), params=params, sql=sql)
|
||||||
|
return q
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/query")
|
||||||
|
async def create_select_query_endpoint(
|
||||||
|
query: SelectQueryBase, db=Depends(get_db), user=Depends(get_current_user)
|
||||||
|
) -> SelectQueryInDB:
|
||||||
|
sql, params = build_sql_query_text(query)
|
||||||
|
query_in = SelectQueryIn(
|
||||||
|
**query.model_dump(), owner_id=user.id, params=params, sql=sql
|
||||||
|
)
|
||||||
|
return await create_select_query(db=db, query=query_in)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/query", dependencies=[Depends(get_current_user)])
|
||||||
|
async def get_select_queries_endpoint(db=Depends(get_db)) -> list[SelectQueryInDB]:
|
||||||
|
return await read_all_select_queries(db=db)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/query/{query_id}", dependencies=[Depends(get_current_user)])
|
||||||
|
async def get_select_queries_endpoint(
|
||||||
|
query_id: int, db=Depends(get_db)
|
||||||
|
) -> SelectQueryInDB:
|
||||||
|
return await read_select_query(db=db, query_id=query_id)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/execute", dependencies=[Depends(get_current_user)])
|
||||||
|
async def execute_select(
|
||||||
|
query_id: int,
|
||||||
|
connection_id: int,
|
||||||
|
page_size: Annotated[int, Field(ge=1, le=100)] = 50,
|
||||||
|
db=Depends(get_db),
|
||||||
|
) -> SelectResult:
|
||||||
|
query = await read_select_query(db=db, query_id=query_id)
|
||||||
|
|
||||||
|
if query is None:
|
||||||
|
raise QueryNotFound
|
||||||
|
connection = await read_connection(db=db, connection_id=connection_id)
|
||||||
|
|
||||||
|
if connection is None:
|
||||||
|
raise ConnectionNotFound
|
||||||
|
pool = mysql.pools.get(connection.id, None)
|
||||||
|
if pool is None:
|
||||||
|
raise PoolNotFound
|
||||||
|
|
||||||
|
raw_result, rowcount = await mysql.execute_select_query(
|
||||||
|
pool=pool, query=query.sql, params=query.params, fetch_num=page_size
|
||||||
|
)
|
||||||
|
|
||||||
|
results = mysql.dict_result_to_list(result=mysql.serializer(raw_result=raw_result))
|
||||||
|
|
||||||
|
meta = SelectQueryMetaData(
|
||||||
|
cursor=None, total_number=rowcount, has_more=len(results.data) != rowcount
|
||||||
|
)
|
||||||
|
|
||||||
|
return SelectResult(
|
||||||
|
meta=meta,
|
||||||
|
query=query,
|
||||||
|
results=results,
|
||||||
|
)
|
||||||
16
app/users.py
16
app/users.py
@@ -7,17 +7,17 @@ from data.crud import read_all_users, read_user, create_user, delete_user
|
|||||||
from core.dependencies import get_db, get_current_user, get_admin_user
|
from core.dependencies import get_db, get_current_user, get_admin_user
|
||||||
from sqlalchemy.exc import IntegrityError
|
from sqlalchemy.exc import IntegrityError
|
||||||
from core.exceptions import ObjectNotFoundInDB, UserNotFound
|
from core.exceptions import ObjectNotFoundInDB, UserNotFound
|
||||||
from core.scripts import create_secret
|
from utils.scripts import create_secret
|
||||||
|
|
||||||
users_router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
@users_router.get("/me")
|
@router.get("/me")
|
||||||
async def get_me(user=Depends(get_current_user)) -> UserOut:
|
async def get_me(user=Depends(get_current_user)) -> UserOut:
|
||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
@users_router.get(
|
@router.get(
|
||||||
"/",
|
"/",
|
||||||
dependencies=[Depends(get_current_user)],
|
dependencies=[Depends(get_current_user)],
|
||||||
)
|
)
|
||||||
@@ -25,7 +25,7 @@ async def get_all_users_endpoint(db=Depends(get_db)) -> list[UserOut]:
|
|||||||
return await read_all_users(db=db)
|
return await read_all_users(db=db)
|
||||||
|
|
||||||
|
|
||||||
@users_router.post(
|
@router.post(
|
||||||
"/", dependencies=[Depends(get_current_user)], status_code=status.HTTP_201_CREATED
|
"/", dependencies=[Depends(get_current_user)], status_code=status.HTTP_201_CREATED
|
||||||
)
|
)
|
||||||
async def create_user_endpoint(
|
async def create_user_endpoint(
|
||||||
@@ -42,7 +42,7 @@ async def create_user_endpoint(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@users_router.post('/update-my-api_key/', status_code=status.HTTP_204_NO_CONTENT)
|
@router.post('/update-my-api_key/', status_code=status.HTTP_204_NO_CONTENT)
|
||||||
async def update_user_own_api_key(user=Depends(get_current_user), db=Depends(get_db)):
|
async def update_user_own_api_key(user=Depends(get_current_user), db=Depends(get_db)):
|
||||||
if user.role == UserRole.admin:
|
if user.role == UserRole.admin:
|
||||||
raise HTTPException(status_code=400, detail={
|
raise HTTPException(status_code=400, detail={
|
||||||
@@ -54,7 +54,7 @@ async def update_user_own_api_key(user=Depends(get_current_user), db=Depends(get
|
|||||||
await db.commit()
|
await db.commit()
|
||||||
await db.refresh(user)
|
await db.refresh(user)
|
||||||
|
|
||||||
@users_router.post('/update-user-api_key/', status_code=status.HTTP_202_ACCEPTED, dependencies=[Depends(get_admin_user)])
|
@router.post('/update-user-api_key/', status_code=status.HTTP_202_ACCEPTED, dependencies=[Depends(get_admin_user)])
|
||||||
async def update_user_own_api_key(user_id:int, db=Depends(get_db)) -> UserInDBBase:
|
async def update_user_own_api_key(user_id:int, db=Depends(get_db)) -> UserInDBBase:
|
||||||
user = await read_user(db=db, user_id=user_id)
|
user = await read_user(db=db, user_id=user_id)
|
||||||
if user is None:
|
if user is None:
|
||||||
@@ -65,7 +65,7 @@ async def update_user_own_api_key(user_id:int, db=Depends(get_db)) -> UserInDBBa
|
|||||||
await db.refresh(user)
|
await db.refresh(user)
|
||||||
return user
|
return user
|
||||||
|
|
||||||
@users_router.delete(
|
@router.delete(
|
||||||
"/", dependencies=[Depends(get_admin_user)], status_code=status.HTTP_204_NO_CONTENT
|
"/", dependencies=[Depends(get_admin_user)], status_code=status.HTTP_204_NO_CONTENT
|
||||||
)
|
)
|
||||||
async def delete_user_endpoint(user_id: int, db=Depends(get_db)):
|
async def delete_user_endpoint(user_id: int, db=Depends(get_db)):
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
# dependencies.py
|
|
||||||
from fastapi import Depends, HTTPException, status, Security
|
from fastapi import Depends, HTTPException, status, Security
|
||||||
|
|
||||||
from fastapi.security import APIKeyHeader
|
from fastapi.security import APIKeyHeader
|
||||||
@@ -6,10 +5,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||||||
from sqlalchemy.future import select
|
from sqlalchemy.future import select
|
||||||
from data.db import SessionLocal
|
from data.db import SessionLocal
|
||||||
from data.models import User, UserRole
|
from data.models import User, UserRole
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
# class UserInDB(User):
|
|
||||||
# hashed_password: str
|
|
||||||
|
|
||||||
async def get_db():
|
async def get_db():
|
||||||
async with SessionLocal() as session:
|
async with SessionLocal() as session:
|
||||||
@@ -18,7 +13,7 @@ async def get_db():
|
|||||||
API_KEY_NAME = "Authorization"
|
API_KEY_NAME = "Authorization"
|
||||||
api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False)
|
api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False)
|
||||||
|
|
||||||
async def get_current_user(db: AsyncSession = Depends(get_db), api_key:str = Security(api_key_header)):
|
async def get_current_user(db: AsyncSession = Depends(get_db), api_key:str = Security(api_key_header)) -> User:
|
||||||
if api_key_header is None:
|
if api_key_header is None:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_403_FORBIDDEN, detail="API key missing"
|
status_code=status.HTTP_403_FORBIDDEN, detail="API key missing"
|
||||||
@@ -34,7 +29,7 @@ async def get_current_user(db: AsyncSession = Depends(get_db), api_key:str = Sec
|
|||||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid credentials")
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid credentials")
|
||||||
return user
|
return user
|
||||||
|
|
||||||
async def get_admin_user(current_user: User = Depends(get_current_user)):
|
async def get_admin_user(current_user: User = Depends(get_current_user)) -> User:
|
||||||
if current_user.role != UserRole.admin:
|
if current_user.role != UserRole.admin:
|
||||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Not enough permissions")
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Not enough permissions")
|
||||||
return current_user
|
return current_user
|
||||||
|
|||||||
@@ -6,4 +6,22 @@ class ConnectionTypes(str, enum.Enum):
|
|||||||
|
|
||||||
class UserRole(enum.Enum):
|
class UserRole(enum.Enum):
|
||||||
admin = "admin"
|
admin = "admin"
|
||||||
user = "user"
|
user = "user"
|
||||||
|
|
||||||
|
|
||||||
|
class FilterOperator(str, enum.Enum):
|
||||||
|
eq = "="
|
||||||
|
neq = "!="
|
||||||
|
gt = ">"
|
||||||
|
lt = "<"
|
||||||
|
gte = ">="
|
||||||
|
lte = "<="
|
||||||
|
like = "LIKE"
|
||||||
|
ilike = "ILIKE"
|
||||||
|
in_ = "IN"
|
||||||
|
is_null = "IS NULL"
|
||||||
|
is_not_null = "IS NOT NULL"
|
||||||
|
|
||||||
|
class SortOrder(str, enum.Enum):
|
||||||
|
asc = "ASC"
|
||||||
|
desc = "DESC"
|
||||||
@@ -17,3 +17,33 @@ class UserNotFound(HTTPException):
|
|||||||
headers=None,
|
headers=None,
|
||||||
):
|
):
|
||||||
super().__init__(status_code, detail, headers)
|
super().__init__(status_code, detail, headers)
|
||||||
|
|
||||||
|
|
||||||
|
class QueryValidationError(ValueError):
|
||||||
|
def __init__(self, loc: list[str], msg: str):
|
||||||
|
self.loc = loc
|
||||||
|
self.msg = msg
|
||||||
|
super().__init__(msg)
|
||||||
|
|
||||||
|
class QueryNotFound(HTTPException):
|
||||||
|
def __init__(self, status_code=404, detail = {
|
||||||
|
'message': "The referenced query was not found.",
|
||||||
|
"code": 'query-not-found'
|
||||||
|
}, headers = None):
|
||||||
|
super().__init__(status_code, detail, headers)
|
||||||
|
|
||||||
|
|
||||||
|
class ConnectionNotFound(HTTPException):
|
||||||
|
def __init__(self, status_code=404, detail = {
|
||||||
|
'message': "The referenced connection was not found.",
|
||||||
|
"code": 'connection-not-found'
|
||||||
|
}, headers = None):
|
||||||
|
super().__init__(status_code, detail, headers)
|
||||||
|
|
||||||
|
|
||||||
|
class PoolNotFound(HTTPException):
|
||||||
|
def __init__(self, status_code=404, detail = {
|
||||||
|
'message': "We didn't find a running Pool for the referenced connection.",
|
||||||
|
"code": 'pool-not-found'
|
||||||
|
}, headers = None):
|
||||||
|
super().__init__(status_code, detail, headers)
|
||||||
21
data/crud.py
21
data/crud.py
@@ -1,8 +1,8 @@
|
|||||||
# crud.py
|
# crud.py
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy.future import select
|
from sqlalchemy.future import select
|
||||||
from data.models import Connection, User
|
from data.models import Connection, User, Query
|
||||||
from data.schemas import ConnectionCreate, ConnectionUpdate, UserCreate
|
from data.schemas import ConnectionCreate, ConnectionUpdate, UserCreate, SelectQueryIn
|
||||||
from core.exceptions import ObjectNotFoundInDB
|
from core.exceptions import ObjectNotFoundInDB
|
||||||
|
|
||||||
async def read_user(db:AsyncSession, user_id:int):
|
async def read_user(db:AsyncSession, user_id:int):
|
||||||
@@ -14,7 +14,7 @@ async def read_all_users(db:AsyncSession):
|
|||||||
return result.scalars().all()
|
return result.scalars().all()
|
||||||
|
|
||||||
async def create_user(db: AsyncSession, user: UserCreate):
|
async def create_user(db: AsyncSession, user: UserCreate):
|
||||||
from core.scripts import create_secret
|
from utils.scripts import create_secret
|
||||||
db_user = User(**user.model_dump(), api_key=create_secret())
|
db_user = User(**user.model_dump(), api_key=create_secret())
|
||||||
db.add(db_user)
|
db.add(db_user)
|
||||||
await db.commit()
|
await db.commit()
|
||||||
@@ -60,3 +60,18 @@ async def delete_connection(db: AsyncSession, connection_id: int):
|
|||||||
await db.delete(db_connection)
|
await db.delete(db_connection)
|
||||||
await db.commit()
|
await db.commit()
|
||||||
return db_connection
|
return db_connection
|
||||||
|
|
||||||
|
async def create_select_query(db:AsyncSession, query:SelectQueryIn) -> Query:
|
||||||
|
db_query = Query(**query.model_dump())
|
||||||
|
db.add(db_query)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(db_query)
|
||||||
|
return db_query
|
||||||
|
|
||||||
|
async def read_all_select_queries(db:AsyncSession) -> list[Query]:
|
||||||
|
result = await db.execute(select(Query))
|
||||||
|
return result.scalars().all()
|
||||||
|
|
||||||
|
async def read_select_query(db:AsyncSession, query_id:int) -> Query:
|
||||||
|
result = await db.execute(select(Query).filter(Query.id == query_id))
|
||||||
|
return result.scalars().first()
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
# models.py
|
# models.py
|
||||||
from sqlalchemy import Column, Integer, String, Enum, ForeignKey
|
from sqlalchemy import Column, Integer, String, Enum, ForeignKey, JSON
|
||||||
from sqlalchemy.orm import relationship
|
from sqlalchemy.orm import relationship
|
||||||
from data.db import Base
|
from data.db import Base
|
||||||
from core.enums import ConnectionTypes, UserRole
|
from core.enums import ConnectionTypes, UserRole
|
||||||
@@ -27,3 +27,19 @@ class Connection(Base):
|
|||||||
owner_id = Column(Integer, ForeignKey("users.id"))
|
owner_id = Column(Integer, ForeignKey("users.id"))
|
||||||
|
|
||||||
# owner = relationship("User", back_populates="connections")
|
# owner = relationship("User", back_populates="connections")
|
||||||
|
|
||||||
|
class Query(Base):
|
||||||
|
__tablename__ = "queries"
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True, index=True)
|
||||||
|
name = Column(String, nullable=False)
|
||||||
|
description = Column(String, nullable=True)
|
||||||
|
owner_id = Column(Integer, ForeignKey("users.id"))
|
||||||
|
table_name = Column(String, nullable=False)
|
||||||
|
columns = Column(JSON, nullable=False)
|
||||||
|
filters = Column(JSON, nullable=True)
|
||||||
|
sort_by = Column(JSON, nullable=True)
|
||||||
|
limit = Column(Integer, nullable=True)
|
||||||
|
offset = Column(Integer, nullable=True)
|
||||||
|
sql= Column(String, nullable=False)
|
||||||
|
params = Column(JSON, nullable=False)
|
||||||
|
|||||||
143
data/schemas.py
143
data/schemas.py
@@ -1,7 +1,9 @@
|
|||||||
# schemas.py
|
import re
|
||||||
from pydantic import BaseModel
|
from typing import Union, List, Optional, Literal, Any
|
||||||
from typing import Optional
|
from typing_extensions import Annotated
|
||||||
from core.enums import ConnectionTypes, UserRole
|
from pydantic import BaseModel, Field, field_validator, ValidationInfo, UUID4
|
||||||
|
from core.enums import ConnectionTypes, UserRole, FilterOperator, SortOrder
|
||||||
|
from core.exceptions import QueryValidationError
|
||||||
|
|
||||||
|
|
||||||
class ConnectionBase(BaseModel):
|
class ConnectionBase(BaseModel):
|
||||||
@@ -57,3 +59,136 @@ class UserInDBBase(UserBase):
|
|||||||
|
|
||||||
class User(UserInDBBase):
|
class User(UserInDBBase):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class FilterClause(BaseModel):
|
||||||
|
column: str
|
||||||
|
operator: FilterOperator
|
||||||
|
value: Optional[Union[str, int, float, bool, list]] = None
|
||||||
|
|
||||||
|
@field_validator("column")
|
||||||
|
@classmethod
|
||||||
|
def validate_column(cls, v: str) -> str:
|
||||||
|
if not re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", v):
|
||||||
|
raise QueryValidationError(
|
||||||
|
loc=["filters", "column"], msg="Invalid column name format"
|
||||||
|
)
|
||||||
|
return v
|
||||||
|
|
||||||
|
@field_validator("value")
|
||||||
|
@classmethod
|
||||||
|
def validate_value(
|
||||||
|
cls,
|
||||||
|
v: Optional[Union[str, int, float, bool, list]],
|
||||||
|
values: ValidationInfo,
|
||||||
|
) -> Optional[Union[str, int, float, bool, list]]:
|
||||||
|
operator = values.data.get("operator")
|
||||||
|
|
||||||
|
if operator in [FilterOperator.is_null, FilterOperator.is_not_null]:
|
||||||
|
if v is not None:
|
||||||
|
raise QueryValidationError(
|
||||||
|
loc=["filters", "value"],
|
||||||
|
msg="Value must be null for IS NULL/IS NOT NULL operators",
|
||||||
|
)
|
||||||
|
elif operator == FilterOperator.in_:
|
||||||
|
if not isinstance(v, list):
|
||||||
|
raise QueryValidationError(
|
||||||
|
loc=["filters", "value"],
|
||||||
|
msg="IN operator requires a list of values",
|
||||||
|
)
|
||||||
|
elif v is None:
|
||||||
|
raise QueryValidationError(
|
||||||
|
loc=["filters", "value"], msg="Value required for this operator"
|
||||||
|
)
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
class SortClause(BaseModel):
|
||||||
|
column: str
|
||||||
|
order: SortOrder = SortOrder.asc
|
||||||
|
|
||||||
|
@field_validator("column")
|
||||||
|
@classmethod
|
||||||
|
def validate_column(cls, v: str) -> str:
|
||||||
|
if not re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", v):
|
||||||
|
raise QueryValidationError(
|
||||||
|
loc=["sort_by", "column"], msg="Invalid column name format"
|
||||||
|
)
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
class SelectQueryBase(BaseModel):
|
||||||
|
name: str # a short name for this query.
|
||||||
|
description: str | None = None # describing what does this query do.
|
||||||
|
table_name: str
|
||||||
|
columns: Union[Literal["*"], List[str]] = "*"
|
||||||
|
filters: Optional[List[FilterClause]] = None
|
||||||
|
sort_by: Optional[List[SortClause]] = None
|
||||||
|
limit: Annotated[int, Field(strict=True, gt=0)] = None
|
||||||
|
offset: Annotated[int, Field(strict=True, ge=0)] = None
|
||||||
|
|
||||||
|
@field_validator("table_name")
|
||||||
|
@classmethod
|
||||||
|
def validate_table_name(cls, v: str) -> str:
|
||||||
|
if not re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", v):
|
||||||
|
raise QueryValidationError(
|
||||||
|
loc=["table_name"], msg="Invalid table name format"
|
||||||
|
)
|
||||||
|
return v
|
||||||
|
|
||||||
|
@field_validator("columns")
|
||||||
|
@classmethod
|
||||||
|
def validate_columns(
|
||||||
|
cls, v: Union[Literal["*"], List[str]]
|
||||||
|
) -> Union[Literal["*"], List[str]]:
|
||||||
|
if v == "*":
|
||||||
|
return v
|
||||||
|
|
||||||
|
for col in v:
|
||||||
|
if not re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", col):
|
||||||
|
raise QueryValidationError(
|
||||||
|
loc=["columns"], msg=f"Invalid column name: {col}"
|
||||||
|
)
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
class SelectQuery(SelectQueryBase):
|
||||||
|
sql: str
|
||||||
|
params: list[str | int | float]
|
||||||
|
|
||||||
|
|
||||||
|
class SelectQueryIn(SelectQuery):
|
||||||
|
owner_id: int
|
||||||
|
|
||||||
|
|
||||||
|
class SelectQueryInDB(SelectQueryIn):
|
||||||
|
id: int
|
||||||
|
|
||||||
|
|
||||||
|
class SelectResultData(BaseModel):
|
||||||
|
columns: List[str]
|
||||||
|
data: List[List[Any]]
|
||||||
|
|
||||||
|
|
||||||
|
class SelectQueryInResult(BaseModel):
|
||||||
|
id: int
|
||||||
|
sql: str
|
||||||
|
params: list[str | int | float]
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
from_attributes = True
|
||||||
|
|
||||||
|
|
||||||
|
class SelectQueryMetaData(BaseModel):
|
||||||
|
cursor: Optional[UUID4] = Field(
|
||||||
|
None,
|
||||||
|
description="A UUID4 cursor for pagination. Can be None if no more data is available.",
|
||||||
|
)
|
||||||
|
total_number: int
|
||||||
|
has_more: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class SelectResult(BaseModel):
|
||||||
|
meta: SelectQueryMetaData
|
||||||
|
query: SelectQueryInResult
|
||||||
|
results: SelectResultData | None
|
||||||
|
|||||||
177
dbs/mysql.py
177
dbs/mysql.py
@@ -1,2 +1,177 @@
|
|||||||
import aiomysql
|
import aiomysql, decimal, datetime
|
||||||
|
import asyncio
|
||||||
|
from typing import Literal, Any
|
||||||
|
from data.schemas import Connection, SelectResultData
|
||||||
|
|
||||||
|
# Database configuration
|
||||||
|
DB_CONFIG = {
|
||||||
|
"host": "localhost",
|
||||||
|
"user": "me", # Replace with your MySQL username
|
||||||
|
"password": "Passwd3.14", # Replace with your MySQL password
|
||||||
|
"db": "testing", # Replace with your database name
|
||||||
|
"port": 3306, # Default MySQL port
|
||||||
|
}
|
||||||
|
|
||||||
|
pools: None | dict[str, aiomysql.Pool] = {}
|
||||||
|
|
||||||
|
|
||||||
|
async def pool_creator(connection: Connection, minsize=5, maxsize=10):
|
||||||
|
|
||||||
|
return await aiomysql.create_pool(
|
||||||
|
# **DB_CONFIG,
|
||||||
|
host=connection.host,
|
||||||
|
user=connection.username,
|
||||||
|
password=connection.password,
|
||||||
|
db=connection.db_name,
|
||||||
|
port=connection.port,
|
||||||
|
minsize=minsize,
|
||||||
|
maxsize=maxsize,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def execute_select_query(pool: aiomysql.Pool, query: str, params: list, fetch_num:int = 100):
|
||||||
|
"""
|
||||||
|
Executes a SELECT query on the MySQL database asynchronously and returns the results.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pool (aiomysql.Pool): Connection pool to use.
|
||||||
|
query (str): The SELECT query to execute.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: A list of rows (as dictionaries) returned by the query.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
async with pool.acquire() as connection:
|
||||||
|
async with connection.cursor(aiomysql.DictCursor) as cursor:
|
||||||
|
await cursor.execute(query, params)
|
||||||
|
result = await cursor.fetchmany(fetch_num)
|
||||||
|
|
||||||
|
return result, cursor.rowcount
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error executing query: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def construct_select_query(
|
||||||
|
table_name,
|
||||||
|
columns: Literal["*"] | list[str] = "*",
|
||||||
|
filters=None,
|
||||||
|
sort_by=None,
|
||||||
|
limit=None,
|
||||||
|
offset=None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Constructs a dynamic SELECT query based on the provided parameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
table_name (str): The name of the table to query.
|
||||||
|
columns (str or list): The columns to select. Default is "*" (all columns).
|
||||||
|
filters (dict): A dictionary of filters (e.g., {"column": "value"}).
|
||||||
|
sort_by (str): The column to sort by (e.g., "column_name ASC").
|
||||||
|
limit (int): The maximum number of rows to return.
|
||||||
|
offset (int): The number of rows to skip (for pagination).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The constructed SELECT query.
|
||||||
|
"""
|
||||||
|
# Handle columns
|
||||||
|
if isinstance(columns, list):
|
||||||
|
columns = ", ".join(columns)
|
||||||
|
elif columns != "*":
|
||||||
|
raise ValueError("Columns must be a list or '*'.")
|
||||||
|
|
||||||
|
# Base query
|
||||||
|
query = f"SELECT {columns}\n FROM {table_name}"
|
||||||
|
|
||||||
|
# Add filters
|
||||||
|
if filters:
|
||||||
|
filter_conditions = [
|
||||||
|
f"{column} = '{value}'" for column, value in filters.items()
|
||||||
|
]
|
||||||
|
query += "\n WHERE " + " AND ".join(filter_conditions)
|
||||||
|
|
||||||
|
# Add sorting
|
||||||
|
if sort_by:
|
||||||
|
query += f"\n ORDER BY {sort_by}"
|
||||||
|
|
||||||
|
# Add pagination
|
||||||
|
if limit is not None:
|
||||||
|
query += f"\n LIMIT {limit}"
|
||||||
|
if offset is not None:
|
||||||
|
query += f"\n OFFSET {offset}"
|
||||||
|
|
||||||
|
return query
|
||||||
|
|
||||||
|
|
||||||
|
async def get_tables_and_datatypes(pool: aiomysql.Pool):
|
||||||
|
"""
|
||||||
|
Retrieves all table names and their column data types from a MySQL database.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_config (dict): MySQL database configuration (host, user, password, db, port).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: A dictionary where keys are table names and values are lists of (column_name, data_type) tuples.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Connect to the MySQL database
|
||||||
|
# connection = await aiomysql.connect(**db_config)
|
||||||
|
async with pool.acquire() as connection:
|
||||||
|
async with connection.cursor() as cursor:
|
||||||
|
# Get all table names
|
||||||
|
await cursor.execute("SHOW TABLES")
|
||||||
|
tables = await cursor.fetchall()
|
||||||
|
|
||||||
|
result = {}
|
||||||
|
for table in tables:
|
||||||
|
table_name = table[0]
|
||||||
|
result[table_name] = {}
|
||||||
|
|
||||||
|
# Get column names and data types for the current table
|
||||||
|
await cursor.execute(f"DESCRIBE {table_name}")
|
||||||
|
columns = await cursor.fetchall()
|
||||||
|
# Store column names and data types
|
||||||
|
for col in columns:
|
||||||
|
result[table_name][col[0]] = col[1]
|
||||||
|
|
||||||
|
# Close the connection
|
||||||
|
# connection.close()
|
||||||
|
# await connection.wait_closed()
|
||||||
|
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error: {e}")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
def serializer(raw_result: list[dict[str, Any]]):
|
||||||
|
serialized = []
|
||||||
|
while raw_result:
|
||||||
|
re = raw_result.pop()
|
||||||
|
serialized_re = {}
|
||||||
|
for key, value in re.items():
|
||||||
|
if isinstance(value, str | int | None | float):
|
||||||
|
serialized_re[key] = str(value)
|
||||||
|
elif isinstance(value, decimal.Decimal):
|
||||||
|
serialized_re[key] = float(value)
|
||||||
|
elif isinstance(value, datetime.date):
|
||||||
|
serialized_re[key] = value.strftime("%Y-%m-%d")
|
||||||
|
else:
|
||||||
|
serialized_re[key] = str(value)
|
||||||
|
print(key, type(value))
|
||||||
|
|
||||||
|
serialized.append(serialized_re)
|
||||||
|
return serialized
|
||||||
|
|
||||||
|
|
||||||
|
def dict_result_to_list(result: list[dict[str, Any]]) -> None | SelectResultData:
|
||||||
|
if len(result) == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
columns = list(result[0].keys())
|
||||||
|
data = []
|
||||||
|
for row in result:
|
||||||
|
data.append(list(row.values()))
|
||||||
|
|
||||||
|
return SelectResultData(columns=columns, data=data)
|
||||||
|
|||||||
29
main.py
29
main.py
@@ -1,26 +1,19 @@
|
|||||||
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from data.db import engine, Base
|
|
||||||
|
|
||||||
from app.connections import connections_router
|
from app import api_router
|
||||||
from app.users import users_router
|
|
||||||
|
|
||||||
app = FastAPI()
|
from utils.scripts import pools_creator, pools_destroy, db_startup
|
||||||
|
|
||||||
app.include_router(router=users_router, prefix="/users", tags=["Users"])
|
|
||||||
app.include_router(
|
|
||||||
router=connections_router, prefix="/connections", tags=["Connections"]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# @app.on_event("startup")
|
|
||||||
async def startup():
|
|
||||||
async with engine.begin() as conn:
|
|
||||||
await conn.run_sync(Base.metadata.create_all)
|
|
||||||
|
|
||||||
|
|
||||||
# import asyncio
|
@asynccontextmanager
|
||||||
# asyncio.run(startup())
|
async def lifespan(app: FastAPI):
|
||||||
|
await pools_creator()
|
||||||
|
yield
|
||||||
|
await pools_destroy()
|
||||||
|
|
||||||
# import uvicorn
|
app = FastAPI(lifespan=lifespan)
|
||||||
|
app.include_router(router=api_router)
|
||||||
# uvicorn.run(app=app)
|
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
# add_user.py
|
# add_user.py
|
||||||
import asyncio
|
|
||||||
|
import asyncio, logging
|
||||||
import secrets
|
import secrets
|
||||||
from sqlalchemy.future import select
|
from sqlalchemy.future import select
|
||||||
from sqlalchemy.exc import IntegrityError
|
from sqlalchemy.exc import IntegrityError
|
||||||
@@ -7,6 +8,29 @@ from getpass import getpass
|
|||||||
from data.db import engine, SessionLocal
|
from data.db import engine, SessionLocal
|
||||||
from data.models import Base, User, UserRole
|
from data.models import Base, User, UserRole
|
||||||
|
|
||||||
|
|
||||||
|
async def pools_creator():
|
||||||
|
from data.crud import read_all_connections
|
||||||
|
from dbs import mysql
|
||||||
|
|
||||||
|
async with SessionLocal() as db:
|
||||||
|
connections = await read_all_connections(db=db)
|
||||||
|
|
||||||
|
for connection in connections:
|
||||||
|
mysql.pools[connection.id] = await mysql.pool_creator(connection=connection)
|
||||||
|
logging.info(msg='Created Pools')
|
||||||
|
|
||||||
|
async def pools_destroy():
|
||||||
|
from dbs import mysql
|
||||||
|
for connection_id, pool in mysql.pools.items():
|
||||||
|
pool.close()
|
||||||
|
await pool.wait_closed()
|
||||||
|
logging.info(f'Closed pool: {connection_id}')
|
||||||
|
|
||||||
|
async def db_startup():
|
||||||
|
async with engine.begin() as conn:
|
||||||
|
await conn.run_sync(Base.metadata.create_all)
|
||||||
|
|
||||||
def create_secret():
|
def create_secret():
|
||||||
return secrets.token_hex(32)
|
return secrets.token_hex(32)
|
||||||
|
|
||||||
78
utils/sql_creator.py
Normal file
78
utils/sql_creator.py
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
from data.schemas import SelectQueryBase
|
||||||
|
|
||||||
|
def build_sql_query_text(query: SelectQueryBase) -> tuple[str, list]:
|
||||||
|
"""
|
||||||
|
Builds a SQL query text and parameters from a SelectQuery schema object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query (SelectQuery): The query schema object.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[str, list]: The SQL query text and a list of parameters.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Build SELECT clause
|
||||||
|
if query.columns == "*":
|
||||||
|
select_clause = "SELECT *"
|
||||||
|
else:
|
||||||
|
select_clause = f"SELECT {', '.join(query.columns)}"
|
||||||
|
|
||||||
|
# Build FROM clause
|
||||||
|
from_clause = f"FROM {query.table_name}"
|
||||||
|
|
||||||
|
# Build WHERE clause
|
||||||
|
where_clause = ""
|
||||||
|
params = []
|
||||||
|
if query.filters:
|
||||||
|
conditions = []
|
||||||
|
for filter in query.filters:
|
||||||
|
column = filter.column
|
||||||
|
operator = filter.operator.value
|
||||||
|
value = filter.value
|
||||||
|
|
||||||
|
if operator in ["IS NULL", "IS NOT NULL"]:
|
||||||
|
conditions.append(f"{column} {operator}")
|
||||||
|
elif operator == "IN":
|
||||||
|
placeholders = ", ".join(["%s"] * len(value))
|
||||||
|
conditions.append(f"{column} IN ({placeholders})")
|
||||||
|
params.extend(value)
|
||||||
|
else:
|
||||||
|
# operators like < > == !=
|
||||||
|
conditions.append(f"{column} {operator} %s")
|
||||||
|
params.append(value)
|
||||||
|
|
||||||
|
where_clause = "WHERE " + " AND ".join(conditions)
|
||||||
|
|
||||||
|
# Build ORDER BY clause
|
||||||
|
order_by_clause = ""
|
||||||
|
if query.sort_by:
|
||||||
|
order_by = []
|
||||||
|
for sort in query.sort_by:
|
||||||
|
column = sort.column
|
||||||
|
order = sort.order.value
|
||||||
|
order_by.append(f"{column} {order}")
|
||||||
|
order_by_clause = "ORDER BY " + ", ".join(order_by)
|
||||||
|
|
||||||
|
# Build LIMIT and OFFSET clauses
|
||||||
|
limit_clause = ""
|
||||||
|
if query.limit:
|
||||||
|
limit_clause = f"LIMIT {query.limit}"
|
||||||
|
|
||||||
|
offset_clause = ""
|
||||||
|
if query.offset:
|
||||||
|
offset_clause = f"OFFSET {query.offset}"
|
||||||
|
|
||||||
|
# Combine all clauses
|
||||||
|
sql_query = """
|
||||||
|
""".join(
|
||||||
|
[
|
||||||
|
select_clause,
|
||||||
|
from_clause,
|
||||||
|
where_clause,
|
||||||
|
order_by_clause,
|
||||||
|
limit_clause,
|
||||||
|
offset_clause,
|
||||||
|
]
|
||||||
|
).strip()
|
||||||
|
|
||||||
|
return sql_query, params
|
||||||
Reference in New Issue
Block a user