From 3c468d1a94c5c66405c0c0b189951f6bced6708a Mon Sep 17 00:00:00 2001 From: meowmeow Date: Sat, 17 Dec 2022 18:50:28 +0000 Subject: [PATCH] Detect the program name when `python -m` was executed --- fire/core.py | 42 +++++++++++++++++++++++++++++++++++++++++- fire/fire_test.py | 12 ++++++++++++ 2 files changed, 53 insertions(+), 1 deletion(-) diff --git a/fire/core.py b/fire/core.py index c1e97367..9e333c03 100644 --- a/fire/core.py +++ b/fire/core.py @@ -109,7 +109,8 @@ def Fire(component=None, command=None, name=None, serialize=None): code 2. When used with the help or trace flags, Fire will raise a FireExit with code 0 if successful. """ - name = name or os.path.basename(sys.argv[0]) + + name = _GetProgName(name) # Get args as a list. if isinstance(command, six.string_types): @@ -281,6 +282,45 @@ def _PrintResult(component_trace, verbose=False, serialize=None): Display(output, out=sys.stdout) +def _GetProgName(name, main=None): + """Determines the program name. + + This function returns the program name that should be + displayed. + + If ``python -m`` was used to execute a module, ``python -m name`` will be + returned, instead of ``__main__.py``. + + Args: + name: Optional. The name of the command as entered at the command line. + main: Optional. This should only be passed during testing. + Returns: + The program name determined by this function. + """ + if name: + return name + + name_from_arg = os.path.basename(sys.argv[0]) + + if main: + py_module = main + else: + py_module = sys.modules['__main__'].__package__ # pylint: disable=no-member + + if py_module is not None: + if name_from_arg == '__main__.py': + return '{executable} -m {module}'.format( + executable=sys.executable, module=py_module) + else: + # For example: python -m sample.cli + name = os.path.splitext(name_from_arg)[0] + py_module = '{module}.{name}'.format(module=py_module, name=name) + return '{executable} -m {module}'.format( + executable=sys.executable, module=py_module.lstrip('.')) + else: + return name_from_arg + + def _DisplayError(component_trace): """Prints the Fire trace and the error to stdout.""" result = component_trace.GetResult() diff --git a/fire/fire_test.py b/fire/fire_test.py index 8b904c29..20071843 100644 --- a/fire/fire_test.py +++ b/fire/fire_test.py @@ -24,6 +24,7 @@ import fire from fire import test_components as tc from fire import testutils +from fire.core import _GetProgName import mock import six @@ -63,6 +64,17 @@ def testFireDefaultName(self): stderr=None): fire.Fire(tc.Empty) + def testFireProgName(self): + # ``sample/__main__.py` would be argv[0] when ``python -m sample`` was + # executed. + with mock.patch.object(sys, 'argv', ['sample/__main__.py']): + self.assertEqual(_GetProgName('', main='sample'), + '{executable} -m sample'.format(executable=sys.executable)) + + with mock.patch.object(sys, 'argv', ['sample/cli.py']): + self.assertEqual(_GetProgName('', main='sample'), + '{executable} -m sample.cli'.format(executable=sys.executable)) + def testFireNoArgs(self): self.assertEqual(fire.Fire(tc.MixedDefaults, command=['ten']), 10)