diff --git a/greenplumpython/experimental/file.py b/greenplumpython/experimental/file.py index ac8d49bd..99c5a719 100644 --- a/greenplumpython/experimental/file.py +++ b/greenplumpython/experimental/file.py @@ -2,11 +2,14 @@ import inspect import io import pathlib +import sys import tarfile +import tempfile import uuid -from typing import get_type_hints +from typing import Any, get_type_hints import psycopg2 +import psycopg2.extensions import greenplumpython as gp from greenplumpython.func import NormalFunction @@ -15,20 +18,30 @@ @gp.create_function -def _dump_file_chunk(tmp_archive_name: str, chunk_base64: str) -> int: - tmp_archive_base = pathlib.Path("/") / "tmp" / tmp_archive_name - tmp_archive_base.mkdir(parents=True, exist_ok=True) - tmp_archive_path = tmp_archive_base / f"{tmp_archive_name}.tar.gz" +def _dump_file_chunk(tmp_dir_handle: str, chunk_base64: str) -> str: + try: + _gd = globals()["GD"] # type: ignore reportUnknownVariableType + except KeyError: + _gd = sys.modules["plpy"]._GD + if tmp_dir_handle not in _gd: + server_tmp_dir = tempfile.TemporaryDirectory(prefix="pygp.srv.") + _gd[tmp_dir_handle] = server_tmp_dir # Pin to GD for later UDFs + else: + server_tmp_dir = _gd[tmp_dir_handle] # type: ignore reportUnknownVariableType + + server_tmp_dir_path: pathlib.Path = pathlib.Path(server_tmp_dir.name) # type: ignore reportUnknownVariableType + server_tmp_dir_path.mkdir(parents=True, exist_ok=True) + tmp_archive_path = server_tmp_dir_path / f"{tmp_dir_handle}.tar.gz" with open(tmp_archive_path, "ab") as tmp_archive: tmp_archive.write(base64.b64decode(chunk_base64)) - return 0 + return server_tmp_dir.name @gp.create_function -def _extract_files(tmp_archive_name: str, returning: str) -> list[str]: - tmp_archive_base = pathlib.Path("/") / "tmp" / tmp_archive_name - tmp_archive_path = tmp_archive_base / f"{tmp_archive_name}.tar.gz" - extracted_root = tmp_archive_base / "extracted" +def _extract_files(server_tmp_dir: str, tmp_dir_handle: str, returning: str) -> list[str]: + server_tmp_dir_path: pathlib.Path = pathlib.Path(server_tmp_dir) + tmp_archive_path = server_tmp_dir_path / f"{tmp_dir_handle}.tar.gz" + extracted_root = server_tmp_dir_path / "extracted" if not extracted_root.exists(): with tarfile.open(tmp_archive_path, "r:gz") as tmp_archive: extracted_root.mkdir() @@ -43,17 +56,34 @@ def _extract_files(tmp_archive_name: str, returning: str) -> list[str]: yield str(path.resolve()) -def _archive_and_upload(tmp_archive_name: str, files: list[str], db: gp.Database): - tmp_archive_base = pathlib.Path("/") / "tmp" / tmp_archive_name - tmp_archive_base.mkdir(exist_ok=True) - tmp_archive_path = tmp_archive_base / f"{tmp_archive_name}.tar.gz" - with tarfile.open(tmp_archive_path, "w:gz") as tmp_archive: - for file_path in files: - tmp_archive.add(pathlib.Path(file_path)) - server_options = "-c gp_session_role=utility" if db._is_variant("greenplum") else None - with psycopg2.connect(db._dsn, options=server_options) as util_conn: # type: ignore reportUnknownVariableType - with util_conn.cursor() as cursor: # type: ignore reportUnknownVariableType - cursor.execute(f"CREATE TEMP TABLE {tmp_archive_name} (id serial, text_base64 text);") +def _remove_tmp_dir(conn: psycopg2.extensions.connection, db: gp.Database, tmp_dir_handle: str): + @gp.create_function + def udf(tmp_dir_handle: str) -> None: + try: + _gd = globals()["GD"] # type: ignore reportUnknownVariableType + except KeyError: + _gd = sys.modules["plpy"]._GD + _gd[tmp_dir_handle].cleanup() + + with conn.cursor() as cursor: + cursor.execute(udf._serialize(db)) + cursor.execute(f"SELECT {udf._qualified_name_str}('{tmp_dir_handle}');") + + +def _archive_and_upload( + util_conn: psycopg2.extensions.connection, + tmp_dir_handle: str, + files: list[str], + db: gp.Database, +) -> str: + with tempfile.TemporaryDirectory(prefix="pygp.cln.") as local_tmp_dir: + local_tmp_dir_path: pathlib.Path = pathlib.Path(local_tmp_dir) + tmp_archive_path = local_tmp_dir_path / f"{tmp_dir_handle}.tar.gz" + with tarfile.open(tmp_archive_path, "w:gz") as tmp_archive: + for file_path in files: + tmp_archive.add(pathlib.Path(file_path)) + with util_conn.cursor() as cursor: + cursor.execute(f"CREATE TEMP TABLE {tmp_dir_handle} (id serial, text_base64 text);") with open(tmp_archive_path, "rb") as tmp_archive: while True: chunk = tmp_archive.read(_CHUNK_SIZE) @@ -61,45 +91,47 @@ def _archive_and_upload(tmp_archive_name: str, files: list[str], db: gp.Database break chunk_base64 = base64.b64encode(chunk) cursor.copy_expert( - f"COPY {tmp_archive_name} (text_base64) FROM STDIN", + f"COPY {tmp_dir_handle} (text_base64) FROM STDIN", io.BytesIO(chunk_base64), ) util_conn.commit() - cursor.execute(_dump_file_chunk._serialize(db)) # type: ignore reportUnknownArgumentType + cursor.execute(_dump_file_chunk._serialize(db)) cursor.execute( f""" - SELECT {_dump_file_chunk._qualified_name_str}('{tmp_archive_name}', text_base64) - FROM "{tmp_archive_name}" + SELECT {_dump_file_chunk._qualified_name_str}('{tmp_dir_handle}', text_base64) + FROM "{tmp_dir_handle}" ORDER BY id; """ ) + return cursor.fetchall()[0][0] @classmethod def _from_files(_, files: list[str], parser: NormalFunction, db: gp.Database) -> gp.DataFrame: - tmp_archive_name = f"tar_{uuid.uuid4().hex}" - _archive_and_upload(tmp_archive_name, files, db) - func_sig = inspect.signature(parser.unwrap()) - result_members = get_type_hints(func_sig.return_annotation) - return db.apply( - lambda: parser(_extract_files(tmp_archive_name, "files")), - expand=len(result_members) == 0, - ) + tmp_dir_handle = f"__pygp_tar_{uuid.uuid4().hex}" + server_options = "-c gp_session_role=utility" if db._is_variant("greenplum") else None + with psycopg2.connect(db._dsn, options=server_options) as util_conn: # type: ignore reportUnknownVariableType + server_tmp_dir = _archive_and_upload(util_conn, tmp_dir_handle, files, db) # type: ignore reportUnknownArgumentType + func_sig = inspect.signature(parser.unwrap()) + result_members = get_type_hints(func_sig.return_annotation) + df = db.apply( + lambda: parser(_extract_files(server_tmp_dir, tmp_dir_handle, "files")), + expand=len(result_members) == 0, + ) + # _remove_tmp_dir(util_conn, db, tmp_dir_handle) # Cannot remove now since the returned DataFrame depends on it. + return df setattr(gp.DataFrame, "from_files", _from_files) - -import subprocess as sp -import sys +import subprocess @gp.create_function -def _install_on_server(pkg_dir: str, requirements: str) -> str: - import subprocess as sp - import sys - +def _install_on_server(server_tmp_dir: str, local_tmp_dir: str, requirements: str) -> str: assert sys.executable, "Python executable is required to install packages." + server_tmp_dir_path: pathlib.Path = pathlib.Path(server_tmp_dir) + local_tmp_dir_path = pathlib.Path(local_tmp_dir) cmd = [ sys.executable, "-m", @@ -109,48 +141,49 @@ def _install_on_server(pkg_dir: str, requirements: str) -> str: "--requirement", "/dev/stdin", "--find-links", - pkg_dir, + str( + server_tmp_dir_path + / "extracted" + / local_tmp_dir_path.relative_to(local_tmp_dir_path.root) + ), ] try: - output = sp.check_output(cmd, text=True, stderr=sp.STDOUT, input=requirements) + output = subprocess.check_output( + cmd, text=True, stderr=subprocess.STDOUT, input=requirements + ) return output - except sp.CalledProcessError as e: + except subprocess.CalledProcessError as e: raise Exception(e.stdout) def _install_packages(db: gp.Database, requirements: str): - tmp_archive_name = f"tar_{uuid.uuid4().hex}" - # FIXME: Windows client is not supported yet. - local_dir = pathlib.Path("/") / "tmp" / tmp_archive_name / "pip" - local_dir.mkdir(parents=True) - cmd = [ - sys.executable, - "-m", - "pip", - "download", - "--requirement", - "/dev/stdin", - "--dest", - local_dir, - ] - try: - sp.check_output(cmd, text=True, stderr=sp.STDOUT, input=requirements) - except sp.CalledProcessError as e: - raise e from Exception(e.stdout) - _archive_and_upload(tmp_archive_name, [local_dir.resolve()], db) - extracted = db.apply(lambda: _extract_files(tmp_archive_name, "root"), column_name="cache_dir") - assert len(list(extracted)) == 1 - server_dir = ( - pathlib.Path("/") - / "tmp" - / tmp_archive_name - / "extracted" - / local_dir.relative_to(local_dir.root) - ) - installed = extracted.apply( - lambda _: _install_on_server(server_dir.as_uri(), requirements), column_name="result" - ) - assert len(list(installed)) == 1 + tmp_dir_handle = f"__pygp_tar_{uuid.uuid4().hex}" + with tempfile.TemporaryDirectory(prefix="pygp.cln.") as local_pkg_dir: + local_tmp_dir_path = pathlib.Path(local_pkg_dir) + cmd = [ + sys.executable, + "-m", + "pip", + "download", + "--requirement", + "/dev/stdin", + "--dest", + str(local_tmp_dir_path), + ] + try: + subprocess.check_output(cmd, text=True, stderr=subprocess.STDOUT, input=requirements) + except subprocess.CalledProcessError as e: + raise e from Exception(e.stdout) + server_options = "-c gp_session_role=utility" if db._is_variant("greenplum") else None + with psycopg2.connect(db._dsn, options=server_options) as util_conn: # type: ignore reportUnknownVariableType + server_tmp_dir = _archive_and_upload(util_conn, tmp_dir_handle, [local_pkg_dir], db) # type: ignore reportUnknownArgumentType + extracted = db.apply(lambda: _extract_files(server_tmp_dir, tmp_dir_handle, "root")) + assert len(list(extracted)) == 1 + installed = extracted.apply( + lambda _: _install_on_server(server_tmp_dir, local_pkg_dir, requirements) + ) + assert len(list(installed)) == 1 + _remove_tmp_dir(util_conn, db, tmp_dir_handle) # type: ignore reportUnknownArgumentType setattr(gp.Database, "install_packages", _install_packages) diff --git a/greenplumpython/func.py b/greenplumpython/func.py index f4676f5e..b68198c2 100644 --- a/greenplumpython/func.py +++ b/greenplumpython/func.py @@ -345,6 +345,7 @@ def _serialize(self, db: Database) -> str: f" if {sysconfig_lib_name}.get_python_version() != '{python_version}':\n" f" raise ModuleNotFoundError\n" f" setattr({sys_lib_name}.modules['plpy'], '_SD', SD)\n" + f" setattr({sys_lib_name}.modules['plpy'], '_GD', GD)\n" f" GD['{func_ast.name}'] = {pickle_lib_name}.loads({func_pickled})\n" f" except ModuleNotFoundError:\n" f" exec({json.dumps(ast.unparse(func_ast))}, globals())\n" diff --git a/tests/test_file.py b/tests/test_file.py index f56ea3c7..9a946b3b 100644 --- a/tests/test_file.py +++ b/tests/test_file.py @@ -41,8 +41,10 @@ def test_csv_multi_chunks(db: gp.Database): default_chunk_size = greenplumpython.experimental.file._CHUNK_SIZE greenplumpython.experimental.file._CHUNK_SIZE = 3 assert greenplumpython.experimental.file._CHUNK_SIZE == 3 - test_csv_single_chunk(db) - greenplumpython.experimental.file._CHUNK_SIZE = default_chunk_size + try: + test_csv_single_chunk(db) + finally: + greenplumpython.experimental.file._CHUNK_SIZE = default_chunk_size import subprocess