Files
db-middleware/data/schemas.py

228 lines
5.3 KiB
Python

import re
from typing import Union, List, Optional, Literal, Any
from typing_extensions import Annotated
from pydantic import BaseModel, Field, field_validator, ValidationInfo, UUID4
from core.enums import (
ConnectionTypes,
UserRole,
FilterOperator,
SortOrder,
DBUpdatesActions,
)
from core.exceptions import QueryValidationError
class ConnectionBase(BaseModel):
db_name: str
type: ConnectionTypes
host: str
port: int
username: str
password: str
class ConnectionCreate(ConnectionBase):
pass
class ConnectionUpdate(ConnectionBase):
pass
class ConnectionInDBBase(ConnectionBase):
id: int
owner_id: int
class Config:
from_attributes = True
class Connection(ConnectionInDBBase):
pass
class UserBase(BaseModel):
username: str
class UserCreate(UserBase):
role: UserRole
class UserOut(UserBase):
id: int
role: UserRole
class UserInDBBase(UserBase):
id: int
role: UserRole
api_key: str
class Config:
from_attributes = True
class User(UserInDBBase):
pass
class FilterClause(BaseModel):
column: str
operator: FilterOperator
value: Optional[Union[str, int, float, bool, list]] = None
@field_validator("column")
@classmethod
def validate_column(cls, v: str) -> str:
if not re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", v):
raise QueryValidationError(
loc=["filters", "column"], msg="Invalid column name format"
)
return v
@field_validator("value")
@classmethod
def validate_value(
cls,
v: Optional[Union[str, int, float, bool, list]],
values: ValidationInfo,
) -> Optional[Union[str, int, float, bool, list]]:
operator = values.data.get("operator")
if operator in [FilterOperator.is_null, FilterOperator.is_not_null]:
if v is not None:
raise QueryValidationError(
loc=["filters", "value"],
msg="Value must be null for IS NULL/IS NOT NULL operators",
)
elif operator == FilterOperator.in_:
if not isinstance(v, list):
raise QueryValidationError(
loc=["filters", "value"],
msg="IN operator requires a list of values",
)
elif v is None:
raise QueryValidationError(
loc=["filters", "value"], msg="Value required for this operator"
)
return v
class SortClause(BaseModel):
column: str
order: SortOrder = SortOrder.asc
@field_validator("column")
@classmethod
def validate_column(cls, v: str) -> str:
if not re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", v):
raise QueryValidationError(
loc=["sort_by", "column"], msg="Invalid column name format"
)
return v
class SelectQueryBase(BaseModel):
name: str # a short name for this query.
description: str | None = None # describing what does this query do.
table_name: str
columns: Union[Literal["*"], List[str]] = "*"
filters: Optional[List[FilterClause]] = None
sort_by: Optional[List[SortClause]] = None
limit: Optional[Annotated[int, Field(strict=True, gt=0)]] = None
offset: Optional[Annotated[int, Field(strict=True, ge=0)]] = None
@field_validator("table_name")
@classmethod
def validate_table_name(cls, v: str) -> str:
if not re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", v):
raise QueryValidationError(
loc=["table_name"], msg="Invalid table name format"
)
return v
@field_validator("columns")
@classmethod
def validate_columns(
cls, v: Union[Literal["*"], List[str]]
) -> Union[Literal["*"], List[str]]:
if v == "*":
return v
for col in v:
if not re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", col):
raise QueryValidationError(
loc=["columns"], msg=f"Invalid column name: {col}"
)
return v
class SelectQuery(SelectQueryBase):
sql: str
params: list[str | int | float]
class SelectQueryIn(SelectQuery):
owner_id: int
class SelectQueryInDB(SelectQueryIn):
id: int
class SelectResultData(BaseModel):
columns: List[str]
data: List[List[Any]]
class SelectQueryInResult(BaseModel):
id: int
sql: str
params: list[str | int | float]
class Config:
from_attributes = True
class CachedCursorOut(BaseModel):
id: UUID4 | None
connection_id: int
query: SelectQueryInResult
row_count: int
fetched_rows: int
is_closed: bool
has_more: bool
close_at: int
ttl: int
class Config:
from_attributes = True
class SelectResult(BaseModel):
cursor: CachedCursorOut
results: SelectResultData | None
class ConnectionChangeBase(BaseModel):
connection_id: int
action: DBUpdatesActions
table: str
class ConnectionChangeInsert(ConnectionChangeBase):
action: DBUpdatesActions = DBUpdatesActions.insert
values: list[Any]
class ConnectionChangeDelete(ConnectionChangeBase):
action: DBUpdatesActions = DBUpdatesActions.delete
values: list[Any]
class ConnectionChangeUpdate(ConnectionChangeBase):
action: DBUpdatesActions = DBUpdatesActions.update
before_values: list[Any]
after_values: list[Any]