Skip to content
This repository was archived by the owner on Feb 10, 2022. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions htx/htx.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import json
import multiprocessing as mp
import logging
import threading

from queue import Queue
from flask import Flask, request, jsonify

from htx.model_manager import ModelManager
Expand All @@ -19,9 +21,9 @@ def init_model_server(**kwargs):

@_server.before_first_request
def launch_train_loop():
train_process = mp.Process(
train_process = threading.Thread(
target=_model_manager.train_loop,
args=(_model_manager.queue, _model_manager.train_script)
args=(_model_manager.queue, _model_manager.train_script, _model_manager.redis, )
)
train_process.start()

Expand Down
30 changes: 22 additions & 8 deletions htx/model_manager.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import os
import multiprocessing as mp
import logging
import json
import attr
import io
import shutil
import queue

from operator import attrgetter
from redis import Redis
Expand All @@ -16,6 +16,18 @@
logger = logging.getLogger(__name__)


@attr.s
class ModelWrapper(object):
model = attr.ib()
version = attr.ib()

def predict(self, *args, **kwargs):
return self.model.predict(*args, **kwargs)

def get_data_item(self, *args, **kwargs):
return self.model.get_data_item(*args, **kwargs)


class QueuedItem(object):

def __init__(self, project):
Expand Down Expand Up @@ -54,7 +66,7 @@ class ModelManager(object):

_MODEL_LIST_FILE = 'model_list.txt'
_DEFAULT_MODEL_VERSION = 'model'
queue = mp.Queue()
queue = queue.Queue()

def __init__(
self,
Expand All @@ -64,6 +76,7 @@ def __init__(
data_dir='~/.heartex/data',
min_examples_for_train=1,
retrain_after_num_examples=1,
redis=None,
redis_host='localhost',
redis_port=6379,
redis_queue='default',
Expand All @@ -87,11 +100,13 @@ def __init__(
self.model_list_file = os.path.join(self.model_dir, self._MODEL_LIST_FILE)

self._current_model = {}
self._redis = Redis(host=redis_host, port=redis_port)
if redis is not None:
self._redis = redis
else:
self._redis = Redis(host=redis_host, port=redis_port)

def _get_latest_finished_train_job(self, project):
queue = Queue(name=self.redis_queue, connection=self._redis)
registry = FinishedJobRegistry(queue.name, queue.connection)
registry = FinishedJobRegistry(self.redis_queue, self._redis)
if registry.count == 0:
logger.info('Train job registry is empty.')
return None
Expand Down Expand Up @@ -140,7 +155,7 @@ def setup(self, project, schema):
logger.error(f'Found resources {resources}, but model is not loaded for project {project}. '
f'Consequent API calls (e.g. predict) will fail.')
return None
self._current_model[project] = model
self._current_model[project] = ModelWrapper(model, model_version)
self._stash_resources(project, resources)
logger.info(f'Model {model_version} successfully loaded for project {project}.')
return model_version
Expand Down Expand Up @@ -283,8 +298,7 @@ def _flush_all(self, project, redis, reqis_queue):
# TODO: do we need the locks here?
shutil.rmtree(project_data_dir)

def train_loop(self, data_queue, train_script):
redis = Redis(host=self.redis_host, port=self.redis_port)
def train_loop(self, data_queue, train_script, redis):
redis_queue = Queue(name=self.redis_queue, connection=redis)
logger.info(f'Train loop starts: PID={os.getpid()}, Redis connection: {redis}, queue: {redis_queue}')
for queued_items, in iter(data_queue.get, None):
Expand Down