Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Redis cache. Publisher+Consumer refactoring. #57

Merged
merged 4 commits into from
Aug 31, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,17 @@ services:
- postgres:/data/postgres
networks:
- postgres
redis:
container_name: osint-framework-redis
image: redis:alpine
healthcheck:
test: redis-cli ping
interval: 30s
timeout: 5s
retries: 5
networks:
- redis
restart: unless-stopped
rabbitmq:
container_name: osint-framework-rabbitmq
image: rabbitmq:alpine
Expand Down Expand Up @@ -76,6 +87,7 @@ services:
POSTGRES_PORT: ${POSTGRES_PORT:-5432}
RABBITMQ_HOST: ${RABBITMQ_HOST:-osint-framework-rabbitmq}
RABBITMQ_PORT: ${RABBITMQ_PORT:-5672}
REDIS_HOST: ${REDIS_HOST-osint-framework-redis}
LOG_HANDLER: ${LOG_HANDLER:-stream}
build:
context: .
Expand All @@ -91,10 +103,13 @@ services:
networks:
- postgres
- rabbitmq
- redis
networks:
postgres:
driver: bridge
rabbitmq:
driver: bridge
redis:
driver: bridge
volumes:
postgres:
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ pycares==3.1.1
pycparser==2.20
Pygments==2.6.1
PyYAML==5.3.1
redis==3.5.3
requests==2.24.0
rich==5.1.2
selenium==3.141.0
Expand Down
25 changes: 22 additions & 3 deletions server.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from src.server.handlers.task_spawner import TaskSpawner
from src.server.structures.response import ServerResponse
from src.server.structures.task import TaskItem
from src.server.structures.task import TaskStatus
from src.cache.redis import RedisCache

# Set logging level for Tornado Server
tornado.log.access_log.setLevel(DEBUG)
Expand All @@ -32,6 +34,9 @@
# Initialize publisher
publisher = Publisher()

# Initialize redis
redis = RedisCache()


class BaseHandler(RequestHandler, ABC):
"""
Expand Down Expand Up @@ -170,12 +175,26 @@ def get(self) -> None:
"""
try:
task_id = self.get_argument("task_id", default=None)
results = json_encode(TaskCrud.get_results(task_id))
redis_cache = redis.get(task_id)
# If cache is available - write cache as response
if redis_cache:
logger.info(msg=f"Redis cache is available, task '{task_id}'")
return self.write(redis_cache)
# If cache is not available - get results from the database
db_results = TaskCrud.get_results(task_id)
json_results = dumps(db_results, default=str)
# If status is 'pending' (in progress), skip cache saving, write database results
if db_results.get("task", {}).get("status", "") == TaskStatus.PENDING:
logger.info(msg=f"Status of the task '{task_id}' is '{TaskStatus.PENDING}', skip Redis cache saving")
return self.write(json_results)
# If status is 'error' or 'success' (finished in any way), save the cache and write database results
redis.set(key=task_id, value=json_results)
logger.info(msg=f"Save results to Redis cache, task '{task_id}'")
self.write(json_results)
except Exception as get_results_error:
return self.error(
msg=f"Unexpected error at getting results: {str(get_results_error)}"
)
self.write(results)


class HealthCheckHandler(BaseHandler, ABC):
Expand Down Expand Up @@ -219,7 +238,7 @@ def make_app() -> Application:

# Init rabbitmq queue polling
polling = tornado.ioloop.PeriodicCallback(
lambda: publisher.process_data_events(), 1000
lambda: publisher.process_data_events(), callback_time=1.000
)
polling.start()

Expand Down
Empty file added src/cache/__init__.py
Empty file.
60 changes: 60 additions & 0 deletions src/cache/redis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#!/usr/bin/env python3

from redis import Redis
from os import environ


class DefaultValues:
# 24 hrs
REDIS_TIMEOUT = 86400
REDIS_HOST = environ.get("REDIS_HOST", default="localhost")


class RedisCache:
def __init__(
self,
host: str = DefaultValues.REDIS_HOST,
timeout: int = DefaultValues.REDIS_TIMEOUT,
):
self.options = dict(timeout=timeout)
self.redis = Redis(host=host)

def get(self, key) -> dict or list:
"""
Return redis cache value
:param key: key to get
:return: cache
"""
if self.exists(key):
return self.redis.get(key)
return None

def set(self, key, value, timeout=None) -> None:
"""
Set redis cache value
:param key: key to set
:param value: value to set
:param timeout: timeout to live
:return: None
"""
self.redis.set(key, value)
if timeout:
self.redis.expire(key, timeout)
else:
self.redis.expire(key, self.options["timeout"])

def delitem(self, key) -> None:
"""
Delete cache value
:param key: key to delete
:return: None
"""
self.redis.delete(key)

def exists(self, key) -> bool:
"""
Check if value exists
:param key: key to check
:return: bool
"""
return bool(self.redis.exists(key))
17 changes: 10 additions & 7 deletions src/db/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,25 +118,28 @@ def update_task(task: TaskItem, db: Session = SessionLocal()) -> None:

@staticmethod
@retry()
def get_results(task_id: str, db: Session = SessionLocal()) -> list:
def get_results(task_id: str, db: Session = SessionLocal()) -> dict:
"""
Return results
:param task_id: task id to use
:param db: database to use
:return: dict
"""
# fmt: off
try:
results = (
db.query(models.Result).filter(models.Result.owner_id == task_id).all()
)
db_results = db.query(models.Result).filter(models.Result.owner_id == task_id).all()
db_task_status = db.query(models.Task).filter_by(task_id=task_id).first()
except exc.DBAPIError as api_err:
raise api_err from api_err
except:
return []
return {}
else:
return [loads(str(data.result)) for data in results]
results = [loads(str(data.result)) for data in db_results]
task_status = object_as_dict(db_task_status)
return {"task": task_status, "results": results}
finally:
db.close()
# fmt: on

@staticmethod
@retry()
Expand Down Expand Up @@ -164,7 +167,7 @@ def get_results_count(task_id: str, db: Session = SessionLocal()) -> int:
@retry()
def get_task(task_id: str, db: Session = SessionLocal()) -> dict:
"""
Return task results by UUID
Return task status by UUID
:param task_id: task id to use
:param db: database to use
:return: dict
Expand Down
60 changes: 36 additions & 24 deletions src/queue/consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

from json import loads, dumps

import pika
from pika import BlockingConnection, ConnectionParameters
from pika.adapters.blocking_connection import BlockingChannel
from pika.spec import Basic, BasicProperties

from src.core.runner.manager import CaseManager
from src.core.utils.log import Logger
Expand All @@ -15,62 +17,72 @@

class Consumer:
def __init__(
self, host: str = Default.RABBITMQ_HOST, port: int = Default.RABBITMQ_PORT
self,
host: str = Default.RABBITMQ_HOST,
port: int = Default.RABBITMQ_PORT,
task_queue: str = Default.TASK_QUEUE
):
"""
Init rabbitmq consumer
:param host: rabbitmq host
:param port: rabbitmq port
:param task_queue: queue name
"""
self.queue = Default.QUEUE
self.connection = pika.BlockingConnection(
pika.ConnectionParameters(host=host, port=port,)
self.connection = BlockingConnection(
ConnectionParameters(host=host, port=port,)
)
self.channel = self.connection.channel()
self.channel.queue_declare(queue=self.queue)
self.channel.queue_declare(queue=task_queue)
self.channel.basic_consume(
queue=task_queue,
on_message_callback=self.task_process,
)

self.manager = CaseManager()

def callback(self, ch, method, properties, body) -> None:
def task_process(
self,
channel: BlockingChannel,
method: Basic.Deliver,
properties: BasicProperties,
body: bytes
) -> None:
"""
Process the received task
:param ch: channel
:param channel: channel
:param method: method
:param properties: task properties
:param body: task body
:return: None
"""
raw_body = loads(body)
raw_body = loads(body.decode(encoding="utf-8"))
cases = raw_body.get("cases", {})
task = TaskItem(**raw_body.get("task", {}))

done_tasks = 0
cases_len = len(cases)
for result in self.manager.multi_case_runner(cases=cases):
done_tasks += 1
TaskCrud.create_task_result(task, result or {})
message = f"Done {done_tasks} out of {cases_len} cases"
task.set_pending(message)
logger.info(message)
TaskCrud.update_task(task)
try:
results = list(self.manager.multi_case_runner(cases=cases))
for result in results:
TaskCrud.create_task_result(task, result or {})
task.set_success(msg=f"Task done: {len(results)} out of {len(cases)} cases")
except Exception as cases_err:
task.set_error(msg=f"Task error: {str(cases_err)}")

task.set_success(msg=f"All cases done ({done_tasks} out of {cases_len})")
TaskCrud.update_task(task)
logger.info(msg=f"Done task {task.task_id}")

ch.basic_publish(
channel.basic_publish(
exchange="",
routing_key=properties.reply_to,
properties=pika.BasicProperties(correlation_id=properties.correlation_id),
body=dumps(task.as_json()),
properties=BasicProperties(correlation_id=properties.correlation_id),
body=dumps(task.as_json()).encode(encoding="utf-8"),
)
ch.basic_ack(delivery_tag=method.delivery_tag)
channel.basic_ack(delivery_tag=method.delivery_tag)

def start_consuming(self) -> None:
"""
Run consumer
:return: None
"""
self.channel.basic_consume(queue=self.queue, on_message_callback=self.callback)
self.channel.start_consuming()

def __del__(self):
Expand Down
3 changes: 2 additions & 1 deletion src/queue/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ class DefaultValues:
RABBITMQ_HOST = str(environ.get("RABBITMQ_HOST", default="localhost"))
RABBITMQ_PORT = int(environ.get("RABBITMQ_PORT", default=5672))

QUEUE = "case_queue"
TASK_QUEUE = "task_queue"
RESPONSE_QUEUE = "response_queue"
Loading