Implemented alembic migrations, added cursors closing task.
This commit is contained in:
119
alembic/alembic.ini
Normal file
119
alembic/alembic.ini
Normal file
@@ -0,0 +1,119 @@
|
||||
# A generic, single database configuration.
|
||||
|
||||
[alembic]
|
||||
# path to migration scripts
|
||||
# Use forward slashes (/) also on windows to provide an os agnostic path
|
||||
script_location = alembic
|
||||
|
||||
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
|
||||
# Uncomment the line below if you want the files to be prepended with date and time
|
||||
# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
|
||||
# for all available tokens
|
||||
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
|
||||
|
||||
# sys.path path, will be prepended to sys.path if present.
|
||||
# defaults to the current working directory.
|
||||
prepend_sys_path = .
|
||||
|
||||
# timezone to use when rendering the date within the migration file
|
||||
# as well as the filename.
|
||||
# If specified, requires the python>=3.9 or backports.zoneinfo library and tzdata library.
|
||||
# Any required deps can installed by adding `alembic[tz]` to the pip requirements
|
||||
# string value is passed to ZoneInfo()
|
||||
# leave blank for localtime
|
||||
# timezone =
|
||||
|
||||
# max length of characters to apply to the "slug" field
|
||||
# truncate_slug_length = 40
|
||||
|
||||
# set to 'true' to run the environment during
|
||||
# the 'revision' command, regardless of autogenerate
|
||||
# revision_environment = false
|
||||
|
||||
# set to 'true' to allow .pyc and .pyo files without
|
||||
# a source .py file to be detected as revisions in the
|
||||
# versions/ directory
|
||||
# sourceless = false
|
||||
|
||||
# version location specification; This defaults
|
||||
# to alembic/versions. When using multiple version
|
||||
# directories, initial revisions must be specified with --version-path.
|
||||
# The path separator used here should be the separator specified by "version_path_separator" below.
|
||||
# version_locations = %(here)s/bar:%(here)s/bat:alembic/versions
|
||||
|
||||
# version path separator; As mentioned above, this is the character used to split
|
||||
# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep.
|
||||
# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas.
|
||||
# Valid values for version_path_separator are:
|
||||
#
|
||||
# version_path_separator = :
|
||||
# version_path_separator = ;
|
||||
# version_path_separator = space
|
||||
# version_path_separator = newline
|
||||
#
|
||||
# Use os.pathsep. Default configuration used for new projects.
|
||||
version_path_separator = os
|
||||
|
||||
# set to 'true' to search source files recursively
|
||||
# in each "version_locations" directory
|
||||
# new in Alembic version 1.10
|
||||
# recursive_version_locations = false
|
||||
|
||||
# the output encoding used when revision files
|
||||
# are written from script.py.mako
|
||||
# output_encoding = utf-8
|
||||
|
||||
# sqlalchemy.url = driver://user:pass@localhost/dbname
|
||||
|
||||
|
||||
[post_write_hooks]
|
||||
# post_write_hooks defines scripts or Python functions that are run
|
||||
# on newly generated revision scripts. See the documentation for further
|
||||
# detail and examples
|
||||
|
||||
# format using "black" - use the console_scripts runner, against the "black" entrypoint
|
||||
# hooks = black
|
||||
# black.type = console_scripts
|
||||
# black.entrypoint = black
|
||||
# black.options = -l 79 REVISION_SCRIPT_FILENAME
|
||||
|
||||
# lint with attempts to fix using "ruff" - use the exec runner, execute a binary
|
||||
# hooks = ruff
|
||||
# ruff.type = exec
|
||||
# ruff.executable = %(here)s/.venv/bin/ruff
|
||||
# ruff.options = --fix REVISION_SCRIPT_FILENAME
|
||||
|
||||
# Logging configuration
|
||||
[loggers]
|
||||
keys = root,sqlalchemy,alembic
|
||||
|
||||
[handlers]
|
||||
keys = console
|
||||
|
||||
[formatters]
|
||||
keys = generic
|
||||
|
||||
[logger_root]
|
||||
level = WARNING
|
||||
handlers = console
|
||||
qualname =
|
||||
|
||||
[logger_sqlalchemy]
|
||||
level = WARNING
|
||||
handlers =
|
||||
qualname = sqlalchemy.engine
|
||||
|
||||
[logger_alembic]
|
||||
level = INFO
|
||||
handlers =
|
||||
qualname = alembic
|
||||
|
||||
[handler_console]
|
||||
class = StreamHandler
|
||||
args = (sys.stderr,)
|
||||
level = NOTSET
|
||||
formatter = generic
|
||||
|
||||
[formatter_generic]
|
||||
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||
datefmt = %H:%M:%S
|
||||
80
alembic/env.py
Normal file
80
alembic/env.py
Normal file
@@ -0,0 +1,80 @@
|
||||
from logging.config import fileConfig
|
||||
|
||||
from data.db import engine
|
||||
from data.models import Base
|
||||
|
||||
from alembic import context
|
||||
|
||||
|
||||
"""
|
||||
Create migration example:
|
||||
> alembic -c alembic/alembic.ini revision --autogenerate -m "Added pool size config to the connection model."
|
||||
|
||||
Migrate example:
|
||||
> alembic -c alembic/alembic.ini upgrade head
|
||||
"""
|
||||
|
||||
|
||||
config = context.config
|
||||
|
||||
if config.config_file_name is not None:
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
from data.models import User, Query, Connection
|
||||
|
||||
|
||||
target_metadata = Base.metadata
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
"""
|
||||
Run migrations in 'offline' mode.
|
||||
|
||||
This configures the context with just a URL
|
||||
and not an Engine, though an Engine is acceptable
|
||||
here as well. By skipping the Engine creation
|
||||
we don't even need a DBAPI to be available.
|
||||
|
||||
Calls to context.execute() here emit the given string to the
|
||||
script output.
|
||||
|
||||
"""
|
||||
url = config.get_main_option("sqlalchemy.url")
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
async def run_migrations_online() -> None:
|
||||
"""
|
||||
Run migrations in 'online' mode.
|
||||
|
||||
In this scenario we need to create an Engine
|
||||
and associate a connection with the context.
|
||||
|
||||
"""
|
||||
|
||||
connectable = engine
|
||||
async with connectable.connect() as connection:
|
||||
await connection.run_sync(do_migrations)
|
||||
|
||||
|
||||
def do_migrations(connection):
|
||||
context.configure(connection=connection, target_metadata=Base.metadata)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
import asyncio
|
||||
|
||||
asyncio.run(run_migrations_online())
|
||||
26
alembic/script.py.mako
Normal file
26
alembic/script.py.mako
Normal file
@@ -0,0 +1,26 @@
|
||||
"""${message}
|
||||
|
||||
Revision ID: ${up_revision}
|
||||
Revises: ${down_revision | comma,n}
|
||||
Create Date: ${create_date}
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
${imports if imports else ""}
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = ${repr(up_revision)}
|
||||
down_revision: Union[str, None] = ${repr(down_revision)}
|
||||
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
||||
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
${upgrades if upgrades else "pass"}
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
${downgrades if downgrades else "pass"}
|
||||
@@ -0,0 +1,32 @@
|
||||
"""Added pool size config to the connection model.
|
||||
|
||||
Revision ID: 1c62ff091f5c
|
||||
Revises: 6eb236240aec
|
||||
Create Date: 2025-02-25 19:24:21.712856
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '1c62ff091f5c'
|
||||
down_revision: Union[str, None] = '6eb236240aec'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column('connections', sa.Column('pool_minsize', sa.Integer(), nullable=False, server_default="5"))
|
||||
op.add_column('connections', sa.Column('pool_maxsize', sa.Integer(), nullable=False, server_default="10"))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column('connections', 'pool_maxsize')
|
||||
op.drop_column('connections', 'pool_minsize')
|
||||
# ### end Alembic commands ###
|
||||
75
alembic/versions/6eb236240aec_initial_migration.py
Normal file
75
alembic/versions/6eb236240aec_initial_migration.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""Initial migration.
|
||||
|
||||
Revision ID: 6eb236240aec
|
||||
Revises:
|
||||
Create Date: 2025-02-25 19:18:03.125433
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '6eb236240aec'
|
||||
down_revision: Union[str, None] = None
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('users',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('username', sa.String(), nullable=False),
|
||||
sa.Column('role', sa.Enum('admin', 'user', name='userrole'), nullable=False),
|
||||
sa.Column('api_key', sa.String(), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('api_key')
|
||||
)
|
||||
op.create_index(op.f('ix_users_id'), 'users', ['id'], unique=False)
|
||||
op.create_index(op.f('ix_users_username'), 'users', ['username'], unique=True)
|
||||
op.create_table('connections',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('db_name', sa.String(), nullable=False),
|
||||
sa.Column('type', sa.Enum('mysql', 'postgresql', name='connectiontypes'), nullable=False),
|
||||
sa.Column('host', sa.String(), nullable=True),
|
||||
sa.Column('port', sa.Integer(), nullable=True),
|
||||
sa.Column('username', sa.String(), nullable=True),
|
||||
sa.Column('password', sa.String(), nullable=True),
|
||||
sa.Column('owner_id', sa.Integer(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['owner_id'], ['users.id'], ),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_connections_id'), 'connections', ['id'], unique=False)
|
||||
op.create_table('queries',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('name', sa.String(), nullable=False),
|
||||
sa.Column('description', sa.String(), nullable=True),
|
||||
sa.Column('owner_id', sa.Integer(), nullable=True),
|
||||
sa.Column('table_name', sa.String(), nullable=False),
|
||||
sa.Column('columns', sa.JSON(), nullable=False),
|
||||
sa.Column('filters', sa.JSON(), nullable=True),
|
||||
sa.Column('sort_by', sa.JSON(), nullable=True),
|
||||
sa.Column('limit', sa.Integer(), nullable=True),
|
||||
sa.Column('offset', sa.Integer(), nullable=True),
|
||||
sa.Column('sql', sa.String(), nullable=False),
|
||||
sa.Column('params', sa.JSON(), nullable=False),
|
||||
sa.ForeignKeyConstraint(['owner_id'], ['users.id'], ),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_queries_id'), 'queries', ['id'], unique=False)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index(op.f('ix_queries_id'), table_name='queries')
|
||||
op.drop_table('queries')
|
||||
op.drop_index(op.f('ix_connections_id'), table_name='connections')
|
||||
op.drop_table('connections')
|
||||
op.drop_index(op.f('ix_users_username'), table_name='users')
|
||||
op.drop_index(op.f('ix_users_id'), table_name='users')
|
||||
op.drop_table('users')
|
||||
# ### end Alembic commands ###
|
||||
@@ -1,8 +1,10 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
from app.connections import router as connections_router
|
||||
from app.operations import router as router
|
||||
from app.operations import router as operations_router
|
||||
from app.users import router as user_router
|
||||
from app.cursors import router as cursors_router
|
||||
from app.queries import router as queries_router
|
||||
|
||||
api_router = APIRouter()
|
||||
api_router.include_router(router=user_router, prefix="/users", tags=["Users"])
|
||||
@@ -10,5 +12,11 @@ api_router.include_router(
|
||||
router=connections_router, prefix="/connections", tags=["Connections"]
|
||||
)
|
||||
api_router.include_router(
|
||||
router=router, prefix='/operations', tags=["Operations"]
|
||||
router=cursors_router, prefix='/cursors', tags=['Cursors']
|
||||
)
|
||||
api_router.include_router(
|
||||
router=queries_router, prefix='/queries', tags=['Queries']
|
||||
)
|
||||
api_router.include_router(
|
||||
router=operations_router, prefix='/operations', tags=["Operations"]
|
||||
)
|
||||
78
app/cursors.py
Normal file
78
app/cursors.py
Normal file
@@ -0,0 +1,78 @@
|
||||
from fastapi.routing import APIRouter
|
||||
from data.schemas import (
|
||||
CachedCursorOut,
|
||||
)
|
||||
from fastapi import Depends, status
|
||||
from data.crud import (
|
||||
read_connection,
|
||||
read_select_query,
|
||||
)
|
||||
from core.dependencies import get_db, get_current_user, get_admin_user
|
||||
from core.exceptions import (
|
||||
QueryNotFound,
|
||||
ConnectionNotFound,
|
||||
CursorNotFound,
|
||||
)
|
||||
from dbs import mysql
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/", dependencies=[Depends(get_current_user)])
|
||||
async def create_cursor_endpoint(
|
||||
query_id: int,
|
||||
connection_id: int,
|
||||
db=Depends(get_db),
|
||||
) -> CachedCursorOut:
|
||||
query = await read_select_query(db=db, query_id=query_id)
|
||||
|
||||
if query is None:
|
||||
raise QueryNotFound
|
||||
connection = await read_connection(db=db, connection_id=connection_id)
|
||||
|
||||
if connection is None:
|
||||
raise ConnectionNotFound
|
||||
|
||||
cached_cursor = await mysql.create_cursor(query=query, connection_id=connection_id)
|
||||
|
||||
mysql.cached_cursors[cached_cursor.id] = cached_cursor
|
||||
print(mysql.cached_cursors)
|
||||
return cached_cursor
|
||||
|
||||
|
||||
@router.get("/", dependencies=[Depends(get_current_user)])
|
||||
async def get_all_cursors() -> list[CachedCursorOut]:
|
||||
return mysql.cached_cursors.values()
|
||||
|
||||
|
||||
@router.get("/{cursor_id}", dependencies=[Depends(get_current_user)])
|
||||
async def get_cursors(cursor_id: str) -> CachedCursorOut:
|
||||
try:
|
||||
return mysql.cached_cursors[cursor_id]
|
||||
except KeyError:
|
||||
raise CursorNotFound
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/",
|
||||
dependencies=[Depends(get_admin_user)],
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
)
|
||||
async def close_all_cursor() -> None:
|
||||
for cached_cursor in mysql.cached_cursors.values():
|
||||
await cached_cursor.close()
|
||||
mysql.cached_cursors.clear()
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/{cursor_id}",
|
||||
dependencies=[Depends(get_current_user)],
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
)
|
||||
async def close_cursor(cursor_id: str) -> None:
|
||||
cached_cursor = mysql.cached_cursors.get(cursor_id, None)
|
||||
if cached_cursor is None:
|
||||
raise CursorNotFound
|
||||
|
||||
await cached_cursor.close()
|
||||
del mysql.cached_cursors[cursor_id]
|
||||
@@ -1,59 +1,22 @@
|
||||
from fastapi.routing import APIRouter
|
||||
from typing_extensions import Annotated
|
||||
from pydantic import Field
|
||||
from data.schemas import (
|
||||
SelectQueryBase,
|
||||
SelectQueryInDB,
|
||||
SelectQuery,
|
||||
SelectQueryIn,
|
||||
SelectResult,
|
||||
SelectQueryMetaData,
|
||||
SelectQueryInResult,
|
||||
)
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from data.schemas import SelectResult, CachedCursorOut
|
||||
from fastapi import Depends
|
||||
from data.crud import (
|
||||
read_connection,
|
||||
create_select_query,
|
||||
read_all_select_queries,
|
||||
read_select_query,
|
||||
)
|
||||
from core.dependencies import get_db, get_current_user, get_admin_user
|
||||
from core.exceptions import QueryNotFound, ConnectionNotFound, PoolNotFound
|
||||
from utils.sql_creator import build_sql_query_text
|
||||
from core.exceptions import (
|
||||
QueryNotFound,
|
||||
ConnectionNotFound,
|
||||
PoolNotFound,
|
||||
CursorNotFound,
|
||||
)
|
||||
from dbs import mysql
|
||||
|
||||
router = APIRouter(prefix="/select")
|
||||
|
||||
|
||||
@router.post("/check-query", dependencies=[Depends(get_current_user)])
|
||||
async def check_select_query(query: SelectQueryBase) -> SelectQuery:
|
||||
sql, params = build_sql_query_text(query)
|
||||
q = SelectQuery(**query.model_dump(), params=params, sql=sql)
|
||||
return q
|
||||
|
||||
|
||||
@router.post("/query")
|
||||
async def create_select_query_endpoint(
|
||||
query: SelectQueryBase, db=Depends(get_db), user=Depends(get_current_user)
|
||||
) -> SelectQueryInDB:
|
||||
sql, params = build_sql_query_text(query)
|
||||
query_in = SelectQueryIn(
|
||||
**query.model_dump(), owner_id=user.id, params=params, sql=sql
|
||||
)
|
||||
return await create_select_query(db=db, query=query_in)
|
||||
|
||||
|
||||
@router.get("/query", dependencies=[Depends(get_current_user)])
|
||||
async def get_select_queries_endpoint(db=Depends(get_db)) -> list[SelectQueryInDB]:
|
||||
return await read_all_select_queries(db=db)
|
||||
|
||||
|
||||
@router.get("/query/{query_id}", dependencies=[Depends(get_current_user)])
|
||||
async def get_select_queries_endpoint(
|
||||
query_id: int, db=Depends(get_db)
|
||||
) -> SelectQueryInDB:
|
||||
return await read_select_query(db=db, query_id=query_id)
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/execute", dependencies=[Depends(get_current_user)])
|
||||
@@ -76,17 +39,41 @@ async def execute_select(
|
||||
raise PoolNotFound
|
||||
|
||||
raw_result, rowcount = await mysql.execute_select_query(
|
||||
pool=pool, query=query.sql, params=query.params, fetch_num=page_size
|
||||
pool=pool, sql_query=query.sql, params=query.params, fetch_num=page_size
|
||||
)
|
||||
|
||||
results = mysql.dict_result_to_list(result=mysql.serializer(raw_result=raw_result))
|
||||
|
||||
meta = SelectQueryMetaData(
|
||||
cursor=None, total_number=rowcount, has_more=len(results.data) != rowcount
|
||||
)
|
||||
|
||||
return SelectResult(
|
||||
meta=meta,
|
||||
cursor=CachedCursorOut(
|
||||
id=None,
|
||||
connection_id=connection_id,
|
||||
query=query,
|
||||
row_count=rowcount,
|
||||
fetched_rows=len(results.data),
|
||||
is_closed=True,
|
||||
has_more=len(results.data) != rowcount,
|
||||
ttl=-1,
|
||||
close_at=-1,
|
||||
),
|
||||
results=results,
|
||||
)
|
||||
|
||||
|
||||
@router.get(path="/fetch_cursor", dependencies=[Depends(get_current_user)])
|
||||
async def fetch_cursor(
|
||||
cursor_id: str,
|
||||
page_size: Annotated[int, Field(ge=1, le=1000)] = 50,
|
||||
) -> SelectResult:
|
||||
cached_cursor = mysql.cached_cursors.get(cursor_id, None)
|
||||
if cached_cursor is None:
|
||||
raise CursorNotFound
|
||||
result = await cached_cursor.fetch_many(size=page_size)
|
||||
|
||||
if cached_cursor.done:
|
||||
mysql.cached_cursors.pop(cursor_id, None)
|
||||
|
||||
return SelectResult(
|
||||
cursor=cached_cursor,
|
||||
results={"columns": cached_cursor.query.columns, "data": result},
|
||||
)
|
||||
|
||||
50
app/queries.py
Normal file
50
app/queries.py
Normal file
@@ -0,0 +1,50 @@
|
||||
|
||||
from fastapi.routing import APIRouter
|
||||
|
||||
from data.schemas import (
|
||||
SelectQueryBase,
|
||||
SelectQueryInDB,
|
||||
SelectQuery,
|
||||
SelectQueryIn,
|
||||
)
|
||||
from fastapi import Depends
|
||||
from data.crud import (
|
||||
create_select_query,
|
||||
read_all_select_queries,
|
||||
read_select_query,
|
||||
)
|
||||
from core.dependencies import get_db, get_current_user, get_admin_user
|
||||
from utils.mysql_scripts import build_sql_query_text
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
|
||||
@router.post("/check", dependencies=[Depends(get_current_user)])
|
||||
async def check_select_query(query: SelectQueryBase) -> SelectQuery:
|
||||
sql, params = build_sql_query_text(query)
|
||||
q = SelectQuery(**query.model_dump(), params=params, sql=sql)
|
||||
return q
|
||||
|
||||
|
||||
@router.post("/")
|
||||
async def create_select_query_endpoint(
|
||||
query: SelectQueryBase, db=Depends(get_db), user=Depends(get_current_user)
|
||||
) -> SelectQueryInDB:
|
||||
sql, params = build_sql_query_text(query)
|
||||
query_in = SelectQueryIn(
|
||||
**query.model_dump(), owner_id=user.id, params=params, sql=sql
|
||||
)
|
||||
return await create_select_query(db=db, query=query_in)
|
||||
|
||||
|
||||
@router.get("/", dependencies=[Depends(get_current_user)])
|
||||
async def get_select_queries_endpoint(db=Depends(get_db)) -> list[SelectQueryInDB]:
|
||||
return await read_all_select_queries(db=db)
|
||||
|
||||
|
||||
@router.get("/{query_id}", dependencies=[Depends(get_current_user)])
|
||||
async def get_select_queries_endpoint(
|
||||
query_id: int, db=Depends(get_db)
|
||||
) -> SelectQueryInDB:
|
||||
return await read_select_query(db=db, query_id=query_id)
|
||||
@@ -25,25 +25,67 @@ class QueryValidationError(ValueError):
|
||||
self.msg = msg
|
||||
super().__init__(msg)
|
||||
|
||||
|
||||
class QueryNotFound(HTTPException):
|
||||
def __init__(self, status_code=404, detail = {
|
||||
'message': "The referenced query was not found.",
|
||||
"code": 'query-not-found'
|
||||
}, headers = None):
|
||||
def __init__(
|
||||
self,
|
||||
status_code=404,
|
||||
detail={
|
||||
"message": "The referenced query was not found.",
|
||||
"code": "query-not-found",
|
||||
},
|
||||
headers=None,
|
||||
):
|
||||
super().__init__(status_code, detail, headers)
|
||||
|
||||
|
||||
class ConnectionNotFound(HTTPException):
|
||||
def __init__(self, status_code=404, detail = {
|
||||
'message': "The referenced connection was not found.",
|
||||
"code": 'connection-not-found'
|
||||
}, headers = None):
|
||||
def __init__(
|
||||
self,
|
||||
status_code=404,
|
||||
detail={
|
||||
"message": "The referenced connection was not found.",
|
||||
"code": "connection-not-found",
|
||||
},
|
||||
headers=None,
|
||||
):
|
||||
super().__init__(status_code, detail, headers)
|
||||
|
||||
|
||||
class PoolNotFound(HTTPException):
|
||||
def __init__(self, status_code=404, detail = {
|
||||
'message': "We didn't find a running Pool for the referenced connection.",
|
||||
"code": 'pool-not-found'
|
||||
}, headers = None):
|
||||
def __init__(
|
||||
self,
|
||||
status_code=404,
|
||||
detail={
|
||||
"message": "We didn't find a running Pool for the referenced connection.",
|
||||
"code": "pool-not-found",
|
||||
},
|
||||
headers=None,
|
||||
):
|
||||
super().__init__(status_code, detail, headers)
|
||||
|
||||
|
||||
class CursorNotFound(HTTPException):
|
||||
def __init__(
|
||||
self,
|
||||
status_code=404,
|
||||
detail={
|
||||
"message": "We didn't find a Cursor with the provided ID.",
|
||||
"code": "cursor-not-found",
|
||||
},
|
||||
headers=None,
|
||||
):
|
||||
super().__init__(status_code, detail, headers)
|
||||
|
||||
|
||||
class ClosedCursorUsage(HTTPException):
|
||||
def __init__(
|
||||
self,
|
||||
status_code=400,
|
||||
detail={
|
||||
"message": "The Cursor you are trying to use is closed.",
|
||||
"code": "cursor-closed",
|
||||
},
|
||||
headers=None,
|
||||
):
|
||||
super().__init__(status_code, detail, headers)
|
||||
63
data/app_types.py
Normal file
63
data/app_types.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import datetime
|
||||
from aiomysql import SSCursor, Connection, Pool, SSCursor
|
||||
from data.schemas import SelectQuery
|
||||
from core.exceptions import ClosedCursorUsage
|
||||
|
||||
|
||||
class CachedCursor:
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
cursor: SSCursor,
|
||||
connection: Connection,
|
||||
pool: Pool,
|
||||
connection_id: int,
|
||||
query: SelectQuery,
|
||||
ttl: int=60
|
||||
):
|
||||
|
||||
self.id = id
|
||||
self.cursor = cursor
|
||||
self.connection = connection
|
||||
self.connection_id = connection_id
|
||||
self.pool = pool
|
||||
self.query = query
|
||||
self.row_count: int = -1 if cursor.rowcount > 10000000000000000000 else cursor.rowcount
|
||||
# The rowcount for a SELECT is set to -1 when using a server-side cursor.
|
||||
# The incorrect large number (> 10000000000000000000) is because -1 is
|
||||
# interpreted as an unsigned integer in MySQL's internal C API.
|
||||
self.fetched_rows: int = 0
|
||||
self.is_closed: bool=False
|
||||
self.ttl:int = ttl
|
||||
self.close_at = self.upgrade_close_at()
|
||||
self.done=False
|
||||
|
||||
@property
|
||||
def has_more(self):
|
||||
return not self.done
|
||||
|
||||
def upgrade_close_at(self) -> int:
|
||||
return int(datetime.datetime.now(tz=datetime.UTC).timestamp()) + self.ttl
|
||||
|
||||
async def close(self):
|
||||
await self.cursor.close()
|
||||
await self.pool.release(self.connection)
|
||||
self.is_closed=True
|
||||
|
||||
async def fetch_many(self, size: int = 100) -> tuple[list[tuple], bool]:
|
||||
if self.is_closed:
|
||||
raise ClosedCursorUsage
|
||||
|
||||
result = await self.cursor.fetchmany(size)
|
||||
|
||||
if len(result) < size:
|
||||
# The cursor has reached the end of the set.
|
||||
await self.close()
|
||||
self.done=True
|
||||
else:
|
||||
self.upgrade_close_at()
|
||||
|
||||
self.fetched_rows += len(result)
|
||||
return result
|
||||
|
||||
|
||||
@@ -24,6 +24,8 @@ class Connection(Base):
|
||||
port = Column(Integer)
|
||||
username = Column(String)
|
||||
password = Column(String)
|
||||
pool_minsize = Column(Integer, nullable=False, default=5)
|
||||
pool_maxsize = Column(Integer, nullable=False, default=10)
|
||||
owner_id = Column(Integer, ForeignKey("users.id"))
|
||||
|
||||
# owner = relationship("User", back_populates="connections")
|
||||
|
||||
@@ -124,8 +124,8 @@ class SelectQueryBase(BaseModel):
|
||||
columns: Union[Literal["*"], List[str]] = "*"
|
||||
filters: Optional[List[FilterClause]] = None
|
||||
sort_by: Optional[List[SortClause]] = None
|
||||
limit: Annotated[int, Field(strict=True, gt=0)] = None
|
||||
offset: Annotated[int, Field(strict=True, ge=0)] = None
|
||||
limit: Optional[Annotated[int, Field(strict=True, gt=0)]] = None
|
||||
offset: Optional[Annotated[int, Field(strict=True, ge=0)]] = None
|
||||
|
||||
@field_validator("table_name")
|
||||
@classmethod
|
||||
@@ -179,16 +179,20 @@ class SelectQueryInResult(BaseModel):
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class SelectQueryMetaData(BaseModel):
|
||||
cursor: Optional[UUID4] = Field(
|
||||
None,
|
||||
description="A UUID4 cursor for pagination. Can be None if no more data is available.",
|
||||
)
|
||||
total_number: int
|
||||
has_more: bool = False
|
||||
class CachedCursorOut(BaseModel):
|
||||
id: UUID4 | None
|
||||
connection_id: int
|
||||
query: SelectQueryInResult
|
||||
row_count: int
|
||||
fetched_rows: int
|
||||
is_closed: bool
|
||||
has_more: bool
|
||||
close_at: int
|
||||
ttl: int
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class SelectResult(BaseModel):
|
||||
meta: SelectQueryMetaData
|
||||
query: SelectQueryInResult
|
||||
cursor: CachedCursorOut
|
||||
results: SelectResultData | None
|
||||
|
||||
92
dbs/mysql.py
92
dbs/mysql.py
@@ -1,18 +1,59 @@
|
||||
import aiomysql, decimal, datetime
|
||||
import asyncio
|
||||
import asyncio, aiomysql, decimal, datetime, uuid, logging
|
||||
|
||||
from typing import Literal, Any
|
||||
from data.schemas import Connection, SelectResultData
|
||||
from data.schemas import Connection, SelectResultData, SelectQuery
|
||||
from data.app_types import CachedCursor
|
||||
from core.exceptions import PoolNotFound
|
||||
|
||||
# Database configuration
|
||||
DB_CONFIG = {
|
||||
"host": "localhost",
|
||||
"user": "me", # Replace with your MySQL username
|
||||
"password": "Passwd3.14", # Replace with your MySQL password
|
||||
"db": "testing", # Replace with your database name
|
||||
"port": 3306, # Default MySQL port
|
||||
}
|
||||
pools: dict[str, aiomysql.Pool] = {}
|
||||
cached_cursors: dict[str, CachedCursor] = {}
|
||||
closed_cached_cursors: dict[str, CachedCursor] = {}
|
||||
|
||||
pools: None | dict[str, aiomysql.Pool] = {}
|
||||
cached_cursors_cleaner_task = None
|
||||
|
||||
async def close_old_cached_cursors():
|
||||
global cached_cursors, closed_cached_cursors
|
||||
|
||||
for cursor_id in list(cached_cursors.keys()):
|
||||
cursor = cached_cursors.get(cursor_id, None)
|
||||
if cursor is None:
|
||||
continue
|
||||
print(cursor.close_at, datetime.datetime.now(datetime.UTC).timestamp())
|
||||
if cursor.close_at > datetime.datetime.now(datetime.UTC).timestamp():
|
||||
continue
|
||||
|
||||
try:
|
||||
await cursor.close()
|
||||
cached_cursors.pop(cursor_id, None)
|
||||
closed_cached_cursors[cursor_id] = cursor
|
||||
print(f"Closed cursor {cursor_id}")
|
||||
except Exception as e:
|
||||
print(f"Error closing Cursor {cursor_id} -> {e}")
|
||||
|
||||
|
||||
async def remove_old_closed_cached_cursors():
|
||||
global closed_cached_cursors
|
||||
for cursor_id in set(closed_cached_cursors.keys()):
|
||||
closed_cursor = closed_cached_cursors.get(cursor_id, None)
|
||||
if closed_cursor is None:
|
||||
continue
|
||||
if (
|
||||
closed_cursor.close_at + closed_cursor.ttl * 5
|
||||
> datetime.datetime.now(datetime.UTC).timestamp()
|
||||
):
|
||||
continue
|
||||
|
||||
del closed_cached_cursors[cursor_id]
|
||||
print(f"Removed cursor {cursor_id}")
|
||||
|
||||
async def cached_cursors_cleaner():
|
||||
global cached_cursors, closed_cached_cursors
|
||||
while True:
|
||||
print("hey")
|
||||
await close_old_cached_cursors()
|
||||
await remove_old_closed_cached_cursors()
|
||||
|
||||
await asyncio.sleep(10)
|
||||
|
||||
|
||||
async def pool_creator(connection: Connection, minsize=5, maxsize=10):
|
||||
@@ -28,8 +69,31 @@ async def pool_creator(connection: Connection, minsize=5, maxsize=10):
|
||||
maxsize=maxsize,
|
||||
)
|
||||
|
||||
async def create_cursor(connection_id: int, query: SelectQuery) -> CachedCursor:
|
||||
pool = pools.get(connection_id, None)
|
||||
|
||||
async def execute_select_query(pool: aiomysql.Pool, query: str, params: list, fetch_num:int = 100):
|
||||
if pool is None:
|
||||
raise PoolNotFound
|
||||
|
||||
connection = await pool.acquire()
|
||||
cursor = await connection.cursor(aiomysql.SSCursor)
|
||||
await cursor.execute(query.sql, query.params)
|
||||
|
||||
cached_cursor = CachedCursor(
|
||||
id=str(uuid.uuid4()),
|
||||
cursor=cursor,
|
||||
connection=connection,
|
||||
pool=pool,
|
||||
connection_id=connection_id,
|
||||
query=query,
|
||||
)
|
||||
|
||||
return cached_cursor
|
||||
|
||||
|
||||
async def execute_select_query(
|
||||
pool: aiomysql.Pool, sql_query: str, params: list, fetch_num: int = 100
|
||||
):
|
||||
"""
|
||||
Executes a SELECT query on the MySQL database asynchronously and returns the results.
|
||||
|
||||
@@ -43,7 +107,7 @@ async def execute_select_query(pool: aiomysql.Pool, query: str, params: list, fe
|
||||
try:
|
||||
async with pool.acquire() as connection:
|
||||
async with connection.cursor(aiomysql.DictCursor) as cursor:
|
||||
await cursor.execute(query, params)
|
||||
await cursor.execute(sql_query, params)
|
||||
result = await cursor.fetchmany(fetch_num)
|
||||
|
||||
return result, cursor.rowcount
|
||||
|
||||
20
main.py
20
main.py
@@ -1,19 +1,27 @@
|
||||
import asyncio
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI
|
||||
|
||||
from app import api_router
|
||||
|
||||
from utils.scripts import pools_creator, pools_destroy, db_startup
|
||||
|
||||
|
||||
from utils.scripts import pools_creator, pools_destroy, db_startup, cursors_closer
|
||||
from dbs import mysql
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
await pools_creator()
|
||||
mysql.cached_cursors_cleaner_task = asyncio.create_task(mysql.cached_cursors_cleaner())
|
||||
|
||||
yield
|
||||
|
||||
mysql.cached_cursors_cleaner_task.cancel()
|
||||
try:
|
||||
await mysql.cached_cursors_cleaner_task
|
||||
except asyncio.CancelledError:
|
||||
print('Closed cached_cursors_cleaner_task')
|
||||
await cursors_closer()
|
||||
await pools_destroy()
|
||||
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
app.include_router(router=api_router)
|
||||
@@ -1,3 +1,4 @@
|
||||
|
||||
from data.schemas import SelectQueryBase
|
||||
|
||||
def build_sql_query_text(query: SelectQueryBase) -> tuple[str, list]:
|
||||
@@ -20,6 +20,12 @@ async def pools_creator():
|
||||
mysql.pools[connection.id] = await mysql.pool_creator(connection=connection)
|
||||
logging.info(msg='Created Pools')
|
||||
|
||||
async def cursors_closer():
|
||||
from dbs import mysql
|
||||
for cursor_id, cursor in mysql.cached_cursors.items():
|
||||
await cursor.close()
|
||||
logging.info(f'Closed cursor: {cursor_id}')
|
||||
|
||||
async def pools_destroy():
|
||||
from dbs import mysql
|
||||
for connection_id, pool in mysql.pools.items():
|
||||
|
||||
Reference in New Issue
Block a user