diff --git a/README.md b/README.md index bc53ab6..57f7336 100644 --- a/README.md +++ b/README.md @@ -160,6 +160,20 @@ db_upgrade This will run migrations using [Alembic](https://alembic.sqlalchemy.org/en/latest/) (already installed as a dependency) to create or update the required tables in your database. +### Creating New Database Revisions + +If you need to create new database migrations after modifying the database models, you can use the `db_revision` command: + +```bash +# Create a new revision with manual changes +db_revision --message "Add new column to users table" + +# Create a new revision with automatic detection of model changes +db_revision --message "Update user model" --autogenerate +``` + +The `--autogenerate` flag will automatically detect changes in your SQLAlchemy models and generate the appropriate migration code. Without this flag, you'll need to manually write the migration code in the generated revision file. + ## Launch You are now ready to go. Simply run from the virtual environment: diff --git a/brevia/alembic/__init__.py b/brevia/alembic/__init__.py index 0056ff6..6e0084a 100644 --- a/brevia/alembic/__init__.py +++ b/brevia/alembic/__init__.py @@ -18,3 +18,8 @@ def upgrade(revision="head"): def downgrade(revision): command.downgrade(alembic_cfg, revision) + + +def revision(message, autogenerate=False): + """Create a new revision file""" + command.revision(alembic_cfg, message=message, autogenerate=autogenerate) diff --git a/brevia/alembic/env.py b/brevia/alembic/env.py index 8d75819..ea8f67e 100644 --- a/brevia/alembic/env.py +++ b/brevia/alembic/env.py @@ -4,6 +4,15 @@ from alembic import context from sqlalchemy_utils import database_exists, create_database from brevia.connection import connection_string +# Import models for autogenerate support +from brevia.chat_history import ChatHistoryStore # noqa: F401 +from brevia.async_jobs import AsyncJobsStore # noqa: F401 +from brevia.settings import ConfigStore # noqa: F401 +from langchain_community.vectorstores.pgembedding import ( # noqa: F401 + BaseModel, + CollectionStore, + EmbeddingStore, +) # this is the Alembic Config object, which provides # access to the values within the .ini file in use. @@ -28,9 +37,8 @@ # add your model's MetaData object here # for 'autogenerate' support -# from myapp import mymodel -# target_metadata = mymodel.Base.metadata -target_metadata = None +# Models are imported at the top of the file +target_metadata = BaseModel.metadata # other values from the config, defined by the needs of env.py, # can be acquired: diff --git a/brevia/alembic/versions/eb659f4dd1c9_utc_timestamp.py b/brevia/alembic/versions/eb659f4dd1c9_utc_timestamp.py new file mode 100644 index 0000000..2a46a0a --- /dev/null +++ b/brevia/alembic/versions/eb659f4dd1c9_utc_timestamp.py @@ -0,0 +1,132 @@ +"""utc_timestamp + +Revision ID: eb659f4dd1c9 +Revises: 24b025d48e0a +Create Date: 2025-07-07 19:12:35.616363 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = 'eb659f4dd1c9' +down_revision = '24b025d48e0a' +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.alter_column( + 'async_jobs', 'expires', + existing_type=postgresql.TIMESTAMP(), + type_=postgresql.TIMESTAMP(timezone=True), + existing_comment='Job expiry time', + existing_nullable=True, + ) + op.alter_column( + 'async_jobs', 'created', + existing_type=postgresql.TIMESTAMP(), + type_=postgresql.TIMESTAMP(timezone=True), + existing_comment='Creation timestamp', + existing_nullable=False, + existing_server_default=sa.text('CURRENT_TIMESTAMP'), + ) + op.alter_column( + 'async_jobs', 'locked_until', + existing_type=postgresql.TIMESTAMP(), + type_=postgresql.TIMESTAMP(timezone=True), + existing_comment='Timestamp at which the lock expires', + existing_nullable=True, + ) + op.alter_column( + 'async_jobs', 'completed', + existing_type=postgresql.TIMESTAMP(), + type_=postgresql.TIMESTAMP(timezone=True), + existing_comment='Timestamp at which this job was marked as completed', + existing_nullable=True, + ) + op.alter_column( + 'chat_history', 'created', + existing_type=postgresql.TIMESTAMP(), + type_=sa.DateTime(timezone=True), + existing_nullable=False, + existing_server_default=sa.text('CURRENT_TIMESTAMP'), + ) + op.alter_column( + 'config', 'created', + existing_type=postgresql.TIMESTAMP(), + type_=postgresql.TIMESTAMP(timezone=True), + existing_comment='Creation timestamp', + existing_nullable=False, + existing_server_default=sa.text('CURRENT_TIMESTAMP'), + ) + op.alter_column( + 'config', 'modified', + existing_type=postgresql.TIMESTAMP(), + type_=postgresql.TIMESTAMP(timezone=True), + existing_comment='Last update timestamp', + existing_nullable=False, + existing_server_default=sa.text('CURRENT_TIMESTAMP'), + ) + op.create_unique_constraint('uq_config_key', 'config', ['config_key']) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_constraint('uq_config_key', 'config', type_='unique') + op.alter_column( + 'config', 'modified', + existing_type=postgresql.TIMESTAMP(timezone=True), + type_=postgresql.TIMESTAMP(), + existing_comment='Last update timestamp', + existing_nullable=False, + existing_server_default=sa.text('CURRENT_TIMESTAMP'), + ) + op.alter_column( + 'config', 'created', + existing_type=postgresql.TIMESTAMP(timezone=True), + type_=postgresql.TIMESTAMP(), + existing_comment='Creation timestamp', + existing_nullable=False, + existing_server_default=sa.text('CURRENT_TIMESTAMP'), + ) + op.alter_column( + 'chat_history', 'created', + existing_type=sa.DateTime(timezone=True), + type_=postgresql.TIMESTAMP(), + existing_nullable=False, + existing_server_default=sa.text('CURRENT_TIMESTAMP'), + ) + op.alter_column( + 'async_jobs', 'completed', + existing_type=postgresql.TIMESTAMP(timezone=True), + type_=postgresql.TIMESTAMP(), + existing_comment='Timestamp at which this job was marked as completed', + existing_nullable=True, + ) + op.alter_column( + 'async_jobs', 'locked_until', + existing_type=postgresql.TIMESTAMP(timezone=True), + type_=postgresql.TIMESTAMP(), + existing_comment='Timestamp at which the lock expires', + existing_nullable=True, + ) + op.alter_column( + 'async_jobs', 'created', + existing_type=postgresql.TIMESTAMP(timezone=True), + type_=postgresql.TIMESTAMP(), + existing_comment='Creation timestamp', + existing_nullable=False, + existing_server_default=sa.text('CURRENT_TIMESTAMP'), + ) + op.alter_column( + 'async_jobs', 'expires', + existing_type=postgresql.TIMESTAMP(timezone=True), + type_=postgresql.TIMESTAMP(), + existing_comment='Job expiry time', + existing_nullable=True, + ) + # ### end Alembic commands ### diff --git a/brevia/async_jobs.py b/brevia/async_jobs.py index 7a9d1c4..bf8e7fb 100644 --- a/brevia/async_jobs.py +++ b/brevia/async_jobs.py @@ -1,7 +1,7 @@ """Async Jobs table & utilities""" import logging import time -from datetime import datetime +from datetime import datetime, timezone from sqlalchemy import BinaryExpression, Column, desc, func, String, text from pydantic import BaseModel as PydanticModel from sqlalchemy.dialects.postgresql import JSON, TIMESTAMP, SMALLINT @@ -22,22 +22,30 @@ class AsyncJobsStore(BaseModel): """ Async Jobs table """ __tablename__ = "async_jobs" - service = Column(String(), nullable=False) - payload = Column(JSON()) - expires = Column(TIMESTAMP(timezone=False)) + service = Column(String(), nullable=False, comment='Service job') + payload = Column(JSON(), comment='Input data for this job') + expires = Column(TIMESTAMP(timezone=True), comment='Job expiry time') created = Column( - TIMESTAMP(timezone=False), + TIMESTAMP(timezone=True), nullable=False, server_default=func.current_timestamp(), + comment='Creation timestamp', + ) + completed = Column( + TIMESTAMP(timezone=True), + comment='Timestamp at which this job was marked as completed', + ) + locked_until = Column( + TIMESTAMP(timezone=True), + comment='Timestamp at which the lock expires' ) - completed = Column(TIMESTAMP(timezone=False)) - locked_until = Column(TIMESTAMP(timezone=False)) max_attempts = Column( SMALLINT(), nullable=False, server_default='1', + comment='Maximum number of attempts left for this job' ) - result = Column(JSON(), nullable=True) + result = Column(JSON(), nullable=True, comment='Job result') def single_job(uuid: str) -> (AsyncJobsStore | None): @@ -136,8 +144,8 @@ def create_job( """ Create async job """ max_duration = payload.get('max_duration', MAX_DURATION) # max duration in minutes max_attempts = payload.get('max_attempts', MAX_ATTEMPTS) - tstamp = int(time.time()) + (max_duration * max_attempts * 2 * 60) - expires = datetime.fromtimestamp(tstamp) + tstamp = time.time() + (max_duration * max_attempts * 2 * 60) + expires = datetime.fromtimestamp(timestamp=tstamp, tz=timezone.utc) with Session(db_connection()) as session: job_store = AsyncJobsStore( @@ -164,7 +172,7 @@ def complete_job( if not job_store: log.error("Job %s not found", uuid) return - now = datetime.now() + now = datetime.now(tz=timezone.utc) if job_store.expires and job_store.expires < now: log.warning("Job %s is expired at %s", uuid, job_store.expires) return @@ -184,7 +192,7 @@ def complete_job( def save_job_result(job_store: AsyncJobsStore, result: dict, error: bool = False): """Save Job result""" with Session(db_connection()) as session: - job_store.completed = datetime.now() + job_store.completed = datetime.now(tz=timezone.utc) job_store.result = result if error: job_store.max_attempts = max(job_store.max_attempts - 1, 0) @@ -233,8 +241,8 @@ def lock_job_service( if not is_job_available(job_store): raise RuntimeError(f'Job {job_store.uuid} is not available') payload = job_store.payload if job_store.payload else {} - tstamp = int(time.time()) + (int(payload.get('max_duration', MAX_DURATION)) * 60) - locked_until = datetime.fromtimestamp(tstamp) + tstamp = time.time() + (float(payload.get('max_duration', MAX_DURATION)) * 60) + locked_until = datetime.fromtimestamp(tstamp, tz=timezone.utc) with Session(db_connection()) as session: job_store.locked_until = locked_until @@ -247,7 +255,7 @@ def is_job_available( job_store: AsyncJobsStore ) -> bool: """Check if job is available""" - now = datetime.now() + now = datetime.now(tz=timezone.utc) if job_store.completed or (job_store.expires and job_store.expires < now): return False if job_store.locked_until and (job_store.locked_until > now): diff --git a/brevia/chat_history.py b/brevia/chat_history.py index efca025..6953650 100644 --- a/brevia/chat_history.py +++ b/brevia/chat_history.py @@ -33,7 +33,7 @@ class ChatHistoryStore(BaseModel): answer = sqlalchemy.Column(sqlalchemy.String) # pylint: disable=not-callable created = sqlalchemy.Column( - sqlalchemy.DateTime(timezone=False), + sqlalchemy.DateTime(timezone=True), nullable=False, server_default=sqlalchemy.sql.func.now() ) @@ -41,9 +41,16 @@ class ChatHistoryStore(BaseModel): user_evaluation = sqlalchemy.Column( sqlalchemy.BOOLEAN(), nullable=True, + comment='User evaluation as good (True) or bad (False)', + ) + user_feedback = sqlalchemy.Column( + sqlalchemy.String, + comment='User textual feedback on the evaluation', + ) + chat_source = sqlalchemy.Column( + sqlalchemy.String, + comment='Generic string to identify chat source (e.g. application name)', ) - user_feedback = sqlalchemy.Column(sqlalchemy.String) - chat_source = sqlalchemy.Column(sqlalchemy.String) def history(chat_history: list, session: str = None): diff --git a/brevia/commands.py b/brevia/commands.py index 69f8acd..b01dca7 100644 --- a/brevia/commands.py +++ b/brevia/commands.py @@ -5,6 +5,7 @@ from logging import config import click from brevia.alembic import current, upgrade, downgrade +from brevia.alembic import revision as create_revision from brevia.index import update_links_documents from brevia.utilities import files_import, run_service, collections_io from brevia.tokens import create_token @@ -39,6 +40,20 @@ def db_downgrade_cmd(revision): downgrade(revision) +@click.command() +@click.option("-m", "--message", required=True, help="Revision message") +@click.option( + "--autogenerate", + is_flag=True, + default=False, + help="Autogenerate migration from model changes" +) +def db_revision_cmd(message, autogenerate): + """Create a new database revision""" + create_revision(message, autogenerate=autogenerate) + print(f"New revision created: {message}") + + @click.command() @click.option("-f", "--file-path", required=True, help="File or folder path") @click.option("-c", "--collection", required=True, help="Collection name") diff --git a/brevia/settings.py b/brevia/settings.py index 7f6a304..0a86f9c 100644 --- a/brevia/settings.py +++ b/brevia/settings.py @@ -4,9 +4,11 @@ from typing import Annotated, Any from os import environ, path, getcwd from urllib import parse -from sqlalchemy import NullPool, create_engine, Column, String, func, inspect +from sqlalchemy import ( + TIMESTAMP, NullPool, create_engine, Column, String, func, inspect, + UniqueConstraint, +) from sqlalchemy.engine import Connection -from sqlalchemy.dialects.postgresql import TIMESTAMP from sqlalchemy.orm import Session from langchain_community.vectorstores.pgembedding import BaseModel from langchain.globals import set_verbose @@ -210,19 +212,26 @@ class ConfigStore(BaseModel): # pylint: disable=too-few-public-methods,not-callable """ Config table """ __tablename__ = "config" + __table_args__ = ( + UniqueConstraint('config_key', name='uq_config_key'), + ) - config_key = Column(String(), nullable=False, unique=True) - config_val = Column(String(), nullable=False) + config_key = Column( + String(), nullable=False, comment='Configuration name' + ) + config_val = Column(String(), nullable=False, comment='Configuration value') created = Column( - TIMESTAMP(timezone=False), + TIMESTAMP(timezone=True), nullable=False, server_default=func.current_timestamp(), + comment='Creation timestamp', ) modified = Column( - TIMESTAMP(timezone=False), + TIMESTAMP(timezone=True), nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp(), + comment='Last update timestamp', ) diff --git a/pyproject.toml b/pyproject.toml index 0475bf7..7771055 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,6 +46,7 @@ repository = "https://github.com/brevia-ai/brevia" db_current = "brevia.commands:db_current_cmd" db_upgrade = "brevia.commands:db_upgrade_cmd" db_downgrade = "brevia.commands:db_downgrade_cmd" + db_revision = "brevia.commands:db_revision_cmd" export_collection = "brevia.commands:export_collection" import_collection = "brevia.commands:import_collection" import_file = "brevia.commands:import_file" diff --git a/tests/test_commands.py b/tests/test_commands.py index b8781a0..8071dcf 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -1,4 +1,5 @@ """Commands module tests""" +import glob from pathlib import Path from os import unlink from os.path import exists @@ -8,6 +9,7 @@ db_current_cmd, db_upgrade_cmd, db_downgrade_cmd, + db_revision_cmd, export_collection, import_collection, import_file, @@ -145,3 +147,40 @@ def test_update_collection_links(): collection.name, ]) assert result.exit_code == 0 + + +def test_db_revision_cmd(): + """ Test db_revision_cmd function """ + runner = CliRunner() + result = runner.invoke(db_revision_cmd, [ + '--message', + 'Test revision message', + ]) + assert result.exit_code == 0 + assert 'New revision created: Test revision message' in result.output + + # Clean up generated migration files + versions_dir = f'{Path(__file__).parent.parent}/brevia/alembic/versions' + migration_files = glob.glob(f'{versions_dir}/*_test_revision_message.py') + for file_path in migration_files: + if exists(file_path): + unlink(file_path) + + +def test_db_revision_cmd_with_autogenerate(): + """ Test db_revision_cmd function with autogenerate flag """ + runner = CliRunner() + result = runner.invoke(db_revision_cmd, [ + '--message', + 'Test autogenerate revision', + '--autogenerate', + ]) + assert result.exit_code == 0 + assert 'New revision created: Test autogenerate revision' in result.output + + # Clean up generated migration files + versions_dir = f'{Path(__file__).parent.parent}/brevia/alembic/versions' + migration_files = glob.glob(f'{versions_dir}/*_test_autogenerate_revision.py') + for file_path in migration_files: + if exists(file_path): + unlink(file_path)