From 41d98aafe93e711d3ad950dade181a217bd99e40 Mon Sep 17 00:00:00 2001 From: abdulhade Date: Wed, 26 Feb 2025 05:13:24 +0300 Subject: [PATCH] Added support for websocket streaming cursor. --- app/operations.py | 47 +++++++++++++++++++++++++++++++++++++------- core/dependencies.py | 9 +++++++-- dbs/mysql.py | 5 +++-- 3 files changed, 50 insertions(+), 11 deletions(-) diff --git a/app/operations.py b/app/operations.py index 9257075..b5b18cf 100644 --- a/app/operations.py +++ b/app/operations.py @@ -1,3 +1,4 @@ +from fastapi import WebSocket from fastapi.routing import APIRouter from fastapi.responses import StreamingResponse from typing_extensions import Annotated @@ -8,7 +9,7 @@ from data.crud import ( read_connection, read_select_query, ) -from core.dependencies import get_db, get_current_user, get_admin_user +from core.dependencies import get_db, get_current_user, get_user_from_api_key from core.exceptions import ( QueryNotFound, ConnectionNotFound, @@ -80,9 +81,9 @@ async def fetch_cursor( ) - - -@router.get("/sse-stream-cursor", dependencies=[Depends(get_current_user)], status_code=200) +@router.get( + "/sse-stream-cursor", dependencies=[Depends(get_current_user)], status_code=200 +) async def server_side_events_stream_cursor( cursor_id: str, page_size: Annotated[int, Field(ge=1, le=1000)] = 50, @@ -95,11 +96,43 @@ async def server_side_events_stream_cursor( while True: result = await fetch_cursor(cursor_id=cursor_id, page_size=page_size) - serialized_result = ( - result.model_dump_json() - ) + serialized_result = result.model_dump_json() yield f"data: {serialized_result}\n\n" # Format as Server-Sent Event (SSE) if result.cursor.is_closed: break return StreamingResponse(stream(), media_type="text/event-stream") + + +@router.websocket( + path="/ws-stream/{cursor_id}", +) +async def websocket_stream_cursor( + websocket: WebSocket, + cursor_id: str, + page_size: Annotated[int, Field(ge=1, le=1000)] = 50, + db=Depends(get_db), +): + await websocket.accept() + api_key = websocket.headers.get("Authorization") + user = await get_user_from_api_key(db=db, api_key=api_key) + if user is None: + await websocket.close(reason="Invalid credentials", code=1008) + return + + cached_cursor = mysql.cached_cursors.get(cursor_id, None) + if cached_cursor is None: + e = CursorNotFound() + await websocket.close(reason=str(e.detail), code=1011) + return + + while True: + result = await fetch_cursor(cursor_id=cursor_id, page_size=page_size) + + serialized_result = result.model_dump_json() + await websocket.send_text( + f"data: {serialized_result}\n\n" + ) # Format as Server-Sent Event (SSE) + if result.cursor.is_closed: + break + await websocket.close(reason="Done") diff --git a/core/dependencies.py b/core/dependencies.py index f7abea4..d7d43aa 100644 --- a/core/dependencies.py +++ b/core/dependencies.py @@ -13,6 +13,11 @@ async def get_db(): API_KEY_NAME = "Authorization" api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False) +async def get_user_from_api_key(db:AsyncSession, api_key:str): + user = await db.execute(select(User).filter(User.api_key == api_key)) + user = user.scalars().first() + return user + async def get_current_user(db: AsyncSession = Depends(get_db), api_key:str = Security(api_key_header)) -> User: if api_key_header is None: raise HTTPException( @@ -23,8 +28,8 @@ async def get_current_user(db: AsyncSession = Depends(get_db), api_key:str = Sec raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="API key missing, provide it in the header value [Authorization]" ) - user = await db.execute(select(User).filter(User.api_key == api_key)) - user = user.scalars().first() + + user= await get_user_from_api_key(db=db, api_key=api_key) if user is None: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid credentials") return user diff --git a/dbs/mysql.py b/dbs/mysql.py index 1ac1128..d6d1d68 100644 --- a/dbs/mysql.py +++ b/dbs/mysql.py @@ -11,6 +11,7 @@ closed_cached_cursors: dict[str, CachedCursor] = {} cached_cursors_cleaner_task = None + async def close_old_cached_cursors(): global cached_cursors, closed_cached_cursors @@ -46,13 +47,12 @@ async def remove_old_closed_cached_cursors(): del closed_cached_cursors[cursor_id] print(f"Removed cursor {cursor_id}") + async def cached_cursors_cleaner(): global cached_cursors, closed_cached_cursors while True: - print("hey") await close_old_cached_cursors() await remove_old_closed_cached_cursors() - await asyncio.sleep(10) @@ -69,6 +69,7 @@ async def pool_creator(connection: Connection, minsize=5, maxsize=10): maxsize=maxsize, ) + async def create_cursor(connection_id: int, query: SelectQuery) -> CachedCursor: pool = pools.get(connection_id, None)