diff --git a/pyfarm/jobtypes/core/jobtype.py b/pyfarm/jobtypes/core/jobtype.py index 6c665626..103fe6a3 100644 --- a/pyfarm/jobtypes/core/jobtype.py +++ b/pyfarm/jobtypes/core/jobtype.py @@ -531,6 +531,10 @@ def get_environment(self): """ Constructs an environment dictionary that can be used when a process is spawned by a job type. + + :raises TypeError: + Raised if ``jobtype_default_environment`` is defined + in the config and is not a dictionary. """ environment = {} config_environment = config.get("jobtype_default_environment") @@ -538,7 +542,12 @@ def get_environment(self): if config["jobtype_include_os_environ"]: environment.update(FROZEN_ENVIRONMENT) - if isinstance(config_environment, dict): + if config_environment is not None: + if not isinstance(config_environment, dict): + raise TypeError( + "Expected the `jobtype_default_environment` configuration " + "value to be a dictionary.") + environment.update(config_environment) elif config_environment is not None: @@ -550,20 +559,20 @@ def get_environment(self): # but it's possible to create an environment using the config which # contains non-strings. for key in environment: + value = environment.pop(key) if not isinstance(key, STRING_TYPES): logger.warning( "Environment key %r is not a string. It will be converted " "to a string.", key) - - value = environment.pop(key) key = str(key) - environment[key] = value - if not isinstance(environment[key], STRING_TYPES): + if not isinstance(value, STRING_TYPES): logger.warning( "Environment value for %r is not a string. It will be " "converted to a string.", key) - environment[key] = str(environment[key]) + value = str(value) + + environment[key] = value return environment diff --git a/tests/test_jobtypes/test_core_jobtype.py b/tests/test_jobtypes/test_core_jobtype.py index 51a278ab..57af601a 100644 --- a/tests/test_jobtypes/test_core_jobtype.py +++ b/tests/test_jobtypes/test_core_jobtype.py @@ -31,6 +31,7 @@ from pyfarm.agent.sysinfo.user import is_administrator from pyfarm.jobtypes.core.internals import USER_GROUP_TYPES from pyfarm.jobtypes.core.jobtype import ( + FROZEN_ENVIRONMENT, JobType, CommandData, JobType, CommandData, process_stdout, process_stderr, logger) from pyfarm.jobtypes.core.log import STDOUT, STDERR, logpool @@ -219,6 +220,64 @@ def test_schema(self): JobType.load({}) + +class TestJobTypeGetEnvironment(TestCase): + POP_CONFIG_KEYS = [ + "jobtype_default_environment", + "jobtype_include_os_environ" + ] + + def assertEnvironmentContains(self, target, contains): + for key, value in target.iteritems(): + self.assertIn(key, contains) + self.assertEqual(target[key], contains[key]) + + def assertDoesNotEnvironmentContains(self, target, contains): + for key, value in target.iteritems(): + self.assertNotIn(key, contains) + + def test_includes_os_environ(self): + config["jobtype_include_os_environ"] = True + jobtype = JobType(fake_assignment()) + self.assertEnvironmentContains( + FROZEN_ENVIRONMENT, jobtype.get_environment()) + + def test_does_not_include_os_environ(self): + config["jobtype_include_os_environ"] = False + jobtype = JobType(fake_assignment()) + self.assertDoesNotEnvironmentContains( + FROZEN_ENVIRONMENT, jobtype.get_environment()) + + def test_include_config_environment(self): + config_env = config["jobtype_default_environment"] = { + "foo": "1", "bar": "2" + } + jobtype = JobType(fake_assignment()) + self.assertEnvironmentContains(jobtype.get_environment(), config_env) + + def test_bad_config_raises_type_error(self): + config["jobtype_default_environment"] = [""] + jobtype = JobType(fake_assignment()) + + with self.assertRaises(TypeError): + jobtype.get_environment() + + def test_converts_non_string_key_to_string(self): + jobtype = JobType(fake_assignment()) + config["jobtype_default_environment"] = { + 1: "yes" + } + for key in jobtype.get_environment().keys(): + self.assertIsInstance(key, str) + + def test_converts_non_string_value_to_string(self): + jobtype = JobType(fake_assignment()) + config["jobtype_default_environment"] = { + "yes": 1 + } + for value in jobtype.get_environment().values(): + self.assertIsInstance(value, str) + class TestJobTypeFormatError(TestCase): def test_process_terminated(self): error = Failure(exc_value=Exception())