Skip to content

Commit

Permalink
Generate Data Package DAGs (#18)
Browse files Browse the repository at this point in the history
* Added generation of data package DAGs.  Moved some re-usable stuff to utils

* loop syntax correction for simple string array

* Corrected reference to data package builder ECS task def

* Removed irrelevant params
  • Loading branch information
cpcundill authored Nov 5, 2024
1 parent 87e927f commit 52510b1
Show file tree
Hide file tree
Showing 4 changed files with 169 additions and 74 deletions.
76 changes: 9 additions & 67 deletions dags/collection_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,72 +12,14 @@
from airflow.operators.python import PythonOperator
from airflow.models.param import Param

from utils import get_config, get_task_log_config, load_specification_datasets
from utils import dag_default_args, get_config, load_specification_datasets, setup_configure_dag_callable

# read config from file and environment
config = get_config()

# set some variables needed for ECS tasks,
ecs_cluster = f"{config['env']}-cluster"
collection_task_defn= f"{config['env']}-mwaa-collection-task"

# set some default_arrgs for all colllections
default_args = {
"owner": "airflow",
"depends_on_past": False,
"start_date": datetime(2024, 1, 1),
"dagrun_timeout": timedelta(minutes=5),
}

# set task id for the initialisation task at the start
configure_dag_task_id = "configure-dag"


def configure_dag(**kwargs):
"""
function which returns the relevant configuration details
and stores them in xcoms for other tasks. this includes:
- get and process params into correct formats
- read in env variables
- access options defined in the task definitions
"""
aws_vpc_config = {
"subnets": kwargs['conf'].get(section='custom', key='ecs_task_subnets').split(","),
"securityGroups": kwargs['conf'].get(section='custom', key='ecs_task_security_groups').split(","),
"assignPublicIp": "ENABLED",
}

# retrieve and process parameters
params = kwargs['params']

memory = int(params.get('memory'))
cpu = int(params.get('cpu'))
transformed_jobs = str(kwargs['params'].get('transformed-jobs'))
dataset_jobs = str(kwargs['params'].get('dataset-jobs'))

# get ecs-task logging configuration
ecs_client = boto3.client('ecs')
collection_task_log_config = get_task_log_config(ecs_client, collection_task_defn)
collection_task_log_config_options = collection_task_log_config['options']
collection_task_log_group = str(collection_task_log_config_options.get('awslogs-group'))
# add container name to prefix
collection_task_log_stream_prefix = str(collection_task_log_config_options.get('awslogs-stream-prefix')) + f'/{collection_task_defn}'
collection_task_log_region = str(collection_task_log_config_options.get('awslogs-region'))
collection_dataset_bucket_name = kwargs['conf'].get(section='custom', key='collection_dataset_bucket_name')

# Push values to XCom
ti = kwargs['ti']
ti.xcom_push(key='env', value=config['env'])
ti.xcom_push(key='aws_vpc_config', value=aws_vpc_config)
ti.xcom_push(key='memory', value=memory)
ti.xcom_push(key='cpu', value=cpu)
ti.xcom_push(key='transformed-jobs', value=transformed_jobs)
ti.xcom_push(key='dataset-jobs', value=dataset_jobs)
ti.xcom_push(key='collection-task-log-group', value=collection_task_log_group)
ti.xcom_push(key='collection-task-log-stream-prefix', value=collection_task_log_stream_prefix)
ti.xcom_push(key='collection-task-log-region', value=collection_task_log_region)
ti.xcom_push(key='collection-dataset-bucket-name', value=collection_dataset_bucket_name)

collection_task_name = f"{config['env']}-mwaa-collection-task"

collections = load_specification_datasets()

Expand All @@ -86,7 +28,7 @@ def configure_dag(**kwargs):

with DAG(
f"{collection}-collection",
default_args=default_args,
default_args=dag_default_args,
description=f"Collection task for the {collection} collection",
schedule=None,
catchup=False,
Expand All @@ -99,9 +41,9 @@ def configure_dag(**kwargs):
render_template_as_native_obj=True,
is_paused_upon_creation=False
) as dag:
convert_params_task = PythonOperator(
task_id=configure_dag_task_id,
python_callable=configure_dag,
configure_dag_task = PythonOperator(
task_id="configure-dag",
python_callable=setup_configure_dag_callable(config, collection_task_name),
dag=dag,
)

Expand All @@ -110,12 +52,12 @@ def configure_dag(**kwargs):
dag=dag,
execution_timeout=timedelta(minutes=600),
cluster=ecs_cluster,
task_definition=collection_task_defn,
task_definition=collection_task_name,
launch_type="FARGATE",
overrides={
"containerOverrides": [
{
"name": collection_task_defn,
"name": collection_task_name,
'cpu': '{{ task_instance.xcom_pull(task_ids="configure-dag", key="cpu") | int }}',
'memory': '{{ task_instance.xcom_pull(task_ids="configure-dag", key="memory") | int }}',
"environment": [
Expand Down Expand Up @@ -145,4 +87,4 @@ def configure_dag(**kwargs):
awslogs_fetch_interval=timedelta(seconds=1)
)

convert_params_task >> collection_ecs_task
configure_dag_task >> collection_ecs_task
22 changes: 18 additions & 4 deletions dags/dag_triggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,16 @@ def collection_selected(collection_name, configuration):
organisation_collection_selected = collection_selected('organisation', config)

if organisation_collection_selected:
run_org_dag = TriggerDagRunOperator(
run_org_collection_dag = TriggerDagRunOperator(
task_id='trigger-organisation-collection-dag',
trigger_dag_id=f'organisation-collection'
)
run_org_builder_dag = TriggerDagRunOperator(
task_id='trigger-organisation-builder-dag',
trigger_dag_id=f'organisation-builder',
wait_for_completion=True
)
run_org_collection_dag >> run_org_builder_dag

collections = load_specification_datasets()

Expand All @@ -51,7 +57,7 @@ def collection_selected(collection_name, configuration):
trigger_dag_id=f'{collection}-collection'
)
if organisation_collection_selected:
run_org_dag >> collection_dag
run_org_builder_dag >> collection_dag


with DAG(
Expand All @@ -63,11 +69,19 @@ def collection_selected(collection_name, configuration):
is_paused_upon_creation=False
):

run_org_dag = TriggerDagRunOperator(
run_org_collection_dag = TriggerDagRunOperator(
task_id='trigger-organisation-collection-dag',
trigger_dag_id=f'organisation-collection'
)

run_org_builder_dag = TriggerDagRunOperator(
task_id='trigger-organisation-builder-dag',
trigger_dag_id=f'organisation-builder',
wait_for_completion=True
)

run_org_collection_dag >> run_org_builder_dag

collections = load_specification_datasets()

for collection, datasets in collections.items():
Expand All @@ -78,4 +92,4 @@ def collection_selected(collection_name, configuration):
trigger_dag_id=f'{collection}-collection'
)

run_org_dag >> collection_dag
run_org_builder_dag >> collection_dag
80 changes: 80 additions & 0 deletions dags/data_package_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from datetime import timedelta

from airflow import DAG

from utils import dag_default_args, get_config, setup_configure_dag_callable
from airflow.providers.amazon.aws.operators.ecs import (
EcsRegisterTaskDefinitionOperator,
EcsRunTaskOperator,
)
from airflow.operators.python import PythonOperator
from airflow.models.param import Param

data_packages = ["organisation"]

# read config from file and environment
config = get_config()

# set some variables needed for ECS tasks,
ecs_cluster = f"{config['env']}-cluster"
task_definition_name = f"{config['env']}-mwaa-data-package-builder-task"

for package in data_packages:
with DAG(
f"{package}-builder",
default_args=dag_default_args,
description=f"Data package builder task for the {package} data package",
schedule=None,
catchup=False,
params={
"cpu": Param(default=8192, type="integer"),
"memory": Param(default=32768, type="integer")
},
render_template_as_native_obj=True,
is_paused_upon_creation=False
) as dag:
configure_dag_task = PythonOperator(
task_id="configure-dag",
python_callable=setup_configure_dag_callable(config, task_definition_name),
dag=dag,
)

builder_ecs_task = EcsRunTaskOperator(
task_id=f"build-data-package",
dag=dag,
execution_timeout=timedelta(minutes=600),
cluster=ecs_cluster,
task_definition=task_definition_name,
launch_type="FARGATE",
overrides={
"containerOverrides": [
{
"name": task_definition_name,
'cpu': '{{ task_instance.xcom_pull(task_ids="configure-dag", key="cpu") | int }}',
'memory': '{{ task_instance.xcom_pull(task_ids="configure-dag", key="memory") | int }}',
"environment": [
{"name": "ENVIRONMENT",
"value": "'{{ task_instance.xcom_pull(task_ids=\"configure-dag\", key=\"env\") | string }}'"},
{"name": "DATA_PACKAGE_NAME", "value": package},
{
"name": "READ_S3_BUCKET",
"value": "'{{ task_instance.xcom_pull(task_ids=\"configure-dag\", key=\"collection-dataset-bucket-name\") | string }}'"
},
{
"name": "WRITE_S3_BUCKET",
"value": "'{{ task_instance.xcom_pull(task_ids=\"configure-dag\", key=\"collection-dataset-bucket-name\") | string }}'"
}
],
},
]
},
network_configuration={
"awsvpcConfiguration": '{{ task_instance.xcom_pull(task_ids="configure-dag", key="aws_vpc_config") }}'
},
awslogs_group='{{ task_instance.xcom_pull(task_ids="configure-dag", key="collection-task-log-group") }}',
awslogs_region='{{ task_instance.xcom_pull(task_ids="configure-dag", key="collection-task-log-region") }}',
awslogs_stream_prefix='{{ task_instance.xcom_pull(task_ids="configure-dag", key="collection-task-log-stream-prefix") }}',
awslogs_fetch_interval=timedelta(seconds=1)
)

configure_dag_task >> builder_ecs_task
65 changes: 62 additions & 3 deletions dags/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,20 @@
import os
import tempfile
import urllib
from datetime import datetime, timedelta
from pathlib import Path

import boto3
import logging

# Some useful default args for all DAGs
dag_default_args = {
"owner": "airflow",
"depends_on_past": False,
"start_date": datetime(2024, 1, 1),
"dagrun_timeout": timedelta(minutes=5),
}


def get_config(path=None):
if path is None:
Expand Down Expand Up @@ -42,7 +51,7 @@ def load_specification_datasets():
return collections_dict


def get_task_log_config(ecs_client,task_definition_family):
def get_task_log_config(ecs_client, task_definition_family):
"""
returns the log configuration of a task definition stored in aws
assumes the local environment is set up to access aws
Expand All @@ -52,6 +61,56 @@ def get_task_log_config(ecs_client,task_definition_family):
response = ecs_client.describe_task_definition(taskDefinition=task_definition_family)

# Extract the log configuration from the container definitions
log_config = response['taskDefinition']['containerDefinitions'][0].get('logConfiguration',{})
log_config = response['taskDefinition']['containerDefinitions'][0].get('logConfiguration', {})

return log_config
return log_config


def setup_configure_dag_callable(config, task_definition_name):
def configure_dag(**kwargs):
"""
function which returns the relevant configuration details
and stores them in xcoms for other tasks. this includes:
- get and process params into correct formats
- read in env variables
- access options defined in the task definitions
"""
aws_vpc_config = {
"subnets": kwargs['conf'].get(section='custom', key='ecs_task_subnets').split(","),
"securityGroups": kwargs['conf'].get(section='custom', key='ecs_task_security_groups').split(","),
"assignPublicIp": "ENABLED",
}

# retrieve and process parameters
params = kwargs['params']

memory = int(params.get('memory'))
cpu = int(params.get('cpu'))
transformed_jobs = str(kwargs['params'].get('transformed-jobs'))
dataset_jobs = str(kwargs['params'].get('dataset-jobs'))

# get ecs-task logging configuration
ecs_client = boto3.client('ecs')
collection_task_log_config = get_task_log_config(ecs_client, task_definition_name)
collection_task_log_config_options = collection_task_log_config['options']
collection_task_log_group = str(collection_task_log_config_options.get('awslogs-group'))
# add container name to prefix
collection_task_log_stream_prefix = (str(collection_task_log_config_options.get('awslogs-stream-prefix'))
+ f'/{task_definition_name}')
collection_task_log_region = str(collection_task_log_config_options.get('awslogs-region'))
collection_dataset_bucket_name = kwargs['conf'].get(section='custom', key='collection_dataset_bucket_name')

# Push values to XCom
ti = kwargs['ti']
ti.xcom_push(key='env', value=config['env'])
ti.xcom_push(key='aws_vpc_config', value=aws_vpc_config)
ti.xcom_push(key='memory', value=memory)
ti.xcom_push(key='cpu', value=cpu)
ti.xcom_push(key='transformed-jobs', value=transformed_jobs)
ti.xcom_push(key='dataset-jobs', value=dataset_jobs)
ti.xcom_push(key='collection-task-log-group', value=collection_task_log_group)
ti.xcom_push(key='collection-task-log-stream-prefix', value=collection_task_log_stream_prefix)
ti.xcom_push(key='collection-task-log-region', value=collection_task_log_region)
ti.xcom_push(key='collection-dataset-bucket-name', value=collection_dataset_bucket_name)

return configure_dag

0 comments on commit 52510b1

Please sign in to comment.