From 60a2fb80d6b8cc30fe98422e29e70929d18864cd Mon Sep 17 00:00:00 2001 From: nik Date: Thu, 6 Jun 2019 15:11:21 +0300 Subject: [PATCH 1/2] make use threads and custom redis client --- htx/htx.py | 6 ++++-- htx/model_manager.py | 16 +++++++++------- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/htx/htx.py b/htx/htx.py index fe73a95..8dcb096 100644 --- a/htx/htx.py +++ b/htx/htx.py @@ -1,7 +1,9 @@ import json import multiprocessing as mp import logging +import threading +from queue import Queue from flask import Flask, request, jsonify from htx.model_manager import ModelManager @@ -19,9 +21,9 @@ def init_model_server(**kwargs): @_server.before_first_request def launch_train_loop(): - train_process = mp.Process( + train_process = threading.Thread( target=_model_manager.train_loop, - args=(_model_manager.queue, _model_manager.train_script) + args=(_model_manager.queue, _model_manager.train_script, _model_manager.redis, ) ) train_process.start() diff --git a/htx/model_manager.py b/htx/model_manager.py index 4628a9d..9d3f9c9 100644 --- a/htx/model_manager.py +++ b/htx/model_manager.py @@ -1,10 +1,10 @@ import os -import multiprocessing as mp import logging import json import attr import io import shutil +import queue from operator import attrgetter from redis import Redis @@ -52,7 +52,7 @@ class ModelManager(object): _MODEL_LIST_FILE = 'model_list.txt' _DEFAULT_MODEL_VERSION = 'model' - queue = mp.Queue() + queue = queue.Queue() def __init__( self, @@ -62,6 +62,7 @@ def __init__( data_dir='~/.heartex/data', min_examples_for_train=1, retrain_after_num_examples=1, + redis=None, redis_host='localhost', redis_port=6379, redis_queue='default', @@ -85,11 +86,13 @@ def __init__( self.model_list_file = os.path.join(self.model_dir, self._MODEL_LIST_FILE) self._current_model = {} - self._redis = Redis(host=redis_host, port=redis_port) + if redis is not None: + self._redis = redis + else: + self._redis = Redis(host=redis_host, port=redis_port) def _get_latest_finished_train_job(self, project): - queue = Queue(name=self.redis_queue, connection=self._redis) - registry = FinishedJobRegistry(queue.name, queue.connection) + registry = FinishedJobRegistry(self.redis_queue, self._redis) if registry.count == 0: logger.info('Train job registry is empty.') return None @@ -281,8 +284,7 @@ def _flush_all(self, project, redis, reqis_queue): # TODO: do we need the locks here? shutil.rmtree(project_data_dir) - def train_loop(self, data_queue, train_script): - redis = Redis(host=self.redis_host, port=self.redis_port) + def train_loop(self, data_queue, train_script, redis): redis_queue = Queue(name=self.redis_queue, connection=redis) logger.info(f'Train loop starts: PID={os.getpid()}, Redis connection: {redis}, queue: {redis_queue}') for queued_items, in iter(data_queue.get, None): From c68d3172a64359148cf4d0645a81a0199a768501 Mon Sep 17 00:00:00 2001 From: nik Date: Sun, 16 Jun 2019 11:57:47 +0300 Subject: [PATCH 2/2] fix bbox, model wrapper, fix retrain --- htx/base_model.py | 8 ++++---- htx/model_manager.py | 26 ++++++++++++++++++++++---- 2 files changed, 26 insertions(+), 8 deletions(-) diff --git a/htx/base_model.py b/htx/base_model.py index 5cf5d01..19694c0 100644 --- a/htx/base_model.py +++ b/htx/base_model.py @@ -18,7 +18,7 @@ # Output types CHOICES_TYPE = 'Choices' LABELS_TYPE = 'Labels' -BOUNDING_BOX_TYPE = 'AddRectangleButton' +BOUNDING_BOX_TYPE = 'RectangleLabels' LIST_TYPE = 'Ranker' @@ -242,7 +242,7 @@ def make_results(self, list_of_spans, scores): class BoundingBoxBaseModel(BaseModel): - OUTPUT_TYPES = (LABELS_TYPE, BOUNDING_BOX_TYPE) + OUTPUT_TYPES = (BOUNDING_BOX_TYPE,) def get_output(self, task): input_name = self.input_names[0] @@ -257,7 +257,7 @@ def get_output(self, task): 'y': value['y'], 'width': value['width'], 'height': value['height'], - 'label': value['labels'][0] + 'label': value['rectanglelabels'][0] }) return output @@ -272,7 +272,7 @@ def make_result(self, list_of_bboxes, scores): 'from_name': output_name, 'to_name': input_name, 'value': { - 'labels': [bbox['label']], + 'rectanglelabels': [bbox['label']], 'x': bbox['x'], 'y': bbox['y'], 'height': bbox['height'], diff --git a/htx/model_manager.py b/htx/model_manager.py index 9d3f9c9..c9e8fad 100644 --- a/htx/model_manager.py +++ b/htx/model_manager.py @@ -16,6 +16,18 @@ logger = logging.getLogger(__name__) +@attr.s +class ModelWrapper(object): + model = attr.ib() + version = attr.ib() + + def predict(self, *args, **kwargs): + return self.model.predict(*args, **kwargs) + + def get_data_item(self, *args, **kwargs): + return self.model.get_data_item(*args, **kwargs) + + class QueuedItem(object): def __init__(self, project): @@ -41,7 +53,10 @@ class QueuedWaitSignal(QueuedItem): class QueuedTrainSignal(QueuedItem): - pass + + def __init__(self, project, force=True): + super(QueuedTrainSignal, self).__init__(project) + self.force = force class QueuedFlushAllSignal(QueuedItem): @@ -141,7 +156,7 @@ def setup(self, project, schema): logger.error(f'Found resources {resources}, but model is not loaded for project {project}. ' f'Consequent API calls (e.g. predict) will fail.') return None - self._current_model[project] = model + self._current_model[project] = ModelWrapper(model, model_version) self._stash_resources(project, resources) logger.info(f'Model {model_version} successfully loaded for project {project}.') return model_version @@ -177,7 +192,7 @@ def update(self, task, project, schema): else: queued_items = [ QueuedDataItem(data_item, project), - QueuedTrainSignal(project) + QueuedTrainSignal(project, force=False) ] self.queue.put((queued_items,)) @@ -308,7 +323,10 @@ def train_loop(self, data_queue, train_script, redis): elif isinstance(queued_item, QueuedTrainSignal): try: total_items = self._update_counters(redis, project) - if total_items >= self.min_examples_for_train and total_items % self.retrain_after_num_examples == 0: + if queued_item.force or ( + total_items >= self.min_examples_for_train and + total_items % self.retrain_after_num_examples == 0 + ): self._run_train_script(redis_queue, train_script, project_data_dir, project) except Exception as error: logger.error(f'Failed to start training job. Reason: {error}', exc_info=True)