import asyncio, time, decimal, datetime from fastapi import WebSocket, WebSocketDisconnect, WebSocketException from fastapi.routing import APIRouter from fastapi.responses import StreamingResponse from typing_extensions import Annotated from pydantic import Field from data.schemas import SelectResult, CachedCursorOut from fastapi import Depends from data.crud import ( read_connection, read_select_query, ) from core.dependencies import get_db, get_current_user, get_user_from_api_key from core.exceptions import ( QueryNotFound, ConnectionNotFound, PoolNotFound, CursorNotFound, ) from dbs import mysql from utils.binlog import changes_queue, queue router = APIRouter() database_changes_active_websocket: WebSocket | None = None @router.post("/execute", dependencies=[Depends(get_current_user)]) async def execute_select( query_id: int, connection_id: int, page_size: Annotated[int, Field(ge=1, le=10000)] = 50, db=Depends(get_db), ) -> SelectResult: query = await read_select_query(db=db, query_id=query_id) if query is None: raise QueryNotFound connection = await read_connection(db=db, connection_id=connection_id) if connection is None: raise ConnectionNotFound pool = mysql.pools.get(connection.id, None) if pool is None: raise PoolNotFound raw_result, rowcount = await mysql.execute_select_query( pool=pool, sql_query=query.sql, params=query.params, fetch_num=page_size ) results = mysql.dict_result_to_list(result=mysql.serializer(raw_result=raw_result)) return SelectResult( cursor=CachedCursorOut( id=None, connection_id=connection_id, query=query, row_count=rowcount, fetched_rows=len(results.data), is_closed=True, has_more=len(results.data) != rowcount, ttl=-1, close_at=-1, ), results=results, ) @router.get(path="/fetch_cursor", dependencies=[Depends(get_current_user)]) async def fetch_cursor( cursor_id: str, page_size: Annotated[int, Field(ge=1, le=1000)] = 50, ) -> SelectResult: cached_cursor = mysql.cached_cursors.get(cursor_id, None) if cached_cursor is None: raise CursorNotFound result = await cached_cursor.fetch_many(size=page_size) if cached_cursor.done: mysql.cached_cursors.pop(cursor_id, None) return SelectResult( cursor=cached_cursor, results={"columns": cached_cursor.query.columns, "data": result}, ) @router.get( "/get_database_tables", dependencies=[Depends(get_current_user)], status_code=200 ) async def get_database_tables(connection_id: int): pool = mysql.pools.get(connection_id, None) r = await mysql.get_tables_and_datatypes(pool=pool) print(r) return r @router.get( "/sse-stream-cursor", dependencies=[Depends(get_current_user)], status_code=200 ) async def server_side_events_stream_cursor( cursor_id: str, page_size: Annotated[int, Field(ge=1, le=1000)] = 50, ): cached_cursor = mysql.cached_cursors.get(cursor_id, None) if cached_cursor is None: raise CursorNotFound async def stream(): while True: result = await fetch_cursor(cursor_id=cursor_id, page_size=page_size) serialized_result = result.model_dump_json() yield f"data: {serialized_result}\n\n" if result.cursor.is_closed: break return StreamingResponse(stream(), media_type="text/event-stream") @router.websocket( path="/ws-stream/{cursor_id}", ) async def websocket_stream_cursor( websocket: WebSocket, cursor_id: str, page_size: Annotated[int, Field(ge=1, le=1000)] = 50, db=Depends(get_db), ): await websocket.accept() api_key = websocket.headers.get("Authorization") user = await get_user_from_api_key(db=db, api_key=api_key) if user is None: await websocket.close(reason="Invalid credentials", code=1008) return cached_cursor = mysql.cached_cursors.get(cursor_id, None) if cached_cursor is None: e = CursorNotFound() await websocket.close(reason=str(e.detail), code=1011) return while True: result = await fetch_cursor(cursor_id=cursor_id, page_size=page_size) serialized_result = result.model_dump_json() await websocket.send_text( f"data: {serialized_result}\n\n" ) # Format as Server-Sent Event (SSE) if result.cursor.is_closed: break await websocket.close(reason="Done") @router.websocket("/databases_changes") async def websocket_endpoint( websocket: WebSocket, db=Depends(get_db), ): global database_changes_active_websocket await websocket.accept() api_key = websocket.headers.get("Authorization") user = await get_user_from_api_key(db=db, api_key=api_key) if user is None: await websocket.close(reason="Invalid credentials", code=1008) return if database_changes_active_websocket: try: await database_changes_active_websocket.close( code=1001, reason="New connection established" ) except Exception as e: print(e) database_changes_active_websocket = websocket await websocket.send_json({"message": "status", "status": "Accepted."}) try: await feed_databases_changes_ws(websocket=websocket) except WebSocketDisconnect: print("Closed websocket.") def serialize_list(l: list): serialized = [] for value in l: if isinstance(value, str | int | None | float): serialized.append(str(value)) elif isinstance(value, decimal.Decimal): serialized.append(float(value)) elif isinstance(value, datetime.date): serialized.append(value.strftime("%Y-%m-%d")) else: serialized.append(str(value)) return serialized async def feed_databases_changes_ws(websocket: WebSocket): last_update = 0 while True: try: change = changes_queue.get_nowait() if change.action == "UPDATE": change.after_values = serialize_list(change.after_values) change.before_values = serialize_list(change.before_values) else: change.values = serialize_list(change.values) await websocket.send_json( {"message": "change", "change": change.model_dump()} ) except queue.Empty: if last_update + 10 < time.time(): await websocket.send_json({"message": "status", "status": "Alive."}) last_update = time.time() await asyncio.sleep(1) continue