From 1abc22592301daf90b6a763c7a5ceb31c04733ca Mon Sep 17 00:00:00 2001 From: abdulhade Date: Tue, 25 Feb 2025 23:37:55 +0300 Subject: [PATCH] Implemented `alembic` migrations, added cursors closing task. --- alembic/alembic.ini | 119 ++++++++++++++++++ alembic/env.py | 80 ++++++++++++ alembic/script.py.mako | 26 ++++ ...ff091f5c_added_pool_size_config_to_the_.py | 32 +++++ .../6eb236240aec_initial_migration.py | 75 +++++++++++ app/__init__.py | 12 +- app/cursors.py | 78 ++++++++++++ app/operations.py | 93 ++++++-------- app/queries.py | 50 ++++++++ core/exceptions.py | 68 ++++++++-- data/app_types.py | 63 ++++++++++ data/models.py | 4 +- data/schemas.py | 26 ++-- dbs/mysql.py | 92 +++++++++++--- main.py | 22 ++-- utils/{sql_creator.py => mysql_scripts.py} | 1 + utils/scripts.py | 6 + 17 files changed, 746 insertions(+), 101 deletions(-) create mode 100644 alembic/alembic.ini create mode 100644 alembic/env.py create mode 100644 alembic/script.py.mako create mode 100644 alembic/versions/1c62ff091f5c_added_pool_size_config_to_the_.py create mode 100644 alembic/versions/6eb236240aec_initial_migration.py create mode 100644 app/cursors.py create mode 100644 app/queries.py create mode 100644 data/app_types.py rename utils/{sql_creator.py => mysql_scripts.py} (99%) diff --git a/alembic/alembic.ini b/alembic/alembic.ini new file mode 100644 index 0000000..cf765b2 --- /dev/null +++ b/alembic/alembic.ini @@ -0,0 +1,119 @@ +# A generic, single database configuration. + +[alembic] +# path to migration scripts +# Use forward slashes (/) also on windows to provide an os agnostic path +script_location = alembic + +# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s +# Uncomment the line below if you want the files to be prepended with date and time +# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file +# for all available tokens +# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s + +# sys.path path, will be prepended to sys.path if present. +# defaults to the current working directory. +prepend_sys_path = . + +# timezone to use when rendering the date within the migration file +# as well as the filename. +# If specified, requires the python>=3.9 or backports.zoneinfo library and tzdata library. +# Any required deps can installed by adding `alembic[tz]` to the pip requirements +# string value is passed to ZoneInfo() +# leave blank for localtime +# timezone = + +# max length of characters to apply to the "slug" field +# truncate_slug_length = 40 + +# set to 'true' to run the environment during +# the 'revision' command, regardless of autogenerate +# revision_environment = false + +# set to 'true' to allow .pyc and .pyo files without +# a source .py file to be detected as revisions in the +# versions/ directory +# sourceless = false + +# version location specification; This defaults +# to alembic/versions. When using multiple version +# directories, initial revisions must be specified with --version-path. +# The path separator used here should be the separator specified by "version_path_separator" below. +# version_locations = %(here)s/bar:%(here)s/bat:alembic/versions + +# version path separator; As mentioned above, this is the character used to split +# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep. +# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas. +# Valid values for version_path_separator are: +# +# version_path_separator = : +# version_path_separator = ; +# version_path_separator = space +# version_path_separator = newline +# +# Use os.pathsep. Default configuration used for new projects. +version_path_separator = os + +# set to 'true' to search source files recursively +# in each "version_locations" directory +# new in Alembic version 1.10 +# recursive_version_locations = false + +# the output encoding used when revision files +# are written from script.py.mako +# output_encoding = utf-8 + +# sqlalchemy.url = driver://user:pass@localhost/dbname + + +[post_write_hooks] +# post_write_hooks defines scripts or Python functions that are run +# on newly generated revision scripts. See the documentation for further +# detail and examples + +# format using "black" - use the console_scripts runner, against the "black" entrypoint +# hooks = black +# black.type = console_scripts +# black.entrypoint = black +# black.options = -l 79 REVISION_SCRIPT_FILENAME + +# lint with attempts to fix using "ruff" - use the exec runner, execute a binary +# hooks = ruff +# ruff.type = exec +# ruff.executable = %(here)s/.venv/bin/ruff +# ruff.options = --fix REVISION_SCRIPT_FILENAME + +# Logging configuration +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARNING +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARNING +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/alembic/env.py b/alembic/env.py new file mode 100644 index 0000000..47d786c --- /dev/null +++ b/alembic/env.py @@ -0,0 +1,80 @@ +from logging.config import fileConfig + +from data.db import engine +from data.models import Base + +from alembic import context + + +""" + Create migration example: + > alembic -c alembic/alembic.ini revision --autogenerate -m "Added pool size config to the connection model." + + Migrate example: + > alembic -c alembic/alembic.ini upgrade head +""" + + +config = context.config + +if config.config_file_name is not None: + fileConfig(config.config_file_name) + +from data.models import User, Query, Connection + + +target_metadata = Base.metadata + + +def run_migrations_offline() -> None: + """ + Run migrations in 'offline' mode. + + This configures the context with just a URL + and not an Engine, though an Engine is acceptable + here as well. By skipping the Engine creation + we don't even need a DBAPI to be available. + + Calls to context.execute() here emit the given string to the + script output. + + """ + url = config.get_main_option("sqlalchemy.url") + context.configure( + url=url, + target_metadata=target_metadata, + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + ) + + with context.begin_transaction(): + context.run_migrations() + + +async def run_migrations_online() -> None: + """ + Run migrations in 'online' mode. + + In this scenario we need to create an Engine + and associate a connection with the context. + + """ + + connectable = engine + async with connectable.connect() as connection: + await connection.run_sync(do_migrations) + + +def do_migrations(connection): + context.configure(connection=connection, target_metadata=Base.metadata) + + with context.begin_transaction(): + context.run_migrations() + + +if context.is_offline_mode(): + run_migrations_offline() +else: + import asyncio + + asyncio.run(run_migrations_online()) diff --git a/alembic/script.py.mako b/alembic/script.py.mako new file mode 100644 index 0000000..fbc4b07 --- /dev/null +++ b/alembic/script.py.mako @@ -0,0 +1,26 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision: str = ${repr(up_revision)} +down_revision: Union[str, None] = ${repr(down_revision)} +branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)} +depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)} + + +def upgrade() -> None: + ${upgrades if upgrades else "pass"} + + +def downgrade() -> None: + ${downgrades if downgrades else "pass"} diff --git a/alembic/versions/1c62ff091f5c_added_pool_size_config_to_the_.py b/alembic/versions/1c62ff091f5c_added_pool_size_config_to_the_.py new file mode 100644 index 0000000..cafa90d --- /dev/null +++ b/alembic/versions/1c62ff091f5c_added_pool_size_config_to_the_.py @@ -0,0 +1,32 @@ +"""Added pool size config to the connection model. + +Revision ID: 1c62ff091f5c +Revises: 6eb236240aec +Create Date: 2025-02-25 19:24:21.712856 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '1c62ff091f5c' +down_revision: Union[str, None] = '6eb236240aec' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('connections', sa.Column('pool_minsize', sa.Integer(), nullable=False, server_default="5")) + op.add_column('connections', sa.Column('pool_maxsize', sa.Integer(), nullable=False, server_default="10")) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('connections', 'pool_maxsize') + op.drop_column('connections', 'pool_minsize') + # ### end Alembic commands ### diff --git a/alembic/versions/6eb236240aec_initial_migration.py b/alembic/versions/6eb236240aec_initial_migration.py new file mode 100644 index 0000000..799441b --- /dev/null +++ b/alembic/versions/6eb236240aec_initial_migration.py @@ -0,0 +1,75 @@ +"""Initial migration. + +Revision ID: 6eb236240aec +Revises: +Create Date: 2025-02-25 19:18:03.125433 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '6eb236240aec' +down_revision: Union[str, None] = None +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('users', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('username', sa.String(), nullable=False), + sa.Column('role', sa.Enum('admin', 'user', name='userrole'), nullable=False), + sa.Column('api_key', sa.String(), nullable=False), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('api_key') + ) + op.create_index(op.f('ix_users_id'), 'users', ['id'], unique=False) + op.create_index(op.f('ix_users_username'), 'users', ['username'], unique=True) + op.create_table('connections', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('db_name', sa.String(), nullable=False), + sa.Column('type', sa.Enum('mysql', 'postgresql', name='connectiontypes'), nullable=False), + sa.Column('host', sa.String(), nullable=True), + sa.Column('port', sa.Integer(), nullable=True), + sa.Column('username', sa.String(), nullable=True), + sa.Column('password', sa.String(), nullable=True), + sa.Column('owner_id', sa.Integer(), nullable=True), + sa.ForeignKeyConstraint(['owner_id'], ['users.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_index(op.f('ix_connections_id'), 'connections', ['id'], unique=False) + op.create_table('queries', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('name', sa.String(), nullable=False), + sa.Column('description', sa.String(), nullable=True), + sa.Column('owner_id', sa.Integer(), nullable=True), + sa.Column('table_name', sa.String(), nullable=False), + sa.Column('columns', sa.JSON(), nullable=False), + sa.Column('filters', sa.JSON(), nullable=True), + sa.Column('sort_by', sa.JSON(), nullable=True), + sa.Column('limit', sa.Integer(), nullable=True), + sa.Column('offset', sa.Integer(), nullable=True), + sa.Column('sql', sa.String(), nullable=False), + sa.Column('params', sa.JSON(), nullable=False), + sa.ForeignKeyConstraint(['owner_id'], ['users.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_index(op.f('ix_queries_id'), 'queries', ['id'], unique=False) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f('ix_queries_id'), table_name='queries') + op.drop_table('queries') + op.drop_index(op.f('ix_connections_id'), table_name='connections') + op.drop_table('connections') + op.drop_index(op.f('ix_users_username'), table_name='users') + op.drop_index(op.f('ix_users_id'), table_name='users') + op.drop_table('users') + # ### end Alembic commands ### diff --git a/app/__init__.py b/app/__init__.py index ecce03e..d0f6644 100644 --- a/app/__init__.py +++ b/app/__init__.py @@ -1,8 +1,10 @@ from fastapi import APIRouter from app.connections import router as connections_router -from app.operations import router as router +from app.operations import router as operations_router from app.users import router as user_router +from app.cursors import router as cursors_router +from app.queries import router as queries_router api_router = APIRouter() api_router.include_router(router=user_router, prefix="/users", tags=["Users"]) @@ -10,5 +12,11 @@ api_router.include_router( router=connections_router, prefix="/connections", tags=["Connections"] ) api_router.include_router( - router=router, prefix='/operations', tags=["Operations"] + router=cursors_router, prefix='/cursors', tags=['Cursors'] +) +api_router.include_router( + router=queries_router, prefix='/queries', tags=['Queries'] +) +api_router.include_router( + router=operations_router, prefix='/operations', tags=["Operations"] ) \ No newline at end of file diff --git a/app/cursors.py b/app/cursors.py new file mode 100644 index 0000000..c188aa8 --- /dev/null +++ b/app/cursors.py @@ -0,0 +1,78 @@ +from fastapi.routing import APIRouter +from data.schemas import ( + CachedCursorOut, +) +from fastapi import Depends, status +from data.crud import ( + read_connection, + read_select_query, +) +from core.dependencies import get_db, get_current_user, get_admin_user +from core.exceptions import ( + QueryNotFound, + ConnectionNotFound, + CursorNotFound, +) +from dbs import mysql + +router = APIRouter() + + +@router.post("/", dependencies=[Depends(get_current_user)]) +async def create_cursor_endpoint( + query_id: int, + connection_id: int, + db=Depends(get_db), +) -> CachedCursorOut: + query = await read_select_query(db=db, query_id=query_id) + + if query is None: + raise QueryNotFound + connection = await read_connection(db=db, connection_id=connection_id) + + if connection is None: + raise ConnectionNotFound + + cached_cursor = await mysql.create_cursor(query=query, connection_id=connection_id) + + mysql.cached_cursors[cached_cursor.id] = cached_cursor + print(mysql.cached_cursors) + return cached_cursor + + +@router.get("/", dependencies=[Depends(get_current_user)]) +async def get_all_cursors() -> list[CachedCursorOut]: + return mysql.cached_cursors.values() + + +@router.get("/{cursor_id}", dependencies=[Depends(get_current_user)]) +async def get_cursors(cursor_id: str) -> CachedCursorOut: + try: + return mysql.cached_cursors[cursor_id] + except KeyError: + raise CursorNotFound + + +@router.delete( + "/", + dependencies=[Depends(get_admin_user)], + status_code=status.HTTP_204_NO_CONTENT, +) +async def close_all_cursor() -> None: + for cached_cursor in mysql.cached_cursors.values(): + await cached_cursor.close() + mysql.cached_cursors.clear() + + +@router.delete( + "/{cursor_id}", + dependencies=[Depends(get_current_user)], + status_code=status.HTTP_204_NO_CONTENT, +) +async def close_cursor(cursor_id: str) -> None: + cached_cursor = mysql.cached_cursors.get(cursor_id, None) + if cached_cursor is None: + raise CursorNotFound + + await cached_cursor.close() + del mysql.cached_cursors[cursor_id] diff --git a/app/operations.py b/app/operations.py index d50d75d..d273ca9 100644 --- a/app/operations.py +++ b/app/operations.py @@ -1,59 +1,22 @@ from fastapi.routing import APIRouter from typing_extensions import Annotated from pydantic import Field -from data.schemas import ( - SelectQueryBase, - SelectQueryInDB, - SelectQuery, - SelectQueryIn, - SelectResult, - SelectQueryMetaData, - SelectQueryInResult, -) -from fastapi import Depends, HTTPException, status -from sqlalchemy.ext.asyncio import AsyncSession +from data.schemas import SelectResult, CachedCursorOut +from fastapi import Depends from data.crud import ( read_connection, - create_select_query, - read_all_select_queries, read_select_query, ) from core.dependencies import get_db, get_current_user, get_admin_user -from core.exceptions import QueryNotFound, ConnectionNotFound, PoolNotFound -from utils.sql_creator import build_sql_query_text +from core.exceptions import ( + QueryNotFound, + ConnectionNotFound, + PoolNotFound, + CursorNotFound, +) from dbs import mysql -router = APIRouter(prefix="/select") - - -@router.post("/check-query", dependencies=[Depends(get_current_user)]) -async def check_select_query(query: SelectQueryBase) -> SelectQuery: - sql, params = build_sql_query_text(query) - q = SelectQuery(**query.model_dump(), params=params, sql=sql) - return q - - -@router.post("/query") -async def create_select_query_endpoint( - query: SelectQueryBase, db=Depends(get_db), user=Depends(get_current_user) -) -> SelectQueryInDB: - sql, params = build_sql_query_text(query) - query_in = SelectQueryIn( - **query.model_dump(), owner_id=user.id, params=params, sql=sql - ) - return await create_select_query(db=db, query=query_in) - - -@router.get("/query", dependencies=[Depends(get_current_user)]) -async def get_select_queries_endpoint(db=Depends(get_db)) -> list[SelectQueryInDB]: - return await read_all_select_queries(db=db) - - -@router.get("/query/{query_id}", dependencies=[Depends(get_current_user)]) -async def get_select_queries_endpoint( - query_id: int, db=Depends(get_db) -) -> SelectQueryInDB: - return await read_select_query(db=db, query_id=query_id) +router = APIRouter() @router.post("/execute", dependencies=[Depends(get_current_user)]) @@ -76,17 +39,41 @@ async def execute_select( raise PoolNotFound raw_result, rowcount = await mysql.execute_select_query( - pool=pool, query=query.sql, params=query.params, fetch_num=page_size + pool=pool, sql_query=query.sql, params=query.params, fetch_num=page_size ) results = mysql.dict_result_to_list(result=mysql.serializer(raw_result=raw_result)) - meta = SelectQueryMetaData( - cursor=None, total_number=rowcount, has_more=len(results.data) != rowcount - ) - return SelectResult( - meta=meta, - query=query, + cursor=CachedCursorOut( + id=None, + connection_id=connection_id, + query=query, + row_count=rowcount, + fetched_rows=len(results.data), + is_closed=True, + has_more=len(results.data) != rowcount, + ttl=-1, + close_at=-1, + ), results=results, ) + + +@router.get(path="/fetch_cursor", dependencies=[Depends(get_current_user)]) +async def fetch_cursor( + cursor_id: str, + page_size: Annotated[int, Field(ge=1, le=1000)] = 50, +) -> SelectResult: + cached_cursor = mysql.cached_cursors.get(cursor_id, None) + if cached_cursor is None: + raise CursorNotFound + result = await cached_cursor.fetch_many(size=page_size) + + if cached_cursor.done: + mysql.cached_cursors.pop(cursor_id, None) + + return SelectResult( + cursor=cached_cursor, + results={"columns": cached_cursor.query.columns, "data": result}, + ) diff --git a/app/queries.py b/app/queries.py new file mode 100644 index 0000000..2df4432 --- /dev/null +++ b/app/queries.py @@ -0,0 +1,50 @@ + +from fastapi.routing import APIRouter + +from data.schemas import ( + SelectQueryBase, + SelectQueryInDB, + SelectQuery, + SelectQueryIn, +) +from fastapi import Depends +from data.crud import ( + create_select_query, + read_all_select_queries, + read_select_query, +) +from core.dependencies import get_db, get_current_user, get_admin_user +from utils.mysql_scripts import build_sql_query_text + +router = APIRouter() + + + +@router.post("/check", dependencies=[Depends(get_current_user)]) +async def check_select_query(query: SelectQueryBase) -> SelectQuery: + sql, params = build_sql_query_text(query) + q = SelectQuery(**query.model_dump(), params=params, sql=sql) + return q + + +@router.post("/") +async def create_select_query_endpoint( + query: SelectQueryBase, db=Depends(get_db), user=Depends(get_current_user) +) -> SelectQueryInDB: + sql, params = build_sql_query_text(query) + query_in = SelectQueryIn( + **query.model_dump(), owner_id=user.id, params=params, sql=sql + ) + return await create_select_query(db=db, query=query_in) + + +@router.get("/", dependencies=[Depends(get_current_user)]) +async def get_select_queries_endpoint(db=Depends(get_db)) -> list[SelectQueryInDB]: + return await read_all_select_queries(db=db) + + +@router.get("/{query_id}", dependencies=[Depends(get_current_user)]) +async def get_select_queries_endpoint( + query_id: int, db=Depends(get_db) +) -> SelectQueryInDB: + return await read_select_query(db=db, query_id=query_id) diff --git a/core/exceptions.py b/core/exceptions.py index 1804b96..06ae49e 100644 --- a/core/exceptions.py +++ b/core/exceptions.py @@ -25,25 +25,67 @@ class QueryValidationError(ValueError): self.msg = msg super().__init__(msg) + class QueryNotFound(HTTPException): - def __init__(self, status_code=404, detail = { - 'message': "The referenced query was not found.", - "code": 'query-not-found' - }, headers = None): + def __init__( + self, + status_code=404, + detail={ + "message": "The referenced query was not found.", + "code": "query-not-found", + }, + headers=None, + ): super().__init__(status_code, detail, headers) class ConnectionNotFound(HTTPException): - def __init__(self, status_code=404, detail = { - 'message': "The referenced connection was not found.", - "code": 'connection-not-found' - }, headers = None): + def __init__( + self, + status_code=404, + detail={ + "message": "The referenced connection was not found.", + "code": "connection-not-found", + }, + headers=None, + ): super().__init__(status_code, detail, headers) class PoolNotFound(HTTPException): - def __init__(self, status_code=404, detail = { - 'message': "We didn't find a running Pool for the referenced connection.", - "code": 'pool-not-found' - }, headers = None): - super().__init__(status_code, detail, headers) \ No newline at end of file + def __init__( + self, + status_code=404, + detail={ + "message": "We didn't find a running Pool for the referenced connection.", + "code": "pool-not-found", + }, + headers=None, + ): + super().__init__(status_code, detail, headers) + + +class CursorNotFound(HTTPException): + def __init__( + self, + status_code=404, + detail={ + "message": "We didn't find a Cursor with the provided ID.", + "code": "cursor-not-found", + }, + headers=None, + ): + super().__init__(status_code, detail, headers) + + +class ClosedCursorUsage(HTTPException): + def __init__( + self, + status_code=400, + detail={ + "message": "The Cursor you are trying to use is closed.", + "code": "cursor-closed", + }, + headers=None, + ): + super().__init__(status_code, detail, headers) diff --git a/data/app_types.py b/data/app_types.py new file mode 100644 index 0000000..0a357c0 --- /dev/null +++ b/data/app_types.py @@ -0,0 +1,63 @@ +import datetime +from aiomysql import SSCursor, Connection, Pool, SSCursor +from data.schemas import SelectQuery +from core.exceptions import ClosedCursorUsage + + +class CachedCursor: + def __init__( + self, + id: str, + cursor: SSCursor, + connection: Connection, + pool: Pool, + connection_id: int, + query: SelectQuery, + ttl: int=60 + ): + + self.id = id + self.cursor = cursor + self.connection = connection + self.connection_id = connection_id + self.pool = pool + self.query = query + self.row_count: int = -1 if cursor.rowcount > 10000000000000000000 else cursor.rowcount + # The rowcount for a SELECT is set to -1 when using a server-side cursor. + # The incorrect large number (> 10000000000000000000) is because -1 is + # interpreted as an unsigned integer in MySQL's internal C API. + self.fetched_rows: int = 0 + self.is_closed: bool=False + self.ttl:int = ttl + self.close_at = self.upgrade_close_at() + self.done=False + + @property + def has_more(self): + return not self.done + + def upgrade_close_at(self) -> int: + return int(datetime.datetime.now(tz=datetime.UTC).timestamp()) + self.ttl + + async def close(self): + await self.cursor.close() + await self.pool.release(self.connection) + self.is_closed=True + + async def fetch_many(self, size: int = 100) -> tuple[list[tuple], bool]: + if self.is_closed: + raise ClosedCursorUsage + + result = await self.cursor.fetchmany(size) + + if len(result) < size: + # The cursor has reached the end of the set. + await self.close() + self.done=True + else: + self.upgrade_close_at() + + self.fetched_rows += len(result) + return result + + diff --git a/data/models.py b/data/models.py index 54d8c3d..8eccb0e 100644 --- a/data/models.py +++ b/data/models.py @@ -21,9 +21,11 @@ class Connection(Base): db_name = Column(String, nullable=False) type = Column(Enum(ConnectionTypes), nullable=False) host = Column(String) - port = Column(Integer) + port = Column(Integer) username = Column(String) password = Column(String) + pool_minsize = Column(Integer, nullable=False, default=5) + pool_maxsize = Column(Integer, nullable=False, default=10) owner_id = Column(Integer, ForeignKey("users.id")) # owner = relationship("User", back_populates="connections") diff --git a/data/schemas.py b/data/schemas.py index 9678df4..b2e7be5 100644 --- a/data/schemas.py +++ b/data/schemas.py @@ -124,8 +124,8 @@ class SelectQueryBase(BaseModel): columns: Union[Literal["*"], List[str]] = "*" filters: Optional[List[FilterClause]] = None sort_by: Optional[List[SortClause]] = None - limit: Annotated[int, Field(strict=True, gt=0)] = None - offset: Annotated[int, Field(strict=True, ge=0)] = None + limit: Optional[Annotated[int, Field(strict=True, gt=0)]] = None + offset: Optional[Annotated[int, Field(strict=True, ge=0)]] = None @field_validator("table_name") @classmethod @@ -179,16 +179,20 @@ class SelectQueryInResult(BaseModel): from_attributes = True -class SelectQueryMetaData(BaseModel): - cursor: Optional[UUID4] = Field( - None, - description="A UUID4 cursor for pagination. Can be None if no more data is available.", - ) - total_number: int - has_more: bool = False +class CachedCursorOut(BaseModel): + id: UUID4 | None + connection_id: int + query: SelectQueryInResult + row_count: int + fetched_rows: int + is_closed: bool + has_more: bool + close_at: int + ttl: int + class Config: + from_attributes = True class SelectResult(BaseModel): - meta: SelectQueryMetaData - query: SelectQueryInResult + cursor: CachedCursorOut results: SelectResultData | None diff --git a/dbs/mysql.py b/dbs/mysql.py index 099cbb0..1ac1128 100644 --- a/dbs/mysql.py +++ b/dbs/mysql.py @@ -1,18 +1,59 @@ -import aiomysql, decimal, datetime -import asyncio +import asyncio, aiomysql, decimal, datetime, uuid, logging + from typing import Literal, Any -from data.schemas import Connection, SelectResultData +from data.schemas import Connection, SelectResultData, SelectQuery +from data.app_types import CachedCursor +from core.exceptions import PoolNotFound -# Database configuration -DB_CONFIG = { - "host": "localhost", - "user": "me", # Replace with your MySQL username - "password": "Passwd3.14", # Replace with your MySQL password - "db": "testing", # Replace with your database name - "port": 3306, # Default MySQL port -} +pools: dict[str, aiomysql.Pool] = {} +cached_cursors: dict[str, CachedCursor] = {} +closed_cached_cursors: dict[str, CachedCursor] = {} -pools: None | dict[str, aiomysql.Pool] = {} +cached_cursors_cleaner_task = None + +async def close_old_cached_cursors(): + global cached_cursors, closed_cached_cursors + + for cursor_id in list(cached_cursors.keys()): + cursor = cached_cursors.get(cursor_id, None) + if cursor is None: + continue + print(cursor.close_at, datetime.datetime.now(datetime.UTC).timestamp()) + if cursor.close_at > datetime.datetime.now(datetime.UTC).timestamp(): + continue + + try: + await cursor.close() + cached_cursors.pop(cursor_id, None) + closed_cached_cursors[cursor_id] = cursor + print(f"Closed cursor {cursor_id}") + except Exception as e: + print(f"Error closing Cursor {cursor_id} -> {e}") + + +async def remove_old_closed_cached_cursors(): + global closed_cached_cursors + for cursor_id in set(closed_cached_cursors.keys()): + closed_cursor = closed_cached_cursors.get(cursor_id, None) + if closed_cursor is None: + continue + if ( + closed_cursor.close_at + closed_cursor.ttl * 5 + > datetime.datetime.now(datetime.UTC).timestamp() + ): + continue + + del closed_cached_cursors[cursor_id] + print(f"Removed cursor {cursor_id}") + +async def cached_cursors_cleaner(): + global cached_cursors, closed_cached_cursors + while True: + print("hey") + await close_old_cached_cursors() + await remove_old_closed_cached_cursors() + + await asyncio.sleep(10) async def pool_creator(connection: Connection, minsize=5, maxsize=10): @@ -28,8 +69,31 @@ async def pool_creator(connection: Connection, minsize=5, maxsize=10): maxsize=maxsize, ) +async def create_cursor(connection_id: int, query: SelectQuery) -> CachedCursor: + pool = pools.get(connection_id, None) -async def execute_select_query(pool: aiomysql.Pool, query: str, params: list, fetch_num:int = 100): + if pool is None: + raise PoolNotFound + + connection = await pool.acquire() + cursor = await connection.cursor(aiomysql.SSCursor) + await cursor.execute(query.sql, query.params) + + cached_cursor = CachedCursor( + id=str(uuid.uuid4()), + cursor=cursor, + connection=connection, + pool=pool, + connection_id=connection_id, + query=query, + ) + + return cached_cursor + + +async def execute_select_query( + pool: aiomysql.Pool, sql_query: str, params: list, fetch_num: int = 100 +): """ Executes a SELECT query on the MySQL database asynchronously and returns the results. @@ -43,7 +107,7 @@ async def execute_select_query(pool: aiomysql.Pool, query: str, params: list, fe try: async with pool.acquire() as connection: async with connection.cursor(aiomysql.DictCursor) as cursor: - await cursor.execute(query, params) + await cursor.execute(sql_query, params) result = await cursor.fetchmany(fetch_num) return result, cursor.rowcount diff --git a/main.py b/main.py index cb436d7..dc06e86 100644 --- a/main.py +++ b/main.py @@ -1,19 +1,27 @@ +import asyncio from contextlib import asynccontextmanager - from fastapi import FastAPI - from app import api_router - -from utils.scripts import pools_creator, pools_destroy, db_startup - - +from utils.scripts import pools_creator, pools_destroy, db_startup, cursors_closer +from dbs import mysql @asynccontextmanager async def lifespan(app: FastAPI): await pools_creator() + mysql.cached_cursors_cleaner_task = asyncio.create_task(mysql.cached_cursors_cleaner()) + yield + + mysql.cached_cursors_cleaner_task.cancel() + try: + await mysql.cached_cursors_cleaner_task + except asyncio.CancelledError: + print('Closed cached_cursors_cleaner_task') + await cursors_closer() await pools_destroy() + app = FastAPI(lifespan=lifespan) -app.include_router(router=api_router) \ No newline at end of file + +app.include_router(router=api_router) diff --git a/utils/sql_creator.py b/utils/mysql_scripts.py similarity index 99% rename from utils/sql_creator.py rename to utils/mysql_scripts.py index 0242e8f..d9ecaca 100644 --- a/utils/sql_creator.py +++ b/utils/mysql_scripts.py @@ -1,3 +1,4 @@ + from data.schemas import SelectQueryBase def build_sql_query_text(query: SelectQueryBase) -> tuple[str, list]: diff --git a/utils/scripts.py b/utils/scripts.py index dea0887..681873e 100644 --- a/utils/scripts.py +++ b/utils/scripts.py @@ -20,6 +20,12 @@ async def pools_creator(): mysql.pools[connection.id] = await mysql.pool_creator(connection=connection) logging.info(msg='Created Pools') +async def cursors_closer(): + from dbs import mysql + for cursor_id, cursor in mysql.cached_cursors.items(): + await cursor.close() + logging.info(f'Closed cursor: {cursor_id}') + async def pools_destroy(): from dbs import mysql for connection_id, pool in mysql.pools.items():