Skip to content
Draft
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,8 @@ jobs:
\"service-port\": ${RABBITMQ_PORT},
\"service-host\": \"${RABBITMQ_HOST}\",
\"rabbitmq-vhost\": \"/\",
\"rabbitmq-queue-physics\": \"test-ci\",
\"rabbitmq-exchange-physics\": \"exchange-ci\",
\"rabbitmq-key-physics\": \"queue-ci\",
\"rabbitmq-exchange-training\": \"ams-fanout\",
\"rabbitmq-key-training\": \"training\"
}""" > rmq.json
Expand Down
167 changes: 119 additions & 48 deletions src/AMSWorkflow/ams/rmq.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import numpy as np
import json
import pika
# from pika.exchange_type import ExchangeType


class AMSMessage(object):
Expand Down Expand Up @@ -225,12 +226,15 @@ class AMSChannel:
def __init__(
self,
connection,
q_name,
exchange,
routing_key,
callback: Optional[Callable] = None,
logger: Optional[logging.Logger] = None,
):
self.connection = connection
self.q_name = q_name
self.exchange = exchange
self.routing_key = routing_key
self.q_name = None
self.logger = logger if logger else logging.getLogger(__name__)
self.callback = callback if callback else self.default_callback

Expand All @@ -247,7 +251,18 @@ def default_callback(self, method, properties, body):

def open(self):
self.channel = self.connection.channel()
self.channel.queue_declare(queue=self.q_name)
q_name = self.routing_key
if self.exchange != '':
self.logger.info(f"Declared exchange {self.exchange}")
self.channel.exchange_declare(exchange = self.exchange, exchange_type = "direct")
q_name = "ams-debug" #TODO CHANGE

result = self.channel.queue_declare(queue = q_name, exclusive = False, durable = False)
self.q_name = result.method.queue
self.logger.info(f"Declared queue {self.q_name}")
if self.exchange != '':
self.logger.info(f"Binding queue {self.q_name} to exchange {self.exchange}")
self.channel.queue_bind(exchange = self.exchange, queue = self.q_name, routing_key = self.routing_key)

def close(self):
self.channel.close()
Expand Down Expand Up @@ -308,7 +323,7 @@ def send(self, text: str, exchange: str = ""):
@param text The text to send
@param exchange Exchange to use
"""
self.channel.basic_publish(exchange=exchange, routing_key=self.q_name, body=text)
self.channel.basic_publish(exchange = exchange, routing_key = self.routing_key, body=text)
return

def get_messages(self):
Expand Down Expand Up @@ -374,9 +389,11 @@ def __enter__(self):
def __exit__(self, exc_type, exc_val, exc_tb):
self.connection.close()

def connect(self, queue):
"""Connect to the queue"""
return AMSChannel(self.connection, queue, self.callback)
def connect(self, exchange, routing_key):
"""
Connect to the exchange and routing key.
"""
return AMSChannel(self.connection, exchange, routing_key, self.callback)


class StatusPoller(BlockingClient):
Expand All @@ -401,8 +418,11 @@ def __init__(
user: str,
password: str,
cert: str,
queue: str,
prefetch_count: int = 1,
exchange: str,
routing_key: str,
queue: str = "",
prefetch_count: int = 0,
exchange_type: str = "direct",
on_message_cb: Optional[Callable] = None,
on_close_cb: Optional[Callable] = None,
logger: Optional[logging.Logger] = None,
Expand All @@ -425,6 +445,9 @@ def __init__(
self._vhost = vhost
self._cacert = cert
self._queue = queue
self._exchange = exchange
self._exchange_type = exchange_type
self._routing_key = routing_key

self.should_reconnect = False
# Holds the latest error/reason to reconnect
Expand Down Expand Up @@ -568,8 +591,7 @@ def on_channel_open(self, channel):
self._channel = channel
self.logger.debug("Channel opened")
self.add_on_channel_close_callback()
# we do not set up exchange first here, we use the default exchange ''
self.setup_queue(self._queue)
self.setup_exchange(self._exchange, self._exchange_type)

def add_on_channel_close_callback(self):
"""This method tells pika to call the on_channel_closed method if
Expand All @@ -595,6 +617,33 @@ def on_channel_closed(self, channel, reason):
self._on_close_cb() # running user callback
self.close_connection()

def setup_exchange(self, exchange_name, exchange_type):
"""Setup the exchange on RabbitMQ by invoking the Exchange.Declare RPC
command. When it is complete, the on_exchange_declareok method will
be invoked by pika.

:param str|unicode exchange_name: The name of the exchange to declare

"""
self.logger.debug(f"Declaring exchange: '{exchange_name}'")
cb = functools.partial(
self.on_exchange_declareok, userdata = exchange_name)
self._channel.exchange_declare(
exchange = exchange_name,
exchange_type = exchange_type,
callback = cb)

def on_exchange_declareok(self, _unused_frame, userdata):
"""Invoked by pika when RabbitMQ has finished the Exchange.Declare RPC
command.

:param pika.Frame.Method unused_frame: Exchange.DeclareOk response frame
:param str|unicode userdata: Extra user data (exchange name)

"""
self.logger.debug(f"Exchange declared: '{userdata}'")
self.setup_queue(self._queue)

def setup_queue(self, queue_name):
"""Setup the queue on RabbitMQ by invoking the Queue.Declare RPC
command. When it is complete, the on_queue_declareok method will
Expand All @@ -603,7 +652,7 @@ def setup_queue(self, queue_name):
:param str|unicode queue_name: The name of the queue to declare.

"""
self.logger.debug(f'Declaring queue "{queue_name}"')
self.logger.debug(f"Declaring queue '{queue_name}'")
cb = functools.partial(self.on_queue_declareok, userdata=queue_name)
# arguments = {"x-consumer-timeout":1800000} # 30 minutes in ms
self._channel.queue_declare(queue=queue_name, exclusive=False, callback=cb)
Expand All @@ -620,7 +669,23 @@ def on_queue_declareok(self, _unused_frame, userdata):

"""
queue_name = userdata
self.logger.debug(f'Queue "{queue_name}" declared')
self.logger.info(f"Binding {self._exchange} to queue '{queue_name}' with key '{self._routing_key}'")
cb = functools.partial(self.on_bindok, userdata=queue_name)
self._channel.queue_bind(
queue_name,
self._exchange,
routing_key=self._routing_key,
callback=cb)

def on_bindok(self, _unused_frame, userdata):
"""Invoked by pika when the Queue.Bind method has completed. At this
point we will set the prefetch count for the channel.

:param pika.frame.Method _unused_frame: The Queue.BindOk response frame
:param str|unicode userdata: Extra user data (queue name)

"""
self.logger.debug(f"Queue bound: '{userdata}'")
self.set_qos()

def set_qos(self):
Expand Down Expand Up @@ -772,49 +837,52 @@ def __init__(
user: str,
password: str,
cert: str,
queue: str,
prefetch_count: int = 1,
routing_key: str,
prefetch_count: int = 0,
on_message_cb: Optional[Callable] = None,
on_close_cb: Optional[Callable] = None,
logger: Optional[logging.Logger] = None,
):
super().__init__(
host,
port,
vhost,
user,
password,
cert,
queue,
prefetch_count,
on_message_cb,
on_close_cb,
logger,
)

# Callback when the channel is open
def on_channel_open(self, channel):
self._channel = channel
self.logger.debug("Channel opened")
self.add_on_channel_close_callback()
self._channel.exchange_declare(
host=host,
port=port,
vhost=vhost,
user=user,
password=password,
cert=cert,
exchange="control-panel",
routing_key=routing_key,
queue="",
exchange_type="fanout",
callback=self.on_exchange_declared,
prefetch_count=prefetch_count,
on_message_cb=on_message_cb,
on_close_cb=on_close_cb,
logger=logger,
)

# Callback when the exchange is declared
def on_exchange_declared(self, frame):
self._channel.queue_declare(queue="", exclusive=True, callback=self.on_queue_declared)
# # Callback when the channel is open
# def on_channel_open(self, channel):
# self._channel = channel
# self.logger.debug("Channel opened")
# self.add_on_channel_close_callback()
# self._channel.exchange_declare(
# exchange="control-panel",
# exchange_type="fanout",
# callback=self.on_exchange_declared,
# )

# Callback when the queue is declared
def on_queue_declared(self, queue_result):
self._queue = queue_result.method.queue
self._channel.queue_bind(exchange="control-panel", queue=self._queue, callback=self.on_queue_bound)
# # Callback when the exchange is declared
# def on_exchange_declared(self, frame):
# self._channel.queue_declare(queue="", exclusive=True, callback=self.on_queue_declared)

# Callback when the queue is bound to the exchange
def on_queue_bound(self, frame):
self.set_qos()
# # Callback when the queue is declared
# def on_queue_declared(self, queue_result):
# self._queue = queue_result.method.queue
# self._channel.queue_bind(exchange="control-panel", queue=self._queue, callback=self.on_queue_bound)

# # Callback when the queue is bound to the exchange
# def on_queue_bound(self, frame):
# self.set_qos()


class AMSSyncProducer:
Expand Down Expand Up @@ -947,7 +1015,8 @@ class AMSRMQConfiguration:
"rabbitmq-user": "",
"rabbitmq-vhost": "",
"rabbitmq-cert": "",
"rabbitmq-queue-physics": "",
"rabbitmq-exchange-physics": "",
"rabbitmq-key-physics": "",
"rabbitmq-exchange-training": "",
"rabbitmq-key-training": ""
},
Expand All @@ -962,7 +1031,8 @@ class AMSRMQConfiguration:
rabbitmq_user: str
rabbitmq_vhost: str
rabbitmq_cert: str
rabbitmq_queue_physics: str
rabbitmq_exchange_physics: str
rabbitmq_key_physics: str
rabbitmq_exchange_training: str = ""
rabbitmq_key_training: str = ""
rabbitmq_ml_submit_queue: str = ""
Expand Down Expand Up @@ -994,7 +1064,8 @@ def to_dict(self, AMSlib=False):
"rabbitmq-user": self.rabbitmq_user,
"rabbitmq-vhost": self.rabbitmq_vhost,
"rabbitmq-cert": self.rabbitmq_cert,
"rabbitmq-queue-physics": self.rabbitmq_queue_physics,
"rabbitmq-exchange-physics": self.rabbitmq_exchange_physics,
"rabbitmq-key-physics": self.rabbitmq_key_physics,
"rabbitmq-exchange-training": self.rabbitmq_exchange_training,
"rabbitmq-key-training": self.rabbitmq_key_training,
"rabbitmq-ml-submit-queue": self.rabbitmq_ml_submit_queue,
Expand Down
36 changes: 23 additions & 13 deletions src/AMSWorkflow/ams/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ class RMQDomainDataLoaderTask(Task):
"""
A RMQDomainDataLoaderTask consumes 'AMSMessages' from RabbitMQ bundles the data of
the files into batches and forwards them to the next task waiting on the
output queuee.
output queue.

Attributes:
o_queue: The output queue to write the transformed messages
Expand All @@ -281,14 +281,17 @@ def __init__(
user,
password,
cert,
rmq_queue,
rmq_exchange,
rmq_routing_key,
policy,
prefetch_count=1,
prefetch_count = 0,
signals=[signal.SIGINT, signal.SIGUSR1],
):
self.o_queue = o_queue
self.cert = cert
self.rmq_queue = rmq_queue
# self.rmq_queue = rmq_queue
self.rmq_exchange = rmq_exchange
self.rmq_routing_key = rmq_routing_key
self.prefetch_count = prefetch_count
self.datasize_byte = 0
self.total_time_ns = 0
Expand All @@ -311,7 +314,9 @@ def __init__(
user=user,
password=password,
cert=self.cert,
queue=self.rmq_queue,
exchange=self.rmq_exchange,
routing_key=self.rmq_routing_key,
queue="",
on_message_cb=self.callback_message,
on_close_cb=self.callback_close,
prefetch_count=self.prefetch_count,
Expand Down Expand Up @@ -428,7 +433,7 @@ def __init__(
user: str,
password: str,
cert: str,
prefetch_count: int = 1,
prefetch_count: int = 0,
):
self._consumers = consumers
super().__init__(
Expand Down Expand Up @@ -1061,7 +1066,8 @@ def __init__(
user,
password,
cert,
data_queue,
exchange,
routing_key,
model_update_queue=None,
):
"""
Expand All @@ -1075,10 +1081,12 @@ def __init__(
self._user = user
self._password = password
self._cert = Path(cert)
self._data_queue = data_queue
# self._data_queue = data_queue
self._exchange = exchange
self._routing_key = routing_key
self._model_update_queue = model_update_queue
print("Received a data queue of", self._data_queue)
print("Received a model_update queue of", self._model_update_queue)
print(f"Received data from exchange {self._exchange} / rkey {self._routing_key}")
print(f"Received a model_update queue of {self._model_update_queue}")
self._gracefull_shutdown = None
self._o_queue = None

Expand All @@ -1101,9 +1109,10 @@ def get_load_task(self, o_queue, policy):
self._user,
self._password,
self._cert,
self._data_queue,
self._exchange,
self._routing_key,
policy,
prefetch_count=1,
prefetch_count = 0,
)
self._o_queue = o_queue
self._gracefull_shutdown = AMSShutdown(
Expand Down Expand Up @@ -1176,7 +1185,8 @@ def from_cli(cls, args):
config.rabbitmq_user,
config.rabbitmq_password,
config.rabbitmq_cert,
config.rabbitmq_queue_physics,
config.rabbitmq_exchange_physics,
config.rabbitmq_key_physics,
config.rabbitmq_exchange_training if args.update_rmq_models else None,
)

Expand Down
Loading
Loading