diff --git a/sdks/python/apache_beam/runners/portability/stager.py b/sdks/python/apache_beam/runners/portability/stager.py index 9147410c2463..5c63c745ba92 100644 --- a/sdks/python/apache_beam/runners/portability/stager.py +++ b/sdks/python/apache_beam/runners/portability/stager.py @@ -165,7 +165,7 @@ def create_job_resources( build_setup_args: Optional[List[str]] = None, pypi_requirements: Optional[List[str]] = None, populate_requirements_cache: Optional[Callable[[str, str, bool], - None]] = None, + List[str]]] = None, skip_prestaged_dependencies: Optional[bool] = False, log_submission_env_dependencies: Optional[bool] = True, ): @@ -220,6 +220,7 @@ def create_job_resources( not os.path.exists(requirements_cache_path)): os.makedirs(requirements_cache_path) + downloaded_packages = [] # Stage a requirements file if present. if setup_options.requirements_file is not None: if not os.path.isfile(setup_options.requirements_file): @@ -245,12 +246,14 @@ def create_job_resources( 'such as --requirements_file. ') if setup_options.requirements_cache != SKIP_REQUIREMENTS_CACHE: - ( + result = ( populate_requirements_cache if populate_requirements_cache else Stager._populate_requirements_cache)( setup_options.requirements_file, requirements_cache_path, setup_options.requirements_cache_only_sources) + if result is not None: + downloaded_packages.extend(result) if pypi_requirements: tf = tempfile.NamedTemporaryFile(mode='w', delete=False) @@ -260,18 +263,18 @@ def create_job_resources( # Populate cache with packages from PyPI requirements and stage # the files in the cache. if setup_options.requirements_cache != SKIP_REQUIREMENTS_CACHE: - ( + result = ( populate_requirements_cache if populate_requirements_cache else Stager._populate_requirements_cache)( tf.name, requirements_cache_path, setup_options.requirements_cache_only_sources) + if result is not None: + downloaded_packages.extend(result) - if (setup_options.requirements_cache != SKIP_REQUIREMENTS_CACHE) and ( - setup_options.requirements_file is not None or pypi_requirements): - for pkg in glob.glob(os.path.join(requirements_cache_path, '*')): - resources.append( - Stager._create_file_stage_to_artifact(pkg, os.path.basename(pkg))) + for pkg in downloaded_packages: + resources.append( + Stager._create_file_stage_to_artifact(pkg, os.path.basename(pkg))) # Handle a setup file if present. # We will build the setup package locally and then copy it to the staging @@ -431,7 +434,7 @@ def create_and_stage_job_resources( temp_dir: Optional[str] = None, pypi_requirements: Optional[List[str]] = None, populate_requirements_cache: Optional[Callable[[str, str, bool], - None]] = None, + List[str]]] = None, staging_location: Optional[str] = None): """For internal use only; no backwards-compatibility guarantees. @@ -735,7 +738,9 @@ def _get_platform_for_default_sdk_container(): @retry.with_exponential_backoff( num_retries=4, retry_filter=retry_on_non_zero_exit) def _populate_requirements_cache( - requirements_file, cache_dir, populate_cache_with_sdists=False): + requirements_file, + cache_dir, + populate_cache_with_sdists=False) -> List[str]: # The 'pip download' command will not download again if it finds the # tarball with the proper version already present. # It will get the packages downloaded in the order they are presented in @@ -780,7 +785,12 @@ def _populate_requirements_cache( platform_tag ]) _LOGGER.info('Executing command: %s', cmd_args) - processes.check_output(cmd_args, stderr=processes.STDOUT) + output = processes.check_output(cmd_args, stderr=subprocess.STDOUT) + downloaded_packages = [] + for line in output.decode('utf-8').split('\n'): + if line.startswith('Saved '): + downloaded_packages.append(line.split(' ')[1]) + return downloaded_packages @staticmethod def _build_setup_package( diff --git a/sdks/python/apache_beam/runners/portability/stager_test.py b/sdks/python/apache_beam/runners/portability/stager_test.py index 22a41e592c2b..92e463c4a58a 100644 --- a/sdks/python/apache_beam/runners/portability/stager_test.py +++ b/sdks/python/apache_beam/runners/portability/stager_test.py @@ -98,8 +98,13 @@ def file_copy(self, from_path, to_path): def populate_requirements_cache( self, requirements_file, cache_dir, populate_cache_with_sdists=False): _ = requirements_file - self.create_temp_file(os.path.join(cache_dir, 'abc.txt'), 'nothing') - self.create_temp_file(os.path.join(cache_dir, 'def.txt'), 'nothing') + _ = populate_cache_with_sdists + pkgs = [ + os.path.join(cache_dir, 'abc.txt'), os.path.join(cache_dir, 'def.txt') + ] + for pkg in pkgs: + self.create_temp_file(pkg, 'nothing') + return pkgs @mock.patch('apache_beam.runners.portability.stager.open') @mock.patch('apache_beam.runners.portability.stager.get_new_http') @@ -807,10 +812,15 @@ def test_remove_dependency_from_requirements(self): def _populate_requitements_cache_fake( self, requirements_file, temp_dir, populate_cache_with_sdists): + paths = [] if not populate_cache_with_sdists: - self.create_temp_file(os.path.join(temp_dir, 'nothing.whl'), 'Fake whl') - self.create_temp_file( - os.path.join(temp_dir, 'nothing.tar.gz'), 'Fake tarball') + path = os.path.join(temp_dir, 'nothing.whl') + self.create_temp_file(path, 'nothing') + paths.append(path) + path = os.path.join(temp_dir, 'nothing.tar.gz') + self.create_temp_file(path, 'Fake tarball content') + paths.append(path) + return paths # requirements cache will popultated with bdist/whl if present # else source would be downloaded.