Implemented alembic migrations, added cursors closing task.

This commit is contained in:
2025-02-25 23:37:55 +03:00
parent 836ce1dc82
commit 1abc225923
17 changed files with 746 additions and 101 deletions

119
alembic/alembic.ini Normal file
View File

@@ -0,0 +1,119 @@
# A generic, single database configuration.
[alembic]
# path to migration scripts
# Use forward slashes (/) also on windows to provide an os agnostic path
script_location = alembic
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
# Uncomment the line below if you want the files to be prepended with date and time
# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
# for all available tokens
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
# sys.path path, will be prepended to sys.path if present.
# defaults to the current working directory.
prepend_sys_path = .
# timezone to use when rendering the date within the migration file
# as well as the filename.
# If specified, requires the python>=3.9 or backports.zoneinfo library and tzdata library.
# Any required deps can installed by adding `alembic[tz]` to the pip requirements
# string value is passed to ZoneInfo()
# leave blank for localtime
# timezone =
# max length of characters to apply to the "slug" field
# truncate_slug_length = 40
# set to 'true' to run the environment during
# the 'revision' command, regardless of autogenerate
# revision_environment = false
# set to 'true' to allow .pyc and .pyo files without
# a source .py file to be detected as revisions in the
# versions/ directory
# sourceless = false
# version location specification; This defaults
# to alembic/versions. When using multiple version
# directories, initial revisions must be specified with --version-path.
# The path separator used here should be the separator specified by "version_path_separator" below.
# version_locations = %(here)s/bar:%(here)s/bat:alembic/versions
# version path separator; As mentioned above, this is the character used to split
# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep.
# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas.
# Valid values for version_path_separator are:
#
# version_path_separator = :
# version_path_separator = ;
# version_path_separator = space
# version_path_separator = newline
#
# Use os.pathsep. Default configuration used for new projects.
version_path_separator = os
# set to 'true' to search source files recursively
# in each "version_locations" directory
# new in Alembic version 1.10
# recursive_version_locations = false
# the output encoding used when revision files
# are written from script.py.mako
# output_encoding = utf-8
# sqlalchemy.url = driver://user:pass@localhost/dbname
[post_write_hooks]
# post_write_hooks defines scripts or Python functions that are run
# on newly generated revision scripts. See the documentation for further
# detail and examples
# format using "black" - use the console_scripts runner, against the "black" entrypoint
# hooks = black
# black.type = console_scripts
# black.entrypoint = black
# black.options = -l 79 REVISION_SCRIPT_FILENAME
# lint with attempts to fix using "ruff" - use the exec runner, execute a binary
# hooks = ruff
# ruff.type = exec
# ruff.executable = %(here)s/.venv/bin/ruff
# ruff.options = --fix REVISION_SCRIPT_FILENAME
# Logging configuration
[loggers]
keys = root,sqlalchemy,alembic
[handlers]
keys = console
[formatters]
keys = generic
[logger_root]
level = WARNING
handlers = console
qualname =
[logger_sqlalchemy]
level = WARNING
handlers =
qualname = sqlalchemy.engine
[logger_alembic]
level = INFO
handlers =
qualname = alembic
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S

80
alembic/env.py Normal file
View File

@@ -0,0 +1,80 @@
from logging.config import fileConfig
from data.db import engine
from data.models import Base
from alembic import context
"""
Create migration example:
> alembic -c alembic/alembic.ini revision --autogenerate -m "Added pool size config to the connection model."
Migrate example:
> alembic -c alembic/alembic.ini upgrade head
"""
config = context.config
if config.config_file_name is not None:
fileConfig(config.config_file_name)
from data.models import User, Query, Connection
target_metadata = Base.metadata
def run_migrations_offline() -> None:
"""
Run migrations in 'offline' mode.
This configures the context with just a URL
and not an Engine, though an Engine is acceptable
here as well. By skipping the Engine creation
we don't even need a DBAPI to be available.
Calls to context.execute() here emit the given string to the
script output.
"""
url = config.get_main_option("sqlalchemy.url")
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations()
async def run_migrations_online() -> None:
"""
Run migrations in 'online' mode.
In this scenario we need to create an Engine
and associate a connection with the context.
"""
connectable = engine
async with connectable.connect() as connection:
await connection.run_sync(do_migrations)
def do_migrations(connection):
context.configure(connection=connection, target_metadata=Base.metadata)
with context.begin_transaction():
context.run_migrations()
if context.is_offline_mode():
run_migrations_offline()
else:
import asyncio
asyncio.run(run_migrations_online())

26
alembic/script.py.mako Normal file
View File

@@ -0,0 +1,26 @@
"""${message}
Revision ID: ${up_revision}
Revises: ${down_revision | comma,n}
Create Date: ${create_date}
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
${imports if imports else ""}
# revision identifiers, used by Alembic.
revision: str = ${repr(up_revision)}
down_revision: Union[str, None] = ${repr(down_revision)}
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
def upgrade() -> None:
${upgrades if upgrades else "pass"}
def downgrade() -> None:
${downgrades if downgrades else "pass"}

View File

@@ -0,0 +1,32 @@
"""Added pool size config to the connection model.
Revision ID: 1c62ff091f5c
Revises: 6eb236240aec
Create Date: 2025-02-25 19:24:21.712856
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = '1c62ff091f5c'
down_revision: Union[str, None] = '6eb236240aec'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('connections', sa.Column('pool_minsize', sa.Integer(), nullable=False, server_default="5"))
op.add_column('connections', sa.Column('pool_maxsize', sa.Integer(), nullable=False, server_default="10"))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('connections', 'pool_maxsize')
op.drop_column('connections', 'pool_minsize')
# ### end Alembic commands ###

View File

@@ -0,0 +1,75 @@
"""Initial migration.
Revision ID: 6eb236240aec
Revises:
Create Date: 2025-02-25 19:18:03.125433
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = '6eb236240aec'
down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('users',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('username', sa.String(), nullable=False),
sa.Column('role', sa.Enum('admin', 'user', name='userrole'), nullable=False),
sa.Column('api_key', sa.String(), nullable=False),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('api_key')
)
op.create_index(op.f('ix_users_id'), 'users', ['id'], unique=False)
op.create_index(op.f('ix_users_username'), 'users', ['username'], unique=True)
op.create_table('connections',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('db_name', sa.String(), nullable=False),
sa.Column('type', sa.Enum('mysql', 'postgresql', name='connectiontypes'), nullable=False),
sa.Column('host', sa.String(), nullable=True),
sa.Column('port', sa.Integer(), nullable=True),
sa.Column('username', sa.String(), nullable=True),
sa.Column('password', sa.String(), nullable=True),
sa.Column('owner_id', sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(['owner_id'], ['users.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_connections_id'), 'connections', ['id'], unique=False)
op.create_table('queries',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('name', sa.String(), nullable=False),
sa.Column('description', sa.String(), nullable=True),
sa.Column('owner_id', sa.Integer(), nullable=True),
sa.Column('table_name', sa.String(), nullable=False),
sa.Column('columns', sa.JSON(), nullable=False),
sa.Column('filters', sa.JSON(), nullable=True),
sa.Column('sort_by', sa.JSON(), nullable=True),
sa.Column('limit', sa.Integer(), nullable=True),
sa.Column('offset', sa.Integer(), nullable=True),
sa.Column('sql', sa.String(), nullable=False),
sa.Column('params', sa.JSON(), nullable=False),
sa.ForeignKeyConstraint(['owner_id'], ['users.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_queries_id'), 'queries', ['id'], unique=False)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f('ix_queries_id'), table_name='queries')
op.drop_table('queries')
op.drop_index(op.f('ix_connections_id'), table_name='connections')
op.drop_table('connections')
op.drop_index(op.f('ix_users_username'), table_name='users')
op.drop_index(op.f('ix_users_id'), table_name='users')
op.drop_table('users')
# ### end Alembic commands ###

View File

@@ -1,8 +1,10 @@
from fastapi import APIRouter from fastapi import APIRouter
from app.connections import router as connections_router from app.connections import router as connections_router
from app.operations import router as router from app.operations import router as operations_router
from app.users import router as user_router from app.users import router as user_router
from app.cursors import router as cursors_router
from app.queries import router as queries_router
api_router = APIRouter() api_router = APIRouter()
api_router.include_router(router=user_router, prefix="/users", tags=["Users"]) api_router.include_router(router=user_router, prefix="/users", tags=["Users"])
@@ -10,5 +12,11 @@ api_router.include_router(
router=connections_router, prefix="/connections", tags=["Connections"] router=connections_router, prefix="/connections", tags=["Connections"]
) )
api_router.include_router( api_router.include_router(
router=router, prefix='/operations', tags=["Operations"] router=cursors_router, prefix='/cursors', tags=['Cursors']
)
api_router.include_router(
router=queries_router, prefix='/queries', tags=['Queries']
)
api_router.include_router(
router=operations_router, prefix='/operations', tags=["Operations"]
) )

78
app/cursors.py Normal file
View File

@@ -0,0 +1,78 @@
from fastapi.routing import APIRouter
from data.schemas import (
CachedCursorOut,
)
from fastapi import Depends, status
from data.crud import (
read_connection,
read_select_query,
)
from core.dependencies import get_db, get_current_user, get_admin_user
from core.exceptions import (
QueryNotFound,
ConnectionNotFound,
CursorNotFound,
)
from dbs import mysql
router = APIRouter()
@router.post("/", dependencies=[Depends(get_current_user)])
async def create_cursor_endpoint(
query_id: int,
connection_id: int,
db=Depends(get_db),
) -> CachedCursorOut:
query = await read_select_query(db=db, query_id=query_id)
if query is None:
raise QueryNotFound
connection = await read_connection(db=db, connection_id=connection_id)
if connection is None:
raise ConnectionNotFound
cached_cursor = await mysql.create_cursor(query=query, connection_id=connection_id)
mysql.cached_cursors[cached_cursor.id] = cached_cursor
print(mysql.cached_cursors)
return cached_cursor
@router.get("/", dependencies=[Depends(get_current_user)])
async def get_all_cursors() -> list[CachedCursorOut]:
return mysql.cached_cursors.values()
@router.get("/{cursor_id}", dependencies=[Depends(get_current_user)])
async def get_cursors(cursor_id: str) -> CachedCursorOut:
try:
return mysql.cached_cursors[cursor_id]
except KeyError:
raise CursorNotFound
@router.delete(
"/",
dependencies=[Depends(get_admin_user)],
status_code=status.HTTP_204_NO_CONTENT,
)
async def close_all_cursor() -> None:
for cached_cursor in mysql.cached_cursors.values():
await cached_cursor.close()
mysql.cached_cursors.clear()
@router.delete(
"/{cursor_id}",
dependencies=[Depends(get_current_user)],
status_code=status.HTTP_204_NO_CONTENT,
)
async def close_cursor(cursor_id: str) -> None:
cached_cursor = mysql.cached_cursors.get(cursor_id, None)
if cached_cursor is None:
raise CursorNotFound
await cached_cursor.close()
del mysql.cached_cursors[cursor_id]

View File

@@ -1,59 +1,22 @@
from fastapi.routing import APIRouter from fastapi.routing import APIRouter
from typing_extensions import Annotated from typing_extensions import Annotated
from pydantic import Field from pydantic import Field
from data.schemas import ( from data.schemas import SelectResult, CachedCursorOut
SelectQueryBase, from fastapi import Depends
SelectQueryInDB,
SelectQuery,
SelectQueryIn,
SelectResult,
SelectQueryMetaData,
SelectQueryInResult,
)
from fastapi import Depends, HTTPException, status
from sqlalchemy.ext.asyncio import AsyncSession
from data.crud import ( from data.crud import (
read_connection, read_connection,
create_select_query,
read_all_select_queries,
read_select_query, read_select_query,
) )
from core.dependencies import get_db, get_current_user, get_admin_user from core.dependencies import get_db, get_current_user, get_admin_user
from core.exceptions import QueryNotFound, ConnectionNotFound, PoolNotFound from core.exceptions import (
from utils.sql_creator import build_sql_query_text QueryNotFound,
ConnectionNotFound,
PoolNotFound,
CursorNotFound,
)
from dbs import mysql from dbs import mysql
router = APIRouter(prefix="/select") router = APIRouter()
@router.post("/check-query", dependencies=[Depends(get_current_user)])
async def check_select_query(query: SelectQueryBase) -> SelectQuery:
sql, params = build_sql_query_text(query)
q = SelectQuery(**query.model_dump(), params=params, sql=sql)
return q
@router.post("/query")
async def create_select_query_endpoint(
query: SelectQueryBase, db=Depends(get_db), user=Depends(get_current_user)
) -> SelectQueryInDB:
sql, params = build_sql_query_text(query)
query_in = SelectQueryIn(
**query.model_dump(), owner_id=user.id, params=params, sql=sql
)
return await create_select_query(db=db, query=query_in)
@router.get("/query", dependencies=[Depends(get_current_user)])
async def get_select_queries_endpoint(db=Depends(get_db)) -> list[SelectQueryInDB]:
return await read_all_select_queries(db=db)
@router.get("/query/{query_id}", dependencies=[Depends(get_current_user)])
async def get_select_queries_endpoint(
query_id: int, db=Depends(get_db)
) -> SelectQueryInDB:
return await read_select_query(db=db, query_id=query_id)
@router.post("/execute", dependencies=[Depends(get_current_user)]) @router.post("/execute", dependencies=[Depends(get_current_user)])
@@ -76,17 +39,41 @@ async def execute_select(
raise PoolNotFound raise PoolNotFound
raw_result, rowcount = await mysql.execute_select_query( raw_result, rowcount = await mysql.execute_select_query(
pool=pool, query=query.sql, params=query.params, fetch_num=page_size pool=pool, sql_query=query.sql, params=query.params, fetch_num=page_size
) )
results = mysql.dict_result_to_list(result=mysql.serializer(raw_result=raw_result)) results = mysql.dict_result_to_list(result=mysql.serializer(raw_result=raw_result))
meta = SelectQueryMetaData(
cursor=None, total_number=rowcount, has_more=len(results.data) != rowcount
)
return SelectResult( return SelectResult(
meta=meta, cursor=CachedCursorOut(
query=query, id=None,
connection_id=connection_id,
query=query,
row_count=rowcount,
fetched_rows=len(results.data),
is_closed=True,
has_more=len(results.data) != rowcount,
ttl=-1,
close_at=-1,
),
results=results, results=results,
) )
@router.get(path="/fetch_cursor", dependencies=[Depends(get_current_user)])
async def fetch_cursor(
cursor_id: str,
page_size: Annotated[int, Field(ge=1, le=1000)] = 50,
) -> SelectResult:
cached_cursor = mysql.cached_cursors.get(cursor_id, None)
if cached_cursor is None:
raise CursorNotFound
result = await cached_cursor.fetch_many(size=page_size)
if cached_cursor.done:
mysql.cached_cursors.pop(cursor_id, None)
return SelectResult(
cursor=cached_cursor,
results={"columns": cached_cursor.query.columns, "data": result},
)

50
app/queries.py Normal file
View File

@@ -0,0 +1,50 @@
from fastapi.routing import APIRouter
from data.schemas import (
SelectQueryBase,
SelectQueryInDB,
SelectQuery,
SelectQueryIn,
)
from fastapi import Depends
from data.crud import (
create_select_query,
read_all_select_queries,
read_select_query,
)
from core.dependencies import get_db, get_current_user, get_admin_user
from utils.mysql_scripts import build_sql_query_text
router = APIRouter()
@router.post("/check", dependencies=[Depends(get_current_user)])
async def check_select_query(query: SelectQueryBase) -> SelectQuery:
sql, params = build_sql_query_text(query)
q = SelectQuery(**query.model_dump(), params=params, sql=sql)
return q
@router.post("/")
async def create_select_query_endpoint(
query: SelectQueryBase, db=Depends(get_db), user=Depends(get_current_user)
) -> SelectQueryInDB:
sql, params = build_sql_query_text(query)
query_in = SelectQueryIn(
**query.model_dump(), owner_id=user.id, params=params, sql=sql
)
return await create_select_query(db=db, query=query_in)
@router.get("/", dependencies=[Depends(get_current_user)])
async def get_select_queries_endpoint(db=Depends(get_db)) -> list[SelectQueryInDB]:
return await read_all_select_queries(db=db)
@router.get("/{query_id}", dependencies=[Depends(get_current_user)])
async def get_select_queries_endpoint(
query_id: int, db=Depends(get_db)
) -> SelectQueryInDB:
return await read_select_query(db=db, query_id=query_id)

View File

@@ -25,25 +25,67 @@ class QueryValidationError(ValueError):
self.msg = msg self.msg = msg
super().__init__(msg) super().__init__(msg)
class QueryNotFound(HTTPException): class QueryNotFound(HTTPException):
def __init__(self, status_code=404, detail = { def __init__(
'message': "The referenced query was not found.", self,
"code": 'query-not-found' status_code=404,
}, headers = None): detail={
"message": "The referenced query was not found.",
"code": "query-not-found",
},
headers=None,
):
super().__init__(status_code, detail, headers) super().__init__(status_code, detail, headers)
class ConnectionNotFound(HTTPException): class ConnectionNotFound(HTTPException):
def __init__(self, status_code=404, detail = { def __init__(
'message': "The referenced connection was not found.", self,
"code": 'connection-not-found' status_code=404,
}, headers = None): detail={
"message": "The referenced connection was not found.",
"code": "connection-not-found",
},
headers=None,
):
super().__init__(status_code, detail, headers) super().__init__(status_code, detail, headers)
class PoolNotFound(HTTPException): class PoolNotFound(HTTPException):
def __init__(self, status_code=404, detail = { def __init__(
'message': "We didn't find a running Pool for the referenced connection.", self,
"code": 'pool-not-found' status_code=404,
}, headers = None): detail={
super().__init__(status_code, detail, headers) "message": "We didn't find a running Pool for the referenced connection.",
"code": "pool-not-found",
},
headers=None,
):
super().__init__(status_code, detail, headers)
class CursorNotFound(HTTPException):
def __init__(
self,
status_code=404,
detail={
"message": "We didn't find a Cursor with the provided ID.",
"code": "cursor-not-found",
},
headers=None,
):
super().__init__(status_code, detail, headers)
class ClosedCursorUsage(HTTPException):
def __init__(
self,
status_code=400,
detail={
"message": "The Cursor you are trying to use is closed.",
"code": "cursor-closed",
},
headers=None,
):
super().__init__(status_code, detail, headers)

63
data/app_types.py Normal file
View File

@@ -0,0 +1,63 @@
import datetime
from aiomysql import SSCursor, Connection, Pool, SSCursor
from data.schemas import SelectQuery
from core.exceptions import ClosedCursorUsage
class CachedCursor:
def __init__(
self,
id: str,
cursor: SSCursor,
connection: Connection,
pool: Pool,
connection_id: int,
query: SelectQuery,
ttl: int=60
):
self.id = id
self.cursor = cursor
self.connection = connection
self.connection_id = connection_id
self.pool = pool
self.query = query
self.row_count: int = -1 if cursor.rowcount > 10000000000000000000 else cursor.rowcount
# The rowcount for a SELECT is set to -1 when using a server-side cursor.
# The incorrect large number (> 10000000000000000000) is because -1 is
# interpreted as an unsigned integer in MySQL's internal C API.
self.fetched_rows: int = 0
self.is_closed: bool=False
self.ttl:int = ttl
self.close_at = self.upgrade_close_at()
self.done=False
@property
def has_more(self):
return not self.done
def upgrade_close_at(self) -> int:
return int(datetime.datetime.now(tz=datetime.UTC).timestamp()) + self.ttl
async def close(self):
await self.cursor.close()
await self.pool.release(self.connection)
self.is_closed=True
async def fetch_many(self, size: int = 100) -> tuple[list[tuple], bool]:
if self.is_closed:
raise ClosedCursorUsage
result = await self.cursor.fetchmany(size)
if len(result) < size:
# The cursor has reached the end of the set.
await self.close()
self.done=True
else:
self.upgrade_close_at()
self.fetched_rows += len(result)
return result

View File

@@ -21,9 +21,11 @@ class Connection(Base):
db_name = Column(String, nullable=False) db_name = Column(String, nullable=False)
type = Column(Enum(ConnectionTypes), nullable=False) type = Column(Enum(ConnectionTypes), nullable=False)
host = Column(String) host = Column(String)
port = Column(Integer) port = Column(Integer)
username = Column(String) username = Column(String)
password = Column(String) password = Column(String)
pool_minsize = Column(Integer, nullable=False, default=5)
pool_maxsize = Column(Integer, nullable=False, default=10)
owner_id = Column(Integer, ForeignKey("users.id")) owner_id = Column(Integer, ForeignKey("users.id"))
# owner = relationship("User", back_populates="connections") # owner = relationship("User", back_populates="connections")

View File

@@ -124,8 +124,8 @@ class SelectQueryBase(BaseModel):
columns: Union[Literal["*"], List[str]] = "*" columns: Union[Literal["*"], List[str]] = "*"
filters: Optional[List[FilterClause]] = None filters: Optional[List[FilterClause]] = None
sort_by: Optional[List[SortClause]] = None sort_by: Optional[List[SortClause]] = None
limit: Annotated[int, Field(strict=True, gt=0)] = None limit: Optional[Annotated[int, Field(strict=True, gt=0)]] = None
offset: Annotated[int, Field(strict=True, ge=0)] = None offset: Optional[Annotated[int, Field(strict=True, ge=0)]] = None
@field_validator("table_name") @field_validator("table_name")
@classmethod @classmethod
@@ -179,16 +179,20 @@ class SelectQueryInResult(BaseModel):
from_attributes = True from_attributes = True
class SelectQueryMetaData(BaseModel): class CachedCursorOut(BaseModel):
cursor: Optional[UUID4] = Field( id: UUID4 | None
None, connection_id: int
description="A UUID4 cursor for pagination. Can be None if no more data is available.", query: SelectQueryInResult
) row_count: int
total_number: int fetched_rows: int
has_more: bool = False is_closed: bool
has_more: bool
close_at: int
ttl: int
class Config:
from_attributes = True
class SelectResult(BaseModel): class SelectResult(BaseModel):
meta: SelectQueryMetaData cursor: CachedCursorOut
query: SelectQueryInResult
results: SelectResultData | None results: SelectResultData | None

View File

@@ -1,18 +1,59 @@
import aiomysql, decimal, datetime import asyncio, aiomysql, decimal, datetime, uuid, logging
import asyncio
from typing import Literal, Any from typing import Literal, Any
from data.schemas import Connection, SelectResultData from data.schemas import Connection, SelectResultData, SelectQuery
from data.app_types import CachedCursor
from core.exceptions import PoolNotFound
# Database configuration pools: dict[str, aiomysql.Pool] = {}
DB_CONFIG = { cached_cursors: dict[str, CachedCursor] = {}
"host": "localhost", closed_cached_cursors: dict[str, CachedCursor] = {}
"user": "me", # Replace with your MySQL username
"password": "Passwd3.14", # Replace with your MySQL password
"db": "testing", # Replace with your database name
"port": 3306, # Default MySQL port
}
pools: None | dict[str, aiomysql.Pool] = {} cached_cursors_cleaner_task = None
async def close_old_cached_cursors():
global cached_cursors, closed_cached_cursors
for cursor_id in list(cached_cursors.keys()):
cursor = cached_cursors.get(cursor_id, None)
if cursor is None:
continue
print(cursor.close_at, datetime.datetime.now(datetime.UTC).timestamp())
if cursor.close_at > datetime.datetime.now(datetime.UTC).timestamp():
continue
try:
await cursor.close()
cached_cursors.pop(cursor_id, None)
closed_cached_cursors[cursor_id] = cursor
print(f"Closed cursor {cursor_id}")
except Exception as e:
print(f"Error closing Cursor {cursor_id} -> {e}")
async def remove_old_closed_cached_cursors():
global closed_cached_cursors
for cursor_id in set(closed_cached_cursors.keys()):
closed_cursor = closed_cached_cursors.get(cursor_id, None)
if closed_cursor is None:
continue
if (
closed_cursor.close_at + closed_cursor.ttl * 5
> datetime.datetime.now(datetime.UTC).timestamp()
):
continue
del closed_cached_cursors[cursor_id]
print(f"Removed cursor {cursor_id}")
async def cached_cursors_cleaner():
global cached_cursors, closed_cached_cursors
while True:
print("hey")
await close_old_cached_cursors()
await remove_old_closed_cached_cursors()
await asyncio.sleep(10)
async def pool_creator(connection: Connection, minsize=5, maxsize=10): async def pool_creator(connection: Connection, minsize=5, maxsize=10):
@@ -28,8 +69,31 @@ async def pool_creator(connection: Connection, minsize=5, maxsize=10):
maxsize=maxsize, maxsize=maxsize,
) )
async def create_cursor(connection_id: int, query: SelectQuery) -> CachedCursor:
pool = pools.get(connection_id, None)
async def execute_select_query(pool: aiomysql.Pool, query: str, params: list, fetch_num:int = 100): if pool is None:
raise PoolNotFound
connection = await pool.acquire()
cursor = await connection.cursor(aiomysql.SSCursor)
await cursor.execute(query.sql, query.params)
cached_cursor = CachedCursor(
id=str(uuid.uuid4()),
cursor=cursor,
connection=connection,
pool=pool,
connection_id=connection_id,
query=query,
)
return cached_cursor
async def execute_select_query(
pool: aiomysql.Pool, sql_query: str, params: list, fetch_num: int = 100
):
""" """
Executes a SELECT query on the MySQL database asynchronously and returns the results. Executes a SELECT query on the MySQL database asynchronously and returns the results.
@@ -43,7 +107,7 @@ async def execute_select_query(pool: aiomysql.Pool, query: str, params: list, fe
try: try:
async with pool.acquire() as connection: async with pool.acquire() as connection:
async with connection.cursor(aiomysql.DictCursor) as cursor: async with connection.cursor(aiomysql.DictCursor) as cursor:
await cursor.execute(query, params) await cursor.execute(sql_query, params)
result = await cursor.fetchmany(fetch_num) result = await cursor.fetchmany(fetch_num)
return result, cursor.rowcount return result, cursor.rowcount

22
main.py
View File

@@ -1,19 +1,27 @@
import asyncio
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from fastapi import FastAPI from fastapi import FastAPI
from app import api_router from app import api_router
from utils.scripts import pools_creator, pools_destroy, db_startup, cursors_closer
from utils.scripts import pools_creator, pools_destroy, db_startup from dbs import mysql
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
await pools_creator() await pools_creator()
mysql.cached_cursors_cleaner_task = asyncio.create_task(mysql.cached_cursors_cleaner())
yield yield
mysql.cached_cursors_cleaner_task.cancel()
try:
await mysql.cached_cursors_cleaner_task
except asyncio.CancelledError:
print('Closed cached_cursors_cleaner_task')
await cursors_closer()
await pools_destroy() await pools_destroy()
app = FastAPI(lifespan=lifespan) app = FastAPI(lifespan=lifespan)
app.include_router(router=api_router)
app.include_router(router=api_router)

View File

@@ -1,3 +1,4 @@
from data.schemas import SelectQueryBase from data.schemas import SelectQueryBase
def build_sql_query_text(query: SelectQueryBase) -> tuple[str, list]: def build_sql_query_text(query: SelectQueryBase) -> tuple[str, list]:

View File

@@ -20,6 +20,12 @@ async def pools_creator():
mysql.pools[connection.id] = await mysql.pool_creator(connection=connection) mysql.pools[connection.id] = await mysql.pool_creator(connection=connection)
logging.info(msg='Created Pools') logging.info(msg='Created Pools')
async def cursors_closer():
from dbs import mysql
for cursor_id, cursor in mysql.cached_cursors.items():
await cursor.close()
logging.info(f'Closed cursor: {cursor_id}')
async def pools_destroy(): async def pools_destroy():
from dbs import mysql from dbs import mysql
for connection_id, pool in mysql.pools.items(): for connection_id, pool in mysql.pools.items():