diff --git a/beeswithmachineguns/bees.py b/beeswithmachineguns/bees.py index 7ef56e2..aef7279 100644 --- a/beeswithmachineguns/bees.py +++ b/beeswithmachineguns/bees.py @@ -56,6 +56,10 @@ import boto.ec2 import boto.exception + +from boto.sts import STSConnection +from boto.ec2.connection import EC2Connection + import paramiko import json from collections import defaultdict @@ -63,6 +67,7 @@ STATE_FILENAME = os.path.expanduser('~/.bees') +CONFIG_FILENAME = os.path.expanduser('~/.bees_config') # Utilities @@ -117,10 +122,13 @@ def _get_security_group_id(connection, security_group_name, subnet): print('The bees need a security group to run under. Need to open a port from where you are to the target subnet.') return - security_groups = connection.get_all_security_groups(filters={'group-name': [security_group_name]}) + print('in this function') + ec2_connection = connection + + security_groups = ec2_connection.get_all_security_groups(filters={'group-name': [security_group_name]}) if not security_groups: - security_groups = connection.get_all_security_groups(filters={'group-id': [security_group_name]}) + security_groups = ec2_connection.get_all_security_groups(filters={'group-id': [security_group_name]}) if not security_groups: print('The bees need a security group to run under. The one specified was not found.') return @@ -134,11 +142,57 @@ def up(count, group, zone, image_id, instance_type, username, key_name, subnet, Startup the load testing server. """ + try: + + file_exists = os.path.isfile(CONFIG_FILENAME) + if file_exists: + file = open(CONFIG_FILENAME, "r") + lines = file.readlines() + remembered_mfa_serial=lines[0].replace("\n","") + remembered_region=lines[1] + mfa_serial = raw_input("Enter the MFA serial [" + remembered_mfa_serial + "]: ") or remembered_mfa_serial + else: + mfa_serial = raw_input("Enter the MFA serial (for example arn:aws:iam::1234567891011:mfa/myusername): ") + mfa_TOTP = raw_input("Enter the MFA code: ") + + sts_connection = STSConnection() + + tempCredentials = sts_connection.get_session_token( + duration=3600, + mfa_serial_number=mfa_serial, + mfa_token=mfa_TOTP + ) + + region = boto.ec2.get_region(_get_region(zone)) + ec2_connection = EC2Connection( + region=region, + aws_access_key_id=tempCredentials.access_key, + aws_secret_access_key=tempCredentials.secret_key, + security_token=tempCredentials.session_token + ) + + file = open(CONFIG_FILENAME, "w") + file.write(mfa_serial + "\n") + file.write(_get_region(zone)) + file.close() + + except boto.exception.NoAuthHandlerFound as e: + print("Authenciation config error, perhaps you do not have a ~/.boto file with correct permissions?") + print(e.message) + return e + except Exception as e: + print("Unknown error occured:") + print(e.message) + return e + + if ec2_connection == None: + raise Exception("Invalid zone specified? Unable to connect to region using zone name") + existing_username, existing_key_name, existing_zone, instance_ids = _read_server_list(zone) count = int(count) if existing_username == username and existing_key_name == key_name and existing_zone == zone: - ec2_connection = boto.ec2.connect_to_region(_get_region(zone)) + existing_reservations = ec2_connection.get_all_instances(instance_ids=instance_ids) existing_instances = [instance for reservation in existing_reservations for instance in reservation.instances if instance.state == 'running'] # User, key and zone match existing values and instance ids are found on state file @@ -165,41 +219,38 @@ def up(count, group, zone, image_id, instance_type, username, key_name, subnet, print('Warning. No key file found for %s. You will need to add this key to your SSH agent to connect.' % pem_path) print('Connecting to the hive.') - - try: - ec2_connection = boto.ec2.connect_to_region(_get_region(zone)) - except boto.exception.NoAuthHandlerFound as e: - print("Authenciation config error, perhaps you do not have a ~/.boto file with correct permissions?") - print(e.message) - return e - except Exception as e: - print("Unknown error occured:") - print(e.message) - return e - - if ec2_connection == None: - raise Exception("Invalid zone specified? Unable to connect to region using zone name") - groupId = group if subnet is None else _get_security_group_id(ec2_connection, group, subnet) print("GroupId found: %s" % groupId) - + placement = None if 'gov' in zone else zone print("Placement: %s" % placement) if bid: - print('Attempting to call up %i spot bees, this can take a while...' % count) - - spot_requests = ec2_connection.request_spot_instances( - image_id=image_id, - price=bid, - count=count, - key_name=key_name, - security_group_ids=[groupId], - instance_type=instance_type, - placement=placement, - subnet_id=subnet) - - # it can take a few seconds before the spot requests are fully processed + print('Attempting to call up %i spot bees, this can take a while...' % count) + + if "sg-" not in groupId: + spot_requests = ec2_connection.request_spot_instances( + image_id=image_id, + price=bid, + count=count, + key_name=key_name, + security_groups=[groupId], + instance_type=instance_type, + placement=placement, + subnet_id=subnet) + + else: + spot_requests = ec2_connection.request_spot_instances( + image_id=image_id, + price=bid, + count=count, + key_name=key_name, + security_group_ids=[groupId], + instance_type=instance_type, + placement=placement, + subnet_id=subnet) + + # it can take a few seconds before the spot requests are fully processed time.sleep(5) instances = _wait_for_spot_request_fulfillment(ec2_connection, spot_requests) @@ -272,7 +323,29 @@ def _check_instances(): print('No bees have been mobilized.') return - ec2_connection = boto.ec2.connect_to_region(_get_region(zone)) + file = open(CONFIG_FILENAME, "r") + lines = file.readlines() + remembered_mfa_serial=lines[0].replace("\n","") + remembered_region=lines[1] + + mfa_serial = raw_input("Enter the MFA serial [" + remembered_mfa_serial + "]: ") or remembered_mfa_serial + mfa_TOTP = raw_input("Enter the MFA code: ") + + sts_connection = STSConnection() + + tempCredentials = sts_connection.get_session_token( + duration=3600, + mfa_serial_number=mfa_serial, + mfa_token=mfa_TOTP + ) + + region = boto.ec2.get_region(remembered_region) + ec2_connection = EC2Connection( + region=region, + aws_access_key_id=tempCredentials.access_key, + aws_secret_access_key=tempCredentials.secret_key, + security_token=tempCredentials.session_token + ) reservations = ec2_connection.get_all_instances(instance_ids=instance_ids) @@ -292,9 +365,11 @@ def down(*mr_zone): """ Shutdown the load testing server. """ - def _check_to_down_it(): + + def _check_to_down_it(region): '''check if we can bring down some bees''' - username, key_name, zone, instance_ids = _read_server_list(region) + + username, key_name, zone, instance_ids = _read_server_list(region) if not instance_ids: print('No bees have been mobilized.') @@ -302,7 +377,29 @@ def _check_to_down_it(): print('Connecting to the hive.') - ec2_connection = boto.ec2.connect_to_region(_get_region(zone)) + file = open(CONFIG_FILENAME, "r") + lines = file.readlines() + remembered_mfa_serial=lines[0].replace("\n","") + remembered_region=lines[1] + + mfa_serial = raw_input("Enter the MFA serial [" + remembered_mfa_serial + "]: ") or remembered_mfa_serial + mfa_TOTP = raw_input("Enter the MFA code: ") + + sts_connection = STSConnection() + + tempCredentials = sts_connection.get_session_token( + duration=3600, + mfa_serial_number=mfa_serial, + mfa_token=mfa_TOTP + ) + + region = boto.ec2.get_region(remembered_region) + ec2_connection = EC2Connection( + region=region, + aws_access_key_id=tempCredentials.access_key, + aws_secret_access_key=tempCredentials.secret_key, + security_token=tempCredentials.session_token + ) print(('Calling off the swarm for {}.').format(region)) @@ -313,12 +410,11 @@ def _check_to_down_it(): _delete_server_list(zone) - if len(mr_zone) > 0: username, key_name, zone, instance_ids = _read_server_list(mr_zone[-1]) else: for region in _get_existing_regions(): - _check_to_down_it() + _check_to_down_it(region) def _wait_for_spot_request_fulfillment(conn, requests, fulfilled_requests = []): """ @@ -842,7 +938,29 @@ def hurl_attack(url, n, c, **options): print('Connecting to the hive.') - ec2_connection = boto.ec2.connect_to_region(_get_region(zone)) + file = open(CONFIG_FILENAME, "r") + lines = file.readlines() + remembered_mfa_serial=lines[0].replace("\n","") + remembered_region=lines[1] + + mfa_serial = raw_input("Enter the MFA serial [" + remembered_mfa_serial + "]: ") or remembered_mfa_serial + mfa_TOTP = raw_input("Enter the MFA code: ") + + sts_connection = STSConnection() + + tempCredentials = sts_connection.get_session_token( + duration=3600, + mfa_serial_number=mfa_serial, + mfa_token=mfa_TOTP + ) + + region = boto.ec2.get_region(remembered_region) + ec2_connection = EC2Connection( + region=region, + aws_access_key_id=tempCredentials.access_key, + aws_secret_access_key=tempCredentials.secret_key, + security_token=tempCredentials.session_token + ) print('Assembling bees.') @@ -1307,3 +1425,5 @@ def _get_existing_regions(): something= re.search(r'\.bees\.(.*)', f) existing_regions.append( something.group(1)) if something else "no" return existing_regions + +