from fastapi import WebSocket from fastapi.routing import APIRouter from fastapi.responses import StreamingResponse from typing_extensions import Annotated from pydantic import Field from data.schemas import SelectResult, CachedCursorOut from fastapi import Depends from data.crud import ( read_connection, read_select_query, ) from core.dependencies import get_db, get_current_user, get_user_from_api_key from core.exceptions import ( QueryNotFound, ConnectionNotFound, PoolNotFound, CursorNotFound, ) from dbs import mysql router = APIRouter() @router.post("/execute", dependencies=[Depends(get_current_user)]) async def execute_select( query_id: int, connection_id: int, page_size: Annotated[int, Field(ge=1, le=100)] = 50, db=Depends(get_db), ) -> SelectResult: query = await read_select_query(db=db, query_id=query_id) if query is None: raise QueryNotFound connection = await read_connection(db=db, connection_id=connection_id) if connection is None: raise ConnectionNotFound pool = mysql.pools.get(connection.id, None) if pool is None: raise PoolNotFound raw_result, rowcount = await mysql.execute_select_query( pool=pool, sql_query=query.sql, params=query.params, fetch_num=page_size ) results = mysql.dict_result_to_list(result=mysql.serializer(raw_result=raw_result)) return SelectResult( cursor=CachedCursorOut( id=None, connection_id=connection_id, query=query, row_count=rowcount, fetched_rows=len(results.data), is_closed=True, has_more=len(results.data) != rowcount, ttl=-1, close_at=-1, ), results=results, ) @router.get(path="/fetch_cursor", dependencies=[Depends(get_current_user)]) async def fetch_cursor( cursor_id: str, page_size: Annotated[int, Field(ge=1, le=1000)] = 50, ) -> SelectResult: cached_cursor = mysql.cached_cursors.get(cursor_id, None) if cached_cursor is None: raise CursorNotFound result = await cached_cursor.fetch_many(size=page_size) if cached_cursor.done: mysql.cached_cursors.pop(cursor_id, None) return SelectResult( cursor=cached_cursor, results={"columns": cached_cursor.query.columns, "data": result}, ) @router.get( "/sse-stream-cursor", dependencies=[Depends(get_current_user)], status_code=200 ) async def server_side_events_stream_cursor( cursor_id: str, page_size: Annotated[int, Field(ge=1, le=1000)] = 50, ): cached_cursor = mysql.cached_cursors.get(cursor_id, None) if cached_cursor is None: raise CursorNotFound async def stream(): while True: result = await fetch_cursor(cursor_id=cursor_id, page_size=page_size) serialized_result = result.model_dump_json() yield f"data: {serialized_result}\n\n" # Format as Server-Sent Event (SSE) if result.cursor.is_closed: break return StreamingResponse(stream(), media_type="text/event-stream") @router.websocket( path="/ws-stream/{cursor_id}", ) async def websocket_stream_cursor( websocket: WebSocket, cursor_id: str, page_size: Annotated[int, Field(ge=1, le=1000)] = 50, db=Depends(get_db), ): await websocket.accept() api_key = websocket.headers.get("Authorization") user = await get_user_from_api_key(db=db, api_key=api_key) if user is None: await websocket.close(reason="Invalid credentials", code=1008) return cached_cursor = mysql.cached_cursors.get(cursor_id, None) if cached_cursor is None: e = CursorNotFound() await websocket.close(reason=str(e.detail), code=1011) return while True: result = await fetch_cursor(cursor_id=cursor_id, page_size=page_size) serialized_result = result.model_dump_json() await websocket.send_text( f"data: {serialized_result}\n\n" ) # Format as Server-Sent Event (SSE) if result.cursor.is_closed: break await websocket.close(reason="Done")