223 lines
6.7 KiB
Python
223 lines
6.7 KiB
Python
import asyncio, time, decimal, datetime
|
|
from fastapi import WebSocket, WebSocketDisconnect, WebSocketException
|
|
from fastapi.routing import APIRouter
|
|
from fastapi.responses import StreamingResponse
|
|
from typing_extensions import Annotated
|
|
from pydantic import Field
|
|
from data.schemas import SelectResult, CachedCursorOut
|
|
from fastapi import Depends
|
|
from data.crud import (
|
|
read_connection,
|
|
read_select_query,
|
|
)
|
|
from core.dependencies import get_db, get_current_user, get_user_from_api_key
|
|
from core.exceptions import (
|
|
QueryNotFound,
|
|
ConnectionNotFound,
|
|
PoolNotFound,
|
|
CursorNotFound,
|
|
)
|
|
from dbs import mysql
|
|
from utils.binlog import changes_queue, queue
|
|
|
|
router = APIRouter()
|
|
|
|
database_changes_active_websocket: WebSocket | None = None
|
|
|
|
|
|
@router.post("/execute", dependencies=[Depends(get_current_user)])
|
|
async def execute_select(
|
|
query_id: int,
|
|
connection_id: int,
|
|
page_size: Annotated[int, Field(ge=1, le=10000)] = 50,
|
|
db=Depends(get_db),
|
|
) -> SelectResult:
|
|
query = await read_select_query(db=db, query_id=query_id)
|
|
|
|
if query is None:
|
|
raise QueryNotFound
|
|
connection = await read_connection(db=db, connection_id=connection_id)
|
|
|
|
if connection is None:
|
|
raise ConnectionNotFound
|
|
pool = mysql.pools.get(connection.id, None)
|
|
if pool is None:
|
|
raise PoolNotFound
|
|
|
|
raw_result, rowcount = await mysql.execute_select_query(
|
|
pool=pool, sql_query=query.sql, params=query.params, fetch_num=page_size
|
|
)
|
|
|
|
results = mysql.dict_result_to_list(result=mysql.serializer(raw_result=raw_result))
|
|
|
|
return SelectResult(
|
|
cursor=CachedCursorOut(
|
|
id=None,
|
|
connection_id=connection_id,
|
|
query=query,
|
|
row_count=rowcount,
|
|
fetched_rows=len(results.data),
|
|
is_closed=True,
|
|
has_more=len(results.data) != rowcount,
|
|
ttl=-1,
|
|
close_at=-1,
|
|
),
|
|
results=results,
|
|
)
|
|
|
|
|
|
@router.get(path="/fetch_cursor", dependencies=[Depends(get_current_user)])
|
|
async def fetch_cursor(
|
|
cursor_id: str,
|
|
page_size: Annotated[int, Field(ge=1, le=1000)] = 50,
|
|
) -> SelectResult:
|
|
cached_cursor = mysql.cached_cursors.get(cursor_id, None)
|
|
if cached_cursor is None:
|
|
raise CursorNotFound
|
|
result = await cached_cursor.fetch_many(size=page_size)
|
|
|
|
if cached_cursor.done:
|
|
mysql.cached_cursors.pop(cursor_id, None)
|
|
else:
|
|
cached_cursor.close_at = cached_cursor.upgrade_close_at()
|
|
|
|
return SelectResult(
|
|
cursor=cached_cursor,
|
|
results={"columns": cached_cursor.query.columns, "data": result},
|
|
)
|
|
|
|
|
|
@router.get(
|
|
"/get_database_tables", dependencies=[Depends(get_current_user)], status_code=200
|
|
)
|
|
async def get_database_tables(connection_id: int):
|
|
pool = mysql.pools.get(connection_id, None)
|
|
|
|
r = await mysql.get_tables_and_datatypes(pool=pool)
|
|
print(r)
|
|
return r
|
|
|
|
|
|
@router.get(
|
|
"/sse-stream-cursor", dependencies=[Depends(get_current_user)], status_code=200
|
|
)
|
|
async def server_side_events_stream_cursor(
|
|
cursor_id: str,
|
|
page_size: Annotated[int, Field(ge=1, le=1000)] = 50,
|
|
):
|
|
cached_cursor = mysql.cached_cursors.get(cursor_id, None)
|
|
if cached_cursor is None:
|
|
raise CursorNotFound
|
|
|
|
async def stream():
|
|
while True:
|
|
result = await fetch_cursor(cursor_id=cursor_id, page_size=page_size)
|
|
|
|
serialized_result = result.model_dump_json()
|
|
yield f"data: {serialized_result}\n\n"
|
|
if result.cursor.is_closed:
|
|
break
|
|
|
|
return StreamingResponse(stream(), media_type="text/event-stream")
|
|
|
|
|
|
@router.websocket(
|
|
path="/ws-stream/{cursor_id}",
|
|
)
|
|
async def websocket_stream_cursor(
|
|
websocket: WebSocket,
|
|
cursor_id: str,
|
|
page_size: Annotated[int, Field(ge=1, le=1000)] = 50,
|
|
db=Depends(get_db),
|
|
):
|
|
await websocket.accept()
|
|
api_key = websocket.headers.get("Authorization")
|
|
user = await get_user_from_api_key(db=db, api_key=api_key)
|
|
if user is None:
|
|
await websocket.close(reason="Invalid credentials", code=1008)
|
|
return
|
|
|
|
cached_cursor = mysql.cached_cursors.get(cursor_id, None)
|
|
if cached_cursor is None:
|
|
e = CursorNotFound()
|
|
await websocket.close(reason=str(e.detail), code=1011)
|
|
return
|
|
|
|
while True:
|
|
result = await fetch_cursor(cursor_id=cursor_id, page_size=page_size)
|
|
|
|
serialized_result = result.model_dump_json()
|
|
await websocket.send_text(
|
|
f"data: {serialized_result}\n\n"
|
|
) # Format as Server-Sent Event (SSE)
|
|
if result.cursor.is_closed:
|
|
break
|
|
await websocket.close(reason="Done")
|
|
|
|
|
|
@router.websocket("/databases_changes")
|
|
async def websocket_endpoint(
|
|
websocket: WebSocket,
|
|
db=Depends(get_db),
|
|
):
|
|
global database_changes_active_websocket
|
|
|
|
await websocket.accept()
|
|
|
|
api_key = websocket.headers.get("Authorization")
|
|
user = await get_user_from_api_key(db=db, api_key=api_key)
|
|
if user is None:
|
|
await websocket.close(reason="Invalid credentials", code=1008)
|
|
return
|
|
|
|
if database_changes_active_websocket:
|
|
try:
|
|
await database_changes_active_websocket.close(
|
|
code=1001, reason="New connection established"
|
|
)
|
|
except Exception as e:
|
|
print(e)
|
|
|
|
database_changes_active_websocket = websocket
|
|
await websocket.send_json({"message": "status", "status": "Accepted."})
|
|
|
|
try:
|
|
await feed_databases_changes_ws(websocket=websocket)
|
|
except WebSocketDisconnect:
|
|
print("Closed websocket.")
|
|
|
|
|
|
def serialize_list(l: list):
|
|
serialized = []
|
|
for value in l:
|
|
if isinstance(value, str | int | None | float):
|
|
serialized.append(str(value))
|
|
elif isinstance(value, decimal.Decimal):
|
|
serialized.append(float(value))
|
|
elif isinstance(value, datetime.date):
|
|
serialized.append(value.strftime("%Y-%m-%d"))
|
|
else:
|
|
serialized.append(str(value))
|
|
return serialized
|
|
|
|
|
|
async def feed_databases_changes_ws(websocket: WebSocket):
|
|
last_update = 0
|
|
while True:
|
|
try:
|
|
change = changes_queue.get_nowait()
|
|
if change.action == "UPDATE":
|
|
change.after_values = serialize_list(change.after_values)
|
|
change.before_values = serialize_list(change.before_values)
|
|
else:
|
|
change.values = serialize_list(change.values)
|
|
await websocket.send_json(
|
|
{"message": "change", "change": change.model_dump()}
|
|
)
|
|
except queue.Empty:
|
|
if last_update + 10 < time.time():
|
|
await websocket.send_json({"message": "status", "status": "Alive."})
|
|
last_update = time.time()
|
|
await asyncio.sleep(1)
|
|
continue
|