Added support for websocket streaming cursor.
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
from fastapi import WebSocket
|
||||
from fastapi.routing import APIRouter
|
||||
from fastapi.responses import StreamingResponse
|
||||
from typing_extensions import Annotated
|
||||
@@ -8,7 +9,7 @@ from data.crud import (
|
||||
read_connection,
|
||||
read_select_query,
|
||||
)
|
||||
from core.dependencies import get_db, get_current_user, get_admin_user
|
||||
from core.dependencies import get_db, get_current_user, get_user_from_api_key
|
||||
from core.exceptions import (
|
||||
QueryNotFound,
|
||||
ConnectionNotFound,
|
||||
@@ -80,9 +81,9 @@ async def fetch_cursor(
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
@router.get("/sse-stream-cursor", dependencies=[Depends(get_current_user)], status_code=200)
|
||||
@router.get(
|
||||
"/sse-stream-cursor", dependencies=[Depends(get_current_user)], status_code=200
|
||||
)
|
||||
async def server_side_events_stream_cursor(
|
||||
cursor_id: str,
|
||||
page_size: Annotated[int, Field(ge=1, le=1000)] = 50,
|
||||
@@ -95,11 +96,43 @@ async def server_side_events_stream_cursor(
|
||||
while True:
|
||||
result = await fetch_cursor(cursor_id=cursor_id, page_size=page_size)
|
||||
|
||||
serialized_result = (
|
||||
result.model_dump_json()
|
||||
)
|
||||
serialized_result = result.model_dump_json()
|
||||
yield f"data: {serialized_result}\n\n" # Format as Server-Sent Event (SSE)
|
||||
if result.cursor.is_closed:
|
||||
break
|
||||
|
||||
return StreamingResponse(stream(), media_type="text/event-stream")
|
||||
|
||||
|
||||
@router.websocket(
|
||||
path="/ws-stream/{cursor_id}",
|
||||
)
|
||||
async def websocket_stream_cursor(
|
||||
websocket: WebSocket,
|
||||
cursor_id: str,
|
||||
page_size: Annotated[int, Field(ge=1, le=1000)] = 50,
|
||||
db=Depends(get_db),
|
||||
):
|
||||
await websocket.accept()
|
||||
api_key = websocket.headers.get("Authorization")
|
||||
user = await get_user_from_api_key(db=db, api_key=api_key)
|
||||
if user is None:
|
||||
await websocket.close(reason="Invalid credentials", code=1008)
|
||||
return
|
||||
|
||||
cached_cursor = mysql.cached_cursors.get(cursor_id, None)
|
||||
if cached_cursor is None:
|
||||
e = CursorNotFound()
|
||||
await websocket.close(reason=str(e.detail), code=1011)
|
||||
return
|
||||
|
||||
while True:
|
||||
result = await fetch_cursor(cursor_id=cursor_id, page_size=page_size)
|
||||
|
||||
serialized_result = result.model_dump_json()
|
||||
await websocket.send_text(
|
||||
f"data: {serialized_result}\n\n"
|
||||
) # Format as Server-Sent Event (SSE)
|
||||
if result.cursor.is_closed:
|
||||
break
|
||||
await websocket.close(reason="Done")
|
||||
|
||||
Reference in New Issue
Block a user