Added support for websocket streaming cursor.
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
from fastapi import WebSocket
|
||||
from fastapi.routing import APIRouter
|
||||
from fastapi.responses import StreamingResponse
|
||||
from typing_extensions import Annotated
|
||||
@@ -8,7 +9,7 @@ from data.crud import (
|
||||
read_connection,
|
||||
read_select_query,
|
||||
)
|
||||
from core.dependencies import get_db, get_current_user, get_admin_user
|
||||
from core.dependencies import get_db, get_current_user, get_user_from_api_key
|
||||
from core.exceptions import (
|
||||
QueryNotFound,
|
||||
ConnectionNotFound,
|
||||
@@ -80,9 +81,9 @@ async def fetch_cursor(
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
@router.get("/sse-stream-cursor", dependencies=[Depends(get_current_user)], status_code=200)
|
||||
@router.get(
|
||||
"/sse-stream-cursor", dependencies=[Depends(get_current_user)], status_code=200
|
||||
)
|
||||
async def server_side_events_stream_cursor(
|
||||
cursor_id: str,
|
||||
page_size: Annotated[int, Field(ge=1, le=1000)] = 50,
|
||||
@@ -95,11 +96,43 @@ async def server_side_events_stream_cursor(
|
||||
while True:
|
||||
result = await fetch_cursor(cursor_id=cursor_id, page_size=page_size)
|
||||
|
||||
serialized_result = (
|
||||
result.model_dump_json()
|
||||
)
|
||||
serialized_result = result.model_dump_json()
|
||||
yield f"data: {serialized_result}\n\n" # Format as Server-Sent Event (SSE)
|
||||
if result.cursor.is_closed:
|
||||
break
|
||||
|
||||
return StreamingResponse(stream(), media_type="text/event-stream")
|
||||
|
||||
|
||||
@router.websocket(
|
||||
path="/ws-stream/{cursor_id}",
|
||||
)
|
||||
async def websocket_stream_cursor(
|
||||
websocket: WebSocket,
|
||||
cursor_id: str,
|
||||
page_size: Annotated[int, Field(ge=1, le=1000)] = 50,
|
||||
db=Depends(get_db),
|
||||
):
|
||||
await websocket.accept()
|
||||
api_key = websocket.headers.get("Authorization")
|
||||
user = await get_user_from_api_key(db=db, api_key=api_key)
|
||||
if user is None:
|
||||
await websocket.close(reason="Invalid credentials", code=1008)
|
||||
return
|
||||
|
||||
cached_cursor = mysql.cached_cursors.get(cursor_id, None)
|
||||
if cached_cursor is None:
|
||||
e = CursorNotFound()
|
||||
await websocket.close(reason=str(e.detail), code=1011)
|
||||
return
|
||||
|
||||
while True:
|
||||
result = await fetch_cursor(cursor_id=cursor_id, page_size=page_size)
|
||||
|
||||
serialized_result = result.model_dump_json()
|
||||
await websocket.send_text(
|
||||
f"data: {serialized_result}\n\n"
|
||||
) # Format as Server-Sent Event (SSE)
|
||||
if result.cursor.is_closed:
|
||||
break
|
||||
await websocket.close(reason="Done")
|
||||
|
||||
@@ -13,6 +13,11 @@ async def get_db():
|
||||
API_KEY_NAME = "Authorization"
|
||||
api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False)
|
||||
|
||||
async def get_user_from_api_key(db:AsyncSession, api_key:str):
|
||||
user = await db.execute(select(User).filter(User.api_key == api_key))
|
||||
user = user.scalars().first()
|
||||
return user
|
||||
|
||||
async def get_current_user(db: AsyncSession = Depends(get_db), api_key:str = Security(api_key_header)) -> User:
|
||||
if api_key_header is None:
|
||||
raise HTTPException(
|
||||
@@ -23,8 +28,8 @@ async def get_current_user(db: AsyncSession = Depends(get_db), api_key:str = Sec
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail="API key missing, provide it in the header value [Authorization]"
|
||||
)
|
||||
user = await db.execute(select(User).filter(User.api_key == api_key))
|
||||
user = user.scalars().first()
|
||||
|
||||
user= await get_user_from_api_key(db=db, api_key=api_key)
|
||||
if user is None:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid credentials")
|
||||
return user
|
||||
|
||||
@@ -11,6 +11,7 @@ closed_cached_cursors: dict[str, CachedCursor] = {}
|
||||
|
||||
cached_cursors_cleaner_task = None
|
||||
|
||||
|
||||
async def close_old_cached_cursors():
|
||||
global cached_cursors, closed_cached_cursors
|
||||
|
||||
@@ -46,13 +47,12 @@ async def remove_old_closed_cached_cursors():
|
||||
del closed_cached_cursors[cursor_id]
|
||||
print(f"Removed cursor {cursor_id}")
|
||||
|
||||
|
||||
async def cached_cursors_cleaner():
|
||||
global cached_cursors, closed_cached_cursors
|
||||
while True:
|
||||
print("hey")
|
||||
await close_old_cached_cursors()
|
||||
await remove_old_closed_cached_cursors()
|
||||
|
||||
await asyncio.sleep(10)
|
||||
|
||||
|
||||
@@ -69,6 +69,7 @@ async def pool_creator(connection: Connection, minsize=5, maxsize=10):
|
||||
maxsize=maxsize,
|
||||
)
|
||||
|
||||
|
||||
async def create_cursor(connection_id: int, query: SelectQuery) -> CachedCursor:
|
||||
pool = pools.get(connection_id, None)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user