diff --git a/htx/htx.py b/htx/htx.py index fe73a95..8dcb096 100644 --- a/htx/htx.py +++ b/htx/htx.py @@ -1,7 +1,9 @@ import json import multiprocessing as mp import logging +import threading +from queue import Queue from flask import Flask, request, jsonify from htx.model_manager import ModelManager @@ -19,9 +21,9 @@ def init_model_server(**kwargs): @_server.before_first_request def launch_train_loop(): - train_process = mp.Process( + train_process = threading.Thread( target=_model_manager.train_loop, - args=(_model_manager.queue, _model_manager.train_script) + args=(_model_manager.queue, _model_manager.train_script, _model_manager.redis, ) ) train_process.start() diff --git a/htx/model_manager.py b/htx/model_manager.py index 9555461..c8ffced 100644 --- a/htx/model_manager.py +++ b/htx/model_manager.py @@ -1,10 +1,10 @@ import os -import multiprocessing as mp import logging import json import attr import io import shutil +import queue from operator import attrgetter from redis import Redis @@ -16,6 +16,18 @@ logger = logging.getLogger(__name__) +@attr.s +class ModelWrapper(object): + model = attr.ib() + version = attr.ib() + + def predict(self, *args, **kwargs): + return self.model.predict(*args, **kwargs) + + def get_data_item(self, *args, **kwargs): + return self.model.get_data_item(*args, **kwargs) + + class QueuedItem(object): def __init__(self, project): @@ -54,7 +66,7 @@ class ModelManager(object): _MODEL_LIST_FILE = 'model_list.txt' _DEFAULT_MODEL_VERSION = 'model' - queue = mp.Queue() + queue = queue.Queue() def __init__( self, @@ -64,6 +76,7 @@ def __init__( data_dir='~/.heartex/data', min_examples_for_train=1, retrain_after_num_examples=1, + redis=None, redis_host='localhost', redis_port=6379, redis_queue='default', @@ -87,11 +100,13 @@ def __init__( self.model_list_file = os.path.join(self.model_dir, self._MODEL_LIST_FILE) self._current_model = {} - self._redis = Redis(host=redis_host, port=redis_port) + if redis is not None: + self._redis = redis + else: + self._redis = Redis(host=redis_host, port=redis_port) def _get_latest_finished_train_job(self, project): - queue = Queue(name=self.redis_queue, connection=self._redis) - registry = FinishedJobRegistry(queue.name, queue.connection) + registry = FinishedJobRegistry(self.redis_queue, self._redis) if registry.count == 0: logger.info('Train job registry is empty.') return None @@ -140,7 +155,7 @@ def setup(self, project, schema): logger.error(f'Found resources {resources}, but model is not loaded for project {project}. ' f'Consequent API calls (e.g. predict) will fail.') return None - self._current_model[project] = model + self._current_model[project] = ModelWrapper(model, model_version) self._stash_resources(project, resources) logger.info(f'Model {model_version} successfully loaded for project {project}.') return model_version @@ -283,8 +298,7 @@ def _flush_all(self, project, redis, reqis_queue): # TODO: do we need the locks here? shutil.rmtree(project_data_dir) - def train_loop(self, data_queue, train_script): - redis = Redis(host=self.redis_host, port=self.redis_port) + def train_loop(self, data_queue, train_script, redis): redis_queue = Queue(name=self.redis_queue, connection=redis) logger.info(f'Train loop starts: PID={os.getpid()}, Redis connection: {redis}, queue: {redis_queue}') for queued_items, in iter(data_queue.get, None):