11from botocore .config import Config
2+ from functools import reduce
23from mage_ai .services .aws .ecs .config import EcsConfig
3- from mage_ai .services .aws .ecs .ecs import list_tasks , run_task
4+ from mage_ai .services .aws .ecs .ecs import list_tasks , run_task , stop_task
45from mage_ai .shared .array import find
56from mage_ai .shared .hash import dig
6- from typing import List
7+ from typing import Dict , List
78
89import boto3
10+ import json
911import os
1012
11-
12- CLUSTER_NAME = 'mage-data-prep-development-cluster'
13-
1413class EcsTaskManager :
15- def __init__ (self , cluster_name = CLUSTER_NAME ):
14+ def __init__ (self , cluster_name ):
1615 self .cluster_name = cluster_name
1716
17+ self .metadata_file = os .path .join (
18+ os .getcwd (),
19+ 'instance_metadata.json' ,
20+ )
21+
22+ if not os .path .exists (self .metadata_file ):
23+ self .instance_metadata = {}
24+
25+ @property
26+ def instance_metadata (self ):
27+ metadata = {}
28+ with open (self .metadata_file , 'r' , encoding = 'utf-8' ) as file :
29+ metadata = json .load (file )
30+ return metadata
31+
32+ @instance_metadata .setter
33+ def instance_metadata (self , metadata ):
34+ with open (self .metadata_file , 'w' , encoding = 'utf-8' ) as file :
35+ json .dump (metadata , file )
36+
1837 def list_tasks (self ):
1938 region_name = os .getenv ('AWS_REGION_NAME' , 'us-west-2' )
2039 config = Config (region_name = region_name )
2140 ec2_client = boto3 .client ('ec2' , config = config )
22- response = list_tasks (self .cluster_name )['tasks' ]
2341
42+ response = list_tasks (self .cluster_name )['tasks' ]
2443 network_interfaces = self .__get_network_interfaces (response , ec2_client )
2544
2645 tasks = []
27-
28- for index , task in enumerate (response ):
29- public_ip = dig (network_interfaces [index ], 'Association.PublicIp' )
46+ for task in response :
47+ public_ip = dig (network_interfaces .get (task ['taskArn' ]), 'Association.PublicIp' )
3048
3149 tags = task ['tags' ]
3250 name = find (lambda tag : tag .get ('key' ) == 'name' , tags )
3351
3452 tasks .append (dict (
3553 ip = public_ip ,
36- group = task ['group' ],
3754 name = name .get ('value' ) if name is not None else None ,
3855 status = task ['lastStatus' ],
56+ task_arn = task ['taskArn' ],
3957 type = task ['launchType' ],
4058 ))
4159
42- return tasks
60+ running_instance_names = set (map (lambda x : x ['name' ], tasks ))
61+
62+ stopped_instance_names = \
63+ [name for name in list (self .instance_metadata .keys ()) if name not in running_instance_names ]
64+ stopped_instances = \
65+ list (
66+ map (
67+ lambda name : { 'name' : name , 'status' : 'STOPPED' },
68+ stopped_instance_names
69+ )
70+ )
71+
72+ return tasks + stopped_instances
4373
4474 def create_task (self , name : str , task_definition : str , container_name : str ):
4575 region_name = os .getenv ('AWS_REGION_NAME' , 'us-west-2' )
4676 config = Config (region_name = region_name )
4777 ec2_client = boto3 .client ('ec2' , config = config )
4878
4979 # create new task
50- task = list_tasks (self .cluster_name )['tasks' ][0 ]
51- network_interface = self .__get_network_interfaces ([task ], ec2_client )[0 ]
80+ task = find (
81+ lambda task : task .get ('lastStatus' ) == 'RUNNING' ,
82+ list_tasks (self .cluster_name )['tasks' ],
83+ )
84+ network_interface = self .__get_network_interfaces ([task ], ec2_client )[task ['taskArn' ]]
5285
5386 subnets = [network_interface ['SubnetId' ]]
5487 security_groups = [g ['GroupId' ] for g in network_interface ['Groups' ]]
@@ -67,19 +100,56 @@ def create_task(self, name: str, task_definition: str, container_name: str):
67100 ],
68101 )
69102
103+ self .instance_metadata = {
104+ ** self .instance_metadata ,
105+ name : dict ()
106+ }
107+
70108 return run_task (f'mage start { name } ' , ecs_config = ecs_config )
71109
72- def __get_network_interface_id (self , task : str ):
110+ def stop_task (self , task_arn : str ):
111+ return stop_task (task_arn , self .cluster_name )
112+
113+ def delete_task (self , name , task_arn : str = None ):
114+ if task_arn :
115+ self .stop_task (task_arn )
116+
117+ updated_metadata = self .instance_metadata
118+
119+ if name in updated_metadata :
120+ del updated_metadata [name ]
121+ self .instance_metadata = updated_metadata
122+
123+ def __get_network_interface_id (self , task ):
124+ if task .get ('lastStatus' ) != 'RUNNING' :
125+ return None
126+
73127 attachment = \
74128 find (lambda a : a ['type' ] == 'ElasticNetworkInterface' , task .get ('attachments' , []))
75129 network_interface = \
76130 find (lambda d : d ['name' ] == 'networkInterfaceId' , attachment .get ('details' , []))
77131 return network_interface .get ('value' , None )
78132
133+ def __get_network_interfaces (self , tasks : List , ec2_client ) -> Dict :
134+ task_mapping = dict ()
135+ for task in tasks :
136+ nii = self .__get_network_interface_id (task )
137+ if nii is not None :
138+ task_mapping [task ['taskArn' ]] = nii
79139
80- def __get_network_interfaces (self , tasks : List , ec2_client ):
81- network_interface_ids = [self .__get_network_interface_id (task ) for task in tasks ]
140+ network_interface_ids = list (task_mapping .values ())
82141
83- return ec2_client .describe_network_interfaces (
142+ network_interfaces = ec2_client .describe_network_interfaces (
84143 NetworkInterfaceIds = network_interface_ids
85144 )['NetworkInterfaces' ]
145+
146+ def aggregate (obj , task ):
147+ task_arn = task ['taskArn' ]
148+ if task_arn in task_mapping :
149+ obj [task_arn ] = find (
150+ lambda i : i ['NetworkInterfaceId' ] == task_mapping [task_arn ],
151+ network_interfaces ,
152+ )
153+ return obj
154+
155+ return reduce (aggregate , tasks , {})
0 commit comments