Skip to content

Commit d888d8d

Browse files
[xy] Add AzureEventHub source to streaming pipeline (mage-ai#1260)
* [xy] Add AzureEventHub source. * [xy] Add unit tests to streaming module. * [xy] Add unit test for AzureEventHub. * [xy] Update event handling logic.
1 parent 2437064 commit d888d8d

28 files changed

+320
-44
lines changed

mage_ai/data_preparation/executors/streaming_pipeline_executor.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
from contextlib import redirect_stdout
1+
from contextlib import redirect_stderr, redirect_stdout
22
from mage_ai.data_preparation.executors.pipeline_executor import PipelineExecutor
33
from mage_ai.data_preparation.models.constants import BlockType
44
from mage_ai.data_preparation.models.pipeline import Pipeline
55
from mage_ai.data_preparation.shared.stream import StreamToLogger
6-
from typing import Callable, Dict
6+
from typing import Callable, Dict, List, Union
77
import yaml
88

99

@@ -74,7 +74,8 @@ def execute(
7474
stdout = StreamToLogger(self.logger)
7575
try:
7676
with redirect_stdout(stdout):
77-
self.__execute_in_python()
77+
with redirect_stderr(stdout):
78+
self.__execute_in_python()
7879
except Exception as e:
7980
if not build_block_output_stdout:
8081
self.logger.exception(
@@ -90,13 +91,17 @@ def __execute_in_python(self):
9091
sink_config = yaml.safe_load(self.sink_block.content)
9192
source = SourceFactory.get_source(source_config)
9293
sink = SinkFactory.get_sink(sink_config)
93-
for messages in source.batch_read():
94+
95+
def handle_batch_events(messages: List[Union[Dict, str]]):
9496
if self.transformer_block is not None:
9597
messages = self.transformer_block.execute_block(
9698
input_args=[messages],
9799
)['output']
98100
sink.batch_write(messages)
99101

102+
# Long running method
103+
source.batch_read(handler=handle_batch_events)
104+
100105
def __excute_in_flink(self):
101106
"""
102107
TODO: Implement this method

mage_ai/streaming/constants.py

+1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33

44
class SourceType(str, Enum):
5+
AZURE_EVENT_HUB = 'azure_event_hub'
56
KAFKA = 'kafka'
67

78

mage_ai/streaming/sinks/base.py

+13
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,20 @@
11
from abc import ABC, abstractmethod
2+
from typing import Dict
23

34

45
class BaseSink(ABC):
6+
config_class = None
7+
8+
def __init__(self, config: Dict):
9+
if self.config_class is not None:
10+
if 'connector_type' in config:
11+
config.pop('connector_type')
12+
self.config = self.config_class.load(config=config)
13+
self.init_client()
14+
15+
def init_client():
16+
pass
17+
518
@abstractmethod
619
def write(self, data):
720
pass

mage_ai/streaming/sinks/opensearch.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
from dataclasses import dataclass
2+
from mage_ai.shared.config import BaseConfig
3+
from mage_ai.streaming.sinks.base import BaseSink
24
from opensearchpy import OpenSearch, RequestsHttpConnection
35
from opensearchpy.helpers import bulk
46
from requests_aws4auth import AWS4Auth
@@ -8,19 +10,17 @@
810

911

1012
@dataclass
11-
class OpensearchSinkConfig:
13+
class OpensearchConfig(BaseConfig):
1214
host: str
1315
index_name: str
1416
verify_certs: bool = True
1517
http_auth: str = '@awsauth'
1618

1719

18-
class OpenSearchSink():
19-
def __init__(self, config: Dict):
20-
if 'connector_type' in config:
21-
config.pop('connector_type')
22-
self.config = OpensearchSinkConfig(**config)
20+
class OpenSearchSink(BaseSink):
21+
config_class = OpensearchConfig
2322

23+
def init_client(self):
2424
# Initialize opensearch client
2525
if self.config.http_auth == '@awsauth':
2626
session = boto3.Session()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
from azure.eventhub import EventHubConsumerClient
2+
from dataclasses import dataclass
3+
from mage_ai.shared.config import BaseConfig
4+
from mage_ai.streaming.sources.base import BaseSource
5+
from typing import Callable, List
6+
import traceback
7+
8+
9+
@dataclass
10+
class AzureEventHubConfig(BaseConfig):
11+
connection_str: str
12+
eventhub_name: str
13+
consumer_group: str = '$Default'
14+
15+
16+
class AzureEventHubSource(BaseSource):
17+
config_class = AzureEventHubConfig
18+
19+
def init_client(self):
20+
self.consumer_client = EventHubConsumerClient.from_connection_string(
21+
conn_str=self.config.connection_str,
22+
consumer_group=self.config.consumer_group,
23+
eventhub_name=self.config.eventhub_name,
24+
)
25+
26+
def read(self, handler: Callable):
27+
try:
28+
def on_event(partition_context, event):
29+
print(f'Received event from partition: {partition_context.partition_id}.')
30+
print(f'Event: {event}')
31+
32+
handler(dict(data=event.body_as_str()))
33+
34+
with self.consumer_client:
35+
self.consumer_client.receive(
36+
on_event=on_event,
37+
on_partition_initialize=self.on_partition_initialize,
38+
on_partition_close=self.on_partition_close,
39+
on_error=self.on_error,
40+
starting_position='-1', # '-1' is from the beginning of the partition.
41+
)
42+
except KeyboardInterrupt:
43+
print('Stopped receiving.')
44+
45+
def batch_read(self, handler: Callable):
46+
try:
47+
def on_event_batch(partition_context, event_batch: List):
48+
if len(event_batch) == 0:
49+
return
50+
print(f'Partition {partition_context.partition_id},'
51+
f'Received count: {len(event_batch)}')
52+
print(f'Sample event: {event_batch[0]}')
53+
54+
# Handle events
55+
try:
56+
handler([dict(data=e.body_as_str()) for e in event_batch])
57+
except Exception as e:
58+
traceback.print_exc()
59+
raise e
60+
61+
partition_context.update_checkpoint()
62+
63+
with self.consumer_client:
64+
self.consumer_client.receive_batch(
65+
on_event_batch=on_event_batch,
66+
max_batch_size=100,
67+
on_partition_initialize=self.on_partition_initialize,
68+
on_partition_close=self.on_partition_close,
69+
on_error=self.on_error,
70+
starting_position='-1', # '-1' is from the beginning of the partition.
71+
)
72+
except KeyboardInterrupt:
73+
print('Stopped receiving.')
74+
75+
def test_connection(self):
76+
return True
77+
78+
def on_partition_initialize(partition_context):
79+
print(f'Partition: {partition_context.partition_id} has been initialized.')
80+
81+
def on_partition_close(partition_context, reason):
82+
print(f'Partition: {partition_context.partition_id} has been closed, '
83+
f'reason for closing: {reason}.')
84+
85+
def on_error(partition_context, error):
86+
if partition_context:
87+
print(f'An exception: {partition_context.partition_id} occurred during'
88+
f' receiving from Partition: {error}.')
89+
else:
90+
print(f'An exception: {error} occurred during the load balance process.')

mage_ai/streaming/sources/base.py

+13
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,20 @@
11
from abc import ABC, abstractmethod
2+
from typing import Dict
23

34

45
class BaseSource(ABC):
6+
config_class = None
7+
8+
def __init__(self, config: Dict):
9+
if self.config_class is not None:
10+
if 'connector_type' in config:
11+
config.pop('connector_type')
12+
self.config = self.config_class.load(config=config)
13+
self.init_client()
14+
15+
def init_client():
16+
pass
17+
518
@abstractmethod
619
def read(self):
720
pass

mage_ai/streaming/sources/kafka.py

+10-11
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from dataclasses import dataclass
22
from kafka import KafkaConsumer
33
from mage_ai.shared.config import BaseConfig
4-
from typing import Dict
4+
from mage_ai.streaming.sources.base import BaseSource
5+
from typing import Callable, Dict
56
import json
67
import time
78

@@ -18,7 +19,7 @@ class SSLConfig:
1819

1920

2021
@dataclass
21-
class KafkaSourceConfig(BaseConfig):
22+
class KafkaConfig(BaseConfig):
2223
bootstrap_server: str
2324
consumer_group: str
2425
topic: str
@@ -34,12 +35,10 @@ def parse_config(self, config: Dict) -> Dict:
3435
return config
3536

3637

37-
class KafkaSource:
38-
def __init__(self, config: Dict):
39-
if 'connector_type' in config:
40-
config.pop('connector_type')
41-
self.config = KafkaSourceConfig.load(config=config)
38+
class KafkaSource(BaseSource):
39+
config_class = KafkaConfig
4240

41+
def init_client(self):
4342
print('Start initializing kafka consumer.')
4443
# Initialize kafka consumer
4544
consumer_kwargs = dict(
@@ -61,14 +60,14 @@ def __init__(self, config: Dict):
6160
)
6261
print('Finish initializing kafka consumer.')
6362

64-
def read(self):
63+
def read(self, handler: Callable):
6564
print('Start consuming messages from kafka.')
6665
for message in self.consumer:
6766
self.__print_message(message)
6867
data = json.loads(message.value.decode('utf-8'))
69-
yield data
68+
handler(data)
7069

71-
def batch_read(self):
70+
def batch_read(self, handler: Callable):
7271
print('Start consuming messages from kafka.')
7372
if self.config.batch_size > 0:
7473
batch_size = self.config.batch_size
@@ -87,7 +86,7 @@ def batch_read(self):
8786
self.__print_message(message)
8887
message_values.append(json.loads(message.value.decode('utf-8')))
8988
if len(message_values) > 0:
90-
yield message_values
89+
handler(message_values)
9190

9291
def test_connection(self):
9392
return True

mage_ai/streaming/sources/source_factory.py

+6
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,9 @@ def get_source(self, config: Dict):
99
if connector_type == SourceType.KAFKA:
1010
from mage_ai.streaming.sources.kafka import KafkaSource
1111
return KafkaSource(config)
12+
elif connector_type == SourceType.AZURE_EVENT_HUB:
13+
from mage_ai.streaming.sources.azure_event_hub import AzureEventHubSource
14+
return AzureEventHubSource(config)
15+
raise Exception(
16+
f'Consuming data from {connector_type} is not supported in streaming pipelines yet.',
17+
)

mage_ai/tests/base_test.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import unittest
77

88

9-
class TestCase(unittest.TestCase):
9+
class DBTestCase(unittest.TestCase):
1010
def setUp(self):
1111
pass
1212

@@ -28,3 +28,11 @@ def tearDownClass(self):
2828
db_connection.close_session()
2929
os.remove(TEST_DB)
3030
super().tearDownClass()
31+
32+
33+
class TestCase(unittest.TestCase):
34+
def setUp(self):
35+
pass
36+
37+
def tearDown(self):
38+
pass

mage_ai/tests/data_preparation/models/test_block.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33
from mage_ai.data_preparation.models.block import Block, BlockType
44
from mage_ai.data_preparation.models.pipeline import Pipeline
55
from mage_ai.data_preparation.variable_manager import VariableManager
6-
from mage_ai.tests.base_test import TestCase
6+
from mage_ai.tests.base_test import DBTestCase
77
from pandas.util.testing import assert_frame_equal
88
import os
99
import pandas as pd
1010

1111

12-
class BlockTest(TestCase):
12+
class BlockTest(DBTestCase):
1313
def test_create(self):
1414
block1 = Block.create('test_transformer', 'transformer', self.repo_path)
1515
block2 = Block.create('test data loader', BlockType.DATA_LOADER, self.repo_path)

mage_ai/tests/data_preparation/models/test_pipeline.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
from mage_ai.data_preparation.models.block import Block
22
from mage_ai.data_preparation.models.pipeline import InvalidPipelineError, Pipeline
33
from mage_ai.data_preparation.models.widget import Widget
4-
from mage_ai.tests.base_test import TestCase
4+
from mage_ai.tests.base_test import DBTestCase
55
import os
66

77

8-
class PipelineTest(TestCase):
8+
class PipelineTest(DBTestCase):
99
def test_create(self):
1010
pipeline = Pipeline.create(
1111
'test pipeline',

mage_ai/tests/data_preparation/models/test_variable.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
from mage_ai.data_preparation.models.block import Block
22
from mage_ai.data_preparation.models.pipeline import Pipeline
33
from mage_ai.data_preparation.models.variable import Variable, VariableType
4-
from mage_ai.tests.base_test import TestCase
4+
from mage_ai.tests.base_test import DBTestCase
55
from pandas.util.testing import assert_frame_equal
66
import os
77
import pandas as pd
88

99

10-
class VariableTest(TestCase):
10+
class VariableTest(DBTestCase):
1111
def test_write_and_read_data(self):
1212
pipeline = self.__create_pipeline('test pipeline 1')
1313
variable1 = Variable('var1', pipeline.dir_path, 'block1')

mage_ai/tests/data_preparation/test_variable_manager.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@
77
get_global_variable,
88
set_global_variable,
99
)
10-
from mage_ai.tests.base_test import TestCase
10+
from mage_ai.tests.base_test import DBTestCase
1111
from pandas.util.testing import assert_frame_equal
1212
import pandas as pd
1313

1414

15-
class VariableManagerTest(TestCase):
15+
class VariableManagerTest(DBTestCase):
1616
def test_add_and_get_variable(self):
1717
self.__create_pipeline('test pipeline 1')
1818
variable_manager = VariableManager(self.repo_path)

mage_ai/tests/io/test_config.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
from mage_ai.io.config import ConfigFileLoader, ConfigKey, EnvironmentVariableLoader
2-
from mage_ai.tests.base_test import TestCase
2+
from mage_ai.tests.base_test import DBTestCase
33
from pathlib import Path
44
from unittest import mock
55

66

7-
class ConfigLoaderTests(TestCase):
7+
class ConfigLoaderTests(DBTestCase):
88
def setUp(self):
99
super().setUp()
1010
self.test_path = Path('./test')

0 commit comments

Comments
 (0)