228 lines
5.3 KiB
Python
228 lines
5.3 KiB
Python
import re
|
|
from typing import Union, List, Optional, Literal, Any
|
|
from typing_extensions import Annotated
|
|
from pydantic import BaseModel, Field, field_validator, ValidationInfo, UUID4
|
|
from core.enums import (
|
|
ConnectionTypes,
|
|
UserRole,
|
|
FilterOperator,
|
|
SortOrder,
|
|
DBUpdatesActions,
|
|
)
|
|
from core.exceptions import QueryValidationError
|
|
|
|
|
|
class ConnectionBase(BaseModel):
|
|
db_name: str
|
|
type: ConnectionTypes
|
|
host: str
|
|
port: int
|
|
username: str
|
|
password: str
|
|
|
|
|
|
class ConnectionCreate(ConnectionBase):
|
|
pass
|
|
|
|
|
|
class ConnectionUpdate(ConnectionBase):
|
|
pass
|
|
|
|
|
|
class ConnectionInDBBase(ConnectionBase):
|
|
id: int
|
|
owner_id: int
|
|
|
|
class Config:
|
|
from_attributes = True
|
|
|
|
|
|
class Connection(ConnectionInDBBase):
|
|
pass
|
|
|
|
|
|
class UserBase(BaseModel):
|
|
username: str
|
|
|
|
|
|
class UserCreate(UserBase):
|
|
role: UserRole
|
|
|
|
|
|
class UserOut(UserBase):
|
|
id: int
|
|
role: UserRole
|
|
|
|
|
|
class UserInDBBase(UserBase):
|
|
id: int
|
|
role: UserRole
|
|
api_key: str
|
|
|
|
class Config:
|
|
from_attributes = True
|
|
|
|
|
|
class User(UserInDBBase):
|
|
pass
|
|
|
|
|
|
class FilterClause(BaseModel):
|
|
column: str
|
|
operator: FilterOperator
|
|
value: Optional[Union[str, int, float, bool, list]] = None
|
|
|
|
@field_validator("column")
|
|
@classmethod
|
|
def validate_column(cls, v: str) -> str:
|
|
if not re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", v):
|
|
raise QueryValidationError(
|
|
loc=["filters", "column"], msg="Invalid column name format"
|
|
)
|
|
return v
|
|
|
|
@field_validator("value")
|
|
@classmethod
|
|
def validate_value(
|
|
cls,
|
|
v: Optional[Union[str, int, float, bool, list]],
|
|
values: ValidationInfo,
|
|
) -> Optional[Union[str, int, float, bool, list]]:
|
|
operator = values.data.get("operator")
|
|
|
|
if operator in [FilterOperator.is_null, FilterOperator.is_not_null]:
|
|
if v is not None:
|
|
raise QueryValidationError(
|
|
loc=["filters", "value"],
|
|
msg="Value must be null for IS NULL/IS NOT NULL operators",
|
|
)
|
|
elif operator == FilterOperator.in_:
|
|
if not isinstance(v, list):
|
|
raise QueryValidationError(
|
|
loc=["filters", "value"],
|
|
msg="IN operator requires a list of values",
|
|
)
|
|
elif v is None:
|
|
raise QueryValidationError(
|
|
loc=["filters", "value"], msg="Value required for this operator"
|
|
)
|
|
return v
|
|
|
|
|
|
class SortClause(BaseModel):
|
|
column: str
|
|
order: SortOrder = SortOrder.asc
|
|
|
|
@field_validator("column")
|
|
@classmethod
|
|
def validate_column(cls, v: str) -> str:
|
|
if not re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", v):
|
|
raise QueryValidationError(
|
|
loc=["sort_by", "column"], msg="Invalid column name format"
|
|
)
|
|
return v
|
|
|
|
|
|
class SelectQueryBase(BaseModel):
|
|
name: str # a short name for this query.
|
|
description: str | None = None # describing what does this query do.
|
|
table_name: str
|
|
columns: Union[Literal["*"], List[str]] = "*"
|
|
filters: Optional[List[FilterClause]] = None
|
|
sort_by: Optional[List[SortClause]] = None
|
|
limit: Optional[Annotated[int, Field(strict=True, gt=0)]] = None
|
|
offset: Optional[Annotated[int, Field(strict=True, ge=0)]] = None
|
|
|
|
@field_validator("table_name")
|
|
@classmethod
|
|
def validate_table_name(cls, v: str) -> str:
|
|
if not re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", v):
|
|
raise QueryValidationError(
|
|
loc=["table_name"], msg="Invalid table name format"
|
|
)
|
|
return v
|
|
|
|
@field_validator("columns")
|
|
@classmethod
|
|
def validate_columns(
|
|
cls, v: Union[Literal["*"], List[str]]
|
|
) -> Union[Literal["*"], List[str]]:
|
|
if v == "*":
|
|
return v
|
|
|
|
for col in v:
|
|
if not re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", col):
|
|
raise QueryValidationError(
|
|
loc=["columns"], msg=f"Invalid column name: {col}"
|
|
)
|
|
return v
|
|
|
|
|
|
class SelectQuery(SelectQueryBase):
|
|
sql: str
|
|
params: list[str | int | float]
|
|
|
|
|
|
class SelectQueryIn(SelectQuery):
|
|
owner_id: int
|
|
|
|
|
|
class SelectQueryInDB(SelectQueryIn):
|
|
id: int
|
|
|
|
|
|
class SelectResultData(BaseModel):
|
|
columns: List[str]
|
|
data: List[List[Any]]
|
|
|
|
|
|
class SelectQueryInResult(BaseModel):
|
|
id: int
|
|
sql: str
|
|
params: list[str | int | float]
|
|
|
|
class Config:
|
|
from_attributes = True
|
|
|
|
|
|
class CachedCursorOut(BaseModel):
|
|
id: UUID4 | None
|
|
connection_id: int
|
|
query: SelectQueryInResult
|
|
row_count: int
|
|
fetched_rows: int
|
|
is_closed: bool
|
|
has_more: bool
|
|
close_at: int
|
|
ttl: int
|
|
|
|
class Config:
|
|
from_attributes = True
|
|
|
|
|
|
class SelectResult(BaseModel):
|
|
cursor: CachedCursorOut
|
|
results: SelectResultData | None
|
|
|
|
|
|
class ConnectionChangeBase(BaseModel):
|
|
connection_id: int
|
|
action: DBUpdatesActions
|
|
table: str
|
|
|
|
|
|
class ConnectionChangeInsert(ConnectionChangeBase):
|
|
action: DBUpdatesActions = DBUpdatesActions.insert
|
|
values: list[Any]
|
|
|
|
|
|
class ConnectionChangeDelete(ConnectionChangeBase):
|
|
action: DBUpdatesActions = DBUpdatesActions.delete
|
|
values: list[Any]
|
|
|
|
|
|
class ConnectionChangeUpdate(ConnectionChangeBase):
|
|
action: DBUpdatesActions = DBUpdatesActions.update
|
|
before_values: list[Any]
|
|
after_values: list[Any]
|