-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathhandle_terraform.py
224 lines (182 loc) · 6.18 KB
/
handle_terraform.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
import logging
import json
import time
import threading
from itertools import zip_longest
from kube_terraform import KubeTerraform
import pk_config
dryrun_id = "terraform"
CONFIG_POD_NAME = "terraform_container_name"
CONFIG_PATH = "terraform_path"
TF_VARS_PATH = "/submitter/terraform.tfvars.json"
# Append this to commands to log to the Terraform container
LOG_SUFFIX = (
" | while IFS= read -r line;"
' do printf "%s %s\n" "$(date "+[%Y-%m-%d %H:%M:%S]")" "$line";'
" done | tee /proc/1/fd/1"
)
log = logging.getLogger("pk_terraform")
def scale_worker_node(config, scaling_info_list):
"""
Calculate the scaling variance and modify the vars file accordingly
"""
if pk_config.dryrun_get(dryrun_id):
log.info("(S) DRYRUN enabled. Skipping...")
return
node_variables = _read_vars_file()
for info in scaling_info_list:
replicas = info.get("replicas")
node_name = info.get("node_name")
nodes = node_variables[node_name]
variance = replicas - len(nodes)
if variance > 0:
_add_nodes(nodes, variance)
elif variance < 0:
_del_nodes(nodes, abs(variance))
log.info("(S) {} => m_node_count: {}".format(node_name, replicas))
_write_vars_file(node_variables)
# Thread this action so PK can continue
if _thread_not_active():
tf_thread = threading.Thread(
target=perform_scaling, args=(config,), name="TerraformThread",
)
tf_thread.start()
def query_number_of_worker_nodes(config, worker_name):
"""
Return the number of instances of a worker node, pulled from tfstate
"""
instances = 1
if pk_config.dryrun_get(dryrun_id):
log.info("(C) DRYRUN enabled. Skipping...")
return instances
try:
resources = _get_resources_from_state(config, worker_name)
instances = len(resources[0]["instances"])
except Exception:
log.error("Failed to get no. of instances for {}".format(worker_name))
log.debug("-->instances: {0}".format(instances))
return instances
def drop_worker_node(config, scaling_info_list):
"""
Scan for a specific worker node IP and drop it
"""
indices = []
for info in scaling_info_list:
ips_to_drop = info.get("replicas")
node_name = info.get("node_name")
if not ips_to_drop:
continue
ip_list = _get_ips_from_output(config, node_name)
if not ip_list:
continue
for target_ip in ips_to_drop:
indices.append(_get_target_by_ip(ip_list, target_ip))
_drop_nodes_by_indices(node_name, indices)
if _thread_not_active():
tf_thread = threading.Thread(
target=perform_scaling, args=(config,), name="TerraformThread",
)
tf_thread.start()
def get_terraform(config):
"""
Return the Terraform pod
"""
deployment = config.get(CONFIG_POD_NAME, "terraform")
for i in range(1, 6):
try:
terraform = KubeTerraform(deployment)
return terraform
except Exception as e:
log.debug("Failed attempt {}/5 attaching to Terraform: {}".format(i, e))
time.sleep(5)
log.error("Failed to get Terraform pod")
def perform_scaling(config):
"""
Execute the terraform apply command in the TF container
"""
terraform = get_terraform(config)
terra_path = config.get(CONFIG_PATH)
shell_command = "terraform apply -auto-approve -no-color" + LOG_SUFFIX
terraform.exec_run(shell_command)
def _get_json_from_command(config, command):
"""
Return the JSON from a specific TF command
"""
terraform = get_terraform(config)
terra_path = config.get(CONFIG_PATH)
out = terraform.exec_run(command, workdir=terra_path)
if out.exit_code != 0 or not out.output:
log.error("Could not get json data")
return json.loads(out.output.decode())
def _get_ips_from_output(config, node):
"""
Parse the IP info from the terraform output command's JSON
"""
command = ["terraform", "output", "-json"]
ip_info = _get_json_from_command(config, command)
if not ip_info.get(node):
return
return ip_info[node]["value"]
def _get_resources_from_state(config, worker_name):
"""
Return the resource matching the worker from the tfstate file
"""
command = ["terraform", "state", "pull"]
json_state = _get_json_from_command(config, command)
resources = json_state.get("resources")
resources = [x for x in resources if x.get("name") == worker_name]
if not resources:
log.error("No resources matching {}".format(worker_name))
return resources
def _get_target_by_ip(ip_list, ip_to_drop):
"""
Scan the instance IPs and return the index of the match to drop
"""
privates = ip_list.get("private_ips", [])
publics = ip_list.get("public_ips", [])
for index, ips in enumerate(zip_longest(privates, publics)):
if ip_to_drop in ips:
return index
def _read_vars_file():
"""
Return the data from the tfvars file
"""
with open(TF_VARS_PATH, "r") as file:
data = json.load(file)
return data
def _write_vars_file(data):
"""
Write data to the tfvars file
"""
with open(TF_VARS_PATH, "w") as file:
json.dump(data, file, indent=4)
def _thread_not_active():
"""
Return true if a terraform apply is not already running
"""
return "TerraformThread" not in [x.name for x in threading.enumerate()]
def _add_nodes(nodes, variance):
"""
Scale up by adding node(s) to the infrastructure
"""
if not nodes:
nodes.append("1")
variance -= 1
for _ in range(variance):
node_name = str(int(nodes[-1])+1)
nodes.append(node_name)
def _del_nodes(nodes, variance):
"""
Scale down by removing node(s) from the infrastructure
"""
for _ in range(variance):
nodes.pop()
def _drop_nodes_by_indices(node_name, indices):
"""
Drop specific nodes by removing them from the tfvars file
"""
node_variables = _read_vars_file()
for index in indices:
node_variables[node_name].sort()
node_variables[node_name].pop(index)
_write_vars_file(node_variables)