Skip to content

Commit

Permalink
[COMP-554] Utility functions to support Kuberay Executor. (#327)
Browse files Browse the repository at this point in the history
* Avoid mutable function args.
* Additional k8s utility functions for kuberay.
* Modify code packaging utilities to support kuberay.
* unit test for package_code optional args
  • Loading branch information
echee-insitro authored Nov 3, 2023
1 parent 55e2bef commit fb78bb0
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 13 deletions.
41 changes: 32 additions & 9 deletions redun/executors/code_packaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,22 @@ def find_code_files(
return files


def create_tar(tar_path: str, file_paths: Iterable[str]) -> File:
def create_tar(
tar_path: str, file_paths: Iterable[str], arcname_prefix: Optional[str] = None
) -> File:
"""
Create a tar file from local file paths.
Args:
arcname_prefix: prefix to add to each file path in the tar file.
"""
tar_file = File(tar_path)

with tar_file.open("wb") as out:
with tarfile.open(fileobj=out, mode="w|gz") as tar:
for file_path in file_paths:
tar.add(file_path)
# add prefix to each file
arcname = os.path.join(arcname_prefix, file_path) if arcname_prefix else file_path
tar.add(file_path, arcname=arcname)

return tar_file

Expand All @@ -54,16 +60,21 @@ def extract_tar(tar_file: File, dest_dir: str = ".") -> None:
tar.extractall(dest_dir)


def create_zip(zip_path: str, base_path: str, file_paths: Iterable[str]) -> File:
def create_zip(
zip_path: str, base_path: str, file_paths: Iterable[str], arcname_prefix: Optional[str] = None
) -> File:
"""
Create a zip file from local file paths.
Args:
arcname_prefix: prefix to add to each file path in the tar file.
"""
zip_file = File(zip_path)

with zip_file.open("wb") as out:
with zipfile.ZipFile(out, mode="w") as stream:
for file_path in file_paths:
arcname = os.path.relpath(file_path, base_path)
arcname = os.path.join(arcname_prefix, arcname) if arcname_prefix else arcname
stream.write(file_path, arcname)

return zip_file
Expand All @@ -82,9 +93,19 @@ def parse_code_package_config(config) -> Union[dict, bool]:
return {"includes": shlex.split(include_config), "excludes": shlex.split(exclude_config)}


def package_code(scratch_prefix: str, code_package: dict = {}, use_zip: bool = False) -> File:
def package_code(
scratch_prefix: str,
code_package: dict = {},
use_zip: bool = False,
basename: Optional[str] = None,
arcname_prefix: Optional[str] = None,
) -> File:
"""
Package code to scratch directory.
Args:
basename: If provided, uses this string as the basename instead of the
calculated tarball hash.
arcname_prefix: Optional suffix to append to tarball basename.
"""
with tempfile.TemporaryDirectory() as tmpdir:
file_paths = find_code_files(
Expand All @@ -93,14 +114,16 @@ def package_code(scratch_prefix: str, code_package: dict = {}, use_zip: bool = F

if use_zip:
temp_file = File(os.path.join(tmpdir, "code.zip"))
create_zip(temp_file.path, ".", file_paths)
create_zip(temp_file.path, ".", file_paths, arcname_prefix=arcname_prefix)
else:
temp_file = File(os.path.join(tmpdir, "code.tar.gz"))
create_tar(temp_file.path, file_paths)
create_tar(temp_file.path, file_paths, arcname_prefix=arcname_prefix)

with temp_file.open("rb") as infile:
tar_hash = hash_stream(infile)
code_file = File(get_code_scratch_file(scratch_prefix, tar_hash, use_zip=use_zip))
if not basename:
with temp_file.open("rb") as infile:
basename = hash_stream(infile)

code_file = File(get_code_scratch_file(scratch_prefix, basename, use_zip=use_zip))
if not code_file.exists():
temp_file.copy_to(code_file)

Expand Down
9 changes: 6 additions & 3 deletions redun/executors/k8s.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,11 +663,14 @@ def _process_redun_job(
pod: V1Pod,
job_status: str,
status_reason: Optional[str],
k8s_labels: List[Tuple[str, str]] = [],
k8s_labels: List[Tuple[str, str]] = None,
) -> None:
"""
Complete a redun job (done or reject).
"""
if k8s_labels is None:
k8s_labels = []

assert self._scheduler

if job_status == SUCCEEDED:
Expand All @@ -691,7 +694,7 @@ def _process_redun_job(

elif job_status == FAILED:
error, error_traceback = parse_job_error(self.scratch_prefix, job)
logs = [f"*** Logs for K8S pod {pod.metadata.name}:\n"]
logs = [f"*** Logs for K8S pod {pod.metadata.name}: \n"]

# TODO: Consider displaying events in the logs since this can have
# helpful info as well.
Expand Down Expand Up @@ -770,7 +773,7 @@ def _process_k8s_job_status(self, job: kubernetes.client.V1Job) -> None:

if "batch.kubernetes.io/job-completion-index" not in pod.metadata.annotations:
self.log(
f"Pod {pod.metadata.name} is missing job-completion-index:",
f"Pod {pod.metadata.name} is missing job-completion-index: ",
pod.metadata.annotations,
)
continue
Expand Down
63 changes: 62 additions & 1 deletion redun/executors/k8s_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ class K8SClient:
def __init__(self):
self._is_loaded = False
self._core: Optional[client.CoreV1Api] = None
self._custom_objects: Optional[client.CustomObjectsApi] = None
self._batch: Optional[client.BatchV1Api] = None
self._rbac: Optional[client.RbacAuthorizationV1Api] = None

def load_config(self) -> None:
"""
Expand Down Expand Up @@ -63,6 +65,18 @@ def core(self) -> client.CoreV1Api:
self._core = client.CoreV1Api()
return self._core

@property
def custom_objects(self) -> client.CustomObjectsApi:
"""
Returns an API client support k8s custom_objects API.
https://github.com/kubernetes-client/python/blob/master/kubernetes/docs/CustomObjectsApi.md
"""
if not self._custom_objects:
self.load_config()
self._custom_objects = client.CustomObjectsApi()
return self._custom_objects

@property
def batch(self) -> client.BatchV1Api:
"""
Expand All @@ -75,6 +89,18 @@ def batch(self) -> client.BatchV1Api:
self._batch = client.BatchV1Api()
return self._batch

@property
def rbac(self) -> client.RbacAuthorizationV1Api:
"""
Returns an API client support k8s RbacAuthorizationV1 API.
https://github.com/kubernetes-client/python/blob/master/kubernetes/docs/RbacAuthorizationV1Api.md
"""
if not self._rbac:
self.load_config()
self._rbac = client.RbacAuthorizationV1Api()
return self._rbac


def delete_k8s_secret(k8s_client: K8SClient, secret_name: str, namespace: str) -> None:
"""
Expand Down Expand Up @@ -215,7 +241,7 @@ def create_job_object(


def create_namespace(k8s_client: K8SClient, namespace: str) -> None:
"""Create a k8s namespace job"""
"""Create a k8s namespace"""
try:
k8s_client.core.create_namespace(
client.V1Namespace(metadata=client.V1ObjectMeta(name=namespace))
Expand All @@ -228,6 +254,41 @@ def create_namespace(k8s_client: K8SClient, namespace: str) -> None:
raise


def annotate_service_account(
k8s_client: K8SClient, namespace: str, sa_name: str, annotations_to_append: dict
) -> None:
"""Annotate a service account with additional annotations
Args:
namespace: Namespace wherein service account resides.
sa_name: Service account name.
annotations_to_append: Dict of annotations to append.
"""
sa = k8s_client.core.read_namespaced_service_account(sa_name, namespace)

existing_annotations = sa.metadata.annotations or {}
existing_annotations.update(annotations_to_append)
sa.metadata.annotations = existing_annotations
try:
k8s_client.core.replace_namespaced_service_account(sa_name, namespace, sa)
logger.info(
f"ServiceAccount: {sa_name} in namespace: {namespace} annotated with "
f"{existing_annotations}"
)
except client.exceptions.ApiException as error:
logger.error("Unexpected exception annotating SA: ", error.body)
raise


def delete_namespace(k8s_client: K8SClient, namespace: str) -> None:
"""Delete a k8s namespace"""
try:
k8s_client.core.delete_namespace(name=namespace)
except client.exceptions.ApiException as error:
logger.error("Unexpected exception deleting namespace", error.body)
raise


def create_job(k8s_client: K8SClient, job: client.V1Job, namespace: str) -> client.V1Job:
"""
Creates an actual k8s job.
Expand Down
36 changes: 36 additions & 0 deletions redun/tests/test_code_packaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,42 @@ def test_package_job_code() -> None:
}


@use_tempdir
def test_package_job_code_basename_and_arcname() -> None:
"""
basename should be reflected in tarball name.
arcname_prefix should be reflected in extracted file paths.
"""

# Creating python files.
File("workflow.py").write("")
File("lib/lib.py").write("")

# Package up code.
s3_scratch_prefix = "s3/"
code_package = {"include": ["**/*.py"]}
code_file = package_code(
s3_scratch_prefix,
code_package,
basename="my_tarball",
arcname_prefix="some_prefix",
)

# Code file prefix should have the basename before the hash.
assert code_file.path.startswith(os.path.join(s3_scratch_prefix, "code/my_tarball"))
assert code_file.path.endswith(".tar.gz")

# code_file should contain the files that have `some_prefix` in their paths
os.makedirs("dest")
extract_tar(code_file, "dest")

files = {file.path for file in Dir("dest")}
assert files == {
"dest/some_prefix/workflow.py",
"dest/some_prefix/lib/lib.py",
}


def test_parse_code_package_config():
# Parse default code_package patterns.
config = Config({"batch": {}})
Expand Down

0 comments on commit fb78bb0

Please sign in to comment.