Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@ jobs:
fail-fast: false
matrix:
os: ["ubuntu-latest"]
python-version: ["3.10", "3.11"]
python-version: [
"3.10",
"3.11",
"3.12",
]

env:
OS: ${{ matrix.os }}
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
/.cache
/.idea
/.mypy_cache
/.pytest_cache
/.ruff_cache
/.vagrant
Expand Down
2 changes: 2 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@


## in progress
- CI: Run tests on Python 3.12
- OCI: Update to Python 3.12

## 2024-06-25 v2.14.1
- Started using more SQLAlchemy patches and polyfills from `sqlalchemy-cratedb`
Expand Down
5 changes: 3 additions & 2 deletions examples/tracking_merlion.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,8 @@ def import_data(data_table_name: str, anomalies_table_name: str):
("2014-02-07 14:55:00.000000", "2014-02-09 14:05:00.000000"),
]
cursor.executemany(
f"INSERT INTO {anomalies_table_name} (ts_start, ts_end) VALUES (?, ?)", known_anomalies # noqa: S608
f"INSERT INTO {anomalies_table_name} (ts_start, ts_end) VALUES (?, ?)", # noqa: S608
known_anomalies, # noqa: S608
)


Expand Down Expand Up @@ -222,7 +223,7 @@ def run_experiment(time_series: pd.DataFrame, anomalies_table_name: str):
r = TSADMetric.Recall.value(ground_truth=test_labels, predict=test_pred)
f1 = TSADMetric.F1.value(ground_truth=test_labels, predict=test_pred)
mttd = TSADMetric.MeanTimeToDetect.value(ground_truth=test_labels, predict=test_pred)
print(f"Precision: {p:.4f}, Recall: {r:.4f}, F1: {f1:.4f}\n" f"Mean Time To Detect: {mttd}") # noqa: T201
print(f"Precision: {p:.4f}, Recall: {r:.4f}, F1: {f1:.4f}\nMean Time To Detect: {mttd}") # noqa: T201

mlflow.log_input(mlflow.data.from_pandas(input_test_data), context="training")
mlflow.log_metric("precision", p)
Expand Down
2 changes: 1 addition & 1 deletion examples/tracking_pycaret.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def read_data(table_name: str) -> pd.DataFrame:
FROM {table_name}
GROUP BY month
ORDER BY month
"""
""" # noqa: S608
with connect_database() as conn:
data = pd.read_sql(query, conn)

Expand Down
145 changes: 65 additions & 80 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,10 @@ requires = [
"versioningit",
]

[tool.versioningit.vcs]
method = "git"
default-tag = "0.0.0"

[project]
name = "mlflow-cratedb"
description = "MLflow adapter for CrateDB"
readme = "README.md"
requires-python = ">=3.8,<3.12"
license = {text = "Apache License 2.0"}
keywords = [
"ai",
"cratedb",
Expand All @@ -32,9 +26,11 @@ keywords = [
"mlflow-tracking",
"mlops",
]
license = { text = "Apache License 2.0" }
authors = [
{name = "Andreas Motl", email = "[email protected]"},
{ name = "Andreas Motl", email = "[email protected]" },
]
requires-python = ">=3.8,<3.13"
classifiers = [
"Development Status :: 4 - Beta",
"Environment :: Console",
Expand All @@ -52,6 +48,7 @@ classifiers = [
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Topic :: Communications",
"Topic :: Database",
"Topic :: Database :: Database Engines/Servers",
Expand Down Expand Up @@ -84,51 +81,75 @@ dependencies = [
"sqlparse<0.6",
]

[project.optional-dependencies]
develop = [
"black<25",
optional-dependencies.develop = [
"mypy<1.15",
"poethepoet<1",
"pyproject-fmt<2.6",
"ruff<0.9",
"ruff<0.10",
"validate-pyproject<0.24",
]
examples = [
'pycaret[analysis,models,parallel,test,tuner]==3.3.2; platform_machine != "aarch64"',
optional-dependencies.examples = [
"pycaret[analysis,models,parallel,test,tuner]==3.3.2; platform_machine!='aarch64'",
"salesforce-merlion<2.1",
"werkzeug==2.2.3",
]
release = [
optional-dependencies.release = [
"build<2",
"twine<7",
]
test = [
optional-dependencies.test = [
"psutil==5.9.8",
"pytest<9",
"pytest-cov<7",
]
[project.scripts]
mlflow-cratedb = "mlflow_cratedb.cli:cli"
[project.entry-points."mlflow.app"]
mlflow-cratedb = "mlflow_cratedb.server:app"
urls.changelog = "https://github.com/crate/mlflow-cratedb/blob/main/CHANGES.md"
urls.documentation = "https://github.com/crate/mlflow-cratedb"
urls.homepage = "https://github.com/crate/mlflow-cratedb"
urls.repository = "https://github.com/crate/mlflow-cratedb"
scripts.mlflow-cratedb = "mlflow_cratedb.cli:cli"
entry-points."mlflow.app".mlflow-cratedb = "mlflow_cratedb.server:app"

[tool.setuptools]
# https://setuptools.pypa.io/en/latest/userguide/package_discovery.html
packages = ["mlflow_cratedb"]

[project.urls]
changelog = "https://github.com/crate/mlflow-cratedb/blob/main/CHANGES.md"
documentation = "https://github.com/crate/mlflow-cratedb"
homepage = "https://github.com/crate/mlflow-cratedb"
repository = "https://github.com/crate/mlflow-cratedb"
[tool.black]
packages = [ "mlflow_cratedb" ]

[tool.ruff]
line-length = 120

extend-exclude = "tests/test_tracking.py"
extend-exclude = [
"tests/test_tracking.py",
]

lint.select = [
# Builtins
"A",
# Bugbear
"B",
# comprehensions
"C4",
# Pycodestyle
"E",
# eradicate
"ERA",
# Pyflakes
"F",
# isort
"I",
# pandas-vet
"PD",
# return
"RET",
# Bandit
"S",
# print
"T20",
"W",
# flake8-2020
"YTT",
]

[tool.isort]
profile = "black"
skip_glob = "**/site-packages/**"
skip_gitignore = false
lint.per-file-ignores."tests/*" = [ "S101" ] # Use of `assert` detected
lint.per-file-ignores."tests/conftest.py" = [ "E402" ] # Module level import not at top of file

[tool.pytest.ini_options]
minversion = "2.0"
Expand All @@ -139,7 +160,7 @@ addopts = """
"""
log_level = "DEBUG"
log_cli_level = "DEBUG"
testpaths = ["tests"]
testpaths = [ "tests" ]
xfail_strict = true
markers = [
"examples",
Expand All @@ -149,17 +170,17 @@ markers = [

[tool.coverage.run]
branch = false
source = ["mlflow_cratedb"]
source = [ "mlflow_cratedb" ]
omit = [
"tests/*",
"tests/*",
]

[tool.coverage.report]
fail_under = 0
show_missing = true

[tool.mypy]
packages = ["mlflow_cratedb"]
packages = [ "mlflow_cratedb" ]
exclude = [
]
check_untyped_defs = true
Expand All @@ -173,62 +194,26 @@ strict_equality = true
warn_unused_ignores = true
warn_redundant_casts = true

[tool.ruff]
line-length = 120

lint.select = [
# Bandit
"S",
# Bugbear
"B",
# Builtins
"A",
# comprehensions
"C4",
# eradicate
"ERA",
# flake8-2020
"YTT",
# isort
"I",
# pandas-vet
"PD",
# print
"T20",
# Pycodestyle
"E",
"W",
# Pyflakes
"F",
# return
"RET",
]

extend-exclude = [
]


[tool.ruff.lint.per-file-ignores]
"tests/*" = ["S101"] # Use of `assert` detected
"tests/conftest.py" = ["E402"] # Module level import not at top of file

[tool.versioningit.vcs]
method = "git"
default-tag = "0.0.0"

# ===================
# Tasks configuration
# ===================

[tool.poe.tasks]
format = [
{ cmd = "black ." },
{ cmd = "ruff format ." },
# Configure Ruff not to auto-fix (remove!):
# Ignore unused imports (F401), unused variables (F841), `print` statements (T201), and commented-out code (ERA001).
{ cmd = "ruff --fix --ignore=ERA --ignore=F401 --ignore=F841 --ignore=T20 --ignore=ERA001 ." },
{ cmd = "ruff check --fix --ignore=ERA --ignore=F401 --ignore=F841 --ignore=T20 --ignore=ERA001 ." },
{ cmd = "pyproject-fmt --keep-full-version pyproject.toml" },
]

lint = [
{ cmd = "ruff format --check ." },
{ cmd = "ruff check ." },
{ cmd = "black --check ." },
{ cmd = "validate-pyproject pyproject.toml" },
{ cmd = "mypy" },
]
Expand All @@ -239,8 +224,8 @@ test-fast = [
{ cmd = "pytest -m 'not slow'" },
]
build = { cmd = "python -m build" }
check = ["lint", "test"]
check-fast = ["lint", "test-fast"]
check = [ "lint", "test" ]
check-fast = [ "lint", "test-fast" ]

release = [
{ cmd = "python -m build" },
Expand Down
2 changes: 1 addition & 1 deletion release/oci-runtime/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# - https://vsupalov.com/buildkit-cache-mount-dockerfile/
# - https://github.com/FernandoMiguel/Buildkit#mounttypecache

FROM python:3.11-slim-bullseye
FROM python:3.12-slim-bullseye

ENV DEBIAN_FRONTEND noninteractive
ENV TERM linux
Expand Down
2 changes: 1 addition & 1 deletion release/oci-server/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# - https://vsupalov.com/buildkit-cache-mount-dockerfile/
# - https://github.com/FernandoMiguel/Buildkit#mounttypecache

FROM python:3.11-slim-bullseye
FROM python:3.12-slim-bullseye

ENV DEBIAN_FRONTEND noninteractive
ENV TERM linux
Expand Down
12 changes: 6 additions & 6 deletions tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,15 +183,15 @@ def test_tracking_pycaret(reset_database, engine: sa.Engine, tracking_store: Sql
# We have 2 experiments - one for "Default" experiment and one for the example
assert session.query(SqlExperiment).count() == 2, "experiments should have 2 rows"
# We have 32 distinct runs in the experiment which produced metrics
assert (
session.query(sa.func.count(sa.distinct(SqlMetric.run_uuid))).scalar() == 32
), "metrics should have 32 distinct run_uuid"
assert session.query(sa.func.count(sa.distinct(SqlMetric.run_uuid))).scalar() == 32, (
"metrics should have 32 distinct run_uuid"
)
# We have 33 runs in total (1 parent + 32 child runs)
assert session.query(SqlRun).count() == 33, "runs should have 33 rows"
# We have 33 distinct runs which have parameters (1 parent + 32 child runs)
assert (
session.query(sa.func.count(sa.distinct(SqlParam.run_uuid))).scalar() == 33
), "params should have 33 distinct run_uuid"
assert session.query(sa.func.count(sa.distinct(SqlParam.run_uuid))).scalar() == 33, (
"params should have 33 distinct run_uuid"
)
# We have one model registered
assert session.query(SqlRegisteredModel).count() == 1, "registered_models should have 1 row"

Expand Down
Loading