diff --git a/airflow/contrib/hooks/segment_hook.py b/airflow/contrib/hooks/segment_hook.py new file mode 100644 index 0000000000000..874d35d0743b9 --- /dev/null +++ b/airflow/contrib/hooks/segment_hook.py @@ -0,0 +1,92 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +""" +This module contains a Segment Hook +which allows you to connect to your Segment account, +retrieve data from it or write to that file. + +NOTE: this hook also relies on the Segment analytics package: + https://github.com/segmentio/analytics-python +""" +import analytics +from airflow.hooks.base_hook import BaseHook +from airflow.exceptions import AirflowException + +from airflow.utils.log.logging_mixin import LoggingMixin + + +class SegmentHook(BaseHook, LoggingMixin): + def __init__( + self, + segment_conn_id='segment_default', + segment_debug_mode=False, + *args, + **kwargs + ): + """ + Create new connection to Segment + and allows you to pull data out of Segment or write to it. + + You can then use that file with other + Airflow operators to move the data around or interact with segment. + + :param segment_conn_id: the name of the connection that has the parameters + we need to connect to Segment. + The connection should be type `json` and include a + write_key security token in the `Extras` field. + :type segment_conn_id: str + :param segment_debug_mode: Determines whether Segment should run in debug mode. + Defaults to False + :type segment_debug_mode: boolean + .. note:: + You must include a JSON structure in the `Extras` field. + We need a user's security token to connect to Segment. + So we define it in the `Extras` field as: + `{"write_key":"YOUR_SECURITY_TOKEN"}` + """ + self.segment_conn_id = segment_conn_id + self.segment_debug_mode = segment_debug_mode + self._args = args + self._kwargs = kwargs + + # get the connection parameters + self.connection = self.get_connection(self.segment_conn_id) + self.extras = self.connection.extra_dejson + self.write_key = self.extras.get('write_key') + if self.write_key is None: + raise AirflowException('No Segment write key provided') + + def get_conn(self): + self.log.info('Setting write key for Segment analytics connection') + analytics.debug = self.segment_debug_mode + if self.segment_debug_mode: + self.log.info('Setting Segment analytics connection to debug mode') + analytics.on_error = self.on_error + analytics.write_key = self.write_key + return analytics + + def on_error(self, error, items): + """ + Handles error callbacks when using Segment with segment_debug_mode set to True + """ + self.log.error('Encountered Segment error: {segment_error} with ' + 'items: {with_items}'.format(segment_error=error, + with_items=items)) + raise AirflowException('Segment error: {}'.format(error)) diff --git a/airflow/contrib/operators/segment_track_event_operator.py b/airflow/contrib/operators/segment_track_event_operator.py new file mode 100644 index 0000000000000..faacce84f03dd --- /dev/null +++ b/airflow/contrib/operators/segment_track_event_operator.py @@ -0,0 +1,69 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from airflow.contrib.hooks.segment_hook import SegmentHook +from airflow.models import BaseOperator +from airflow.utils.decorators import apply_defaults + + +class SegmentTrackEventOperator(BaseOperator): + """ + Send Track Event to Segment for a specified user_id and event + + :param user_id: The ID for this user in your database + :type user_id: string + :param event: The name of the event you're tracking + :type event: string + :param properties: A dictionary of properties for the event. + :type properties: dict + :param segment_conn_id: The connection ID to use when connecting to Segment. + :type segment_conn_id: string + :param segment_debug_mode: Determines whether Segment should run in debug mode. + Defaults to False + :type segment_debug_mode: boolean + """ + template_fields = ('user_id', 'event', 'properties') + ui_color = '#ffd700' + + @apply_defaults + def __init__(self, + user_id, + event, + properties=None, + segment_conn_id='segment_default', + segment_debug_mode=False, + *args, + **kwargs): + super(SegmentTrackEventOperator, self).__init__(*args, **kwargs) + self.user_id = user_id + self.event = event + properties = properties or {} + self.properties = properties + self.segment_debug_mode = segment_debug_mode + self.segment_conn_id = segment_conn_id + + def execute(self, context): + hook = SegmentHook(segment_conn_id=self.segment_conn_id, + segment_debug_mode=self.segment_debug_mode) + + self.log.info( + 'Sending track event ({0}) for user id: {1} with properties: {2}'. + format(self.event, self.user_id, self.properties)) + + hook.track(self.user_id, self.event, self.properties) diff --git a/airflow/models.py b/airflow/models.py index 536cf6d0cf7c4..5903075b8612e 100755 --- a/airflow/models.py +++ b/airflow/models.py @@ -600,6 +600,7 @@ class Connection(Base, LoggingMixin): ('aws', 'Amazon Web Services',), ('emr', 'Elastic MapReduce',), ('snowflake', 'Snowflake',), + ('segment', 'Segment',), ] def __init__( diff --git a/airflow/utils/db.py b/airflow/utils/db.py index 3ec6069809719..7bbda9328d655 100644 --- a/airflow/utils/db.py +++ b/airflow/utils/db.py @@ -7,9 +7,9 @@ # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -268,6 +268,10 @@ def initdb(rbac=False): models.Connection( conn_id='qubole_default', conn_type='qubole', host= 'localhost')) + merge_conn( + models.Connection( + conn_id='segment_default', conn_type='segment', + extra='{"write_key": "my-segment-write-key"}')) # Known event types KET = models.KnownEventType diff --git a/docs/code.rst b/docs/code.rst index 596e9c3a32302..618f376c4c72c 100644 --- a/docs/code.rst +++ b/docs/code.rst @@ -172,6 +172,7 @@ Operators .. autoclass:: airflow.contrib.operators.qubole_operator.QuboleOperator .. autoclass:: airflow.contrib.operators.s3_list_operator.S3ListOperator .. autoclass:: airflow.contrib.operators.s3_to_gcs_operator.S3ToGoogleCloudStorageOperator +.. autoclass:: airflow.operators.segment_track_event_operator.SegmentTrackEventOperator .. autoclass:: airflow.contrib.operators.sftp_operator.SFTPOperator .. autoclass:: airflow.contrib.operators.slack_webhook_operator.SlackWebhookOperator .. autoclass:: airflow.contrib.operators.snowflake_operator.SnowflakeOperator @@ -372,6 +373,7 @@ Community contributed hooks .. autoclass:: airflow.contrib.hooks.redis_hook.RedisHook .. autoclass:: airflow.contrib.hooks.redshift_hook.RedshiftHook .. autoclass:: airflow.contrib.hooks.salesforce_hook.SalesforceHook +.. autoclass:: airflow.contrib.hooks.segment_hook.SegmentHook .. autoclass:: airflow.contrib.hooks.sftp_hook.SFTPHook .. autoclass:: airflow.contrib.hooks.slack_webhook_hook.SlackWebhookHook .. autoclass:: airflow.contrib.hooks.snowflake_hook.SnowflakeHook diff --git a/setup.py b/setup.py index 3e0b37615675d..023e6ee0ae95b 100644 --- a/setup.py +++ b/setup.py @@ -177,6 +177,7 @@ def write_version(filename=os.path.join(*['airflow', s3 = ['boto3>=1.7.0'] salesforce = ['simple-salesforce>=0.72'] samba = ['pysmbclient>=0.1.3'] +segment = ['analytics-python>=1.2.9'] slack = ['slackclient>=1.0.0'] snowflake = ['snowflake-connector-python>=1.5.2', 'snowflake-sqlalchemy>=1.1.0'] @@ -211,7 +212,7 @@ def write_version(filename=os.path.join(*['airflow', devel_all = (sendgrid + devel + all_dbs + doc + samba + s3 + slack + crypto + oracle + docker + ssh + kubernetes + celery + azure + redis + gcp_api + datadog + zendesk + jdbc + ldap + kerberos + password + webhdfs + jenkins + - druid + pinot + snowflake + elasticsearch) + druid + pinot + segment + snowflake + elasticsearch) # Snakebite & Google Cloud Dataflow are not Python 3 compatible :'( if PY3: @@ -316,6 +317,7 @@ def do_setup(): 'salesforce': salesforce, 'samba': samba, 'sendgrid': sendgrid, + 'segment': segment, 'slack': slack, 'snowflake': snowflake, 'ssh': ssh, diff --git a/tests/contrib/hooks/test_segment_hook.py b/tests/contrib/hooks/test_segment_hook.py new file mode 100644 index 0000000000000..9aa854e34aa06 --- /dev/null +++ b/tests/contrib/hooks/test_segment_hook.py @@ -0,0 +1,64 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +import mock +import unittest + +from airflow import configuration, AirflowException + +from airflow.contrib.hooks.segment_hook import SegmentHook + +TEST_CONN_ID = 'test_segment' +WRITE_KEY = 'foo' + + +class TestSegmentHook(unittest.TestCase): + + def setUp(self): + super(TestSegmentHook, self).setUp() + configuration.load_test_config() + + self.conn = conn = mock.MagicMock() + conn.write_key = WRITE_KEY + self.expected_write_key = WRITE_KEY + self.conn.extra_dejson = {'write_key': self.expected_write_key} + + class UnitTestSegmentHook(SegmentHook): + + def get_conn(self): + return conn + + def get_connection(self, connection_id): + return conn + + self.test_hook = UnitTestSegmentHook(segment_conn_id=TEST_CONN_ID) + + def test_get_conn(self): + expected_connection = self.test_hook.get_conn() + self.assertEqual(expected_connection, self.conn) + self.assertIsNotNone(expected_connection.write_key) + self.assertEqual(expected_connection.write_key, self.expected_write_key) + + def test_on_error(self): + with self.assertRaises(AirflowException): + self.test_hook.on_error('error', ['items']) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/contrib/operators/test_segment_track_event_operator.py b/tests/contrib/operators/test_segment_track_event_operator.py new file mode 100644 index 0000000000000..9aa854e34aa06 --- /dev/null +++ b/tests/contrib/operators/test_segment_track_event_operator.py @@ -0,0 +1,64 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +import mock +import unittest + +from airflow import configuration, AirflowException + +from airflow.contrib.hooks.segment_hook import SegmentHook + +TEST_CONN_ID = 'test_segment' +WRITE_KEY = 'foo' + + +class TestSegmentHook(unittest.TestCase): + + def setUp(self): + super(TestSegmentHook, self).setUp() + configuration.load_test_config() + + self.conn = conn = mock.MagicMock() + conn.write_key = WRITE_KEY + self.expected_write_key = WRITE_KEY + self.conn.extra_dejson = {'write_key': self.expected_write_key} + + class UnitTestSegmentHook(SegmentHook): + + def get_conn(self): + return conn + + def get_connection(self, connection_id): + return conn + + self.test_hook = UnitTestSegmentHook(segment_conn_id=TEST_CONN_ID) + + def test_get_conn(self): + expected_connection = self.test_hook.get_conn() + self.assertEqual(expected_connection, self.conn) + self.assertIsNotNone(expected_connection.write_key) + self.assertEqual(expected_connection.write_key, self.expected_write_key) + + def test_on_error(self): + with self.assertRaises(AirflowException): + self.test_hook.on_error('error', ['items']) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/core.py b/tests/core.py index 5ab2e9414c40d..ce32482d04f2a 100644 --- a/tests/core.py +++ b/tests/core.py @@ -1038,6 +1038,7 @@ def test_cli_connections_list(self): self.assertIn(['mysql_default', 'mysql'], conns) self.assertIn(['postgres_default', 'postgres'], conns) self.assertIn(['wasb_default', 'wasb'], conns) + self.assertIn(['segment_default', 'segment'], conns) # Attempt to list connections with invalid cli args with mock.patch('sys.stdout',