From c0d4c4feba7da10c46022d149cb588129d1f4d16 Mon Sep 17 00:00:00 2001 From: db0 Date: Mon, 8 Apr 2024 20:53:52 +0200 Subject: [PATCH] feat: Allow quering only for custom models --- horde/__init__.py | 6 +++--- horde/apis/v2/base.py | 17 ++++++++++++++++- horde/apis/v2/stable.py | 17 +++++++++++------ horde/classes/stable/genstats.py | 23 ++++++++++++----------- horde/database/functions.py | 19 ++++++++++++++++++- 5 files changed, 60 insertions(+), 22 deletions(-) diff --git a/horde/__init__.py b/horde/__init__.py index 7a719401..0e3a55f2 100644 --- a/horde/__init__.py +++ b/horde/__init__.py @@ -19,9 +19,9 @@ def after_request(response): response.headers["Access-Control-Allow-Origin"] = "*" response.headers["Access-Control-Allow-Methods"] = "POST, GET, OPTIONS, PUT, DELETE, PATCH" - response.headers["Access-Control-Allow-Headers"] = ( - "Accept, Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, apikey, Client-Agent, X-Fields" - ) + response.headers[ + "Access-Control-Allow-Headers" + ] = "Accept, Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, apikey, Client-Agent, X-Fields" response.headers["Horde-Node"] = f"{socket.gethostname()}:{args.port}:{HORDE_VERSION}" return response diff --git a/horde/apis/v2/base.py b/horde/apis/v2/base.py index fd3c6dde..cc02a2dd 100644 --- a/horde/apis/v2/base.py +++ b/horde/apis/v2/base.py @@ -1565,10 +1565,22 @@ class Models(Resource): help="Filter the models that have at most this amount of threads serving.", location="args", ) + get_parser.add_argument( + "model_state", + required=False, + default="all", + type=str, + help=( + "If 'known', only show stats for known models in the model reference. " + "If 'custom' only show stats for custom models. " + "If 'all' shows stats for all models." + ), + location="args", + ) - @logger.catch(reraise=True) @cache.cached(timeout=2, query_string=True) @api.expect(get_parser) + @api.response(400, "Validation Error", models.response_model_error) @api.marshal_with( models.response_model_active_model, code=200, @@ -1578,10 +1590,13 @@ class Models(Resource): def get(self): """Returns a list of models active currently in this horde""" self.args = self.get_parser.parse_args() + if self.args.model_state not in ["known", "custom", "all"]: + raise e.BadRequest("'model_state' needs to be one of ['known', 'custom', 'all']") models_ret = database.retrieve_available_models( model_type=self.args.type, min_count=self.args.min_count, max_count=self.args.max_count, + model_state=self.args.model_state, ) return (models_ret, 200) diff --git a/horde/apis/v2/stable.py b/horde/apis/v2/stable.py index 0fd58d96..619a30e8 100644 --- a/horde/apis/v2/stable.py +++ b/horde/apis/v2/stable.py @@ -1274,17 +1274,22 @@ class ImageHordeStatsModels(Resource): location="headers", ) get_parser.add_argument( - "model_type", + "model_state", required=False, - default='known', + default="known", type=str, - help="If 'known', only show stats for known models in the model reference. If 'custom' only show stats for custom models. If 'all' shows stats for all models.", + help=( + "If 'known', only show stats for known models in the model reference. " + "If 'custom' only show stats for custom models. " + "If 'all' shows stats for all models." + ), location="args", ) @logger.catch(reraise=True) # @cache.cached(timeout=50, query_string=True) @api.expect(get_parser) + @api.response(400, "Validation Error", models.response_model_error) @api.marshal_with( models.response_model_stats_models, code=200, @@ -1293,6 +1298,6 @@ class ImageHordeStatsModels(Resource): def get(self): """Details how many images were generated per model for the past day, month and total""" self.args = self.get_parser.parse_args() - if self.args.model_type not in ['known', 'custom', 'all']: - return e.BadRequest("'model_type' needs to be one of ['known', 'custom', 'all']") - return compile_imagegen_stats_models(self.args.model_type), 200 + if self.args.model_state not in ["known", "custom", "all"]: + raise e.BadRequest("'model_state' needs to be one of ['known', 'custom', 'all']") + return compile_imagegen_stats_models(self.args.model_state), 200 diff --git a/horde/classes/stable/genstats.py b/horde/classes/stable/genstats.py index 5adae867..a1433e34 100644 --- a/horde/classes/stable/genstats.py +++ b/horde/classes/stable/genstats.py @@ -1,4 +1,3 @@ -from loguru import logger from datetime import datetime, timedelta from sqlalchemy import Enum, func @@ -208,30 +207,32 @@ def compile_imagegen_stats_totals(): return stats_dict -def compile_imagegen_stats_models(model_type = 'known'): +def compile_imagegen_stats_models(model_state="known"): query = db.session.query(ImageGenerationStatistic.model, func.count()).group_by(ImageGenerationStatistic.model) - def check_model_type(model_name): - if model_type == 'known' and model_reference.is_known_image_model(model_name): - return True - if model_type == 'custom' and not model_reference.is_known_image_model(model_name): - return True - if model_type == 'all': + + def check_model_state(model_name): + if model_state == "known" and model_reference.is_known_image_model(model_name): + return True + if model_state == "custom" and not model_reference.is_known_image_model(model_name): + return True + if model_state == "all": return True return False + return { - "total": {model: count for model, count in query.all() if check_model_type(model)}, + "total": {model: count for model, count in query.all() if check_model_state(model)}, "day": { model: count for model, count in query.filter( ImageGenerationStatistic.finished >= datetime.utcnow() - timedelta(days=1), ).all() - if check_model_type(model) + if check_model_state(model) }, "month": { model: count for model, count in query.filter( ImageGenerationStatistic.finished >= datetime.utcnow() - timedelta(days=30), ).all() - if check_model_type(model) + if check_model_state(model) }, } diff --git a/horde/database/functions.py b/horde/database/functions.py index 0bab7a24..32878914 100644 --- a/horde/database/functions.py +++ b/horde/database/functions.py @@ -366,7 +366,7 @@ def get_available_models(filter_model_name: str = None): return list(models_dict.values()) -def retrieve_available_models(model_type=None, min_count=None, max_count=None): +def retrieve_available_models(model_type=None, min_count=None, max_count=None, model_state="known"): """Retrieves model details from Redis cache, or from DB if cache is unavailable""" if hr.horde_r is None: return get_available_models() @@ -384,6 +384,23 @@ def retrieve_available_models(model_type=None, min_count=None, max_count=None): models_ret = [md for md in models_ret if md["count"] >= min_count] if max_count is not None: models_ret = [md for md in models_ret if md["count"] <= max_count] + + def check_model_state(model_name): + if model_type is None: + return True + model_check = model_reference.is_known_image_model + if model_type == "text": + model_check = model_reference.is_known_text_model + if model_state == "known" and model_check(model_name): + return True + if model_state == "custom" and not model_check(model_name): + return True + if model_state == "all": + return True + return False + + models_ret = [md for md in models_ret if check_model_state(md["name"])] + return models_ret