Added support for websocket streaming cursor.

This commit is contained in:
2025-02-26 05:13:24 +03:00
parent efbca3b7cb
commit 41d98aafe9
3 changed files with 50 additions and 11 deletions

View File

@@ -1,3 +1,4 @@
from fastapi import WebSocket
from fastapi.routing import APIRouter from fastapi.routing import APIRouter
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from typing_extensions import Annotated from typing_extensions import Annotated
@@ -8,7 +9,7 @@ from data.crud import (
read_connection, read_connection,
read_select_query, read_select_query,
) )
from core.dependencies import get_db, get_current_user, get_admin_user from core.dependencies import get_db, get_current_user, get_user_from_api_key
from core.exceptions import ( from core.exceptions import (
QueryNotFound, QueryNotFound,
ConnectionNotFound, ConnectionNotFound,
@@ -80,9 +81,9 @@ async def fetch_cursor(
) )
@router.get(
"/sse-stream-cursor", dependencies=[Depends(get_current_user)], status_code=200
@router.get("/sse-stream-cursor", dependencies=[Depends(get_current_user)], status_code=200) )
async def server_side_events_stream_cursor( async def server_side_events_stream_cursor(
cursor_id: str, cursor_id: str,
page_size: Annotated[int, Field(ge=1, le=1000)] = 50, page_size: Annotated[int, Field(ge=1, le=1000)] = 50,
@@ -95,11 +96,43 @@ async def server_side_events_stream_cursor(
while True: while True:
result = await fetch_cursor(cursor_id=cursor_id, page_size=page_size) result = await fetch_cursor(cursor_id=cursor_id, page_size=page_size)
serialized_result = ( serialized_result = result.model_dump_json()
result.model_dump_json()
)
yield f"data: {serialized_result}\n\n" # Format as Server-Sent Event (SSE) yield f"data: {serialized_result}\n\n" # Format as Server-Sent Event (SSE)
if result.cursor.is_closed: if result.cursor.is_closed:
break break
return StreamingResponse(stream(), media_type="text/event-stream") return StreamingResponse(stream(), media_type="text/event-stream")
@router.websocket(
path="/ws-stream/{cursor_id}",
)
async def websocket_stream_cursor(
websocket: WebSocket,
cursor_id: str,
page_size: Annotated[int, Field(ge=1, le=1000)] = 50,
db=Depends(get_db),
):
await websocket.accept()
api_key = websocket.headers.get("Authorization")
user = await get_user_from_api_key(db=db, api_key=api_key)
if user is None:
await websocket.close(reason="Invalid credentials", code=1008)
return
cached_cursor = mysql.cached_cursors.get(cursor_id, None)
if cached_cursor is None:
e = CursorNotFound()
await websocket.close(reason=str(e.detail), code=1011)
return
while True:
result = await fetch_cursor(cursor_id=cursor_id, page_size=page_size)
serialized_result = result.model_dump_json()
await websocket.send_text(
f"data: {serialized_result}\n\n"
) # Format as Server-Sent Event (SSE)
if result.cursor.is_closed:
break
await websocket.close(reason="Done")

View File

@@ -13,6 +13,11 @@ async def get_db():
API_KEY_NAME = "Authorization" API_KEY_NAME = "Authorization"
api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False) api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False)
async def get_user_from_api_key(db:AsyncSession, api_key:str):
user = await db.execute(select(User).filter(User.api_key == api_key))
user = user.scalars().first()
return user
async def get_current_user(db: AsyncSession = Depends(get_db), api_key:str = Security(api_key_header)) -> User: async def get_current_user(db: AsyncSession = Depends(get_db), api_key:str = Security(api_key_header)) -> User:
if api_key_header is None: if api_key_header is None:
raise HTTPException( raise HTTPException(
@@ -23,8 +28,8 @@ async def get_current_user(db: AsyncSession = Depends(get_db), api_key:str = Sec
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail="API key missing, provide it in the header value [Authorization]" status_code=status.HTTP_403_FORBIDDEN, detail="API key missing, provide it in the header value [Authorization]"
) )
user = await db.execute(select(User).filter(User.api_key == api_key))
user = user.scalars().first() user= await get_user_from_api_key(db=db, api_key=api_key)
if user is None: if user is None:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid credentials") raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid credentials")
return user return user

View File

@@ -11,6 +11,7 @@ closed_cached_cursors: dict[str, CachedCursor] = {}
cached_cursors_cleaner_task = None cached_cursors_cleaner_task = None
async def close_old_cached_cursors(): async def close_old_cached_cursors():
global cached_cursors, closed_cached_cursors global cached_cursors, closed_cached_cursors
@@ -46,13 +47,12 @@ async def remove_old_closed_cached_cursors():
del closed_cached_cursors[cursor_id] del closed_cached_cursors[cursor_id]
print(f"Removed cursor {cursor_id}") print(f"Removed cursor {cursor_id}")
async def cached_cursors_cleaner(): async def cached_cursors_cleaner():
global cached_cursors, closed_cached_cursors global cached_cursors, closed_cached_cursors
while True: while True:
print("hey")
await close_old_cached_cursors() await close_old_cached_cursors()
await remove_old_closed_cached_cursors() await remove_old_closed_cached_cursors()
await asyncio.sleep(10) await asyncio.sleep(10)
@@ -69,6 +69,7 @@ async def pool_creator(connection: Connection, minsize=5, maxsize=10):
maxsize=maxsize, maxsize=maxsize,
) )
async def create_cursor(connection_id: int, query: SelectQuery) -> CachedCursor: async def create_cursor(connection_id: int, query: SelectQuery) -> CachedCursor:
pool = pools.get(connection_id, None) pool = pools.get(connection_id, None)