import re from typing import Union, List, Optional, Literal, Any from typing_extensions import Annotated from pydantic import BaseModel, Field, field_validator, ValidationInfo, UUID4 from core.enums import ( ConnectionTypes, UserRole, FilterOperator, SortOrder, DBUpdatesActions, ) from core.exceptions import QueryValidationError class ConnectionBase(BaseModel): db_name: str type: ConnectionTypes host: str port: int username: str password: str class ConnectionCreate(ConnectionBase): pass class ConnectionUpdate(ConnectionBase): pass class ConnectionInDBBase(ConnectionBase): id: int owner_id: int class Config: from_attributes = True class Connection(ConnectionInDBBase): pass class UserBase(BaseModel): username: str class UserCreate(UserBase): role: UserRole class UserOut(UserBase): id: int role: UserRole class UserInDBBase(UserBase): id: int role: UserRole api_key: str class Config: from_attributes = True class User(UserInDBBase): pass class FilterClause(BaseModel): column: str operator: FilterOperator value: Optional[Union[str, int, float, bool, list]] = None @field_validator("column") @classmethod def validate_column(cls, v: str) -> str: if not re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", v): raise QueryValidationError( loc=["filters", "column"], msg="Invalid column name format" ) return v @field_validator("value") @classmethod def validate_value( cls, v: Optional[Union[str, int, float, bool, list]], values: ValidationInfo, ) -> Optional[Union[str, int, float, bool, list]]: operator = values.data.get("operator") if operator in [FilterOperator.is_null, FilterOperator.is_not_null]: if v is not None: raise QueryValidationError( loc=["filters", "value"], msg="Value must be null for IS NULL/IS NOT NULL operators", ) elif operator == FilterOperator.in_: if not isinstance(v, list): raise QueryValidationError( loc=["filters", "value"], msg="IN operator requires a list of values", ) elif v is None: raise QueryValidationError( loc=["filters", "value"], msg="Value required for this operator" ) return v class SortClause(BaseModel): column: str order: SortOrder = SortOrder.asc @field_validator("column") @classmethod def validate_column(cls, v: str) -> str: if not re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", v): raise QueryValidationError( loc=["sort_by", "column"], msg="Invalid column name format" ) return v class SelectQueryBase(BaseModel): name: str # a short name for this query. description: str | None = None # describing what does this query do. table_name: str columns: Union[Literal["*"], List[str]] = "*" filters: Optional[List[FilterClause]] = None sort_by: Optional[List[SortClause]] = None limit: Optional[Annotated[int, Field(strict=True, gt=0)]] = None offset: Optional[Annotated[int, Field(strict=True, ge=0)]] = None @field_validator("table_name") @classmethod def validate_table_name(cls, v: str) -> str: if not re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", v): raise QueryValidationError( loc=["table_name"], msg="Invalid table name format" ) return v @field_validator("columns") @classmethod def validate_columns( cls, v: Union[Literal["*"], List[str]] ) -> Union[Literal["*"], List[str]]: if v == "*": return v for col in v: if not re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", col): raise QueryValidationError( loc=["columns"], msg=f"Invalid column name: {col}" ) return v class SelectQuery(SelectQueryBase): sql: str params: list[str | int | float] class SelectQueryIn(SelectQuery): owner_id: int class SelectQueryInDB(SelectQueryIn): id: int class SelectResultData(BaseModel): columns: List[str] | str data: List[List[Any]] class SelectQueryInResult(BaseModel): id: int sql: str params: list[str | int | float] class Config: from_attributes = True class CachedCursorOut(BaseModel): id: UUID4 | None connection_id: int query: SelectQueryInResult row_count: int fetched_rows: int is_closed: bool has_more: bool close_at: int ttl: int class Config: from_attributes = True class SelectResult(BaseModel): cursor: CachedCursorOut results: SelectResultData | None class ConnectionChangeBase(BaseModel): connection_id: int action: DBUpdatesActions table: str class ConnectionChangeInsert(ConnectionChangeBase): action: DBUpdatesActions = DBUpdatesActions.insert values: list[Any] class ConnectionChangeDelete(ConnectionChangeBase): action: DBUpdatesActions = DBUpdatesActions.delete values: list[Any] class ConnectionChangeUpdate(ConnectionChangeBase): action: DBUpdatesActions = DBUpdatesActions.update before_values: list[Any] after_values: list[Any]