Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,8 @@ jobs:
\"service-port\": ${RABBITMQ_PORT},
\"service-host\": \"${RABBITMQ_HOST}\",
\"rabbitmq-vhost\": \"/\",
\"rabbitmq-queue-physics\": \"test-ci\",
\"rabbitmq-exchange-physics\": \"exchange-ci\",
\"rabbitmq-key-physics\": \"queue-ci\",
\"rabbitmq-exchange-training\": \"ams-fanout\",
\"rabbitmq-key-training\": \"training\"
}""" > rmq.json
Expand Down
149 changes: 96 additions & 53 deletions src/AMSWorkflow/ams/rmq.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import json
import pika


class AMSMessage(object):
"""
Represents a RabbitMQ incoming message from AMSLib.
Expand Down Expand Up @@ -225,12 +224,16 @@ class AMSChannel:
def __init__(
self,
connection,
q_name,
exchange,
routing_key,
queue = "",
callback: Optional[Callable] = None,
logger: Optional[logging.Logger] = None,
):
self.connection = connection
self.q_name = q_name
self.exchange = exchange
self.routing_key = routing_key
self.q_name = queue
self.logger = logger if logger else logging.getLogger(__name__)
self.callback = callback if callback else self.default_callback

Expand All @@ -247,7 +250,16 @@ def default_callback(self, method, properties, body):

def open(self):
self.channel = self.connection.channel()
self.channel.queue_declare(queue=self.q_name)
if self.exchange != '':
self.logger.info(f"Declared exchange {self.exchange}")
self.channel.exchange_declare(exchange = self.exchange, exchange_type = "direct")

result = self.channel.queue_declare(queue = self.q_name, exclusive = False, durable = False)
self.q_name = result.method.queue
self.logger.info(f"Declared queue {self.q_name}")
if self.exchange != '':
self.logger.info(f"Binding queue {self.q_name} to exchange {self.exchange}")
self.channel.queue_bind(exchange = self.exchange, queue = self.q_name, routing_key = self.routing_key)

def close(self):
self.channel.close()
Expand Down Expand Up @@ -308,7 +320,7 @@ def send(self, text: str, exchange: str = ""):
@param text The text to send
@param exchange Exchange to use
"""
self.channel.basic_publish(exchange=exchange, routing_key=self.q_name, body=text)
self.channel.basic_publish(exchange = exchange, routing_key = self.routing_key, body=text)
return

def get_messages(self):
Expand Down Expand Up @@ -374,9 +386,11 @@ def __enter__(self):
def __exit__(self, exc_type, exc_val, exc_tb):
self.connection.close()

def connect(self, queue):
"""Connect to the queue"""
return AMSChannel(self.connection, queue, self.callback)
def connect(self, exchange, routing_key, queue = ""):
"""
Connect to the exchange and routing key.
"""
return AMSChannel(self.connection, exchange, routing_key, queue, self.callback)


class StatusPoller(BlockingClient):
Expand All @@ -401,8 +415,11 @@ def __init__(
user: str,
password: str,
cert: str,
queue: str,
prefetch_count: int = 1,
exchange: str,
routing_key: str,
queue: str = "",
prefetch_count: int = 0,
exchange_type: str = "direct",
on_message_cb: Optional[Callable] = None,
on_close_cb: Optional[Callable] = None,
logger: Optional[logging.Logger] = None,
Expand All @@ -425,6 +442,9 @@ def __init__(
self._vhost = vhost
self._cacert = cert
self._queue = queue
self._exchange = exchange
self._exchange_type = exchange_type
self._routing_key = routing_key

self.should_reconnect = False
# Holds the latest error/reason to reconnect
Expand Down Expand Up @@ -568,8 +588,7 @@ def on_channel_open(self, channel):
self._channel = channel
self.logger.debug("Channel opened")
self.add_on_channel_close_callback()
# we do not set up exchange first here, we use the default exchange ''
self.setup_queue(self._queue)
self.setup_exchange(self._exchange, self._exchange_type)

def add_on_channel_close_callback(self):
"""This method tells pika to call the on_channel_closed method if
Expand All @@ -595,6 +614,33 @@ def on_channel_closed(self, channel, reason):
self._on_close_cb() # running user callback
self.close_connection()

def setup_exchange(self, exchange_name, exchange_type):
"""Setup the exchange on RabbitMQ by invoking the Exchange.Declare RPC
command. When it is complete, the on_exchange_declareok method will
be invoked by pika.

:param str|unicode exchange_name: The name of the exchange to declare

"""
self.logger.debug(f"Declaring exchange: '{exchange_name}'")
cb = functools.partial(
self.on_exchange_declareok, userdata = exchange_name)
self._channel.exchange_declare(
exchange = exchange_name,
exchange_type = exchange_type,
callback = cb)

def on_exchange_declareok(self, _unused_frame, userdata):
"""Invoked by pika when RabbitMQ has finished the Exchange.Declare RPC
command.

:param pika.Frame.Method unused_frame: Exchange.DeclareOk response frame
:param str|unicode userdata: Extra user data (exchange name)

"""
self.logger.debug(f"Exchange declared: '{userdata}'")
self.setup_queue(self._queue)

def setup_queue(self, queue_name):
"""Setup the queue on RabbitMQ by invoking the Queue.Declare RPC
command. When it is complete, the on_queue_declareok method will
Expand All @@ -603,7 +649,7 @@ def setup_queue(self, queue_name):
:param str|unicode queue_name: The name of the queue to declare.

"""
self.logger.debug(f'Declaring queue "{queue_name}"')
self.logger.debug(f"Declaring queue '{queue_name}'")
cb = functools.partial(self.on_queue_declareok, userdata=queue_name)
# arguments = {"x-consumer-timeout":1800000} # 30 minutes in ms
self._channel.queue_declare(queue=queue_name, exclusive=False, callback=cb)
Expand All @@ -620,7 +666,23 @@ def on_queue_declareok(self, _unused_frame, userdata):

"""
queue_name = userdata
self.logger.debug(f'Queue "{queue_name}" declared')
self.logger.info(f"Binding {self._exchange} to queue '{queue_name}' with key '{self._routing_key}'")
cb = functools.partial(self.on_bindok, userdata=queue_name)
self._channel.queue_bind(
queue_name,
self._exchange,
routing_key=self._routing_key,
callback=cb)

def on_bindok(self, _unused_frame, userdata):
"""Invoked by pika when the Queue.Bind method has completed. At this
point we will set the prefetch count for the channel.

:param pika.frame.Method _unused_frame: The Queue.BindOk response frame
:param str|unicode userdata: Extra user data (queue name)

"""
self.logger.debug(f"Queue bound: '{userdata}'")
self.set_qos()

def set_qos(self):
Expand Down Expand Up @@ -772,51 +834,29 @@ def __init__(
user: str,
password: str,
cert: str,
queue: str,
prefetch_count: int = 1,
routing_key: str,
prefetch_count: int = 0,
on_message_cb: Optional[Callable] = None,
on_close_cb: Optional[Callable] = None,
logger: Optional[logging.Logger] = None,
):
super().__init__(
host,
port,
vhost,
user,
password,
cert,
queue,
prefetch_count,
on_message_cb,
on_close_cb,
logger,
)

# Callback when the channel is open
def on_channel_open(self, channel):
self._channel = channel
self.logger.debug("Channel opened")
self.add_on_channel_close_callback()
self._channel.exchange_declare(
host=host,
port=port,
vhost=vhost,
user=user,
password=password,
cert=cert,
exchange="control-panel",
routing_key=routing_key,
queue="",
exchange_type="fanout",
callback=self.on_exchange_declared,
prefetch_count=prefetch_count,
on_message_cb=on_message_cb,
on_close_cb=on_close_cb,
logger=logger,
)

# Callback when the exchange is declared
def on_exchange_declared(self, frame):
self._channel.queue_declare(queue="", exclusive=True, callback=self.on_queue_declared)

# Callback when the queue is declared
def on_queue_declared(self, queue_result):
self._queue = queue_result.method.queue
self._channel.queue_bind(exchange="control-panel", queue=self._queue, callback=self.on_queue_bound)

# Callback when the queue is bound to the exchange
def on_queue_bound(self, frame):
self.set_qos()


class AMSSyncProducer:
def __init__(
self,
Expand Down Expand Up @@ -947,7 +987,8 @@ class AMSRMQConfiguration:
"rabbitmq-user": "",
"rabbitmq-vhost": "",
"rabbitmq-cert": "",
"rabbitmq-queue-physics": "",
"rabbitmq-exchange-physics": "",
"rabbitmq-key-physics": "",
"rabbitmq-exchange-training": "",
"rabbitmq-key-training": ""
},
Expand All @@ -962,7 +1003,8 @@ class AMSRMQConfiguration:
rabbitmq_user: str
rabbitmq_vhost: str
rabbitmq_cert: str
rabbitmq_queue_physics: str
rabbitmq_exchange_physics: str
rabbitmq_key_physics: str
rabbitmq_exchange_training: str = ""
rabbitmq_key_training: str = ""
rabbitmq_ml_submit_queue: str = ""
Expand Down Expand Up @@ -994,7 +1036,8 @@ def to_dict(self, AMSlib=False):
"rabbitmq-user": self.rabbitmq_user,
"rabbitmq-vhost": self.rabbitmq_vhost,
"rabbitmq-cert": self.rabbitmq_cert,
"rabbitmq-queue-physics": self.rabbitmq_queue_physics,
"rabbitmq-exchange-physics": self.rabbitmq_exchange_physics,
"rabbitmq-key-physics": self.rabbitmq_key_physics,
"rabbitmq-exchange-training": self.rabbitmq_exchange_training,
"rabbitmq-key-training": self.rabbitmq_key_training,
"rabbitmq-ml-submit-queue": self.rabbitmq_ml_submit_queue,
Expand Down
36 changes: 23 additions & 13 deletions src/AMSWorkflow/ams/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ class RMQDomainDataLoaderTask(Task):
"""
A RMQDomainDataLoaderTask consumes 'AMSMessages' from RabbitMQ bundles the data of
the files into batches and forwards them to the next task waiting on the
output queuee.
output queue.

Attributes:
o_queue: The output queue to write the transformed messages
Expand All @@ -281,14 +281,17 @@ def __init__(
user,
password,
cert,
rmq_queue,
rmq_exchange,
rmq_routing_key,
policy,
prefetch_count=1,
prefetch_count = 0,
signals=[signal.SIGINT, signal.SIGUSR1],
):
self.o_queue = o_queue
self.cert = cert
self.rmq_queue = rmq_queue
# self.rmq_queue = rmq_queue
self.rmq_exchange = rmq_exchange
self.rmq_routing_key = rmq_routing_key
self.prefetch_count = prefetch_count
self.datasize_byte = 0
self.total_time_ns = 0
Expand All @@ -311,7 +314,9 @@ def __init__(
user=user,
password=password,
cert=self.cert,
queue=self.rmq_queue,
exchange=self.rmq_exchange,
routing_key=self.rmq_routing_key,
queue="",
on_message_cb=self.callback_message,
on_close_cb=self.callback_close,
prefetch_count=self.prefetch_count,
Expand Down Expand Up @@ -428,7 +433,7 @@ def __init__(
user: str,
password: str,
cert: str,
prefetch_count: int = 1,
prefetch_count: int = 0,
):
self._consumers = consumers
super().__init__(
Expand Down Expand Up @@ -1061,7 +1066,8 @@ def __init__(
user,
password,
cert,
data_queue,
exchange,
routing_key,
model_update_queue=None,
):
"""
Expand All @@ -1075,10 +1081,12 @@ def __init__(
self._user = user
self._password = password
self._cert = Path(cert)
self._data_queue = data_queue
# self._data_queue = data_queue
self._exchange = exchange
self._routing_key = routing_key
self._model_update_queue = model_update_queue
print("Received a data queue of", self._data_queue)
print("Received a model_update queue of", self._model_update_queue)
print(f"Received data from exchange {self._exchange} / rkey {self._routing_key}")
print(f"Received a model_update queue of {self._model_update_queue}")
self._gracefull_shutdown = None
self._o_queue = None

Expand All @@ -1101,9 +1109,10 @@ def get_load_task(self, o_queue, policy):
self._user,
self._password,
self._cert,
self._data_queue,
self._exchange,
self._routing_key,
policy,
prefetch_count=1,
prefetch_count = 0,
)
self._o_queue = o_queue
self._gracefull_shutdown = AMSShutdown(
Expand Down Expand Up @@ -1176,7 +1185,8 @@ def from_cli(cls, args):
config.rabbitmq_user,
config.rabbitmq_password,
config.rabbitmq_cert,
config.rabbitmq_queue_physics,
config.rabbitmq_exchange_physics,
config.rabbitmq_key_physics,
config.rabbitmq_exchange_training if args.update_rmq_models else None,
)

Expand Down
Loading