Added support for websocket streaming cursor.
This commit is contained in:
@@ -1,3 +1,4 @@
|
|||||||
|
from fastapi import WebSocket
|
||||||
from fastapi.routing import APIRouter
|
from fastapi.routing import APIRouter
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
@@ -8,7 +9,7 @@ from data.crud import (
|
|||||||
read_connection,
|
read_connection,
|
||||||
read_select_query,
|
read_select_query,
|
||||||
)
|
)
|
||||||
from core.dependencies import get_db, get_current_user, get_admin_user
|
from core.dependencies import get_db, get_current_user, get_user_from_api_key
|
||||||
from core.exceptions import (
|
from core.exceptions import (
|
||||||
QueryNotFound,
|
QueryNotFound,
|
||||||
ConnectionNotFound,
|
ConnectionNotFound,
|
||||||
@@ -80,9 +81,9 @@ async def fetch_cursor(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/sse-stream-cursor", dependencies=[Depends(get_current_user)], status_code=200
|
||||||
@router.get("/sse-stream-cursor", dependencies=[Depends(get_current_user)], status_code=200)
|
)
|
||||||
async def server_side_events_stream_cursor(
|
async def server_side_events_stream_cursor(
|
||||||
cursor_id: str,
|
cursor_id: str,
|
||||||
page_size: Annotated[int, Field(ge=1, le=1000)] = 50,
|
page_size: Annotated[int, Field(ge=1, le=1000)] = 50,
|
||||||
@@ -95,11 +96,43 @@ async def server_side_events_stream_cursor(
|
|||||||
while True:
|
while True:
|
||||||
result = await fetch_cursor(cursor_id=cursor_id, page_size=page_size)
|
result = await fetch_cursor(cursor_id=cursor_id, page_size=page_size)
|
||||||
|
|
||||||
serialized_result = (
|
serialized_result = result.model_dump_json()
|
||||||
result.model_dump_json()
|
|
||||||
)
|
|
||||||
yield f"data: {serialized_result}\n\n" # Format as Server-Sent Event (SSE)
|
yield f"data: {serialized_result}\n\n" # Format as Server-Sent Event (SSE)
|
||||||
if result.cursor.is_closed:
|
if result.cursor.is_closed:
|
||||||
break
|
break
|
||||||
|
|
||||||
return StreamingResponse(stream(), media_type="text/event-stream")
|
return StreamingResponse(stream(), media_type="text/event-stream")
|
||||||
|
|
||||||
|
|
||||||
|
@router.websocket(
|
||||||
|
path="/ws-stream/{cursor_id}",
|
||||||
|
)
|
||||||
|
async def websocket_stream_cursor(
|
||||||
|
websocket: WebSocket,
|
||||||
|
cursor_id: str,
|
||||||
|
page_size: Annotated[int, Field(ge=1, le=1000)] = 50,
|
||||||
|
db=Depends(get_db),
|
||||||
|
):
|
||||||
|
await websocket.accept()
|
||||||
|
api_key = websocket.headers.get("Authorization")
|
||||||
|
user = await get_user_from_api_key(db=db, api_key=api_key)
|
||||||
|
if user is None:
|
||||||
|
await websocket.close(reason="Invalid credentials", code=1008)
|
||||||
|
return
|
||||||
|
|
||||||
|
cached_cursor = mysql.cached_cursors.get(cursor_id, None)
|
||||||
|
if cached_cursor is None:
|
||||||
|
e = CursorNotFound()
|
||||||
|
await websocket.close(reason=str(e.detail), code=1011)
|
||||||
|
return
|
||||||
|
|
||||||
|
while True:
|
||||||
|
result = await fetch_cursor(cursor_id=cursor_id, page_size=page_size)
|
||||||
|
|
||||||
|
serialized_result = result.model_dump_json()
|
||||||
|
await websocket.send_text(
|
||||||
|
f"data: {serialized_result}\n\n"
|
||||||
|
) # Format as Server-Sent Event (SSE)
|
||||||
|
if result.cursor.is_closed:
|
||||||
|
break
|
||||||
|
await websocket.close(reason="Done")
|
||||||
|
|||||||
@@ -13,6 +13,11 @@ async def get_db():
|
|||||||
API_KEY_NAME = "Authorization"
|
API_KEY_NAME = "Authorization"
|
||||||
api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False)
|
api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False)
|
||||||
|
|
||||||
|
async def get_user_from_api_key(db:AsyncSession, api_key:str):
|
||||||
|
user = await db.execute(select(User).filter(User.api_key == api_key))
|
||||||
|
user = user.scalars().first()
|
||||||
|
return user
|
||||||
|
|
||||||
async def get_current_user(db: AsyncSession = Depends(get_db), api_key:str = Security(api_key_header)) -> User:
|
async def get_current_user(db: AsyncSession = Depends(get_db), api_key:str = Security(api_key_header)) -> User:
|
||||||
if api_key_header is None:
|
if api_key_header is None:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
@@ -23,8 +28,8 @@ async def get_current_user(db: AsyncSession = Depends(get_db), api_key:str = Sec
|
|||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_403_FORBIDDEN, detail="API key missing, provide it in the header value [Authorization]"
|
status_code=status.HTTP_403_FORBIDDEN, detail="API key missing, provide it in the header value [Authorization]"
|
||||||
)
|
)
|
||||||
user = await db.execute(select(User).filter(User.api_key == api_key))
|
|
||||||
user = user.scalars().first()
|
user= await get_user_from_api_key(db=db, api_key=api_key)
|
||||||
if user is None:
|
if user is None:
|
||||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid credentials")
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid credentials")
|
||||||
return user
|
return user
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ closed_cached_cursors: dict[str, CachedCursor] = {}
|
|||||||
|
|
||||||
cached_cursors_cleaner_task = None
|
cached_cursors_cleaner_task = None
|
||||||
|
|
||||||
|
|
||||||
async def close_old_cached_cursors():
|
async def close_old_cached_cursors():
|
||||||
global cached_cursors, closed_cached_cursors
|
global cached_cursors, closed_cached_cursors
|
||||||
|
|
||||||
@@ -46,13 +47,12 @@ async def remove_old_closed_cached_cursors():
|
|||||||
del closed_cached_cursors[cursor_id]
|
del closed_cached_cursors[cursor_id]
|
||||||
print(f"Removed cursor {cursor_id}")
|
print(f"Removed cursor {cursor_id}")
|
||||||
|
|
||||||
|
|
||||||
async def cached_cursors_cleaner():
|
async def cached_cursors_cleaner():
|
||||||
global cached_cursors, closed_cached_cursors
|
global cached_cursors, closed_cached_cursors
|
||||||
while True:
|
while True:
|
||||||
print("hey")
|
|
||||||
await close_old_cached_cursors()
|
await close_old_cached_cursors()
|
||||||
await remove_old_closed_cached_cursors()
|
await remove_old_closed_cached_cursors()
|
||||||
|
|
||||||
await asyncio.sleep(10)
|
await asyncio.sleep(10)
|
||||||
|
|
||||||
|
|
||||||
@@ -69,6 +69,7 @@ async def pool_creator(connection: Connection, minsize=5, maxsize=10):
|
|||||||
maxsize=maxsize,
|
maxsize=maxsize,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def create_cursor(connection_id: int, query: SelectQuery) -> CachedCursor:
|
async def create_cursor(connection_id: int, query: SelectQuery) -> CachedCursor:
|
||||||
pool = pools.get(connection_id, None)
|
pool = pools.get(connection_id, None)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user