diff --git a/.gitignore b/.gitignore index d9516057c..4fd51e92a 100755 --- a/.gitignore +++ b/.gitignore @@ -32,3 +32,5 @@ libs/armeabi doc/source/developer_information/developer_guide/instrument_method_map.rst doc/source/run_config/ .eggs +*venv*/* +.vscode/* diff --git a/doc/build_instrument_method_map.py b/doc/build_instrument_method_map.py index 29d2d6f25..cbcc5a88a 100644 --- a/doc/build_instrument_method_map.py +++ b/doc/build_instrument_method_map.py @@ -22,7 +22,7 @@ from wa.framework.signal import CallbackPriority from wa.utils.doc import format_simple_table -OUTPUT_TEMPLATE_FILE = os.path.join(os.path.dirname(__file__), 'source', 'instrument_method_map.template') +OUTPUT_TEMPLATE_FILE = os.path.join(os.path.dirname(__file__), 'source', 'instrument_method_map.template') def generate_instrument_method_map(outfile): @@ -30,7 +30,7 @@ def generate_instrument_method_map(outfile): headers=['method name', 'signal'], align='<<') decorator_names = map(lambda x: x.replace('high', 'fast').replace('low', 'slow'), CallbackPriority.names) priority_table = format_simple_table(zip(decorator_names, CallbackPriority.names, CallbackPriority.values), - headers=['decorator', 'CallbackPriority name', 'CallbackPriority value'], align='<>') + headers=['decorator', 'CallbackPriority name', 'CallbackPriority value'], align='<>') with open(OUTPUT_TEMPLATE_FILE) as fh: template = string.Template(fh.read()) with open(outfile, 'w') as wfh: diff --git a/wa/__init__.py b/wa/__init__.py index c7969623f..f52333d1b 100644 --- a/wa/__init__.py +++ b/wa/__init__.py @@ -30,7 +30,7 @@ from wa.framework.resource import (NO_ONE, JarFile, ApkFile, ReventFile, File, Executable) from wa.framework.target.descriptor import (TargetDescriptor, TargetDescription, - create_target_description, add_description_for_target) + create_target_description) from wa.framework.workload import (Workload, ApkWorkload, ApkUiautoWorkload, ApkReventWorkload, UIWorkload, UiautoWorkload, PackageHandler, ReventWorkload, TestPackageHandler) diff --git a/wa/commands/create.py b/wa/commands/create.py index 85b4c6cc4..c87a29be2 100644 --- a/wa/commands/create.py +++ b/wa/commands/create.py @@ -26,9 +26,9 @@ from devlib.utils.types import identifier try: - import psycopg2 - from psycopg2 import connect, OperationalError, extras - from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT + import psycopg2 # type: ignore + from psycopg2 import connect, OperationalError, extras, _psycopg # type: ignore + from psycopg2.extensions import ISOLATION_LEVEL_AUTOCOMMIT # type: ignore except ImportError as e: psycopg2 = None import_error_msg = e.args[0] if e.args else str(e) @@ -41,6 +41,18 @@ ensure_file_directory_exists as _f) from wa.utils.postgres import get_schema, POSTGRES_SCHEMA_DIR from wa.utils.serializer import yaml +from typing import (Optional, TYPE_CHECKING, cast, OrderedDict as od, Any, IO, + Dict, Tuple, Pattern, List, Callable, Type) +from typing_extensions import TypedDict +from argparse import Namespace +from typing_extensions import Required +from uuid import UUID +if TYPE_CHECKING: + from wa.framework.pluginloader import __LoaderWrapper + from wa.framework.plugin import Plugin + from wa.framework.execution import ExecutionContext, ConfigManager + from wa.framework.target.descriptor import TargetDescriptionProtocol + if sys.version_info >= (3, 8): def copy_tree(src, dst): @@ -64,31 +76,39 @@ def copy_tree(src, dst): TEMPLATES_DIR = os.path.join(os.path.dirname(__file__), 'templates') +class PostgresType(TypedDict, total=False): + host: str + port: int + dbname: str + username: str + password: str + + class CreateDatabaseSubcommand(SubCommand): - name = 'database' - description = """ + name: str = 'database' + description: str = """ Create a Postgresql database which is compatible with the WA Postgres output processor. """ - schemafilepath = os.path.join(POSTGRES_SCHEMA_DIR, 'postgres_schema.sql') - schemaupdatefilepath = os.path.join(POSTGRES_SCHEMA_DIR, 'postgres_schema_update_v{}.{}.sql') + schemafilepath: str = os.path.join(POSTGRES_SCHEMA_DIR, 'postgres_schema.sql') + schemaupdatefilepath: str = os.path.join(POSTGRES_SCHEMA_DIR, 'postgres_schema_update_v{}.{}.sql') - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: super(CreateDatabaseSubcommand, self).__init__(*args, **kwargs) - self.sql_commands = None - self.schema_major = None - self.schema_minor = None - self.postgres_host = None - self.postgres_port = None - self.username = None - self.password = None - self.dbname = None - self.config_file = None - self.force = None - - def initialize(self, context): + self.sql_commands: Optional[str] = None + self.schema_major: Optional[int] = None + self.schema_minor: Optional[int] = None + self.postgres_host: Optional[str] = None + self.postgres_port: Optional[int] = None + self.username: Optional[str] = None + self.password: Optional[str] = None + self.dbname: Optional[str] = None + self.config_file: Optional[str] = None + self.force: Optional[bool] = None + + def initialize(self, context: Optional['ExecutionContext']) -> None: self.parser.add_argument( '-a', '--postgres-host', default='localhost', help='The host on which to create the database.') @@ -120,7 +140,7 @@ def initialize(self, context): '-U', '--upgrade', action='store_true', help='Upgrade the database to use the latest schema version.') - def execute(self, state, args): # pylint: disable=too-many-branches + def execute(self, state: 'ConfigManager', args: Namespace) -> None: # pylint: disable=too-many-branches if not psycopg2: raise CommandError( 'The module psycopg2 is required for the wa ' @@ -144,14 +164,14 @@ def execute(self, state, args): # pylint: disable=too-many-branches return # Open user configuration - with open(self.config_file, 'r') as config_file: - config = yaml.load(config_file) + with open(self.config_file or '', 'r') as config_file: + config: Dict[str, Any] = yaml.load(config_file) if 'postgres' in config and not args.force_update_config: raise CommandError( "The entry 'postgres' already exists in the config file. " + "Please specify the -F flag to force an update.") - possible_connection_errors = [ + possible_connection_errors: List[Tuple[Pattern[str], str]] = [ ( re.compile('FATAL: role ".*" does not exist'), 'Username does not exist or password is incorrect' @@ -178,7 +198,10 @@ def execute(self, state, args): # pylint: disable=too-many-branches ), ] - def predicate(error, handle): + def predicate(error: OperationalError, handle: Tuple[Pattern[str], str]) -> None: + """ + raise appropriate exception + """ if handle[0].match(str(error)): raise CommandError(handle[1] + ': \n' + str(error)) @@ -193,7 +216,10 @@ def predicate(error, handle): # Update the configuration file self._update_configuration_file(config) - def create_database(self): + def create_database(self) -> None: + """ + create postgresql database + """ self._validate_version() self._check_database_existence() @@ -205,21 +231,29 @@ def create_database(self): self.logger.info( "Successfully created the database {}".format(self.dbname)) - def update_schema(self): + def update_schema(self) -> None: + """ + update database schema + """ self._validate_version() schema_major, schema_minor, _ = get_schema(self.schemafilepath) - meta_oid, current_major, current_minor = self._get_database_schema_version() + meta_oid, current_major, current_minor = self._get_database_schema_version() or (None, None, None) while not (schema_major == current_major and schema_minor == current_minor): - current_minor = self._update_schema_minors(current_major, current_minor, meta_oid) - current_major, current_minor = self._update_schema_major(current_major, current_minor, meta_oid) + current_minor = self._update_schema_minors(current_major or 0, current_minor or 0, meta_oid) + current_major, current_minor = self._update_schema_major(current_major or 0, current_minor, meta_oid) msg = "Database schema update of '{}' to v{}.{} complete" self.logger.info(msg.format(self.dbname, schema_major, schema_minor)) - def _update_schema_minors(self, major, minor, meta_oid): + def _update_schema_minors(self, major: int, minor: int, meta_oid: Optional[UUID]) -> int: + """ + update schema minor versions + """ # Upgrade all available minor versions while True: + minor += 1 + schema_update = os.path.join(POSTGRES_SCHEMA_DIR, self.schemaupdatefilepath.format(major, minor)) if not os.path.exists(schema_update): @@ -227,16 +261,20 @@ def _update_schema_minors(self, major, minor, meta_oid): _, _, sql_commands = get_schema(schema_update) self._apply_database_schema(sql_commands, major, minor, meta_oid) - msg = "Updated the database schema to v{}.{}" + msg: str = "Updated the database schema to v{}.{}" self.logger.debug(msg.format(major, minor)) # Return last existing update file version return minor - 1 - def _update_schema_major(self, current_major, current_minor, meta_oid): + def _update_schema_major(self, current_major: int, current_minor: int, + meta_oid: Optional[UUID]) -> Tuple[int, int]: + """ + update schema major versions + """ current_major += 1 - schema_update = os.path.join(POSTGRES_SCHEMA_DIR, - self.schemaupdatefilepath.format(current_major, 0)) + schema_update: str = os.path.join(POSTGRES_SCHEMA_DIR, + self.schemaupdatefilepath.format(current_major, 0)) if not os.path.exists(schema_update): return (current_major - 1, current_minor) @@ -248,17 +286,25 @@ def _update_schema_major(self, current_major, current_minor, meta_oid): self.logger.debug(msg.format(current_major, current_minor)) return (current_major, current_minor) - def _validate_version(self): - conn = connect(user=self.username, - password=self.password, host=self.postgres_host, port=self.postgres_port) + def _validate_version(self) -> None: + """ + validate schema version + """ + conn: _psycopg.connection = connect(user=self.username, + password=self.password, host=self.postgres_host, + port=self.postgres_port) if conn.server_version < 90400: msg = 'Postgres version too low. Please ensure that you are using atleast v9.4' raise CommandError(msg) - def _get_database_schema_version(self): - conn = connect(dbname=self.dbname, user=self.username, - password=self.password, host=self.postgres_host, port=self.postgres_port) - cursor = conn.cursor() + def _get_database_schema_version(self) -> Optional[Tuple[UUID, int, int]]: + """ + get database schema version + """ + conn: _psycopg.connection = connect(dbname=self.dbname, user=self.username, + password=self.password, host=self.postgres_host, + port=self.postgres_port) + cursor: _psycopg.cursor = conn.cursor() cursor.execute('''SELECT DatabaseMeta.oid, DatabaseMeta.schema_major, @@ -267,7 +313,10 @@ def _get_database_schema_version(self): DatabaseMeta;''') return cursor.fetchone() - def _check_database_existence(self): + def _check_database_existence(self) -> None: + """ + check whether database exists + """ try: connect(dbname=self.dbname, user=self.username, password=self.password, host=self.postgres_host, port=self.postgres_port) @@ -282,22 +331,33 @@ def _check_database_existence(self): + "Please specify the -f flag to create it from afresh." ) - def _create_database_postgres(self): - conn = connect(dbname='postgres', user=self.username, - password=self.password, host=self.postgres_host, port=self.postgres_port) + def _create_database_postgres(self) -> None: + """ + create a postgresql database + """ + conn: _psycopg.connection = connect(dbname='postgres', user=self.username, + password=self.password, host=self.postgres_host, + port=self.postgres_port) conn.set_isolation_level(ISOLATION_LEVEL_AUTOCOMMIT) - cursor = conn.cursor() - cursor.execute('DROP DATABASE IF EXISTS ' + self.dbname) - cursor.execute('CREATE DATABASE ' + self.dbname) + cursor: _psycopg.cursor = conn.cursor() + cursor.execute('DROP DATABASE IF EXISTS ' + (self.dbname or '')) + cursor.execute('CREATE DATABASE ' + (self.dbname or '')) conn.commit() cursor.close() conn.close() - def _apply_database_schema(self, sql_commands, schema_major, schema_minor, meta_uuid=None): - conn = connect(dbname=self.dbname, user=self.username, - password=self.password, host=self.postgres_host, port=self.postgres_port) - cursor = conn.cursor() - cursor.execute(sql_commands) + def _apply_database_schema(self, sql_commands: Optional[str], schema_major: Optional[int], + schema_minor: Optional[int], + meta_uuid: Optional[UUID] = None) -> None: + """ + apply database schema + """ + conn: _psycopg.connection = connect(dbname=self.dbname, user=self.username, + password=self.password, host=self.postgres_host, + port=self.postgres_port) + cursor: _psycopg.cursor = conn.cursor() + if sql_commands: + cursor.execute(sql_commands) if not meta_uuid: extras.register_uuid() @@ -318,17 +378,17 @@ def _apply_database_schema(self, sql_commands, schema_major, schema_minor, meta_ cursor.close() conn.close() - def _update_configuration_file(self, config): + def _update_configuration_file(self, config: Dict[str, PostgresType]): ''' Update the user configuration file with the newly created database's configuration. ''' - config['postgres'] = OrderedDict( + config['postgres'] = cast(PostgresType, OrderedDict( [('host', self.postgres_host), ('port', self.postgres_port), - ('dbname', self.dbname), ('username', self.username), ('password', self.password)]) - with open(self.config_file, 'w+') as config_file: + ('dbname', self.dbname), ('username', self.username), ('password', self.password)])) + with open(self.config_file or '', 'w+') as config_file: yaml.dump(config, config_file) - def _parse_args(self, args): + def _parse_args(self, args: Namespace) -> None: self.postgres_host = args.postgres_host self.postgres_port = args.postgres_port self.username = args.username @@ -338,6 +398,25 @@ def _parse_args(self, args): self.force = args.force +class ConfigType(TypedDict, total=False): + device: str + device_config: Dict[str, Any] + augmentations: Required[List[str]] + energy_measurement: Required[Dict[str, Any]] + """more keys can be pulled in from loaded plugins""" + + +class WorkloadEntryType(TypedDict, total=False): + name: Optional[str] + label: str + params: Optional[Dict[str, Any]] + + +class AgendaType(TypedDict, total=False): + config: ConfigType + workloads: List + + class CreateAgendaSubcommand(SubCommand): name = 'agenda' @@ -346,7 +425,7 @@ class CreateAgendaSubcommand(SubCommand): to their default values. """ - def initialize(self, context): + def initialize(self, context: Optional['ExecutionContext']): self.parser.add_argument('plugins', nargs='+', help='Plugins to be added to the agendas') self.parser.add_argument('-i', '--iterations', type=int, default=1, @@ -355,13 +434,13 @@ def initialize(self, context): help='Output file. If not specfied, STDOUT will be used instead.') # pylint: disable=too-many-branches - def execute(self, state, args): - agenda = OrderedDict() - agenda['config'] = OrderedDict(augmentations=[], iterations=args.iterations) + def execute(self, state: 'ConfigManager', args: Namespace) -> None: + agenda: AgendaType = cast(AgendaType, OrderedDict()) + agenda['config'] = cast(ConfigType, OrderedDict(augmentations=[], iterations=args.iterations)) agenda['workloads'] = [] - target_desc = None + target_desc: Optional['TargetDescriptionProtocol'] = None - targets = {td.name: td for td in list_target_descriptions()} + targets: Dict[str, 'TargetDescriptionProtocol'] = {td.name: td for td in list_target_descriptions()} for name in args.plugins: if name in targets: @@ -372,34 +451,37 @@ def execute(self, state, args): agenda['config']['device_config'] = target_desc.get_default_config() continue - extcls = pluginloader.get_plugin_class(name) - config = pluginloader.get_default_config(name) + extcls: Type[Plugin] = cast('__LoaderWrapper', pluginloader).get_plugin_class(name) + config: Optional[Dict[str, Any]] = cast('__LoaderWrapper', pluginloader).get_default_config(name) # Handle special case for EnergyInstrumentBackends if issubclass(extcls, EnergyInstrumentBackend): if 'energy_measurement' not in agenda['config']['augmentations']: - energy_config = pluginloader.get_default_config('energy_measurement') + energy_config = cast('__LoaderWrapper', pluginloader).get_default_config('energy_measurement') agenda['config']['augmentations'].append('energy_measurement') - agenda['config']['energy_measurement'] = energy_config - agenda['config']['energy_measurement']['instrument'] = extcls.name + agenda['config']['energy_measurement'] = cast(Dict[str, Any], energy_config) + agenda['config']['energy_measurement']['instrument'] = cast(EnergyInstrumentBackend, extcls).name agenda['config']['energy_measurement']['instrument_parameters'] = config - elif extcls.kind == 'workload': - entry = OrderedDict() - entry['name'] = extcls.name - if name != extcls.name: + elif cast(Plugin, extcls).kind == 'workload': + entry: WorkloadEntryType = cast(WorkloadEntryType, OrderedDict()) + entry['name'] = cast(Plugin, extcls).name + if name != cast(Plugin, extcls).name: entry['label'] = name entry['params'] = config agenda['workloads'].append(entry) else: - if extcls.kind in ('instrument', 'output_processor'): - if extcls.name not in agenda['config']['augmentations']: - agenda['config']['augmentations'].append(extcls.name) + if cast(Plugin, extcls).kind in ('instrument', 'output_processor'): + if cast(Plugin, extcls).name not in agenda['config']['augmentations']: + agenda['config']['augmentations'].append(cast(Plugin, extcls).name or '') - if extcls.name not in agenda['config']: - agenda['config'][extcls.name] = config + if cast(Plugin, extcls).name not in agenda['config']: + # type error saying that the key should be one of those in the typeddict ConfigType + # but here it is assigning new keys from loaded plugin. so can be any name. + # so ignoring the type error + agenda['config'][cast(Plugin, extcls).name or ''] = config # type:ignore if args.output: - wfh = open(args.output, 'w') + wfh: IO = open(args.output, 'w') else: wfh = sys.stdout yaml.dump(agenda, wfh, indent=4, default_flow_style=False) @@ -409,11 +491,11 @@ def execute(self, state, args): class CreateWorkloadSubcommand(SubCommand): - name = 'workload' - description = '''Create a new workload. By default, a basic workload template will be + name: str = 'workload' + description: str = '''Create a new workload. By default, a basic workload template will be used but you can specify the `KIND` to choose a different template.''' - def initialize(self, context): + def initialize(self, context: Optional['ExecutionContext']) -> None: self.parser.add_argument('name', metavar='NAME', help='Name of the workload to be created') self.parser.add_argument('-p', '--path', metavar='PATH', default=None, @@ -426,9 +508,9 @@ def initialize(self, context): help='The type of workload to be created. The available options ' + 'are: {}'.format(', '.join(list(create_funcs.keys())))) - def execute(self, state, args): # pylint: disable=R0201 - where = args.path or 'local' - check_name = not args.force + def execute(self, state: 'ConfigManager', args: Namespace) -> None: # pylint: disable=R0201 + where: str = args.path or 'local' + check_name: bool = not args.force try: create_workload(args.name, args.kind, where, check_name) @@ -438,12 +520,12 @@ def execute(self, state, args): # pylint: disable=R0201 class CreatePackageSubcommand(SubCommand): - name = 'package' - description = '''Create a new empty Python package for WA extensions. On installation, + name: str = 'package' + description: str = '''Create a new empty Python package for WA extensions. On installation, this package will "advertise" itself to WA so that Plugins within it will be loaded by WA when it runs.''' - def initialize(self, context): + def initialize(self, context: Optional['ExecutionContext']) -> None: self.parser.add_argument('name', metavar='NAME', help='Name of the package to be created') self.parser.add_argument('-p', '--path', metavar='PATH', default=None, @@ -453,22 +535,26 @@ def initialize(self, context): help='Create the new package even if a file or directory with the same name ' 'already exists at the specified location.') - def execute(self, state, args): # pylint: disable=R0201 - package_dir = args.path or os.path.abspath('.') - template_path = os.path.join(TEMPLATES_DIR, 'setup.template') + def execute(self, state: 'ConfigManager', args: Namespace) -> None: # pylint: disable=R0201 + package_dir: str = args.path or os.path.abspath('.') + template_path: str = os.path.join(TEMPLATES_DIR, 'setup.template') self.create_extensions_package(package_dir, args.name, template_path, args.force) - def create_extensions_package(self, location, name, setup_template_path, overwrite=False): - package_path = os.path.join(location, name) + def create_extensions_package(self, location: str, name: str, setup_template_path: str, + overwrite: bool = False) -> None: + """ + create extensions package + """ + package_path: str = os.path.join(location, name) if os.path.exists(package_path): if overwrite: self.logger.info('overwriting existing "{}"'.format(package_path)) shutil.rmtree(package_path) else: raise CommandError('Location "{}" already exists.'.format(package_path)) - actual_package_path = os.path.join(package_path, name) + actual_package_path: str = os.path.join(package_path, name) os.makedirs(actual_package_path) - setup_text = render_template(setup_template_path, {'package_name': name, 'user': getpass.getuser()}) + setup_text: str = render_template(setup_template_path, {'package_name': name, 'user': getpass.getuser()}) with open(os.path.join(package_path, 'setup.py'), 'w') as wfh: wfh.write(setup_text) touch(os.path.join(actual_package_path, '__init__.py')) @@ -476,8 +562,8 @@ def create_extensions_package(self, location, name, setup_template_path, overwri class CreateCommand(ComplexCommand): - name = 'create' - description = ''' + name: str = 'create' + description: str = ''' Used to create various WA-related objects (see positional arguments list for what objects may be created).\n\nUse "wa create -h" for object-specific arguments. @@ -490,15 +576,19 @@ class CreateCommand(ComplexCommand): ] -def create_workload(name, kind='basic', where='local', check_name=True, **kwargs): +def create_workload(name: str, kind: str = 'basic', where: str = 'local', + check_name: bool = True, **kwargs) -> None: + """ + create workload + """ if check_name: - if name in [wl.name for wl in pluginloader.list_plugins('workload')]: + if name in [wl.name for wl in cast('__LoaderWrapper', pluginloader).list_plugins('workload')]: raise CommandError('Workload with name "{}" already exists.'.format(name)) - class_name = get_class_name(name) + class_name: str = get_class_name(name) if where == 'local': - workload_dir = _d(os.path.join(settings.plugins_directory, name)) + workload_dir: str = _d(os.path.join(settings.plugins_directory, name)) else: workload_dir = _d(os.path.join(where, name)) @@ -512,51 +602,62 @@ def create_workload(name, kind='basic', where='local', check_name=True, **kwargs print('Workload created in {}'.format(workload_dir)) -def create_template_workload(path, name, kind, class_name): - source_file = os.path.join(path, '__init__.py') +def create_template_workload(path: str, name: str, kind: str, + class_name: str) -> None: + """ + create template workload + """ + source_file: str = os.path.join(path, '__init__.py') with open(source_file, 'w') as wfh: wfh.write(render_template('{}_workload'.format(kind), {'name': name, 'class_name': class_name})) -def create_uiautomator_template_workload(path, name, kind, class_name): - uiauto_path = os.path.join(path, 'uiauto') +def create_uiautomator_template_workload(path: str, name: str, kind: str, + class_name: str) -> None: + """ + create ui automator template workload + """ + uiauto_path: str = os.path.join(path, 'uiauto') create_uiauto_project(uiauto_path, name) create_template_workload(path, name, kind, class_name) -def create_uiauto_project(path, name): - package_name = 'com.arm.wa.uiauto.' + name.lower() +def create_uiauto_project(path: str, name: str) -> None: + """ + create ui automator project + """ + package_name: str = 'com.arm.wa.uiauto.' + name.lower() copy_tree(os.path.join(TEMPLATES_DIR, 'uiauto', 'uiauto_workload_template'), path) - manifest_path = os.path.join(path, 'app', 'src', 'main') - mainifest = os.path.join(_d(manifest_path), 'AndroidManifest.xml') + manifest_path: str = os.path.join(path, 'app', 'src', 'main') + mainifest: str = os.path.join(_d(manifest_path), 'AndroidManifest.xml') with open(mainifest, 'w') as wfh: wfh.write(render_template(os.path.join('uiauto', 'uiauto_AndroidManifest.xml'), {'package_name': package_name})) - build_gradle_path = os.path.join(path, 'app') - build_gradle = os.path.join(_d(build_gradle_path), 'build.gradle') + build_gradle_path: str = os.path.join(path, 'app') + build_gradle: str = os.path.join(_d(build_gradle_path), 'build.gradle') with open(build_gradle, 'w') as wfh: wfh.write(render_template(os.path.join('uiauto', 'uiauto_build.gradle'), {'package_name': package_name})) - build_script = os.path.join(path, 'build.sh') + build_script: str = os.path.join(path, 'build.sh') with open(build_script, 'w') as wfh: wfh.write(render_template(os.path.join('uiauto', 'uiauto_build_script'), {'package_name': package_name})) os.chmod(build_script, stat.S_IRWXU | stat.S_IRWXG | stat.S_IRWXO) - source_file = _f(os.path.join(path, 'app', 'src', 'main', 'java', - os.sep.join(package_name.split('.')[:-1]), - 'UiAutomation.java')) + source_file: str = _f(os.path.join(path, 'app', 'src', 'main', 'java', + os.sep.join(package_name.split('.')[:-1]), + 'UiAutomation.java')) with open(source_file, 'w') as wfh: wfh.write(render_template(os.path.join('uiauto', 'UiAutomation.java'), {'name': name, 'package_name': package_name})) # Mapping of workload types to their corresponding creation method -create_funcs = { +create_funcs: Dict[str, Callable[[str, str, str, str], None]] = { 'basic': create_template_workload, 'apk': create_template_workload, 'revent': create_template_workload, @@ -567,19 +668,28 @@ def create_uiauto_project(path, name): # Utility functions -def render_template(name, params): - filepath = os.path.join(TEMPLATES_DIR, name) +def render_template(name: str, params: Dict[str, Any]) -> str: + """ + render template + """ + filepath: str = os.path.join(TEMPLATES_DIR, name) with open(filepath) as fh: - text = fh.read() + text: str = fh.read() template = string.Template(text) return template.substitute(params) -def get_class_name(name, postfix=''): +def get_class_name(name: str, postfix='') -> str: + """ + get class name + """ name = identifier(name) return ''.join(map(capitalize, name.split('_'))) + postfix -def touch(path): +def touch(path: str): + """ + clear file contents + """ with open(path, 'w') as _: # NOQA pass diff --git a/wa/commands/list.py b/wa/commands/list.py index f8a5cf0c4..60b0f017f 100644 --- a/wa/commands/list.py +++ b/wa/commands/list.py @@ -19,15 +19,21 @@ from wa.framework.target.descriptor import list_target_descriptions from wa.utils.doc import get_summary from wa.utils.formatter import DescriptionListFormatter +from argparse import Namespace +from typing import TYPE_CHECKING, cast, Optional, List, Dict, Type +if TYPE_CHECKING: + from wa.framework.pluginloader import __LoaderWrapper + from wa.framework.execution import ExecutionContext, ConfigManager + from wa.framework.plugin import Plugin class ListCommand(Command): - name = 'list' - description = 'List available WA plugins with a short description of each.' + name: str = 'list' + description: str = 'List available WA plugins with a short description of each.' - def initialize(self, context): - kinds = get_kinds() + def initialize(self, context: Optional['ExecutionContext']) -> None: + kinds: List[str] = get_kinds() kinds.extend(['augmentations', 'all']) self.parser.add_argument('kind', metavar='KIND', help=('Specify the kind of plugin to list. Must be ' @@ -48,8 +54,8 @@ def initialize(self, context): ''') # pylint: disable=superfluous-parens - def execute(self, state, args): - filters = {} + def execute(self, state: 'ConfigManager', args: Namespace) -> None: + filters: Dict[str, str] = {} if args.name: filters['name'] = args.name @@ -74,8 +80,11 @@ def execute(self, state, args): list_plugins(args, filters) -def get_kinds(): - kinds = pluginloader.kinds +def get_kinds() -> List[str]: + """ + get a list of kinds of commands + """ + kinds = cast('__LoaderWrapper', pluginloader).kinds if 'target_descriptor' in kinds: kinds.remove('target_descriptor') kinds.append('target') @@ -83,7 +92,10 @@ def get_kinds(): # pylint: disable=superfluous-parens -def list_targets(): +def list_targets() -> None: + """ + print out target descriptions + """ targets = list_target_descriptions() targets = sorted(targets, key=lambda x: x.name) @@ -94,8 +106,11 @@ def list_targets(): print('') -def list_plugins(args, filters): - results = pluginloader.list_plugins(args.kind[:-1]) +def list_plugins(args, filters: Dict[str, str]) -> None: + """ + print list of plugins + """ + results = cast('__LoaderWrapper', pluginloader).list_plugins(args.kind[:-1]) if filters or args.platform: filtered_results = [] for result in results: @@ -113,14 +128,14 @@ def list_plugins(args, filters): if filtered_results: output = DescriptionListFormatter() - for result in sorted(filtered_results, key=lambda x: x.name): - output.add_item(get_summary(result), result.name) + for result in sorted(filtered_results, key=lambda x: x.name or ''): + output.add_item(get_summary(result), result.name or '') print(output.format_data()) print('') -def check_platform(plugin, platform): +def check_platform(plugin: Type['Plugin'], platform: str) -> bool: supported_platforms = getattr(plugin, 'supported_platforms', []) if supported_platforms: return platform in supported_platforms diff --git a/wa/commands/process.py b/wa/commands/process.py index 7d312e602..97942c177 100644 --- a/wa/commands/process.py +++ b/wa/commands/process.py @@ -19,17 +19,25 @@ from wa import discover_wa_outputs from wa.framework.configuration.core import Status from wa.framework.exception import CommandError -from wa.framework.output import RunOutput +from wa.framework.output import RunOutput, JobOutput from wa.framework.output_processor import ProcessorManager from wa.utils import log +from argparse import Namespace +from typing import Optional, TYPE_CHECKING, List, cast +from types import ModuleType +if TYPE_CHECKING: + from wa.framework.target.info import TargetInfo + from wa.framework.execution import ExecutionContext, ConfigManager class ProcessContext(object): - - def __init__(self): - self.run_output = None - self.target_info = None - self.job_output = None + """ + process context + """ + def __init__(self) -> None: + self.run_output: Optional[RunOutput] = None + self.target_info: Optional['TargetInfo'] = None + self.job_output: Optional[JobOutput] = None def add_augmentation(self, aug): pass @@ -37,10 +45,10 @@ def add_augmentation(self, aug): class ProcessCommand(Command): - name = 'process' - description = 'Process the output from previously run workloads.' + name: str = 'process' + description: str = 'Process the output from previously run workloads.' - def initialize(self, context): + def initialize(self, context: Optional['ExecutionContext']) -> None: self.parser.add_argument('directory', metavar='DIR', help=""" Specify a directory containing the data @@ -69,14 +77,14 @@ def initialize(self, context): instead of just processing the root. """) - def execute(self, config, args): # pylint: disable=arguments-differ,too-many-branches,too-many-statements - process_directory = os.path.expandvars(args.directory) + def execute(self, config: 'ConfigManager', args: Namespace): # pylint: disable=arguments-differ,too-many-branches,too-many-statements + process_directory: str = os.path.expandvars(args.directory) self.logger.debug('Using process directory: {}'.format(process_directory)) if not os.path.exists(process_directory): - msg = 'Path `{}` does not exist, please specify a valid path.' + msg: str = 'Path `{}` does not exist, please specify a valid path.' raise CommandError(msg.format(process_directory)) if not args.recursive: - output_list = [RunOutput(process_directory)] + output_list: List[RunOutput] = [RunOutput(process_directory)] else: output_list = list(discover_wa_outputs(process_directory)) @@ -96,14 +104,14 @@ def execute(self, config, args): # pylint: disable=arguments-differ,too-many-br self.logger.info('Install output processors for run in path `{}`' .format(run_output.basepath)) - logfile = os.path.join(run_output.basepath, 'process.log') + logfile: str = os.path.join(run_output.basepath, 'process.log') i = 0 while os.path.exists(logfile): i += 1 logfile = os.path.join(run_output.basepath, 'process-{}.log'.format(i)) log.add_file(logfile) - pm = ProcessorManager(loader=config.plugin_cache) + pm = ProcessorManager(loader=cast(ModuleType, config.plugin_cache)) for proc in config.get_processors(): pm.install(proc, pc) if args.additional_processors: @@ -128,11 +136,12 @@ def execute(self, config, args): # pylint: disable=arguments-differ,too-many-br pc.job_output = job_output pm.enable_all() if not args.force: - for augmentation in job_output.spec.augmentations: - try: - pm.disable(augmentation) - except ValueError: - pass + if job_output.spec: + for augmentation in job_output.spec.augmentations: + try: + pm.disable(augmentation) + except ValueError: + pass msg = 'Processing job {} {} iteration {}' self.logger.info(msg.format(job_output.id, job_output.label, diff --git a/wa/commands/report.py b/wa/commands/report.py index f9390d961..63590aa5b 100644 --- a/wa/commands/report.py +++ b/wa/commands/report.py @@ -4,17 +4,24 @@ import os from wa import Command, settings -from wa.framework.configuration.core import Status -from wa.framework.output import RunOutput, discover_wa_outputs +from wa.framework.configuration.core import Status, RunConfigurationProtocol +from wa.framework.output import RunOutput, discover_wa_outputs, JobOutput from wa.utils.doc import underline from wa.utils.log import COLOR_MAP, RESET_COLOR from wa.utils.terminalsize import get_terminal_size +from argparse import Namespace +from typing import Optional, TYPE_CHECKING, List, cast, Dict, Tuple + +if TYPE_CHECKING: + from wa.framework.execution import ExecutionContext, ConfigManager + from wa.framework.run import RunInfo, RunState, JobState + from wa.framework.configuration.core import StatusType class ReportCommand(Command): - name = 'report' - description = ''' + name: str = 'report' + description: str = ''' Monitor an ongoing run and provide information on its progress. Specify the output directory of the run you would like the monitor; @@ -47,7 +54,7 @@ class ReportCommand(Command): zero -- will not appear). ''' - def initialize(self, context): + def initialize(self, context: Optional['ExecutionContext']) -> None: self.parser.add_argument('-d', '--directory', help=''' Specify the WA output path. report will @@ -55,13 +62,13 @@ def initialize(self, context): directories in the current directory. ''') - def execute(self, state, args): + def execute(self, state: 'ConfigManager', args: Namespace) -> None: if args.directory: - output_path = args.directory + output_path: str = args.directory run_output = RunOutput(output_path) else: - possible_outputs = list(discover_wa_outputs(os.getcwd())) - num_paths = len(possible_outputs) + possible_outputs: List[RunOutput] = list(discover_wa_outputs(os.getcwd())) + num_paths: int = len(possible_outputs) if num_paths > 1: print('More than one possible output directory found,' @@ -73,7 +80,7 @@ def execute(self, state, args): while True: try: - select = int(input()) + select: int = int(input()) except ValueError: print("Please select a valid path number") continue @@ -93,18 +100,35 @@ def execute(self, state, args): class RunMonitor: + """ + run monitor + """ + def __init__(self, ro: RunOutput): + self.ro = ro + self._elapsed: Optional[timedelta] = None + self._p_duration: Optional[timedelta] = None + self._job_outputs: Optional[Dict[Tuple[str, str, int], JobOutput]] = None + self._termwidth = None + self._fmt = _simple_formatter() + self.get_data() @property - def elapsed_time(self): + def elapsed_time(self) -> Optional[timedelta]: + """ + elapsed time + """ if self._elapsed is None: - if self.ro.info.duration is None: - self._elapsed = datetime.utcnow() - self.ro.info.start_time + if cast('RunInfo', self.ro.info).duration is None: + self._elapsed = datetime.utcnow() - cast(datetime, cast('RunInfo', self.ro.info).start_time) else: - self._elapsed = self.ro.info.duration + self._elapsed = cast('RunInfo', self.ro.info).duration return self._elapsed @property - def job_outputs(self): + def job_outputs(self) -> Dict[Tuple[str, str, int], JobOutput]: + """ + job outputs + """ if self._job_outputs is None: self._job_outputs = { (j_o.id, j_o.label, j_o.iteration): j_o for j_o in self.ro.jobs @@ -112,69 +136,74 @@ def job_outputs(self): return self._job_outputs @property - def projected_duration(self): - elapsed = self.elapsed_time.total_seconds() + def projected_duration(self) -> timedelta: + """ + projected duration for the run + """ + elapsed: float = cast(timedelta, self.elapsed_time).total_seconds() proj = timedelta(seconds=elapsed * (len(self.jobs) / len(self.segmented['finished']))) - return proj - self.elapsed_time + return proj - cast(timedelta, self.elapsed_time) - def __init__(self, ro): - self.ro = ro - self._elapsed = None - self._p_duration = None - self._job_outputs = None - self._termwidth = None - self._fmt = _simple_formatter() - self.get_data() - - def get_data(self): - self.jobs = [state for label_id, state in self.ro.state.jobs.items()] + def get_data(self) -> None: + """ + get job run data + """ + self.jobs: List['JobState'] = [state for label_id, state in cast('RunState', self.ro.state).jobs.items()] if self.jobs: - rc = self.ro.run_config - self.segmented = segment_jobs_by_state(self.jobs, - rc.max_retries, - rc.retry_on_status - ) - - def generate_run_header(self): - info = self.ro.info + rc = cast(RunConfigurationProtocol, self.ro.run_config) + if rc: + self.segmented = segment_jobs_by_state(self.jobs, + rc.max_retries, + rc.retry_on_status + ) + + def generate_run_header(self) -> str: + """ + generate run header + """ + info: Optional['RunInfo'] = self.ro.info header = underline('Run Info') - header += "UUID: {}\n".format(info.uuid) - if info.run_name: - header += "Run name: {}\n".format(info.run_name) - if info.project: - header += "Project: {}\n".format(info.project) - if info.project_stage: - header += "Project stage: {}\n".format(info.project_stage) - - if info.start_time: - duration = _seconds_as_smh(self.elapsed_time.total_seconds()) - header += ("Start time: {}\n" - "Duration: {:02}:{:02}:{:02}\n" - ).format(info.start_time, - duration[2], duration[1], duration[0], - ) - if self.segmented['finished'] and not info.end_time: - p_duration = _seconds_as_smh(self.projected_duration.total_seconds()) - header += "Projected time remaining: {:02}:{:02}:{:02}\n".format( - p_duration[2], p_duration[1], p_duration[0] - ) - - elif self.ro.info.end_time: - header += "End time: {}\n".format(info.end_time) + if info: + header += "UUID: {}\n".format(info.uuid) + if info.run_name: + header += "Run name: {}\n".format(info.run_name) + if info.project: + header += "Project: {}\n".format(info.project) + if info.project_stage: + header += "Project stage: {}\n".format(info.project_stage) + + if info.start_time: + duration = _seconds_as_smh(cast(timedelta, self.elapsed_time).total_seconds()) + header += ("Start time: {}\n" + "Duration: {:02}:{:02}:{:02}\n" + ).format(info.start_time, + duration[2], duration[1], duration[0], + ) + if self.segmented['finished'] and not info.end_time: + p_duration = _seconds_as_smh(self.projected_duration.total_seconds()) + header += "Projected time remaining: {:02}:{:02}:{:02}\n".format( + p_duration[2], p_duration[1], p_duration[0] + ) + + elif info.end_time: + header += "End time: {}\n".format(info.end_time) return header + '\n' - def generate_job_summary(self): - total = len(self.jobs) - num_fin = len(self.segmented['finished']) + def generate_job_summary(self) -> str: + """ + generate job summary + """ + total: int = len(self.jobs) + num_fin: int = len(self.segmented['finished']) - summary = underline('Job Summary') + summary: str = underline('Job Summary') summary += 'Total: {}, Completed: {} ({}%)\n'.format( total, num_fin, (num_fin / total) * 100 ) if total > 0 else 'No jobs created\n' - ctr = Counter() + ctr: Counter = Counter() for run_state, jobs in ((k, v) for k, v in self.segmented.items() if v): if run_state == 'finished': ctr.update([job.status.name.lower() for job in jobs]) @@ -185,8 +214,11 @@ def generate_job_summary(self): [str(count) + ' ' + self._fmt.highlight_keyword(status) for status, count in ctr.items()] ) + '\n\n' - def generate_job_detail(self): - detail = underline('Job Detail') + def generate_job_detail(self) -> str: + """ + generate job detail + """ + detail: str = underline('Job Detail') for job in self.jobs: detail += ('{} ({}) [{}]{}, {}\n').format( job.id, @@ -196,26 +228,32 @@ def generate_job_detail(self): self._fmt.highlight_keyword(str(job.status)) ) - job_output = self.job_outputs[(job.id, job.label, job.iteration)] + job_output: JobOutput = self.job_outputs[(job.id, job.label, job.iteration)] for event in job_output.events: detail += self._fmt.fit_term_width( '\t{}\n'.format(event.summary) ) return detail - def generate_run_detail(self): - detail = underline('Run Events') if self.ro.events else '' + def generate_run_detail(self) -> str: + """ + generate run detail + """ + detail: str = underline('Run Events') if self.ro.events else '' for event in self.ro.events: detail += '{}\n'.format(event.summary) return detail + '\n' - def generate_output(self, verbose): + def generate_output(self, verbose: bool) -> str: + """ + generate output + """ if not self.jobs: return 'No jobs found in output directory\n' - output = self.generate_run_header() + output: str = self.generate_run_header() output += self.generate_job_summary() if verbose: @@ -225,7 +263,7 @@ def generate_output(self, verbose): return output -def _seconds_as_smh(seconds): +def _seconds_as_smh(seconds: float) -> Tuple[int, int, int]: seconds = int(seconds) hours = seconds // 3600 minutes = (seconds % 3600) // 60 @@ -233,13 +271,17 @@ def _seconds_as_smh(seconds): return seconds, minutes, hours -def segment_jobs_by_state(jobstates, max_retries, retry_status): - finished_states = [ +def segment_jobs_by_state(jobstates: List['JobState'], max_retries: int, + retry_status: List['StatusType']) -> Dict[str, List['JobState']]: + """ + segment jobs by jobstate + """ + finished_states: List['StatusType'] = [ Status.PARTIAL, Status.FAILED, Status.ABORTED, Status.OK, Status.SKIPPED ] - segmented = { + segmented: Dict[str, List['JobState']] = { 'finished': [], 'other': [], 'running': [], 'pending': [], 'uninitialized': [] } @@ -262,25 +304,34 @@ def segment_jobs_by_state(jobstates, max_retries, retry_status): class _simple_formatter: - color_map = { + """ + formatter for output in report + """ + color_map: Dict[str, str] = { 'running': COLOR_MAP[logging.INFO], 'partial': COLOR_MAP[logging.WARNING], 'failed': COLOR_MAP[logging.CRITICAL], 'aborted': COLOR_MAP[logging.ERROR] } - def __init__(self): - self.termwidth = get_terminal_size()[0] - self.color = settings.logging['color'] + def __init__(self) -> None: + self.termwidth: int = get_terminal_size()[0] + self.color: bool = settings.logging['color'] - def fit_term_width(self, text): + def fit_term_width(self, text: str) -> str: + """ + fit to the terminal width + """ text = text.expandtabs() if len(text) <= self.termwidth: return text else: return text[0:self.termwidth - 4] + " ...\n" - def highlight_keyword(self, kw): + def highlight_keyword(self, kw: str) -> str: + """ + highlight keyword + """ if not self.color or kw not in _simple_formatter.color_map: return kw diff --git a/wa/commands/revent.py b/wa/commands/revent.py index ab46f7e6d..6bf8eb2b6 100644 --- a/wa/commands/revent.py +++ b/wa/commands/revent.py @@ -20,15 +20,24 @@ from wa import Command from wa.framework import pluginloader from wa.framework.exception import ConfigError -from wa.framework.resource import ResourceResolver +from wa.framework.resource import ResourceResolver, Resource from wa.framework.target.manager import TargetManager from wa.utils.revent import ReventRecorder +from devlib.target import Target +from argparse import _MutuallyExclusiveGroup, Namespace +from typing import (cast, Optional, TYPE_CHECKING, Dict, + Tuple, Callable) +if TYPE_CHECKING: + from wa.framework.pluginloader import __LoaderWrapper + from wa.framework.execution import ExecutionContext, ConfigManager + from wa.framework.configuration.core import ConfigurationPoint, RunConfigurationProtocol + from wa.framework.workload import Workload, ApkWorkload class RecordCommand(Command): - name = 'record' - description = ''' + name: str = 'record' + description: str = ''' Performs a revent recording This command helps making revent recordings. It will automatically @@ -53,13 +62,13 @@ class RecordCommand(Command): or optionally ``-a`` to indicate all stages should be recorded. ''' - def __init__(self, **kwargs): + def __init__(self, **kwargs) -> None: super(RecordCommand, self).__init__(**kwargs) - self.tm = None - self.target = None - self.revent_recorder = None + self.tm: Optional[TargetManager] = None + self.target: Optional[Target] = None + self.revent_recorder: Optional[ReventRecorder] = None - def initialize(self, context): + def initialize(self, context: Optional['ExecutionContext']) -> None: self.parser.add_argument('-d', '--device', metavar='DEVICE', help=''' Specify the device on which to run. This will @@ -81,11 +90,14 @@ def initialize(self, context): # Need validation self.parser.add_argument('-C', '--clear', help='Clear app cache before launching it', action='store_true') - group = self.parser.add_mutually_exclusive_group(required=False) + group: _MutuallyExclusiveGroup = self.parser.add_mutually_exclusive_group(required=False) group.add_argument('-p', '--package', help='Android package to launch before recording') group.add_argument('-w', '--workload', help='Name of a revent workload (mostly games)') - def validate_args(self, args): + def validate_args(self, args: Namespace) -> None: + """ + validate arguments + """ if args.clear and not (args.package or args.workload): self.logger.error("Package/Workload must be specified if you want to clear cache") sys.exit() @@ -100,15 +112,15 @@ def validate_args(self, args): self.logger.error("Please specify which workload stages you wish to record") sys.exit() - def execute(self, state, args): + def execute(self, state: 'ConfigManager', args: Namespace) -> None: self.validate_args(args) state.run_config.merge_device_config(state.plugin_cache) if args.device: device = args.device - device_config = {} + device_config: Dict[str, 'ConfigurationPoint'] = {} else: - device = state.run_config.device - device_config = state.run_config.device_config or {} + device = cast('RunConfigurationProtocol', state.run_config).device + device_config = cast(Dict[str, 'ConfigurationPoint'], state.run_config.device_config) or {} if args.output: outdir = os.path.basename(args.output) @@ -118,8 +130,9 @@ def execute(self, state, args): self.tm = TargetManager(device, device_config, outdir) self.tm.initialize() self.target = self.tm.target - self.revent_recorder = ReventRecorder(self.target) - self.revent_recorder.deploy() + if self.target: + self.revent_recorder = ReventRecorder(self.target) + self.revent_recorder.deploy() if args.workload: self.workload_record(args) @@ -127,24 +140,29 @@ def execute(self, state, args): self.package_record(args) else: self.manual_record(args) - - self.revent_recorder.remove() - - def record(self, revent_file, name, output_path): - msg = 'Press Enter when you are ready to record {}...' + if self.revent_recorder: + self.revent_recorder.remove() + + def record(self, revent_file: Optional[str], name: str, output_path: str) -> None: + """ + record commands + """ + msg: str = 'Press Enter when you are ready to record {}...' self.logger.info(msg.format(name)) input('') - self.revent_recorder.start_record(revent_file) + if self.revent_recorder: + self.revent_recorder.start_record(revent_file) msg = 'Press Enter when you have finished recording {}...' self.logger.info(msg.format(name)) input('') - self.revent_recorder.stop_record() + if self.revent_recorder: + self.revent_recorder.stop_record() if not os.path.isdir(output_path): os.makedirs(output_path) - revent_file_name = self.target.path.basename(revent_file) - host_path = os.path.join(output_path, revent_file_name) + revent_file_name: str = self.target.path.basename(revent_file) if self.target and self.target.path else '' + host_path: str = os.path.join(output_path, revent_file_name) if os.path.exists(host_path): msg = 'Revent file \'{}\' already exists, overwrite? [y/n]' self.logger.info(msg.format(revent_file_name)) @@ -155,17 +173,26 @@ def record(self, revent_file, name, output_path): self.logger.warning(msg.format(revent_file_name)) return msg = 'Pulling \'{}\' from device' - self.logger.info(msg.format(self.target.path.basename(revent_file))) - self.target.pull(revent_file, output_path, as_root=self.target.is_rooted) - - def manual_record(self, args): + self.logger.info(msg.format(self.target.path.basename(revent_file) if self.target and self.target.path else '')) + if self.target: + self.target.pull(revent_file, output_path, as_root=self.target.is_rooted) + + def manual_record(self, args: Namespace) -> None: + """ + record manually + """ output_path, file_name = self._split_revent_location(args.output) - revent_file = self.target.get_workpath(file_name) + revent_file = self.target.get_workpath(file_name) if self.target else '' self.record(revent_file, '', output_path) msg = 'Recording is available at: \'{}\'' self.logger.info(msg.format(os.path.join(output_path, file_name))) - def package_record(self, args): + def package_record(self, args: Namespace) -> None: + """ + record package execution on android + """ + if self.target is None: + raise ConfigError('Target is None') if self.target.os != 'android' and self.target.os != 'chromeos': raise ConfigError('Target does not appear to be running Android') if self.target.os == 'chromeos' and not self.target.supports_android: @@ -173,36 +200,39 @@ def package_record(self, args): if args.clear: self.target.execute('pm clear {}'.format(args.package)) self.logger.info('Starting {}'.format(args.package)) - cmd = 'monkey -p {} -c android.intent.category.LAUNCHER 1' + cmd: str = 'monkey -p {} -c android.intent.category.LAUNCHER 1' self.target.execute(cmd.format(args.package)) output_path, file_name = self._split_revent_location(args.output) - revent_file = self.target.get_workpath(file_name) + revent_file: Optional[str] = self.target.get_workpath(file_name) self.record(revent_file, '', output_path) msg = 'Recording is available at: \'{}\'' self.logger.info(msg.format(os.path.join(output_path, file_name))) - def workload_record(self, args): + def workload_record(self, args: Namespace) -> None: + """ + record workload execution + """ context = LightContext(self.tm) - setup_revent = '{}.setup.revent'.format(self.target.model) - run_revent = '{}.run.revent'.format(self.target.model) - extract_results_revent = '{}.extract_results.revent'.format(self.target.model) - teardown_file_revent = '{}.teardown.revent'.format(self.target.model) - setup_file = self.target.get_workpath(setup_revent) - run_file = self.target.get_workpath(run_revent) - extract_results_file = self.target.get_workpath(extract_results_revent) - teardown_file = self.target.get_workpath(teardown_file_revent) + setup_revent: str = '{}.setup.revent'.format(self.target.model if self.target else '') + run_revent: str = '{}.run.revent'.format(self.target.model if self.target else '') + extract_results_revent: str = '{}.extract_results.revent'.format(self.target.model if self.target else '') + teardown_file_revent: str = '{}.teardown.revent'.format(self.target.model if self.target else '') + setup_file: Optional[str] = self.target.get_workpath(setup_revent) if self.target else '' + run_file: Optional[str] = self.target.get_workpath(run_revent) if self.target else '' + extract_results_file: Optional[str] = self.target.get_workpath(extract_results_revent) if self.target else '' + teardown_file: Optional[str] = self.target.get_workpath(teardown_file_revent) if self.target else '' self.logger.info('Deploying {}'.format(args.workload)) - workload = pluginloader.get_workload(args.workload, self.target) + workload: 'Workload' = cast('__LoaderWrapper', pluginloader).get_workload(args.workload, self.target) # Setup apk if android workload if hasattr(workload, 'apk'): - workload.apk.initialize(context) - workload.apk.setup(context) - sleep(workload.loading_time) + cast('ApkWorkload', workload).apk.initialize(cast('ExecutionContext', context)) + cast('ApkWorkload', workload).apk.setup(cast('ExecutionContext', context)) + sleep(cast('ApkWorkload', workload).loading_time) - output_path = os.path.join(workload.dependencies_directory, - 'revent_files') + output_path: str = os.path.join(workload.dependencies_directory, + 'revent_files') if args.setup or args.all: self.record(setup_file, 'SETUP', output_path) if args.run or args.all: @@ -212,17 +242,20 @@ def workload_record(self, args): if args.teardown or args.all: self.record(teardown_file, 'TEARDOWN', output_path) self.logger.info('Tearing down {}'.format(args.workload)) - workload.teardown(context) + workload.teardown(cast('ExecutionContext', context)) self.logger.info('Recording(s) are available at: \'{}\''.format(output_path)) - def _split_revent_location(self, output): - output_path = None - file_name = None + def _split_revent_location(self, output: str) -> Tuple[str, str]: + """ + split the output location string into path and file name + """ + output_path: Optional[str] = None + file_name: Optional[str] = None if output: output_path, file_name, = os.path.split(output) if not file_name: - file_name = '{}.revent'.format(self.target.model) + file_name = '{}.revent'.format(self.target.model if self.target else '') if not output_path: output_path = os.getcwd() @@ -231,15 +264,15 @@ def _split_revent_location(self, output): class ReplayCommand(Command): - name = 'replay' - description = ''' + name: str = 'replay' + description: str = ''' Replay a revent recording Revent allows you to record raw inputs such as screen swipes or button presses. See ``wa show record`` to see how to make an revent recording. ''' - def initialize(self, context): + def initialize(self, context: Optional['ExecutionContext']) -> None: self.parser.add_argument('recording', help='The name of the file to replay', metavar='FILE') self.parser.add_argument('-d', '--device', help='The name of the device') @@ -248,34 +281,36 @@ def initialize(self, context): action="store_true") # pylint: disable=W0201 - def execute(self, state, args): + def execute(self, state: 'ConfigManager', args: Namespace) -> None: state.run_config.merge_device_config(state.plugin_cache) if args.device: device = args.device - device_config = {} + device_config: Dict[str, 'ConfigurationPoint'] = {} else: - device = state.run_config.device - device_config = state.run_config.device_config or {} + device = cast('RunConfigurationProtocol', state.run_config).device + device_config = cast(Dict[str, 'ConfigurationPoint'], state.run_config.device_config) or {} target_manager = TargetManager(device, device_config, None) target_manager.initialize() self.target = target_manager.target - revent_file = self.target.path.join(self.target.working_directory, - os.path.split(args.recording)[1]) + revent_file: str = self.target.path.join(self.target.working_directory, + os.path.split(args.recording)[1]) if self.target and self.target.path else '' self.logger.info("Pushing file to target") - self.target.push(args.recording, self.target.working_directory) - - revent_recorder = ReventRecorder(target_manager.target) - revent_recorder.deploy() + self.target.push(args.recording, self.target.working_directory) if self.target else '' + if target_manager.target: + revent_recorder = ReventRecorder(target_manager.target) + revent_recorder.deploy() if args.clear: - self.target.execute('pm clear {}'.format(args.package)) + if self.target: + self.target.execute('pm clear {}'.format(args.package)) if args.package: self.logger.info('Starting {}'.format(args.package)) cmd = 'monkey -p {} -c android.intent.category.LAUNCHER 1' - self.target.execute(cmd.format(args.package)) + if self.target: + self.target.execute(cmd.format(args.package)) self.logger.info("Starting replay") revent_recorder.replay(revent_file) @@ -285,16 +320,24 @@ def execute(self, state, args): # Used to satisfy the workload API class LightContext(object): - + """ + light execution context for satisfying workload api + """ def __init__(self, tm): self.tm = tm self.resolver = ResourceResolver() self.resolver.load() - def get_resource(self, resource, strict=True): + def get_resource(self, resource: Resource, strict: bool = True) -> Optional[str]: + """ + get path to the resource + """ return self.resolver.get(resource, strict) - def update_metadata(self, key, *args): + def update_metadata(self, key: str, *args): + """ + update metadata + """ pass - get = get_resource + get: Callable[..., Optional[str]] = get_resource diff --git a/wa/commands/run.py b/wa/commands/run.py index 91d6296bc..0bfc5954c 100644 --- a/wa/commands/run.py +++ b/wa/commands/run.py @@ -23,21 +23,28 @@ from wa.framework import pluginloader from wa.framework.configuration.parsers import AgendaParser from wa.framework.execution import Executor -from wa.framework.output import init_run_output +from wa.framework.output import init_run_output, RunOutput from wa.framework.exception import NotFoundError, ConfigError from wa.utils import log from wa.utils.types import toggle_set +from argparse import Namespace +from typing import (Optional, TYPE_CHECKING, cast, List, Dict, + Any) +if TYPE_CHECKING: + from wa.framework.execution import ExecutionContext, ConfigManager + from wa.framework.pluginloader import __LoaderWrapper + from wa.framework.configuration.core import RunConfigurationProtocol class RunCommand(Command): - name = 'run' - description = ''' + name: str = 'run' + description: str = ''' Execute automated workloads on a remote device and process the resulting output. ''' - def initialize(self, context): + def initialize(self, context: Optional['ExecutionContext']) -> None: self.parser.add_argument('agenda', metavar='AGENDA', help=""" Agenda for this workload automation run. This @@ -84,8 +91,8 @@ def initialize(self, context): be specified multiple times. """) - def execute(self, config, args): # pylint: disable=arguments-differ - output = self.set_up_output_directory(config, args) + def execute(self, config: 'ConfigManager', args: Namespace) -> None: # pylint: disable=arguments-differ + output: RunOutput = self.set_up_output_directory(config, args) log.add_file(output.logfile) output.add_artifact('runlog', output.logfile, kind='log', description='Run log.') @@ -97,30 +104,34 @@ def execute(self, config, args): # pylint: disable=arguments-differ parser = AgendaParser() if os.path.isfile(args.agenda): - includes = parser.load_from_path(config, args.agenda) + includes: List[str] = parser.load_from_path(config, args.agenda) shutil.copy(args.agenda, output.raw_config_dir) for inc in includes: shutil.copy(inc, output.raw_config_dir) else: try: - pluginloader.get_plugin_class(args.agenda, kind='workload') - agenda = {'workloads': [{'name': args.agenda}]} + cast('__LoaderWrapper', pluginloader).get_plugin_class(args.agenda, kind='workload') + agenda: Dict[str, List[Dict[str, Any]]] = {'workloads': [{'name': args.agenda}]} parser.load(config, agenda, 'CMDLINE_ARGS') except NotFoundError: - msg = 'Agenda file "{}" does not exist, and there no workload '\ - 'with that name.\nYou can get a list of available '\ - 'by running "wa list workloads".' + msg: str = 'Agenda file "{}" does not exist, and there no workload '\ + 'with that name.\nYou can get a list of available '\ + 'by running "wa list workloads".' raise ConfigError(msg.format(args.agenda)) # Update run info with newly parsed config values - output.info.project = config.run_config.project - output.info.project_stage = config.run_config.project_stage - output.info.run_name = config.run_config.run_name + if output.info: + output.info.project = cast('RunConfigurationProtocol', config.run_config).project + output.info.project_stage = cast('RunConfigurationProtocol', config.run_config).project_stage + output.info.run_name = cast('RunConfigurationProtocol', config.run_config).run_name executor = Executor() executor.execute(config, output) - def set_up_output_directory(self, config, args): + def set_up_output_directory(self, config: 'ConfigManager', args: Namespace) -> RunOutput: + """ + set up the run output directory + """ if args.output_directory: output_directory = args.output_directory else: diff --git a/wa/commands/show.py b/wa/commands/show.py index 071b627e1..ca2c5a9e3 100644 --- a/wa/commands/show.py +++ b/wa/commands/show.py @@ -27,42 +27,49 @@ from wa.framework import pluginloader from wa.framework.configuration.core import MetaConfiguration, RunConfiguration from wa.framework.exception import NotFoundError -from wa.framework.target.descriptor import list_target_descriptions +from wa.framework.target.descriptor import list_target_descriptions, TargetDescriptionProtocol from wa.utils.types import caseless_string, identifier from wa.utils.doc import (strip_inlined_text, get_rst_from_plugin, get_params_rst, underline) from wa.utils.misc import which +from typing import TYPE_CHECKING, cast, Optional, List, Type +from argparse import Namespace +if TYPE_CHECKING: + from wa.framework.execution import ExecutionContext, ConfigManager + from wa.framework.pluginloader import __LoaderWrapper + from wa.framework.plugin import Plugin + from wa.framework.configuration.core import ConfigurationPoint class ShowCommand(Command): - name = 'show' - description = 'Display documentation for the specified plugin (workload, instrument, etc.).' + name: str = 'show' + description: str = 'Display documentation for the specified plugin (workload, instrument, etc.).' - def initialize(self, context): + def initialize(self, context: Optional['ExecutionContext']) -> None: self.parser.add_argument('plugin', metavar='PLUGIN', help='The name of the plugin to display documentation for.') - def execute(self, state, args): - name = identifier(args.plugin) - rst_output = None + def execute(self, state: 'ConfigManager', args: Namespace) -> None: + name: str = identifier(args.plugin) + rst_output: Optional[str] = None if name == caseless_string('settings'): rst_output = get_rst_for_global_config() rst_output += get_rst_for_envars() - plugin_name = name.lower() - kind = 'global:' + plugin_name: str = name.lower() + kind: str = 'global:' else: try: - plugin = pluginloader.get_plugin_class(name) + plugin: Optional[Type['Plugin']] = cast('__LoaderWrapper', pluginloader).get_plugin_class(name) except NotFoundError: plugin = None if plugin: rst_output = get_rst_from_plugin(plugin) - plugin_name = plugin.name + plugin_name = plugin.name or '' kind = '{}:'.format(plugin.kind) else: - target = get_target_description(name) + target: Optional[TargetDescriptionProtocol] = get_target_description(name) if target: rst_output = get_rst_from_target(target) plugin_name = target.name @@ -73,8 +80,8 @@ def execute(self, state, args): if which('pandoc'): p = Popen(['pandoc', '-f', 'rst', '-t', 'man'], stdin=PIPE, stdout=PIPE, stderr=PIPE) - output, _ = p.communicate(rst_output.encode(sys.stdin.encoding)) - output = output.decode(sys.stdout.encoding) + output_, _ = p.communicate(rst_output.encode(sys.stdin.encoding)) + output = output_.decode(sys.stdout.encoding) # Make sure to double escape back slashes output = output.replace('\\', '\\\\\\') @@ -89,17 +96,24 @@ def execute(self, state, args): print(rst_output) # pylint: disable=superfluous-parens -def get_target_description(name): +def get_target_description(name: str) -> Optional[TargetDescriptionProtocol]: + """ + get target description + """ targets = list_target_descriptions() for target in targets: if name == identifier(target.name): return target + return None -def get_rst_from_target(target): - text = underline(target.name, '~') +def get_rst_from_target(target: TargetDescriptionProtocol) -> str: + """ + get restructured text from target description + """ + text: str = underline(target.name, '~') if hasattr(target, 'description'): - desc = strip_inlined_text(target.description or '') + desc: str = strip_inlined_text(target.description or '') text += desc text += underline('Device Parameters:', '-') text += get_params_rst(target.conn_params) @@ -110,19 +124,25 @@ def get_rst_from_target(target): return text + '\n' -def get_rst_for_global_config(): - text = underline('Global Configuration') +def get_rst_for_global_config() -> str: + """ + get restructured text for global configuration + """ + text: str = underline('Global Configuration') text += 'These parameters control the behaviour of WA/run as a whole, they ' \ 'should be set inside a config file (either located in ' \ '$WA_USER_DIRECTORY/config.yaml or one which is specified with -c), ' \ 'or into config/global section of the agenda.\n\n' - cfg_points = MetaConfiguration.config_points + RunConfiguration.config_points + cfg_points: List['ConfigurationPoint'] = MetaConfiguration.config_points + RunConfiguration.config_points text += get_params_rst(cfg_points) return text -def get_rst_for_envars(): +def get_rst_for_envars() -> str: + """ + get restructured text for environment variables + """ text = underline('Environment Variables') text += '''WA_USER_DIRECTORY: str This is the location WA will look for config.yaml, plugins, dependencies, diff --git a/wa/framework/command.py b/wa/framework/command.py index 856d81255..118087846 100644 --- a/wa/framework/command.py +++ b/wa/framework/command.py @@ -19,9 +19,17 @@ from wa.framework.plugin import Plugin from wa.framework.version import get_wa_version from wa.utils.doc import format_body +from typing import Optional, List, Type, Dict, cast, Any, TYPE_CHECKING +from argparse import ArgumentParser, _SubParsersAction, Namespace +import logging +if TYPE_CHECKING: + from wa.framework.execution import ExecutionContext, ConfigManager -def init_argument_parser(parser): +def init_argument_parser(parser: ArgumentParser): + """ + initialize argument parser + """ parser.add_argument('-c', '--config', action='append', default=[], help='specify an additional config.yaml') parser.add_argument('-v', '--verbose', action='count', @@ -40,33 +48,33 @@ class SubCommand(object): command line arguments. """ - name = None - help = None - usage = None - description = None - epilog = None - formatter_class = None - - def __init__(self, logger, subparsers): + name: Optional[str] = None + help: Optional[str] = None + usage: Optional[str] = None + description: Optional[str] = None + epilog: Optional[str] = None + formatter_class: Optional[str] = None + + def __init__(self, logger: logging.Logger, subparsers: _SubParsersAction): self.logger = logger self.group = subparsers - desc = format_body(textwrap.dedent(self.description), 80) - parser_params = dict(help=(self.help or self.description), usage=self.usage, - description=desc, epilog=self.epilog) + desc = format_body(textwrap.dedent(self.description or ''), 80) + parser_params: Dict[str, Any] = dict(help=(self.help or self.description), usage=self.usage, + description=desc, epilog=self.epilog) if self.formatter_class: parser_params['formatter_class'] = self.formatter_class - self.parser = subparsers.add_parser(self.name, **parser_params) + self.parser: ArgumentParser = subparsers.add_parser(self.name or '', **parser_params) init_argument_parser(self.parser) # propagate top-level options self.initialize(None) - def initialize(self, context): + def initialize(self, context: Optional['ExecutionContext']) -> None: """ Perform command-specific initialisation (e.g. adding command-specific options to the command's parser). ``context`` is always ``None``. """ - def execute(self, state, args): + def execute(self, state: 'ConfigManager', args: Namespace) -> None: """ Execute this command. @@ -90,9 +98,9 @@ class Command(Plugin, SubCommand): # pylint: disable=abstract-method command line arguments. """ - kind = "command" + kind: str = "command" - def __init__(self, subparsers): + def __init__(self, subparsers: _SubParsersAction): Plugin.__init__(self) SubCommand.__init__(self, self.logger, subparsers) @@ -103,20 +111,20 @@ class ComplexCommand(Command): """ - subcmd_classes = [] + subcmd_classes: List[Type[SubCommand]] = [] - def __init__(self, subparsers): - self.subcommands = [] + def __init__(self, subparsers: _SubParsersAction): + self.subcommands: List[SubCommand] = [] super(ComplexCommand, self).__init__(subparsers) - def initialize(self, context): - subparsers = self.parser.add_subparsers(dest='what', metavar='SUBCMD') + def initialize(self, context: Optional['ExecutionContext']) -> None: + subparsers: _SubParsersAction[ArgumentParser] = self.parser.add_subparsers(dest='what', metavar='SUBCMD') subparsers.required = True for subcmd_cls in self.subcmd_classes: - subcmd = subcmd_cls(self.logger, subparsers) + subcmd: SubCommand = subcmd_cls(self.logger, subparsers) self.subcommands.append(subcmd) - def execute(self, state, args): + def execute(self, state: 'ConfigManager', args: Namespace) -> None: for subcmd in self.subcommands: if subcmd.name == args.what: subcmd.execute(state, args) diff --git a/wa/framework/configuration/core.py b/wa/framework/configuration/core.py index 85b771bfb..ba217cf61 100644 --- a/wa/framework/configuration/core.py +++ b/wa/framework/configuration/core.py @@ -18,26 +18,60 @@ from collections import OrderedDict, defaultdict from wa.framework.exception import ConfigError, NotFoundError -from wa.framework.configuration.tree import SectionNode +from wa.framework.configuration.tree import SectionNode, JobSpecSource from wa.utils import log from wa.utils.misc import (get_article, merge_config_values) from wa.utils.types import (identifier, integer, boolean, list_of_strings, list_of, toggle_set, obj_dict, enum) from wa.utils.serializer import is_pod, Podable - +from typing import (Optional, Any, List, Dict, Union, Callable, cast, + Tuple, TYPE_CHECKING, DefaultDict, OrderedDict as od, + Set) +from typing_extensions import Protocol +if TYPE_CHECKING: + from wa.framework.configuration.plugin_cache import PluginCache + from wa.framework.plugin import Plugin + from wa.framework.target.manager import TargetManager + from wa.framework.workload import Workload + from wa.framework.instrument import Instrument + from wa.framework.output_processor import OutputProcessor + # When type-checking, pretend StatusType is a standard Enum + # (or anything you want the type checker to see). + import enum as en + + class StatusType(en.Enum): + UNKNOWN = 0 + NEW = 1 + PENDING = 2 + STARTED = 3 + CONNECTED = 4 + INITIALIZED = 5 + RUNNING = 6 + OK = 7 + PARTIAL = 8 + FAILED = 9 + ABORTED = 10 + SKIPPED = 11 + + @classmethod + def from_pod(cls, pod: Dict[str, Any]) -> 'Podable': + ... + + def to_pod(self) -> Dict[str, Any]: + ... # Mapping for kind conversion; see docs for convert_types below -KIND_MAP = { +KIND_MAP: Dict[Callable, Callable] = { int: integer, bool: boolean, - dict: OrderedDict, + dict: od, } Status = enum(['UNKNOWN', 'NEW', 'PENDING', 'STARTED', 'CONNECTED', 'INITIALIZED', 'RUNNING', 'OK', 'PARTIAL', 'FAILED', 'ABORTED', 'SKIPPED']) -logger = logging.getLogger('config') +logger: logging.Logger = logging.getLogger('config') ########################## @@ -62,39 +96,54 @@ class RebootPolicy(object): """ - valid_policies = ['never', 'as_needed', 'initial', 'each_spec', 'each_job', 'run_completion'] + valid_policies: List[str] = ['never', 'as_needed', 'initial', 'each_spec', 'each_job', 'run_completion'] @staticmethod - def from_pod(pod): + def from_pod(pod: Podable) -> 'RebootPolicy': return RebootPolicy(pod) - def __init__(self, policy): + def __init__(self, policy: Union['RebootPolicy', str, Podable]): if isinstance(policy, RebootPolicy): policy = policy.policy - policy = policy.strip().lower().replace(' ', '_') + policy = cast(str, policy).strip().lower().replace(' ', '_') if policy not in self.valid_policies: message = 'Invalid reboot policy {}; must be one of {}'.format(policy, ', '.join(self.valid_policies)) raise ConfigError(message) - self.policy = policy + self.policy: str = policy @property - def can_reboot(self): + def can_reboot(self) -> bool: + """ + True if reboot policy is not 'never' + """ return self.policy != 'never' @property - def perform_initial_reboot(self): + def perform_initial_reboot(self) -> bool: + """ + True if reboot policy is 'initial' + """ return self.policy == 'initial' @property - def reboot_on_each_job(self): + def reboot_on_each_job(self) -> bool: + """ + True if reboot policy is 'each_job' + """ return self.policy == 'each_job' @property - def reboot_on_each_spec(self): + def reboot_on_each_spec(self) -> bool: + """ + True if reboot policy is 'each_spec' + """ return self.policy == 'each_spec' @property def reboot_on_run_completion(self): + """ + True if reboot policy is 'run_completion' + """ return self.policy == 'run_completion' def __str__(self): @@ -108,21 +157,32 @@ def __eq__(self, other): else: return self.policy == other - def to_pod(self): + def to_pod(self) -> str: return self.policy class status_list(list): - def append(self, item): + def append(self, item: Any): list.append(self, str(item).upper()) class LoggingConfig(Podable, dict): + """ + WA logging configuration. This should be a dict with a subset + of the following keys:: + + :normal_format: Logging format used for console output + :verbose_format: Logging format used for verbose console output + :file_format: Logging format used for run.log + :color: If ``True`` (the default), console logging output will + contain bash color escape codes. Set this to ``False`` if + console output will be piped somewhere that does not know + how to handle those. + """ + _pod_serialization_version: int = 1 - _pod_serialization_version = 1 - - defaults = { + defaults: Dict[str, Union[str, bool]] = { 'file_format': '%(asctime)s %(levelname)-8s %(name)s: %(message)s', 'verbose_format': '%(asctime)s %(levelname)-8s %(name)s: %(message)s', 'regular_format': '%(levelname)-8s %(message)s', @@ -130,14 +190,14 @@ class LoggingConfig(Podable, dict): } @staticmethod - def from_pod(pod): + def from_pod(pod: Dict) -> 'LoggingConfig': pod = LoggingConfig._upgrade_pod(pod) - pod_version = pod.pop('_pod_version') + pod_version: int = cast(Dict[str, int], pod).pop('_pod_version') instance = LoggingConfig(pod) instance._pod_version = pod_version # pylint: disable=protected-access return instance - def __init__(self, config=None): + def __init__(self, config: Optional[Dict[str, Any]] = None): super(LoggingConfig, self).__init__() dict.__init__(self) if isinstance(config, dict): @@ -156,25 +216,28 @@ def __init__(self, config=None): else: raise ValueError(config) - def to_pod(self): + def to_pod(self) -> Dict: pod = super(LoggingConfig, self).to_pod() pod.update(self) return pod @staticmethod - def _pod_upgrade_v1(pod): + def _pod_upgrade_v1(pod: Dict) -> Dict: pod['_pod_version'] = pod.get('_pod_version', 1) return pod -def expanded_path(path): +def expanded_path(path: str) -> str: """ Ensure that the provided path has been expanded if applicable """ return os.path.expanduser(str(path)) -def get_type_name(kind): +def get_type_name(kind) -> str: + """ + get the type name + """ typename = str(kind) if '\'' in typename: typename = typename.split('\'')[1] @@ -190,18 +253,18 @@ class ConfigurationPoint(object): """ - def __init__(self, name, - kind=None, - mandatory=None, - default=None, - override=False, - allowed_values=None, - description=None, - constraint=None, - merge=False, - aliases=None, - global_alias=None, - deprecated=False): + def __init__(self, name: str, + kind: Optional[Callable] = None, + mandatory: Optional[bool] = None, + default: Any = None, + override: bool = False, + allowed_values: Optional[List[Any]] = None, + description: Optional[str] = None, + constraint: Optional[Union[Callable[..., bool], Tuple[Callable[..., bool], str]]] = None, + merge: bool = False, + aliases: Optional[List[str]] = None, + global_alias: Optional[str] = None, + deprecated: bool = False): """ Create a new Parameter object. @@ -257,7 +320,7 @@ def __init__(self, name, a warning to the user however will continue execution. """ self.name = identifier(name) - kind = KIND_MAP.get(kind, kind) + kind = KIND_MAP.get(kind, kind) if kind else None if kind is not None and not callable(kind): raise ValueError('Kind must be callable.') self.kind = kind @@ -285,30 +348,39 @@ def __init__(self, name, except ConfigError: raise ValueError('Default value "{}" is not valid'.format(self.default)) - def match(self, name): + def match(self, name: str) -> bool: + """ + check whether name is matching the config point name or its aliases + """ if name == self.name or name in self.aliases: return True elif name == self.global_alias: return True return False - def set_value(self, obj, value=None, check_mandatory=True): + def set_value(self, obj: Union['Configuration', obj_dict, 'Plugin'], value: Any = None, + check_mandatory: bool = True) -> None: + """ + set the value to the configuration + """ if self.deprecated: if value is not None: - msg = 'Depreciated parameter supplied for "{}" in "{}". The value will be ignored.' + msg: str = 'Depreciated parameter supplied for "{}" in "{}". The value will be ignored.' logger.warning(msg.format(self.name, obj.name)) return if value is None: if self.default is not None: - value = self.kind(self.default) + if self.kind: + value = self.kind(self.default) elif check_mandatory and self.mandatory: msg = 'No values specified for mandatory parameter "{}" in {}' raise ConfigError(msg.format(self.name, obj.name)) else: try: - value = self.kind(value) + if self.kind: + value = self.kind(value) except (ValueError, TypeError): - typename = get_type_name(self.kind) + typename: str = get_type_name(self.kind) msg = 'Bad value "{}" for {}; must be {} {}' article = get_article(typename) raise ConfigError(msg.format(value, self.name, article, typename)) @@ -318,36 +390,48 @@ def set_value(self, obj, value=None, check_mandatory=True): value = merge_config_values(getattr(obj, self.name), value) setattr(obj, self.name, value) - def validate(self, obj, check_mandatory=True): + def validate(self, obj: Union['Configuration', 'Plugin'], check_mandatory: bool = True) -> None: + """ + validate the value and deprecated as well as mandatory status + """ if self.deprecated: return value = getattr(obj, self.name, None) if value is not None: - self.validate_value(obj.name, value) + self.validate_value(obj.name or '', value) else: if check_mandatory and self.mandatory: msg = 'No value specified for mandatory parameter "{}" in {}.' raise ConfigError(msg.format(self.name, obj.name)) - def validate_value(self, name, value): + def validate_value(self, name: str, value: Any) -> None: + """ + validate the value against allowed values or constraints + """ if self.allowed_values: self.validate_allowed_values(name, value) if self.constraint: self.validate_constraint(name, value) - def validate_allowed_values(self, name, value): + def validate_allowed_values(self, name: str, value: Any) -> None: + """ + validate against allowed values + """ if 'list' in str(self.kind): for v in value: - if v not in self.allowed_values: + if self.allowed_values and v not in self.allowed_values: msg = 'Invalid value {} for {} in {}; must be in {}' raise ConfigError(msg.format(v, self.name, name, self.allowed_values)) else: - if value not in self.allowed_values: + if self.allowed_values and value not in self.allowed_values: msg = 'Invalid value {} for {} in {}; must be in {}' raise ConfigError(msg.format(value, self.name, name, self.allowed_values)) - def validate_constraint(self, name, value): - msg_vals = {'value': value, 'param': self.name, 'plugin': name} + def validate_constraint(self, name: str, value: Any) -> None: + """ + validate against the constraints + """ + msg_vals: Dict[str, Any] = {'value': value, 'param': self.name, 'plugin': name} if isinstance(self.constraint, tuple) and len(self.constraint) == 2: constraint, msg = self.constraint # pylint: disable=unpacking-non-sequence elif callable(self.constraint): @@ -371,32 +455,58 @@ def __repr__(self): ##################### -def _to_pod(cfg_point, value): +def _to_pod(cfg_point: ConfigurationPoint, value: Any) -> Any: + """ + convert value to a plain old datatype (pod) + """ if is_pod(value): return value if hasattr(cfg_point.kind, 'to_pod'): - return value.to_pod() - msg = '{} value "{}" is not serializable' + return cast(Podable, value).to_pod() + msg: str = '{} value "{}" is not serializable' raise ValueError(msg.format(cfg_point.name, value)) class Configuration(Podable): - - _pod_serialization_version = 1 - config_points = [] - name = '' + """ + configure the behaviour of WA and how a run as a whole will behave. + The most common options that that you may want to specify are: + + :device: The name of the 'device' that you wish to perform the run + on. This name is a combination of a devlib + `Platform `_ and + `Target `_. To + see the available options please use ``wa list targets``. + :device_config: The is a dict mapping allowing you to configure which target + to connect to (e.g. ``host`` for an SSH connection or + ``device`` to specific an ADB name) as well as configure other + options for the device for example the ``working_directory`` + or the list of ``modules`` to be loaded onto the device. (For + more information please see + :ref:`here `) + :execution_order: Defines the order in which the agenda spec will be executed. + :reboot_policy: Defines when during execution of a run a Device will be rebooted. + :max_retries: The maximum number of times failed jobs will be retried before giving up. + :allow_phone_home: Prevent running any workloads that are marked with ‘phones_home’. + """ + _pod_serialization_version: int = 1 + config_points: List[ConfigurationPoint] = [] + name: str = '' # The below line must be added to all subclasses - configuration = {cp.name: cp for cp in config_points} + configuration: Dict[str, ConfigurationPoint] = {cp.name: cp for cp in config_points} @classmethod - def from_pod(cls, pod): - instance = super(Configuration, cls).from_pod(pod) + def from_pod(cls, pod: Dict[str, Any]) -> 'Configuration': + """ + create Configuration object from a plain old datastructure + """ + instance = cast('Configuration', super(Configuration, cls).from_pod(pod)) for cfg_point in cls.config_points: if cfg_point.name in pod: value = pod.pop(cfg_point.name) if hasattr(cfg_point.kind, 'from_pod'): - value = cfg_point.kind.from_pod(value) + value = cast(Podable, cfg_point.kind).from_pod(value) cfg_point.set_value(instance, value) if pod: msg = 'Invalid entry(ies) for "{}": "{}"' @@ -408,7 +518,10 @@ def __init__(self): for confpoint in self.config_points: confpoint.set_value(self, check_mandatory=False) - def set(self, name, value, check_mandatory=True): + def set(self, name: str, value: Any, check_mandatory: bool = True) -> None: + """ + set the value to the configuration point with the given name + """ if name not in self.configuration: raise ConfigError('Unknown {} configuration "{}"'.format(self.name, name)) @@ -419,15 +532,24 @@ def set(self, name, value, check_mandatory=True): msg = 'Invalid value "{}" for "{}": {}' raise ConfigError(msg.format(value, name, e)) - def update_config(self, values, check_mandatory=True): + def update_config(self, values: Any, check_mandatory: bool = True) -> None: + """ + update configuration with new values + """ for k, v in values.items(): self.set(k, v, check_mandatory=check_mandatory) - def validate(self): + def validate(self) -> None: + """ + validate all the configuration points in the configuration + """ for cfg_point in self.config_points: cfg_point.validate(self) - def to_pod(self): + def to_pod(self) -> Dict[str, Any]: + """ + convert Configuration to a plain old datastructure + """ pod = super(Configuration, self).to_pod() for cfg_point in self.config_points: value = getattr(self, cfg_point.name, None) @@ -435,17 +557,49 @@ def to_pod(self): return pod @staticmethod - def _pod_upgrade_v1(pod): + def _pod_upgrade_v1(pod: Dict[str, Any]) -> Dict[str, Any]: + """ + pod upgrade function for Configuration class + """ pod['_pod_version'] = pod.get('_pod_version', 1) return pod +class MetaConfigurationProtocol(Protocol): + plugin_packages: List[str] + dependencies_directory: str + plugins_directory: str + cache_directory: str + plugin_paths: List[str] + user_config_file: str + additional_packages_file: str + target_info_cache_file: str + apk_info_cache_file: str + user_directory: str + assets_repository: str + logging: LoggingConfig + verbosity: int + default_output_directory: str + extra_plugin_paths: List[str] + configuration: Dict[str, ConfigurationPoint] + + def to_pod(self) -> Dict[str, Any]: + ... + + def set(self, name: str, value: Any, check_mandatory: bool = True) -> None: + ... + + # This configuration for the core WA framework class MetaConfiguration(Configuration): + """ + There are also a couple of settings are used to provide additional metadata + for a run. These may get picked up by instruments or output processors to + attach context to results. + """ + name: str = "Meta Configuration" - name = "Meta Configuration" - - core_plugin_packages = [ + core_plugin_packages: List[str] = [ 'wa.commands', 'wa.framework.getters', 'wa.framework.target.descriptor', @@ -454,7 +608,7 @@ class MetaConfiguration(Configuration): 'wa.workloads', ] - config_points = [ + config_points: List[ConfigurationPoint] = [ ConfigurationPoint( 'user_directory', description=""" @@ -512,67 +666,139 @@ class MetaConfiguration(Configuration): """, ), ] - configuration = {cp.name: cp for cp in config_points} + configuration: Dict[str, ConfigurationPoint] = {cp.name: cp for cp in config_points} @property - def dependencies_directory(self): - return os.path.join(self.user_directory, 'dependencies') + def dependencies_directory(self) -> str: + """ + dependencies directory, typically ``~/.workload_automation/dependencies/`` + """ + return os.path.join(cast(MetaConfigurationProtocol, self).user_directory, 'dependencies') @property - def plugins_directory(self): - return os.path.join(self.user_directory, 'plugins') + def plugins_directory(self) -> str: + """ + plugins directory + """ + return os.path.join(cast(MetaConfigurationProtocol, self).user_directory, 'plugins') @property - def cache_directory(self): - return os.path.join(self.user_directory, 'cache') + def cache_directory(self) -> str: + """ + cache directory + """ + return os.path.join(cast(MetaConfigurationProtocol, self).user_directory, 'cache') @property - def plugin_paths(self): - return [self.plugins_directory] + (self.extra_plugin_paths or []) + def plugin_paths(self) -> List[str]: + """ + list of plugin paths + """ + return [self.plugins_directory] + (cast(MetaConfigurationProtocol, self).extra_plugin_paths or []) @property - def user_config_file(self): - return os.path.join(self.user_directory, 'config.yaml') + def user_config_file(self) -> str: + """ + user configuration file + """ + return os.path.join(cast(MetaConfigurationProtocol, self).user_directory, 'config.yaml') @property - def additional_packages_file(self): - return os.path.join(self.user_directory, 'packages') + def additional_packages_file(self) -> str: + """ + Additional packages file + """ + return os.path.join(cast(MetaConfigurationProtocol, self).user_directory, 'packages') @property - def target_info_cache_file(self): + def target_info_cache_file(self) -> str: + """ + target information cache file + """ return os.path.join(self.cache_directory, 'targets.json') @property - def apk_info_cache_file(self): + def apk_info_cache_file(self) -> str: + """ + apk information cache file + """ return os.path.join(self.cache_directory, 'apk_info.json') - def __init__(self, environ=None): + def __init__(self, environ: Optional[os._Environ] = None): super(MetaConfiguration, self).__init__() if environ is None: environ = os.environ - user_directory = environ.pop('WA_USER_DIRECTORY', '') + user_directory: str = environ.pop('WA_USER_DIRECTORY', '') if user_directory: self.set('user_directory', user_directory) - extra_plugin_paths = environ.pop('WA_PLUGIN_PATHS', '') + extra_plugin_paths: str = environ.pop('WA_PLUGIN_PATHS', '') if extra_plugin_paths: self.set('extra_plugin_paths', extra_plugin_paths.split(os.pathsep)) - self.plugin_packages = copy(self.core_plugin_packages) + self.plugin_packages: List[str] = copy(self.core_plugin_packages) if os.path.isfile(self.additional_packages_file): with open(self.additional_packages_file) as fh: extra_packages = [p.strip() for p in fh.read().split('\n') if p.strip()] self.plugin_packages.extend(extra_packages) +class RunConfigurationProtocol(Protocol): + device_config: Optional[obj_dict] + augmentations: Dict[str, Dict[str, Optional[ConfigurationPoint]]] + resource_getters: Dict[str, Dict[str, Optional[ConfigurationPoint]]] + run_name: str + project: str + project_stage: Union[Dict, str] + execution_order: str + reboot_policy: RebootPolicy + device: str + retry_on_status: List + max_retries: int + bail_on_init_failure: bool + bail_on_job_failure: bool + allow_phone_home: bool + name: str + meta_data: List[ConfigurationPoint] + config_points: List[ConfigurationPoint] + configuration: Dict[str, ConfigurationPoint] + + def set(self, name: str, value: Any, check_mandatory: bool = True) -> None: + ... + + def add_augmentation(self, aug: 'Plugin') -> None: + ... + + def add_resource_getter(self, getter: 'Plugin') -> None: + ... + + def to_pod(self) -> Dict[str, Any]: + ... + + def merge_device_config(self, plugin_cache: 'PluginCache') -> None: + ... + + # This is generic top-level configuration for WA runs. class RunConfiguration(Configuration): - - name = "Run Configuration" + """ + In addition to specifying run execution parameters through an agenda, the + behaviour of WA can be modified through configuration file(s). The default + configuration file is ``~/.workload_automation/config.yaml`` (the location can + be changed by setting ``WA_USER_DIRECTORY`` environment variable. + This file will be created when you first run WA if it does not already exist. + This file must always exist and will always be loaded. You can add to or override + the contents of that file on invocation of Workload Automation by specifying an + additional configuration file using ``--config`` option. Variables with specific + names will be picked up by the framework and used to modify the behaviour of + Workload automation e.g. the ``iterations`` variable might be specified to tell + WA how many times to run each workload. + """ + name: str = "Run Configuration" # Metadata is separated out because it is not loaded into the auto # generated config file - meta_data = [ + meta_data: List[ConfigurationPoint] = [ ConfigurationPoint( 'run_name', kind=str, @@ -601,7 +827,7 @@ class RunConfiguration(Configuration): ''', ), ] - config_points = [ + config_points: List[ConfigurationPoint] = [ ConfigurationPoint( 'execution_order', kind=str, @@ -757,18 +983,21 @@ class RunConfiguration(Configuration): workloads when testing confidential devices. '''), ] - configuration = {cp.name: cp for cp in config_points + meta_data} + configuration: Dict[str, ConfigurationPoint] = {cp.name: cp for cp in config_points + meta_data} @classmethod - def from_pod(cls, pod): - meta_pod = {} + def from_pod(cls, pod: Dict[str, Any]) -> 'RunConfiguration': + """ + create a RunConfiguration object from a plain old datastructure + """ + meta_pod: Dict[str, Any] = {} for cfg_point in cls.meta_data: meta_pod[cfg_point.name] = pod.pop(cfg_point.name, None) - device_config = pod.pop('device_config', None) - augmentations = pod.pop('augmentations', {}) - getters = pod.pop('resource_getters', {}) - instance = super(RunConfiguration, cls).from_pod(pod) + device_config: Optional[obj_dict] = pod.pop('device_config', None) + augmentations: Dict[str, Dict[str, Optional[ConfigurationPoint]]] = pod.pop('augmentations', {}) + getters: Dict[str, Dict[str, Optional[ConfigurationPoint]]] = pod.pop('resource_getters', {}) + instance = cast('RunConfiguration', super(RunConfiguration, cls).from_pod(pod)) instance.device_config = device_config instance.augmentations = augmentations instance.resource_getters = getters @@ -777,37 +1006,48 @@ def from_pod(cls, pod): return instance - def __init__(self): + def __init__(self) -> None: super(RunConfiguration, self).__init__() for confpoint in self.meta_data: confpoint.set_value(self, check_mandatory=False) - self.device_config = None - self.augmentations = {} - self.resource_getters = {} + self.device_config: Optional[obj_dict] = None + self.augmentations: Dict[str, Dict[str, Optional[ConfigurationPoint]]] = {} + self.resource_getters: Dict[str, Dict[str, Optional[ConfigurationPoint]]] = {} - def merge_device_config(self, plugin_cache): + def merge_device_config(self, plugin_cache: 'PluginCache') -> None: """ Merges global device config and validates that it is correct for the selected device. """ # pylint: disable=no-member - if self.device is None: + if cast(RunConfigurationProtocol, self).device is None: msg = 'Attempting to merge device config with unspecified device' raise RuntimeError(msg) - self.device_config = plugin_cache.get_plugin_config(self.device, + self.device_config = plugin_cache.get_plugin_config(cast(RunConfigurationProtocol, self).device, generic_name="device_config") - def add_augmentation(self, aug): + def add_augmentation(self, aug: 'Plugin') -> None: + """ + add an augmentation to the run configuration + """ if aug.name in self.augmentations: raise ValueError('Augmentation "{}" already added.'.format(aug.name)) - self.augmentations[aug.name] = aug.get_config() + if aug.name: + self.augmentations[aug.name] = aug.get_config() - def add_resource_getter(self, getter): + def add_resource_getter(self, getter: 'Plugin') -> None: + """ + add a resource getter to the run configuration + """ if getter.name in self.resource_getters: raise ValueError('Resource getter "{}" already added.'.format(getter.name)) - self.resource_getters[getter.name] = getter.get_config() + if getter.name: + self.resource_getters[getter.name] = getter.get_config() - def to_pod(self): + def to_pod(self) -> Dict[str, Any]: + """ + Convert Run Configuration to a plain old datastructure + """ pod = super(RunConfiguration, self).to_pod() pod['device_config'] = dict(self.device_config or {}) pod['augmentations'] = self.augmentations @@ -815,12 +1055,37 @@ def to_pod(self): return pod +class JobSpecProtocol(Protocol): + id: Optional[str] + section_id: Optional[str] + workload_id: Optional[str] + iterations: int + workload_name: str + workload_parameters: obj_dict + runtime_parameters: obj_dict + boot_parameters: obj_dict + label: str + augmentations: toggle_set + flash: Dict + classifiers: od[str, str] + _sources: List[JobSpecSource] + + @classmethod + def from_pod(cls, pod: Dict[str, Any]) -> 'Podable': + ... + + def to_pod(self) -> Dict[str, Any]: + ... + + class JobSpec(Configuration): # pylint: disable=access-member-before-definition,attribute-defined-outside-init + """ + Job specification + """ + name: str = "Job Spec" - name = "Job Spec" - - config_points = [ + config_points: List[ConfigurationPoint] = [ ConfigurationPoint('iterations', kind=int, default=1, description=''' How many times to repeat this workload spec @@ -877,65 +1142,85 @@ class JobSpec(Configuration): for results when post processing. '''), ] - configuration = {cp.name: cp for cp in config_points} + configuration: Dict[str, ConfigurationPoint] = {cp.name: cp for cp in config_points} @classmethod - def from_pod(cls, pod): - job_id = pod.pop('id') - instance = super(JobSpec, cls).from_pod(pod) + def from_pod(cls, pod: Dict[str, Any]) -> 'JobSpec': + """ + Create a JobSpec object from a plain old datastructure + """ + job_id: str = pod.pop('id') + instance = cast('JobSpec', super(JobSpec, cls).from_pod(pod)) instance.id = job_id return instance @property - def section_id(self): + def section_id(self) -> Optional[str]: + """ + section id + """ if self.id is not None: return self.id.rsplit('-', 1)[0] + return None @property - def workload_id(self): + def workload_id(self) -> Optional[str]: + """ + workload id + """ if self.id is not None: return self.id.rsplit('-', 1)[-1] + return None - def __init__(self): + def __init__(self) -> None: super(JobSpec, self).__init__() if self.classifiers is None: - self.classifiers = OrderedDict() - self.to_merge = defaultdict(OrderedDict) - self._sources = [] - self.id = None + self.classifiers: od[str, str] = OrderedDict() + self.to_merge: DefaultDict[str, Dict[JobSpecSource, Any]] = defaultdict(OrderedDict) + self._sources: List[JobSpecSource] = [] + self.id: Optional[str] = None if self.boot_parameters is None: - self.boot_parameters = obj_dict() + self.boot_parameters: obj_dict = obj_dict() if self.runtime_parameters is None: - self.runtime_parameters = obj_dict() + self.runtime_parameters: obj_dict = obj_dict() - def to_pod(self): + def to_pod(self) -> Dict[str, Any]: + """ + Convert JobSpec to a plain old datastructure + """ pod = super(JobSpec, self).to_pod() pod['id'] = self.id return pod - def update_config(self, source, check_mandatory=True): # pylint: disable=arguments-differ + # FIXME - the base class update_config seems to have the second argument as a Dict[str,Any] + # but the inherited class Jobspec implements it with second arg as a JobSpecSource type. + # currently keeping the type as Any in the base class works, but may not be the right thing to do. + def update_config(self, source: JobSpecSource, check_mandatory: bool = True): # pylint: disable=arguments-differ self._sources.append(source) values = source.config for k, v in values.items(): if k == "id": continue - elif k.endswith('_parameters'): + elif cast(str, k).endswith('_parameters'): if v: self.to_merge[k][source] = copy(v) else: try: self.set(k, v, check_mandatory=check_mandatory) except ConfigError as e: - msg = 'Error in {}:\n\t{}' + msg: str = 'Error in {}:\n\t{}' raise ConfigError(msg.format(source.name, e.message)) - def merge_workload_parameters(self, plugin_cache): + def merge_workload_parameters(self, plugin_cache: 'PluginCache') -> None: + """ + merge the workload parameters + """ # merge global generic and specific config - workload_params = plugin_cache.get_plugin_config(self.workload_name, - generic_name="workload_parameters", - is_final=False) + workload_params: obj_dict = plugin_cache.get_plugin_config(cast(JobSpecProtocol, self).workload_name, + generic_name="workload_parameters", + is_final=False) - cfg_points = plugin_cache.get_plugin_parameters(self.workload_name) + cfg_points: Dict[str, ConfigurationPoint] = plugin_cache.get_plugin_parameters(cast(JobSpecProtocol, self).workload_name) for source in self._sources: config = dict(self.to_merge["workload_parameters"].get(source, {})) if not config: @@ -948,30 +1233,35 @@ def merge_workload_parameters(self, plugin_cache): check_mandatory=False) if config: msg = 'Unexpected config "{}" for "{}"' - raise ConfigError(msg.format(config, self.workload_name)) + raise ConfigError(msg.format(config, cast(JobSpecProtocol, self).workload_name)) self.workload_parameters = workload_params - def merge_runtime_parameters(self, plugin_cache, target_manager): - + def merge_runtime_parameters(self, plugin_cache: 'PluginCache', target_manager: 'TargetManager') -> None: + """ + merge the runtime parameters + """ # Order global runtime parameters - runtime_parameters = OrderedDict() + runtime_parameters: od[JobSpecSource, Dict[str, ConfigurationPoint]] = OrderedDict() try: - global_runtime_params = plugin_cache.get_plugin_config("runtime_parameters") + global_runtime_params: Union[obj_dict, Dict[JobSpecSource, Any]] = plugin_cache.get_plugin_config("runtime_parameters") except NotFoundError: global_runtime_params = {} for source in plugin_cache.sources: if source in global_runtime_params: - runtime_parameters[source] = global_runtime_params[source] + runtime_parameters[source] = global_runtime_params[source] # type:ignore # Add runtime parameters from JobSpec for source, values in self.to_merge['runtime_parameters'].items(): runtime_parameters[source] = values # Merge - self.runtime_parameters = target_manager.merge_runtime_parameters(runtime_parameters) + self.runtime_parameters = cast(obj_dict, target_manager.merge_runtime_parameters(runtime_parameters)) - def finalize(self): + def finalize(self) -> None: + """ + finakize jobspec creation + """ self.id = "-".join([str(source.config['id']) for source in self._sources[1:]]) # ignore first id, "global" @@ -980,72 +1270,92 @@ def finalize(self): self.runtime_parameters = obj_dict(list((self.runtime_parameters or {}).items())) self.workload_parameters = obj_dict(list((self.workload_parameters or {}).items())) - if self.label is None: - self.label = self.workload_name + if cast(JobSpecProtocol, self).label is None: + self.label = cast(JobSpecProtocol, self).workload_name # This is used to construct the list of Jobs WA will run class JobGenerator(object): + """ + construct the list of jobs WA will run + """ + name: str = "Jobs Configuration" + + def __init__(self, plugin_cache: 'PluginCache'): + self.plugin_cache = plugin_cache + self.ids_to_run: List[str] = [] + self.workloads: List['Workload'] = [] + self._enabled_augmentations = toggle_set() + self._enabled_instruments: Optional[List[str]] = None + self._enabled_processors: Optional[List[str]] = None + self._read_augmentations: bool = False + self.disabled_augmentations: Set[str] = set() + + self.job_spec_template: obj_dict = obj_dict(not_in_dict=['name']) + self.job_spec_template.name = "globally specified job spec configuration" + self.job_spec_template.id = "global" + # Load defaults + for cfg_point in JobSpec.configuration.values(): + cfg_point.set_value(self.job_spec_template, check_mandatory=False) - name = "Jobs Configuration" + self.root_node = SectionNode(self.job_spec_template) @property - def enabled_instruments(self): + def enabled_instruments(self) -> List[str]: + """ + get the enabled instruments for the job + """ self._read_augmentations = True if self._enabled_instruments is None: self._enabled_instruments = [] for entry in list(self._enabled_augmentations.merge_with(self.disabled_augmentations).values()): entry_cls = self.plugin_cache.get_plugin_class(entry) - if entry_cls.kind == 'instrument': + if entry_cls and entry_cls.kind == 'instrument': self._enabled_instruments.append(entry) return self._enabled_instruments @property - def enabled_processors(self): + def enabled_processors(self) -> List[str]: + """ + get the enabled output processors for the job + """ self._read_augmentations = True if self._enabled_processors is None: self._enabled_processors = [] for entry in list(self._enabled_augmentations.merge_with(self.disabled_augmentations).values()): entry_cls = self.plugin_cache.get_plugin_class(entry) - if entry_cls.kind == 'output_processor': + if entry_cls and entry_cls.kind == 'output_processor': self._enabled_processors.append(entry) return self._enabled_processors - def __init__(self, plugin_cache): - self.plugin_cache = plugin_cache - self.ids_to_run = [] - self.workloads = [] - self._enabled_augmentations = toggle_set() - self._enabled_instruments = None - self._enabled_processors = None - self._read_augmentations = False - self.disabled_augmentations = set() - - self.job_spec_template = obj_dict(not_in_dict=['name']) - self.job_spec_template.name = "globally specified job spec configuration" - self.job_spec_template.id = "global" - # Load defaults - for cfg_point in JobSpec.configuration.values(): - cfg_point.set_value(self.job_spec_template, check_mandatory=False) - - self.root_node = SectionNode(self.job_spec_template) - - def set_global_value(self, name, value): + def set_global_value(self, name: str, value: Any) -> None: + """ + set value to a global job spec configuration or augmentations + """ JobSpec.configuration[name].set_value(self.job_spec_template, value, check_mandatory=False) if name == "augmentations": self.update_augmentations(value) - def add_section(self, section, workloads, group): - new_node = self.root_node.add_section(section, group) + def add_section(self, section: obj_dict, workloads: List[obj_dict], group: str) -> None: + """ + Add a new section to the job tree + """ + new_node: SectionNode = self.root_node.add_section(section, group) with log.indentcontext(): for workload in workloads: new_node.add_workload(workload) - def add_workload(self, workload): + def add_workload(self, workload: obj_dict) -> None: + """ + add a workload to the job tree + """ self.root_node.add_workload(workload) - def disable_augmentations(self, augmentations): + def disable_augmentations(self, augmentations: toggle_set): + """ + disable augmentations + """ for entry in augmentations: if entry == '~~': continue @@ -1057,43 +1367,56 @@ def disable_augmentations(self, augmentations): raise ConfigError('Error disabling unknown augmentation: "{}"'.format(entry)) self.disabled_augmentations = self.disabled_augmentations.union(augmentations) - def update_augmentations(self, value): + def update_augmentations(self, value: Any) -> None: + """ + update augmentations + """ if self._read_augmentations: msg = 'Cannot update augmentations after they have been accessed' raise RuntimeError(msg) self._enabled_augmentations = self._enabled_augmentations.merge_with(value) - def only_run_ids(self, ids): + def only_run_ids(self, ids: Union[str, List[str]]) -> None: + """ + List of ids of the only jobs to run + """ if isinstance(ids, str): ids = [ids] self.ids_to_run = ids - def generate_job_specs(self, target_manager): - specs = [] + def generate_job_specs(self, target_manager: 'TargetManager') -> List[JobSpecProtocol]: + """ + generate job specifications + """ + specs: List[JobSpecProtocol] = [] for leaf in self.root_node.leaves(): workload_entries = leaf.workload_entries - sections = [leaf] + sections: List[SectionNode] = [leaf] for ancestor in leaf.ancestors(): workload_entries = ancestor.workload_entries + workload_entries sections.insert(0, ancestor) for workload_entry in workload_entries: - job_spec = create_job_spec(deepcopy(workload_entry), sections, - target_manager, self.plugin_cache, - self.disabled_augmentations) + job_spec: JobSpecProtocol = create_job_spec(deepcopy(workload_entry), sections, + target_manager, self.plugin_cache, + self.disabled_augmentations) if self.ids_to_run: for job_id in self.ids_to_run: - if job_id in job_spec.id: + if job_spec.id and job_id in job_spec.id: break else: continue - self.update_augmentations(list(job_spec.augmentations.values())) + self.update_augmentations(list(cast(JobSpecProtocol, job_spec).augmentations.values())) specs.append(job_spec) return specs -def create_job_spec(workload_entry, sections, target_manager, plugin_cache, - disabled_augmentations): +def create_job_spec(workload_entry: JobSpecSource, sections: List[SectionNode], + target_manager: 'TargetManager', plugin_cache: 'PluginCache', + disabled_augmentations: Set[str]) -> JobSpecProtocol: + """ + create the job specification + """ job_spec = JobSpec() # PHASE 2.1: Merge general job spec configuration @@ -1118,11 +1441,14 @@ def create_job_spec(workload_entry, sections, target_manager, plugin_cache, job_spec.set("augmentations", disabled_augmentations) job_spec.finalize() - return job_spec + return cast(JobSpecProtocol, job_spec) -def get_config_point_map(params): - pmap = {} +def get_config_point_map(params: List[ConfigurationPoint]) -> Dict[str, ConfigurationPoint]: + """ + get map of configuration points + """ + pmap: Dict[str, ConfigurationPoint] = {} for p in params: pmap[p.name] = p for alias in p.aliases: @@ -1130,4 +1456,4 @@ def get_config_point_map(params): return pmap -settings = MetaConfiguration(os.environ) +settings = cast(MetaConfigurationProtocol, MetaConfiguration(os.environ)) diff --git a/wa/framework/configuration/default.py b/wa/framework/configuration/default.py index 94f36f22e..d535b1d05 100644 --- a/wa/framework/configuration/default.py +++ b/wa/framework/configuration/default.py @@ -17,9 +17,11 @@ from wa.framework.configuration.plugin_cache import PluginCache from wa.utils.serializer import yaml from wa.utils.doc import strip_inlined_text +from typing import List, TYPE_CHECKING, TextIO, Optional, cast +if TYPE_CHECKING: + from wa.framework.configuration.core import ConfigurationPoint - -DEFAULT_AUGMENTATIONS = [ +DEFAULT_AUGMENTATIONS: List[str] = [ 'execution_time', 'interrupts', 'cpufreq', @@ -28,27 +30,36 @@ ] -def _format_yaml_comment(param, short_description=False): - comment = param.description - comment = strip_inlined_text(comment) +def _format_yaml_comment(param: 'ConfigurationPoint', short_description=False) -> str: + """ + format yaml comment + """ + comment: Optional[str] = param.description + comment = strip_inlined_text(comment or '') if short_description: - comment = comment.split('\n\n')[0] - comment = comment.replace('\n', '\n# ') + comment = comment.split('\n\n')[0] if comment else '' + comment = comment.replace('\n', '\n# ') if comment else '' comment = "# {}\n".format(comment) return comment -def _format_augmentations(output): +def _format_augmentations(output: TextIO) -> None: + """ + format augmentations + """ plugin_cache = PluginCache() output.write("augmentations:\n") for plugin in DEFAULT_AUGMENTATIONS: plugin_cls = plugin_cache.loader.get_plugin_class(plugin) - output.writelines(_format_yaml_comment(plugin_cls, short_description=True)) + output.writelines(_format_yaml_comment(cast('ConfigurationPoint', plugin_cls), short_description=True)) output.write(" - {}\n".format(plugin)) output.write("\n") -def generate_default_config(path): +def generate_default_config(path: str) -> None: + """ + generate default configuration + """ with open(path, 'w') as output: for param in MetaConfiguration.config_points + RunConfiguration.config_points: entry = {param.name: param.default} @@ -56,8 +67,11 @@ def generate_default_config(path): _format_augmentations(output) -def write_param_yaml(entry, param, output): - comment = _format_yaml_comment(param) +def write_param_yaml(entry, param: 'ConfigurationPoint', output: TextIO) -> None: + """ + write the configuration parameter into yaml file + """ + comment: str = _format_yaml_comment(param) output.writelines(comment) yaml.dump(entry, output, default_flow_style=False) output.write("\n") diff --git a/wa/framework/configuration/execution.py b/wa/framework/configuration/execution.py index d83d8ac71..0e685933c 100644 --- a/wa/framework/configuration/execution.py +++ b/wa/framework/configuration/execution.py @@ -16,44 +16,60 @@ import random from itertools import groupby, chain -from future.moves.itertools import zip_longest +from future.moves.itertools import zip_longest # type:ignore from devlib.utils.types import identifier +from devlib.target import Target from wa.framework.configuration.core import (MetaConfiguration, RunConfiguration, - JobGenerator, settings) + JobGenerator, settings, + JobSpecProtocol) from wa.framework.configuration.parsers import ConfigParser from wa.framework.configuration.plugin_cache import PluginCache from wa.framework.exception import NotFoundError, ConfigError from wa.framework.job import Job from wa.utils import log from wa.utils.serializer import Podable +from typing import (TYPE_CHECKING, cast, Dict, Any, Tuple, + Optional, List, Callable, Generator) +if TYPE_CHECKING: + from wa.framework.execution import ExecutionContext + from wa.framework.configuration.core import RunConfigurationProtocol, MetaConfigurationProtocol + from wa.framework.instrument import Instrument + from wa.framework.output_processor import OutputProcessor + from wa.framework.plugin import Plugin class CombinedConfig(Podable): - _pod_serialization_version = 1 + _pod_serialization_version: int = 1 @staticmethod - def from_pod(pod): - instance = super(CombinedConfig, CombinedConfig).from_pod(pod) - instance.settings = MetaConfiguration.from_pod(pod.get('settings', {})) - instance.run_config = RunConfiguration.from_pod(pod.get('run_config', {})) + def from_pod(pod: Dict[str, Any]) -> 'CombinedConfig': + instance = cast('CombinedConfig', super(CombinedConfig, CombinedConfig).from_pod(pod)) + instance.settings = cast('MetaConfigurationProtocol', MetaConfiguration.from_pod(pod.get('settings', {}))) + instance.run_config = cast('RunConfigurationProtocol', RunConfiguration.from_pod(pod.get('run_config', {}))) return instance - def __init__(self, settings=None, run_config=None): # pylint: disable=redefined-outer-name + def __init__(self, settings: Optional['MetaConfigurationProtocol'] = None, + run_config: Optional['RunConfigurationProtocol'] = None): # pylint: disable=redefined-outer-name super(CombinedConfig, self).__init__() self.settings = settings self.run_config = run_config - def to_pod(self): + def to_pod(self) -> Dict[str, Any]: pod = super(CombinedConfig, self).to_pod() - pod['settings'] = self.settings.to_pod() - pod['run_config'] = self.run_config.to_pod() + if self.settings: + pod['settings'] = self.settings.to_pod() + if self.run_config: + pod['run_config'] = self.run_config.to_pod() return pod @staticmethod - def _pod_upgrade_v1(pod): + def _pod_upgrade_v1(pod: Dict[str, Any]) -> Dict[str, Any]: + """ + pod upgrade function for CombinedConfig + """ pod['_pod_version'] = pod.get('_pod_version', 1) return pod @@ -68,101 +84,141 @@ class ConfigManager(object): """ @property - def enabled_instruments(self): + def enabled_instruments(self) -> List[str]: + """ + list of enabled instruments + """ return self.jobs_config.enabled_instruments @property - def enabled_processors(self): + def enabled_processors(self) -> List[str]: + """ + list of enabled output processors + """ return self.jobs_config.enabled_processors @property - def job_specs(self): + def job_specs(self) -> List[JobSpecProtocol]: + """ + list of job specifications + """ if not self._jobs_generated: - msg = 'Attempting to access job specs before '\ - 'jobs have been generated' + msg: str = 'Attempting to access job specs before '\ + 'jobs have been generated' raise RuntimeError(msg) return [j.spec for j in self._jobs] @property - def jobs(self): + def jobs(self) -> List[Job]: + """ + List of jobs generated + """ if not self._jobs_generated: - msg = 'Attempting to access jobs before '\ - 'they have been generated' + msg: str = 'Attempting to access jobs before '\ + 'they have been generated' raise RuntimeError(msg) return self._jobs - def __init__(self, settings=settings): # pylint: disable=redefined-outer-name + def __init__(self, settings: 'MetaConfigurationProtocol' = settings): # pylint: disable=redefined-outer-name self.settings = settings - self.run_config = RunConfiguration() + self.run_config: 'RunConfigurationProtocol' = cast('RunConfigurationProtocol', RunConfiguration()) self.plugin_cache = PluginCache() self.jobs_config = JobGenerator(self.plugin_cache) - self.loaded_config_sources = [] + self.loaded_config_sources: List[str] = [] self._config_parser = ConfigParser() - self._jobs = [] - self._jobs_generated = False - self.agenda = None - - def load_config_file(self, filepath): + self._jobs: List[Job] = [] + self._jobs_generated: bool = False + self.agenda: Optional[str] = None + + def load_config_file(self, filepath: str) -> None: + """ + Load configuration file + """ includes = self._config_parser.load_from_path(self, filepath) self.loaded_config_sources.append(filepath) self.loaded_config_sources.extend(includes) - def load_config(self, values, source): + def load_config(self, values: Dict, source: str) -> None: + """ + load configuration from source + """ self._config_parser.load(self, values, source) self.loaded_config_sources.append(source) - def get_plugin(self, name=None, kind=None, *args, **kwargs): + def get_plugin(self, name: Optional[str] = None, kind: Optional[str] = None, + *args, **kwargs) -> Optional['Plugin']: + """ + get the plugin of the specified name and kind + """ return self.plugin_cache.get_plugin(identifier(name), kind, *args, **kwargs) - def get_instruments(self, target): - instruments = [] + def get_instruments(self, target: Target) -> List['Instrument']: + """ + get the list of instruments associated with the WA + """ + instruments: List['Instrument'] = [] for name in self.enabled_instruments: try: - instruments.append(self.get_plugin(name, kind='instrument', - target=target)) + instruments.append(cast('Instrument', self.get_plugin(name, kind='instrument', + target=target))) except NotFoundError: msg = 'Instrument "{}" not found' raise NotFoundError(msg.format(name)) return instruments - def get_processors(self): - processors = [] + def get_processors(self) -> List['OutputProcessor']: + """ + get the output processors associated with the WA + """ + processors: List['OutputProcessor'] = [] for name in self.enabled_processors: try: - proc = self.plugin_cache.get_plugin(name, kind='output_processor') + proc: 'OutputProcessor' = cast('OutputProcessor', + self.plugin_cache.get_plugin(name, kind='output_processor')) except NotFoundError: msg = 'Output Processor "{}" not found' raise NotFoundError(msg.format(name)) processors.append(proc) return processors - def get_config(self): + def get_config(self) -> CombinedConfig: + """ + get the combined configuration + """ return CombinedConfig(self.settings, self.run_config) - def finalize(self): + def finalize(self) -> CombinedConfig: + """ + finalize the configuration + """ if not self.agenda: - msg = 'Attempting to finalize config before agenda has been set' + msg: str = 'Attempting to finalize config before agenda has been set' raise RuntimeError(msg) self.run_config.merge_device_config(self.plugin_cache) return self.get_config() - def generate_jobs(self, context): - job_specs = self.jobs_config.generate_job_specs(context.tm) + def generate_jobs(self, context: 'ExecutionContext') -> None: + """ + generate jobs based on job specifications based on the configuration + """ + job_specs: List[JobSpecProtocol] = self.jobs_config.generate_job_specs(context.tm) if not job_specs: - msg = 'No jobs available for running.' + msg: str = 'No jobs available for running.' raise ConfigError(msg) - exec_order = self.run_config.execution_order + exec_order: str = cast('RunConfigurationProtocol', self.run_config).execution_order log.indent() for spec, i in permute_iterations(job_specs, exec_order): job = Job(spec, i, context) - job.load(context.tm.target) + if context.tm.target: + job.load(context.tm.target) self._jobs.append(job) - context.run_state.add_job(job) + if context.run_state: + context.run_state.add_job(job) log.dedent() self._jobs_generated = True -def permute_by_workload(specs): +def permute_by_workload(specs: List[JobSpecProtocol]) -> Generator[Tuple[JobSpecProtocol, int], Any, None]: """ This is that "classic" implementation that executes all iterations of a workload spec before proceeding onto the next spec. @@ -173,7 +229,7 @@ def permute_by_workload(specs): yield (spec, i) -def permute_by_iteration(specs): +def permute_by_iteration(specs: List[JobSpecProtocol]) -> Generator[Tuple[JobSpecProtocol, int], Any, None]: """ Runs the first iteration for all benchmarks first, before proceeding to the next iteration, i.e. A1, B1, C1, A2, B2, C2... instead of A1, A1, B1, B2, @@ -189,9 +245,9 @@ def permute_by_iteration(specs): X.A1, Y.A1, X.B1, Y.B1, X.A2, Y.A2, X.B2, Y.B2 """ - groups = [list(g) for _, g in groupby(specs, lambda s: s.workload_id)] + groups: List[List[JobSpecProtocol]] = [list(g) for _, g in groupby(specs, lambda s: s.workload_id)] - all_tuples = [] + all_tuples: List[List[Tuple[JobSpecProtocol, int]]] = [] for spec in chain(*groups): all_tuples.append([(spec, i + 1) for i in range(spec.iterations)]) @@ -200,7 +256,7 @@ def permute_by_iteration(specs): yield t -def permute_by_section(specs): +def permute_by_section(specs: List[JobSpecProtocol]) -> Generator[Tuple[JobSpecProtocol, int], Any, None]: """ Runs the first iteration for all benchmarks first, before proceeding to the next iteration, i.e. A1, B1, C1, A2, B2, C2... instead of A1, A1, B1, B2, @@ -215,9 +271,9 @@ def permute_by_section(specs): X.A1, X.B1, Y.A1, Y.B1, X.A2, X.B2, Y.A2, Y.B2 """ - groups = [list(g) for _, g in groupby(specs, lambda s: s.section_id)] + groups: List[List[JobSpecProtocol]] = [list(g) for _, g in groupby(specs, lambda s: s.section_id)] - all_tuples = [] + all_tuples: List[List[Tuple[JobSpecProtocol, int]]] = [] for spec in chain(*groups): all_tuples.append([(spec, i + 1) for i in range(spec.iterations)]) @@ -226,12 +282,12 @@ def permute_by_section(specs): yield t -def permute_randomly(specs): +def permute_randomly(specs: List[JobSpecProtocol]) -> Generator[Tuple[JobSpecProtocol, int], Any, None]: """ This will generate a random permutation of specs/iteration tuples. """ - result = [] + result: List[Tuple[JobSpecProtocol, int]] = [] for spec in specs: for i in range(1, spec.iterations + 1): result.append((spec, i)) @@ -240,7 +296,8 @@ def permute_randomly(specs): yield t -permute_map = { +permute_map: Dict[str, Callable[[List[JobSpecProtocol]], + Generator[Tuple[JobSpecProtocol, int], Any, None]]] = { 'by_iteration': permute_by_iteration, 'by_workload': permute_by_workload, 'by_section': permute_by_section, @@ -248,7 +305,10 @@ def permute_randomly(specs): } -def permute_iterations(specs, exec_order): +def permute_iterations(specs: List[JobSpecProtocol], exec_order: str): + """ + permute iterations based on the specified execution order + """ if exec_order not in permute_map: msg = 'Unknown execution order "{}"; must be in: {}' raise ValueError(msg.format(exec_order, list(permute_map.keys()))) diff --git a/wa/framework/configuration/parsers.py b/wa/framework/configuration/parsers.py index 4547db1d0..a4949e81f 100644 --- a/wa/framework/configuration/parsers.py +++ b/wa/framework/configuration/parsers.py @@ -24,28 +24,42 @@ from wa.framework.exception import ConfigError from wa.utils import log from wa.utils.serializer import json, read_pod, SerializerSyntaxError -from wa.utils.types import toggle_set, counter +from wa.utils.types import toggle_set, counter, obj_dict from wa.utils.misc import merge_config_values, isiterable +from wa.framework.configuration.tree import JobSpecSource +from typing import (TYPE_CHECKING, Dict, List, Any, cast, Union, + Tuple, Set, Optional) +if TYPE_CHECKING: + from wa.framework.configuration.execution import ConfigManager, JobGenerator + from wa.framework.configuration.core import ConfigurationPoint - -logger = logging.getLogger('config') +logger: logging.Logger = logging.getLogger('config') class ConfigParser(object): - - def load_from_path(self, state, filepath): + """ + Config file parser + """ + def load_from_path(self, state: 'ConfigManager', filepath: str) -> List[str]: + """ + Load config file from the specified path + """ raw, includes = _load_file(filepath, "Config") self.load(state, raw, filepath) return includes - def load(self, state, raw, source, wrap_exceptions=True): # pylint: disable=too-many-branches + def load(self, state: 'ConfigManager', raw: Dict, + source: str, wrap_exceptions: bool = True) -> None: # pylint: disable=too-many-branches + """ + load configuration from source file + """ logger.debug('Parsing config from "{}"'.format(source)) log.indent() try: - state.plugin_cache.add_source(source) + state.plugin_cache.add_source(cast(JobSpecSource, source)) if 'run_name' in raw: - msg = '"run_name" can only be specified in the config '\ - 'section of an agenda' + msg: str = '"run_name" can only be specified in the config '\ + 'section of an agenda' raise ConfigError(msg) if 'id' in raw: @@ -78,7 +92,7 @@ def load(self, state, raw, source, wrap_exceptions=True): # pylint: disable=too # Assume that all leftover config is for a plug-in or a global # alias it is up to PluginCache to assert this assumption logger.debug('Caching "{}" with "{}"'.format(identifier(name), values)) - state.plugin_cache.add_configs(identifier(name), values, source) + state.plugin_cache.add_configs(identifier(name), values, cast(JobSpecSource, source)) except ConfigError as e: if wrap_exceptions: @@ -90,13 +104,18 @@ def load(self, state, raw, source, wrap_exceptions=True): # pylint: disable=too class AgendaParser(object): - - def load_from_path(self, state, filepath): + """ + Agenda Parser + """ + def load_from_path(self, state: 'ConfigManager', filepath: str) -> List[str]: raw, includes = _load_file(filepath, 'Agenda') self.load(state, raw, filepath) return includes - def load(self, state, raw, source): + def load(self, state: 'ConfigManager', raw: Dict[str, List[Dict]], source: str) -> None: + """ + load agenda from source + """ logger.debug('Parsing agenda from "{}"'.format(source)) log.indent() try: @@ -104,11 +123,11 @@ def load(self, state, raw, source): raise ConfigError('Invalid agenda, top level entry must be a dict') self._populate_and_validate_config(state, raw, source) - sections = self._pop_sections(raw) - global_workloads = self._pop_workloads(raw) + sections: List[Dict] = self._pop_sections(raw) + global_workloads: List = self._pop_workloads(raw) if not global_workloads: - msg = 'No jobs avaliable. Please ensure you have specified at '\ - 'least one workload to run.' + msg: str = 'No jobs avaliable. Please ensure you have specified at '\ + 'least one workload to run.' raise ConfigError(msg) if raw: @@ -126,14 +145,19 @@ def load(self, state, raw, source): finally: log.dedent() - def _populate_and_validate_config(self, state, raw, source): + def _populate_and_validate_config(self, state: 'ConfigManager', + raw: Dict[str, Any], source: str) -> None: + """ + populate the configuration and validate it + config and global are dicts + """ for name in ['config', 'global']: - entry = raw.pop(name, None) + entry: Optional[Dict] = raw.pop(name, None) if entry is None: continue if not isinstance(entry, dict): - msg = 'Invalid entry "{}" - must be a dict' + msg: str = 'Invalid entry "{}" - must be a dict' raise ConfigError(msg.format(name)) if 'run_name' in entry: @@ -143,8 +167,12 @@ def _populate_and_validate_config(self, state, raw, source): state.load_config(entry, '{}/{}'.format(source, name)) - def _pop_sections(self, raw): - sections = raw.pop("sections", []) + def _pop_sections(self, raw: Dict[str, Any]) -> List[Dict]: + """ + get sections from raw data + sections is a List of dicts + """ + sections: List[Dict] = raw.pop("sections", []) if not isinstance(sections, list): raise ConfigError('Invalid entry "sections" - must be a list') for section in sections: @@ -152,15 +180,22 @@ def _pop_sections(self, raw): raise ConfigError('Invalid section "{}" - must be a dict'.format(section)) return sections - def _pop_workloads(self, raw): - workloads = raw.pop("workloads", []) + def _pop_workloads(self, raw: Dict[str, Any]) -> List[Dict[str, Any]]: + """ + get workloads from raw data + """ + workloads: List[Dict[str, Any]] = raw.pop("workloads", []) if not isinstance(workloads, list): raise ConfigError('Invalid entry "workloads" - must be a list') return workloads - def _collect_ids(self, sections, global_workloads): - seen_section_ids = set() - seen_workload_ids = set() + def _collect_ids(self, sections: List[Dict[str, Any]], + global_workloads: List[Dict[str, Any]]) -> Tuple[Set[str], Set[str]]: + """ + collect section and workload ids and return them as a tuple of sets + """ + seen_section_ids: Set[str] = set() + seen_workload_ids: Set[str] = set() for workload in global_workloads: workload = _get_workload_entry(workload) @@ -175,15 +210,23 @@ def _collect_ids(self, sections, global_workloads): return seen_section_ids, seen_workload_ids - def _process_global_workloads(self, state, global_workloads, seen_wkl_ids): + def _process_global_workloads(self, state: 'ConfigManager', global_workloads: List[Dict[str, Any]], + seen_wkl_ids: Set[str]) -> None: + """ + process global workload entries + """ for workload_entry in global_workloads: workload = _process_workload_entry(workload_entry, seen_wkl_ids, state.jobs_config) - state.jobs_config.add_workload(workload) + state.jobs_config.add_workload(cast(obj_dict, workload)) - def _process_sections(self, state, sections, seen_sect_ids, seen_wkl_ids): + def _process_sections(self, state: 'ConfigManager', sections: List[Dict[str, Any]], + seen_sect_ids: Set[str], seen_wkl_ids: Set[str]) -> None: + """ + process sections in the configuration + """ for section in sections: - workloads = [] + workloads: List[Dict[str, Any]] = [] for workload_entry in section.pop("workloads", []): workload = _process_workload_entry(workload_entry, seen_wkl_ids, state.jobs_config) @@ -191,22 +234,22 @@ def _process_sections(self, state, sections, seen_sect_ids, seen_wkl_ids): if 'params' in section: if 'runtime_params' in section: - msg = 'both "params" and "runtime_params" specified in a '\ - 'section: "{}"' + msg: str = 'both "params" and "runtime_params" specified in a '\ + 'section: "{}"' raise ConfigError(msg.format(json.dumps(section, indent=None))) section['runtime_params'] = section.pop('params') - group = section.pop('group', None) + group: str = section.pop('group', None) section = _construct_valid_entry(section, seen_sect_ids, "s", state.jobs_config) - state.jobs_config.add_section(section, workloads, group) + state.jobs_config.add_section(cast(obj_dict, section), cast(List[obj_dict], workloads), group) ######################## ### Helper functions ### ######################## -def pop_aliased_param(cfg_point, d, default=None): +def pop_aliased_param(cfg_point: 'ConfigurationPoint', d: Dict[str, str], default: Any = None) -> Any: """ Given a ConfigurationPoint and a dict, this function will search the dict for the ConfigurationPoint's name/aliases. If more than one is found it will raise @@ -214,8 +257,8 @@ def pop_aliased_param(cfg_point, d, default=None): for the ConfigurationPoint. If the name or aliases are present in the dict it will return the "default" parameter of this function. """ - aliases = [cfg_point.name] + cfg_point.aliases - alias_map = [a for a in aliases if a in d] + aliases: List[str] = [cfg_point.name] + cfg_point.aliases + alias_map: List[str] = [a for a in aliases if a in d] if len(alias_map) > 1: raise ConfigError('Duplicate entry: {}'.format(aliases)) elif alias_map: @@ -224,23 +267,32 @@ def pop_aliased_param(cfg_point, d, default=None): return default -def _load_file(filepath, error_name): +def _load_file(filepath: str, error_name: str) -> Tuple[Dict[str, Any], List[str]]: + """ + read raw data and includes information from file + """ if not os.path.isfile(filepath): raise ValueError("{} does not exist".format(filepath)) try: - raw = read_pod(filepath) - includes = _process_includes(raw, filepath, error_name) + raw: Dict[str, Any] = read_pod(filepath) + includes: List[str] = _process_includes(raw, filepath, error_name) except SerializerSyntaxError as e: raise ConfigError('Error parsing {} {}: {}'.format(error_name, filepath, e)) if not isinstance(raw, dict): - message = '{} does not contain a valid {} structure; top level must be a dict.' + message: str = '{} does not contain a valid {} structure; top level must be a dict.' raise ConfigError(message.format(filepath, error_name)) return raw, includes -def _config_values_from_includes(filepath, include_path, error_name): - source_dir = os.path.dirname(filepath) - included_files = [] +def _config_values_from_includes(filepath: str, + include_path: Union[str, List[str]], + error_name: str) -> Tuple[Dict[str, Any], List[str]]: + """ + get the configuration values from the included files in the current configuration. + it again calls _load_file -> _process_includes in all the subsequent includes. + """ + source_dir: str = os.path.dirname(filepath) + included_files: List[str] = [] if isinstance(include_path, str): include_path = os.path.expanduser(os.path.join(source_dir, include_path)) @@ -268,15 +320,39 @@ def _config_values_from_includes(filepath, include_path, error_name): return replace_value, included_files -def _process_includes(raw, filepath, error_name): +def _process_includes(raw: Optional[Dict], filepath: str, error_name: str) -> List[str]: + """ + It is possible to include other files in your config files and agendas. This is + done by specifying ``include#`` (note the trailing hash) as a key in one of the + mappings, with the value being the path to the file to be included. The path + must be either absolute, or relative to the location of the file it is being + included from (*not* to the current working directory). The path may also + include ``~`` to indicate current user's home directory. + + The include is performed by removing the ``include#`` loading the contents of + the specified into the mapping that contained it. In cases where the mapping + already contains the key to be loaded, values will be merged using the usual + merge method (for overwrites, values in the mapping take precedence over those + from the included files). + Some additional details about the implementation and its limitations: + + - The ``include#`` *must* be a key in a mapping, and the contents of the + included file *must* be a mapping as well; it is not possible to include a + list + - Being a key in a mapping, there can only be one ``include#`` entry per block. + - The included file *must* have a ``.yaml`` extension. + - Nested inclusions *are* allowed. I.e. included files may themselves include + files; in such cases the included paths must be relative to *that* file, and + not the "main" file. + """ if not raw: return [] - included_files = [] - replace_value = None + included_files: List[str] = [] + replace_value: Optional[Dict[str, Any]] = None if hasattr(raw, 'items'): - for key, value in raw.items(): + for key, value in cast(Dict, raw).items(): if key == 'include#': replace_value, includes = _config_values_from_includes(filepath, value, error_name) included_files.extend(includes) @@ -297,7 +373,7 @@ def _process_includes(raw, filepath, error_name): return included_files -def merge_augmentations(raw): +def merge_augmentations(raw: Dict[str, Any]) -> None: """ Since, from configuration perspective, output processors and instruments are handled identically, the configuration entries are now interchangeable. E.g. it is @@ -309,10 +385,10 @@ def merge_augmentations(raw): that there are no conflicts between the entries. """ - cfg_point = JobSpec.configuration['augmentations'] - names = [cfg_point.name, ] + cfg_point.aliases + cfg_point: 'ConfigurationPoint' = JobSpec.configuration['augmentations'] + names: List[str] = [cfg_point.name, ] + cfg_point.aliases - entries = [] + entries: List[toggle_set] = [] for n in names: if n not in raw: continue @@ -320,15 +396,15 @@ def merge_augmentations(raw): try: entries.append(toggle_set(value)) except TypeError as exc: - msg = 'Invalid value "{}" for "{}": {}' + msg: str = 'Invalid value "{}" for "{}": {}' raise ConfigError(msg.format(value, n, exc)) # Make sure none of the specified aliases conflict with each other - to_check = list(entries) + to_check: List[toggle_set] = list(entries) while len(to_check) > 1: - check_entry = to_check.pop() + check_entry: toggle_set = to_check.pop() for e in to_check: - conflicts = check_entry.conflicts_with(e) + conflicts: List[str] = check_entry.conflicts_with(e) if conflicts: msg = '"{}" and "{}" have conflicting entries: {}' conflict_string = ', '.join('"{}"'.format(c.strip("~")) @@ -336,10 +412,10 @@ def merge_augmentations(raw): raise ConfigError(msg.format(check_entry, e, conflict_string)) if entries: - raw['augmentations'] = reduce(lambda x, y: x.union(y), entries) + raw['augmentations'] = reduce(lambda x, y: cast(toggle_set, x.union(y)), entries) -def _pop_aliased(d, names, entry_id): +def _pop_aliased(d: Dict[str, str], names: List[str], entry_id: str) -> Optional[str]: name_count = sum(1 for n in names if n in d) if name_count > 1: names_list = ', '.join(names) @@ -351,13 +427,16 @@ def _pop_aliased(d, names, entry_id): return None -def _construct_valid_entry(raw, seen_ids, prefix, jobs_config): - workload_entry = {} - +def _construct_valid_entry(raw: Dict[str, Any], seen_ids: Set[str], prefix: str, + jobs_config: 'JobGenerator') -> Dict[str, Any]: + workload_entry: Dict[str, Any] = {} + """ + construct a valid workload entry from raw data read from file + """ # Generate an automatic ID if the entry doesn't already have one if 'id' not in raw: while True: - new_id = '{}{}'.format(prefix, counter(name=prefix)) + new_id: str = '{}{}'.format(prefix, counter(name=prefix)) if new_id not in seen_ids: break workload_entry['id'] = new_id @@ -370,15 +449,15 @@ def _construct_valid_entry(raw, seen_ids, prefix, jobs_config): # Validate all workload_entry for name, cfg_point in JobSpec.configuration.items(): - value = pop_aliased_param(cfg_point, raw) - if value is not None: + value: Any = pop_aliased_param(cfg_point, raw) + if value is not None and cfg_point.kind: value = cfg_point.kind(value) cfg_point.validate_value(name, value) workload_entry[name] = value if "augmentations" in workload_entry: if '~~' in workload_entry['augmentations']: - msg = '"~~" can only be specfied in top-level config, and not for individual workloads/sections' + msg: str = '"~~" can only be specfied in top-level config, and not for individual workloads/sections' raise ConfigError(msg) jobs_config.update_augmentations(workload_entry['augmentations']) @@ -390,7 +469,7 @@ def _construct_valid_entry(raw, seen_ids, prefix, jobs_config): return workload_entry -def _collect_valid_id(entry_id, seen_ids, entry_type): +def _collect_valid_id(entry_id: Optional[Union[int, str]], seen_ids: Set[str], entry_type) -> None: if entry_id is None: return entry_id = str(entry_id) @@ -398,7 +477,7 @@ def _collect_valid_id(entry_id, seen_ids, entry_type): raise ConfigError('Duplicate {} ID "{}".'.format(entry_type, entry_id)) # "-" is reserved for joining section and workload IDs if "-" in entry_id: - msg = 'Invalid {} ID "{}"; IDs cannot contain a "-"' + msg: str = 'Invalid {} ID "{}"; IDs cannot contain a "-"' raise ConfigError(msg.format(entry_type, entry_id)) if entry_id == "global": msg = 'Invalid {} ID "global"; is a reserved ID' @@ -406,7 +485,7 @@ def _collect_valid_id(entry_id, seen_ids, entry_type): seen_ids.add(entry_id) -def _get_workload_entry(workload): +def _get_workload_entry(workload: Union[Dict[str, Any], str]) -> Dict[str, Any]: if isinstance(workload, str): workload = {'name': workload} elif not isinstance(workload, dict): @@ -414,7 +493,8 @@ def _get_workload_entry(workload): return workload -def _process_workload_entry(workload, seen_workload_ids, jobs_config): +def _process_workload_entry(workload: Dict[str, Any], seen_workload_ids: Set[str], + jobs_config: 'JobGenerator') -> Dict[str, Any]: workload = _get_workload_entry(workload) workload = _construct_valid_entry(workload, seen_workload_ids, "wk", jobs_config) diff --git a/wa/framework/configuration/plugin_cache.py b/wa/framework/configuration/plugin_cache.py index 19b8a66ed..2ee615483 100644 --- a/wa/framework/configuration/plugin_cache.py +++ b/wa/framework/configuration/plugin_cache.py @@ -23,9 +23,18 @@ from wa.framework.exception import ConfigError, NotFoundError from wa.framework.target.descriptor import list_target_descriptions from wa.utils.types import obj_dict, caseless_string +from typing import (TYPE_CHECKING, Dict, cast, Optional, List, Any, + DefaultDict, Set, Tuple, Type) +from types import ModuleType +from wa.framework.configuration.tree import JobSpecSource +if TYPE_CHECKING: + from wa.framework.configuration.core import ConfigurationPoint, Configuration + from wa.framework.pluginloader import __LoaderWrapper + from wa.framework.target.descriptor import TargetDescriptionProtocol + from wa.framework.plugin import Plugin -GENERIC_CONFIGS = ["device_config", "workload_parameters", - "boot_parameters", "runtime_parameters"] +GENERIC_CONFIGS: List[str] = ["device_config", "workload_parameters", + "boot_parameters", "runtime_parameters"] class PluginCache(object): @@ -37,31 +46,48 @@ class PluginCache(object): from, and the priority order of said sources. """ - def __init__(self, loader=pluginloader): - self.loader = loader - self.sources = [] - self.plugin_configs = defaultdict(lambda: defaultdict(dict)) - self.global_alias_values = defaultdict(dict) - self.targets = {td.name: td for td in list_target_descriptions()} + def __init__(self, loader: ModuleType = pluginloader): + self.loader = cast('__LoaderWrapper', loader) + self.sources: List[JobSpecSource] = [] + self.plugin_configs: Dict[str, Dict[JobSpecSource, + Dict[str, Any]]] = defaultdict(lambda: defaultdict(dict)) + self.global_alias_values: DefaultDict = defaultdict(dict) + self.targets: Dict[str, 'TargetDescriptionProtocol'] = {td.name: cast('TargetDescriptionProtocol', td) + for td in list_target_descriptions()} # Generate a mapping of what global aliases belong to - self._global_alias_map = defaultdict(dict) - self._list_of_global_aliases = set() + self._global_alias_map: DefaultDict[str, Dict] = defaultdict(dict) + self._list_of_global_aliases: Set[str] = set() for plugin in self.loader.list_plugins(): for param in plugin.parameters: if param.global_alias: - self._global_alias_map[plugin.name][param.global_alias] = param + self._global_alias_map[plugin.name or ''][param.global_alias] = param self._list_of_global_aliases.add(param.global_alias) - def add_source(self, source): + def add_source(self, source: JobSpecSource): + """ + add a source to the plugin cache + """ if source in self.sources: - msg = "Source '{}' has already been added." + msg: str = "Source '{}' has already been added." raise Exception(msg.format(source)) self.sources.append(source) - def add_global_alias(self, alias, value, source): + def add_global_alias(self, alias: str, value: Dict[str, Any], source: JobSpecSource) -> None: + """ + add global alias to a source + Typically, values for plugin parameters are specified name spaced under + the plugin's name in the configuration. A global alias is an alias that + may be specified at the top level in configuration. + + There two common reasons for this. First, several plugins might + specify the same global alias for the same parameter, thus allowing all + of them to be configured with one settings. Second, a plugin may not be + exposed directly to the user (e.g. resource getters) so it makes more + sense to treat its parameters as global configuration values. + """ if source not in self.sources: - msg = "Source '{}' has not been added to the plugin cache." + msg: str = "Source '{}' has not been added to the plugin cache." raise RuntimeError(msg.format(source)) if not self.is_global_alias(alias): @@ -70,13 +96,16 @@ def add_global_alias(self, alias, value, source): self.global_alias_values[alias][source] = value - def add_configs(self, plugin_name, values, source): + def add_configs(self, plugin_name: str, values: Dict[str, Any], source: JobSpecSource) -> None: + """ + add configurations to the plugin + """ if self.is_global_alias(plugin_name): self.add_global_alias(plugin_name, values, source) return if source not in self.sources: - msg = "Source '{}' has not been added to the plugin cache." + msg: str = "Source '{}' has not been added to the plugin cache." raise RuntimeError(msg.format(source)) if caseless_string(plugin_name) in ['global', 'config']: @@ -102,13 +131,23 @@ def add_configs(self, plugin_name, values, source): self.plugin_configs[plugin_name][source][name] = value - def is_global_alias(self, name): + def is_global_alias(self, name: str) -> bool: + """ + check whether the provided name is in the list of global aliases + """ return name in self._list_of_global_aliases - def list_plugins(self, kind=None): + def list_plugins(self, kind: Optional[str] = None) -> List[Type['Plugin']]: + """ + List plugins in the plugin cache + """ return self.loader.list_plugins(kind) - def get_plugin_config(self, plugin_name, generic_name=None, is_final=True): + def get_plugin_config(self, plugin_name: str, generic_name: Optional[str] = None, + is_final: bool = True) -> obj_dict: + """ + get the plugin configuration + """ config = obj_dict(not_in_dict=['name']) config.name = plugin_name @@ -119,8 +158,8 @@ def get_plugin_config(self, plugin_name, generic_name=None, is_final=True): if generic_name is None: # Perform a simple merge with the order of sources representing # priority - plugin_config = self.plugin_configs[plugin_name] - cfg_points = self.get_plugin_parameters(plugin_name) + plugin_config: Dict[JobSpecSource, Dict[str, Any]] = self.plugin_configs[plugin_name] + cfg_points: Dict[str, 'ConfigurationPoint'] = self.get_plugin_parameters(plugin_name) for source in self.sources: if source not in plugin_config: continue @@ -134,26 +173,39 @@ def get_plugin_config(self, plugin_name, generic_name=None, is_final=True): return config - def get_plugin(self, name, kind=None, *args, **kwargs): - config = self.get_plugin_config(name) + def get_plugin(self, name: str, kind: Optional[str] = None, + *args, **kwargs) -> Optional['Plugin']: + """ + get plugin from plugin cache + """ + config: obj_dict = self.get_plugin_config(name) kwargs = dict(list(config.items()) + list(kwargs.items())) - return self.loader.get_plugin(name, kind=kind, *args, **kwargs) + return self.loader.get_plugin(name, kind, *args, **kwargs) - def get_plugin_class(self, name, kind=None): + def get_plugin_class(self, name: str, kind: Optional[str] = None) -> Type['Plugin']: return self.loader.get_plugin_class(name, kind) @memoized - def get_plugin_parameters(self, name): + def get_plugin_parameters(self, name: str) -> Dict[str, 'ConfigurationPoint']: + """ + get the plugin parameters + """ if name in self.targets: return self._get_target_params(name) - params = self.loader.get_plugin_class(name).parameters + params: List[ConfigurationPoint] = self.loader.get_plugin_class(name).parameters return get_config_point_map(params) - def resolve_alias(self, name): + def resolve_alias(self, name: str) -> Tuple[str, Dict]: + """ + resolve the name aliases + """ return self.loader.resolve_alias(name) - def _set_plugin_defaults(self, plugin_name, config): - cfg_points = self.get_plugin_parameters(plugin_name) + def _set_plugin_defaults(self, plugin_name: str, config: obj_dict) -> None: + """ + set the defaults for the plugin + """ + cfg_points: Dict[str, 'ConfigurationPoint'] = self.get_plugin_parameters(plugin_name) for cfg_point in cfg_points.values(): cfg_point.set_value(config, check_mandatory=False) @@ -164,7 +216,10 @@ def _set_plugin_defaults(self, plugin_name, config): except NotFoundError: pass - def _set_from_global_aliases(self, plugin_name, config): + def _set_from_global_aliases(self, plugin_name: str, config: obj_dict) -> None: + """ + set configuration parameter values based on global alias values + """ for alias, param in self._global_alias_map[plugin_name].items(): if alias in self.global_alias_values: for source in self.sources: @@ -173,13 +228,17 @@ def _set_from_global_aliases(self, plugin_name, config): val = self.global_alias_values[alias][source] param.set_value(config, value=val) - def _get_target_params(self, name): - td = self.targets[name] - return get_config_point_map(chain(td.target_params, td.platform_params, td.conn_params, td.assistant_params)) + def _get_target_params(self, name: str) -> Dict[str, 'ConfigurationPoint']: + """ + get the target parameters + """ + td: 'TargetDescriptionProtocol' = cast('TargetDescriptionProtocol', self.targets[name]) + return get_config_point_map(list(chain(td.target_params, td.platform_params, td.conn_params, td.assistant_params))) # pylint: disable=too-many-nested-blocks, too-many-branches - def _merge_using_priority_specificity(self, specific_name, - generic_name, merged_config, is_final=True): + def _merge_using_priority_specificity(self, specific_name: str, + generic_name: str, merged_config: obj_dict, + is_final: bool = True) -> None: """ WA configuration can come from various sources of increasing priority, as well as being specified in a generic and specific manner (e.g @@ -230,8 +289,9 @@ def _merge_using_priority_specificity(self, specific_name, # Validate final configuration merged_config.name = specific_name - for cfg_point in ms.cfg_points.values(): - cfg_point.validate(merged_config, check_mandatory=is_final) + if ms.cfg_points: + for cfg_point in ms.cfg_points.values(): + cfg_point.validate(cast('Configuration', merged_config), check_mandatory=is_final) def __getattr__(self, name): """ @@ -256,61 +316,69 @@ def __wrapper(pname, *args, **kwargs): if name.startswith('list_'): name = name.replace('list_', '', 1).rstrip('s') if name in self.loader.kind_map: - def __wrapper(*args, **kwargs): # pylint: disable=E0102 + def __list_plugins_wrapper(*args, **kwargs): # pylint: disable=E0102 return self.list_plugins(name, *args, **kwargs) - return __wrapper + return __list_plugins_wrapper raise NotFoundError(error_msg.format(name)) if name.startswith('has_'): name = name.replace('has_', '', 1) if name in self.loader.kind_map: - def __wrapper(pname, *args, **kwargs): # pylint: disable=E0102 + def __has_plugin_wrapper(pname, *args, **kwargs): # pylint: disable=E0102 return self.loader.has_plugin(pname, name, *args, **kwargs) - return __wrapper + return __has_plugin_wrapper raise NotFoundError(error_msg.format(name)) raise AttributeError(name) class MergeState(object): - - def __init__(self): - self.generic_name = None - self.specific_name = None - self.generic_config = None - self.specific_config = None - self.cfg_points = None - self.seen_specific_config = defaultdict(list) + """ + merge configurations based on priority specificity + """ + def __init__(self) -> None: + self.generic_name: Optional[str] = None + self.specific_name: Optional[str] = None + self.generic_config: Optional[Dict[JobSpecSource, Dict[str, Any]]] = None + self.specific_config: Optional[Dict[JobSpecSource, Dict[str, Any]]] = None + self.cfg_points: Optional[Dict[str, 'ConfigurationPoint']] = None + self.seen_specific_config: DefaultDict[str, List[str]] = defaultdict(list) -def update_config_from_source(final_config, source, state): - if source in state.generic_config: +def update_config_from_source(final_config: obj_dict, source: JobSpecSource, + state: MergeState) -> None: + """ + update configuration from source and merge based on priority specificity + """ + if state.generic_config and source in state.generic_config: final_config.name = state.generic_name - for name, cfg_point in state.cfg_points.items(): - if name in state.generic_config[source]: - if name in state.seen_specific_config: - msg = ('"{generic_name}" configuration "{config_name}" has ' - 'already been specified more specifically for ' - '{specific_name} in:\n\t\t{sources}') - seen_sources = state.seen_specific_config[name] - msg = msg.format(generic_name=state.generic_name, - config_name=name, - specific_name=state.specific_name, - sources=", ".join(seen_sources)) - raise ConfigError(msg) - value = state.generic_config[source].pop(name) - cfg_point.set_value(final_config, value, check_mandatory=False) + if state.cfg_points: + for name, cfg_point in state.cfg_points.items(): + if name in state.generic_config[source]: + if name in state.seen_specific_config: + msg: str = ('"{generic_name}" configuration "{config_name}" has ' + 'already been specified more specifically for ' + '{specific_name} in:\n\t\t{sources}') + seen_sources: List[str] = state.seen_specific_config[name] + msg = msg.format(generic_name=state.generic_name, + config_name=name, + specific_name=state.specific_name, + sources=", ".join(seen_sources)) + raise ConfigError(msg) + value = state.generic_config[source].pop(name) + cfg_point.set_value(final_config, value, check_mandatory=False) if state.generic_config[source]: msg = 'Unexpected values for {}: {}' raise ConfigError(msg.format(state.generic_name, state.generic_config[source])) - if source in state.specific_config: + if state.specific_config and source in state.specific_config: final_config.name = state.specific_name - for name, cfg_point in state.cfg_points.items(): - if name in state.specific_config[source]: - state.seen_specific_config[name].append(str(source)) - value = state.specific_config[source].pop(name) - cfg_point.set_value(final_config, value, check_mandatory=False) + if state.cfg_points: + for name, cfg_point in state.cfg_points.items(): + if name in state.specific_config[source]: + state.seen_specific_config[name].append(str(source)) + value = state.specific_config[source].pop(name) + cfg_point.set_value(final_config, value, check_mandatory=False) if state.specific_config[source]: msg = 'Unexpected values for {}: {}' diff --git a/wa/framework/configuration/tree.py b/wa/framework/configuration/tree.py index f45c5cdf5..33067941a 100644 --- a/wa/framework/configuration/tree.py +++ b/wa/framework/configuration/tree.py @@ -15,29 +15,41 @@ import logging from wa.utils import log - +from wa.utils.types import obj_dict +from typing import Optional, List, Generator, Any logger = logging.getLogger('config') class JobSpecSource(object): + """ + class representing a job specification source. + """ + kind: str = "" - kind = "" - - def __init__(self, config, parent=None): + def __init__(self, config: obj_dict, parent: Optional['SectionNode'] = None): self.config = config self.parent = parent self._log_self() @property - def id(self): + def id(self) -> str: + """ + source id + """ return self.config['id'] @property - def name(self): + def name(self) -> str: + """ + name of the specification + """ raise NotImplementedError() - def _log_self(self): + def _log_self(self) -> None: + """ + log the source structure + """ logger.debug('Creating {} node'.format(self.kind)) with log.indentcontext(): for key, value in self.config.items(): @@ -45,38 +57,60 @@ def _log_self(self): class WorkloadEntry(JobSpecSource): - kind = "workload" + """ + workloads in section nodes + """ + kind: str = "workload" @property - def name(self): - if self.parent.id == "global": + def name(self) -> str: + """ + name of the workload entry + """ + if self.parent and self.parent.id == "global": return 'workload "{}"'.format(self.id) else: - return 'workload "{}" from section "{}"'.format(self.id, self.parent.id) + return 'workload "{}" from section "{}"'.format(self.id, self.parent.id if self.parent else '') class SectionNode(JobSpecSource): - - kind = "section" + """ + a node representing a section in the job tree. + section is a set of configurations for how jobs should be run. The + settings in them take less precedence than workload-specific settings. For + every section, all jobs will be run again, with the changes + specified in the section's agenda entry. Sections + are useful for several runs in which global settings change. + """ + kind: str = "section" @property - def name(self): + def name(self) -> str: + """ + name of the section node + """ if self.id == "global": return "globally specified configuration" else: return 'section "{}"'.format(self.id) @property - def is_leaf(self): + def is_leaf(self) -> bool: + """ + true if it is a leaf node of the tree + """ return not bool(self.children) - def __init__(self, config, parent=None, group=None): + def __init__(self, config: obj_dict, parent=None, group: Optional[str] = None): super(SectionNode, self).__init__(config, parent=parent) - self.workload_entries = [] - self.children = [] + self.workload_entries: List[WorkloadEntry] = [] + self.children: List['SectionNode'] = [] self.group = group - def add_section(self, section, group=None): + def add_section(self, section: obj_dict, group: Optional[str] = None) -> 'SectionNode': + """ + add section to the job tree + """ # Each level is the same group, only need to check first if not self.children or group == self.children[0].group: new_node = SectionNode(section, parent=self, group=group) @@ -86,22 +120,34 @@ def add_section(self, section, group=None): new_node = child.add_section(section, group) return new_node - def add_workload(self, workload_config): + def add_workload(self, workload_config: obj_dict) -> None: + """ + add a workload to the section node + """ self.workload_entries.append(WorkloadEntry(workload_config, self)) - def descendants(self): + def descendants(self) -> Generator['SectionNode', Any, None]: + """ + descendants of the current section node + """ for child in self.children: for n in child.descendants(): yield n yield child - def ancestors(self): + def ancestors(self) -> Generator['SectionNode', Any, None]: + """ + ancestors of the current section node + """ if self.parent is not None: yield self.parent for ancestor in self.parent.ancestors(): yield ancestor - def leaves(self): + def leaves(self) -> Generator['SectionNode', Any, None]: + """ + leaf nodes of the job tree starting from current section node + """ if self.is_leaf: yield self else: diff --git a/wa/framework/entrypoint.py b/wa/framework/entrypoint.py index 2a99fdaed..436d11fab 100644 --- a/wa/framework/entrypoint.py +++ b/wa/framework/entrypoint.py @@ -20,9 +20,13 @@ import logging import os import warnings - +from typing import (Optional, TYPE_CHECKING, cast, Dict, + List) import devlib +import devlib.utils +import devlib.utils.version try: + installed_devlib_version: Optional[devlib.utils.version.Version] from devlib.utils.version import version as installed_devlib_version except ImportError: installed_devlib_version = None @@ -37,20 +41,28 @@ required_devlib_version) from wa.utils import log from wa.utils.doc import format_body +from argparse import _SubParsersAction, Namespace +if TYPE_CHECKING: + from wa.framework.pluginloader import __LoaderWrapper + from wa.framework.command import Command warnings.filterwarnings(action='ignore', category=UserWarning, module='zope') # Disable this to avoid false positive from dynamically-created attributes. # pylint: disable=no-member -logger = logging.getLogger('command_line') +logger: logging.Logger = logging.getLogger('command_line') -def load_commands(subparsers): - commands = {} - for command in pluginloader.list_commands(): - commands[command.name] = pluginloader.get_command(command.name, - subparsers=subparsers) +def load_commands(subparsers: _SubParsersAction) -> Dict[str, 'Command']: + """ + load commands + """ + commands: Dict[str, 'Command'] = {} + for command in cast('__LoaderWrapper', pluginloader).list_commands(): + commands[command.name] = cast('__LoaderWrapper', + pluginloader).get_command(command.name, + subparsers=subparsers) return commands @@ -59,8 +71,11 @@ def load_commands(subparsers): # description of the issue (with a fix attached since 2013!). To get around # this problem, this will pre-process sys.argv to detect such joined options # and split them. -def split_joined_options(argv): - output = [] +def split_joined_options(argv: List[str]) -> List[str]: + """ + split joined options + """ + output: List[str] = [] for part in argv: if len(part) > 1 and part[0] == '-' and part[1] != '-': for c in part[1:]: @@ -71,19 +86,25 @@ def split_joined_options(argv): # Instead of presenting an obscure error due to a version mismatch explicitly warn the user. -def check_devlib_version(): +def check_devlib_version() -> None: + """ + check devlib version + """ if not installed_devlib_version or installed_devlib_version[:-1] <= required_devlib_version[:-1]: # Check the 'dev' field separately to account for comparing with release versions. - if installed_devlib_version.dev and installed_devlib_version.dev < required_devlib_version.dev: - msg = 'WA requires Devlib version >={}. Please update the currently installed version {}' + if installed_devlib_version and installed_devlib_version.dev and installed_devlib_version.dev < required_devlib_version.dev: + msg: str = 'WA requires Devlib version >={}. Please update the currently installed version {}' raise HostError(msg.format(format_version(required_devlib_version), devlib.__version__)) # If the default encoding is not UTF-8 warn the user as this may cause compatibility issues # when parsing files. -def check_system_encoding(): - system_encoding = locale.getpreferredencoding() - msg = 'System Encoding: {}'.format(system_encoding) +def check_system_encoding() -> None: + """ + check system encoding + """ + system_encoding: str = locale.getpreferredencoding() + msg: str = 'System Encoding: {}'.format(system_encoding) if 'UTF-8' not in system_encoding: logger.warning(msg) logger.warning('To prevent encoding issues please use a locale setting which supports UTF-8') @@ -91,7 +112,7 @@ def check_system_encoding(): logger.debug(msg) -def main(): +def main() -> None: if not os.path.exists(settings.user_directory): init_user_directory() if not os.path.exists(os.path.join(settings.user_directory, 'config.yaml')): @@ -99,9 +120,9 @@ def main(): try: - description = ("Execute automated workloads on a remote device and process " - "the resulting output.\n\nUse \"wa -h\" to see " - "help for individual subcommands.") + description: str = ("Execute automated workloads on a remote device and process " + "the resulting output.\n\nUse \"wa -h\" to see " + "help for individual subcommands.") parser = argparse.ArgumentParser(description=format_body(description, 80), prog='wa', formatter_class=argparse.RawDescriptionHelpFormatter, @@ -112,12 +133,12 @@ def main(): # to be enabled for that, which requires the verbosity setting; however # full argument parsing cannot be completed until the commands are loaded; so # parse just the base args for now so we can get verbosity. - argv = split_joined_options(sys.argv[1:]) + argv: List[str] = split_joined_options(sys.argv[1:]) # 'Parse_known_args' automatically displays the default help and exits # if '-h' or '--help' is detected, we want our custom help messages so # ensure these are never passed as parameters. - filtered_argv = list(argv) + filtered_argv: List[str] = list(argv) if '-h' in filtered_argv: filtered_argv.remove('-h') elif '--help' in filtered_argv: @@ -133,9 +154,9 @@ def main(): check_system_encoding() # each command will add its own subparser - subparsers = parser.add_subparsers(dest='command') + subparsers: _SubParsersAction = parser.add_subparsers(dest='command') subparsers.required = True - commands = load_commands(subparsers) + commands: Dict[str, 'Command'] = load_commands(subparsers) args = parser.parse_args(argv) config = ConfigManager() @@ -145,8 +166,8 @@ def main(): raise ConfigError("Config file {} not found".format(config_file)) config.load_config_file(config_file) - command = commands[args.command] - sys.exit(command.execute(config, args)) + command: 'Command' = commands[args.command] + sys.exit(command.execute(config, args)) # type: ignore except KeyboardInterrupt as e: log.log_error(e, logger) diff --git a/wa/framework/exception.py b/wa/framework/exception.py index 5f323e56c..d1ecdedfd 100644 --- a/wa/framework/exception.py +++ b/wa/framework/exception.py @@ -17,12 +17,15 @@ TargetError, TargetNotRespondingError) from wa.utils.misc import get_traceback +from typing import Optional, Tuple, Type +from types import TracebackType class WAError(Exception): """Base class for all Workload Automation exceptions.""" @property - def message(self): + def message(self) -> str: + """Error message""" if self.args: return self.args[0] return '' @@ -80,20 +83,22 @@ class SerializerSyntaxError(Exception): Error loading a serialized structure from/to a file handle. """ @property - def message(self): + def message(self) -> str: + """Error message""" if self.args: return self.args[0] return '' - def __init__(self, message, line=None, column=None): + def __init__(self, message: str, line: Optional[int] = None, + column: Optional[int] = None): super(SerializerSyntaxError, self).__init__(message) self.line = line self.column = column - def __str__(self): - linestring = ' on line {}'.format(self.line) if self.line else '' - colstring = ' in column {}'.format(self.column) if self.column else '' - message = 'Syntax Error{}: {}' + def __str__(self) -> str: + linestring: str = ' on line {}'.format(self.line) if self.line else '' + colstring: str = ' in column {}'.format(self.column) if self.column else '' + message: str = 'Syntax Error{}: {}' return message.format(''.join([linestring, colstring]), self.message) @@ -104,18 +109,20 @@ class PluginLoaderError(WAError): sys.exc_info() for the original exception (if any) that caused the error.""" - def __init__(self, message, exc_info=None): + def __init__(self, message: str, + exc_info: Optional[Tuple[Optional[Type[BaseException]], + Optional[BaseException], Optional[TracebackType]]] = None): super(PluginLoaderError, self).__init__(message) self.exc_info = exc_info - def __str__(self): + def __str__(self) -> str: if self.exc_info: - orig = self.exc_info[1] - orig_name = type(orig).__name__ + orig: Optional[BaseException] = self.exc_info[1] + orig_name: str = type(orig).__name__ if isinstance(orig, WAError): - reason = 'because of:\n{}: {}'.format(orig_name, orig) + reason: str = 'because of:\n{}: {}'.format(orig_name, orig) else: - text = 'because of:\n{}\n{}: {}' + text: str = 'because of:\n{}\n{}: {}' reason = text.format(get_traceback(self.exc_info), orig_name, orig) return '\n'.join([self.message, reason]) else: @@ -133,7 +140,9 @@ class WorkerThreadError(WAError): """ - def __init__(self, thread, exc_info): + def __init__(self, thread: str, + exc_info: Tuple[Optional[Type[BaseException]], + Optional[BaseException], Optional[TracebackType]]): self.thread = thread self.exc_info = exc_info orig = self.exc_info[1] diff --git a/wa/framework/execution.py b/wa/framework/execution.py index 6daf2c647..8adf6baa3 100644 --- a/wa/framework/execution.py +++ b/wa/framework/execution.py @@ -21,6 +21,7 @@ import shutil from copy import copy from datetime import datetime +from collections import Counter import wa.framework.signal as signal from wa.framework import instrument as instrumentation @@ -28,114 +29,174 @@ from wa.framework.exception import TargetError, HostError, WorkloadError, ExecutionError from wa.framework.exception import TargetNotRespondingError, TimeoutError # pylint: disable=redefined-builtin from wa.framework.job import Job -from wa.framework.output import init_job_output +from wa.framework.output import init_job_output, Output, RunOutput, Metric, Artifact from wa.framework.output_processor import ProcessorManager -from wa.framework.resource import ResourceResolver +from wa.framework.resource import ResourceResolver, Resource from wa.framework.target.manager import TargetManager from wa.utils import log from wa.utils.misc import merge_config_values, format_duration +from typing import (Optional, Dict, Any, TYPE_CHECKING, cast, List, Type) +from devlib.target import AndroidTarget, Target +from types import ModuleType +if TYPE_CHECKING: + from wa.framework.configuration.execution import ConfigManager + from wa.framework.configuration.core import RunConfigurationProtocol, RebootPolicy, ConfigurationPoint + from wa.framework.signal import Signal + from wa.framework.workload import Workload + from wa.framework.target.info import TargetInfo + from wa.framework.plugin import Plugin + from wa.framework.configuration.core import StatusType + from louie.dispatcher import Anonymous # type:ignore class ExecutionContext(object): @property - def previous_job(self): + def previous_job(self) -> Optional[Job]: + """ + previous job that was executed and completed + """ if not self.completed_jobs: return None return self.completed_jobs[-1] @property - def next_job(self): + def next_job(self) -> Optional[Job]: + """ + next job in the job queue to be executed + """ if not self.job_queue: return None return self.job_queue[0] @property - def spec_changed(self): + def spec_changed(self) -> bool: + """ + checks whether job spec has changed between previous and current job + """ if self.previous_job is None and self.current_job is not None: # Start of run return True if self.previous_job is not None and self.current_job is None: # End of run return True - return self.current_job.spec.id != self.previous_job.spec.id + return (self.previous_job is not None and self.current_job is not None) and (self.current_job.spec.id != self.previous_job.spec.id) @property - def spec_will_change(self): + def spec_will_change(self) -> bool: + """ + checks whether job spec will change between current and next job + """ if self.current_job is None and self.next_job is not None: # Start of run return True if self.current_job is not None and self.next_job is None: # End of run return True - return self.current_job.spec.id != self.next_job.spec.id + return (self.next_job is not None and self.current_job is not None) and self.current_job.spec.id != self.next_job.spec.id @property - def workload(self): + def workload(self) -> Optional['Workload']: + """ + workload being executed in the current job + """ if self.current_job: return self.current_job.workload + return None @property - def job_output(self): + def job_output(self) -> Optional[Output]: + """ + output from the current job + """ if self.current_job: return self.current_job.output + return None @property - def output(self): + def output(self) -> RunOutput: + """ + output from the current run + """ if self.current_job: - return self.job_output + return self.job_output # type:ignore return self.run_output @property - def output_directory(self): + def output_directory(self) -> str: + """ + output directory of the job execution + """ return self.output.basepath @property - def reboot_policy(self): - return self.cm.run_config.reboot_policy + def reboot_policy(self) -> 'RebootPolicy': + """ + reboot policy + """ + return cast('RunConfigurationProtocol', self.cm.run_config).reboot_policy @property - def target_info(self): + def target_info(self) -> Optional['TargetInfo']: + """ + get target info + """ return self.run_output.target_info - def __init__(self, cm, tm, output): - self.logger = logging.getLogger('context') + def __init__(self, cm: 'ConfigManager', tm: 'TargetManager', output: RunOutput): + self.logger: logging.Logger = logging.getLogger('context') self.cm = cm self.tm = tm self.run_output = output self.run_state = output.state - self.job_queue = None - self.completed_jobs = None - self.current_job = None - self.successful_jobs = 0 - self.failed_jobs = 0 - self.run_interrupted = False + self.job_queue: Optional[List[Job]] = None + self.completed_jobs: Optional[List[Job]] = None + self.current_job: Optional[Job] = None + self.successful_jobs: int = 0 + self.failed_jobs: int = 0 + self.run_interrupted: bool = False self._load_resource_getters() - def start_run(self): - self.output.info.start_time = datetime.utcnow() + def start_run(self) -> None: + """ + start execution + """ + if self.output.info: + self.output.info.start_time = datetime.utcnow() self.output.write_info() self.job_queue = copy(self.cm.jobs) self.completed_jobs = [] - self.run_state.status = Status.STARTED + if self.run_state: + self.run_state.status = Status.STARTED self.output.status = Status.STARTED self.output.write_state() - def end_run(self): + def end_run(self) -> None: + """ + end the execution + """ if self.successful_jobs: if self.failed_jobs: - status = Status.PARTIAL + status: 'StatusType' = Status.PARTIAL else: status = Status.OK else: status = Status.FAILED - self.run_state.status = status + if self.run_state: + self.run_state.status = status self.run_output.status = status - self.run_output.info.end_time = datetime.utcnow() - self.run_output.info.duration = (self.run_output.info.end_time - - self.run_output.info.start_time) + if self.run_output.info: + self.run_output.info.end_time = datetime.utcnow() + self.run_output.info.duration = (cast(datetime, self.run_output.info.end_time) + - cast(datetime, self.run_output.info.start_time)) self.write_output() - def finalize(self): + def finalize(self) -> None: + """ + finalize the execution + """ self.tm.finalize() - def start_job(self): + def start_job(self) -> 'Job': + """ + start a job from the job queue + """ if not self.job_queue: raise RuntimeError('No jobs to run') self.current_job = self.job_queue.pop(0) @@ -143,58 +204,102 @@ def start_job(self): self.current_job.set_output(job_output) return self.current_job - def end_job(self): + def end_job(self) -> None: + """ + end running job + """ if not self.current_job: raise RuntimeError('No jobs in progress') - self.completed_jobs.append(self.current_job) + if self.completed_jobs: + self.completed_jobs.append(self.current_job) self.output.write_result() self.current_job = None - def set_status(self, status, force=False, write=True): + def set_status(self, status: 'StatusType', + force: bool = False, write: bool = True) -> None: + """ + set status for the current job + """ if not self.current_job: raise RuntimeError('No jobs in progress') self.set_job_status(self.current_job, status, force, write) - def set_job_status(self, job, status, force=False, write=True): + def set_job_status(self, job: Job, status: 'StatusType', + force: bool = False, write: bool = True) -> None: + """ + set status for the specified job + """ job.set_status(status, force) if write: self.run_output.write_state() - def extract_results(self): + def extract_results(self) -> None: + """ + exract results of the execution + """ self.tm.extract_results(self) - def move_failed(self, job): - self.run_output.move_failed(job.output) + def move_failed(self, job: Job): + """ + move output of failed jobs to a separate directory + """ + if job.output: + self.run_output.move_failed(job.output) - def skip_job(self, job): + def skip_job(self, job: Job) -> None: + """ + skip execution of the specified job + """ self.set_job_status(job, Status.SKIPPED, force=True) - self.completed_jobs.append(job) + if self.completed_jobs: + self.completed_jobs.append(job) - def skip_remaining_jobs(self): + def skip_remaining_jobs(self) -> None: + """ + skip all the remaining jobs execution + """ while self.job_queue: - job = self.job_queue.pop(0) + job: Job = self.job_queue.pop(0) self.skip_job(job) self.write_state() - def write_config(self): + def write_config(self) -> None: + """ + write config into a config json file + """ self.run_output.write_config(self.cm.get_config()) - def write_state(self): + def write_state(self) -> None: + """ + write execution state into .run_state json file + """ self.run_output.write_state() - def write_output(self): + def write_output(self) -> None: + """ + write info into run_info file, state into .run_state file and result into result file + """ self.run_output.write_info() self.run_output.write_state() self.run_output.write_result() - def write_job_specs(self): + def write_job_specs(self) -> None: + """ + write job specs into jobs.json file + """ self.run_output.write_job_specs(self.cm.job_specs) - def add_augmentation(self, aug): + def add_augmentation(self, aug: 'Plugin') -> None: + """ + add augmentation to the run configuration + """ self.cm.run_config.add_augmentation(aug) - def get_resource(self, resource, strict=True): - result = self.resolver.get(resource, strict) + def get_resource(self, resource: Resource, strict: bool = True) -> Optional[str]: + """ + get path to the resource + """ + result: Optional[str] = self.resolver.get(resource, strict) if result is None: return result if os.path.isfile(result): @@ -206,7 +311,10 @@ def get_resource(self, resource, strict=True): get = get_resource # alias to allow a context to act as a resolver - def get_metric(self, name): + def get_metric(self, name: str) -> Optional[Metric]: + """ + get metric of the specified name from the workload execution output + """ try: return self.output.get_metric(name) except HostError: @@ -214,14 +322,21 @@ def get_metric(self, name): raise return self.run_output.get_metric(name) - def add_metric(self, name, value, units=None, lower_is_better=False, - classifiers=None): + def add_metric(self, name: str, value: Any, units: Optional[str] = None, + lower_is_better: bool = False, + classifiers: Optional[Dict[str, Any]] = None): + """ + Add a metric to the workload execution output + """ if self.current_job: classifiers = merge_config_values(self.current_job.classifiers, classifiers) self.output.add_metric(name, value, units, lower_is_better, classifiers) - def get_artifact(self, name): + def get_artifact(self, name: str) -> Optional[Artifact]: + """ + get artifact from workload execution output + """ try: return self.output.get_artifact(name) except HostError: @@ -229,7 +344,10 @@ def get_artifact(self, name): raise return self.run_output.get_artifact(name) - def get_artifact_path(self, name): + def get_artifact_path(self, name: str) -> str: + """ + get path to the specified artifact + """ try: return self.output.get_artifact_path(name) except HostError: @@ -237,84 +355,126 @@ def get_artifact_path(self, name): raise return self.run_output.get_artifact_path(name) - def add_artifact(self, name, path, kind, description=None, classifiers=None): + def add_artifact(self, name: str, path: str, kind: str, + description: Optional[str] = None, + classifiers: Optional[Dict[str, Any]] = None) -> None: + """ + add artifact to the workload job execution output + """ self.output.add_artifact(name, path, kind, description, classifiers) - def add_run_artifact(self, name, path, kind, description=None, - classifiers=None): + def add_run_artifact(self, name: str, path: str, kind: str, + description: Optional[str] = None, + classifiers: Optional[Dict[str, Any]] = None) -> None: + """ + add artifact to the workload run execution output + """ self.run_output.add_artifact(name, path, kind, description, classifiers) - def add_event(self, message): + def add_event(self, message: str) -> None: + """ + add event to the workload job execution output + """ self.output.add_event(message) - def add_classifier(self, name, value, overwrite=False): + def add_classifier(self, name: str, value: Any, overwrite: bool = False) -> None: + """ + add classifier to workload execution output + """ self.output.add_classifier(name, value, overwrite) if self.current_job: self.current_job.add_classifier(name, value, overwrite) - def add_metadata(self, key, *args, **kwargs): + def add_metadata(self, key: str, *args, **kwargs) -> None: + """ + add metadata to workload execution output + """ self.output.add_metadata(key, *args, **kwargs) - def update_metadata(self, key, *args): - self.output.update_metadata(key, *args) + def update_metadata(self, key: str, *args) -> None: + """ + update an existing metadata in workload execution output + """ + if self.output: + self.output.update_metadata(key, *args) - def take_screenshot(self, filename): - filepath = self._get_unique_filepath(filename) - self.tm.target.capture_screen(filepath) + def take_screenshot(self, filename: str) -> None: + """ + take screenshot + """ + filepath: str = self._get_unique_filepath(filename) + if self.tm.target: + self.tm.target.capture_screen(filepath) if os.path.isfile(filepath): self.add_artifact('screenshot', filepath, kind='log') - def take_uiautomator_dump(self, filename): + def take_uiautomator_dump(self, filename: str) -> None: + """ + take a ui automator dump + """ filepath = self._get_unique_filepath(filename) - self.tm.target.capture_ui_hierarchy(filepath) + cast(AndroidTarget, self.tm.target).capture_ui_hierarchy(filepath) self.add_artifact('uitree', filepath, kind='log') - def record_ui_state(self, basename): + def record_ui_state(self, basename: str) -> None: + """ + record ui state of the target + """ self.logger.info('Recording screen state...') self.take_screenshot('{}.png'.format(basename)) - target = self.tm.target - if target.os == 'android' or\ - (target.os == 'chromeos' and target.has('android_container')): + target: Optional[Target] = self.tm.target + if target and (target.os == 'android' or + (target.os == 'chromeos' and target.has('android_container'))): self.take_uiautomator_dump('{}.uix'.format(basename)) - def initialize_jobs(self): - new_queue = [] - failed_ids = [] - for job in self.job_queue: - if job.id in failed_ids: - # Don't try to initialize a job if another job with the same ID - # (i.e. same job spec) has failed - we can assume it will fail - # too. - self.skip_job(job) - continue - - try: - job.initialize(self) - except WorkloadError as e: - self.set_job_status(job, Status.FAILED, write=False) - log.log_error(e, self.logger) - failed_ids.append(job.id) + def initialize_jobs(self) -> None: + """ + initialize jobs + """ + new_queue: List[Job] = [] + failed_ids: List[str] = [] + if self.job_queue: + for job in self.job_queue: + if job.id in failed_ids: + # Don't try to initialize a job if another job with the same ID + # (i.e. same job spec) has failed - we can assume it will fail + # too. + self.skip_job(job) + continue - if self.cm.run_config.bail_on_init_failure: - raise - else: - new_queue.append(job) + try: + job.initialize(self) + except WorkloadError as e: + self.set_job_status(job, Status.FAILED, write=False) + log.log_error(e, self.logger) + failed_ids.append(job.id or '') + + if self.cm.run_config.bail_on_init_failure: + raise + else: + new_queue.append(job) self.job_queue = new_queue self.write_state() - def _load_resource_getters(self): + def _load_resource_getters(self) -> None: + """ + load resource getters + """ self.logger.debug('Loading resource discoverers') - self.resolver = ResourceResolver(self.cm.plugin_cache) + self.resolver = ResourceResolver(cast(ModuleType, self.cm.plugin_cache)) self.resolver.load() for getter in self.resolver.getters: self.cm.run_config.add_resource_getter(getter) - def _get_unique_filepath(self, filename): - filepath = os.path.join(self.output_directory, filename) + def _get_unique_filepath(self, filename: str) -> str: + """ + get a unique filepath for the specified file in output directory + """ + filepath: str = os.path.join(self.output_directory, filename) rest, ext = os.path.splitext(filepath) - i = 1 - new_filepath = '{}-{}{}'.format(rest, i, ext) + i: int = 1 + new_filepath: str = '{}-{}{}'.format(rest, i, ext) if not os.path.exists(filepath) and not os.path.exists(new_filepath): return filepath @@ -344,13 +504,13 @@ class Executor(object): """ # pylint: disable=R0915 - def __init__(self): - self.logger = logging.getLogger('executor') - self.error_logged = False - self.warning_logged = False - self.target_manager = None + def __init__(self) -> None: + self.logger: logging.Logger = logging.getLogger('executor') + self.error_logged: bool = False + self.warning_logged: bool = False + self.target_manager: Optional[TargetManager] = None - def execute(self, config_manager, output): + def execute(self, config_manager: 'ConfigManager', output: RunOutput): """ Execute the run specified by an agenda. Optionally, selectors may be used to only execute a subset of the specified agenda. @@ -370,12 +530,13 @@ def execute(self, config_manager, output): config = config_manager.finalize() output.write_config(config) - self.target_manager = TargetManager(config.run_config.device, - config.run_config.device_config, - output.basepath) - - self.logger.info('Initializing execution context') - context = ExecutionContext(config_manager, self.target_manager, output) + if config.run_config: + self.target_manager = TargetManager(config.run_config.device, + cast(Dict[str, 'ConfigurationPoint'], config.run_config.device_config), + output.basepath) + if self.target_manager: + self.logger.info('Initializing execution context') + context = ExecutionContext(config_manager, self.target_manager, output) try: self.do_execute(context) @@ -392,18 +553,22 @@ def execute(self, config_manager, output): finally: context.finalize() self.execute_postamble(context, output) - signal.send(signal.RUN_COMPLETED, self, context) + signal.send(signal.RUN_COMPLETED, cast(Type['Anonymous'], self), context) - def do_execute(self, context): + def do_execute(self, context: ExecutionContext) -> None: + """ + connect to target, do initializations and run the jobs + """ self.logger.info('Connecting to target') context.tm.initialize() if context.cm.run_config.reboot_policy.perform_initial_reboot: self.logger.info('Performing initial reboot.') - attempts = context.cm.run_config.max_retries + attempts: int = context.cm.run_config.max_retries while attempts: try: - self.target_manager.reboot(context) + if self.target_manager: + self.target_manager.reboot(context) except TargetError as e: if attempts: attempts -= 1 @@ -411,18 +576,19 @@ def do_execute(self, context): raise e else: break - - context.output.set_target_info(self.target_manager.get_target_info()) + if self.target_manager: + context.output.set_target_info(self.target_manager.get_target_info()) self.logger.info('Generating jobs') context.cm.generate_jobs(context) context.write_job_specs() context.output.write_state() - self.logger.info('Installing instruments') - for instrument in context.cm.get_instruments(self.target_manager.target): - instrumentation.install(instrument, context) - instrumentation.validate() + if self.target_manager and self.target_manager.target: + self.logger.info('Installing instruments') + for instrument in context.cm.get_instruments(self.target_manager.target): + instrumentation.install(instrument, context) + instrumentation.validate() self.logger.info('Installing output processors') pm = ProcessorManager() @@ -434,18 +600,22 @@ def do_execute(self, context): self.logger.info('Starting run') runner = Runner(context, pm) - signal.send(signal.RUN_STARTED, self, context) + signal.send(signal.RUN_STARTED, cast(Type['Anonymous'], self), context) runner.run() - def execute_postamble(self, context, output): + def execute_postamble(self, context: ExecutionContext, output: RunOutput) -> None: + """ + execute postamble + """ self.logger.info('Done.') - duration = format_duration(output.info.duration) + duration: str = format_duration(output.info.duration or 0) if output.info else '' self.logger.info('Run duration: {}'.format(duration)) - num_ran = context.run_state.num_completed_jobs - status_summary = 'Ran a total of {} iterations: '.format(num_ran) + num_ran: int = context.run_state.num_completed_jobs if context.run_state else 0 + status_summary: str = 'Ran a total of {} iterations: '.format(num_ran) - counter = context.run_state.get_status_counts() - parts = [] + if context.run_state: + counter: Counter = context.run_state.get_status_counts() + parts: List[str] = [] for status in reversed(Status.levels): if status in counter: parts.append('{} {}'.format(counter[status], status)) @@ -460,11 +630,17 @@ def execute_postamble(self, context, output): self.logger.warning('There were warnings during execution.') self.logger.warning('Please see {}'.format(output.logfile)) - def _error_signalled_callback(self, _): + def _error_signalled_callback(self, _) -> None: + """ + error signalled callback + """ self.error_logged = True signal.disconnect(self._error_signalled_callback, signal.ERROR_LOGGED) - def _warning_signalled_callback(self, _): + def _warning_signalled_callback(self, _) -> None: + """ + warning signalled callback + """ self.warning_logged = True signal.disconnect(self._warning_signalled_callback, signal.WARNING_LOGGED) @@ -483,19 +659,22 @@ class Runner(object): processing job and run results. """ - def __init__(self, context, pm): - self.logger = logging.getLogger('runner') + def __init__(self, context: ExecutionContext, pm: ProcessorManager): + self.logger: logging.Logger = logging.getLogger('runner') self.context = context self.pm = pm self.output = self.context.output self.config = self.context.cm - def run(self): + def run(self) -> None: + """ + run the jobs + """ try: self.initialize_run() self.send(signal.RUN_INITIALIZED) - with signal.wrap('JOB_QUEUE_EXECUTION', self, self.context): + with signal.wrap('JOB_QUEUE_EXECUTION', cast(Type['Anonymous'], self), self.context): while self.context.job_queue: if self.context.run_interrupted: raise KeyboardInterrupt() @@ -506,7 +685,7 @@ def run(self): self.logger.info('Skipping remaining jobs.') self.context.skip_remaining_jobs() except Exception as e: - message = e.args[0] if e.args else str(e) + message: str = e.args[0] if e.args else str(e) log.log_error(e, self.logger) self.logger.error('Skipping remaining jobs due to "{}".'.format(message)) self.context.skip_remaining_jobs() @@ -515,7 +694,10 @@ def run(self): self.finalize_run() self.send(signal.RUN_FINALIZED) - def initialize_run(self): + def initialize_run(self) -> None: + """ + initialize run of execution + """ self.logger.info('Initializing run') signal.connect(self._error_signalled_callback, signal.ERROR_LOGGED) signal.connect(self._warning_signalled_callback, signal.WARNING_LOGGED) @@ -525,15 +707,19 @@ def initialize_run(self): self.context.initialize_jobs() self.context.write_state() - def finalize_run(self): + def finalize_run(self) -> None: + """ + finalize run of execution + """ self.logger.info('Run completed') with log.indentcontext(): - for job in self.context.completed_jobs: - job.finalize(self.context) + if self.context.completed_jobs: + for job in self.context.completed_jobs: + job.finalize(self.context) self.logger.info('Finalizing run') self.context.end_run() self.pm.enable_all() - with signal.wrap('RUN_OUTPUT_PROCESSED', self): + with signal.wrap('RUN_OUTPUT_PROCESSED', cast(Type['Anonymous'], self)): self.pm.process_run_output(self.context) self.pm.export_run_output(self.context) self.pm.finalize(self.context) @@ -543,8 +729,11 @@ def finalize_run(self): signal.disconnect(self._error_signalled_callback, signal.ERROR_LOGGED) signal.disconnect(self._warning_signalled_callback, signal.WARNING_LOGGED) - def run_next_job(self, context): - job = context.start_job() + def run_next_job(self, context: ExecutionContext) -> None: + """ + run next job + """ + job: Job = context.start_job() self.logger.info('Running job {}'.format(job.id)) try: @@ -556,7 +745,7 @@ def run_next_job(self, context): self.logger.info('Rebooting on new spec.') self.context.tm.reboot(context) - with signal.wrap('JOB', self, context): + with signal.wrap('JOB', cast(Type['Anonymous'], self), context): context.tm.start() self.do_run_job(job, context) context.set_job_status(job, Status.OK) @@ -581,10 +770,13 @@ def run_next_job(self, context): log.dedent() self.check_job(job) - def do_run_job(self, job, context): + def do_run_job(self, job: Job, context: ExecutionContext) -> None: + """ + do run job + """ # pylint: disable=too-many-branches,too-many-statements - rc = self.context.cm.run_config - if job.workload.phones_home and not rc.allow_phone_home: + rc: 'RunConfigurationProtocol' = self.context.cm.run_config + if job.workload and job.workload.phones_home and not rc.allow_phone_home: self.logger.warning('Skipping job {} ({}) due to allow_phone_home=False' .format(job.id, job.workload.name)) self.context.skip_job(job) @@ -595,7 +787,7 @@ def do_run_job(self, job, context): job.configure_augmentations(context, self.pm) - with signal.wrap('JOB_TARGET_CONFIG', self, context): + with signal.wrap('JOB_TARGET_CONFIG', cast(Type['Anonymous'], self), context): job.configure_target(context) try: @@ -625,7 +817,7 @@ def do_run_job(self, job, context): raise e finally: try: - with signal.wrap('JOB_OUTPUT_PROCESSED', self, context): + with signal.wrap('JOB_OUTPUT_PROCESSED', cast(Type['Anonymous'], self), context): job.process_output(context) self.pm.process_job_output(context) self.pm.export_job_output(context) @@ -645,11 +837,14 @@ def do_run_job(self, job, context): # run even if the job failed job.teardown(context) - def check_job(self, job): - rc = self.context.cm.run_config + def check_job(self, job: Job) -> None: + """ + check job + """ + rc: 'RunConfigurationProtocol' = self.context.cm.run_config if job.status in rc.retry_on_status: if job.retries < rc.max_retries: - msg = 'Job {} iteration {} completed with status {}. retrying...' + msg: str = 'Job {} iteration {} completed with status {}. retrying...' self.logger.error(msg.format(job.id, job.iteration, job.status)) self.retry_job(job) self.context.move_failed(job) @@ -671,22 +866,36 @@ def check_job(self, job): self.context.failed_jobs += 1 self.send(signal.JOB_ABORTED) - def retry_job(self, job): + def retry_job(self, job: Job) -> None: + """ + retry job + """ retry_job = Job(job.spec, job.iteration, self.context) retry_job.workload = job.workload retry_job.state = job.state + # FIXME - the setter type is not geting recognized by the type checker retry_job.retries = job.retries + 1 self.context.set_job_status(retry_job, Status.PENDING, force=True) - self.context.job_queue.insert(0, retry_job) + if self.context.job_queue: + self.context.job_queue.insert(0, retry_job) self.send(signal.JOB_RESTARTED) - def send(self, s): - signal.send(s, self, self.context) + def send(self, s: 'Signal') -> None: + """ + send signal to runner + """ + signal.send(s, cast(Type['Anonymous'], self), self.context) - def _error_signalled_callback(self, record): + def _error_signalled_callback(self, record: logging.LogRecord) -> None: + """ + error signalled callback + """ self.context.add_event(record.getMessage()) - def _warning_signalled_callback(self, record): + def _warning_signalled_callback(self, record: logging.LogRecord) -> None: + """ + warniing signalled callback + """ self.context.add_event(record.getMessage()) def __str__(self): diff --git a/wa/framework/getters.py b/wa/framework/getters.py index 9ca63bf7e..120a89ce0 100644 --- a/wa/framework/getters.py +++ b/wa/framework/getters.py @@ -25,11 +25,14 @@ import shutil import sys -import requests - +import requests # type:ignore +from typing import cast, List, Optional, Dict, Any +from typing_extensions import Protocol +from requests.models import Response # type:ignore from wa import Parameter, settings, __file__ as _base_filepath -from wa.framework.resource import ResourceGetter, SourcePriority, NO_ONE +from wa.framework.resource import (ResourceGetter, SourcePriority, NO_ONE, Resource, + File, Executable, ResourceResolver) from wa.framework.exception import ResourceError from wa.utils.misc import (ensure_directory_exists as _d, atomic_write_path, ensure_file_directory_exists as _f, sha256, urljoin) @@ -41,15 +44,23 @@ logging.getLogger("requests").setLevel(logging.WARNING) logging.getLogger("urllib3").setLevel(logging.WARNING) -logger = logging.getLogger('resource') +logger: logging.Logger = logging.getLogger('resource') + + +class OwnerProtocol(Protocol): + name: str + dependencies_directory: str -def get_by_extension(path, ext): +def get_by_extension(path: str, ext: str) -> List[str]: + """ + get files with the specified extension under the path + """ if not ext.startswith('.'): ext = '.' + ext ext = caseless_string(ext) - found = [] + found: List[str] = [] for entry in os.listdir(path): entry_ext = os.path.splitext(entry)[1] if entry_ext == ext: @@ -57,8 +68,11 @@ def get_by_extension(path, ext): return found -def get_generic_resource(resource, files): - matches = [] +def get_generic_resource(resource: Resource, files: List[str]) -> Optional[str]: + """ + get generic resource + """ + matches: List[str] = [] for f in files: if resource.match(f): matches.append(f) @@ -70,8 +84,11 @@ def get_generic_resource(resource, files): return matches[0] -def get_path_matches(resource, files): - matches = [] +def get_path_matches(resource: Resource, files: List[str]) -> List[str]: + """ + get path matches + """ + matches: List[str] = [] for f in files: if resource.match_path(f): matches.append(f) @@ -79,13 +96,16 @@ def get_path_matches(resource, files): # pylint: disable=too-many-return-statements -def get_from_location(basepath, resource): +def get_from_location(basepath: str, resource: Resource) -> Optional[str]: + """ + get resource from location + """ if resource.kind == 'file': - path = os.path.join(basepath, resource.path) + path = os.path.join(basepath, cast(File, resource).path) if os.path.exists(path): return path elif resource.kind == 'executable': - bin_dir = os.path.join(basepath, 'bin', resource.abi) + bin_dir = os.path.join(basepath, 'bin', cast(Executable, resource).abi) if not os.path.exists(bin_dir): return None for entry in os.listdir(bin_dir): @@ -110,39 +130,63 @@ def get_from_location(basepath, resource): class Package(ResourceGetter): - name = 'package' + name: str = 'package' - def register(self, resolver): + def register(self, resolver: ResourceResolver) -> None: + """ + register the package with resolver + """ resolver.register(self.get, SourcePriority.package) # pylint: disable=no-self-use - def get(self, resource): + def get(self, resource: Resource) -> Optional[str]: + """ + get resource + """ if resource.owner == NO_ONE: - basepath = os.path.join(os.path.dirname(_base_filepath), 'assets') + basepath: str = os.path.join(os.path.dirname(_base_filepath), 'assets') else: - modname = resource.owner.__module__ - basepath = os.path.dirname(sys.modules[modname].__file__) + modname: str = resource.owner.__module__ + basepath = os.path.dirname(sys.modules[modname].__file__ or '') return get_from_location(basepath, resource) class UserDirectory(ResourceGetter): - name = 'user' + name: str = 'user' - def register(self, resolver): + def register(self, resolver: ResourceResolver) -> None: + """ + register user directory wiht the resolver + """ resolver.register(self.get, SourcePriority.local) # pylint: disable=no-self-use - def get(self, resource): - basepath = settings.dependencies_directory - directory = _d(os.path.join(basepath, resource.owner.name)) + def get(self, resource: Resource) -> Optional[str]: + """ + get resource + """ + basepath: str = settings.dependencies_directory + directory: str = _d(os.path.join(basepath, cast(OwnerProtocol, resource.owner).name)) return get_from_location(directory, resource) +class HttpProtocol(Protocol): + name: str + description: str + url: str + username: str + password: str + always_fetch: bool + chunk_size: int + logger: logging.Logger + index: Dict + + class Http(ResourceGetter): - name = 'http' - description = """ + name: str = 'http' + description: str = """ Downloads resources from a server based on an index fetched from the specified URL. @@ -180,7 +224,7 @@ class Http(ResourceGetter): provided that hasn't changed, it won't try to download the file again. """ - parameters = [ + parameters: List[Parameter] = [ Parameter('url', global_alias='remote_assets_url', description=""" URL of the index file for assets on an HTTP server. @@ -205,24 +249,30 @@ class Http(ResourceGetter): """), ] - def __init__(self, **kwargs): + def __init__(self, **kwargs) -> None: super(Http, self).__init__(**kwargs) self.logger = logger - self.index = {} + self.index: Dict[str, Dict] = {} - def register(self, resolver): + def register(self, resolver: ResourceResolver) -> None: + """ + register Http with resolver + """ resolver.register(self.get, SourcePriority.remote) - def get(self, resource): + def get(self, resource: Resource) -> Optional[str]: + """ + get resource + """ if not resource.owner: - return # TODO: add support for unowned resources + return None # TODO: add support for unowned resources if not self.index: try: self.index = self.fetch_index() except requests.exceptions.RequestException as e: msg = 'Skipping HTTP getter due to connection error: {}' - self.logger.debug(msg.format(e.message)) - return + self.logger.debug(msg.format(str(e))) + return None if resource.kind == 'apk': # APKs must always be downloaded to run ApkInfo for version # information. @@ -230,16 +280,19 @@ def get(self, resource): else: asset = self.resolve_resource(resource) if not asset: - return - return self.download_asset(asset, resource.owner.name) - - def fetch_index(self): - if not self.url: + return None + return self.download_asset(asset, cast(OwnerProtocol, resource.owner).name) + + def fetch_index(self) -> Dict: + """ + fetch index page + """ + if not cast(HttpProtocol, self).url: return {} - index_url = urljoin(self.url, 'index.json') - response = self.geturl(index_url) + index_url: str = urljoin(cast(HttpProtocol, self).url, 'index.json') + response: Response = self.geturl(index_url) if response.status_code != http.client.OK: - message = 'Could not fetch "{}"; received "{} {}"' + message: str = 'Could not fetch "{}"; received "{} {}"' self.logger.error(message.format(index_url, response.status_code, response.reason)) @@ -247,77 +300,98 @@ def fetch_index(self): content = response.content.decode('utf-8') return json.loads(content) - def download_asset(self, asset, owner_name): - url = urljoin(self.url, owner_name, asset['path']) - local_path = _f(os.path.join(settings.dependencies_directory, '__remote', - owner_name, asset['path'].replace('/', os.sep))) + def download_asset(self, asset: Dict[str, Any], owner_name: str) -> Optional[str]: + """ + download asset + """ + url: str = urljoin(cast(HttpProtocol, self).url, owner_name, asset['path']) + local_path: str = _f(os.path.join(settings.dependencies_directory, '__remote', + owner_name, asset['path'].replace('/', os.sep))) if os.path.exists(local_path) and not self.always_fetch: - local_sha = sha256(local_path) + local_sha: str = sha256(local_path) if local_sha == asset['sha256']: self.logger.debug('Local SHA256 matches; not re-downloading') return local_path self.logger.debug('Downloading {}'.format(url)) - response = self.geturl(url, stream=True) + response: Response = self.geturl(url, stream=True) if response.status_code != http.client.OK: - message = 'Could not download asset "{}"; received "{} {}"' + message: str = 'Could not download asset "{}"; received "{} {}"' self.logger.warning(message.format(url, response.status_code, response.reason)) - return + return None with atomic_write_path(local_path) as at_path: with open(at_path, 'wb') as wfh: for chunk in response.iter_content(chunk_size=self.chunk_size): wfh.write(chunk) return local_path - def geturl(self, url, stream=False): - if self.username: - auth = (self.username, self.password) + def geturl(self, url: str, stream: bool = False) -> Response: + """ + get response from url request via http + """ + if cast(HttpProtocol, self).username: + auth = (cast(HttpProtocol, self).username, cast(HttpProtocol, self).password) else: auth = None return requests.get(url, auth=auth, stream=stream) - def resolve_apk(self, resource): - assets = self.index.get(resource.owner.name, {}) + def resolve_apk(self, resource: Resource) -> Optional[str]: + """ + resolve apk + """ + assets: Dict = self.index.get(cast(OwnerProtocol, resource.owner).name, {}) if not assets: return None - asset_map = {a['path']: a for a in assets} - paths = get_path_matches(resource, list(asset_map.keys())) - local_paths = [] + asset_map: Dict = {a['path']: a for a in assets} + paths: List[str] = get_path_matches(resource, list(asset_map.keys())) + local_paths: List[Optional[str]] = [] for path in paths: local_paths.append(self.download_asset(asset_map[path], - resource.owner.name)) - for path in local_paths: + cast(OwnerProtocol, resource.owner).name)) + for path in cast(List[str], local_paths): if resource.match(path): return path + return None - def resolve_resource(self, resource): + def resolve_resource(self, resource: Resource) -> Optional[Dict]: + """ + resolve resource + """ # pylint: disable=too-many-branches,too-many-locals - assets = self.index.get(resource.owner.name, {}) + assets: Dict = self.index.get(cast(OwnerProtocol, resource.owner).name, {}) if not assets: return {} - asset_map = {a['path']: a for a in assets} + asset_map: Dict = {a['path']: a for a in assets} if resource.kind in ['jar', 'revent']: - path = get_generic_resource(resource, list(asset_map.keys())) + path: Optional[str] = get_generic_resource(resource, list(asset_map.keys())) if path: return asset_map[path] elif resource.kind == 'executable': - path = '/'.join(['bin', resource.abi, resource.filename]) + path = '/'.join(['bin', cast(Executable, resource).abi, cast(Executable, resource).filename]) for asset in assets: if asset['path'].lower() == path.lower(): return asset else: # file for asset in assets: - if asset['path'].lower() == resource.path.lower(): + if asset['path'].lower() == cast(File, resource).path.lower(): return asset + return None + + +class FilerProtocol(Protocol): + name: str + description: str + remote_path: str + always_fetch: bool class Filer(ResourceGetter): - name = 'filer' - description = """ + name: str = 'filer' + description: str = """ Finds resources on a (locally mounted) remote filer and caches them locally. @@ -325,7 +399,7 @@ class Filer(ResourceGetter): samba share). """ - parameters = [ + parameters: List[Parameter] = [ Parameter('remote_path', global_alias='remote_assets_path', default=settings.assets_repository, description=""" @@ -339,17 +413,24 @@ class Filer(ResourceGetter): """), ] - def register(self, resolver): + def register(self, resolver: ResourceResolver) -> None: + """ + register Filer with resource resolver + """ resolver.register(self.get, SourcePriority.lan) - def get(self, resource): + def get(self, resource: Resource) -> Optional[str]: + """ + get filer + """ + remote_path: str = '' if resource.owner: - remote_path = os.path.join(self.remote_path, resource.owner.name) - local_path = os.path.join(settings.dependencies_directory, '__filer', - resource.owner.dependencies_directory) + remote_path = os.path.join(cast(FilerProtocol, self).remote_path, cast(OwnerProtocol, resource.owner).name) + local_path: str = os.path.join(settings.dependencies_directory, '__filer', + cast(OwnerProtocol, resource.owner).dependencies_directory) return self.try_get_resource(resource, remote_path, local_path) else: # No owner - result = None + result: Optional[str] = None for entry in os.listdir(remote_path): remote_path = os.path.join(self.remote_path, entry) local_path = os.path.join(settings.dependencies_directory, '__filer', @@ -359,9 +440,13 @@ def get(self, resource): break return result - def try_get_resource(self, resource, remote_path, local_path): - if not self.always_fetch: - result = get_from_location(local_path, resource) + def try_get_resource(self, resource: Resource, remote_path: str, + local_path: str) -> Optional[str]: + """ + try get resource + """ + if not cast(FilerProtocol, self).always_fetch: + result: Optional[str] = get_from_location(local_path, resource) if result: return result if not os.path.exists(local_path): @@ -374,7 +459,7 @@ def try_get_resource(self, resource, remote_path, local_path): else: # remote path is not set return None # Found it remotely, cache locally, then return it - local_full_path = os.path.join(_d(local_path), os.path.basename(result)) + local_full_path: str = os.path.join(_d(local_path), os.path.basename(result)) self.logger.debug('cp {} {}'.format(result, local_full_path)) shutil.copy(result, local_full_path) return result diff --git a/wa/framework/host.py b/wa/framework/host.py index 973a253fe..34c96e0ef 100644 --- a/wa/framework/host.py +++ b/wa/framework/host.py @@ -25,12 +25,13 @@ from wa.utils.misc import load_struct_from_python from wa.utils.serializer import yaml from wa.utils.types import identifier - +from typing import Dict, List, Any, Union, cast, Optional # Have to disable this due to dynamic attributes # pylint: disable=no-member -def init_user_directory(overwrite_existing=False): # pylint: disable=R0914 + +def init_user_directory(overwrite_existing: bool = False): # pylint: disable=R0914 """ Initialise a fresh user directory. """ @@ -48,11 +49,11 @@ def init_user_directory(overwrite_existing=False): # pylint: disable=R0914 if os.getenv('USER') == 'root': # If running with sudo on POSIX, change the ownership to the real user. - real_user = os.getenv('SUDO_USER') + real_user: Optional[str] = os.getenv('SUDO_USER') if real_user: # pylint: disable=import-outside-toplevel import pwd # done here as module won't import on win32 - user_entry = pwd.getpwnam(real_user) + user_entry: pwd.struct_passwd = pwd.getpwnam(real_user) uid, gid = user_entry.pw_uid, user_entry.pw_gid os.chown(settings.user_directory, uid, gid) # why, oh why isn't there a recusive=True option for os.chown? @@ -63,27 +64,27 @@ def init_user_directory(overwrite_existing=False): # pylint: disable=R0914 os.chown(os.path.join(root, f), uid, gid) -def init_config(): +def init_config() -> None: """ If configuration file is missing try to convert WA2 config if present otherwise initialize fresh config file """ - wa2_config_file = os.path.join(settings.user_directory, 'config.py') - wa3_config_file = os.path.join(settings.user_directory, 'config.yaml') + wa2_config_file: str = os.path.join(settings.user_directory, 'config.py') + wa3_config_file: str = os.path.join(settings.user_directory, 'config.yaml') if os.path.exists(wa2_config_file): convert_wa2_agenda(wa2_config_file, wa3_config_file) else: generate_default_config(wa3_config_file) -def convert_wa2_agenda(filepath, output_path): +def convert_wa2_agenda(filepath: str, output_path: str) -> None: """ Convert WA2 .py config file to a WA3 .yaml config file. """ - orig_agenda = load_struct_from_python(filepath) - new_agenda = {'augmentations': []} - config_points = MetaConfiguration.config_points + RunConfiguration.config_points + orig_agenda: Dict[str, Any] = load_struct_from_python(filepath) + new_agenda: Dict[str, Union[List, Dict]] = {'augmentations': []} + config_points: List[ConfigurationPoint] = MetaConfiguration.config_points + RunConfiguration.config_points # Add additional config points to extract from config file. # Also allows for aliasing of renamed parameters @@ -115,19 +116,19 @@ def convert_wa2_agenda(filepath, output_path): for cfg_point in config_points: if param == cfg_point.name or param in cfg_point.aliases: if cfg_point.name == 'augmentations': - new_agenda['augmentations'].extend(orig_agenda.pop(param)) + cast(List, new_agenda['augmentations']).extend(orig_agenda.pop(param)) else: new_agenda[cfg_point.name] = format_parameter(orig_agenda.pop(param)) with open(output_path, 'w') as output: - for param in config_points: - entry = {param.name: new_agenda.get(param.name, param.default)} - write_param_yaml(entry, param, output) + for param_ in config_points: + entry: Dict[str, Union[List, Dict]] = {param_.name: new_agenda.get(param_.name, param_.default)} + write_param_yaml(entry, param_, output) # Convert plugin configuration output.write("# Plugin Configuration\n") for param in list(orig_agenda.keys()): - if pluginloader.has_plugin(param): + if cast(pluginloader.__LoaderWrapper, pluginloader).has_plugin(param): entry = {param: orig_agenda.pop(param)} yaml.dump(format_parameter(entry), output, default_flow_style=False) output.write("\n") @@ -142,7 +143,7 @@ def convert_wa2_agenda(filepath, output_path): output.write("\n") -def format_parameter(param): +def format_parameter(param: Any): if isinstance(param, dict): return {identifier(k): v for k, v in param.items()} else: diff --git a/wa/framework/instrument.py b/wa/framework/instrument.py index 663361662..d3fd69e38 100644 --- a/wa/framework/instrument.py +++ b/wa/framework/instrument.py @@ -24,7 +24,7 @@ Once a signal is broadcasted, the corresponding registered method is invoked. Each method in Instrument must take two arguments, which are self and context. -Supported signals can be found in [... link to signals ...] To make +Supported signals can be found in [wa.framework.signal.py] To make implementations easier and common, the basic steps to add new instrument is similar to the steps to add new workload. @@ -109,9 +109,26 @@ def teardown(self, context): from wa.utils.log import log_error from wa.utils.misc import isiterable from wa.utils.types import identifier, level +from typing import (List, OrderedDict as od, TYPE_CHECKING, Callable, Any, + Union, cast, Optional) +if TYPE_CHECKING: + from wa.framework.execution import ExecutionContext + from wa.framework.configuration.core import StatusType + # When type-checking, pretend CallbackPriority is a standard Enum + # (or anything you want the type checker to see). + import enum + class CallbackPriority(enum.Enum): + EXTREMELY_LOW = -30 + VERY_LOW = -20 + LOW = -10 + NORMAL = 0 + HIGH = 10 + VERY_HIGH = 20 + EXTREMELY_HIGH = 30 -logger = logging.getLogger('instruments') + +logger: logging.Logger = logging.getLogger('instruments') # Maps method names onto signals the should be registered to. @@ -119,7 +136,7 @@ def teardown(self, context): # then the corresponding end_ signal is guaranteed to also be sent. # Note: using OrderedDict to preserve logical ordering for the table generated # in the documentation -SIGNAL_MAP = OrderedDict([ +SIGNAL_MAP: od[str, signal.Signal] = OrderedDict([ # Below are "aliases" for some of the more common signals to allow # instruments to have similar structure to workloads ('initialize', signal.RUN_INITIALIZED), @@ -160,23 +177,29 @@ def teardown(self, context): ]) -def get_priority(func): +def get_priority(func) -> "CallbackPriority": + """ + get priority + """ return getattr(getattr(func, 'im_func', func), 'priority', signal.CallbackPriority.normal) -def priority(priority): # pylint: disable=redefined-outer-name - def decorate(func): - def wrapper(*args, **kwargs): +def priority(priority: 'CallbackPriority') -> Any: # pylint: disable=redefined-outer-name + """ + priority for the instrument signal callback + """ + def decorate(func: Callable) -> Callable: + def wrapper(*args, **kwargs) -> Any: return func(*args, **kwargs) wrapper.__name__ = func.__name__ if priority in signal.CallbackPriority.levels: - wrapper.priority = signal.CallbackPriority(priority) + wrapper.priority = signal.CallbackPriority(priority) # type:ignore else: if not isinstance(priority, int): msg = 'Invalid priorty "{}"; must be an int or one of {}' raise ValueError(msg.format(priority, signal.CallbackPriority.values)) - wrapper.priority = level('custom', priority) + wrapper.priority = level('custom', priority) # type:ignore return wrapper return decorate @@ -190,7 +213,7 @@ def wrapper(*args, **kwargs): extremely_fast = priority(signal.CallbackPriority.extremely_high) -def hostside(func): +def hostside(func: Callable) -> Callable: """ Used as a hint that the callback only performs actions on the host and does not rely on an active connection to the target. @@ -198,18 +221,24 @@ def hostside(func): thought to be unresponsive. """ - func.is_hostside = True + func.is_hostside = True # type:ignore return func -def is_hostside(func): +def is_hostside(func: Callable) -> bool: + """ + whether callback is only relying on host + """ return getattr(func, 'is_hostside', False) -installed = [] +installed: List['Instrument'] = [] -def is_installed(instrument): +def is_installed(instrument: Union['Instrument', type, str]) -> bool: + """ + whether instrument is already installed + """ if isinstance(instrument, Instrument): if instrument in installed: return True @@ -224,27 +253,36 @@ def is_installed(instrument): return False -def is_enabled(instrument): +def is_enabled(instrument: Union['Instrument', type, str]) -> bool: + """ + whether instrument is enabled + """ if isinstance(instrument, (Instrument, type)): - name = instrument.name + name: Optional[str] = cast('Instrument', instrument).name else: # assume string name = instrument try: - installed_instrument = get_instrument(name) + installed_instrument: 'Instrument' = get_instrument(cast(str, name)) return installed_instrument.is_enabled except ValueError: return False -failures_detected = False +failures_detected: bool = False -def reset_failures(): +def reset_failures() -> None: + """ + reset failures + """ global failures_detected # pylint: disable=W0603 failures_detected = False -def check_failures(): +def check_failures() -> bool: + """ + check failures + """ result = failures_detected reset_failures() return result @@ -257,12 +295,12 @@ class ManagedCallback(object): """ - def __init__(self, instrument, callback): + def __init__(self, instrument: 'Instrument', callback: Callable): self.instrument = instrument self.callback = callback self.is_hostside = is_hostside(callback) - def __call__(self, context): + def __call__(self, context: 'ExecutionContext'): if self.instrument.is_enabled: try: if not context.tm.is_responsive and not self.is_hostside: @@ -278,18 +316,18 @@ def __call__(self, context): log_error(e, logger) context.add_event(e.args[0] if e.args else str(e)) if isinstance(e, WorkloadError): - context.set_status('FAILED') + context.set_status(cast('StatusType', 'FAILED')) elif isinstance(e, (TargetError, TimeoutError)): context.tm.verify_target_responsive(context) else: if context.current_job: - context.set_status('PARTIAL') + context.set_status(cast('StatusType', 'PARTIAL')) else: raise - def __repr__(self): - text = 'ManagedCallback({}, {})' - return text.format(self.instrument.name, self.callback.__func__.__name__) + def __repr__(self) -> str: + text: str = 'ManagedCallback({}, {})' + return text.format(self.instrument.name, self.callback.__func__.__name__) # type:ignore __str__ = __repr__ @@ -297,10 +335,10 @@ def __repr__(self): # Need this to keep track of callbacks, because the dispatcher only keeps # weak references, so if the callbacks aren't referenced elsewhere, they will # be deallocated before they've had a chance to be invoked. -_callbacks = [] +_callbacks: List[ManagedCallback] = [] -def install(instrument, context): +def install(instrument: 'Instrument', context: 'ExecutionContext'): """ This will look for methods (or any callable members) with specific names in the instrument and hook them up to the corresponding signals. @@ -312,30 +350,30 @@ def install(instrument, context): logger.debug('Installing instrument %s.', instrument) if is_installed(instrument): - msg = 'Instrument {} is already installed.' + msg: str = 'Instrument {} is already installed.' raise ValueError(msg.format(instrument.name)) for attr_name in dir(instrument): if attr_name not in SIGNAL_MAP: continue - attr = getattr(instrument, attr_name) + attr: Any = getattr(instrument, attr_name) if not callable(attr): msg = 'Attribute {} not callable in {}.' raise ValueError(msg.format(attr_name, instrument)) - argspec = inspect.getfullargspec(attr) - arg_num = len(argspec.args) + argspec: inspect.FullArgSpec = inspect.getfullargspec(attr) + arg_num: int = len(argspec.args) # Instrument callbacks will be passed exactly two arguments: self # (the instrument instance to which the callback is bound) and # context. However, we also allow callbacks to capture the context # in variable arguments (declared as "*args" in the definition). if arg_num > 2 or (arg_num < 2 and argspec.varargs is None): - message = '{} must take exactly 2 positional arguments; {} given.' + message: str = '{} must take exactly 2 positional arguments; {} given.' raise ValueError(message.format(attr_name, arg_num)) - priority = get_priority(attr) - hostside = ' [hostside]' if is_hostside(attr) else '' + priority: 'CallbackPriority' = get_priority(attr) + hostside: str = ' [hostside]' if is_hostside(attr) else '' logger.debug('\tConnecting %s to %s with priority %s(%d)%s', attr.__name__, SIGNAL_MAP[attr_name], priority.name, priority.value, hostside) @@ -343,22 +381,31 @@ def install(instrument, context): _callbacks.append(mc) signal.connect(mc, SIGNAL_MAP[attr_name], priority=priority.value) - instrument.logger.context = context + instrument.logger.context = context # type:ignore installed.append(instrument) context.add_augmentation(instrument) -def uninstall(instrument): +def uninstall(instrument: 'Instrument') -> None: + """ + uninstall the instrument. + """ instrument = get_instrument(instrument) installed.remove(instrument) -def validate(): +def validate() -> None: + """ + validate the instrument + """ for instrument in installed: instrument.validate() -def get_instrument(inst): +def get_instrument(inst: Union['Instrument', str]) -> 'Instrument': + """ + get instrument + """ if isinstance(inst, Instrument): return inst for installed_inst in installed: @@ -367,33 +414,48 @@ def get_instrument(inst): raise ValueError('Instrument {} is not installed'.format(inst)) -def disable_all(): +def disable_all() -> None: + """ + disable all instruments + """ for instrument in installed: _disable_instrument(instrument) -def enable_all(): +def enable_all() -> None: + """ + enable all instruments + """ for instrument in installed: _enable_instrument(instrument) -def enable(to_enable): +def enable(to_enable: Union['Instrument', List['Instrument'], str]) -> None: + """ + enable the specified instruments + """ if isiterable(to_enable): - for inst in to_enable: + for inst in to_enable: # type:ignore _enable_instrument(inst) else: - _enable_instrument(to_enable) + _enable_instrument(cast('Instrument', to_enable)) -def disable(to_disable): +def disable(to_disable: Union['Instrument', List['Instrument'], str]) -> None: + """ + disable the specified instruments + """ if isiterable(to_disable): - for inst in to_disable: + for inst in to_disable: # type:ignore _disable_instrument(inst) else: - _disable_instrument(to_disable) + _disable_instrument(cast('Instrument', to_disable)) -def _enable_instrument(inst): +def _enable_instrument(inst: Union['Instrument', str]) -> None: + """ + enable the specified instrument + """ inst = get_instrument(inst) if not inst.is_broken: logger.debug('Enabling instrument {}'.format(inst.name)) @@ -402,28 +464,42 @@ def _enable_instrument(inst): logger.debug('Not enabling broken instrument {}'.format(inst.name)) -def _disable_instrument(inst): +def _disable_instrument(inst: Union['Instrument', str]) -> None: + """ + disable the specified instrument + """ inst = get_instrument(inst) if inst.is_enabled: logger.debug('Disabling instrument {}'.format(inst.name)) inst.is_enabled = False -def get_enabled(): +def get_enabled() -> List['Instrument']: + """ + get list of enabled instruments + """ return [i for i in installed if i.is_enabled] -def get_disabled(): +def get_disabled() -> List['Instrument']: + """ + get list of disabled instruments + """ return [i for i in installed if not i.is_enabled] class Instrument(TargetedPlugin): """ Base class for instrument implementations. + These "instrument"s in a WA run in order to change it's behaviour (e.g. + introducing delays between successive job executions), or collect + additional measurements (e.g. energy usage). Some instruments may depend + on particular features being enabled on the target (e.g. cpufreq), or + on additional hardware (e.g. energy probes). """ - kind = "instrument" + kind: str = "instrument" - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: super(Instrument, self).__init__(*args, **kwargs) - self.is_enabled = True - self.is_broken = False + self.is_enabled: bool = True + self.is_broken: bool = False diff --git a/wa/framework/job.py b/wa/framework/job.py index dee9a7e96..d0070b143 100644 --- a/wa/framework/job.py +++ b/wa/framework/job.py @@ -18,96 +18,151 @@ import logging from copy import copy -from datetime import datetime +from datetime import datetime, timedelta from wa.framework import pluginloader, signal, instrument from wa.framework.configuration.core import Status from wa.utils.log import indentcontext from wa.framework.run import JobState +from typing import (TYPE_CHECKING, Optional, Any, Dict, + Type, cast, OrderedDict, Set) +from devlib.target import Target +from types import ModuleType +if TYPE_CHECKING: + from wa.framework.configuration.core import JobSpecProtocol + from wa.framework.execution import ExecutionContext + from wa.framework.workload import Workload + from wa.framework.output import JobOutput + from wa.framework.output_processor import ProcessorManager + from wa.framework.configuration.core import StatusType + from wa.framework.plugin import Plugin + from louie.dispatcher import Anonymous # type:ignore class Job(object): + """ + A single execution of a workload. A job is defined by an associated + :term:`spec`. However, multiple jobs can share the same spec; + E.g. Even if you only have 1 workload to run but wanted 5 iterations + then 5 individual jobs will be generated to be run. + """ + _workload_cache: Dict[str, 'Workload'] = {} - _workload_cache = {} + def __init__(self, spec: 'JobSpecProtocol', + iteration: int, context: 'ExecutionContext'): + self.logger = logging.getLogger('job') + self.spec = spec + self.iteration = iteration + self.context = context + self.workload: Optional['Workload'] = None + self.output: Optional['JobOutput'] = None + self.run_time: Optional[timedelta] = None + self.classifiers: OrderedDict[str, str] = copy(self.spec.classifiers) + self._has_been_initialized: bool = False + self.state = JobState(self.id or '', self.label, self.iteration, Status.NEW) @property - def id(self): + def id(self) -> Optional[str]: + """ + job id + """ return self.spec.id @property - def label(self): + def label(self) -> str: + """ + job label + """ return self.spec.label @property - def status(self): + def status(self) -> 'StatusType': + """ + job status getter + """ return self.state.status - @property - def has_been_initialized(self): - return self._has_been_initialized - - @property - def retries(self): - return self.state.retries - @status.setter - def status(self, value): + def status(self, value: 'StatusType') -> None: + """ + job status setter + """ self.state.status = value self.state.timestamp = datetime.utcnow() if self.output: self.output.status = value + @property + def has_been_initialized(self) -> bool: + """ + check if the job has been initialized + """ + return self._has_been_initialized + + @property + def retries(self) -> int: + """ + number of retries for the job execution + """ + return self.state.retries + @retries.setter - def retries(self, value): + def retries(self, value: int) -> None: + """ + setter for retries + """ self.state.retries = value - def __init__(self, spec, iteration, context): - self.logger = logging.getLogger('job') - self.spec = spec - self.iteration = iteration - self.context = context - self.workload = None - self.output = None - self.run_time = None - self.classifiers = copy(self.spec.classifiers) - self._has_been_initialized = False - self.state = JobState(self.id, self.label, self.iteration, Status.NEW) - - def load(self, target, loader=pluginloader): + def load(self, target: Target, loader: ModuleType = pluginloader) -> None: + """ + load workload for the job + """ self.logger.info('Loading job {}'.format(self)) if self.id not in self._workload_cache: self.workload = loader.get_workload(self.spec.workload_name, target, **self.spec.workload_parameters) - self.workload.init_resources(self.context) - self.workload.validate() - self._workload_cache[self.id] = self.workload + if self.workload: + self.workload.init_resources(self.context) + self.workload.validate() + self._workload_cache[self.id or ''] = self.workload else: self.workload = self._workload_cache[self.id] - def set_output(self, output): + def set_output(self, output: 'JobOutput') -> None: + """ + set output + """ output.classifiers = copy(self.classifiers) self.output = output - def initialize(self, context): + def initialize(self, context: 'ExecutionContext') -> None: + """ + initialize the job execution + """ self.logger.info('Initializing job {}'.format(self)) with indentcontext(): - with signal.wrap('WORKLOAD_INITIALIZED', self, context): - self.workload.logger.context = context - self.workload.initialize(context) + with signal.wrap('WORKLOAD_INITIALIZED', cast(Type['Anonymous'], self), context): + if self.workload: + self.workload.logger.context = context # type:ignore + self.workload.initialize(context) self.set_status(Status.PENDING) self._has_been_initialized = True - def configure_augmentations(self, context, pm): + def configure_augmentations(self, context: 'ExecutionContext', + pm: 'ProcessorManager') -> None: + """ + configure augmentations + """ self.logger.info('Configuring augmentations') with indentcontext(): - instruments_to_enable = set() - output_processors_to_enable = set() - enabled_instruments = set(i.name for i in instrument.get_enabled()) - enabled_output_processors = set(p.name for p in pm.get_enabled()) + instruments_to_enable: Set[str] = set() + output_processors_to_enable: Set[str] = set() + enabled_instruments: Set[Optional[str]] = set(i.name for i in instrument.get_enabled()) + enabled_output_processors: Set[Optional[str]] = set(p.name for p in pm.get_enabled()) for augmentation in list(self.spec.augmentations.values()): - augmentation_cls = context.cm.plugin_cache.get_plugin_class(augmentation) + augmentation_cls: Type['Plugin'] = context.cm.plugin_cache.get_plugin_class(augmentation) if augmentation_cls.kind == 'instrument': instruments_to_enable.add(augmentation) elif augmentation_cls.kind == 'output_processor': @@ -115,62 +170,85 @@ def configure_augmentations(self, context, pm): # Disable unrequired instruments for instrument_name in enabled_instruments.difference(instruments_to_enable): - instrument.disable(instrument_name) + instrument.disable(instrument_name or '') # Enable additional instruments for instrument_name in instruments_to_enable.difference(enabled_instruments): - instrument.enable(instrument_name) + instrument.enable(instrument_name or '') # Disable unrequired output_processors for processor in enabled_output_processors.difference(output_processors_to_enable): - pm.disable(processor) + pm.disable(processor or '') # Enable additional output_processors for processor in output_processors_to_enable.difference(enabled_output_processors): - pm.enable(processor) + pm.enable(processor or '') - def configure_target(self, context): + def configure_target(self, context: 'ExecutionContext') -> None: + """ + configure target + """ self.logger.info('Configuring target for job {}'.format(self)) with indentcontext(): context.tm.commit_runtime_parameters(self.spec.runtime_parameters) - def setup(self, context): + def setup(self, context: 'ExecutionContext') -> None: + """ + setup the job + """ self.logger.info('Setting up job {}'.format(self)) with indentcontext(): - with signal.wrap('WORKLOAD_SETUP', self, context): - self.workload.setup(context) + with signal.wrap('WORKLOAD_SETUP', cast(Type['Anonymous'], self), context): + if self.workload: + self.workload.setup(context) - def run(self, context): + def run(self, context: 'ExecutionContext') -> None: + """ + run the job + """ self.logger.info('Running job {}'.format(self)) with indentcontext(): - with signal.wrap('WORKLOAD_EXECUTION', self, context): - start_time = datetime.utcnow() + with signal.wrap('WORKLOAD_EXECUTION', cast(Type['Anonymous'], self), context): + start_time: datetime = datetime.utcnow() try: - self.workload.run(context) + if self.workload: + self.workload.run(context) finally: self.run_time = datetime.utcnow() - start_time - def process_output(self, context): + def process_output(self, context: 'ExecutionContext') -> None: + """ + process job output + """ if not context.tm.is_responsive: self.logger.info('Target unresponsive; not processing job output.') return self.logger.info('Processing output for job {}'.format(self)) with indentcontext(): if self.status != Status.FAILED: - with signal.wrap('WORKLOAD_RESULT_EXTRACTION', self, context): - self.workload.extract_results(context) + with signal.wrap('WORKLOAD_RESULT_EXTRACTION', cast(Type['Anonymous'], self), context): + if self.workload: + self.workload.extract_results(context) context.extract_results() - with signal.wrap('WORKLOAD_OUTPUT_UPDATE', self, context): - self.workload.update_output(context) + with signal.wrap('WORKLOAD_OUTPUT_UPDATE', cast(Type['Anonymous'], self), context): + if self.workload: + self.workload.update_output(context) - def teardown(self, context): + def teardown(self, context: 'ExecutionContext') -> None: + """ + teardown the job run + """ if not context.tm.is_responsive: self.logger.info('Target unresponsive; not tearing down.') return self.logger.info('Tearing down job {}'.format(self)) with indentcontext(): - with signal.wrap('WORKLOAD_TEARDOWN', self, context): - self.workload.teardown(context) + with signal.wrap('WORKLOAD_TEARDOWN', cast(Type['Anonymous'], self), context): + if self.workload: + self.workload.teardown(context) - def finalize(self, context): + def finalize(self, context: 'ExecutionContext') -> None: + """ + finalize the job run + """ if not self._has_been_initialized: return if not context.tm.is_responsive: @@ -178,15 +256,24 @@ def finalize(self, context): return self.logger.info('Finalizing job {} '.format(self)) with indentcontext(): - with signal.wrap('WORKLOAD_FINALIZED', self, context): - self.workload.finalize(context) + with signal.wrap('WORKLOAD_FINALIZED', cast(Type['Anonymous'], self), context): + if self.workload: + self.workload.finalize(context) - def set_status(self, status, force=False): + def set_status(self, status: 'StatusType', + force: bool = False) -> None: + """ + set job status + """ status = Status(status) - if force or self.status < status: + if force or cast(int, self.status) < cast(int, status): self.status = status - def add_classifier(self, name, value, overwrite=False): + def add_classifier(self, name: str, value: Any, + overwrite: bool = False) -> None: + """ + add classifier to the job + """ if name in self.classifiers and not overwrite: raise ValueError('Cannot overwrite "{}" classifier.'.format(name)) self.classifiers[name] = value diff --git a/wa/framework/output.py b/wa/framework/output.py index 814a84d96..2fba399ea 100644 --- a/wa/framework/output.py +++ b/wa/framework/output.py @@ -14,8 +14,9 @@ # try: - import psycopg2 - from psycopg2 import Error as Psycopg2Error + import psycopg2 # type:ignore + from psycopg2 import Error as Psycopg2Error # type:ignore + from psycopg2 import _psycopg # type:ignore except ImportError: psycopg2 = None Psycopg2Error = None @@ -32,8 +33,9 @@ import devlib -from wa.framework.configuration.core import JobSpec, Status -from wa.framework.configuration.execution import CombinedConfig +from wa.framework.configuration.core import (JobSpec, JobSpecProtocol, Status, + RunConfigurationProtocol, MetaConfigurationProtocol) +from wa.framework.configuration.execution import CombinedConfig, ConfigManager from wa.framework.exception import HostError, SerializerSyntaxError, ConfigError from wa.framework.run import RunState, RunInfo from wa.framework.target.info import TargetInfo @@ -44,82 +46,124 @@ from wa.utils.postgres import get_schema_versions from wa.utils.serializer import write_pod, read_pod, Podable, json from wa.utils.types import enum, numeric +from uuid import UUID +from typing import (Optional, List, cast, Dict, Any, TYPE_CHECKING, Set, + Union, Generator, Tuple, DefaultDict) +if TYPE_CHECKING: + from wa.framework.configuration.core import StatusType + from wa.framework.job import Job -logger = logging.getLogger('output') +logger: logging.Logger = logging.getLogger('output') class Output(object): - - kind = None + """ + base class for run output and job output + """ + kind: Optional[str] = None @property - def resultfile(self): + def resultfile(self) -> str: + """ + result file + """ return os.path.join(self.basepath, 'result.json') @property - def event_summary(self): - num_events = len(self.events) + def event_summary(self) -> str: + """ + event summary + """ + num_events: int = len(self.events) if num_events: - lines = self.events[0].message.split('\n') - message = '({} event(s)): {}' + lines: List[str] = self.events[0].message.split('\n') + message: str = '({} event(s)): {}' if num_events > 1 or len(lines) > 1: message += '[...]' return message.format(num_events, lines[0]) return '' @property - def status(self): + def status(self) -> Optional['StatusType']: + """ + output status + """ if self.result is None: return None return self.result.status @status.setter - def status(self, value): - self.result.status = value + def status(self, value: 'StatusType') -> None: + """ + output status setter + """ + if self.result: + self.result.status = value @property - def metrics(self): + def metrics(self) -> List['Metric']: + """ + list of metrics + """ if self.result is None: return [] return self.result.metrics @property - def artifacts(self): + def artifacts(self) -> List['Artifact']: + """ + list of artifacts + """ if self.result is None: return [] return self.result.artifacts @property - def classifiers(self): + def classifiers(self) -> OrderedDict: + """ + dict of classifiers + """ if self.result is None: return OrderedDict() return self.result.classifiers @classifiers.setter - def classifiers(self, value): + def classifiers(self, value: OrderedDict) -> None: + """ + classifiers setter + """ if self.result is None: msg = 'Attempting to set classifiers before output has been set' raise RuntimeError(msg) self.result.classifiers = value @property - def events(self): + def events(self) -> List['Event']: + """ + list of events + """ if self.result is None: return [] return self.result.events @property - def metadata(self): + def metadata(self) -> OrderedDict: + """ + output metadata + """ if self.result is None: - return {} + return cast(OrderedDict, {}) return self.result.metadata - def __init__(self, path): + def __init__(self, path: str): self.basepath = path - self.result = None + self.result: Optional[Result] = None - def reload(self): + def reload(self) -> None: + """ + reload result + """ try: if os.path.isdir(self.basepath): pod = read_pod(self.resultfile) @@ -132,48 +176,93 @@ def reload(self): self.result.status = Status.UNKNOWN self.add_event(str(e)) - def write_result(self): - write_pod(self.result.to_pod(), self.resultfile) + def write_result(self) -> None: + """ + write result + """ + if self.result: + write_pod(self.result.to_pod(), self.resultfile) - def get_path(self, subpath): + def get_path(self, subpath: str) -> str: + """ + get path from subpath + """ return os.path.join(self.basepath, subpath.strip(os.sep)) - def add_metric(self, name, value, units=None, lower_is_better=False, - classifiers=None): - self.result.add_metric(name, value, units, lower_is_better, classifiers) + def add_metric(self, name: str, value: Any, + units: Optional[str] = None, lower_is_better: bool = False, + classifiers: Optional[Dict[str, Any]] = None) -> None: + """ + add metric to workload execution result + """ + if self.result: + self.result.add_metric(name, value, units, lower_is_better, classifiers) - def add_artifact(self, name, path, kind, description=None, classifiers=None): + def add_artifact(self, name: str, path: str, kind: str, description: Optional[str] = None, + classifiers: Optional[Dict[str, Any]] = None): + """ + add artifact to workload execution result + """ if not os.path.exists(path): path = self.get_path(path) if not os.path.exists(path): msg = 'Attempting to add non-existing artifact: {}' raise HostError(msg.format(path)) - is_dir = os.path.isdir(path) + is_dir: bool = os.path.isdir(path) path = os.path.relpath(path, self.basepath) + if self.result: + self.result.add_artifact(name, path, kind, description, classifiers, is_dir) - self.result.add_artifact(name, path, kind, description, classifiers, is_dir) - - def add_event(self, message): - self.result.add_event(message) + def add_event(self, message: str) -> None: + """ + add event + """ + if self.result: + self.result.add_event(message) - def get_metric(self, name): - return self.result.get_metric(name) + def get_metric(self, name: str) -> Optional['Metric']: + """ + get metric + """ + if self.result: + return self.result.get_metric(name) + return None - def get_artifact(self, name): - return self.result.get_artifact(name) + def get_artifact(self, name: str) -> Optional['Artifact']: + """ + get the specified artifact from workload execution result + """ + if self.result: + return self.result.get_artifact(name) + return None - def get_artifact_path(self, name): + def get_artifact_path(self, name: str) -> str: + """ + get path to the specified artifact + """ artifact = self.get_artifact(name) - return self.get_path(artifact.path) + return self.get_path(artifact.path if artifact else '') - def add_classifier(self, name, value, overwrite=False): - self.result.add_classifier(name, value, overwrite) + def add_classifier(self, name: str, value: Any, overwrite: bool = False) -> None: + """ + add classifier to workload execution result + """ + if self.result: + self.result.add_classifier(name, value, overwrite) - def add_metadata(self, key, *args, **kwargs): - self.result.add_metadata(key, *args, **kwargs) + def add_metadata(self, key: str, *args, **kwargs) -> None: + """ + add metadata to workload execution result + """ + if self.result: + self.result.add_metadata(key, *args, **kwargs) - def update_metadata(self, key, *args): - self.result.update_metadata(key, *args) + def update_metadata(self, key: str, *args) -> None: + """ + update an existing metadata in workload execution result + """ + if self.result: + self.result.update_metadata(key, *args) def __repr__(self): return '<{} {}>'.format(self.__class__.__name__, @@ -188,24 +277,38 @@ class RunOutputCommon(object): the RunOutput classes ''' @property - def run_config(self): - if self._combined_config: - return self._combined_config.run_config + def run_config(self) -> Optional[RunConfigurationProtocol]: + """ + get run configuration + """ + if cast('RunOutput', self)._combined_config: + return cast(CombinedConfig, cast('RunOutput', self)._combined_config).run_config + return None @property - def settings(self): - if self._combined_config: - return self._combined_config.settings + def settings(self) -> Optional[MetaConfigurationProtocol]: + """ + metadata configurations of the run + """ + if cast('RunOutput', self)._combined_config: + return cast(CombinedConfig, cast('RunOutput', self)._combined_config).settings + return None - def get_job_spec(self, spec_id): - for spec in self.job_specs: + def get_job_spec(self, spec_id: str) -> Optional[JobSpecProtocol]: + """ + get the job specifications + """ + for spec in cast('RunOutput', self).job_specs: if spec.id == spec_id: return spec return None - def list_workloads(self): - workloads = [] - for job in self.jobs: + def list_workloads(self) -> List[str]: + """ + list the workloads + """ + workloads: List[str] = [] + for job in cast('RunOutput', self).jobs: if job.label not in workloads: workloads.append(job.label) return workloads @@ -213,69 +316,99 @@ def list_workloads(self): class RunOutput(Output, RunOutputCommon): - kind = 'run' + kind: Optional[str] = 'run' @property - def logfile(self): + def logfile(self) -> str: + """ + log file path + """ return os.path.join(self.basepath, 'run.log') @property - def metadir(self): + def metadir(self) -> str: + """ + metadata directory + """ return os.path.join(self.basepath, '__meta') @property - def infofile(self): + def infofile(self) -> str: + """ + run info file + """ return os.path.join(self.metadir, 'run_info.json') @property - def statefile(self): + def statefile(self) -> str: + """ + run state file + """ return os.path.join(self.basepath, '.run_state.json') @property - def configfile(self): + def configfile(self) -> str: + """ + configuration file + """ return os.path.join(self.metadir, 'config.json') @property - def targetfile(self): + def targetfile(self) -> str: + """ + target information file + """ return os.path.join(self.metadir, 'target_info.json') @property - def jobsfile(self): + def jobsfile(self) -> str: + """ + jobs file + """ return os.path.join(self.metadir, 'jobs.json') @property - def raw_config_dir(self): + def raw_config_dir(self) -> str: + """ + raw configuration file + """ return os.path.join(self.metadir, 'raw_config') @property - def failed_dir(self): + def failed_dir(self) -> str: + """ + failed info file + """ path = os.path.join(self.basepath, '__failed') return ensure_directory_exists(path) @property - def augmentations(self): - run_augs = set([]) + def augmentations(self) -> List: + """ + augmentations on the run + """ + run_augs: Set = set([]) for job in self.jobs: - for aug in job.spec.augmentations: + for aug in cast(JobSpecProtocol, job.spec).augmentations: run_augs.add(aug) return list(run_augs) - def __init__(self, path): + def __init__(self, path: str): super(RunOutput, self).__init__(path) - self.info = None - self.state = None - self.result = None - self.target_info = None - self._combined_config = None - self.jobs = [] - self.job_specs = [] + self.info: Optional[RunInfo] = None + self.state: Optional[RunState] = None + self.result: Optional['Result'] = None + self.target_info: Optional['TargetInfo'] = None + self._combined_config: Optional[CombinedConfig] = None + self.jobs: List[JobOutput] = [] + self.job_specs: List[JobSpecProtocol] = [] if (not os.path.isfile(self.statefile) or not os.path.isfile(self.infofile)): - msg = '"{}" does not exist or is not a valid WA output directory.' + msg: str = '"{}" does not exist or is not a valid WA output directory.' raise ValueError(msg.format(self.basepath)) self.reload() - def reload(self): + def reload(self) -> None: super(RunOutput, self).reload() self.info = RunInfo.from_pod(read_pod(self.infofile)) self.state = RunState.from_pod(read_pod(self.statefile)) @@ -284,10 +417,10 @@ def reload(self): if os.path.isfile(self.targetfile): self.target_info = TargetInfo.from_pod(read_pod(self.targetfile)) if os.path.isfile(self.jobsfile): - self.job_specs = self.read_job_specs() + self.job_specs = self.read_job_specs() or [] for job_state in self.state.jobs.values(): - job_path = os.path.join(self.basepath, job_state.output_name) + job_path: str = os.path.join(self.basepath, job_state.output_name) job = JobOutput(job_path, job_state.id, job_state.label, job_state.iteration, job_state.retries) @@ -297,41 +430,67 @@ def reload(self): logger.warning('Could not find spec for job {}'.format(job.id)) self.jobs.append(job) - def write_info(self): - write_pod(self.info.to_pod(), self.infofile) + def write_info(self) -> None: + """ + write run info to infofile + """ + if self.info: + write_pod(self.info.to_pod(), self.infofile) - def write_state(self): - write_pod(self.state.to_pod(), self.statefile) + def write_state(self) -> None: + """ + write run state into statefile + """ + if self.state: + write_pod(self.state.to_pod(), self.statefile) - def write_config(self, config): + def write_config(self, config: CombinedConfig) -> None: + """ + write config into config file + """ self._combined_config = config write_pod(config.to_pod(), self.configfile) - def read_config(self): + def read_config(self) -> Optional[CombinedConfig]: + """ + read combined config file + """ if not os.path.isfile(self.configfile): return None return CombinedConfig.from_pod(read_pod(self.configfile)) - def set_target_info(self, ti): + def set_target_info(self, ti: TargetInfo) -> None: + """ + set target info + """ self.target_info = ti write_pod(ti.to_pod(), self.targetfile) - def write_job_specs(self, job_specs): + def write_job_specs(self, job_specs: List[JobSpecProtocol]) -> None: + """ + write job specifications + """ job_specs[0].to_pod() js_pod = {'jobs': [js.to_pod() for js in job_specs]} write_pod(js_pod, self.jobsfile) - def read_job_specs(self): + def read_job_specs(self) -> Optional[List[JobSpecProtocol]]: + """ + read job specifications + """ if not os.path.isfile(self.jobsfile): return None pod = read_pod(self.jobsfile) - return [JobSpec.from_pod(jp) for jp in pod['jobs']] + return cast(List[JobSpecProtocol], [JobSpec.from_pod(jp) for jp in pod['jobs']]) - def move_failed(self, job_output): - name = os.path.basename(job_output.basepath) - attempt = job_output.retry + 1 - failed_name = '{}-attempt{:02}'.format(name, attempt) - failed_path = os.path.join(self.failed_dir, failed_name) + def move_failed(self, job_output: 'JobOutput') -> None: + """ + move output of failed jobs to failed file + """ + name: str = os.path.basename(job_output.basepath) + attempt: int = job_output.retry + 1 + failed_name: str = '{}-attempt{:02}'.format(name, attempt) + failed_path: str = os.path.join(self.failed_dir, failed_name) if os.path.exists(failed_path): raise ValueError('Path {} already exists'.format(failed_path)) shutil.move(job_output.basepath, failed_path) @@ -343,31 +502,35 @@ class JobOutput(Output): kind = 'job' # pylint: disable=redefined-builtin - def __init__(self, path, id, label, iteration, retry): + def __init__(self, path: str, id: str, label: str, iteration: int, retry: int): super(JobOutput, self).__init__(path) self.id = id self.label = label self.iteration = iteration self.retry = retry - self.result = None - self.spec = None + self.result: Optional['Result'] = None + self.spec: Optional[JobSpecProtocol] = None self.reload() @property - def augmentations(self): + def augmentations(self) -> List: + """ + list of augmentations + """ job_augs = set([]) - for aug in self.spec.augmentations: - job_augs.add(aug) + if self.spec: + for aug in self.spec.augmentations: + job_augs.add(aug) return list(job_augs) class Result(Podable): - _pod_serialization_version = 1 + _pod_serialization_version: int = 1 @staticmethod - def from_pod(pod): - instance = super(Result, Result).from_pod(pod) + def from_pod(pod) -> 'Result': + instance = cast('Result', super(Result, Result).from_pod(pod)) instance.status = Status.from_pod(pod['status']) instance.metrics = [Metric.from_pod(m) for m in pod['metrics']] instance.artifacts = [Artifact.from_pod(a) for a in pod['artifacts']] @@ -376,45 +539,67 @@ def from_pod(pod): instance.metadata = pod.get('metadata', OrderedDict()) return instance - def __init__(self): + def __init__(self) -> None: # pylint: disable=no-member super(Result, self).__init__() - self.status = Status.NEW - self.metrics = [] - self.artifacts = [] - self.events = [] - self.classifiers = OrderedDict() - self.metadata = OrderedDict() - - def add_metric(self, name, value, units=None, lower_is_better=False, - classifiers=None): + self.status: 'StatusType' = Status.NEW + self.metrics: List['Metric'] = [] + self.artifacts: List['Artifact'] = [] + self.events: List['Event'] = [] + self.classifiers: OrderedDict = OrderedDict() + self.metadata: OrderedDict = OrderedDict() + + def add_metric(self, name: str, value: Any, + units: Optional[str] = None, lower_is_better: bool = False, + classifiers: Optional[Dict[str, Any]] = None) -> None: + """ + add a metric to the workload execution result + """ metric = Metric(name, value, units, lower_is_better, classifiers) logger.debug('Adding metric: {}'.format(metric)) self.metrics.append(metric) - def add_artifact(self, name, path, kind, description=None, classifiers=None, - is_dir=False): + def add_artifact(self, name: str, path: str, kind: str, + description: Optional[str] = None, + classifiers: Optional[Dict[str, Any]] = None, + is_dir: bool = False) -> None: + """ + add artifact to the workload execution result + """ artifact = Artifact(name, path, kind, description=description, classifiers=classifiers, is_dir=is_dir) logger.debug('Adding artifact: {}'.format(artifact)) self.artifacts.append(artifact) - def add_event(self, message): + def add_event(self, message: str): + """ + add new event with the message into result + """ self.events.append(Event(message)) - def get_metric(self, name): + def get_metric(self, name: str) -> Optional['Metric']: + """ + get the specified metric from workload execution result + """ for metric in self.metrics: if metric.name == name: return metric return None - def get_artifact(self, name): + def get_artifact(self, name: str) -> Optional['Artifact']: + """ + get the specified artifact from workload execution result + """ for artifact in self.artifacts: if artifact.name == name: return artifact raise HostError('Artifact "{}" not found'.format(name)) - def add_classifier(self, name, value, overwrite=False): + def add_classifier(self, name: str, value: Any, + overwrite: bool = False) -> None: + """ + add classifier to the workload execution result and update the metrics and artifacts + """ if name in self.classifiers and not overwrite: raise ValueError('Cannot overwrite "{}" classifier.'.format(name)) self.classifiers[name] = value @@ -429,10 +614,13 @@ def add_classifier(self, name, value, overwrite=False): raise ValueError('Cannot overwrite "{}" classifier; clashes with {}.'.format(name, artifact)) artifact.classifiers[name] = value - def add_metadata(self, key, *args, **kwargs): - force = kwargs.pop('force', False) + def add_metadata(self, key: str, *args, **kwargs) -> None: + """ + add metadata to workload execution result + """ + force: bool = kwargs.pop('force', False) if kwargs: - msg = 'Unexpected keyword arguments: {}' + msg: str = 'Unexpected keyword arguments: {}' raise ValueError(msg.format(kwargs)) if key in self.metadata and not force: @@ -450,8 +638,11 @@ def add_metadata(self, key, *args, **kwargs): self.metadata[key] = value - def update_metadata(self, key, *args): - if not args: + def update_metadata(self, key: str, *args) -> None: + """ + update an existing metadata in workload execution output + """ + if not args or len(args) == 0: del self.metadata[key] return @@ -494,7 +685,7 @@ def _pod_upgrade_v1(pod): return pod -ARTIFACT_TYPES = ['log', 'meta', 'data', 'export', 'raw'] +ARTIFACT_TYPES: List[str] = ['log', 'meta', 'data', 'export', 'raw'] ArtifactType = enum(ARTIFACT_TYPES) @@ -546,20 +737,22 @@ class Artifact(Podable): """ - _pod_serialization_version = 2 + _pod_serialization_version: int = 2 @staticmethod - def from_pod(pod): + def from_pod(pod) -> 'Artifact': pod = Artifact._upgrade_pod(pod) - pod_version = pod.pop('_pod_version') + pod_version: int = pod.pop('_pod_version') pod['kind'] = ArtifactType(pod['kind']) instance = Artifact(**pod) instance._pod_version = pod_version # pylint: disable =protected-access instance.is_dir = pod.pop('is_dir') return instance - def __init__(self, name, path, kind, description=None, classifiers=None, - is_dir=False): + def __init__(self, name: str, path: str, kind: str, + description: Optional[str] = None, + classifiers: Optional[Dict[str, Any]] = None, + is_dir: bool = False): """" :param name: Name that uniquely identifies this artifact. :param path: The *relative* path of the artifact. Depending on the @@ -582,13 +775,13 @@ def __init__(self, name, path, kind, description=None, classifiers=None, try: self.kind = ArtifactType(kind) except ValueError: - msg = 'Invalid Artifact kind: {}; must be in {}' + msg: str = 'Invalid Artifact kind: {}; must be in {}' raise ValueError(msg.format(kind, ARTIFACT_TYPES)) self.description = description self.classifiers = classifiers or {} self.is_dir = is_dir - def to_pod(self): + def to_pod(self) -> Dict[str, Any]: pod = super(Artifact, self).to_pod() pod.update(self.__dict__) pod['kind'] = str(self.kind) @@ -596,12 +789,18 @@ def to_pod(self): return pod @staticmethod - def _pod_upgrade_v1(pod): + def _pod_upgrade_v1(pod: Dict[str, Any]) -> Dict[str, Any]: + """ + pod upgrade function version 1 + """ pod['_pod_version'] = pod.get('_pod_version', 1) return pod @staticmethod - def _pod_upgrade_v2(pod): + def _pod_upgrade_v2(pod: Dict[str, Any]) -> Dict[str, Any]: + """ + pod upgrade function version 2 + """ pod['is_dir'] = pod.get('is_dir', False) return pod @@ -630,33 +829,37 @@ class Metric(Podable): to identify sub-tests). """ - __slots__ = ['name', 'value', 'units', 'lower_is_better', 'classifiers'] - _pod_serialization_version = 1 + __slots__: List[str] = ['name', 'value', 'units', 'lower_is_better', 'classifiers'] + _pod_serialization_version: int = 1 @staticmethod - def from_pod(pod): + def from_pod(pod) -> 'Metric': pod = Metric._upgrade_pod(pod) - pod_version = pod.pop('_pod_version') + pod_version: int = pod.pop('_pod_version') instance = Metric(**pod) instance._pod_version = pod_version # pylint: disable =protected-access return instance @property - def label(self): + def label(self) -> str: + """ + label of the metric + """ parts = ['{}={}'.format(n, v) for n, v in self.classifiers.items()] parts.insert(0, self.name) return '/'.join(parts) - def __init__(self, name, value, units=None, lower_is_better=False, - classifiers=None): + def __init__(self, name: str, value: Any, units: Optional[str] = None, + lower_is_better: bool = False, + classifiers: Optional[Dict[str, Any]] = None): super(Metric, self).__init__() self.name = name - self.value = numeric(value) + self.value: Union[int, float] = numeric(value) self.units = units self.lower_is_better = lower_is_better self.classifiers = classifiers or {} - def to_pod(self): + def to_pod(self) -> Dict[str, Any]: pod = super(Metric, self).to_pod() pod['name'] = self.name pod['value'] = self.value @@ -666,19 +869,22 @@ def to_pod(self): return pod @staticmethod - def _pod_upgrade_v1(pod): + def _pod_upgrade_v1(pod: Dict[str, Any]) -> Dict[str, Any]: + """ + pod upgrade function version 1 + """ pod['_pod_version'] = pod.get('_pod_version', 1) return pod - def __str__(self): - result = '{}: {}'.format(self.name, self.value) + def __str__(self) -> str: + result: str = '{}: {}'.format(self.name, self.value) if self.units: result += ' ' + self.units result += ' ({})'.format('-' if self.lower_is_better else '+') return result - def __repr__(self): - text = self.__str__() + def __repr__(self) -> str: + text: str = self.__str__() if self.classifiers: return '<{} {}>'.format(text, format_ordered_dict(self.classifiers)) else: @@ -688,42 +894,47 @@ def __repr__(self): class Event(Podable): """ An event that occured during a run. - """ - __slots__ = ['timestamp', 'message'] - _pod_serialization_version = 1 + __slots__: List[str] = ['timestamp', 'message'] + _pod_serialization_version: int = 1 @staticmethod - def from_pod(pod): + def from_pod(pod) -> 'Event': pod = Event._upgrade_pod(pod) - pod_version = pod.pop('_pod_version') + pod_version: int = pod.pop('_pod_version') instance = Event(pod['message']) instance.timestamp = pod['timestamp'] instance._pod_version = pod_version # pylint: disable =protected-access return instance @property - def summary(self): - lines = self.message.split('\n') - result = lines[0] + def summary(self) -> str: + """ + summary of the event + """ + lines: List[str] = self.message.split('\n') + result: str = lines[0] if len(lines) > 1: result += '[...]' return result - def __init__(self, message): + def __init__(self, message: str): super(Event, self).__init__() - self.timestamp = datetime.utcnow() + self.timestamp: datetime = datetime.utcnow() self.message = str(message) - def to_pod(self): + def to_pod(self) -> Dict[str, Any]: pod = super(Event, self).to_pod() pod['timestamp'] = self.timestamp pod['message'] = self.message return pod @staticmethod - def _pod_upgrade_v1(pod): + def _pod_upgrade_v1(pod: Dict[str, Any]) -> Dict[str, Any]: + """ + pod upgrade function version 1 + """ pod['_pod_version'] = pod.get('_pod_version', 1) return pod @@ -733,7 +944,11 @@ def __str__(self): __repr__ = __str__ -def init_run_output(path, wa_state, force=False): +def init_run_output(path: str, wa_state: ConfigManager, + force: bool = False) -> RunOutput: + """ + initialize run output + """ if os.path.exists(path): if force: logger.info('Removing existing output directory.') @@ -743,7 +958,7 @@ def init_run_output(path, wa_state, force=False): logger.info('Creating output directory.') os.makedirs(path) - meta_dir = os.path.join(path, '__meta') + meta_dir: str = os.path.join(path, '__meta') os.makedirs(meta_dir) _save_raw_config(meta_dir, wa_state) touch(os.path.join(path, 'run.log')) @@ -751,7 +966,7 @@ def init_run_output(path, wa_state, force=False): info = RunInfo( run_name=wa_state.run_config.run_name, project=wa_state.run_config.project, - project_stage=wa_state.run_config.project_stage, + project_stage=wa_state.run_config.project_stage or '', ) write_pod(info.to_pod(), os.path.join(meta_dir, 'run_info.json')) write_pod(RunState().to_pod(), os.path.join(path, '.run_state.json')) @@ -764,19 +979,25 @@ def init_run_output(path, wa_state, force=False): return ro -def init_job_output(run_output, job): - output_name = '{}-{}-{}'.format(job.id, job.spec.label, job.iteration) - path = os.path.join(run_output.basepath, output_name) +def init_job_output(run_output: RunOutput, job: 'Job') -> JobOutput: + """ + initialize job output + """ + output_name: str = '{}-{}-{}'.format(job.id, job.spec.label, job.iteration) + path: str = os.path.join(run_output.basepath, output_name) ensure_directory_exists(path) write_pod(Result().to_pod(), os.path.join(path, 'result.json')) - job_output = JobOutput(path, job.id, job.label, job.iteration, job.retries) + job_output = JobOutput(path, job.id or '', job.label, job.iteration, job.retries) job_output.spec = job.spec job_output.status = job.status run_output.jobs.append(job_output) return job_output -def discover_wa_outputs(path): +def discover_wa_outputs(path: str) -> Generator[RunOutput, Any, None]: + """ + discover workload automation outputs + """ # Use topdown=True to allow pruning dirs for root, dirs, _ in os.walk(path, topdown=True): if '__meta' in dirs: @@ -786,37 +1007,46 @@ def discover_wa_outputs(path): dirs.clear() -def _save_raw_config(meta_dir, state): - raw_config_dir = os.path.join(meta_dir, 'raw_config') +def _save_raw_config(meta_dir: str, state: 'ConfigManager') -> None: + """ + save raw configuration + """ + raw_config_dir: str = os.path.join(meta_dir, 'raw_config') os.makedirs(raw_config_dir) for i, source in enumerate(state.loaded_config_sources): if not os.path.isfile(source): continue - basename = os.path.basename(source) - dest_path = os.path.join(raw_config_dir, 'cfg{}-{}'.format(i, basename)) + basename: str = os.path.basename(source) + dest_path: str = os.path.join(raw_config_dir, 'cfg{}-{}'.format(i, basename)) shutil.copy(source, dest_path) class DatabaseOutput(Output): - kind = None + kind: Optional[str] = None @property - def resultfile(self): + def resultfile(self) -> Optional[Dict[str, Any]]: # type:ignore if self.conn is None or self.oid is None: return {} pod = self._get_pod_version() - pod['metrics'] = self._get_metrics() - pod['status'] = self._get_status() - pod['classifiers'] = self._get_classifiers(self.oid, 'run') - pod['events'] = self._get_events() - pod['artifacts'] = self._get_artifacts() + if pod: + pod['metrics'] = self._get_metrics() + pod['status'] = self._get_status() + pod['classifiers'] = self._get_classifiers(self.oid, 'run') + pod['events'] = self._get_events() + pod['artifacts'] = self._get_artifacts() return pod @staticmethod - def _build_command(columns, tables, conditions=None, joins=None): - cmd = '''SELECT\n\t{}\nFROM\n\t{}'''.format(',\n\t'.join(columns), ',\n\t'.join(tables)) + def _build_command(columns: List[str], tables: List[str], + conditions: Optional[List[str]] = None, + joins: Optional[List[Tuple[str, str]]] = None) -> str: + """ + build command + """ + cmd: str = '''SELECT\n\t{}\nFROM\n\t{}'''.format(',\n\t'.join(columns), ',\n\t'.join(tables)) if joins: for join in joins: cmd += '''\nLEFT JOIN {} ON {}'''.format(join[0], join[1]) @@ -824,10 +1054,11 @@ def _build_command(columns, tables, conditions=None, joins=None): cmd += '''\nWHERE\n\t{}'''.format('\nAND\n\t'.join(conditions)) return cmd + ';' - def __init__(self, conn, oid=None, reload=True): # pylint: disable=super-init-not-called + def __init__(self, conn: Optional['_psycopg.connection'], + oid: Optional[UUID] = None, reload: bool = True): # pylint: disable=super-init-not-called self.conn = conn self.oid = oid - self.result = None + self.result: Optional[Result] = None if reload: self.reload() @@ -845,32 +1076,43 @@ def reload(self): self.result.status = Status.UNKNOWN self.add_event(str(e)) - def get_artifact_path(self, name): + def get_artifact_path(self, name: str) -> str: artifact = self.get_artifact(name) - if artifact.is_dir: + if artifact and artifact.is_dir: return self._read_dir_artifact(artifact) else: - return self._read_file_artifact(artifact) + return cast(str, self._read_file_artifact(artifact)) - def _read_dir_artifact(self, artifact): + def _read_dir_artifact(self, artifact: Artifact) -> str: + """ + read directory artifact + """ artifact_path = tempfile.mkdtemp(prefix='wa_') - with tarfile.open(fileobj=self.conn.lobject(int(artifact.path), mode='b'), mode='r|gz') as tar_file: - safe_extract(tar_file, artifact_path) - self.conn.commit() + if self.conn: + with tarfile.open(fileobj=self.conn.lobject(int(artifact.path), mode='b'), + mode=cast(tarfile._FileCreationModes, 'r|gz')) as tar_file: # type:ignore + safe_extract(tar_file, artifact_path) + self.conn.commit() return artifact_path - def _read_file_artifact(self, artifact): - artifact = StringIO(self.conn.lobject(int(artifact.path)).read()) - self.conn.commit() - return artifact + def _read_file_artifact(self, artifact: Optional[Artifact]) -> StringIO: + """ + read file artifact + """ + artifact_ = StringIO(self.conn.lobject(int(artifact.path if artifact else '')).read() if self.conn else '') + if self.conn: + self.conn.commit() + return artifact_ # pylint: disable=too-many-locals - def _read_db(self, columns, tables, conditions=None, join=None, as_dict=True): + def _read_db(self, columns: List[Union[str, Tuple[str, str]]], tables: List[str], + conditions: Optional[List[str]] = None, join: Optional[List[Tuple[str, str]]] = None, + as_dict: bool = True) -> List[Dict]: # Automatically remove table name from column when using column names as keys or # allow for column names to be aliases when retrieving the data, # (db_column_name, alias) - db_columns = [] - aliases_colunms = [] + db_columns: List[str] = [] + aliases_colunms: List[str] = [] for column in columns: if isinstance(column, tuple): db_columns.append(column[0]) @@ -879,19 +1121,20 @@ def _read_db(self, columns, tables, conditions=None, join=None, as_dict=True): db_columns.append(column) aliases_colunms.append(column.rsplit('.', 1)[-1]) - cmd = self._build_command(db_columns, tables, conditions, join) + cmd: str = self._build_command(db_columns, tables, conditions, join) logger.debug(cmd) - with self.conn.cursor() as cursor: - cursor.execute(cmd) - results = cursor.fetchall() - self.conn.commit() + if self.conn: + with self.conn.cursor() as cursor: + cursor.execute(cmd) + results = cursor.fetchall() + self.conn.commit() if not as_dict: - return results + return cast(List[Dict[Any, Any]], results) # Format the output dict using column names as keys - output = [] + output: List[Dict] = [] for result in results: entry = {} for k, v in zip(aliases_colunms, result): @@ -899,9 +1142,12 @@ def _read_db(self, columns, tables, conditions=None, join=None, as_dict=True): output.append(entry) return output - def _get_pod_version(self): - columns = ['_pod_version', '_pod_serialization_version'] - tables = ['{}s'.format(self.kind)] + def _get_pod_version(self) -> Optional[Dict]: + """ + get pod version from database + """ + columns: List[Union[str, Tuple[str, str]]] = ['_pod_version', '_pod_serialization_version'] + tables: List[str] = ['{}s'.format(self.kind)] conditions = ['{}s.oid = \'{}\''.format(self.kind, self.oid)] results = self._read_db(columns, tables, conditions) if results: @@ -909,68 +1155,91 @@ def _get_pod_version(self): else: return None - def _populate_classifers(self, pod, kind): + def _populate_classifers(self, pod: List[Dict], kind: str) -> List[Dict]: + """ + populate classifiers + """ for entry in pod: oid = entry.pop('oid') entry['classifiers'] = self._get_classifiers(oid, kind) return pod - def _get_classifiers(self, oid, kind): - columns = ['classifiers.key', 'classifiers.value'] - tables = ['classifiers'] - conditions = ['{}_oid = \'{}\''.format(kind, oid)] + def _get_classifiers(self, oid: UUID, kind: str) -> Dict: + """ + get classifiers from database. Classifiers are used to annotate generated + metrics and artifacts in order to assist post-processing tools in sorting + through them. + """ + columns: List[Union[str, Tuple[str, str]]] = ['classifiers.key', 'classifiers.value'] + tables: List[str] = ['classifiers'] + conditions: List[str] = ['{}_oid = \'{}\''.format(kind, oid)] results = self._read_db(columns, tables, conditions, as_dict=False) classifiers = {} for (k, v) in results: classifiers[k] = v return classifiers - def _get_metrics(self): - columns = ['metrics.name', 'metrics.value', 'metrics.units', - 'metrics.lower_is_better', - 'metrics.oid', 'metrics._pod_version', - 'metrics._pod_serialization_version'] - tables = ['metrics'] - joins = [('classifiers', 'classifiers.metric_oid = metrics.oid')] - conditions = ['metrics.{}_oid = \'{}\''.format(self.kind, self.oid)] - pod = self._read_db(columns, tables, conditions, joins) + def _get_metrics(self) -> List[Dict]: + """ + get metrics from database + """ + columns: List[Union[str, Tuple[str, str]]] = ['metrics.name', 'metrics.value', 'metrics.units', + 'metrics.lower_is_better', + 'metrics.oid', 'metrics._pod_version', + 'metrics._pod_serialization_version'] + tables: List[str] = ['metrics'] + joins: List[Tuple[str, str]] = [('classifiers', 'classifiers.metric_oid = metrics.oid')] + conditions: List[str] = ['metrics.{}_oid = \'{}\''.format(self.kind, self.oid)] + pod: List[Dict] = self._read_db(columns, tables, conditions, joins) return self._populate_classifers(pod, 'metric') - def _get_status(self): - columns = ['{}s.status'.format(self.kind)] - tables = ['{}s'.format(self.kind)] - conditions = ['{}s.oid = \'{}\''.format(self.kind, self.oid)] - results = self._read_db(columns, tables, conditions, as_dict=False) + def _get_status(self) -> Any: + """ + get status from database + """ + columns: List[Union[str, Tuple[str, str]]] = ['{}s.status'.format(self.kind)] + tables: List[str] = ['{}s'.format(self.kind)] + conditions: List[str] = ['{}s.oid = \'{}\''.format(self.kind, self.oid)] + results: List[Dict] = self._read_db(columns, tables, conditions, as_dict=False) if results: return results[0][0] else: return None - def _get_artifacts(self): - columns = ['artifacts.name', 'artifacts.description', 'artifacts.kind', - ('largeobjects.lo_oid', 'path'), 'artifacts.oid', 'artifacts.is_dir', - 'artifacts._pod_version', 'artifacts._pod_serialization_version'] - tables = ['largeobjects', 'artifacts'] - joins = [('classifiers', 'classifiers.artifact_oid = artifacts.oid')] - conditions = ['artifacts.{}_oid = \'{}\''.format(self.kind, self.oid), - 'artifacts.large_object_uuid = largeobjects.oid'] + def _get_artifacts(self) -> List[Dict]: + """ + get artifacts from database + """ + columns: List[Union[str, Tuple[str, str]]] = ['artifacts.name', 'artifacts.description', 'artifacts.kind', + ('largeobjects.lo_oid', 'path'), 'artifacts.oid', 'artifacts.is_dir', + 'artifacts._pod_version', 'artifacts._pod_serialization_version'] + tables: List[str] = ['largeobjects', 'artifacts'] + joins: List[Tuple[str, str]] = [('classifiers', 'classifiers.artifact_oid = artifacts.oid')] + conditions: List[str] = ['artifacts.{}_oid = \'{}\''.format(self.kind, self.oid), + 'artifacts.large_object_uuid = largeobjects.oid'] # If retrieving run level artifacts we want those that don't also belong to a job if self.kind == 'run': conditions.append('artifacts.job_oid IS NULL') - pod = self._read_db(columns, tables, conditions, joins) + pod: List[Dict] = self._read_db(columns, tables, conditions, joins) for artifact in pod: artifact['path'] = str(artifact['path']) return self._populate_classifers(pod, 'metric') - def _get_events(self): - columns = ['events.message', 'events.timestamp'] - tables = ['events'] - conditions = ['events.{}_oid = \'{}\''.format(self.kind, self.oid)] + def _get_events(self) -> List[Dict]: + """ + get events from database + """ + columns: List[Union[str, Tuple[str, str]]] = ['events.message', 'events.timestamp'] + tables: List[str] = ['events'] + conditions: List[str] = ['events.{}_oid = \'{}\''.format(self.kind, self.oid)] return self._read_db(columns, tables, conditions) -def kernel_config_from_db(raw): - kernel_config = {} +def kernel_config_from_db(raw: Any) -> Dict: + """ + get kernel configuration from database + """ + kernel_config: Dict = {} if raw: for k, v in zip(raw[0], raw[1]): kernel_config[k] = v @@ -979,25 +1248,34 @@ def kernel_config_from_db(raw): class RunDatabaseOutput(DatabaseOutput, RunOutputCommon): - kind = 'run' + kind: str = 'run' @property def basepath(self): + """ + get base path of database + """ return 'db:({})-{}@{}:{}'.format(self.dbname, self.user, self.host, self.port) @property - def augmentations(self): - columns = ['augmentations.name'] - tables = ['augmentations'] - conditions = ['augmentations.run_oid = \'{}\''.format(self.oid)] - results = self._read_db(columns, tables, conditions, as_dict=False) + def augmentations(self) -> List: + """ + get augmentations for run output from databse + """ + columns: List[Union[str, Tuple[str, str]]] = ['augmentations.name'] + tables: List[str] = ['augmentations'] + conditions: List[str] = ['augmentations.run_oid = \'{}\''.format(self.oid)] + results: List[Dict] = self._read_db(columns, tables, conditions, as_dict=False) return [a for augs in results for a in augs] @property - def _db_infofile(self): - columns = ['start_time', 'project', ('run_uuid', 'uuid'), 'end_time', - 'run_name', 'duration', '_pod_version', '_pod_serialization_version'] + def _db_infofile(self) -> Dict: + """ + get run info file from db + """ + columns: List[Union[str, Tuple[str, str]]] = ['start_time', 'project', ('run_uuid', 'uuid'), 'end_time', + 'run_name', 'duration', '_pod_version', '_pod_serialization_version'] tables = ['runs'] conditions = ['runs.run_uuid = \'{}\''.format(self.run_uuid)] pod = self._read_db(columns, tables, conditions) @@ -1006,18 +1284,21 @@ def _db_infofile(self): return pod[0] @property - def _db_targetfile(self): - columns = ['os', 'is_rooted', 'target', 'modules', 'abi', 'cpus', 'os_version', - 'hostid', 'hostname', 'kernel_version', 'kernel_release', - 'kernel_sha1', 'kernel_config', 'sched_features', 'page_size_kb', - 'system_id', 'screen_resolution', 'prop', 'android_id', - '_pod_version', '_pod_serialization_version'] - tables = ['targets'] - conditions = ['targets.run_oid = \'{}\''.format(self.oid)] - pod = self._read_db(columns, tables, conditions) - if not pod: + def _db_targetfile(self) -> Dict: + """ + get database target file + """ + columns: List[Union[str, Tuple[str, str]]] = ['os', 'is_rooted', 'target', 'modules', 'abi', 'cpus', 'os_version', + 'hostid', 'hostname', 'kernel_version', 'kernel_release', + 'kernel_sha1', 'kernel_config', 'sched_features', 'page_size_kb', + 'system_id', 'screen_resolution', 'prop', 'android_id', + '_pod_version', '_pod_serialization_version'] + tables: List[str] = ['targets'] + conditions: List[str] = ['targets.run_oid = \'{}\''.format(self.oid)] + pod_: List[Dict] = self._read_db(columns, tables, conditions) + if not pod_: return {} - pod = pod[0] + pod: Dict = pod_[0] try: pod['cpus'] = [json.loads(cpu) for cpu in pod.pop('cpus')] except SerializerSyntaxError: @@ -1027,13 +1308,16 @@ def _db_targetfile(self): return pod @property - def _db_statefile(self): + def _db_statefile(self) -> Dict: + """ + get state file from database + """ # Read overall run information - columns = ['runs.state'] - tables = ['runs'] - conditions = ['runs.run_uuid = \'{}\''.format(self.run_uuid)] - pod = self._read_db(columns, tables, conditions) - pod = pod[0].get('state') + columns: List[Union[str, Tuple[str, str]]] = ['runs.state'] + tables: List[str] = ['runs'] + conditions: List[str] = ['runs.run_uuid = \'{}\''.format(self.run_uuid)] + pod_: List[Dict] = self._read_db(columns, tables, conditions) + pod = pod_[0].get('state') if not pod: return {} @@ -1041,7 +1325,7 @@ def _db_statefile(self): columns = ['jobs.job_id', 'jobs.oid'] tables = ['jobs'] conditions = ['jobs.run_oid = \'{}\''.format(self.oid)] - job_oids = self._read_db(columns, tables, conditions) + job_oids: List[Dict] = self._read_db(columns, tables, conditions) # Match job oid with jobs from state file for job in pod.get('jobs', []): @@ -1052,15 +1336,18 @@ def _db_statefile(self): return pod @property - def _db_jobsfile(self): - workload_params = self._get_parameters('workload') - runtime_params = self._get_parameters('runtime') + def _db_jobsfile(self) -> List[Dict]: + """ + get jobs file from database + """ + workload_params: Dict = self._get_parameters('workload') + runtime_params: Dict = self._get_parameters('runtime') - columns = [('jobs.job_id', 'id'), 'jobs.label', 'jobs.workload_name', - 'jobs.oid', 'jobs._pod_version', 'jobs._pod_serialization_version'] - tables = ['jobs'] - conditions = ['jobs.run_oid = \'{}\''.format(self.oid)] - jobs = self._read_db(columns, tables, conditions) + columns: List[Union[str, Tuple[str, str]]] = [('jobs.job_id', 'id'), 'jobs.label', 'jobs.workload_name', + 'jobs.oid', 'jobs._pod_version', 'jobs._pod_serialization_version'] + tables: List[str] = ['jobs'] + conditions: List[str] = ['jobs.run_oid = \'{}\''.format(self.oid)] + jobs: List[Dict] = self._read_db(columns, tables, conditions) for job in jobs: job['augmentations'] = self._get_job_augmentations(job['oid']) @@ -1070,48 +1357,52 @@ def _db_jobsfile(self): return jobs @property - def _db_run_config(self): - pod = defaultdict(dict) - parameter_types = ['augmentation', 'resource_getter'] + def _db_run_config(self) -> Union[Dict, DefaultDict]: + """ + get run configuration from database + """ + pod: DefaultDict = defaultdict(dict) + parameter_types: List[str] = ['augmentation', 'resource_getter'] for parameter_type in parameter_types: - columns = ['parameters.name', 'parameters.value', - 'parameters.value_type', - ('{}s.name'.format(parameter_type), '{}'.format(parameter_type))] - tables = ['parameters', '{}s'.format(parameter_type)] - conditions = ['parameters.run_oid = \'{}\''.format(self.oid), - 'parameters.type = \'{}\''.format(parameter_type), - 'parameters.{0}_oid = {0}s.oid'.format(parameter_type)] - configs = self._read_db(columns, tables, conditions) - for config in configs: - entry = {config['name']: json.loads(config['value'])} - pod['{}s'.format(parameter_type)][config.pop(parameter_type)] = entry + columns: List[Union[str, Tuple[str, str]]] = ['parameters.name', 'parameters.value', + 'parameters.value_type', + ('{}s.name'.format(parameter_type), + '{}'.format(parameter_type))] + tables: List[str] = ['parameters', '{}s'.format(parameter_type)] + conditions: List[str] = ['parameters.run_oid = \'{}\''.format(self.oid), + 'parameters.type = \'{}\''.format(parameter_type), + 'parameters.{0}_oid = {0}s.oid'.format(parameter_type)] + configs: List[Dict] = self._read_db(columns, tables, conditions) + for config_t in configs: + entry: Dict = {config_t['name']: json.loads(config_t['value'])} + pod['{}s'.format(parameter_type)][config_t.pop(parameter_type)] = entry # run config columns = ['runs.max_retries', 'runs.allow_phone_home', 'runs.bail_on_init_failure', 'runs.retry_on_status'] tables = ['runs'] conditions = ['runs.oid = \'{}\''.format(self.oid)] - config = self._read_db(columns, tables, conditions) - if not config: + config_ = self._read_db(columns, tables, conditions) + if not config_: return {} - config = config[0] + config = config_[0] # Convert back into a string representation of an enum list config['retry_on_status'] = config['retry_on_status'][1:-1].split(',') pod.update(config) return pod def __init__(self, - password=None, - dbname='wa', - host='localhost', - port='5432', - user='postgres', - run_uuid=None, - list_runs=False): + password: Optional[str] = None, + dbname: str = 'wa', + host: str = 'localhost', + port: str = '5432', + user: str = 'postgres', + run_uuid: Optional[UUID] = None, + list_runs: bool = False): if psycopg2 is None: - msg = 'Please install the psycopg2 in order to connect to postgres databases' + msg: str = 'Please install the psycopg2 in order to connect to postgres databases' raise HostError(msg) self.dbname = dbname @@ -1120,15 +1411,15 @@ def __init__(self, self.user = user self.password = password self.run_uuid = run_uuid - self.conn = None + self.conn: Optional[_psycopg.connection] = None - self.info = None - self.state = None - self.result = None - self.target_info = None - self._combined_config = None - self.jobs = [] - self.job_specs = [] + self.info: Optional[RunInfo] = None + self.state: Optional[RunState] = None + self.result: Optional[Result] = None + self.target_info: Optional[TargetInfo] = None + self._combined_config: Optional[CombinedConfig] = None + self.jobs: List['JobDatabaseOutput'] = [] + self.job_specs: List[JobSpecProtocol] = [] self.connect() super(RunDatabaseOutput, self).__init__(conn=self.conn, reload=False) @@ -1156,13 +1447,19 @@ def __init__(self, self.oid = self._get_oid() self.reload() - def read_job_specs(self): - job_specs = [] + def read_job_specs(self) -> List[JobSpecProtocol]: + """ + read job specifications + """ + job_specs: List[JobSpecProtocol] = [] for job in self._db_jobsfile: - job_specs.append(JobSpec.from_pod(job)) + job_specs.append(cast(JobSpecProtocol, JobSpec.from_pod(job))) return job_specs - def connect(self): + def connect(self) -> None: + """ + connect to database + """ if self.conn and not self.conn.closed: return try: @@ -1170,20 +1467,24 @@ def connect(self): user=self.user, host=self.host, password=self.password, - port=self.port) - except Psycopg2Error as e: + port=self.port) if psycopg2 else None + except Psycopg2Error or Exception as e: raise HostError('Unable to connect to the Database: "{}'.format(e.args[0])) - def disconnect(self): - self.conn.commit() - self.conn.close() + def disconnect(self) -> None: + """ + disconnect from database + """ + if self.conn: + self.conn.commit() + self.conn.close() - def reload(self): + def reload(self) -> None: super(RunDatabaseOutput, self).reload() - info_pod = self._db_infofile - state_pod = self._db_statefile + info_pod: Dict = self._db_infofile + state_pod: Dict = self._db_statefile if not info_pod or not state_pod: - msg = '"{}" does not appear to be a valid WA Database Output.' + msg: str = '"{}" does not appear to be a valid WA Database Output.' raise ValueError(msg.format(self.oid)) self.info = RunInfo.from_pod(info_pod) @@ -1202,24 +1503,30 @@ def reload(self): logger.warning('Could not find spec for job {}'.format(job.id)) self.jobs.append(job) - def _get_oid(self): - columns = ['{}s.oid'.format(self.kind)] - tables = ['{}s'.format(self.kind)] - conditions = ['runs.run_uuid = \'{}\''.format(self.run_uuid)] - oid = self._read_db(columns, tables, conditions, as_dict=False) + def _get_oid(self) -> UUID: + """ + get database oid + """ + columns: List[Union[str, Tuple[str, str]]] = ['{}s.oid'.format(self.kind)] + tables: List[str] = ['{}s'.format(self.kind)] + conditions: List[str] = ['runs.run_uuid = \'{}\''.format(self.run_uuid)] + oid: List[Dict] = self._read_db(columns, tables, conditions, as_dict=False) if not oid: raise ConfigError('No matching run entries found for run_uuid {}'.format(self.run_uuid)) if len(oid) > 1: raise ConfigError('Multiple entries found for run_uuid: {}'.format(self.run_uuid)) return oid[0][0] - def _get_parameters(self, param_type): - columns = ['parameters.job_oid', 'parameters.name', 'parameters.value'] - tables = ['parameters'] - conditions = ['parameters.type = \'{}\''.format(param_type), - 'parameters.run_oid = \'{}\''.format(self.oid)] - params = self._read_db(columns, tables, conditions, as_dict=False) - parm_dict = defaultdict(dict) + def _get_parameters(self, param_type: str) -> Dict: + """ + get database parameters + """ + columns: List[Union[str, Tuple[str, str]]] = ['parameters.job_oid', 'parameters.name', 'parameters.value'] + tables: List[str] = ['parameters'] + conditions: List[str] = ['parameters.type = \'{}\''.format(param_type), + 'parameters.run_oid = \'{}\''.format(self.oid)] + params: List[Dict] = self._read_db(columns, tables, conditions, as_dict=False) + parm_dict: DefaultDict = defaultdict(dict) for (job_oid, k, v) in params: try: parm_dict[job_oid][k] = json.loads(v) @@ -1227,24 +1534,31 @@ def _get_parameters(self, param_type): logger.debug('Failed to deserialize job_oid:{}-"{}":"{}"'.format(job_oid, k, v)) return parm_dict - def _get_job_augmentations(self, job_oid): - columns = ['jobs_augs.augmentation_oid', 'augmentations.name', - 'augmentations.oid', 'jobs_augs.job_oid'] - tables = ['jobs_augs', 'augmentations'] - conditions = ['jobs_augs.job_oid = \'{}\''.format(job_oid), - 'jobs_augs.augmentation_oid = augmentations.oid'] - augmentations = self._read_db(columns, tables, conditions) + def _get_job_augmentations(self, job_oid: UUID) -> List: + """ + get job augmentations + """ + columns: List[Union[str, Tuple[str, str]]] = ['jobs_augs.augmentation_oid', 'augmentations.name', + 'augmentations.oid', 'jobs_augs.job_oid'] + tables: List[str] = ['jobs_augs', 'augmentations'] + conditions: List[str] = ['jobs_augs.job_oid = \'{}\''.format(job_oid), + 'jobs_augs.augmentation_oid = augmentations.oid'] + augmentations: List[Dict] = self._read_db(columns, tables, conditions) return [aug['name'] for aug in augmentations] - def _list_runs(self): - columns = ['runs.run_uuid', 'runs.run_name', 'runs.project', - 'runs.project_stage', 'runs.status', 'runs.start_time', 'runs.end_time'] - tables = ['runs'] - pod = self._read_db(columns, tables) + def _list_runs(self) -> None: + """ + list runs + """ + columns: List[Union[str, Tuple[str, str]]] = ['runs.run_uuid', 'runs.run_name', 'runs.project', + 'runs.project_stage', 'runs.status', + 'runs.start_time', 'runs.end_time'] + tables: List[str] = ['runs'] + pod: List[Dict] = self._read_db(columns, tables) if pod: - headers = ['Run Name', 'Project', 'Project Stage', 'Start Time', 'End Time', - 'run_uuid'] - run_list = [] + headers: List[str] = ['Run Name', 'Project', 'Project Stage', 'Start Time', 'End Time', + 'run_uuid'] + run_list: List = [] for entry in pod: # Format times to display better start_time = entry['start_time'] @@ -1269,16 +1583,17 @@ def _list_runs(self): class JobDatabaseOutput(DatabaseOutput): - kind = 'job' + kind: str = 'job' - def __init__(self, conn, oid, job_id, label, iteration, retry): + def __init__(self, conn: Optional['_psycopg.connection'], oid: UUID, + job_id: str, label: str, iteration: int, retry: int): super(JobDatabaseOutput, self).__init__(conn, oid=oid) self.id = job_id self.label = label self.iteration = iteration self.retry = retry - self.result = None - self.spec = None + self.result: Optional[Result] = None + self.spec: Optional[JobSpecProtocol] = None self.reload() def __repr__(self): @@ -1289,8 +1604,11 @@ def __str__(self): return '{}-{}-{}'.format(self.id, self.label, self.iteration) @property - def augmentations(self): - job_augs = set([]) + def augmentations(self) -> List: + """ + augmentations + """ + job_augs: Set = set([]) if self.spec: for aug in self.spec.augmentations: job_augs.add(aug) diff --git a/wa/framework/output_processor.py b/wa/framework/output_processor.py index 6fe80f7e9..f1b5c6405 100644 --- a/wa/framework/output_processor.py +++ b/wa/framework/output_processor.py @@ -17,74 +17,108 @@ from wa.framework import pluginloader from wa.framework.exception import ConfigError -from wa.framework.instrument import is_installed +from wa.framework.instrument import is_installed, Instrument from wa.framework.plugin import Plugin from wa.utils.log import log_error, indentcontext from wa.utils.misc import isiterable from wa.utils.types import identifier +from typing import List, Union, TYPE_CHECKING, cast +from types import ModuleType +if TYPE_CHECKING: + from wa.framework.execution import ExecutionContext + from wa.commands.process import ProcessContext class OutputProcessor(Plugin): - - kind = 'output_processor' - requires = [] + """ + These post-process metrics and artifacts generated by workloads or + instruments, as well as target metadata collected by WA, in order to + generate additional metrics and/or artifacts (e.g. generating statistics + or reports). Output processors are also used to export WA output + externally (e.g. upload to a database). + """ + kind: str = 'output_processor' + requires: List[Union[str, Instrument]] = [] def __init__(self, **kwargs): super(OutputProcessor, self).__init__(**kwargs) self.is_enabled = True - def validate(self): + def validate(self) -> None: + """ + validate whether requirements are satisfied + """ super(OutputProcessor, self).validate() for instrument in self.requires: if not is_installed(instrument): msg = 'Instrument "{}" is required by {}, but is not installed.' raise ConfigError(msg.format(instrument, self.name)) - def initialize(self, context): + def initialize(self, context: Union['ExecutionContext', 'ProcessContext']) -> None: pass - def finalize(self, context): + def finalize(self, context: Union['ExecutionContext', 'ProcessContext']) -> None: pass class ProcessorManager(object): - - def __init__(self, loader=pluginloader): + """ + manage the output processors + """ + def __init__(self, loader: ModuleType = pluginloader): self.loader = loader self.logger = logging.getLogger('processor') - self.processors = [] + self.processors: List[OutputProcessor] = [] - def install(self, processor, context): + def install(self, processor: OutputProcessor, context: Union['ExecutionContext', 'ProcessContext']): + """ + install output processor + """ if not isinstance(processor, OutputProcessor): processor = self.loader.get_output_processor(processor) self.logger.debug('Installing {}'.format(processor.name)) - processor.logger.context = context + processor.logger.context = context # type:ignore self.processors.append(processor) context.add_augmentation(processor) - def disable_all(self): + def disable_all(self) -> None: + """ + disable all output processors + """ for output_processor in self.processors: self._disable_output_processor(output_processor) - def enable_all(self): + def enable_all(self) -> None: + """ + enable all output processors + """ for output_processor in self.processors: self._enable_output_processor(output_processor) - def enable(self, to_enable): + def enable(self, to_enable: Union[str, List[OutputProcessor], OutputProcessor]) -> None: + """ + enable the specified output processor + """ if isiterable(to_enable): - for inst in to_enable: - self._enable_output_processor(inst) + for inst in to_enable: # type:ignore + self._enable_output_processor(cast(OutputProcessor, inst)) else: - self._enable_output_processor(to_enable) + self._enable_output_processor(cast(OutputProcessor, to_enable)) - def disable(self, to_disable): + def disable(self, to_disable: Union[str, List[OutputProcessor], OutputProcessor]) -> None: + """ + disable the specified output processor + """ if isiterable(to_disable): - for inst in to_disable: - self._disable_output_processor(inst) + for inst in to_disable: # type:ignore + self._disable_output_processor(cast(OutputProcessor, inst)) else: - self._disable_output_processor(to_disable) + self._disable_output_processor(cast(OutputProcessor, to_disable)) - def get_output_processor(self, processor): + def get_output_processor(self, processor: Union[OutputProcessor, str]) -> OutputProcessor: + """ + get output processor + """ if isinstance(processor, OutputProcessor): return processor @@ -94,43 +128,73 @@ def get_output_processor(self, processor): return p raise ValueError('Output processor {} is not installed'.format(processor)) - def get_enabled(self): + def get_enabled(self) -> List[OutputProcessor]: + """ + get enabled output processors + """ return [p for p in self.processors if p.is_enabled] - def get_disabled(self): + def get_disabled(self) -> List[OutputProcessor]: + """ + get disabled output processors + """ return [p for p in self.processors if not p.is_enabled] - def validate(self): + def validate(self) -> None: + """ + validate the output processors + """ for proc in self.processors: proc.validate() - def initialize(self, context): + def initialize(self, context: Union['ExecutionContext', 'ProcessContext']) -> None: + """ + initialize the output processor + """ for proc in self.processors: proc.initialize(context) - def finalize(self, context): + def finalize(self, context: Union['ExecutionContext', 'ProcessContext']) -> None: + """ + finalize the output processor + """ for proc in self.processors: proc.finalize(context) - def process_job_output(self, context): + def process_job_output(self, context: Union['ExecutionContext', 'ProcessContext']) -> None: + """ + process job output + """ self.do_for_each_proc('process_job_output', 'Processing using "{}"', context.job_output, context.target_info, context.run_output) - def export_job_output(self, context): + def export_job_output(self, context: Union['ExecutionContext', 'ProcessContext']) -> None: + """ + export job output + """ self.do_for_each_proc('export_job_output', 'Exporting using "{}"', context.job_output, context.target_info, context.run_output) - def process_run_output(self, context): + def process_run_output(self, context: Union['ExecutionContext', 'ProcessContext']) -> None: + """ + process run output + """ self.do_for_each_proc('process_run_output', 'Processing using "{}"', context.run_output, context.target_info) - def export_run_output(self, context): + def export_run_output(self, context: Union['ExecutionContext', 'ProcessContext']) -> None: + """ + export run output + """ self.do_for_each_proc('export_run_output', 'Exporting using "{}"', context.run_output, context.target_info) - def do_for_each_proc(self, method_name, message, *args): + def do_for_each_proc(self, method_name: str, message: str, *args) -> None: + """ + run the method for each processor + """ with indentcontext(): for proc in self.processors: if proc.is_enabled: @@ -145,13 +209,19 @@ def do_for_each_proc(self, method_name, message, *args): raise log_error(e, self.logger) - def _enable_output_processor(self, inst): + def _enable_output_processor(self, inst: OutputProcessor) -> None: + """ + enable output processor + """ inst = self.get_output_processor(inst) self.logger.debug('Enabling output processor {}'.format(inst.name)) if not inst.is_enabled: inst.is_enabled = True - def _disable_output_processor(self, inst): + def _disable_output_processor(self, inst: OutputProcessor) -> None: + """ + disable output processor + """ inst = self.get_output_processor(inst) self.logger.debug('Disabling output processor {}'.format(inst.name)) if inst.is_enabled: diff --git a/wa/framework/plugin.py b/wa/framework/plugin.py index a59ded5e4..2f0e7ebce 100644 --- a/wa/framework/plugin.py +++ b/wa/framework/plugin.py @@ -23,7 +23,7 @@ from itertools import chain from copy import copy -from future.utils import with_metaclass +from future.utils import with_metaclass # type:ignore from wa.framework.configuration.core import settings, ConfigurationPoint as Parameter from wa.framework.exception import (NotFoundError, PluginLoaderError, TargetError, @@ -32,6 +32,13 @@ from wa.utils.misc import (ensure_directory_exists as _d, walk_modules, load_class, merge_dicts_simple, get_article, import_path) from wa.utils.types import identifier +from typing import (Optional, List, Dict, DefaultDict, Tuple, cast, + TYPE_CHECKING, Union, Type, Any) +from devlib.target import Target +from types import ModuleType +from typing_extensions import Protocol +if TYPE_CHECKING: + from wa.framework.output import Artifact class AttributeCollection(object): @@ -39,31 +46,37 @@ class AttributeCollection(object): Accumulator for plugin attribute objects (such as Parameters or Artifacts). This will replace any class member list accumulating such attributes - through the magic of metaprogramming\ [*]_. + through the magic of metaprogramming [*]_. .. [*] which is totally safe and not going backfire in any way... """ @property - def values(self): + def values(self) -> List[Any]: + """ + list of attribute values + """ return list(self._attrs.values()) - def __init__(self, attrcls): + def __init__(self, attrcls: Type): self._attrcls = attrcls - self._attrs = OrderedDict() + self._attrs: OrderedDict[str, Parameter] = OrderedDict() - def add(self, p): + def add(self, p: Parameter) -> None: + """ + add attribute to collection + """ p = self._to_attrcls(p) if p.name in self._attrs: if p.override: - newp = copy(self._attrs[p.name]) + newp: Parameter = copy(self._attrs[p.name]) for a, v in p.__dict__.items(): if v is not None: setattr(newp, a, v) if not hasattr(newp, "_overridden"): # pylint: disable=protected-access - newp._overridden = p._owner + newp._overridden = p._owner # type:ignore self._attrs[p.name] = newp else: # Duplicate attribute condition is check elsewhere. @@ -78,14 +91,17 @@ def __str__(self): __repr__ = __str__ - def _to_attrcls(self, p): + def _to_attrcls(self, p: Parameter) -> Parameter: + """ + convert parameter to the required class type + """ if not isinstance(p, self._attrcls): raise ValueError('Invalid attribute value: {}; must be a {}'.format(p, self._attrcls)) if p.name in self._attrs and not p.override: raise ValueError('Attribute {} has already been defined.'.format(p.name)) return p - def __iadd__(self, other): + def __iadd__(self, other: List[Parameter]): for p in other: self.add(p) return self @@ -93,10 +109,10 @@ def __iadd__(self, other): def __iter__(self): return iter(self.values) - def __contains__(self, p): + def __contains__(self, p: Parameter): return p in self._attrs - def __getitem__(self, i): + def __getitem__(self, i: str) -> Parameter: return self._attrs[i] def __len__(self): @@ -104,11 +120,13 @@ def __len__(self): class AliasCollection(AttributeCollection): - + """ + collection of aliases + """ def __init__(self): super(AliasCollection, self).__init__(Alias) - def _to_attrcls(self, p): + def _to_attrcls(self, p: Parameter) -> Parameter: if isinstance(p, (list, tuple)): # must be in the form (name, {param: value, ...}) # pylint: disable=protected-access @@ -134,12 +152,15 @@ class Alias(object): """ - def __init__(self, name, **kwargs): + def __init__(self, name: str, **kwargs): self.name = name - self.params = kwargs - self.plugin_name = None # gets set by the MetaClass + self.params: Dict[str, Any] = kwargs + self.plugin_name: Optional[str] = None # gets set by the MetaClass - def validate(self, ext): + def validate(self, ext: 'Plugin') -> None: + """ + validate the alias + """ ext_params = set(p.name for p in ext.parameters) for param in self.params: if param not in ext_params: @@ -162,18 +183,20 @@ class PluginMeta(type): """ - to_propagate = [ + to_propagate: List[Tuple[str, Type[Parameter], Type[AttributeCollection]]] = [ ('parameters', Parameter, AttributeCollection), ] - def __new__(mcs, clsname, bases, attrs): + def __new__(mcs: Type['PluginMeta'], clsname: str, bases: Tuple[Type, ...], + attrs: Dict[str, Any]): mcs._propagate_attributes(bases, attrs, clsname) cls = type.__new__(mcs, clsname, bases, attrs) - mcs._setup_aliases(cls) + mcs._setup_aliases(cast('Plugin', cls)) return cls @classmethod - def _propagate_attributes(mcs, bases, attrs, clsname): # pylint: disable=too-many-locals + def _propagate_attributes(mcs: Type['PluginMeta'], bases: Tuple[Type, ...], + attrs: Dict[str, Any], clsname: str) -> None: # pylint: disable=too-many-locals # pylint: disable=protected-access """ For attributes specified by to_propagate, their values will be a union of @@ -194,7 +217,7 @@ def _propagate_attributes(mcs, bases, attrs, clsname): # pylint: disable=too-ma if not isinstance(pa, attr_cls): msg = 'Invalid value "{}" for attribute "{}"; must be a {}' raise ValueError(msg.format(pa, prop_attr, attr_cls)) - pa._owner = clsname + pa._owner = clsname # type:ignore propagated += pattrs should_propagate = True if should_propagate: @@ -207,7 +230,10 @@ def _propagate_attributes(mcs, bases, attrs, clsname): # pylint: disable=too-ma attrs[prop_attr] = propagated @classmethod - def _setup_aliases(mcs, cls): + def _setup_aliases(mcs: Type['PluginMeta'], cls: 'Plugin') -> None: + """ + setup aliases for plugins + """ if hasattr(cls, 'aliases'): aliases, cls.aliases = cls.aliases, AliasCollection() for alias in aliases: @@ -215,10 +241,16 @@ def _setup_aliases(mcs, cls): alias = Alias(alias) alias.validate(cls) alias.plugin_name = cls.name - cls.aliases.add(alias) + cls.aliases.add(cast(Parameter, alias)) + + +class LoaderProtocol(Protocol): + def get_module(self, name: str, owner: 'Plugin', **kwargs) -> 'Plugin': + ... -class Plugin(with_metaclass(PluginMeta, object)): + +class Plugin(metaclass=PluginMeta): """ Base class for all WA plugins. An plugin is basically a plug-in. It extends the functionality of WA in some way. Plugins are discovered and @@ -230,29 +262,38 @@ class Plugin(with_metaclass(PluginMeta, object)): """ - kind = None - name = None - parameters = [] - artifacts = [] - aliases = [] - core_modules = [] + kind: Optional[str] = None + name: Optional[str] = None + parameters: List[Parameter] = [] + artifacts: List['Artifact'] = [] + aliases: Union[List[Alias], AliasCollection] = [] + core_modules: List[Union[str, Dict]] = [] @classmethod - def get_default_config(cls): + def get_default_config(cls) -> Dict[str, Any]: + """ + get default configuration + """ return {p.name: p.default for p in cls.parameters if not p.deprecated} @property - def dependencies_directory(self): - return _d(os.path.join(settings.dependencies_directory, self.name)) + def dependencies_directory(self) -> str: + """ + dependencies directory for the plugin + """ + return _d(os.path.join(settings.dependencies_directory, self.name or '')) @property - def _classname(self): + def _classname(self) -> str: + """ + name of the plugin class + """ return self.__class__.__name__ - def __init__(self, **kwargs): - self.logger = logging.getLogger(self.name) - self._modules = [] - self.capabilities = getattr(self.__class__, 'capabilities', []) + def __init__(self, **kwargs) -> None: + self.logger: logging.Logger = logging.getLogger(self.name) + self._modules: List['Plugin'] = [] + self.capabilities: List[str] = getattr(self.__class__, 'capabilities', []) for param in self.parameters: param.set_value(self, kwargs.get(param.name)) for key in kwargs: @@ -260,17 +301,17 @@ def __init__(self, **kwargs): message = 'Unexpected parameter "{}" for {}' raise ConfigError(message.format(key, self.name)) - def get_config(self): + def get_config(self) -> Dict[str, Optional[Parameter]]: """ Returns current configuration (i.e. parameter values) of this plugin. """ - config = {} + config: Dict[str, Optional[Parameter]] = {} for param in self.parameters: config[param.name] = getattr(self, param.name, None) return config - def validate(self): + def validate(self) -> None: """ Perform basic validation to ensure that this plugin is capable of running. This is intended as an early check to ensure the plugin has @@ -280,7 +321,7 @@ def validate(self): This method may also be used to enforce (i.e. set as well as check) inter-parameter constraints for the plugin (e.g. if valid values for parameter A depend on the value of parameter B -- something that is not - possible to enfroce using ``Parameter``\ 's ``constraint`` attribute. + possible to enfroce using ``Parameter`'s ``constraint`` attribute. """ if self.name is None: @@ -288,7 +329,7 @@ def validate(self): for param in self.parameters: param.validate(self) - def __getattr__(self, name): + def __getattr__(self, name: str) -> Any: if name == '_modules': raise ValueError('_modules accessed too early!') for module in self._modules: @@ -296,7 +337,7 @@ def __getattr__(self, name): return getattr(module, name) raise AttributeError(name) - def load_modules(self, loader): + def load_modules(self, loader: LoaderProtocol) -> None: """ Load the modules specified by the "modules" Parameter using the provided loader. A loader can be any object that has an atribute called @@ -309,17 +350,17 @@ def load_modules(self, loader): appropriate exception. """ - modules = list(reversed(self.core_modules)) + modules: List[Union[str, Dict]] = list(reversed(self.core_modules)) modules += list(reversed(self.modules or [])) if not modules: return for module_spec in modules: if not module_spec: continue - module = self._load_module(loader, module_spec) + module: 'Plugin' = self._load_module(loader, module_spec) self._install_module(module) - def has(self, capability): + def has(self, capability: str): """ Check if this plugin has the specified capability. The alternative method ``can`` is identical to this. Which to use is up to the caller @@ -331,18 +372,18 @@ def has(self, capability): can = has - def _load_module(self, loader, module_spec): + def _load_module(self, loader: LoaderProtocol, module_spec: Union[str, Dict]) -> 'Plugin': if isinstance(module_spec, str): - name = module_spec - params = {} + name: str = module_spec + params: Dict = {} elif isinstance(module_spec, dict): if len(module_spec) != 1: - msg = 'Invalid module spec: {}; dict must have exctly one key -- '\ - 'the module name.' + msg: str = 'Invalid module spec: {}; dict must have exctly one key -- '\ + 'the module name.' raise ValueError(msg.format(module_spec)) name, params = list(module_spec.items())[0] else: - message = 'Invalid module spec: {}; must be a string or a one-key dict.' + message: str = 'Invalid module spec: {}; must be a string or a one-key dict.' raise ValueError(message.format(module_spec)) if not isinstance(params, dict): @@ -353,16 +394,19 @@ def _load_module(self, loader, module_spec): module.initialize(None) return module - def _install_module(self, module): + def _install_module(self, module: 'Plugin') -> None: + """ + install the module + """ for capability in module.capabilities: if capability not in self.capabilities: self.capabilities.append(capability) self._modules.append(module) - def __str__(self): + def __str__(self) -> str: return str(self.name) - def __repr__(self): + def __repr__(self) -> str: params = [] for param in self.parameters: params.append('{}={}'.format(param.name, @@ -376,8 +420,8 @@ class TargetedPlugin(Plugin): """ - supported_targets = [] - parameters = [ + supported_targets: List[str] = [] + parameters: List[Parameter] = [ Parameter('cleanup_assets', kind=bool, global_alias='cleanup_assets', aliases=['clean_up'], @@ -389,13 +433,16 @@ class TargetedPlugin(Plugin): ] @classmethod - def check_compatible(cls, target): + def check_compatible(cls, target: Target) -> None: + """ + check target's compatibility with the plugin + """ if cls.supported_targets: if target.os not in cls.supported_targets: msg = 'Incompatible target OS "{}" for {}' raise TargetError(msg.format(target.os, cls.name)) - def __init__(self, target, **kwargs): + def __init__(self, target: Target, **kwargs): super(TargetedPlugin, self).__init__(**kwargs) self.check_compatible(target) self.target = target @@ -420,8 +467,9 @@ class PluginLoader(object): """ - def __init__(self, packages=None, paths=None, ignore_paths=None, - keep_going=False): + def __init__(self, packages: Optional[List[str]] = None, + paths: Optional[List[str]] = None, ignore_paths: Optional[List[str]] = None, + keep_going: bool = False): """ params:: @@ -439,17 +487,20 @@ def __init__(self, packages=None, paths=None, ignore_paths=None, self.packages = packages or [] self.paths = paths or [] self.ignore_paths = ignore_paths or [] - self.plugins = {} - self.kind_map = defaultdict(dict) - self.aliases = {} - self.global_param_aliases = {} + self.plugins: Dict[str, Type[Plugin]] = {} + self.kind_map: DefaultDict[str, + Dict[str, Type[Plugin]]] = defaultdict(dict) + self.aliases: Dict[str, Alias] = {} + self.global_param_aliases: Dict[str, Alias] = {} self._discover_from_packages(self.packages) self._discover_from_paths(self.paths, self.ignore_paths) - def update(self, packages=None, paths=None, ignore_paths=None): + def update(self, packages: Optional[List[str]] = None, + paths: Optional[List[str]] = None, + ignore_paths: Optional[List[str]] = None) -> None: """ Load plugins from the specified paths/packages without clearing or reloading existing plugin. """ - msg = 'Updating from: packages={} paths={}' + msg: str = 'Updating from: packages={} paths={}' self.logger.debug(msg.format(packages, paths)) if packages: self.packages.extend(packages) @@ -459,52 +510,54 @@ def update(self, packages=None, paths=None, ignore_paths=None): self.ignore_paths.extend(ignore_paths or []) self._discover_from_paths(paths, ignore_paths or []) - def clear(self): + def clear(self) -> None: """ Clear all discovered items. """ self.plugins = {} self.kind_map.clear() self.aliases.clear() self.global_param_aliases.clear() - def reload(self): + def reload(self) -> None: """ Clear all discovered items and re-run the discovery. """ self.logger.debug('Reloading') self.clear() self._discover_from_packages(self.packages) self._discover_from_paths(self.paths, self.ignore_paths) - def get_plugin_class(self, name, kind=None): + def get_plugin_class(self, name: Optional[str], + kind: Optional[str] = None) -> Type[Plugin]: """ Return the class for the specified plugin if found or raises ``ValueError``. """ - name, _ = self.resolve_alias(name) + name, _ = self.resolve_alias(name or '') if kind is None: try: - return self.plugins[name] + return self.plugins[str(name)] except KeyError: raise NotFoundError('plugins {} not found.'.format(name)) if kind not in self.kind_map: raise ValueError('Unknown plugin type: {}'.format(kind)) - store = self.kind_map[kind] + store: Dict[str, Type[Plugin]] = self.kind_map[kind] if name not in store: - msg = 'plugins {} is not {} {}.' + msg: str = 'plugins {} is not {} {}.' raise NotFoundError(msg.format(name, get_article(kind), kind)) return store[name] - def get_plugin(self, name=None, kind=None, *args, **kwargs): # pylint: disable=keyword-arg-before-vararg + def get_plugin(self, name: Optional[str] = None, + kind: Optional[str] = None, *args, **kwargs) -> Plugin: # pylint: disable=keyword-arg-before-vararg """ Return plugin of the specified kind with the specified name. Any additional parameters will be passed to the plugin's __init__. """ - name, base_kwargs = self.resolve_alias(name) + name, base_kwargs = self.resolve_alias(name or '') kwargs = OrderedDict(chain(iter(base_kwargs.items()), iter(kwargs.items()))) cls = self.get_plugin_class(name, kind) plugin = cls(*args, **kwargs) return plugin - def get_default_config(self, name): + def get_default_config(self, name: str) -> Dict: """ Returns the default configuration for the specified plugin name. The name may be an alias, in which case, the returned config will be @@ -515,7 +568,7 @@ def get_default_config(self, name): base_default_config = self.get_plugin_class(real_name).get_default_config() return merge_dicts_simple(base_default_config, alias_config) - def list_plugins(self, kind=None): + def list_plugins(self, kind: Optional[str] = None) -> List[Type[Plugin]]: """ List discovered plugin classes. Optionally, only list plugins of a particular type. @@ -527,7 +580,7 @@ def list_plugins(self, kind=None): raise ValueError('Unknown plugin type: {}'.format(kind)) return list(self.kind_map[kind].values()) - def has_plugin(self, name, kind=None): + def has_plugin(self, name: str, kind: Optional[str] = None) -> bool: """ Returns ``True`` if an plugins with the specified ``name`` has been discovered by the loader. If ``kind`` was specified, only returns ``True`` @@ -540,11 +593,11 @@ def has_plugin(self, name, kind=None): except NotFoundError: return False - def resolve_alias(self, alias_name): + def resolve_alias(self, alias_name: str) -> Tuple[str, Dict]: """ Try to resolve the specified name as an plugin alias. Returns a two-tuple, the first value of which is actual plugin name, and the - iisecond is a dict of parameter values for this alias. If the name passed + second is a dict of parameter values for this alias. If the name passed is already an plugin name, then the result is ``(alias_name, {})``. """ @@ -552,13 +605,13 @@ def resolve_alias(self, alias_name): if alias_name in self.plugins: return (alias_name, {}) if alias_name in self.aliases: - alias = self.aliases[alias_name] - return (alias.plugin_name, copy(alias.params)) + alias: Alias = self.aliases[alias_name] + return (cast(str, alias.plugin_name), copy(alias.params)) raise NotFoundError('Could not find plugin or alias "{}"'.format(alias_name)) # Internal methods. - def __getattr__(self, name): + def __getattr__(self, name: str): """ This resolves methods for specific plugins types based on corresponding generic plugin methods. So it's possible to say things like :: @@ -570,31 +623,34 @@ def __getattr__(self, name): loader.get_plugin('foo', kind='device') """ - error_msg = 'No plugins of type "{}" discovered' + error_msg: str = 'No plugins of type "{}" discovered' if name.startswith('get_'): name = name.replace('get_', '', 1) if name in self.kind_map: - def __wrapper(pname, *args, **kwargs): + def __wrapper(pname: str, *args, **kwargs): return self.get_plugin(pname, name, *args, **kwargs) return __wrapper raise NotFoundError(error_msg.format(name)) if name.startswith('list_'): name = name.replace('list_', '', 1).rstrip('s') if name in self.kind_map: - def __wrapper(*args, **kwargs): # pylint: disable=E0102 + def __list_plugins_wrapper(*args, **kwargs): # pylint: disable=E0102 return self.list_plugins(name, *args, **kwargs) - return __wrapper + return __list_plugins_wrapper raise NotFoundError(error_msg.format(name)) if name.startswith('has_'): name = name.replace('has_', '', 1) if name in self.kind_map: - def __wrapper(pname, *args, **kwargs): # pylint: disable=E0102 + def __has_plugin_wrapper(pname, *args, **kwargs): # pylint: disable=E0102 return self.has_plugin(pname, name, *args, **kwargs) - return __wrapper + return __has_plugin_wrapper raise NotFoundError(error_msg.format(name)) raise AttributeError(name) - def _discover_from_packages(self, packages): + def _discover_from_packages(self, packages: List[str]) -> None: + """ + discover plugins from packages + """ self.logger.debug('Discovering plugins in packages') try: for package in packages: @@ -602,10 +658,13 @@ def _discover_from_packages(self, packages): self._discover_in_module(module) except HostError as e: message = 'Problem loading plugins from {}: {}' - raise PluginLoaderError(message.format(e.module, str(e.orig_exc)), - e.exc_info) + raise PluginLoaderError(message.format(e.module, str(e.orig_exc)), e.exc_info) # type:ignore - def _discover_from_paths(self, paths, ignore_paths): + def _discover_from_paths(self, paths: Optional[List[str]], + ignore_paths: Optional[List[str]]) -> None: + """ + discover plugins in the specified paths + """ paths = paths or [] ignore_paths = ignore_paths or [] @@ -616,7 +675,7 @@ def _discover_from_paths(self, paths, ignore_paths): self._discover_from_file(path) elif os.path.exists(path): for root, _, files in os.walk(path, followlinks=True): - should_skip = False + should_skip: bool = False for igpath in ignore_paths: if root.startswith(igpath): should_skip = True @@ -635,22 +694,28 @@ def _discover_from_paths(self, paths, ignore_paths): except Exception: # NOQA pylint: disable=broad-except pass - def _discover_from_file(self, filepath): + def _discover_from_file(self, filepath: str) -> None: + """ + discover plugins from file + """ try: - module = import_path(filepath) + module: ModuleType = import_path(filepath) self._discover_in_module(module) except (SystemExit, ImportError) as e: if self.keep_going: self.logger.warning('Failed to load {}'.format(filepath)) self.logger.warning('Got: {}'.format(e)) else: - msg = 'Failed to load {}' + msg: str = 'Failed to load {}' raise PluginLoaderError(msg.format(filepath), sys.exc_info()) except Exception as e: - message = 'Problem loading plugins from {}: {}' + message: str = 'Problem loading plugins from {}: {}' raise PluginLoaderError(message.format(filepath, e)) - def _discover_in_module(self, module): # NOQA pylint: disable=too-many-branches + def _discover_in_module(self, module: ModuleType): # NOQA pylint: disable=too-many-branches + """ + discover plugi in a module + """ self.logger.debug('Checking module %s', module.__name__) with log.indentcontext(): for obj in vars(module).values(): @@ -675,13 +740,13 @@ def _discover_in_module(self, module): # NOQA pylint: disable=too-many-branches else: raise e - def _add_found_plugin(self, obj): + def _add_found_plugin(self, obj: Type[Plugin]) -> None: """ :obj: Found plugin class :ext: matching plugin item. """ self.logger.debug('Adding %s %s', obj.kind, obj.name) - key = identifier(obj.name.lower()) + key = identifier(obj.name.lower() if obj.name else '') if key in self.plugins or key in self.aliases: msg = '{} "{}" already exists.' raise PluginLoaderError(msg.format(obj.kind, obj.name)) @@ -689,7 +754,7 @@ def _add_found_plugin(self, obj): # dict, and in per-plugin kind dict (as retrieving # plugins by kind is a common use case. self.plugins[key] = obj - self.kind_map[obj.kind][key] = obj + self.kind_map[obj.kind or ''][key] = obj for alias in obj.aliases: alias_id = identifier(alias.name.lower()) diff --git a/wa/framework/pluginloader.py b/wa/framework/pluginloader.py index 45b1a027e..ceb6e0239 100644 --- a/wa/framework/pluginloader.py +++ b/wa/framework/pluginloader.py @@ -13,26 +13,41 @@ # limitations under the License. # import sys +from typing import (Optional, List, Tuple, Dict, cast, Type, + DefaultDict, Any) +from types import ModuleType +from wa.framework import plugin class __LoaderWrapper(object): - + """ + wrapper around plugin loader + """ @property - def kinds(self): + def kinds(self) -> List[str]: + """ + kinds of plugins + """ if not self._loader: self.reset() - return list(self._loader.kind_map.keys()) + return list(self._loader.kind_map.keys()) if self._loader else [] @property - def kind_map(self): + def kind_map(self) -> DefaultDict[str, Dict[str, Type[plugin.Plugin]]]: + """ + map from plugin name to the type + """ if not self._loader: self.reset() - return self._loader.kind_map + return self._loader.kind_map if self._loader else cast(DefaultDict, {}) - def __init__(self): - self._loader = None + def __init__(self) -> None: + self._loader: Optional[plugin.PluginLoader] = None def reset(self): + """ + reset the plugin loader + """ # These imports cannot be done at top level, because of # sys.modules manipulation below # pylint: disable=import-outside-toplevel @@ -41,50 +56,97 @@ def reset(self): self._loader = PluginLoader(settings.plugin_packages, settings.plugin_paths, []) - def update(self, packages=None, paths=None, ignore_paths=None): + def update(self, packages: Optional[List[str]] = None, + paths: Optional[List[str]] = None, + ignore_paths: Optional[List[str]] = None) -> None: + """ + update the internal plugins with new plugins loaded + """ if not self._loader: self.reset() - self._loader.update(packages, paths, ignore_paths) + if self._loader: + self._loader.update(packages, paths, ignore_paths) - def reload(self): + def reload(self) -> None: + """ + reload the plugins + """ if not self._loader: self.reset() - self._loader.reload() + if self._loader: + self._loader.reload() - def list_plugins(self, kind=None): + def list_plugins(self, kind: Optional[str] = None) -> List[Type[plugin.Plugin]]: + """ + List the plugins loaded + """ if not self._loader: self.reset() - return self._loader.list_plugins(kind) - - def has_plugin(self, name, kind=None): + if self._loader: + return self._loader.list_plugins(kind) + else: + return [] + + def has_plugin(self, name: str, kind: Optional[str] = None) -> bool: + """ + True if the plugin of the given name and kind is already loaded + """ if not self._loader: self.reset() - return self._loader.has_plugin(name, kind) - - def get_plugin_class(self, name, kind=None): + if self._loader: + return self._loader.has_plugin(name, kind) + return False + + def get_plugin_class(self, name: str, kind: Optional[str] = None) -> Type[plugin.Plugin]: + """ + get the class type of the plugin + """ if not self._loader: self.reset() - return self._loader.get_plugin_class(name, kind) - - def get_plugin(self, name=None, kind=None, *args, **kwargs): # pylint: disable=keyword-arg-before-vararg + if self._loader: + return self._loader.get_plugin_class(name, kind) + else: + return plugin.Plugin # dummy to satisfy type checker. + + def get_plugin(self, name: Optional[str] = None, + kind: Optional[str] = None, *args, **kwargs) -> Optional[plugin.Plugin]: # pylint: disable=keyword-arg-before-vararg + """ + get plugin of the specified name + """ if not self._loader: self.reset() - return self._loader.get_plugin(name=name, kind=kind, *args, **kwargs) - - def get_default_config(self, name): + if self._loader: + return self._loader.get_plugin(name, kind, *args, **kwargs) + return None + + def get_default_config(self, name: str) -> Optional[Dict[str, Any]]: + """ + get default configuration + """ if not self._loader: self.reset() - return self._loader.get_default_config(name) - - def resolve_alias(self, name): + if self._loader: + return self._loader.get_default_config(name) + return None + + def resolve_alias(self, name: str) -> Tuple[str, Dict]: + """ + resolve Alias of the plugin + """ if not self._loader: self.reset() - return self._loader.resolve_alias(name) - - def __getattr__(self, name): + if self._loader: + return self._loader.resolve_alias(name) + else: + return ('', {}) # dummy to satisfy type checker + + def __getattr__(self, name: str) -> Any: + """ + get attribute with the specified name + """ if not self._loader: self.reset() return getattr(self._loader, name) -sys.modules[__name__] = __LoaderWrapper() +sys.modules[__name__] = cast(ModuleType, __LoaderWrapper()) diff --git a/wa/framework/resource.py b/wa/framework/resource.py index c69ded8cd..c95206ebd 100644 --- a/wa/framework/resource.py +++ b/wa/framework/resource.py @@ -21,10 +21,11 @@ from wa.framework.exception import ResourceError from wa.framework.configuration import settings from wa.utils import log -from wa.utils.android import get_cacheable_apk_info +from wa.utils.android import get_cacheable_apk_info, ApkInfo from wa.utils.misc import get_object_name from wa.utils.types import enum, list_or_string, prioritylist, version_tuple - +from typing import Optional, List, Union +from types import ModuleType SourcePriority = enum(['package', 'remote', 'lan', 'local', 'perferred'], start=0, step=10) @@ -33,10 +34,10 @@ class __NullOwner(object): """Represents an owner for a resource not owned by anyone.""" - name = 'noone' - dependencies_directory = settings.dependencies_directory + name: str = 'noone' + dependencies_directory: str = settings.dependencies_directory - def __getattr__(self, name): + def __getattr__(self, name: str): return None def __str__(self): @@ -59,15 +60,21 @@ class Resource(object): """ - kind = None + kind: Optional[str] = None - def __init__(self, owner=NO_ONE): + def __init__(self, owner: object = NO_ONE): self.owner = owner - def match(self, path): + def match(self, path: str): + """ + match the resource path + """ return self.match_path(path) - def match_path(self, path): + def match_path(self, path: str) -> bool: + """ + match the resource path + """ raise NotImplementedError() def __str__(self): @@ -75,14 +82,16 @@ def __str__(self): class File(Resource): + """ + File resource + """ + kind: str = 'file' - kind = 'file' - - def __init__(self, owner, path): + def __init__(self, owner: object, path: str): super(File, self).__init__(owner) self.path = path - def match_path(self, path): + def match_path(self, path: str) -> bool: return self.path == path def __str__(self): @@ -90,15 +99,17 @@ def __str__(self): class Executable(Resource): + """ + Executable resource + """ + kind: str = 'executable' - kind = 'executable' - - def __init__(self, owner, abi, filename): + def __init__(self, owner: object, abi: str, filename: str): super(Executable, self).__init__(owner) self.abi = abi self.filename = filename - def match_path(self, path): + def match_path(self, path: str) -> bool: return self.filename == os.path.basename(path) def __str__(self): @@ -106,15 +117,17 @@ def __str__(self): class ReventFile(Resource): + """ + Revent File resource + """ + kind: str = 'revent' - kind = 'revent' - - def __init__(self, owner, stage, target): + def __init__(self, owner: object, stage: str, target: Optional[str]): super(ReventFile, self).__init__(owner) self.stage = stage self.target = target - def match_path(self, path): + def match_path(self, path: str) -> bool: filename = os.path.basename(path) parts = filename.split('.') if len(parts) > 2: @@ -126,22 +139,27 @@ def match_path(self, path): class JarFile(Resource): + """ + Jar file resource + """ + kind: str = 'jar' - kind = 'jar' - - def match_path(self, path): + def match_path(self, path: str) -> bool: # An owner always has at most one jar file, so # always match return True class ApkFile(Resource): + """ + Apk file resource + """ + kind: str = 'apk' - kind = 'apk' - - def __init__(self, owner, variant=None, version=None, - package=None, uiauto=False, exact_abi=False, - supported_abi=None, min_version=None, max_version=None): + def __init__(self, owner: object, variant: Optional[str] = None, + version: Optional[Union[str, List[str]]] = None, package: Optional[str] = None, + uiauto: bool = False, exact_abi: bool = False, supported_abi: Optional[List[Optional[str]]] = None, + min_version: Optional[str] = None, max_version: Optional[str] = None): super(ApkFile, self).__init__(owner) self.variant = variant self.version = version @@ -152,17 +170,17 @@ def __init__(self, owner, variant=None, version=None, self.exact_abi = exact_abi self.supported_abi = supported_abi - def match_path(self, path): + def match_path(self, path: str) -> bool: ext = os.path.splitext(path)[1].lower() return ext == '.apk' - def match(self, path): - name_matches = True - version_matches = True - version_range_matches = True - package_matches = True - abi_matches = True - uiauto_matches = uiauto_test_matches(path, self.uiauto) + def match(self, path: str) -> bool: + name_matches: bool = True + version_matches: bool = True + version_range_matches: bool = True + package_matches: bool = True + abi_matches: bool = True + uiauto_matches: bool = uiauto_test_matches(path, self.uiauto) if self.version: version_matches = apk_version_matches(path, self.version) if self.max_version or self.min_version: @@ -179,7 +197,7 @@ def match(self, path): version_range_matches and uiauto_matches \ and package_matches and abi_matches - def __str__(self): + def __str__(self) -> str: text = '<{}\'s apk'.format(self.owner) if self.variant: text += ' {}'.format(self.variant) @@ -211,16 +229,22 @@ class ResourceGetter(Plugin): """ - name = None - kind = 'resource_getter' + name: Optional[str] = None + kind: str = 'resource_getter' - def register(self, resolver): + def register(self, resolver: 'ResourceResolver'): + """ + register a resource resolver to the getter + """ raise NotImplementedError() - def initialize(self): + def initialize(self) -> None: + """ + initialize the getter + """ pass - def __str__(self): + def __str__(self) -> str: return ''.format(self.name) @@ -231,28 +255,34 @@ class ResourceResolver(object): """ - def __init__(self, loader=pluginloader): + def __init__(self, loader: ModuleType = pluginloader): self.loader = loader self.logger = logging.getLogger('resolver') - self.getters = [] + self.getters: List[ResourceGetter] = [] self.sources = prioritylist() - def load(self): + def load(self) -> None: + """ + load the resource getters to the resolver + """ for gettercls in self.loader.list_plugins('resource_getter'): self.logger.debug('Loading getter {}'.format(gettercls.name)) - getter = self.loader.get_plugin(name=gettercls.name, - kind="resource_getter") + getter: ResourceGetter = self.loader.get_plugin(name=gettercls.name, + kind="resource_getter") with log.indentcontext(): getter.initialize() getter.register(self) self.getters.append(getter) - def register(self, source, priority=SourcePriority.local): - msg = 'Registering "{}" with priority "{}"' + def register(self, source: object, priority=SourcePriority.local) -> None: + """ + register the source + """ + msg: str = 'Registering "{}" with priority "{}"' self.logger.debug(msg.format(get_object_name(source), priority)) self.sources.add(source, priority) - def get(self, resource, strict=True): + def get(self, resource: Resource, strict: bool = True) -> Optional[str]: """ Uses registered getters to attempt to discover a resource of the specified kind and matching the specified criteria. Returns path to the resource that @@ -263,11 +293,11 @@ def get(self, resource, strict=True): """ self.logger.debug('Resolving {}'.format(resource)) for source in self.sources: - source_name = get_object_name(source) + source_name: Optional[str] = get_object_name(source) self.logger.debug('Trying {}'.format(source_name)) - result = source(resource) + result: str = source(resource) if result is not None: - msg = 'Resource {} found using {}:' + msg: str = 'Resource {} found using {}:' self.logger.debug(msg.format(resource, source_name)) self.logger.debug('\t{}'.format(result)) return result @@ -277,52 +307,70 @@ def get(self, resource, strict=True): return None -def apk_version_matches(path, version): - version = list_or_string(version) - info = get_cacheable_apk_info(path) - for v in version: - if v in (info.version_name, info.version_code): - return True - if loose_version_matching(v, info.version_name): - return True +def apk_version_matches(path: str, version: Union[str, List[str]]): + """ + check apk version matches + """ + version_ = list_or_string(version) + info: Optional[ApkInfo] = get_cacheable_apk_info(path) + for v in version_: + if info is not None: + if v in (info.version_name, info.version_code): + return True + if loose_version_matching(v, info.version_name): + return True return False -def apk_version_matches_range(path, min_version=None, max_version=None): +def apk_version_matches_range(path: str, min_version: Optional[str] = None, + max_version: Optional[str] = None) -> bool: + """ + check if the apk version matches the range of versions + """ info = get_cacheable_apk_info(path) - return range_version_matching(info.version_name, min_version, max_version) + return range_version_matching(info.version_name if info else '', min_version, max_version) -def range_version_matching(apk_version, min_version=None, max_version=None): +def range_version_matching(apk_version: Optional[str], min_version: Optional[str] = None, + max_version: Optional[str] = None): + """ + check if the apk version matches the range of versions + """ if not apk_version: return False - apk_version = version_tuple(apk_version) + apk_version_tuple = version_tuple(apk_version or '') if max_version: - max_version = version_tuple(max_version) - if apk_version > max_version: + max_version_tuple = version_tuple(max_version) + if apk_version_tuple > max_version_tuple: return False if min_version: - min_version = version_tuple(min_version) - if apk_version < min_version: + min_version_tuple = version_tuple(min_version) + if apk_version_tuple < min_version_tuple: return False return True -def loose_version_matching(config_version, apk_version): - config_version = version_tuple(config_version) - apk_version = version_tuple(apk_version) +def loose_version_matching(config_version: str, apk_version: Optional[str]) -> bool: + """ + check version matching loosely + """ + config_version_tuple = version_tuple(config_version) + apk_version_tuple = version_tuple(apk_version or '') - if len(apk_version) < len(config_version): + if len(apk_version_tuple) < len(config_version_tuple): return False # More specific version requested than available - for i in range(len(config_version)): - if config_version[i] != apk_version[i]: + for i in range(len(config_version_tuple)): + if config_version_tuple[i] != apk_version_tuple[i]: return False return True -def file_name_matches(path, pattern): +def file_name_matches(path: str, pattern: str) -> bool: + """ + check file name matches pattern + """ filename = os.path.basename(path) if pattern in filename: return True @@ -331,27 +379,43 @@ def file_name_matches(path, pattern): return False -def uiauto_test_matches(path, uiauto): +def uiauto_test_matches(path: str, uiauto: bool) -> bool: + """ + check uiauto matches + """ info = get_cacheable_apk_info(path) - return uiauto == ('com.arm.wa.uiauto' in info.package) + if info is None: + return False + return uiauto == ('com.arm.wa.uiauto' in (info.package or '')) -def package_name_matches(path, package): +def package_name_matches(path: str, package: str) -> bool: + """ + check if package name matches + """ info = get_cacheable_apk_info(path) + if info is None: + return False return info.package == package -def apk_abi_matches(path, supported_abi, exact_abi=False): - supported_abi = list_or_string(supported_abi) +def apk_abi_matches(path: str, supported_abi: Union[str, List[Optional[str]]], + exact_abi: bool = False) -> bool: + """ + check apk abi matches + """ + supported_abi_ = list_or_string(supported_abi) info = get_cacheable_apk_info(path) + if info is None: + return False # If no native code present, suitable for all devices. if not info.native_code: return True if exact_abi: # Only check primary - return supported_abi[0] in info.native_code + return supported_abi_[0] in info.native_code else: - for abi in supported_abi: + for abi in supported_abi_: if abi in info.native_code: return True return False diff --git a/wa/framework/run.py b/wa/framework/run.py index 059ed45b8..49e91ac60 100644 --- a/wa/framework/run.py +++ b/wa/framework/run.py @@ -23,6 +23,11 @@ from wa.framework.configuration.core import Status from wa.utils.serializer import Podable +from typing import (cast, TYPE_CHECKING, OrderedDict as od, Tuple, + Optional, Dict, Any, Union) +if TYPE_CHECKING: + from wa.framework.job import Job + from wa.framework.configuration.core import StatusType class RunInfo(Podable): @@ -30,25 +35,53 @@ class RunInfo(Podable): Information about the current run, such as its unique ID, run time, etc. + The :class:`RunInfo` provides general run information. It has the following + attributes: + + ``uuid`` + A unique identifier for that particular run. + + ``run_name`` + The name of the run (if provided) + + ``project`` + The name of the project the run belongs to (if provided) + + ``project_stage`` + The project stage the run is associated with (if provided) + + ``duration`` + The length of time the run took to complete. + + ``start_time`` + The time the run was stared. + + ``end_time`` + The time at which the run finished. + """ - _pod_serialization_version = 1 + _pod_serialization_version: int = 1 @staticmethod - def from_pod(pod): + def from_pod(pod: Dict[str, Any]) -> 'RunInfo': + """ + create Runinfo from pod + """ pod = RunInfo._upgrade_pod(pod) - uid = pod.pop('uuid') + uid: str = pod.pop('uuid') _pod_version = pod.pop('_pod_version') duration = pod.pop('duration') if uid is not None: - uid = uuid.UUID(uid) + uid_ = uuid.UUID(uid) instance = RunInfo(**pod) instance._pod_version = _pod_version # pylint: disable=protected-access - instance.uuid = uid + instance.uuid = uid_ instance.duration = duration if duration is None else timedelta(seconds=duration) return instance - def __init__(self, run_name=None, project=None, project_stage=None, - start_time=None, end_time=None, duration=None): + def __init__(self, run_name: Optional[str] = None, project: Optional[str] = None, + project_stage: Optional[Union[Dict, str]] = None, start_time: Optional[datetime] = None, + end_time: Optional[datetime] = None, duration: Optional[timedelta] = None): super(RunInfo, self).__init__() self.uuid = uuid.uuid4() self.run_name = run_name @@ -58,8 +91,11 @@ def __init__(self, run_name=None, project=None, project_stage=None, self.end_time = end_time self.duration = duration - def to_pod(self): - d = super(RunInfo, self).to_pod() + def to_pod(self) -> Dict[str, Any]: + """ + create pod from RunInfo + """ + d: Dict[str, Any] = super(RunInfo, self).to_pod() d.update(copy(self.__dict__)) d['uuid'] = str(self.uuid) if self.duration is None: @@ -69,7 +105,10 @@ def to_pod(self): return d @staticmethod - def _pod_upgrade_v1(pod): + def _pod_upgrade_v1(pod: Dict[str, Any]) -> Dict[str, Any]: + """ + pod upgrade function version 1 + """ pod['_pod_version'] = pod.get('_pod_version', 1) return pod @@ -79,11 +118,14 @@ class RunState(Podable): Represents the state of a WA run. """ - _pod_serialization_version = 1 + _pod_serialization_version: int = 1 @staticmethod - def from_pod(pod): - instance = super(RunState, RunState).from_pod(pod) + def from_pod(pod) -> 'RunState': + """ + create RunState from pod + """ + instance = cast('RunState', super(RunState, RunState).from_pod(pod)) instance.status = Status.from_pod(pod['status']) instance.timestamp = pod['timestamp'] jss = [JobState.from_pod(j) for j in pod['jobs']] @@ -91,26 +133,38 @@ def from_pod(pod): return instance @property - def num_completed_jobs(self): + def num_completed_jobs(self) -> int: + """ + number of completed jobs in the current run + """ return sum(1 for js in self.jobs.values() if js.status > Status.RUNNING) - def __init__(self): + def __init__(self) -> None: super(RunState, self).__init__() - self.jobs = OrderedDict() - self.status = Status.NEW + self.jobs: od[Tuple[str, int], 'JobState'] = OrderedDict() + self.status: 'StatusType' = Status.NEW self.timestamp = datetime.utcnow() - def add_job(self, job): + def add_job(self, job: 'Job') -> None: + """ + add job to the run state + """ self.jobs[(job.state.id, job.state.iteration)] = job.state - def get_status_counts(self): - counter = Counter() + def get_status_counts(self) -> Counter: + """ + get status counter + """ + counter: Counter = Counter() for job_state in self.jobs.values(): counter[job_state.status] += 1 return counter - def to_pod(self): + def to_pod(self) -> Dict[str, Any]: + """ + convert RunState to pod + """ pod = super(RunState, self).to_pod() pod['status'] = self.status.to_pod() pod['timestamp'] = self.timestamp @@ -118,18 +172,26 @@ def to_pod(self): return pod @staticmethod - def _pod_upgrade_v1(pod): + def _pod_upgrade_v1(pod: Dict[str, Any]) -> Dict[str, Any]: + """ + pod upgrade function version 1 + """ pod['_pod_version'] = pod.get('_pod_version', 1) pod['status'] = Status(pod['status']).to_pod() return pod class JobState(Podable): - - _pod_serialization_version = 1 + """ + state of the running job + """ + _pod_serialization_version: int = 1 @staticmethod - def from_pod(pod): + def from_pod(pod: Dict[str, Any]) -> 'JobState': + """ + create a JobState from pod + """ pod = JobState._upgrade_pod(pod) instance = JobState(pod['id'], pod['label'], pod['iteration'], Status.from_pod(pod['status'])) @@ -138,10 +200,10 @@ def from_pod(pod): return instance @property - def output_name(self): + def output_name(self) -> str: return '{}-{}-{}'.format(self.id, self.label, self.iteration) - def __init__(self, id, label, iteration, status): + def __init__(self, id: str, label: str, iteration: int, status: 'StatusType'): # pylint: disable=redefined-builtin super(JobState, self).__init__() self.id = id @@ -151,18 +213,24 @@ def __init__(self, id, label, iteration, status): self.retries = 0 self.timestamp = datetime.utcnow() - def to_pod(self): + def to_pod(self) -> Dict[str, Any]: + """ + convert JobState to pod + """ pod = super(JobState, self).to_pod() pod['id'] = self.id pod['label'] = self.label pod['iteration'] = self.iteration - pod['status'] = self.status.to_pod() + pod['status'] = cast(Podable, self.status).to_pod() pod['retries'] = self.retries pod['timestamp'] = self.timestamp return pod @staticmethod - def _pod_upgrade_v1(pod): + def _pod_upgrade_v1(pod: Dict[str, Any]) -> Dict[str, Any]: + """ + pod upgrade function version 1 + """ pod['_pod_version'] = pod.get('_pod_version', 1) pod['status'] = Status(pod['status']).to_pod() return pod diff --git a/wa/framework/signal.py b/wa/framework/signal.py index 71f13c64b..4de415063 100644 --- a/wa/framework/signal.py +++ b/wa/framework/signal.py @@ -24,13 +24,14 @@ from contextlib import contextmanager from louie import dispatcher, saferef # pylint: disable=wrong-import-order -from louie.dispatcher import _remove_receiver -import wrapt +from louie.dispatcher import _remove_receiver # type:ignore +from louie.signal import All # type: ignore +import wrapt # type: ignore from wa.utils.types import prioritylist, enum +from typing import cast, Type, Dict, Callable, List, Tuple, Optional - -logger = logging.getLogger('signal') +logger: logging.Logger = logging.getLogger('signal') class Signal(object): @@ -41,7 +42,7 @@ class Signal(object): """ - def __init__(self, name, description='no description', invert_priority=False): + def __init__(self, name: str, description: str = 'no description', invert_priority: bool = False): """ Instantiates a Signal. @@ -61,12 +62,12 @@ def __init__(self, name, description='no description', invert_priority=False): self.description = description self.invert_priority = invert_priority - def __str__(self): + def __str__(self) -> str: return self.name __repr__ = __str__ - def __hash__(self): + def __hash__(self) -> int: return id(self.name) @@ -199,7 +200,8 @@ def append(self, *args, **kwargs): pass -def connect(handler, signal, sender=dispatcher.Any, priority=0): +def connect(handler: Callable, signal: Signal, sender: Type[dispatcher.Any] = dispatcher.Any, + priority: int = 0) -> None: """ Connects a callback to a signal, so that the callback will be automatically invoked when that signal is sent. @@ -234,20 +236,20 @@ def connect(handler, signal, sender=dispatcher.Any, priority=0): logger.debug('Connecting {} to {}({}) with priority {}'.format(handler, signal, sender, priority)) if getattr(signal, 'invert_priority', False): priority = -priority - senderkey = id(sender) + senderkey: int = id(sender) if senderkey in dispatcher.connections: - signals = dispatcher.connections[senderkey] + signals: Dict[Signal, prioritylist] = dispatcher.connections[senderkey] else: dispatcher.connections[senderkey] = signals = {} if signal in signals: receivers = signals[signal] else: receivers = signals[signal] = _prioritylist_wrapper() - dispatcher.connect(handler, signal, sender) + dispatcher.connect(handler, cast(Type[All], signal), sender) receivers.add(saferef.safe_ref(handler, on_delete=_remove_receiver), priority) -def disconnect(handler, signal, sender=dispatcher.Any): +def disconnect(handler: Callable, signal: Signal, sender: Type[dispatcher.Any] = dispatcher.Any) -> None: """ Disconnect a previously connected handler form the specified signal, optionally, only for the specified sender. @@ -262,10 +264,11 @@ def disconnect(handler, signal, sender=dispatcher.Any): """ logger.debug('Disconnecting {} from {}({})'.format(handler, signal, sender)) - dispatcher.disconnect(handler, signal, sender) + dispatcher.disconnect(handler, cast(Type[All], signal), sender) -def send(signal, sender=dispatcher.Anonymous, *args, **kwargs): +def send(signal: Signal, sender: Type[dispatcher.Anonymous] = dispatcher.Anonymous, + *args, **kwargs) -> List[Tuple]: """ Sends a signal, causing connected handlers to be invoked. @@ -280,16 +283,16 @@ def send(signal, sender=dispatcher.Anonymous, *args, **kwargs): """ logger.debug('Sending {} from {}'.format(signal, sender)) - return dispatcher.send(signal, sender, *args, **kwargs) + return dispatcher.send(cast(Type[All], signal), sender, *args, **kwargs) # This will normally be set to log_error() by init_logging(); see wa.utils.log # Done this way to prevent a circular import dependency. -log_error_func = logger.error +log_error_func: Callable = logger.error -def safe_send(signal, sender=dispatcher.Anonymous, - propagate=None, *args, **kwargs): +def safe_send(signal: Signal, sender: Type[dispatcher.Anonymous] = dispatcher.Anonymous, + propagate: Optional[List[Type[BaseException]]] = None, *args, **kwargs) -> None: """ Same as ``send``, except this will catch and log all exceptions raised by handlers, except those specified in ``propagate`` argument (defaults @@ -307,16 +310,16 @@ def safe_send(signal, sender=dispatcher.Anonymous, @contextmanager -def wrap(signal_name, sender=dispatcher.Anonymous, *args, **kwargs): # pylint: disable=keyword-arg-before-vararg +def wrap(signal_name: str, sender: Type[dispatcher.Anonymous] = dispatcher.Anonymous, *args, **kwargs): # pylint: disable=keyword-arg-before-vararg """Wraps the suite in before/after signals, ensuring that after signal is always sent.""" - safe = kwargs.pop('safe', False) + safe: bool = kwargs.pop('safe', False) signal_name = signal_name.upper().replace('-', '_') - send_func = safe_send if safe else send + send_func: Callable = safe_send if safe else send try: - before_signal = globals()['BEFORE_' + signal_name] - success_signal = globals()['SUCCESSFUL_' + signal_name] - after_signal = globals()['AFTER_' + signal_name] + before_signal: Signal = globals()['BEFORE_' + signal_name] + success_signal: Signal = globals()['SUCCESSFUL_' + signal_name] + after_signal: Signal = globals()['AFTER_' + signal_name] except KeyError: raise ValueError('Invalid wrapped signal name: {}'.format(signal_name)) try: @@ -330,7 +333,7 @@ def wrap(signal_name, sender=dispatcher.Anonymous, *args, **kwargs): # pylint: send_func(after_signal, sender, *args, **kwargs) -def wrapped(signal_name, sender=dispatcher.Anonymous, safe=False): +def wrapped(signal_name: str, sender: Type[dispatcher.Anonymous] = dispatcher.Anonymous, safe: bool = False) -> Callable: """A decorator for wrapping function in signal dispatch.""" @wrapt.decorator def signal_wrapped(wrapped_func, _, args, kwargs): diff --git a/wa/framework/target/assistant.py b/wa/framework/target/assistant.py index 3fb51a5d3..69fe15cfb 100644 --- a/wa/framework/target/assistant.py +++ b/wa/framework/target/assistant.py @@ -26,34 +26,57 @@ from wa.utils.android import LogcatParser from wa.utils.misc import touch import wa.framework.signal as signal +from typing import List, Optional, TYPE_CHECKING +from devlib.target import LinuxTarget, AndroidTarget, ChromeOsTarget +if TYPE_CHECKING: + from wa.framework.execution import ExecutionContext class LinuxAssistant(object): + """ + assistant to connect to a linux target + """ + parameters: List[Parameter] = [] - parameters = [] - - def __init__(self, target): + def __init__(self, target: LinuxTarget): self.target = target def initialize(self): + """ + initialize the target + """ pass def start(self): + """ + start the target + """ pass - def extract_results(self, context): + def extract_results(self, context: 'ExecutionContext'): + """ + extract results of execution from the target + """ pass def stop(self): + """ + stop the target + """ pass def finalize(self): + """ + finalize the target + """ pass class AndroidAssistant(object): - - parameters = [ + """ + assistant to connect to an android target + """ + parameters: List[Parameter] = [ Parameter('disable_selinux', kind=bool, default=True, description=""" If ``True``, the default, and the target is rooted, an attempt will @@ -91,27 +114,34 @@ class AndroidAssistant(object): """), ] - def __init__(self, target, logcat_poll_period=None, disable_selinux=True, stay_on_mode=None): + def __init__(self, target: AndroidTarget, logcat_poll_period: Optional[int] = None, + disable_selinux: bool = True, stay_on_mode: Optional[int] = None): self.target = target self.logcat_poll_period = logcat_poll_period self.disable_selinux = disable_selinux self.stay_on_mode = stay_on_mode - self.orig_stay_on_mode = self.target.get_stay_on_mode() if stay_on_mode is not None else None - self.logcat_poller = None - self.logger = logging.getLogger('logcat') - self._logcat_marker_msg = None - self._logcat_marker_tag = None + self.orig_stay_on_mode: Optional[int] = self.target.get_stay_on_mode() if stay_on_mode is not None else None + self.logcat_poller: Optional[LogcatPoller] = None + self.logger: logging.Logger = logging.getLogger('logcat') + self._logcat_marker_msg: Optional[str] = None + self._logcat_marker_tag: Optional[str] = None signal.connect(self._before_workload, signal.BEFORE_WORKLOAD_EXECUTION) if self.logcat_poll_period: signal.connect(self._after_workload, signal.AFTER_WORKLOAD_EXECUTION) - def initialize(self): + def initialize(self) -> None: + """ + initialize the android target + """ if self.target.is_rooted and self.disable_selinux: self.do_disable_selinux() if self.stay_on_mode is not None: self.target.set_stay_on_mode(self.stay_on_mode) - def start(self): + def start(self) -> None: + """ + start the android target + """ if self.logcat_poll_period: self.logcat_poller = LogcatPoller(self.target, self.logcat_poll_period) self.logcat_poller.start() @@ -120,15 +150,24 @@ def start(self): self._logcat_marker_msg = 'WA logcat marker for wrap detection' self._logcat_marker_tag = 'WAlog' - def stop(self): + def stop(self) -> None: + """ + stop the android target + """ if self.logcat_poller: self.logcat_poller.stop() def finalize(self): + """ + finalize the android target + """ if self.stay_on_mode is not None: self.target.set_stay_on_mode(self.orig_stay_on_mode) - def extract_results(self, context): + def extract_results(self, context: 'ExecutionContext'): + """ + extract execution results from android target + """ logcat_file = os.path.join(context.output_directory, 'logcat.log') self.dump_logcat(logcat_file) context.add_artifact('logcat', logcat_file, kind='log') @@ -139,28 +178,44 @@ def extract_results(self, context): ' inaccurate or incomplete.' ) - def dump_logcat(self, outfile): + def dump_logcat(self, outfile: str) -> None: + """ + dump logcat buffer into output file + """ if self.logcat_poller: self.logcat_poller.write_log(outfile) else: self.target.dump_logcat(outfile, logcat_format='threadtime') - def clear_logcat(self): + def clear_logcat(self) -> None: + """ + clear the logcat buffer + """ if self.logcat_poller: self.logcat_poller.clear_buffer() else: self.target.clear_logcat() - def _before_workload(self, _): + def _before_workload(self, _) -> None: + """ + things to do before start of workload + """ if self.logcat_poller: self.logcat_poller.start_logcat_wrap_detect() else: self.insert_logcat_marker() - def _after_workload(self, _): - self.logcat_poller.stop_logcat_wrap_detect() + def _after_workload(self, _) -> None: + """ + things to do after the end of workload run + """ + if self.logcat_poller: + self.logcat_poller.stop_logcat_wrap_detect() - def _check_logcat_nowrap(self, outfile): + def _check_logcat_nowrap(self, outfile: str) -> bool: + """ + check whether the logcat buffer is wrapping around or not + """ if self.logcat_poller: return self.logcat_poller.check_logcat_nowrap(outfile) else: @@ -172,7 +227,10 @@ def _check_logcat_nowrap(self, outfile): return False - def insert_logcat_marker(self): + def insert_logcat_marker(self) -> None: + """ + insert logcat marker for wrap detection + """ self.logger.debug('Inserting logcat marker') self.target.execute( 'log -t "{}" "{}"'.format( @@ -180,7 +238,10 @@ def insert_logcat_marker(self): ) ) - def do_disable_selinux(self): + def do_disable_selinux(self) -> None: + """ + disable SELinux + """ # SELinux was added in Android 4.3 (API level 18). Trying to # 'getenforce' in earlier versions will produce an error. if self.target.get_sdk_version() >= 18: @@ -190,26 +251,32 @@ def do_disable_selinux(self): class LogcatPoller(threading.Thread): - - def __init__(self, target, period=60, timeout=30): + """ + to poll logcat periodically and store the buffer + """ + def __init__(self, target: AndroidTarget, period: int = 60, + timeout: int = 30): super(LogcatPoller, self).__init__() self.target = target - self.logger = logging.getLogger('logcat') + self.logger: logging.Logger = logging.getLogger('logcat') self.period = period self.timeout = timeout self.stop_signal = threading.Event() self.lock = threading.RLock() self.buffer_file = tempfile.mktemp() - self.last_poll = 0 - self.daemon = True - self.exc = None + self.last_poll: float = 0 + self.daemon: bool = True + self.exc: Optional[Exception] = None self._logcat_marker_tag = 'WALog' self._logcat_marker_msg = 'WA logcat marker for wrap detection:{}' self._marker_count = 0 - self._start_marker = None - self._end_marker = None + self._start_marker: Optional[int] = None + self._end_marker: Optional[int] = None - def run(self): + def run(self) -> None: + """ + start polling logcat + """ self.logger.debug('Starting polling') try: self.insert_logcat_marker() @@ -226,7 +293,10 @@ def run(self): self.exc = WorkerThreadError(self.name, sys.exc_info()) self.logger.debug('Polling stopped') - def stop(self): + def stop(self) -> None: + """ + stop logcat polling + """ self.logger.debug('Stopping logcat polling') self.stop_signal.set() self.join(self.timeout) @@ -235,13 +305,19 @@ def stop(self): if self.exc: raise self.exc # pylint: disable=E0702 - def clear_buffer(self): + def clear_buffer(self) -> None: + """ + clear logcat buffer + """ self.logger.debug('Clearing logcat buffer') with self.lock: self.target.clear_logcat() touch(self.buffer_file) - def write_log(self, outfile): + def write_log(self, outfile: str) -> None: + """ + write log into output file + """ with self.lock: self.poll() if os.path.isfile(self.buffer_file): @@ -249,17 +325,26 @@ def write_log(self, outfile): else: # there was no logcat trace at this time touch(outfile) - def close(self): + def close(self) -> None: + """ + close the logcat poller and remove the temp log file + """ self.logger.debug('Closing poller') if os.path.isfile(self.buffer_file): os.remove(self.buffer_file) - def poll(self): + def poll(self) -> None: + """ + poll logcat buffer and dump it to the log file + """ self.last_poll = time.time() self.target.dump_logcat(self.buffer_file, append=True, timeout=self.timeout, logcat_format='threadtime') self.target.clear_logcat() - def insert_logcat_marker(self): + def insert_logcat_marker(self) -> None: + """ + insert logcat marker for wrap detection + """ self.logger.debug('Inserting logcat marker') with self.lock: self.target.execute( @@ -270,11 +355,16 @@ def insert_logcat_marker(self): ) self._marker_count += 1 - def check_logcat_nowrap(self, outfile): + def check_logcat_nowrap(self, outfile: str) -> bool: + """ + check whether the logcat buffer is wrapping around or not + """ parser = LogcatParser() - counter = self._start_marker + counter: Optional[int] = self._start_marker + if not counter: + return False for event in parser.parse(outfile): - message = self._logcat_marker_msg.split(':')[0] + message: str = self._logcat_marker_msg.split(':')[0] if not (event.tag == self._logcat_marker_tag and event.message.split(':')[0] == message): continue @@ -290,39 +380,57 @@ def check_logcat_nowrap(self, outfile): return False - def start_logcat_wrap_detect(self): + def start_logcat_wrap_detect(self) -> None: + """ + start logcat wrap detection + """ with self.lock: self._start_marker = self._marker_count self.insert_logcat_marker() - def stop_logcat_wrap_detect(self): + def stop_logcat_wrap_detect(self) -> None: + """ + stop logcat wrap detection + """ with self.lock: self._end_marker = self._marker_count class ChromeOsAssistant(LinuxAssistant): + """ + assistant to connect to a ChromeOs target + """ + parameters: List[Parameter] = LinuxAssistant.parameters + AndroidAssistant.parameters - parameters = LinuxAssistant.parameters + AndroidAssistant.parameters - - def __init__(self, target, logcat_poll_period=None, disable_selinux=True): + def __init__(self, target: ChromeOsTarget, + logcat_poll_period: Optional[int] = None, disable_selinux=True): super(ChromeOsAssistant, self).__init__(target) - if target.supports_android: - self.android_assistant = AndroidAssistant(target.android_container, - logcat_poll_period, disable_selinux) + if target.supports_android and target.android_container: + self.android_assistant: Optional[AndroidAssistant] = AndroidAssistant(target.android_container, + logcat_poll_period, disable_selinux) else: self.android_assistant = None - def start(self): + def start(self) -> None: + """ + start ChromeOs target + """ super(ChromeOsAssistant, self).start() if self.android_assistant: self.android_assistant.start() - def extract_results(self, context): + def extract_results(self, context: 'ExecutionContext') -> None: + """ + extract execution results from target + """ super(ChromeOsAssistant, self).extract_results(context) if self.android_assistant: self.android_assistant.extract_results(context) - def stop(self): + def stop(self) -> None: + """ + stop ChromeOs target + """ super(ChromeOsAssistant, self).stop() if self.android_assistant: self.android_assistant.stop() diff --git a/wa/framework/target/config.py b/wa/framework/target/config.py index 1e6eed0af..5135fab8f 100644 --- a/wa/framework/target/config.py +++ b/wa/framework/target/config.py @@ -14,22 +14,24 @@ # from copy import copy +from typing import TYPE_CHECKING, Union, Optional, cast, Any +if TYPE_CHECKING: + from wa.utils.types import ParameterDict class TargetConfig(dict): """ Represents a configuration for a target. - """ - def __init__(self, config=None): + def __init__(self, config: Optional[Union['TargetConfig', 'ParameterDict']] = None): dict.__init__(self) if isinstance(config, TargetConfig): self.__dict__ = copy(config.__dict__) elif hasattr(config, 'iteritems'): - for k, v in config.iteritems: + for k, v in cast('ParameterDict', config).iteritems(): self.set(k, v) elif config: raise ValueError(config) - def set(self, name, value): + def set(self, name: str, value: Any): setattr(self, name, value) diff --git a/wa/framework/target/descriptor.py b/wa/framework/target/descriptor.py index aeb46ccf0..58f8c492c 100644 --- a/wa/framework/target/descriptor.py +++ b/wa/framework/target/descriptor.py @@ -15,13 +15,17 @@ import inspect -from devlib import (LinuxTarget, AndroidTarget, LocalLinuxTarget, - ChromeOsTarget, Platform, Juno, TC2, Gem5SimulationPlatform, - AdbConnection, SshConnection, LocalConnection, - TelnetConnection, Gem5Connection) -from devlib.target import DEFAULT_SHELL_PROMPT +from devlib.target import (DEFAULT_SHELL_PROMPT, LinuxTarget, AndroidTarget, + LocalLinuxTarget, ChromeOsTarget, Target) +from devlib.platform import Platform +from devlib.platform.arm import Juno, TC2 +from devlib.platform.gem5 import Gem5SimulationPlatform +from devlib.utils.android import AdbConnection +from devlib.utils.ssh import SshConnection, TelnetConnection, Gem5Connection +from devlib.host import LocalConnection +from devlib.utils.annotation_helpers import SupportedConnections from devlib.utils.ssh import DEFAULT_SSH_SUDO_COMMAND - +from devlib.utils.misc import InitCheckpointMeta from wa.framework import pluginloader from wa.framework.configuration.core import get_config_point_map from wa.framework.exception import PluginLoaderError @@ -29,36 +33,53 @@ from wa.framework.target.assistant import LinuxAssistant, AndroidAssistant, ChromeOsAssistant from wa.utils.types import list_of_strings, list_of_ints, regex, identifier, caseless_string from wa.utils.misc import isiterable +from types import ModuleType +from typing import (List, Dict, Union, cast, Tuple, + Optional, Type, Any, Iterable) +from typing_extensions import Protocol -def list_target_descriptions(loader=pluginloader): - targets = {} +def list_target_descriptions(loader: ModuleType = pluginloader) -> List['TargetDescriptionProtocol']: + """ + get list of all the target descriptions + """ + targets: Dict[str, 'TargetDescriptionProtocol'] = {} for cls in loader.list_target_descriptors(): - descriptor = cls() + descriptor: 'TargetDescriptor' = cls() for desc in descriptor.get_descriptions(): if desc.name in targets: msg = 'Duplicate target "{}" returned by {} and {}' - prev_dtor = targets[desc.name].source + # FIXME - not sure how the source which is a string, can be treated as a targetdescription + prev_dtor = cast('TargetDescriptionProtocol', targets[desc.name].source) raise PluginLoaderError(msg.format(desc.name, prev_dtor.name, descriptor.name)) targets[desc.name] = desc return list(targets.values()) -def get_target_description(name, loader=pluginloader): +def get_target_description(name: str, loader: ModuleType = pluginloader) -> 'TargetDescriptionProtocol': + """ + get a specific target description + """ for tdesc in list_target_descriptions(loader): if tdesc.name == name: return tdesc raise ValueError('Could not find target descriptor "{}"'.format(name)) -def instantiate_target(tdesc, params, connect=None, extra_platform_params=None): +def instantiate_target(tdesc: 'TargetDescriptionProtocol', params: Dict[str, Parameter], + connect: Optional[bool] = None, extra_platform_params: Optional[Dict[str, Any]] = None) -> Target: + """ + instantiate a target based on the target description and parameters + """ # pylint: disable=too-many-locals,too-many-branches target_params = get_config_point_map(tdesc.target_params) platform_params = get_config_point_map(tdesc.platform_params) conn_params = get_config_point_map(tdesc.conn_params) assistant_params = get_config_point_map(tdesc.assistant_params) - + tp: Dict[str, Any] + pp: Dict[str, Any] + cp: Dict[str, Any] tp, pp, cp = {}, {}, {} for supported_params, new_params in (target_params, tp), (platform_params, pp), (conn_params, cp): @@ -86,7 +107,7 @@ def instantiate_target(tdesc, params, connect=None, extra_platform_params=None): if pname in pp: raise RuntimeError('Platform parameter clash: {}'.format(pname)) pp[pname] = pval - + # FIXME - Platform is not callable tp['platform'] = (tdesc.platform or Platform)(**pp) if cp: tp['connection_settings'] = cp @@ -98,21 +119,49 @@ def instantiate_target(tdesc, params, connect=None, extra_platform_params=None): return tdesc.target(**tp) -def instantiate_assistant(tdesc, params, target): - assistant_params = {} +def instantiate_assistant(tdesc: 'TargetDescriptionProtocol', params: Dict[str, Parameter], + target: Target) -> Union[LinuxAssistant, AndroidAssistant]: + """ + instantiate assistant to connect to the target + """ + assistant_params: Dict[str, Any] = {} for param in tdesc.assistant_params: if param.name in params: assistant_params[param.name] = params[param.name] elif param.default: assistant_params[param.name] = param.default - return tdesc.assistant(target, **assistant_params) + # FIXME - casting target to Any because, assistant can be linuxassistant or androidassistant. They need linuxtarget or androidtarget + # respectively. Not sure how to annotate this + return tdesc.assistant(cast(Any, target), **assistant_params) -class TargetDescription(object): +class TargetDescriptionProtocol(Protocol): + name: str + source: str + description: str + target: Type[Target] + platform: Type[Platform] + connection: SupportedConnections + assistant: Union[Type[LinuxAssistant], Type[AndroidAssistant]] + target_params: List[Parameter] + platform_params: List[Parameter] + conn_params: List[Parameter] + assistant_params: List[Parameter] + conn: InitCheckpointMeta + + def get_default_config(self) -> Dict[str, Any]: + ... + - def __init__(self, name, source, description=None, target=None, platform=None, - conn=None, assistant=None, target_params=None, platform_params=None, - conn_params=None, assistant_params=None): +class TargetDescription(object): + """ + description of the target with target, platform, and assistant configurations + """ + def __init__(self, name: str, source: Any, description: Optional[str] = None, + target: Optional[Type[Target]] = None, platform: Optional[Type[Platform]] = None, + conn: Optional[InitCheckpointMeta] = None, assistant: Optional[Union[LinuxAssistant, AndroidAssistant]] = None, + target_params: Optional[Dict[str, Parameter]] = None, platform_params: Optional[Dict[str, Parameter]] = None, + conn_params: Optional[Dict[str, Parameter]] = None, assistant_params: Optional[Dict[str, Parameter]] = None): self.name = name self.source = source self.description = description @@ -125,22 +174,28 @@ def __init__(self, name, source, description=None, target=None, platform=None, self._set('conn_params', conn_params) self._set('assistant_params', assistant_params) - def get_default_config(self): - param_attrs = ['target_params', 'platform_params', - 'conn_params', 'assistant_params'] - config = {} + def get_default_config(self) -> Dict[str, Any]: + """ + get default configuration for the target + """ + param_attrs: List[str] = ['target_params', 'platform_params', + 'conn_params', 'assistant_params'] + config: Dict[str, Any] = {} for pattr in param_attrs: for p in getattr(self, pattr): - if not p.deprecated: - config[p.name] = p.default + if not cast(Parameter, p).deprecated: + config[cast(Parameter, p).name] = cast(Parameter, p).default return config - def _set(self, attr, vals): + def _set(self, attr: str, vals: Optional[Iterable]) -> None: + """ + set values to the attributes + """ if vals is None: vals = [] elif isiterable(vals): if hasattr(vals, 'values'): - vals = list(vals.values()) + vals = list(vals.values()) # type: ignore else: msg = '{} must be iterable; got "{}"' raise ValueError(msg.format(attr, vals)) @@ -148,14 +203,19 @@ def _set(self, attr, vals): class TargetDescriptor(Plugin): - + """ + Target descriptor Plugin + """ kind = 'target_descriptor' - def get_descriptions(self): # pylint: disable=no-self-use + def get_descriptions(self) -> List[TargetDescriptionProtocol]: # pylint: disable=no-self-use + """ + get list of target description (class:'TargetDescription') + """ return [] -COMMON_TARGET_PARAMS = [ +COMMON_TARGET_PARAMS: List[Parameter] = [ Parameter('working_directory', kind=str, description=''' On-target working directory that will be used by WA. This @@ -200,13 +260,13 @@ def get_descriptions(self): # pylint: disable=no-self-use '''), Parameter('max_async', kind=int, default=50, - description=''' + description=''' The maximum number of concurent asynchronous connections to the target maintained at any time. '''), ] -COMMON_PLATFORM_PARAMS = [ +COMMON_PLATFORM_PARAMS: List[Parameter] = [ Parameter('core_names', kind=list_of_strings, description=''' List of names of CPU cores in the order that they appear to the @@ -237,7 +297,7 @@ def get_descriptions(self): # pylint: disable=no-self-use '''), ] -VEXPRESS_PLATFORM_PARAMS = [ +VEXPRESS_PLATFORM_PARAMS: List[Parameter] = [ Parameter('serial_port', kind=str, description=''' The serial device/port on the host for the initial connection to @@ -274,7 +334,7 @@ def get_descriptions(self): # pylint: disable=no-self-use '''), ] -GEM5_PLATFORM_PARAMS = [ +GEM5_PLATFORM_PARAMS: List[Parameter] = [ Parameter('gem5_bin', kind=str, mandatory=True, description=''' Path to the gem5 binary @@ -295,7 +355,7 @@ def get_descriptions(self): # pylint: disable=no-self-use ] -CONNECTION_PARAMS = { +CONNECTION_PARAMS: Dict[InitCheckpointMeta, List[Parameter]] = { AdbConnection: [ Parameter( 'device', kind=str, @@ -546,32 +606,36 @@ def get_descriptions(self): # pylint: disable=no-self-use ], } -CONNECTION_PARAMS['ChromeOsConnection'] = \ +CONNECTION_PARAMS[cast(InitCheckpointMeta, 'ChromeOsConnection')] = \ CONNECTION_PARAMS[AdbConnection] + CONNECTION_PARAMS[SshConnection] +TargetTuple = Tuple[Tuple[Union[Type[LinuxTarget], Type[AndroidTarget], Type[ChromeOsTarget]], + InitCheckpointMeta, List[Type[Platform]]], + List[Parameter], Optional[List[Parameter]]] # name --> ((target_class, conn_class, unsupported_platforms), params_list, defaults) -TARGETS = { +TARGETS: Dict[str, TargetTuple] = { 'linux': ((LinuxTarget, SshConnection, []), COMMON_TARGET_PARAMS, None), 'android': ((AndroidTarget, AdbConnection, []), COMMON_TARGET_PARAMS + - [Parameter('package_data_directory', kind=str, default='/data/data', - description=''' - Directory containing Android data - '''), - ], None), - 'chromeos': ((ChromeOsTarget, 'ChromeOsConnection', []), COMMON_TARGET_PARAMS + [Parameter('package_data_directory', kind=str, default='/data/data', description=''' + Directory containing Android data + '''), + ], None), + 'chromeos': ((ChromeOsTarget, cast(InitCheckpointMeta, 'ChromeOsConnection'), []), + COMMON_TARGET_PARAMS + + [Parameter('package_data_directory', kind=str, default='/data/data', + description=''' Directory containing Android data '''), - Parameter('android_working_directory', kind=str, - description=''' + Parameter('android_working_directory', kind=str, + description=''' On-target working directory that will be used by WA for the android container. This directory must be writable by the user WA logs in as without the need for privilege elevation. '''), - Parameter('android_executables_directory', kind=str, - description=''' + Parameter('android_executables_directory', kind=str, + description=''' On-target directory where WA will install its executable binaries for the android container. This location must allow execution. This location does *not* need to be writable by unprivileged users or @@ -579,13 +643,13 @@ def get_descriptions(self): # pylint: disable=no-self-use directory must be writable by the user WA logs in as without the need for privilege elevation. '''), - ], None), + ], None), 'local': ((LocalLinuxTarget, LocalConnection, [Juno, Gem5SimulationPlatform, TC2]), COMMON_TARGET_PARAMS, None), } # name --> assistant -ASSISTANTS = { +ASSISTANTS: Dict[str, Union[Type[LinuxAssistant], Type[AndroidAssistant]]] = { 'linux': LinuxAssistant, 'android': AndroidAssistant, 'local': LinuxAssistant, @@ -593,29 +657,29 @@ def get_descriptions(self): # pylint: disable=no-self-use } # Platform specific parameter overrides. -JUNO_PLATFORM_OVERRIDES = [ - Parameter('baudrate', kind=int, default=115200, - description=''' +JUNO_PLATFORM_OVERRIDES: List[Parameter] = [ + Parameter('baudrate', kind=int, default=115200, + description=''' Baud rate for the serial connection. '''), - Parameter('vemsd_mount', kind=str, default='/media/JUNO', - description=''' + Parameter('vemsd_mount', kind=str, default='/media/JUNO', + description=''' VExpress MicroSD card mount location. This is a MicroSD card in the VExpress device that is mounted on the host via USB. The card contains configuration files for the platform and firmware and kernel images to be flashed. '''), - Parameter('bootloader', kind=str, default='u-boot', - allowed_values=['uefi', 'uefi-shell', 'u-boot', 'bootmon'], - description=''' + Parameter('bootloader', kind=str, default='u-boot', + allowed_values=['uefi', 'uefi-shell', 'u-boot', 'bootmon'], + description=''' Selects the bootloader mechanism used by the board. Depending on firmware version, a number of possible boot mechanisms may be use. Please see ``devlib`` documentation for descriptions. '''), - Parameter('hard_reset_method', kind=str, default='dtr', - allowed_values=['dtr', 'reboottxt'], - description=''' + Parameter('hard_reset_method', kind=str, default='dtr', + allowed_values=['dtr', 'reboottxt'], + description=''' There are a couple of ways to reset VersatileExpress board if the software running on the board becomes unresponsive. Both require configuration to be enabled (please see ``devlib`` documentation). @@ -624,29 +688,29 @@ def get_descriptions(self): # pylint: disable=no-self-use ``reboottxt``: create ``reboot.txt`` in the root of the VEMSD mount. '''), ] -TC2_PLATFORM_OVERRIDES = [ - Parameter('baudrate', kind=int, default=38400, - description=''' +TC2_PLATFORM_OVERRIDES: List[Parameter] = [ + Parameter('baudrate', kind=int, default=38400, + description=''' Baud rate for the serial connection. '''), - Parameter('vemsd_mount', kind=str, default='/media/VEMSD', - description=''' + Parameter('vemsd_mount', kind=str, default='/media/VEMSD', + description=''' VExpress MicroSD card mount location. This is a MicroSD card in the VExpress device that is mounted on the host via USB. The card contains configuration files for the platform and firmware and kernel images to be flashed. '''), - Parameter('bootloader', kind=str, default='bootmon', - allowed_values=['uefi', 'uefi-shell', 'u-boot', 'bootmon'], - description=''' + Parameter('bootloader', kind=str, default='bootmon', + allowed_values=['uefi', 'uefi-shell', 'u-boot', 'bootmon'], + description=''' Selects the bootloader mechanism used by the board. Depending on firmware version, a number of possible boot mechanisms may be use. Please see ``devlib`` documentation for descriptions. '''), - Parameter('hard_reset_method', kind=str, default='reboottxt', - allowed_values=['dtr', 'reboottxt'], - description=''' + Parameter('hard_reset_method', kind=str, default='reboottxt', + allowed_values=['dtr', 'reboottxt'], + description=''' There are a couple of ways to reset VersatileExpress board if the software running on the board becomes unresponsive. Both require configuration to be enabled (please see ``devlib`` documentation). @@ -663,13 +727,14 @@ def get_descriptions(self): # pylint: disable=no-self-use # particular platform. Parameters you can override are in COMMON_TARGET_PARAMS # Example of overriding one of the target parameters: Replace last `None` with # a list of `Parameter` objects to be used instead. -PLATFORMS = { +PLATFORMS: Dict[str, Tuple[Tuple[Type[Platform], Optional[InitCheckpointMeta], Optional[List[Parameter]]], + Optional[List[Parameter]], Optional[List[Parameter]], Optional[List[Parameter]]]] = { 'generic': ((Platform, None, None), COMMON_PLATFORM_PARAMS, None, None), 'juno': ((Juno, None, [ - Parameter('host', kind=str, mandatory=False, - description="Host name or IP address of the target."), - ] - ), COMMON_PLATFORM_PARAMS + VEXPRESS_PLATFORM_PARAMS, JUNO_PLATFORM_OVERRIDES, None), + Parameter('host', kind=str, mandatory=False, + description="Host name or IP address of the target."), + ] + ), COMMON_PLATFORM_PARAMS + VEXPRESS_PLATFORM_PARAMS, JUNO_PLATFORM_OVERRIDES, None), 'tc2': ((TC2, None, None), COMMON_PLATFORM_PARAMS + VEXPRESS_PLATFORM_PARAMS, TC2_PLATFORM_OVERRIDES, None), 'gem5': ((Gem5SimulationPlatform, Gem5Connection, None), GEM5_PLATFORM_PARAMS, None, None), @@ -677,34 +742,36 @@ def get_descriptions(self): # pylint: disable=no-self-use class DefaultTargetDescriptor(TargetDescriptor): + """ + default target descriptor plugin + """ + name: str = 'devlib_targets' - name = 'devlib_targets' - - description = """ + description: str = """ The default target descriptor that provides descriptions in the form _. - These map directly onto ``Target``\ s and ``Platform``\ s supplied by ``devlib``. + These map directly onto ``Target``s and ``Platform`` s supplied by ``devlib``. """ - def get_descriptions(self): + def get_descriptions(self) -> List[TargetDescriptionProtocol]: # pylint: disable=attribute-defined-outside-init,too-many-locals - result = [] + result: List[TargetDescriptionProtocol] = [] for target_name, target_tuple in TARGETS.items(): (target, conn, unsupported_platforms), target_params = self._get_item(target_tuple) - assistant = ASSISTANTS[target_name] - conn_params = CONNECTION_PARAMS[conn] + assistant: Union[Type[LinuxAssistant], Type[AndroidAssistant]] = ASSISTANTS[target_name] + conn_params: List[Parameter] = CONNECTION_PARAMS[conn] for platform_name, platform_tuple in PLATFORMS.items(): - platform_target_defaults = platform_tuple[-1] - platform_tuple = platform_tuple[0:-1] - (platform, plat_conn, conn_defaults), platform_params = self._get_item(platform_tuple) + platform_target_defaults: Optional[List[Parameter]] = platform_tuple[-1] + platform_tuple_slice = platform_tuple[0:-1] + (platform, plat_conn, conn_defaults), platform_params = self._get_item(platform_tuple_slice) if platform in unsupported_platforms: continue # Add target defaults specified in the Platform tuple target_params = self._override_params(target_params, platform_target_defaults) name = '{}_{}'.format(platform_name, target_name) - td = TargetDescription(name, self) + td: TargetDescriptionProtocol = cast(TargetDescriptionProtocol, TargetDescription(name, self)) td.target = target td.platform = platform td.assistant = assistant @@ -719,39 +786,47 @@ def get_descriptions(self): else: td.conn = conn td.conn_params = self._override_params(conn_params, conn_defaults) - result.append(td) return result - def _override_params(self, params, overrides): # pylint: disable=no-self-use + def _override_params(self, params: List[Parameter], + overrides: Optional[List[Parameter]]) -> List[Parameter]: # pylint: disable=no-self-use ''' Returns a new list of parameters replacing any parameter with the corresponding parameter in overrides''' if not overrides: return params - param_map = {p.name: p for p in params} + param_map: Dict[str, Parameter] = {p.name: p for p in params} for override in overrides: if override.name in param_map: param_map[override.name] = override # Return the list of overriden parameters return list(param_map.values()) - def _get_item(self, item_tuple): + # FIXME - cannot make exact type annotation for the cls_tuple being returned as the two usages + # (target tuple vs platform tuple) are not having same types + def _get_item(self, item_tuple: Tuple) -> Tuple[Tuple, List[Parameter]]: + """ + get the item tuple of target descriptor or platform descriptor + """ + cls_tuple: Union[Tuple[Union[Type[LinuxTarget], Type[AndroidTarget], Type[ChromeOsTarget]], + InitCheckpointMeta, List[Type[Platform]]], + Tuple[Type[Platform], Optional[InitCheckpointMeta], Optional[List[Parameter]]]] cls_tuple, params, defaults = item_tuple updated_params = self._override_params(params, defaults) return cls_tuple, updated_params -_adhoc_target_descriptions = [] +_adhoc_target_descriptions: List[TargetDescription] = [] -def create_target_description(name, *args, **kwargs): +def create_target_description(name: str, *args, **kwargs) -> None: name = identifier(name) for td in _adhoc_target_descriptions: if caseless_string(name) == td.name: - msg = 'Target with name "{}" already exists (from source: {})' + msg: str = 'Target with name "{}" already exists (from source: {})' raise ValueError(msg.format(name, td.source)) - stack = inspect.stack() + stack: List[inspect.FrameInfo] = inspect.stack() # inspect.stack() returns a list of call frame records for the current thread # in reverse call order. So the first entry is for the current frame and next one # for the immediate caller. Each entry is a tuple in the format @@ -761,57 +836,35 @@ def create_target_description(name, *args, **kwargs): # because this might be invoked via the add_scription_for_target wrapper, we need to # check for that, and make sure that we get the info for *its* caller in that case. if stack[1][3] == 'add_description_for_target': - source = stack[2][1] + source: str = stack[2][1] else: source = stack[1][1] _adhoc_target_descriptions.append(TargetDescription(name, source, *args, **kwargs)) -def _get_target_defaults(target): - specificity = 0 - res = ('linux', TARGETS['linux']) # fallback to a generic linux target +def _get_target_defaults(target: Type[Target]) -> Tuple[str, TargetTuple]: + """ + get defaults for target + """ + specificity: int = 0 + res: Tuple[str, TargetTuple] = ('linux', TARGETS['linux']) # fallback to a generic linux target for name, ttup in TARGETS.items(): if issubclass(target, ttup[0][0]): - new_spec = len(inspect.getmro(ttup[0][0])) + new_spec: int = len(inspect.getmro(ttup[0][0])) if new_spec > specificity: res = (name, ttup) specificity = new_spec return res -def add_description_for_target(target, description=None, **kwargs): - (base_name, ((_, base_conn, _), base_params, _)) = _get_target_defaults(target) - - if 'target_params' not in kwargs: - kwargs['target_params'] = base_params - - if 'platform' not in kwargs: - kwargs['platform'] = Platform - if 'platform_params' not in kwargs: - for (plat, conn, _), params, _, _ in PLATFORMS.values(): - if plat == kwargs['platform']: - kwargs['platform_params'] = params - if conn is not None and kwargs['conn'] is None: - kwargs['conn'] = conn - break - - if 'conn' not in kwargs: - kwargs['conn'] = base_conn - if 'conn_params' not in kwargs: - kwargs['conn_params'] = CONNECTION_PARAMS.get(kwargs['conn']) - - if 'assistant' not in kwargs: - kwargs['assistant'] = ASSISTANTS.get(base_name) - - create_target_description(target.name, target=target, description=description, **kwargs) - - class SimpleTargetDescriptor(TargetDescriptor): + """ + a simple target descriptor + """ + name: str = 'adhoc_targets' - name = 'adhoc_targets' - - description = """ + description: str = """ Returns target descriptions added with ``create_target_description``. """ diff --git a/wa/framework/target/info.py b/wa/framework/target/info.py index a7a8dd15b..c20903ab6 100644 --- a/wa/framework/target/info.py +++ b/wa/framework/target/info.py @@ -16,20 +16,27 @@ import os -from devlib import AndroidTarget, TargetError -from devlib.target import KernelConfig, KernelVersion, Cpuinfo +from devlib.exception import TargetError +from devlib.target import (KernelConfig, KernelVersion, Cpuinfo, + AndroidTarget, Target) from devlib.utils.android import AndroidProperties from wa.framework.configuration.core import settings from wa.framework.exception import ConfigError from wa.utils.serializer import read_pod, write_pod, Podable from wa.utils.misc import atomic_write_path +from typing import cast, Optional, List, Dict, Tuple, Any +from devlib.module.cpufreq import CpufreqModule +from devlib.module.cpuidle import Cpuidle -def cpuinfo_from_pod(pod): +def cpuinfo_from_pod(pod: Dict[str, Any]) -> Cpuinfo: + """ + get cpu info (devlib) from a plain old datastructure + """ cpuinfo = Cpuinfo('') cpuinfo.sections = pod['cpuinfo'] - lines = [] + lines: List[str] = [] for section in cpuinfo.sections: for key, value in section.items(): line = '{}: {}'.format(key, value) @@ -39,9 +46,12 @@ def cpuinfo_from_pod(pod): return cpuinfo -def kernel_version_from_pod(pod): - release_string = pod['kernel_release'] - version_string = pod['kernel_version'] +def kernel_version_from_pod(pod) -> KernelVersion: + """ + get kernel version from plain old datastructure + """ + release_string: str = pod['kernel_release'] + version_string: str = pod['kernel_version'] if release_string: if version_string: kernel_string = '{} #{}'.format(release_string, version_string) @@ -52,10 +62,13 @@ def kernel_version_from_pod(pod): return KernelVersion(kernel_string) -def kernel_config_from_pod(pod): +def kernel_config_from_pod(pod: Dict[str, Any]) -> KernelConfig: + """ + get kernel configuration from plain old datastructure + """ config = KernelConfig('') config.typed_config._config = pod['kernel_config'] - lines = [] + lines: List[str] = [] for key, value in config.items(): if value == 'n': lines.append('# {} is not set'.format(key)) @@ -66,29 +79,31 @@ def kernel_config_from_pod(pod): class CpufreqInfo(Podable): - - _pod_serialization_version = 1 + """ + cpu frequency information + """ + _pod_serialization_version: int = 1 @staticmethod - def from_pod(pod): + def from_pod(pod: Dict[str, Any]) -> 'CpufreqInfo': pod = CpufreqInfo._upgrade_pod(pod) return CpufreqInfo(**pod) - def __init__(self, **kwargs): + def __init__(self, **kwargs) -> None: super(CpufreqInfo, self).__init__() - self.available_frequencies = kwargs.pop('available_frequencies', []) - self.available_governors = kwargs.pop('available_governors', []) - self.related_cpus = kwargs.pop('related_cpus', []) - self.driver = kwargs.pop('driver', None) - self._pod_version = kwargs.pop('_pod_version', self._pod_serialization_version) + self.available_frequencies: List[int] = kwargs.pop('available_frequencies', []) + self.available_governors: List[str] = kwargs.pop('available_governors', []) + self.related_cpus: List[int] = kwargs.pop('related_cpus', []) + self.driver: Optional[str] = kwargs.pop('driver', None) + self._pod_version: int = kwargs.pop('_pod_version', self._pod_serialization_version) - def to_pod(self): + def to_pod(self) -> Dict[str, Any]: pod = super(CpufreqInfo, self).to_pod() pod.update(self.__dict__) return pod @staticmethod - def _pod_upgrade_v1(pod): + def _pod_upgrade_v1(pod: Dict[str, Any]) -> Dict[str, Any]: pod['_pod_version'] = pod.get('_pod_version', 1) return pod @@ -99,29 +114,31 @@ def __repr__(self): class IdleStateInfo(Podable): - - _pod_serialization_version = 1 + """ + idle state information + """ + _pod_serialization_version: int = 1 @staticmethod - def from_pod(pod): + def from_pod(pod: Dict[str, Any]) -> 'IdleStateInfo': pod = IdleStateInfo._upgrade_pod(pod) return IdleStateInfo(**pod) - def __init__(self, **kwargs): + def __init__(self, **kwargs) -> None: super(IdleStateInfo, self).__init__() - self.name = kwargs.pop('name', None) - self.desc = kwargs.pop('desc', None) - self.power = kwargs.pop('power', None) - self.latency = kwargs.pop('latency', None) - self._pod_version = kwargs.pop('_pod_version', self._pod_serialization_version) + self.name: Optional[str] = kwargs.pop('name', None) + self.desc: Optional[str] = kwargs.pop('desc', None) + self.power: Optional[int] = kwargs.pop('power', None) + self.latency: Optional[int] = kwargs.pop('latency', None) + self._pod_version: int = kwargs.pop('_pod_version', self._pod_serialization_version) - def to_pod(self): + def to_pod(self) -> Dict[str, Any]: pod = super(IdleStateInfo, self).to_pod() pod.update(self.__dict__) return pod @staticmethod - def _pod_upgrade_v1(pod): + def _pod_upgrade_v1(pod: Dict[str, Any]) -> Dict[str, Any]: pod['_pod_version'] = pod.get('_pod_version', 1) return pod @@ -132,11 +149,13 @@ def __repr__(self): class CpuidleInfo(Podable): - - _pod_serialization_version = 1 + """ + cpu idle information + """ + _pod_serialization_version: int = 1 @staticmethod - def from_pod(pod): + def from_pod(pod: Dict[str, Any]) -> 'CpuidleInfo': pod = CpuidleInfo._upgrade_pod(pod) instance = CpuidleInfo() instance._pod_version = pod['_pod_version'] @@ -146,16 +165,19 @@ def from_pod(pod): return instance @property - def num_states(self): + def num_states(self) -> int: + """ + number of cpu idle states + """ return len(self.states) - def __init__(self): + def __init__(self) -> None: super(CpuidleInfo, self).__init__() - self.governor = None - self.driver = None - self.states = [] + self.governor: Optional[str] = None + self.driver: Optional[str] = None + self.states: List[IdleStateInfo] = [] - def to_pod(self): + def to_pod(self) -> Dict[str, Any]: pod = super(CpuidleInfo, self).to_pod() pod['governor'] = self.governor pod['driver'] = self.driver @@ -163,7 +185,7 @@ def to_pod(self): return pod @staticmethod - def _pod_upgrade_v1(pod): + def _pod_upgrade_v1(pod: Dict[str, Any]) -> Dict[str, Any]: pod['_pod_version'] = pod.get('_pod_version', 1) return pod @@ -175,12 +197,14 @@ def __repr__(self): class CpuInfo(Podable): - - _pod_serialization_version = 1 + """ + Cpu information + """ + _pod_serialization_version: int = 1 @staticmethod - def from_pod(pod): - instance = super(CpuInfo, CpuInfo).from_pod(pod) + def from_pod(pod) -> 'CpuInfo': + instance = cast('CpuInfo', super(CpuInfo, CpuInfo).from_pod(pod)) instance.id = pod['id'] instance.name = pod['name'] instance.architecture = pod['architecture'] @@ -189,16 +213,16 @@ def from_pod(pod): instance.cpuidle = CpuidleInfo.from_pod(pod['cpuidle']) return instance - def __init__(self): + def __init__(self) -> None: super(CpuInfo, self).__init__() - self.id = None - self.name = None - self.architecture = None - self.features = [] + self.id: Optional[int] = None + self.name: Optional[str] = None + self.architecture: Optional[str] = None + self.features: List[str] = [] self.cpufreq = CpufreqInfo() self.cpuidle = CpuidleInfo() - def to_pod(self): + def to_pod(self) -> Dict[str, Any]: pod = super(CpuInfo, self).to_pod() pod['id'] = self.id pod['name'] = self.name @@ -209,7 +233,7 @@ def to_pod(self): return pod @staticmethod - def _pod_upgrade_v1(pod): + def _pod_upgrade_v1(pod: Dict[str, Any]) -> Dict[str, Any]: pod['_pod_version'] = pod.get('_pod_version', 1) return pod @@ -219,7 +243,10 @@ def __repr__(self): __str__ = __repr__ -def get_target_info(target): +def get_target_info(target: Target) -> 'TargetInfo': + """ + get information about the target + """ info = TargetInfo() info.target = target.__class__.__name__ info.modules = target.modules @@ -234,7 +261,7 @@ def get_target_info(target): info.hostid = target.hostid try: - info.sched_features = target.read_value('/sys/kernel/debug/sched_features').split() + info.sched_features = cast(str, target.read_value('/sys/kernel/debug/sched_features')).split() except TargetError: # best effort -- debugfs might not be mounted pass @@ -247,15 +274,15 @@ def get_target_info(target): cpu.architecture = target.cpuinfo.architecture if target.has('cpufreq'): - cpu.cpufreq.available_governors = target.cpufreq.list_governors(i) - cpu.cpufreq.available_frequencies = target.cpufreq.list_frequencies(i) - cpu.cpufreq.related_cpus = target.cpufreq.get_related_cpus(i) - cpu.cpufreq.driver = target.cpufreq.get_driver(i) + cpu.cpufreq.available_governors = cast(CpufreqModule, target.cpufreq).list_governors(i) + cpu.cpufreq.available_frequencies = cast(CpufreqModule, target.cpufreq).list_frequencies(i) + cpu.cpufreq.related_cpus = cast(CpufreqModule, target.cpufreq).get_related_cpus(i) + cpu.cpufreq.driver = cast(CpufreqModule, target.cpufreq).get_driver(i) if target.has('cpuidle'): - cpu.cpuidle.driver = target.cpuidle.get_driver() - cpu.cpuidle.governor = target.cpuidle.get_governor() - for state in target.cpuidle.get_states(i): + cpu.cpuidle.driver = cast(Cpuidle, target.cpuidle).get_driver() + cpu.cpuidle.governor = cast(Cpuidle, target.cpuidle).get_governor() + for state in cast(Cpuidle, target.cpuidle).get_states(i): state_info = IdleStateInfo() state_info.name = state.name state_info.desc = state.desc @@ -275,7 +302,10 @@ def get_target_info(target): return info -def read_target_info_cache(): +def read_target_info_cache() -> Dict[str, Any]: + """ + read cached target information + """ if not os.path.exists(settings.cache_directory): os.makedirs(settings.cache_directory) if not os.path.isfile(settings.target_info_cache_file): @@ -283,14 +313,21 @@ def read_target_info_cache(): return read_pod(settings.target_info_cache_file) -def write_target_info_cache(cache): +def write_target_info_cache(cache: Dict[str, Any]) -> None: + """ + cache the target information + """ if not os.path.exists(settings.cache_directory): os.makedirs(settings.cache_directory) with atomic_write_path(settings.target_info_cache_file) as at_path: write_pod(cache, at_path) -def get_target_info_from_cache(system_id, cache=None): +def get_target_info_from_cache(system_id: str, + cache: Optional[Dict[str, Any]] = None) -> Optional['TargetInfo']: + """ + get target information from cache + """ if cache is None: cache = read_target_info_cache() pod = cache.get(system_id, None) @@ -298,7 +335,7 @@ def get_target_info_from_cache(system_id, cache=None): if not pod: return None - _pod_version = pod.get('_pod_version', 0) + _pod_version: int = pod.get('_pod_version', 0) if _pod_version != TargetInfo._pod_serialization_version: msg = 'Target info version mismatch. Expected {}, but found {}.\nTry deleting {}' raise ConfigError(msg.format(TargetInfo._pod_serialization_version, _pod_version, @@ -306,22 +343,29 @@ def get_target_info_from_cache(system_id, cache=None): return TargetInfo.from_pod(pod) -def cache_target_info(target_info, overwrite=False, cache=None): +def cache_target_info(target_info: 'TargetInfo', overwrite: bool = False, + cache: Optional[Dict[str, Any]] = None): + """ + store target information into the cache + """ if cache is None: cache = read_target_info_cache() if target_info.system_id in cache and not overwrite: raise ValueError('TargetInfo for {} is already in cache.'.format(target_info.system_id)) - cache[target_info.system_id] = target_info.to_pod() + if target_info.system_id: + cache[target_info.system_id] = target_info.to_pod() write_target_info_cache(cache) class TargetInfo(Podable): - - _pod_serialization_version = 5 + """ + target information + """ + _pod_serialization_version: int = 5 @staticmethod - def from_pod(pod): - instance = super(TargetInfo, TargetInfo).from_pod(pod) + def from_pod(pod) -> 'TargetInfo': + instance = cast('TargetInfo', super(TargetInfo, TargetInfo).from_pod(pod)) instance.target = pod['target'] instance.modules = pod['modules'] instance.abi = pod['abi'] @@ -345,27 +389,27 @@ def from_pod(pod): return instance - def __init__(self): + def __init__(self) -> None: super(TargetInfo, self).__init__() - self.target = None - self.modules = [] - self.cpus = [] - self.os = None - self.os_version = None - self.system_id = None - self.hostid = None - self.hostname = None - self.abi = None - self.is_rooted = None - self.kernel_version = None - self.kernel_config = None - self.sched_features = None - self.screen_resolution = None - self.prop = None - self.android_id = None - self.page_size_kb = None - - def to_pod(self): + self.target: Optional[str] = None + self.modules: List[str] = [] + self.cpus: List[CpuInfo] = [] + self.os: Optional[str] = None + self.os_version: Optional[Dict[str, str]] = None + self.system_id: Optional[str] = None + self.hostid: Optional[int] = None + self.hostname: Optional[str] = None + self.abi: Optional[str] = None + self.is_rooted: Optional[bool] = None + self.kernel_version: Optional[KernelVersion] = None + self.kernel_config: Optional[KernelConfig] = None + self.sched_features: Optional[List[str]] = None + self.screen_resolution: Optional[Tuple[int, int]] = None + self.prop: Optional[AndroidProperties] = None + self.android_id: Optional[str] = None + self.page_size_kb: Optional[int] = None + + def to_pod(self) -> Dict[str, Any]: pod = super(TargetInfo, self).to_pod() pod['target'] = self.target pod['modules'] = self.modules @@ -378,20 +422,20 @@ def to_pod(self): pod['hostname'] = self.hostname pod['abi'] = self.abi pod['is_rooted'] = self.is_rooted - pod['kernel_release'] = self.kernel_version.release - pod['kernel_version'] = self.kernel_version.version - pod['kernel_config'] = dict(self.kernel_config.iteritems()) + pod['kernel_release'] = self.kernel_version.release if self.kernel_version else '' + pod['kernel_version'] = self.kernel_version.version if self.kernel_version else '' + pod['kernel_config'] = dict(self.kernel_config.iteritems()) if self.kernel_config else {} pod['sched_features'] = self.sched_features pod['page_size_kb'] = self.page_size_kb if self.os == 'android': pod['screen_resolution'] = self.screen_resolution - pod['prop'] = self.prop._properties + pod['prop'] = self.prop._properties if self.prop else {} pod['android_id'] = self.android_id return pod @staticmethod - def _pod_upgrade_v1(pod): + def _pod_upgrade_v1(pod: Dict[str, Any]) -> Dict[str, Any]: pod['_pod_version'] = pod.get('_pod_version', 1) pod['cpus'] = pod.get('cpus', []) pod['system_id'] = pod.get('system_id') @@ -404,24 +448,24 @@ def _pod_upgrade_v1(pod): return pod @staticmethod - def _pod_upgrade_v2(pod): + def _pod_upgrade_v2(pod: Dict[str, Any]) -> Dict[str, Any]: pod['page_size_kb'] = pod.get('page_size_kb') pod['_pod_version'] = pod.get('format_version', 0) return pod @staticmethod - def _pod_upgrade_v3(pod): - config = {} + def _pod_upgrade_v3(pod: Dict[str, Any]) -> Dict[str, Any]: + config: Dict[str, Any] = {} for key, value in pod['kernel_config'].items(): config[key.upper()] = value pod['kernel_config'] = config return pod @staticmethod - def _pod_upgrade_v4(pod): + def _pod_upgrade_v4(pod: Dict[str, Any]) -> Dict[str, Any]: return TargetInfo._pod_upgrade_v3(pod) @staticmethod - def _pod_upgrade_v5(pod): + def _pod_upgrade_v5(pod: Dict[str, Any]) -> Dict[str, Any]: pod['modules'] = pod.get('modules') or [] return pod diff --git a/wa/framework/target/manager.py b/wa/framework/target/manager.py index 387d7537b..40052d39a 100644 --- a/wa/framework/target/manager.py +++ b/wa/framework/target/manager.py @@ -15,9 +15,10 @@ import logging -from devlib import Gem5SimulationPlatform +from devlib.platform.gem5 import Gem5SimulationPlatform from devlib.utils.misc import memoized - +from devlib.target import Target +from devlib.module.hotplug import HotplugModule from wa.framework import signal from wa.framework.exception import ExecutionError, TargetError, TargetNotRespondingError from wa.framework.plugin import Parameter @@ -27,7 +28,16 @@ from wa.framework.target.info import (get_target_info, get_target_info_from_cache, cache_target_info, read_target_info_cache) from wa.framework.target.runtime_parameter_manager import RuntimeParameterManager -from wa.utils.types import module_name_set +from wa.framework.target.assistant import LinuxAssistant, AndroidAssistant, ChromeOsAssistant +from wa.utils.types import module_name_set, obj_dict +from typing import TYPE_CHECKING, Dict, Optional, Union, cast, Type, List, Any +from wa.framework.configuration.tree import JobSpecSource +if TYPE_CHECKING: + from wa.framework.configuration.core import ConfigurationPoint + from louie import dispatcher # type:ignore + from wa.framework.execution import ExecutionContext + from wa.framework.target.info import TargetInfo + from wa.framework.target.descriptor import TargetDescriptionProtocol class TargetManager(object): @@ -35,47 +45,57 @@ class TargetManager(object): Instantiate the required target and perform configuration and validation of the device. """ - parameters = [ - Parameter('disconnect', kind=bool, default=False, - description=""" - Specifies whether the target should be disconnected from - at the end of the run. - """), - ] + parameters: Dict[str, Parameter] = {'disconnect': + Parameter('disconnect', kind=bool, default=False, + description=""" + Specifies whether the target should be disconnected from + at the end of the run. + """), + } - def __init__(self, name, parameters, outdir): + def __init__(self, name: str, parameters: Dict[str, Parameter], + outdir: Optional[str]): self.outdir = outdir self.logger = logging.getLogger('tm') self.target_name = name - self.target = None - self.assistant = None - self.platform_name = None - self.is_responsive = None - self.rpm = None + self.target: Optional[Target] = None + self.assistant: Optional[Union[LinuxAssistant, AndroidAssistant, ChromeOsAssistant]] = None + self.platform_name: Optional[str] = None + self.is_responsive: Optional[bool] = None + self.rpm: Optional[RuntimeParameterManager] = None self.parameters = parameters + # FIXME - seems like improper access. list of Parameter dont have get property self.disconnect = parameters.get('disconnect') - def initialize(self): + def initialize(self) -> None: + """ + initialize target and assistant + """ self._init_target() - self.assistant.initialize() + if self.assistant: + self.assistant.initialize() # If target supports hotplugging, online all cpus before perform discovery # and restore original configuration after completed. - if self.target.has('hotplug'): - online_cpus = self.target.list_online_cpus() - try: - self.target.hotplug.online_all() - except TargetError: - msg = 'Failed to online all CPUS - some information may not be '\ - 'able to be retrieved.' - self.logger.debug(msg) - self.rpm = RuntimeParameterManager(self.target) - all_cpus = set(range(self.target.number_of_cpus)) - self.target.hotplug.offline(*all_cpus.difference(online_cpus)) - else: - self.rpm = RuntimeParameterManager(self.target) + if self.target: + if self.target.has('hotplug'): + online_cpus: List[int] = self.target.list_online_cpus() + try: + cast(HotplugModule, self.target.hotplug).online_all() + except TargetError: + msg: str = 'Failed to online all CPUS - some information may not be '\ + 'able to be retrieved.' + self.logger.debug(msg) + self.rpm = RuntimeParameterManager(self.target) + all_cpus = set(range(self.target.number_of_cpus)) + cast(HotplugModule, self.target.hotplug).offline(*all_cpus.difference(online_cpus)) + else: + self.rpm = RuntimeParameterManager(self.target) - def finalize(self): + def finalize(self) -> None: + """ + finalize target and assistant + """ if not self.target: return if self.assistant: @@ -85,20 +105,37 @@ def finalize(self): with signal.wrap('TARGET_DISCONNECT'): self.target.disconnect() - def start(self): - self.assistant.start() + def start(self) -> None: + """ + start assistant + """ + if self.assistant: + self.assistant.start() - def stop(self): - self.assistant.stop() + def stop(self) -> None: + """ + stop assistant + """ + if self.assistant: + self.assistant.stop() - def extract_results(self, context): - self.assistant.extract_results(context) + def extract_results(self, context: 'ExecutionContext') -> None: + """ + extract results from target + """ + if self.assistant: + self.assistant.extract_results(context) @memoized - def get_target_info(self): - cache = read_target_info_cache() - info = get_target_info_from_cache(self.target.system_id, cache=cache) - + def get_target_info(self) -> 'TargetInfo': + """ + get target information + """ + cache: Dict[str, Any] = read_target_info_cache() + info: Optional['TargetInfo'] = get_target_info_from_cache(self.target.system_id, + cache=cache) if self.target and self.target.system_id else None + if self.target is None: + raise TargetError("Target is None") if info is None: info = get_target_info(self.target) cache_target_info(info, cache=cache) @@ -112,21 +149,46 @@ def get_target_info(self): return info - def reboot(self, context, hard=False): - with signal.wrap('REBOOT', self, context): - self.target.reboot(hard) - - def merge_runtime_parameters(self, parameters): + def reboot(self, context: 'ExecutionContext', hard: bool = False) -> None: + """ + reboot the target + """ + with signal.wrap('REBOOT', cast(Type[dispatcher.Anonymous], self), context): + if self.target: + self.target.reboot(hard) + + def merge_runtime_parameters(self, + parameters: Dict[JobSpecSource, Dict[str, 'ConfigurationPoint']]) -> Dict: + """ + merge runtime parameters from different sources + """ + if self.rpm is None: + raise ExecutionError('rpm is not set') return self.rpm.merge_runtime_parameters(parameters) - def validate_runtime_parameters(self, parameters): + def validate_runtime_parameters(self, parameters: obj_dict) -> None: + """ + validate the runtime parameters + """ + if self.rpm is None: + raise ExecutionError('rpm is not set') self.rpm.validate_runtime_parameters(parameters) - def commit_runtime_parameters(self, parameters): + def commit_runtime_parameters(self, parameters: obj_dict) -> None: + """ + commit the runtime parameters to runtime parameter manager + """ + if self.rpm is None: + raise ExecutionError('rpm is not set') self.rpm.commit_runtime_parameters(parameters) - def verify_target_responsive(self, context): - can_reboot = context.reboot_policy.can_reboot + def verify_target_responsive(self, context: 'ExecutionContext') -> None: + """ + verify that the target is responsive + """ + can_reboot: bool = context.reboot_policy.can_reboot + if self.target is None: + raise TargetError("Target is None") if not self.target.check_responsive(explode=False): self.is_responsive = False if not can_reboot: @@ -139,12 +201,15 @@ def verify_target_responsive(self, context): else: raise TargetNotRespondingError('Target unresponsive and hard reset not supported; bailing.') - def _init_target(self): - tdesc = get_target_description(self.target_name) + def _init_target(self) -> None: + """ + initialize target + """ + tdesc: 'TargetDescriptionProtocol' = get_target_description(self.target_name) - extra_plat_params = {} + extra_plat_params: Dict[str, str] = {} if tdesc.platform is Gem5SimulationPlatform: - extra_plat_params['host_output_dir'] = self.outdir + extra_plat_params['host_output_dir'] = self.outdir or '' self.logger.debug('Creating {} target'.format(self.target_name)) self.target = instantiate_target(tdesc, self.parameters, connect=False, diff --git a/wa/framework/target/runtime_config.py b/wa/framework/target/runtime_config.py index 471c2ff40..3beb33269 100644 --- a/wa/framework/target/runtime_config.py +++ b/wa/framework/target/runtime_config.py @@ -26,40 +26,60 @@ from wa.framework.plugin import Plugin, Parameter from wa.utils.misc import resolve_cpus, resolve_unique_domain_cpus from wa.utils.types import caseless_string, enum - +from devlib.target import Target, AndroidTarget +from devlib.module.hotplug import HotplugModule +from devlib.module.cpufreq import CpufreqModule +from devlib.module.cpuidle import Cpuidle, CpuidleState +from typing import (Optional, Set, List, Tuple, Callable, Dict, + Any, cast, DefaultDict, OrderedDict as od, Union) logger = logging.getLogger('RuntimeConfig') class RuntimeParameter(Parameter): - def __init__(self, name, setter, setter_params=None, **kwargs): + """ + represents a runtime parameter + """ + def __init__(self, name: str, setter: Callable, + setter_params: Optional[Dict] = None, **kwargs): super(RuntimeParameter, self).__init__(name, **kwargs) self.setter = setter self.setter_params = setter_params or {} - def set(self, obj, value): + def set(self, obj: Any, value: Any) -> None: + """ + set value to object + """ self.validate_value(self.name, value) self.setter(obj, value, **self.setter_params) class RuntimeConfig(Plugin): - - name = None - kind = 'runtime-config' + """" + represents a runtime configuration + """ + name: Optional[str] = None + kind: str = 'runtime-config' @property - def supported_parameters(self): + def supported_parameters(self) -> List[Parameter]: + """ + supported parameters + """ return list(self._runtime_params.values()) @property - def core_names(self): + def core_names(self) -> List[str]: + """ + list of core names + """ return unique(self.target.core_names) - def __init__(self, target, **kwargs): + def __init__(self, target: Target, **kwargs): super(RuntimeConfig, self).__init__(**kwargs) self.target = target - self._target_checked = False - self._runtime_params = {} + self._target_checked: bool = False + self._runtime_params: Dict[str, RuntimeParameter] = {} try: self.initialize() except TargetError: @@ -67,30 +87,51 @@ def __init__(self, target, **kwargs): self.logger.debug(msg.format(self.name)) self._runtime_params = {} - def initialize(self): + def initialize(self) -> None: + """ + initialize runtime configuration + """ raise NotImplementedError() - def commit(self): + def commit(self) -> None: + """ + commit runtime configuration + """ raise NotImplementedError() - def set_runtime_parameter(self, name, value): + def set_runtime_parameter(self, name: str, value: Any) -> None: + """ + set runtime parameters + """ if not self._target_checked: self.check_target() self._target_checked = True self._runtime_params[name].set(self, value) - def set_defaults(self): + def set_defaults(self) -> None: + """ + set default runtime configuration parameters + """ for p in self.supported_parameters: if p.default: self.set_runtime_parameter(p.name, p.default) - def validate_parameters(self): + def validate_parameters(self) -> None: + """ + validate runtime configuration parameters + """ raise NotImplementedError() - def check_target(self): + def check_target(self) -> Optional[bool]: + """ + check target for runtime configuration + """ raise NotImplementedError() - def clear(self): + def clear(self) -> None: + """ + clear the runtime configuration + """ raise NotImplementedError() @@ -100,37 +141,40 @@ class HotplugRuntimeConfig(RuntimeConfig): was hotplugged out when the devlib target was created. ''' - name = 'rt-hotplug' + name: str = 'rt-hotplug' @staticmethod - def set_num_cores(obj, value, core): - cpus = resolve_cpus(core, obj.target) - max_cores = len(cpus) - value = integer(value) - if value > max_cores: + def set_num_cores(obj: 'HotplugRuntimeConfig', value: Any, core: str) -> None: + """ + set number of cores to be enabled + """ + cpus: List[int] = resolve_cpus(core, obj.target) + max_cores: int = len(cpus) + value_int = integer(value) + if value_int > max_cores: msg = 'Cannot set number of {}\'s to {}; max is {}' - raise ValueError(msg.format(core, value, max_cores)) + raise ValueError(msg.format(core, value_int, max_cores)) msg = 'CPU{} Hotplugging already configured' # Set cpus to be enabled - for cpu in cpus[:value]: + for cpu in cpus[:value_int]: if cpu in obj.num_cores: raise ConfigError(msg.format(cpu)) obj.num_cores[cpu] = True # Set the remaining cpus to be disabled. - for cpu in cpus[value:]: + for cpu in cpus[value_int:]: if cpu in obj.num_cores: raise ConfigError(msg.format(cpu)) obj.num_cores[cpu] = False - def __init__(self, target): - self.num_cores = defaultdict(dict) + def __init__(self, target: Target): + self.num_cores: DefaultDict = defaultdict(dict) super(HotplugRuntimeConfig, self).__init__(target) - def initialize(self): + def initialize(self) -> None: if not self.target.has('hotplug'): return - param_name = 'num_cores' + param_name: str = 'num_cores' self._runtime_params[param_name] = \ RuntimeParameter(param_name, kind=int, constraint=lambda x: 0 <= x <= self.target.number_of_cpus, @@ -173,38 +217,51 @@ def initialize(self): setter=self.set_num_cores, setter_params={'core': cluster}) - def check_target(self): + def check_target(self) -> Optional[bool]: + """ + check whether target supports hotplugging + """ if not self.target.has('hotplug'): raise TargetError('Target does not appear to support hotplug') + return True - def validate_parameters(self): + def validate_parameters(self) -> None: + """ + validate parameters of hotplug + """ if self.num_cores and len(self.num_cores) == self.target.number_of_cpus: if all(v is False for v in list(self.num_cores.values())): raise ValueError('Cannot set number of all cores to 0') - def commit(self): + def commit(self) -> None: '''Online all CPUs required in order before then off-lining''' - num_cores = sorted(self.num_cores.items()) + num_cores: List[Tuple[int, bool]] = sorted(self.num_cores.items()) for cpu, online in num_cores: if online: - self.target.hotplug.online(cpu) + cast(HotplugModule, self.target.hotplug).online(cpu) for cpu, online in reversed(num_cores): if not online: - self.target.hotplug.offline(cpu) + cast(HotplugModule, self.target.hotplug).offline(cpu) - def clear(self): + def clear(self) -> None: self.num_cores = defaultdict(dict) class SysfileValuesRuntimeConfig(RuntimeConfig): - - name = 'rt-sysfiles' + """ + sys file values runtime configuration + """ + name: str = 'rt-sysfiles' # pylint: disable=unused-argument @staticmethod - def set_sysfile(obj, values, core): + def set_sysfile(obj: 'SysfileValuesRuntimeConfig', + values: Dict[str, Any], core: str) -> None: + """ + set sys file + """ for path, value in values.items(): - verify = True + verify: bool = True if path.endswith('!'): verify = False path = path[:-1] @@ -215,11 +272,11 @@ def set_sysfile(obj, values, core): obj.sysfile_values[path] = (value, verify) - def __init__(self, target): - self.sysfile_values = OrderedDict() + def __init__(self, target: Target): + self.sysfile_values: od[str, Tuple[Any, bool]] = OrderedDict() super(SysfileValuesRuntimeConfig, self).__init__(target) - def initialize(self): + def initialize(self) -> None: self._runtime_params['sysfile_values'] = \ RuntimeParameter('sysfile_values', kind=dict, merge=True, setter=self.set_sysfile, @@ -228,33 +285,38 @@ def initialize(self): Sysfile path to be set """) - def check_target(self): + def check_target(self) -> Optional[bool]: return True - def validate_parameters(self): + def validate_parameters(self) -> None: return - def commit(self): + def commit(self) -> None: for path, (value, verify) in self.sysfile_values.items(): self.target.write_value(path, value, verify=verify) - def clear(self): + def clear(self) -> None: self.sysfile_values = OrderedDict() - def check_exists(self, path): + def check_exists(self, path: str) -> None: + """ + check if file exists in the path + """ if not self.target.file_exists(path): raise ConfigError('Sysfile "{}" does not exist.'.format(path)) class FreqValue(object): - - def __init__(self, values): + """ + frequency values + """ + def __init__(self, values: Optional[Set[int]]): if values is None: - self.values = values + self.values: Optional[List[int]] = values else: self.values = sorted(values) - def __call__(self, value): + def __call__(self, value: Union[int, str]): ''' `self.values` can be `None` if the device's supported values could not be retrieved for some reason e.g. the cluster was offline, in this case we assume @@ -264,7 +326,7 @@ def __call__(self, value): if isinstance(value, int): return value else: - msg = 'CPU frequency values could not be retrieved, cannot resolve "{}"' + msg: str = 'CPU frequency values could not be retrieved, cannot resolve "{}"' raise TargetError(msg.format(value)) elif isinstance(value, int) and value in self.values: return value @@ -281,46 +343,48 @@ def __str__(self): class CpufreqRuntimeConfig(RuntimeConfig): - - name = 'rt-cpufreq' + """ + cpu frequency runtime configuration + """ + name: str = 'rt-cpufreq' @staticmethod - def set_frequency(obj, value, core): + def set_frequency(obj: 'CpufreqRuntimeConfig', value: Any, core: str): obj.set_param(obj, value, core, 'frequency') @staticmethod - def set_max_frequency(obj, value, core): + def set_max_frequency(obj: 'CpufreqRuntimeConfig', value: Any, core: str): obj.set_param(obj, value, core, 'max_frequency') @staticmethod - def set_min_frequency(obj, value, core): + def set_min_frequency(obj: 'CpufreqRuntimeConfig', value: Any, core: str): obj.set_param(obj, value, core, 'min_frequency') @staticmethod - def set_governor(obj, value, core): + def set_governor(obj: 'CpufreqRuntimeConfig', value: Any, core: str): obj.set_param(obj, value, core, 'governor') @staticmethod - def set_governor_tunables(obj, value, core): + def set_governor_tunables(obj: 'CpufreqRuntimeConfig', value: Any, core: str): obj.set_param(obj, value, core, 'governor_tunables') @staticmethod - def set_param(obj, value, core, parameter): + def set_param(obj: 'CpufreqRuntimeConfig', value: Any, core: str, parameter: str): '''Method to store passed parameter if it is not already specified for that cpu''' - cpus = resolve_unique_domain_cpus(core, obj.target) + cpus: List[int] = resolve_unique_domain_cpus(core, obj.target) for cpu in cpus: if parameter in obj.config[cpu]: - msg = 'Cannot set "{}" for core "{}"; Parameter for CPU{} has already been set' + msg: str = 'Cannot set "{}" for core "{}"; Parameter for CPU{} has already been set' raise ConfigError(msg.format(parameter, core, cpu)) obj.config[cpu][parameter] = value - def __init__(self, target): - self.config = defaultdict(dict) - self.supported_cpu_freqs = {} - self.supported_cpu_governors = {} + def __init__(self, target: Target): + self.config: DefaultDict[int, Dict[str, Any]] = defaultdict(dict) + self.supported_cpu_freqs: Dict[int, Set[int]] = {} + self.supported_cpu_governors: Dict[int, Set[str]] = {} super(CpufreqRuntimeConfig, self).__init__(target) - def initialize(self): + def initialize(self) -> None: # pylint: disable=too-many-statements if not self.target.has('cpufreq'): return @@ -330,7 +394,7 @@ def initialize(self): # Add common parameters if available. freq_val = FreqValue(common_freqs) - param_name = 'frequency' + param_name: str = 'frequency' self._runtime_params[param_name] = \ RuntimeParameter( param_name, kind=freq_val, @@ -384,9 +448,9 @@ def initialize(self): # Add core name parameters for name in unique(self.target.platform.core_names): - cpu = resolve_unique_domain_cpus(name, self.target)[0] + cpu: int = resolve_unique_domain_cpus(name, self.target)[0] freq_val = FreqValue(self.supported_cpu_freqs.get(cpu)) - avail_govs = self.supported_cpu_governors.get(cpu) + avail_govs: Optional[Set[str]] = self.supported_cpu_governors.get(cpu) param_name = '{}_frequency'.format(name) self._runtime_params[param_name] = \ @@ -544,20 +608,21 @@ def initialize(self): The governor tunables to be set for the {} cores """.format(cluster)) - def check_target(self): + def check_target(self) -> Optional[bool]: if not self.target.has('cpufreq'): raise TargetError('Target does not appear to support cpufreq') + return True - def validate_parameters(self): + def validate_parameters(self) -> None: '''Method to validate parameters against each other''' for cpu in self.config: - config = self.config[cpu] - minf = config.get('min_frequency') - maxf = config.get('max_frequency') - freq = config.get('frequency') + config: Dict[str, Any] = self.config[cpu] + minf: int = config.get('min_frequency') or 0 + maxf: int = config.get('max_frequency') or 0 + freq: int = config.get('frequency') or 0 if freq and minf: - msg = 'CPU{}: Can\'t set both cpu frequency and minimum frequency' + msg: str = 'CPU{}: Can\'t set both cpu frequency and minimum frequency' raise ConfigError(msg.format(cpu)) if freq and maxf: msg = 'CPU{}: Can\'t set both cpu frequency and maximum frequency' @@ -567,12 +632,12 @@ def validate_parameters(self): msg = 'CPU{}: min_frequency "{}" cannot be greater than max_frequency "{}"' raise ConfigError(msg.format(cpu, minf, maxf)) - def commit(self): + def commit(self) -> None: for cpu in self.config: - config = self.config[cpu] - freq = self._resolve_freq(config.get('frequency'), cpu) - minf = self._resolve_freq(config.get('min_frequency'), cpu) - maxf = self._resolve_freq(config.get('max_frequency'), cpu) + config: Dict[str, Any] = self.config[cpu] + freq: int = self._resolve_freq(config.get('frequency') or 0, cpu) + minf: int = self._resolve_freq(config.get('min_frequency') or 0, cpu) + maxf: int = self._resolve_freq(config.get('max_frequency') or 0, cpu) self.configure_governor(cpu, config.get('governor'), @@ -582,21 +647,29 @@ def commit(self): def clear(self): self.config = defaultdict(dict) - def configure_governor(self, cpu, governor=None, gov_tunables=None): + def configure_governor(self, cpu: int, governor: Optional[str] = None, + gov_tunables: Optional[Dict] = None) -> None: + """ + configure governor + """ if not governor and not gov_tunables: return if cpu not in self.target.list_online_cpus(): - msg = 'Cannot configure governor for {} as no CPUs are online.' + msg: str = 'Cannot configure governor for {} as no CPUs are online.' raise TargetError(msg.format(cpu)) if not governor: - governor = self.target.get_governor(cpu) + governor = cast(CpufreqModule, self.target.cpufreq).get_governor(cpu) if not gov_tunables: gov_tunables = {} - self.target.cpufreq.set_governor(cpu, governor, **gov_tunables) + cast(CpufreqModule, self.target.cpufreq).set_governor(cpu, governor, **gov_tunables) - def configure_frequency(self, cpu, freq=None, min_freq=None, max_freq=None, governor=None): + def configure_frequency(self, cpu: int, freq: Optional[int] = None, min_freq: Optional[int] = None, + max_freq: Optional[int] = None, governor: Optional[str] = None) -> None: + """ + configure frequency + """ if freq and (min_freq or max_freq): - msg = 'Cannot specify both frequency and min/max frequency' + msg: str = 'Cannot specify both frequency and min/max frequency' raise ConfigError(msg) if cpu not in self.target.list_online_cpus(): @@ -608,53 +681,59 @@ def configure_frequency(self, cpu, freq=None, min_freq=None, max_freq=None, gove else: self._set_min_max_frequencies(cpu, min_freq, max_freq) - def _resolve_freq(self, value, cpu): + def _resolve_freq(self, value: Union[str, int], cpu: int) -> int: if value == 'min': - value = self.target.cpufreq.get_min_available_frequency(cpu) + value = cast(CpufreqModule, self.target.cpufreq).get_min_available_frequency(cpu) or 0 elif value == 'max': - value = self.target.cpufreq.get_max_available_frequency(cpu) - return value + value = cast(CpufreqModule, self.target.cpufreq).get_max_available_frequency(cpu) or 0 + return cast(int, value) - def _set_frequency(self, cpu, freq, governor): + def _set_frequency(self, cpu: int, freq: int, governor: Optional[str]) -> None: + """ + set frequency to the cpu under the specified governor + """ if not governor: - governor = self.target.cpufreq.get_governor(cpu) + governor = cast(CpufreqModule, self.target.cpufreq).get_governor(cpu) has_userspace = governor == 'userspace' # Sets all frequency to be to desired frequency - if freq < self.target.cpufreq.get_frequency(cpu): - self.target.cpufreq.set_min_frequency(cpu, freq) + if freq < cast(CpufreqModule, self.target.cpufreq).get_frequency(cpu): + cast(CpufreqModule, self.target.cpufreq).set_min_frequency(cpu, freq) if has_userspace: - self.target.cpufreq.set_frequency(cpu, freq) - self.target.cpufreq.set_max_frequency(cpu, freq) + cast(CpufreqModule, self.target.cpufreq).set_frequency(cpu, freq) + cast(CpufreqModule, self.target.cpufreq).set_max_frequency(cpu, freq) else: - self.target.cpufreq.set_max_frequency(cpu, freq) + cast(CpufreqModule, self.target.cpufreq).set_max_frequency(cpu, freq) if has_userspace: - self.target.cpufreq.set_frequency(cpu, freq) - self.target.cpufreq.set_min_frequency(cpu, freq) - - def _set_min_max_frequencies(self, cpu, min_freq, max_freq): - min_freq_set = False - current_min_freq = self.target.cpufreq.get_min_frequency(cpu) - current_max_freq = self.target.cpufreq.get_max_frequency(cpu) + cast(CpufreqModule, self.target.cpufreq).set_frequency(cpu, freq) + cast(CpufreqModule, self.target.cpufreq).set_min_frequency(cpu, freq) + + def _set_min_max_frequencies(self, cpu: int, min_freq: Optional[int], max_freq: Optional[int]) -> None: + """ + set minimum and maximum frequencies + """ + min_freq_set: bool = False + current_min_freq: int = cast(CpufreqModule, self.target.cpufreq).get_min_frequency(cpu) + current_max_freq: int = cast(CpufreqModule, self.target.cpufreq).get_max_frequency(cpu) if max_freq: if max_freq < current_min_freq: if min_freq: - self.target.cpufreq.set_min_frequency(cpu, min_freq) - self.target.cpufreq.set_max_frequency(cpu, max_freq) + cast(CpufreqModule, self.target.cpufreq).set_min_frequency(cpu, min_freq) + cast(CpufreqModule, self.target.cpufreq).set_max_frequency(cpu, max_freq) min_freq_set = True else: - msg = 'CPU {}: Cannot set max_frequency ({}) below current min frequency ({}).' + msg: str = 'CPU {}: Cannot set max_frequency ({}) below current min frequency ({}).' raise ConfigError(msg.format(cpu, max_freq, current_min_freq)) else: - self.target.cpufreq.set_max_frequency(cpu, max_freq) + cast(CpufreqModule, self.target.cpufreq).set_max_frequency(cpu, max_freq) if min_freq and not min_freq_set: current_max_freq = max_freq or current_max_freq if min_freq > current_max_freq: msg = 'CPU {}: Cannot set min_frequency ({}) above current max frequency ({}).' raise ConfigError(msg.format(cpu, min_freq, current_max_freq)) - self.target.cpufreq.set_min_frequency(cpu, min_freq) + cast(CpufreqModule, self.target.cpufreq).set_min_frequency(cpu, min_freq) - def _retrive_cpufreq_info(self): + def _retrive_cpufreq_info(self) -> None: ''' Tries to retrieve cpu freq information for all cpus on device. For each cpu domain, only one cpu is queried for information and @@ -663,15 +742,15 @@ def _retrive_cpufreq_info(self): can still be populated. ''' for cluster_cpu in resolve_unique_domain_cpus('all', self.target): - domain_cpus = self.target.cpufreq.get_related_cpus(cluster_cpu) + domain_cpus: List[int] = cast(CpufreqModule, self.target.cpufreq).get_related_cpus(cluster_cpu) for cpu in domain_cpus: if cpu in self.target.list_online_cpus(): - supported_cpu_freqs = self.target.cpufreq.list_frequencies(cpu) - supported_cpu_governors = self.target.cpufreq.list_governors(cpu) + supported_cpu_freqs: Set[int] = cast(CpufreqModule, self.target.cpufreq).list_frequencies(cpu) + supported_cpu_governors: Set[str] = cast(CpufreqModule, self.target.cpufreq).list_governors(cpu) break else: - msg = 'CPUFreq information could not be retrieved for{};'\ - 'Will not be validated against device.' + msg: str = 'CPUFreq information could not be retrieved for{};'\ + 'Will not be validated against device.' logger.debug(msg.format(' CPU{},'.format(cpu for cpu in domain_cpus))) return @@ -679,12 +758,12 @@ def _retrive_cpufreq_info(self): self.supported_cpu_freqs[cpu] = supported_cpu_freqs self.supported_cpu_governors[cpu] = supported_cpu_governors - def _get_common_values(self): + def _get_common_values(self) -> Tuple[Optional[Set[int]], Optional[Set[int]], Optional[Set[str]]]: ''' Find common values for frequency and governors across all cores''' - common_freqs = None - common_gov = None - all_freqs = None - initialized = False + common_freqs: Optional[Set[int]] = None + common_gov: Optional[Set[str]] = None + all_freqs: Optional[Set[int]] = None + initialized: bool = False for cpu in resolve_unique_domain_cpus('all', self.target): if not initialized: initialized = True @@ -692,22 +771,24 @@ def _get_common_values(self): all_freqs = copy(common_freqs) common_gov = set(self.supported_cpu_governors.get(cpu) or []) else: - common_freqs = common_freqs.intersection(self.supported_cpu_freqs.get(cpu) or set()) - all_freqs = all_freqs.union(self.supported_cpu_freqs.get(cpu) or set()) - common_gov = common_gov.intersection(self.supported_cpu_governors.get(cpu) or set()) + common_freqs = common_freqs.intersection(self.supported_cpu_freqs.get(cpu) or set()) if common_freqs else set() + all_freqs = all_freqs.union(self.supported_cpu_freqs.get(cpu) or set()) if all_freqs else set() + common_gov = common_gov.intersection(self.supported_cpu_governors.get(cpu) or set()) if common_gov else set() return all_freqs, common_freqs, common_gov class IdleStateValue(object): - - def __init__(self, values): + """ + value of idle state + """ + def __init__(self, values: Optional[List[CpuidleState]]): if values is None: self.values = values else: self.values = [(value.id, value.name, value.desc) for value in values] - def __call__(self, value): + def __call__(self, value: Union[List[str], str]): if self.values is None: return value @@ -721,16 +802,18 @@ def __call__(self, value): return [self._get_state_ID(value)] elif isinstance(value, list): - valid_states = [] + valid_states: List[str] = [] for state in value: valid_states.append(self._get_state_ID(state)) return valid_states else: raise ValueError('Invalid IdleState: "{}"'.format(value)) - def _get_state_ID(self, value): + def _get_state_ID(self, value: str) -> str: '''Checks passed state and converts to its ID''' value = caseless_string(value) + if not self.values: + raise ValueError('self.values is none') for s_id, s_name, s_desc in self.values: if value in (s_id, s_name, s_desc): return s_id @@ -742,33 +825,35 @@ def __str__(self): class CpuidleRuntimeConfig(RuntimeConfig): - - name = 'rt-cpuidle' + """ + cpu idle runtime configuration + """ + name: str = 'rt-cpuidle' @staticmethod - def set_idle_state(obj, value, core): - cpus = resolve_cpus(core, obj.target) + def set_idle_state(obj: 'CpuidleRuntimeConfig', value: Any, core: str) -> None: + cpus: List[int] = resolve_cpus(core, obj.target) for cpu in cpus: obj.config[cpu] = [] for state in value: obj.config[cpu].append(state) - def __init__(self, target): - self.config = defaultdict(dict) - self.supported_idle_states = {} + def __init__(self, target: Target): + self.config: DefaultDict = defaultdict(dict) + self.supported_idle_states: Dict[int, List[CpuidleState]] = {} super(CpuidleRuntimeConfig, self).__init__(target) - def initialize(self): + def initialize(self) -> None: if not self.target.has('cpuidle'): return self._retrieve_device_idle_info() - common_idle_states = self._get_common_idle_values() + common_idle_states: List[CpuidleState] = self._get_common_idle_values() idle_state_val = IdleStateValue(common_idle_states) if common_idle_states: - param_name = 'idle_states' + param_name: str = 'idle_states' self._runtime_params[param_name] = \ RuntimeParameter( param_name, kind=idle_state_val, @@ -779,7 +864,7 @@ def initialize(self): """) for name in unique(self.target.platform.core_names): - cpu = resolve_cpus(name, self.target)[0] + cpu: int = resolve_cpus(name, self.target)[0] idle_state_val = IdleStateValue(self.supported_idle_states.get(cpu)) param_name = '{}_idle_states'.format(name) self._runtime_params[param_name] = \ @@ -817,33 +902,37 @@ def initialize(self): The idle states to be set for the {} cores """.format(cluster)) - def check_target(self): + def check_target(self) -> Optional[bool]: if not self.target.has('cpuidle'): raise TargetError('Target does not appear to support cpuidle') + return True - def validate_parameters(self): + def validate_parameters(self) -> None: return - def clear(self): + def clear(self) -> None: self.config = defaultdict(dict) - def commit(self): + def commit(self) -> None: for cpu in self.config: - idle_states = set(state.id for state in self.supported_idle_states.get(cpu, [])) - enabled = self.config[cpu] - disabled = idle_states.difference(enabled) + idle_states: Set[str] = set(state.id for state in self.supported_idle_states.get(cpu, [])) + enabled: Set[str] = self.config[cpu] + disabled: Set[str] = idle_states.difference(enabled) for state in enabled: - self.target.cpuidle.enable(state, cpu) + cast(Cpuidle, self.target.cpuidle).enable(state, cpu) for state in disabled: - self.target.cpuidle.disable(state, cpu) + cast(Cpuidle, self.target.cpuidle).disable(state, cpu) - def _retrieve_device_idle_info(self): + def _retrieve_device_idle_info(self) -> None: + """ + get the device idle info + """ for cpu in range(self.target.number_of_cpus): - self.supported_idle_states[cpu] = self.target.cpuidle.get_states(cpu) + self.supported_idle_states[cpu] = cast(Cpuidle, self.target.cpuidle).get_states(cpu) - def _get_common_idle_values(self): + def _get_common_idle_values(self) -> List[CpuidleState]: '''Find common values for cpu idle states across all cores''' - common_idle_states = [] + common_idle_states: List[CpuidleState] = [] for cpu in range(self.target.number_of_cpus): for state in self.supported_idle_states.get(cpu) or []: if state.name not in common_idle_states: @@ -855,45 +944,63 @@ def _get_common_idle_values(self): class AndroidRuntimeConfig(RuntimeConfig): - - name = 'rt-android' + """ + android runtime configuration + """ + name: str = 'rt-android' @staticmethod - def set_brightness(obj, value): + def set_brightness(obj: 'AndroidRuntimeConfig', value: Optional[int]) -> None: + """ + set brightness + """ if value is not None: obj.config['brightness'] = value @staticmethod - def set_airplane_mode(obj, value): + def set_airplane_mode(obj: 'AndroidRuntimeConfig', value: Optional[bool]) -> None: + """ + set airplane mode + """ if value is not None: obj.config['airplane_mode'] = value @staticmethod - def set_rotation(obj, value): + def set_rotation(obj: 'AndroidRuntimeConfig', value: Any) -> None: + """ + set rotation + """ if value is not None: obj.config['rotation'] = value.value @staticmethod - def set_screen_state(obj, value): + def set_screen_state(obj: 'AndroidRuntimeConfig', value: Optional[bool]) -> None: + """ + set screen on or off state + """ if value is not None: obj.config['screen_on'] = value @staticmethod - def set_unlock_screen(obj, value): + def set_unlock_screen(obj: 'AndroidRuntimeConfig', value: str) -> None: + """ + set unlock screen + """ if value is not None: obj.config['unlock_screen'] = value - def __init__(self, target): - self.config = defaultdict(dict) + def __init__(self, target: Target): + self.config: DefaultDict[str, Any] = defaultdict(dict) super(AndroidRuntimeConfig, self).__init__(target) + self.target = cast(AndroidTarget, target) - def initialize(self): + def initialize(self) -> None: if self.target.os not in ['android', 'chromeos']: return if self.target.os == 'chromeos' and not self.target.supports_android: return - param_name = 'brightness' + param_name: str = 'brightness' self._runtime_params[param_name] = \ RuntimeParameter( param_name, kind=int, @@ -945,21 +1052,22 @@ def initialize(self): Specify how the device screen should be unlocked (e.g., vertical) """) - def check_target(self): + def check_target(self) -> Optional[bool]: if self.target.os != 'android' and self.target.os != 'chromeos': raise ConfigError('Target does not appear to be running Android') if self.target.os == 'chromeos' and not self.target.supports_android: raise ConfigError('Target does not appear to support Android') + return True - def validate_parameters(self): + def validate_parameters(self) -> None: pass - def commit(self): + def commit(self) -> None: # pylint: disable=too-many-branches if 'airplane_mode' in self.config: - new_airplane_mode = self.config['airplane_mode'] - old_airplane_mode = self.target.get_airplane_mode() - self.target.set_airplane_mode(new_airplane_mode) + new_airplane_mode: bool = self.config['airplane_mode'] + old_airplane_mode: bool = cast(Callable, self.target.get_airplane_mode)() + cast(Callable, self.target.set_airplane_mode)(new_airplane_mode) # If we've just switched airplane mode off, wait a few seconds to # enable the network state to stabilise. That's helpful if we're @@ -967,7 +1075,7 @@ def commit(self): # connectivity. if old_airplane_mode and not new_airplane_mode: self.logger.info('Disabled airplane mode, waiting up to 20 seconds for network setup') - network_is_ready = False + network_is_ready: bool = False for _ in range(4): time.sleep(5) network_is_ready = self.target.is_network_connected() @@ -979,21 +1087,21 @@ def commit(self): self.logger.warning("Network unreachable") if 'brightness' in self.config: - self.target.set_brightness(self.config['brightness']) + cast(Callable, self.target.set_brightness)(self.config['brightness']) if 'rotation' in self.config: - self.target.set_rotation(self.config['rotation']) + cast(Callable, self.target.set_rotation)(self.config['rotation']) if 'screen_on' in self.config: if self.config['screen_on']: - self.target.ensure_screen_is_on() + cast(Callable, self.target.ensure_screen_is_on)() else: - self.target.ensure_screen_is_off() + cast(Callable, self.target.ensure_screen_is_off)() if self.config.get('unlock_screen'): - self.target.ensure_screen_is_on() - if self.target.is_screen_locked(): - self.target.swipe_to_unlock(self.config['unlock_screen']) + cast(Callable, self.target.ensure_screen_is_on)() + if cast(Callable, self.target.is_screen_locked)(): + cast(Callable, self.target.swipe_to_unlock)(self.config['unlock_screen']) - def clear(self): - self.config = {} + def clear(self) -> None: + self.config.clear() diff --git a/wa/framework/target/runtime_parameter_manager.py b/wa/framework/target/runtime_parameter_manager.py index 77365dd28..0dd6a5c9c 100644 --- a/wa/framework/target/runtime_parameter_manager.py +++ b/wa/framework/target/runtime_parameter_manager.py @@ -20,14 +20,21 @@ HotplugRuntimeConfig, CpufreqRuntimeConfig, CpuidleRuntimeConfig, - AndroidRuntimeConfig) + AndroidRuntimeConfig, + RuntimeConfig) from wa.utils.types import obj_dict, caseless_string from wa.framework import pluginloader +from typing import TYPE_CHECKING, Dict, List, Optional, cast, Type +from wa.framework.configuration.tree import JobSpecSource +if TYPE_CHECKING: + from wa.framework.configuration.core import ConfigurationPoint + from devlib.target import Target + from wa.framework.pluginloader import __LoaderWrapper class RuntimeParameterManager(object): - runtime_config_cls = [ + runtime_config_cls: List[Type[RuntimeConfig]] = [ # order matters SysfileValuesRuntimeConfig, HotplugRuntimeConfig, @@ -36,73 +43,94 @@ class RuntimeParameterManager(object): AndroidRuntimeConfig, ] - def __init__(self, target): + def __init__(self, target: 'Target'): self.target = target - self.runtime_params = {} + RuntimeParameter = namedtuple('RuntimeParameter', 'cfg_point, rt_config') + self.runtime_params: Dict[str, RuntimeParameter] = {} try: - for rt_cls in pluginloader.list_plugins(kind='runtime-config'): + for rt_cls in cast('__LoaderWrapper', pluginloader).list_plugins(kind='runtime-config'): if rt_cls not in self.runtime_config_cls: - self.runtime_config_cls.append(rt_cls) + self.runtime_config_cls.append(cast(Type[RuntimeConfig], rt_cls)) except ValueError: pass - self.runtime_configs = [cls(self.target) for cls in self.runtime_config_cls] + self.runtime_configs: List[RuntimeConfig] = [cls(self.target) for cls in self.runtime_config_cls] - runtime_parameter = namedtuple('RuntimeParameter', 'cfg_point, rt_config') for cfg in self.runtime_configs: for param in cfg.supported_parameters: if param.name in self.runtime_params: - msg = 'Duplicate runtime parameter name "{}": in both {} and {}' + msg: str = 'Duplicate runtime parameter name "{}": in both {} and {}' raise RuntimeError(msg.format(param.name, self.runtime_params[param.name].rt_config.name, cfg.name)) - self.runtime_params[param.name] = runtime_parameter(param, cfg) + self.runtime_params[param.name] = RuntimeParameter(param, cfg) # Uses corresponding config point to merge parameters - def merge_runtime_parameters(self, parameters): + def merge_runtime_parameters(self, parameters: Dict[JobSpecSource, Dict[str, 'ConfigurationPoint']]) -> Dict: + """ + merge the runtime parameters + """ merged_params = obj_dict() for source in parameters: for name, value in parameters[source].items(): - cp = self.get_cfg_point(name) + cp: 'ConfigurationPoint' = self.get_cfg_point(name) cp.set_value(merged_params, value) return dict(merged_params) # Validates runtime_parameters against each other - def validate_runtime_parameters(self, parameters): + def validate_runtime_parameters(self, parameters: obj_dict) -> None: + """ + validate the runtime parameters + """ self.clear_runtime_parameters() self.set_runtime_parameters(parameters) for cfg in self.runtime_configs: cfg.validate_parameters() # Writes the given parameters to the device. - def commit_runtime_parameters(self, parameters): + def commit_runtime_parameters(self, parameters: obj_dict) -> None: + """ + commit the runtime parameters + """ self.clear_runtime_parameters() self.set_runtime_parameters(parameters) for cfg in self.runtime_configs: cfg.commit() # Stores a set of parameters performing isolated validation when appropriate - def set_runtime_parameters(self, parameters): + def set_runtime_parameters(self, parameters: obj_dict) -> None: + """ + set the runtime parameters + """ for name, value in parameters.items(): - cfg = self.get_config_for_name(name) + cfg: Optional[RuntimeConfig] = self.get_config_for_name(name) if cfg is None: - msg = 'Unsupported runtime parameter: "{}"' + msg: str = 'Unsupported runtime parameter: "{}"' raise ConfigError(msg.format(name)) cfg.set_runtime_parameter(name, value) - def clear_runtime_parameters(self): + def clear_runtime_parameters(self) -> None: + """ + clear runtime parameters + """ for cfg in self.runtime_configs: cfg.clear() cfg.set_defaults() - def get_config_for_name(self, name): + def get_config_for_name(self, name: str) -> Optional[RuntimeConfig]: + """ + get the configuration for the provided name + """ name = caseless_string(name) for k, v in self.runtime_params.items(): if name == k: return v.rt_config return None - def get_cfg_point(self, name): + def get_cfg_point(self, name: str) -> 'ConfigurationPoint': + """ + get the configuration point + """ name = caseless_string(name) for k, v in self.runtime_params.items(): if name == k or name in v.cfg_point.aliases: diff --git a/wa/framework/version.py b/wa/framework/version.py index 1d22384fd..512c76744 100644 --- a/wa/framework/version.py +++ b/wa/framework/version.py @@ -17,37 +17,49 @@ import sys from collections import namedtuple from subprocess import Popen, PIPE +from typing import Optional - -VersionTuple = namedtuple('Version', ['major', 'minor', 'revision', 'dev']) +VersionTuple = namedtuple('VersionTuple', ['major', 'minor', 'revision', 'dev']) version = VersionTuple(3, 4, 0, 'dev1') required_devlib_version = VersionTuple(1, 4, 0, 'dev3') -def format_version(v): - version_string = '{}.{}.{}'.format( +def format_version(v: VersionTuple) -> str: + """ + create version string from version tuple + """ + version_string: str = '{}.{}.{}'.format( v.major, v.minor, v.revision) if v.dev: version_string += '.{}'.format(v.dev) return version_string -def get_wa_version(): +def get_wa_version() -> str: + """ + get workload automation version + """ return format_version(version) -def get_wa_version_with_commit(): - version_string = get_wa_version() - commit = get_commit() +def get_wa_version_with_commit() -> str: + """ + get workload automation version with commit id + """ + version_string: str = get_wa_version() + commit: Optional[str] = get_commit() if commit: return '{}+{}'.format(version_string, commit) else: return version_string -def get_commit(): +def get_commit() -> Optional[str]: + """ + get commit id of workload automation + """ try: p = Popen(['git', 'rev-parse', 'HEAD'], cwd=os.path.dirname(__file__), stdout=PIPE, stderr=PIPE) diff --git a/wa/framework/workload.py b/wa/framework/workload.py index 5abb90e3f..0d5a922cf 100644 --- a/wa/framework/workload.py +++ b/wa/framework/workload.py @@ -33,6 +33,11 @@ from wa.utils.revent import ReventRecorder from wa.utils.exec_control import once_per_instance from wa.utils.misc import atomic_write_path +from typing import (Any, List, TYPE_CHECKING, Optional, Dict, Union, cast) +from wa.framework.execution import ExecutionContext +if TYPE_CHECKING: + from devlib.target import Target, AndroidTarget, ChromeOsTarget, installed_package_info + from devlib.utils.android import ApkInfo class Workload(TargetedPlugin): @@ -42,9 +47,9 @@ class Workload(TargetedPlugin): by the derived classes. """ - kind = 'workload' + kind: Optional[str] = 'workload' - parameters = [ + parameters: List[Parameter] = [ Parameter('uninstall', kind=bool, default=True, description=""" @@ -59,31 +64,33 @@ class Workload(TargetedPlugin): # database owned by the maintainer. # The user can then set allow_phone_home=False in their configuration to # prevent this workload from being run accidentally. - phones_home = False + phones_home: bool = False # Set this to ``True`` to mark the the workload will fail without a network # connection, this enables it to fail early with a clear message. - requires_network = False + requires_network: bool = False # Set this to specify a custom directory for assets to be pushed to, if unset # the working directory will be used. - asset_directory = None + asset_directory: Optional[str] = None # Used to store information about workload assets. - deployable_assets = [] + deployable_assets: List[str] = [] - def __init__(self, target, **kwargs): + target: 'Target' + + def __init__(self, target: 'Target', **kwargs): super(Workload, self).__init__(target, **kwargs) - self.asset_files = [] - self.deployed_assets = [] + self.asset_files: List[str] = [] + self.deployed_assets: List[str] = [] - supported_platforms = getattr(self, 'supported_platforms', []) + supported_platforms: List[str] = getattr(self, 'supported_platforms', []) if supported_platforms and self.target.os not in supported_platforms: - msg = 'Supported platforms for "{}" are "{}", attempting to run on "{}"' + msg: str = 'Supported platforms for "{}" are "{}", attempting to run on "{}"' raise WorkloadError(msg.format(self.name, ' '.join(self.supported_platforms), self.target.os)) - def init_resources(self, context): + def init_resources(self, context: ExecutionContext) -> None: """ This method may be used to perform early resource discovery and initialization. This is invoked during the initial loading stage and @@ -93,10 +100,10 @@ def init_resources(self, context): """ for asset in self.deployable_assets: - self.asset_files.append(context.get(File(self, asset))) + self.asset_files.append(context.get(File(self, asset)) or '') @once_per_instance - def initialize(self, context): + def initialize(self, context: ExecutionContext) -> None: """ This method should be used to perform once-per-run initialization of a workload instance, i.e., unlike ``setup()`` it will not be invoked on @@ -105,7 +112,7 @@ def initialize(self, context): if self.asset_files: self.deploy_assets(context) - def setup(self, context): + def setup(self, context: ExecutionContext) -> None: """ Perform the setup necessary to run the workload, such as copying the necessary files to the device, configuring the environments, etc. @@ -119,33 +126,39 @@ def setup(self, context): 'Workload "{}" requires internet. Target does not appear ' 'to be connected to the internet.'.format(self.name)) - def run(self, context): + def run(self, context: ExecutionContext) -> None: """ Execute the workload. This is the method that performs the actual "work" of the workload. """ - def extract_results(self, context): + def extract_results(self, context: ExecutionContext) -> None: """ - Extract results on the target + This method gets invoked after the task execution has finished and + should be used to extract metrics from the target. """ - def update_output(self, context): + def update_output(self, context: ExecutionContext) -> None: """ Update the output within the specified execution context with the metrics and artifacts for this workload iteration. - """ - def teardown(self, context): + def teardown(self, context: ExecutionContext) -> None: """ Perform any final clean up for the Workload. """ @once_per_instance - def finalize(self, context): + def finalize(self, context: ExecutionContext) -> None: + """ + This is the complement to ``initialize``. This will be executed + exactly once at the end of the run. This should be used to + perform any final clean up (e.g. uninstalling binaries installed + in the ``initialize``) + """ if self.cleanup_assets: self.remove_assets(context) - def deploy_assets(self, context): + def deploy_assets(self, context: ExecutionContext): """ Deploy assets if available to the target """ # pylint: disable=unused-argument if not self.asset_directory: @@ -156,9 +169,9 @@ def deploy_assets(self, context): for asset in self.asset_files: self.target.push(asset, self.asset_directory) self.deployed_assets.append(self.target.path.join(self.asset_directory, - os.path.basename(asset))) + os.path.basename(asset)) if self.target.path else '') - def remove_assets(self, context): + def remove_assets(self, context: ExecutionContext): """ Cleanup assets deployed to the target """ # pylint: disable=unused-argument for asset in self.deployed_assets: @@ -169,24 +182,65 @@ def __str__(self): class ApkWorkload(Workload): - - supported_platforms = ['android'] + """ + The :class:`ApkWorkload` derives from the base :class:`Workload` class however + this associates the workload with a package allowing for an apk to be found for + the workload, setup and ran on the device before running the workload. + Additional attributes wrt Workload- + ``loading_time`` + This is the time in seconds that WA will wait for the application to load + before continuing with the run. By default this will wait 10 second however + if your application under test requires additional time this values should + be increased. + + ``package_names`` + This attribute should be a list of Apk packages names that are + suitable for this workload. Both the host (in the relevant resource + locations) and device will be searched for an application with a matching + package name. + + ``supported_versions`` + This attribute should be a list of apk versions that are suitable for this + workload, if a specific apk version is not specified then any available + supported version may be chosen. + + ``activity`` + This attribute can be optionally set to override the default activity that + will be extracted from the selected APK file which will be used when + launching the APK. + + ``view`` + This is the "view" associated with the application. This is used by + instruments like ``fps`` to monitor the current framerate being generated by + the application. + + ``apk`` + The is a :class:`PackageHandler`` which is what is used to store + information about the apk and manage the application itself, the handler is + used to call the associated methods to manipulate the application itself for + example to launch/close it etc. + + ``package`` + This is a more convenient way to access the package name of the Apk + that was found and being used for the run. + """ + supported_platforms: List[str] = ['android'] # May be optionally overwritten by subclasses # Times are in seconds - loading_time = 10 - package_names = [] - supported_versions = [] - activity = None - view = None - clear_data_on_reset = True - apk_arguments = {} + loading_time: int = 10 + package_names: List[str] = [] + supported_versions: List[str] = [] + activity: Optional[str] = None + view: Optional[str] = None + clear_data_on_reset: bool = True + apk_arguments: Dict[str, Union[str, float, bool, int]] = {} # Set this to True to mark that this workload requires the target apk to be run # for initialisation purposes before the main run is performed. - requires_rerun = False + requires_rerun: bool = False - parameters = [ + parameters: List[Parameter] = [ Parameter('package_name', kind=str, description=""" The package name that can be used to specify @@ -260,13 +314,13 @@ class ApkWorkload(Workload): ] @property - def package(self): + def package(self) -> Optional[str]: return self.apk.package - def __init__(self, target, **kwargs): + def __init__(self, target: 'Target', **kwargs): if target.os == 'chromeos': - if target.supports_android: - target = target.android_container + if cast('ChromeOsTarget', target).supports_android: + target = cast('AndroidTarget', cast('ChromeOsTarget', target).android_container) else: raise ConfigError('Target does not appear to support Android') @@ -293,24 +347,32 @@ def __init__(self, target, **kwargs): max_version=self.max_version, apk_arguments=self.apk_arguments) - def validate(self): + def validate(self) -> None: + """ + This method can be used to validate any assumptions your workload + makes about the environment (e.g. that required files are + present, environment variables are set, etc) and should raise a + :class:`wa.WorkloadError ` + if that is not the case. The base class implementation only makes + sure sure that the name attribute has been set. + """ if self.min_version and self.max_version: if version_tuple(self.min_version) > version_tuple(self.max_version): msg = 'Cannot specify min version ({}) greater than max version ({})' raise ConfigError(msg.format(self.min_version, self.max_version)) @once_per_instance - def initialize(self, context): + def initialize(self, context: ExecutionContext) -> None: super(ApkWorkload, self).initialize(context) self.apk.initialize(context) # pylint: disable=access-member-before-definition, attribute-defined-outside-init if self.version is None: - self.version = self.apk.apk_info.version_name + self.version: Optional[str] = self.apk.apk_info.version_name if self.apk.apk_info else '' if self.view is None: self.view = 'SurfaceView - {}/{}'.format(self.apk.package, self.apk.activity) - def setup(self, context): + def setup(self, context: ExecutionContext) -> None: super(ApkWorkload, self).setup(context) self.apk.setup(context) if self.requires_rerun: @@ -318,62 +380,81 @@ def setup(self, context): self.apk.restart_activity() time.sleep(self.loading_time) - def setup_rerun(self): + def setup_rerun(self) -> None: """ Perform the setup necessary to rerun the workload. Only called if ``requires_rerun`` is set. """ - def teardown(self, context): + def teardown(self, context: ExecutionContext) -> None: super(ApkWorkload, self).teardown(context) self.apk.teardown() - def deploy_assets(self, context): + def deploy_assets(self, context: ExecutionContext) -> None: super(ApkWorkload, self).deploy_assets(context) - self.target.refresh_files(self.deployed_assets) + cast('AndroidTarget', self.target).refresh_files(self.deployed_assets) class ApkUIWorkload(ApkWorkload): - def __init__(self, target, **kwargs): + def __init__(self, target: 'Target', **kwargs): super(ApkUIWorkload, self).__init__(target, **kwargs) - self.gui = None + self.gui: Optional[Union[UiAutomatorGUI, ReventGUI]] = None - def init_resources(self, context): + def init_resources(self, context: ExecutionContext) -> None: super(ApkUIWorkload, self).init_resources(context) - self.gui.init_resources(context) + if self.gui: + self.gui.init_resources(context) @once_per_instance - def initialize(self, context): + def initialize(self, context: ExecutionContext) -> None: super(ApkUIWorkload, self).initialize(context) - self.gui.deploy() + if self.gui: + self.gui.deploy() - def setup(self, context): + def setup(self, context: ExecutionContext) -> None: super(ApkUIWorkload, self).setup(context) - self.gui.setup() + if self.gui: + self.gui.setup() - def run(self, context): + def run(self, context: ExecutionContext) -> None: super(ApkUIWorkload, self).run(context) - self.gui.run() + if self.gui: + self.gui.run() - def extract_results(self, context): + def extract_results(self, context: ExecutionContext) -> None: super(ApkUIWorkload, self).extract_results(context) - self.gui.extract_results() + if self.gui: + self.gui.extract_results() - def teardown(self, context): - self.gui.teardown() + def teardown(self, context: ExecutionContext) -> None: + if self.gui: + self.gui.teardown() super(ApkUIWorkload, self).teardown(context) @once_per_instance - def finalize(self, context): + def finalize(self, context: ExecutionContext) -> None: super(ApkUIWorkload, self).finalize(context) - if self.cleanup_assets: + if self.cleanup_assets and self.gui: self.gui.remove() class ApkUiautoWorkload(ApkUIWorkload): + """ + The :class:`ApkUiautoWorkload` derives from :class:`ApkUIWorkload` which is an + intermediate class which in turn inherits from + :class:`ApkWorkload`, however in addition to associating an apk with the + workload this class allows for automating the application with UiAutomator. + + This class define these additional attributes: + + ``gui`` + This attribute will be an instance of a :class:`UiAutmatorGUI` which is + used to control the automation, and is what is used to pass parameters to the + java class for example ``gui.uiauto_params``. + """ - parameters = [ + parameters: List[Parameter] = [ Parameter('markers_enabled', kind=bool, default=False, description=""" If set to ``True``, workloads will insert markers into logs @@ -384,27 +465,51 @@ class ApkUiautoWorkload(ApkUIWorkload): """), ] - def __init__(self, target, **kwargs): + def __init__(self, target: 'Target', **kwargs): super(ApkUiautoWorkload, self).__init__(target, **kwargs) self.gui = UiAutomatorGUI(self) - def setup(self, context): - self.gui.uiauto_params['package_name'] = self.apk.apk_info.package - self.gui.uiauto_params['markers_enabled'] = self.markers_enabled - self.gui.init_commands() + def setup(self, context: ExecutionContext) -> None: + cast(UiAutomatorGUI, self.gui).uiauto_params['package_name'] = self.apk.apk_info.package if self.apk.apk_info else '' + cast(UiAutomatorGUI, self.gui).uiauto_params['markers_enabled'] = self.markers_enabled + cast(UiAutomatorGUI, self.gui).init_commands() super(ApkUiautoWorkload, self).setup(context) class ApkReventWorkload(ApkUIWorkload): + """ + The :class:`ApkReventWorkload` derives from :class:`ApkUIWorkload` which is an + intermediate class which in turn inherits from + :class:`ApkWorkload`, however in addition to associating an apk with the + workload this class allows for automating the application with + :ref:`Revent `. + + This class define these additional attributes: + + ``gui`` + This attribute will be an instance of a :class:`ReventGUI` which is + used to control the automation + + ``setup_timeout`` + This is the time allowed for replaying a recording for the setup stage. + ``run_timeout`` + This is the time allowed for replaying a recording for the run stage. + + ``extract_results_timeout`` + This is the time allowed for replaying a recording for the extract results stage. + + ``teardown_timeout`` + This is the time allowed for replaying a recording for the teardown stage. + """ # May be optionally overwritten by subclasses # Times are in seconds - setup_timeout = 5 * 60 - run_timeout = 10 * 60 - extract_results_timeout = 5 * 60 - teardown_timeout = 5 * 60 + setup_timeout: int = 5 * 60 + run_timeout: int = 10 * 60 + extract_results_timeout: int = 5 * 60 + teardown_timeout: int = 5 * 60 - def __init__(self, target, **kwargs): + def __init__(self, target: 'Target', **kwargs): super(ApkReventWorkload, self).__init__(target, **kwargs) self.gui = ReventGUI(self, target, self.setup_timeout, @@ -415,47 +520,65 @@ def __init__(self, target, **kwargs): class UIWorkload(Workload): - def __init__(self, target, **kwargs): + def __init__(self, target: 'Target', **kwargs): super(UIWorkload, self).__init__(target, **kwargs) - self.gui = None + self.gui: Optional[Union[UiAutomatorGUI, ReventGUI]] = None - def init_resources(self, context): + def init_resources(self, context: ExecutionContext) -> None: super(UIWorkload, self).init_resources(context) - self.gui.init_resources(context) + if self.gui: + self.gui.init_resources(context) @once_per_instance - def initialize(self, context): + def initialize(self, context: ExecutionContext) -> None: super(UIWorkload, self).initialize(context) - self.gui.deploy() + if self.gui: + self.gui.deploy() - def setup(self, context): + def setup(self, context: ExecutionContext) -> None: super(UIWorkload, self).setup(context) - self.gui.setup() + if self.gui: + self.gui.setup() - def run(self, context): + def run(self, context: ExecutionContext) -> None: super(UIWorkload, self).run(context) - self.gui.run() + if self.gui: + self.gui.run() - def extract_results(self, context): + def extract_results(self, context: ExecutionContext) -> None: super(UIWorkload, self).extract_results(context) - self.gui.extract_results() + if self.gui: + self.gui.extract_results() - def teardown(self, context): - self.gui.teardown() + def teardown(self, context: ExecutionContext) -> None: + if self.gui: + self.gui.teardown() super(UIWorkload, self).teardown(context) @once_per_instance - def finalize(self, context): + def finalize(self, context: ExecutionContext) -> None: super(UIWorkload, self).finalize(context) - if self.cleanup_assets: + if self.cleanup_assets and self.gui: self.gui.remove() class UiautoWorkload(UIWorkload): + """ + The :class:`UiautoWorkload` derives from :class:`UIWorkload` which is an + intermediate class which in turn inherits from + :class:`Workload`, however this allows for providing generic automation using + UiAutomator without associating a particular application with the workload. + + This class define these additional attributes: - supported_platforms = ['android'] + ``gui`` + This attribute will be an instance of a :class:`UiAutmatorGUI` which is + used to control the automation, and is what is used to pass parameters to the + java class for example ``gui.uiauto_params``. + """ + supported_platforms: List[str] = ['android'] - parameters = [ + parameters: List[Parameter] = [ Parameter('markers_enabled', kind=bool, default=False, description=""" If set to ``True``, workloads will insert markers into logs @@ -466,32 +589,56 @@ class UiautoWorkload(UIWorkload): """), ] - def __init__(self, target, **kwargs): + def __init__(self, target: 'Target', **kwargs): if target.os == 'chromeos': - if target.supports_android: - target = target.android_container + if cast('ChromeOsTarget', target).supports_android: + target = cast('AndroidTarget', cast('ChromeOsTarget', target).android_container) else: raise ConfigError('Target does not appear to support Android') super(UiautoWorkload, self).__init__(target, **kwargs) self.gui = UiAutomatorGUI(self) - def setup(self, context): - self.gui.uiauto_params['markers_enabled'] = self.markers_enabled - self.gui.init_commands() + def setup(self, context: ExecutionContext) -> None: + cast(UiAutomatorGUI, self.gui).uiauto_params['markers_enabled'] = self.markers_enabled + cast(UiAutomatorGUI, self.gui).init_commands() super(UiautoWorkload, self).setup(context) class ReventWorkload(UIWorkload): + """ + The :class:`ReventWorkload` derives from :class:`UIWorkload` which is an + intermediate class which in turn inherits from + :class:`Workload`, however this allows for providing generic automation + using :ref:`Revent ` without associating with the + workload. + + This class define these additional attributes: + ``gui`` + This attribute will be an instance of a :class:`ReventGUI` which is + used to control the automation + + ``setup_timeout`` + This is the time allowed for replaying a recording for the setup stage. + + ``run_timeout`` + This is the time allowed for replaying a recording for the run stage. + + ``extract_results_timeout`` + This is the time allowed for replaying a recording for the extract results stage. + + ``teardown_timeout`` + This is the time allowed for replaying a recording for the teardown stage. + """ # May be optionally overwritten by subclasses # Times are in seconds - setup_timeout = 5 * 60 - run_timeout = 10 * 60 - extract_results_timeout = 5 * 60 - teardown_timeout = 5 * 60 + setup_timeout: int = 5 * 60 + run_timeout: int = 10 * 60 + extract_results_timeout: int = 5 * 60 + teardown_timeout: int = 5 * 60 - def __init__(self, target, **kwargs): + def __init__(self, target: 'Target', **kwargs): super(ReventWorkload, self).__init__(target, **kwargs) self.gui = ReventGUI(self, target, self.setup_timeout, @@ -502,77 +649,108 @@ def __init__(self, target, **kwargs): class UiAutomatorGUI(object): - stages = ['setup', 'runWorkload', 'extractResults', 'teardown'] + stages: List[str] = ['setup', 'runWorkload', 'extractResults', 'teardown'] - uiauto_runner = 'android.support.test.runner.AndroidJUnitRunner' + uiauto_runner: str = 'android.support.test.runner.AndroidJUnitRunner' - def __init__(self, owner, package=None, klass='UiAutomation', timeout=600): + def __init__(self, owner: Workload, package: Optional[str] = None, + klass: str = 'UiAutomation', timeout: int = 600): self.owner = owner - self.target = self.owner.target + self.target = cast('AndroidTarget', self.owner.target) self.uiauto_package = package self.uiauto_class = klass self.timeout = timeout self.logger = logging.getLogger('gui') - self.uiauto_file = None - self.commands = {} + self.uiauto_file: Optional[str] = None + self.commands: Dict[str, str] = {} self.uiauto_params = ParameterDict() - def init_resources(self, resolver): + def init_resources(self, resolver: ExecutionContext): + """ + initialize resources + """ self.uiauto_file = resolver.get(ApkFile(self.owner, uiauto=True)) if not self.uiauto_package: - uiauto_info = get_cacheable_apk_info(self.uiauto_file) - self.uiauto_package = uiauto_info.package + uiauto_info: Optional[ApkInfo] = get_cacheable_apk_info(self.uiauto_file) + self.uiauto_package = uiauto_info.package if uiauto_info else None - def init_commands(self): - params_dict = self.uiauto_params + def init_commands(self) -> None: + """ + initialize commands + """ + params_dict: ParameterDict = self.uiauto_params params_dict['workdir'] = self.target.working_directory - params = '' + params: str = '' for k, v in params_dict.iter_encoded_items(): params += ' -e {} {}'.format(k, v) for stage in self.stages: - class_string = '{}.{}#{}'.format(self.uiauto_package, self.uiauto_class, - stage) - instrumentation_string = '{}/{}'.format(self.uiauto_package, - self.uiauto_runner) - cmd_template = 'am instrument -w -r{} -e class {} {}' + class_string: str = '{}.{}#{}'.format(self.uiauto_package, self.uiauto_class, + stage) + instrumentation_string: str = '{}/{}'.format(self.uiauto_package, + self.uiauto_runner) + cmd_template: str = 'am instrument -w -r{} -e class {} {}' self.commands[stage] = cmd_template.format(params, class_string, instrumentation_string) - def deploy(self): + def deploy(self) -> None: + """ + install ui auto package onto target + """ if self.target.package_is_installed(self.uiauto_package): self.target.uninstall_package(self.uiauto_package) self.target.install_apk(self.uiauto_file) - def set(self, name, value): + def set(self, name: str, value: Any) -> None: + """ + set a uiauto_param to the value + """ self.uiauto_params[name] = value - def setup(self, timeout=None): + def setup(self, timeout: Optional[int] = None) -> None: + """ + execute setup stage commands + """ if not self.commands: raise RuntimeError('Commands have not been initialized') self.target.killall('uiautomator') self._execute('setup', timeout or self.timeout) - def run(self, timeout=None): + def run(self, timeout: Optional[int] = None) -> None: + """ + execute runWorkload stage commands + """ if not self.commands: raise RuntimeError('Commands have not been initialized') self._execute('runWorkload', timeout or self.timeout) - def extract_results(self, timeout=None): + def extract_results(self, timeout: Optional[int] = None) -> None: + """ + execute extractResults stage commands + """ if not self.commands: raise RuntimeError('Commands have not been initialized') self._execute('extractResults', timeout or self.timeout) - def teardown(self, timeout=None): + def teardown(self, timeout: Optional[int] = None) -> None: + """ + execute teardown stage commands + """ if not self.commands: raise RuntimeError('Commands have not been initialized') self._execute('teardown', timeout or self.timeout) - def remove(self): + def remove(self) -> None: + """ + uninstall uiauto package + """ self.target.uninstall(self.uiauto_package) - def _execute(self, stage, timeout): - result = self.target.execute(self.commands[stage], timeout) + def _execute(self, stage: str, timeout: Optional[int]) -> None: + """ + execute commands for the specified stage + """ + result: str = self.target.execute(self.commands[stage], timeout) if 'FAILURE' in result: raise WorkloadError(result) else: @@ -581,9 +759,15 @@ def _execute(self, stage, timeout): class ReventGUI(object): - - def __init__(self, workload, target, setup_timeout, run_timeout, - extract_results_timeout, teardown_timeout): + """ + The revent utility can be used to record and later play back a sequence of user + input events, such as key presses and touch screen taps. This is an alternative + to Android UI Automator for providing automation for workloads. + Some workloads (pretty much all games) rely on recorded revents for their + execution. ReventWorkloads require between 1 and 4 revent files to be ran. + """ + def __init__(self, workload: Workload, target: 'Target', setup_timeout: int, run_timeout: int, + extract_results_timeout: int, teardown_timeout: int): self.workload = workload self.target = target self.setup_timeout = setup_timeout @@ -597,12 +781,15 @@ def __init__(self, workload, target, setup_timeout, run_timeout, self.on_target_extract_results_revent = self.target.get_workpath('{}.extract_results.revent'.format(self.target.model)) self.on_target_teardown_revent = self.target.get_workpath('{}.teardown.revent'.format(self.target.model)) self.logger = logging.getLogger('revent') - self.revent_setup_file = None - self.revent_run_file = None - self.revent_extract_results_file = None - self.revent_teardown_file = None + self.revent_setup_file: Optional[str] = None + self.revent_run_file: Optional[str] = None + self.revent_extract_results_file: Optional[str] = None + self.revent_teardown_file: Optional[str] = None - def init_resources(self, resolver): + def init_resources(self, resolver: ExecutionContext) -> None: + """ + initialize resources + """ self.revent_setup_file = resolver.get(ReventFile(owner=self.workload, stage='setup', target=self.target.model), @@ -619,40 +806,65 @@ def init_resources(self, resolver): target=self.target.model), strict=False) - def deploy(self): + def deploy(self) -> None: + """ + install revent gui on target + """ self.revent_recorder.deploy() - def setup(self): + def setup(self) -> None: + """ + ``setup`` can be used to perform the initial setup + (navigating menus, selecting game modes, etc). + """ self._check_revent_files() if self.revent_setup_file: - self.revent_recorder.replay(self.on_target_setup_revent, + self.revent_recorder.replay(self.on_target_setup_revent or '', timeout=self.setup_timeout) def run(self): + """ + There is one mandatory recording, ``run``, for performing the actual execution of + the workload and the remaining stages are optional + """ msg = 'Replaying {}' - self.logger.debug(msg.format(os.path.basename(self.on_target_run_revent))) - self.revent_recorder.replay(self.on_target_run_revent, + self.logger.debug(msg.format(os.path.basename(self.on_target_run_revent or ''))) + self.revent_recorder.replay(self.on_target_run_revent or '', timeout=self.run_timeout) self.logger.debug('Replay completed.') - def extract_results(self): + def extract_results(self) -> None: + """ + ``extract_results`` can be used to perform any actions after the main stage of + the workload for example to navigate a results or summary screen of the app. + """ if self.revent_extract_results_file: - self.revent_recorder.replay(self.on_target_extract_results_revent, + self.revent_recorder.replay(self.on_target_extract_results_revent or '', timeout=self.extract_results_timeout) - def teardown(self): + def teardown(self) -> None: + """ + ``teardown`` can be used to perform any final actions for example + exiting the app. + """ if self.revent_teardown_file: - self.revent_recorder.replay(self.on_target_teardown_revent, + self.revent_recorder.replay(self.on_target_teardown_revent or '', timeout=self.teardown_timeout) - def remove(self): + def remove(self) -> None: + """ + cleanup the execution files + """ self.target.remove(self.on_target_setup_revent) self.target.remove(self.on_target_run_revent) self.target.remove(self.on_target_extract_results_revent) self.target.remove(self.on_target_teardown_revent) self.revent_recorder.remove() - def _check_revent_files(self): + def _check_revent_files(self) -> None: + """ + check the revent files + """ if not self.revent_run_file: # pylint: disable=too-few-format-args message = '{0}.run.revent file does not exist, ' \ @@ -671,13 +883,13 @@ def _check_revent_files(self): class PackageHandler(object): @property - def package(self): + def package(self) -> Optional[str]: if self.apk_info is None: return None return self.apk_info.package @property - def activity(self): + def activity(self) -> Optional[str]: if self._activity: return self._activity if self.apk_info is None: @@ -685,13 +897,14 @@ def activity(self): return self.apk_info.activity # pylint: disable=too-many-locals - def __init__(self, owner, install_timeout=300, version=None, variant=None, - package_name=None, strict=False, force_install=False, uninstall=False, - exact_abi=False, prefer_host_package=True, clear_data_on_reset=True, - activity=None, min_version=None, max_version=None, apk_arguments=None): + def __init__(self, owner: ApkWorkload, install_timeout: int = 300, version: Optional[Union[str, List[str]]] = None, + variant: Optional[str] = None, package_name: Optional[str] = None, strict: bool = False, + force_install: bool = False, uninstall: bool = False, exact_abi: bool = False, prefer_host_package: bool = True, + clear_data_on_reset: bool = True, activity: Optional[str] = None, min_version: Optional[str] = None, + max_version: Optional[str] = None, apk_arguments: Optional[Dict[str, Union[str, float, bool, int]]] = None): self.logger = logging.getLogger('apk') self.owner = owner - self.target = self.owner.target + self.target: 'AndroidTarget' = cast('AndroidTarget', self.owner.target) self.install_timeout = install_timeout self.version = version self.min_version = min_version @@ -706,27 +919,36 @@ def __init__(self, owner, install_timeout=300, version=None, variant=None, self.clear_data_on_reset = clear_data_on_reset self._activity = activity self.supported_abi = self.target.supported_abi - self.apk_file = None - self.apk_info = None - self.apk_version = None - self.logcat_log = None - self.error_msg = None + self.apk_file: Optional[str] = None + self.apk_info: Optional['ApkInfo'] = None + self.apk_version: Optional[str] = None + self.logcat_log: Optional[str] = None + self.error_msg: Optional[str] = None self.apk_arguments = apk_arguments - def initialize(self, context): + def initialize(self, context: ExecutionContext) -> None: + """ + initialize package + """ self.resolve_package(context) - def setup(self, context): - context.update_metadata('app_version', self.apk_info.version_name) - context.update_metadata('app_name', self.apk_info.package) + def setup(self, context: ExecutionContext) -> None: + """ + setup the package + """ + context.update_metadata('app_version', self.apk_info.version_name if self.apk_info else '') + context.update_metadata('app_name', self.apk_info.package if self.apk_info else '') self.initialize_package(context) self.start_activity() self.target.execute('am kill-all') # kill all *background* activities self.target.clear_logcat() - def resolve_package(self, context): + def resolve_package(self, context: ExecutionContext) -> None: + """ + resolve package on host or target + """ if not self.owner.package_names and not self.package_name: - msg = 'Cannot Resolve package; No package name(s) specified' + msg: str = 'Cannot Resolve package; No package name(s) specified' raise WorkloadError(msg) self.error_msg = None @@ -746,8 +968,8 @@ def resolve_package(self, context): raise WorkloadError(self.error_msg) else: if self.package_name: - message = 'Package "{package}" not found for workload {name} '\ - 'on host or target.' + message: str = 'Package "{package}" not found for workload {name} '\ + 'on host or target.' elif self.version: message = 'No matching package found for workload {name} '\ '(version {version}) on host or target.' @@ -756,7 +978,10 @@ def resolve_package(self, context): raise WorkloadError(message.format(name=self.owner.name, version=self.version, package=self.package_name)) - def resolve_package_from_host(self, context): + def resolve_package_from_host(self, context: ExecutionContext) -> None: + """ + resolve package on host system + """ self.logger.debug('Resolving package on host system') if self.package_name: self.apk_file = context.get_resource(ApkFile(self.owner, @@ -769,7 +994,7 @@ def resolve_package_from_host(self, context): max_version=self.max_version), strict=self.strict) else: - available_packages = [] + available_packages: List[str] = [] for package in self.owner.package_names: apk_file = context.get_resource(ApkFile(self.owner, variant=self.variant, @@ -787,14 +1012,17 @@ def resolve_package_from_host(self, context): elif len(available_packages) > 1: self.error_msg = self._get_package_error_msg('host') - def resolve_package_from_target(self): # pylint: disable=too-many-branches + def resolve_package_from_target(self) -> None: # pylint: disable=too-many-branches + """ + resolve package on target + """ self.logger.debug('Resolving package on target') - found_package = None + found_package: Optional[str] = None if self.package_name: if not self.target.package_is_installed(self.package_name): return else: - installed_versions = [self.package_name] + installed_versions: List[str] = [self.package_name] else: installed_versions = [] for package in self.owner.package_names: @@ -802,9 +1030,9 @@ def resolve_package_from_target(self): # pylint: disable=too-many-branches installed_versions.append(package) if self.version or self.min_version or self.max_version: - matching_packages = [] + matching_packages: List[str] = [] for package in installed_versions: - package_version = self.target.get_package_version(package) + package_version: str = self.target.get_package_version(package) if self.version: for v in list_or_string(self.version): if loose_version_matching(v, package_version): @@ -828,12 +1056,15 @@ def resolve_package_from_target(self): # pylint: disable=too-many-branches self.apk_file = self.pull_apk(found_package) self.package_name = found_package - def initialize_package(self, context): - installed_version = self.target.get_package_version(self.apk_info.package) - host_version = self.apk_info.version_name + def initialize_package(self, context: ExecutionContext) -> None: + """ + initialize package + """ + installed_version: str = self.target.get_package_version(self.apk_info.package if self.apk_info else '') + host_version: Optional[str] = self.apk_info.version_name if self.apk_info else '' if installed_version != host_version: if installed_version: - message = '{} host version: {}, device version: {}; re-installing...' + message: str = '{} host version: {}, device version: {}; re-installing...' self.logger.debug(message.format(self.owner.name, host_version, installed_version)) else: @@ -845,71 +1076,94 @@ def initialize_package(self, context): self.logger.debug(message.format(self.owner.name, host_version)) if self.force_install: if installed_version: - self.target.uninstall_package(self.apk_info.package) + self.target.uninstall_package(self.apk_info.package if self.apk_info else '') self.install_apk(context) else: self.reset(context) - if self.apk_info.permissions: + if self.apk_info and self.apk_info.permissions: self.logger.debug('Granting runtime permissions') for permission in self.apk_info.permissions: self.target.grant_package_permission(self.apk_info.package, permission) self.apk_version = host_version - def start_activity(self): - - cmd = build_apk_launch_command(self.apk_info.package, self.activity, - self.apk_arguments) + def start_activity(self) -> None: + """ + start activity via the activity manager + """ + cmd: str = build_apk_launch_command(self.apk_info.package if self.apk_info else '', self.activity, + self.apk_arguments) - output = self.target.execute(cmd) + output: str = self.target.execute(cmd) if 'Error:' in output: # this will dismiss any error dialogs - self.target.execute('am force-stop {}'.format(self.apk_info.package)) + self.target.execute('am force-stop {}'.format(self.apk_info.package if self.apk_info else '')) raise WorkloadError(output) self.logger.debug(output) - def restart_activity(self): - self.target.execute('am force-stop {}'.format(self.apk_info.package)) + def restart_activity(self) -> None: + """ + restart the activity via the activity manager + """ + self.target.execute('am force-stop {}'.format(self.apk_info.package if self.apk_info else '')) self.start_activity() - def reset(self, context): # pylint: disable=W0613 - self.target.execute('am force-stop {}'.format(self.apk_info.package)) + def reset(self, context: ExecutionContext) -> None: # pylint: disable=W0613 + """ + stop the activity via activity manager and clear the package via package manager + """ + self.target.execute('am force-stop {}'.format(self.apk_info.package if self.apk_info else '')) if self.clear_data_on_reset: - self.target.execute('pm clear {}'.format(self.apk_info.package)) + self.target.execute('pm clear {}'.format(self.apk_info.package if self.apk_info else '')) - def install_apk(self, context): + def install_apk(self, context: ExecutionContext) -> None: + """ + install the apk + """ # pylint: disable=unused-argument - output = self.target.install_apk(self.apk_file, self.install_timeout, - replace=True, allow_downgrade=True) + output: str = self.target.install_apk(self.apk_file, self.install_timeout, + replace=True, allow_downgrade=True) if 'Failure' in output: if 'ALREADY_EXISTS' in output: - msg = 'Using already installed APK (did not uninstall properly?)' + msg: str = 'Using already installed APK (did not uninstall properly?)' self.logger.warning(msg) else: raise WorkloadError(output) else: self.logger.debug(output) - def pull_apk(self, package): + def pull_apk(self, package: str) -> str: + """ + pull apk from the target + """ if not self.target.package_is_installed(package): message = 'Cannot retrieve "{}" as not installed on Target' raise WorkloadError(message.format(package)) - package_info = self.target.get_package_info(package) - apk_name = self._get_package_name(package_info.apk_path) - host_path = os.path.join(self.owner.dependencies_directory, apk_name) + package_info: 'installed_package_info' = self.target.get_package_info(package) + apk_name: str = self._get_package_name(package_info.apk_path) + host_path: str = os.path.join(self.owner.dependencies_directory, apk_name) with atomic_write_path(host_path) as at_path: self.target.pull(package_info.apk_path, at_path, timeout=self.install_timeout) return host_path - def teardown(self): - self.target.execute('am force-stop {}'.format(self.apk_info.package)) + def teardown(self) -> None: + """ + forse stop activity and uninstall the package + """ + self.target.execute('am force-stop {}'.format(self.apk_info.package if self.apk_info else '')) if self.uninstall: - self.target.uninstall_package(self.apk_info.package) + self.target.uninstall_package(self.apk_info.package if self.apk_info else '') - def _get_package_name(self, apk_path): - return self.target.path.basename(apk_path) + def _get_package_name(self, apk_path: str) -> str: + """ + get the name of the package + """ + return self.target.path.basename(apk_path) if self.target.path else '' - def _get_package_error_msg(self, location): + def _get_package_error_msg(self, location: str) -> str: + """ + get the error message from package + """ if self.version: msg = 'Multiple matches for "{version}" found on {location}.' elif self.min_version and self.max_version: @@ -927,8 +1181,8 @@ def _get_package_error_msg(self, location): class TestPackageHandler(PackageHandler): """Class wrapping an APK used through ``am instrument``. """ - def __init__(self, owner, instrument_args=None, raw_output=False, - instrument_wait=True, no_hidden_api_checks=False, + def __init__(self, owner: ApkWorkload, instrument_args: Optional[Dict] = None, raw_output: bool = False, + instrument_wait: bool = True, no_hidden_api_checks: bool = False, *args, **kwargs): if instrument_args is None: instrument_args = {} @@ -938,14 +1192,14 @@ def __init__(self, owner, instrument_args=None, raw_output=False, self.wait = instrument_wait self.no_checks = no_hidden_api_checks - self.cmd = '' - self.instrument_thread = None - self._instrument_output = None + self.cmd: str = '' + self.instrument_thread: Optional[threading.Thread] = None + self._instrument_output: Optional[str] = None - def setup(self, context): + def setup(self, context: ExecutionContext) -> None: self.initialize_package(context) - words = ['am', 'instrument', '--user', '0'] + words: List[str] = ['am', 'instrument', '--user', '0'] if self.raw: words.append('-r') if self.wait: @@ -955,32 +1209,40 @@ def setup(self, context): for k, v in self.args.items(): words.extend(['-e', str(k), str(v)]) - words.append(str(self.apk_info.package)) - if self.apk_info.activity: + words.append(str(self.apk_info.package if self.apk_info else '')) + if self.apk_info and self.apk_info.activity: words[-1] += '/{}'.format(self.apk_info.activity) self.cmd = ' '.join(quote(x) for x in words) self.instrument_thread = threading.Thread(target=self._start_instrument) - def start_activity(self): - self.instrument_thread.start() + def start_activity(self) -> None: + if self.instrument_thread: + self.instrument_thread.start() - def wait_instrument_over(self): - self.instrument_thread.join() - if 'Error:' in self._instrument_output: - cmd = 'am force-stop {}'.format(self.apk_info.package) + def wait_instrument_over(self) -> None: + """ + wait for the instrument to complete execution + """ + if self.instrument_thread: + self.instrument_thread.join() + if self._instrument_output and 'Error:' in self._instrument_output: + cmd = 'am force-stop {}'.format(self.apk_info.package if self.apk_info else '') self.target.execute(cmd) raise WorkloadError(self._instrument_output) - def _start_instrument(self): + def _start_instrument(self) -> None: + """ + start the instrument in separate thread + """ self._instrument_output = self.target.execute(self.cmd) self.logger.debug(self._instrument_output) - def _get_package_name(self, apk_path): - return 'test_{}'.format(self.target.path.basename(apk_path)) + def _get_package_name(self, apk_path: str): + return 'test_{}'.format(self.target.path.basename(apk_path) if self.target.path else '') @property - def instrument_output(self): - if self.instrument_thread.is_alive(): + def instrument_output(self) -> Optional[str]: + if self.instrument_thread and self.instrument_thread.is_alive(): self.instrument_thread.join() # writes self._instrument_output return self._instrument_output diff --git a/wa/instruments/delay.py b/wa/instruments/delay.py index b65135d5c..b675f6fb4 100644 --- a/wa/instruments/delay.py +++ b/wa/instruments/delay.py @@ -21,11 +21,45 @@ from wa.framework.exception import ConfigError, InstrumentError from wa.framework.instrument import extremely_slow from wa.utils.types import identifier +from typing import TYPE_CHECKING, Optional, List, cast +from typing_extensions import Protocol +from devlib.target import Target +import logging +if TYPE_CHECKING: + from wa.framework.execution import ExecutionContext + + +class DelayInstrumentProtocol(Protocol): + name: str + description: str + temperature_file: str + temperature_timeout: int + temperature_poll_period: int + temperature_between_specs: int + fixed_between_specs: int + temperature_between_jobs: int + fixed_between_jobs: int + fixed_before_start: int + temperature_before_start: int + active_cooling: bool + cooling: Optional[Instrument] + active_cooling_modules: List[str] + target: Target + logger: logging.Logger + + def _discover_cooling_module(self) -> Optional[Instrument]: + ... + + def wait_for_temperature(self, temperature: int) -> None: + ... + + def do_wait_for_temperature(self, temperature: int) -> None: + ... class DelayInstrument(Instrument): - name = 'delay' + name: str = 'delay' description = """ This instrument introduces a delay before beginning a new spec, a new job or before the main execution of a workload. @@ -39,7 +73,7 @@ class DelayInstrument(Instrument): """ - parameters = [ + parameters: List[Parameter] = [ Parameter('temperature_file', default='/sys/devices/virtual/thermal/thermal_zone0/temp', global_alias='thermal_temp_file', description=""" @@ -135,13 +169,16 @@ class DelayInstrument(Instrument): """), ] - active_cooling_modules = ['mbed-fan', 'odroidxu3-fan'] + active_cooling_modules: List[str] = ['mbed-fan', 'odroidxu3-fan'] - def initialize(self, context): + def initialize(self: DelayInstrumentProtocol, context: 'ExecutionContext') -> None: + """ + initialize delay instrument + """ if self.active_cooling: self.cooling = self._discover_cooling_module() if not self.cooling: - msg = 'Cooling module not found on target. Please install one of the following modules: {}' + msg: str = 'Cooling module not found on target. Please install one of the following modules: {}' raise InstrumentError(msg.format(self.active_cooling_modules)) if self.temperature_between_jobs == 0: @@ -155,9 +192,12 @@ def initialize(self, context): self.temperature_between_specs = temp @extremely_slow - def start(self, context): + def start(self: DelayInstrumentProtocol, context: 'ExecutionContext') -> None: + """ + start delay instrument + """ if self.fixed_before_start: - msg = 'Waiting for {}s before running workload...' + msg: str = 'Waiting for {}s before running workload...' self.logger.info(msg.format(self.fixed_before_start)) time.sleep(self.fixed_before_start) elif self.temperature_before_start: @@ -165,7 +205,10 @@ def start(self, context): self.wait_for_temperature(self.temperature_before_start) @extremely_slow - def before_job(self, context): + def before_job(self: DelayInstrumentProtocol, context: 'ExecutionContext') -> None: + """ + run before job + """ if self.fixed_between_specs and context.spec_changed: msg = 'Waiting for {}s before starting new spec...' self.logger.info(msg.format(self.fixed_between_specs)) @@ -181,17 +224,24 @@ def before_job(self, context): self.logger.info('Waiting for temperature drop before starting new job...') self.wait_for_temperature(self.temperature_between_jobs) - def wait_for_temperature(self, temperature): + def wait_for_temperature(self: DelayInstrumentProtocol, temperature: int) -> None: + """ + wait for temperature + """ if self.active_cooling: - self.cooling.start() - self.do_wait_for_temperature(temperature) - self.cooling.stop() + if self.cooling: + self.cooling.start() + self.do_wait_for_temperature(temperature) + self.cooling.stop() else: self.do_wait_for_temperature(temperature) - def do_wait_for_temperature(self, temperature): - reading = self.target.read_int(self.temperature_file) - waiting_start_time = time.time() + def do_wait_for_temperature(self: DelayInstrumentProtocol, temperature: int) -> None: + """ + wait to cool to the specified temperature + """ + reading: int = self.target.read_int(self.temperature_file) + waiting_start_time: float = time.time() while reading > temperature: self.logger.debug('target temperature: {}'.format(reading)) if time.time() - waiting_start_time > self.temperature_timeout: @@ -200,26 +250,32 @@ def do_wait_for_temperature(self, temperature): time.sleep(self.temperature_poll_period) reading = self.target.read_int(self.temperature_file) - def validate(self): + def validate(self: DelayInstrumentProtocol): + """ + validate the delay instrument + """ if (self.temperature_between_specs is not None and self.fixed_between_specs is not None): raise ConfigError('Both fixed delay and thermal threshold specified for specs.') - if (self.temperature_between_jobs is not None + if (cast(DelayInstrumentProtocol, self).temperature_between_jobs is not None and self.fixed_between_jobs is not None): raise ConfigError('Both fixed delay and thermal threshold specified for jobs.') - if (self.temperature_before_start is not None + if (cast(DelayInstrumentProtocol, self).temperature_before_start is not None and self.fixed_before_start is not None): raise ConfigError('Both fixed delay and thermal threshold specified before start.') - if not any([self.temperature_between_specs, self.fixed_between_specs, - self.temperature_between_jobs, self.fixed_between_jobs, - self.temperature_before_start, self.fixed_before_start]): + if not any([cast(DelayInstrumentProtocol, self).temperature_between_specs, cast(DelayInstrumentProtocol, self).fixed_between_specs, + cast(DelayInstrumentProtocol, self).temperature_between_jobs, cast(DelayInstrumentProtocol, self).fixed_between_jobs, + cast(DelayInstrumentProtocol, self).temperature_before_start, cast(DelayInstrumentProtocol, self).fixed_before_start]): raise ConfigError('Delay instrument is enabled, but no delay is specified.') - def _discover_cooling_module(self): - cooling_module = None + def _discover_cooling_module(self) -> Optional[Instrument]: + """ + discover the cooling module + """ + cooling_module: Optional[Instrument] = None for module in self.active_cooling_modules: if self.target.has(module): if not cooling_module: diff --git a/wa/instruments/misc.py b/wa/instruments/misc.py index 106c399ff..fd2e3dd81 100644 --- a/wa/instruments/misc.py +++ b/wa/instruments/misc.py @@ -285,8 +285,8 @@ def update_output(self, context): context.add_artifact('interrupts [before]', self.before_file, kind='data', classifiers={'stage': 'before'}) # If workload execution failed, the after_file may not have been created. - if os.path.isfile(self.after_file): - diff_interrupt_files(self.before_file, self.after_file, _f(self.diff_file)) + if os.path.isfile(self.after_file or ''): + diff_interrupt_files(self.before_file or '', self.after_file or '', _f(self.diff_file or '')) context.add_artifact('interrupts [after]', self.after_file, kind='data', classifiers={'stage': 'after'}) context.add_artifact('interrupts [diff]', self.diff_file, kind='data', diff --git a/wa/instruments/poller/__init__.py b/wa/instruments/poller/__init__.py index 3205f1acb..52797e5b0 100644 --- a/wa/instruments/poller/__init__.py +++ b/wa/instruments/poller/__init__.py @@ -14,18 +14,33 @@ # pylint: disable=access-member-before-definition,attribute-defined-outside-init,unused-argument import os -import pandas as pd +import pandas as pd # type:ignore from wa import Instrument, Parameter, Executable from wa.framework import signal from wa.framework.exception import ConfigError, InstrumentError from wa.utils.trace_cmd import TraceCmdParser from wa.utils.types import list_or_string +from typing import cast, List, TYPE_CHECKING, Union, Optional +from signal import Signals +from typing_extensions import Protocol +if TYPE_CHECKING: + from wa.framework.execution import ExecutionContext + + +class FilePollerProtocol(Protocol): + name: str + description: str + sample_interval: int + files: Union[str, List[str]] + labels: Union[str, List[str]] + align_with_ftrace: bool + as_root: bool class FilePoller(Instrument): - name = 'file_poller' - description = """ + name: str = 'file_poller' + description: str = """ Polls the given files at a set sample interval. The values are output in CSV format. This instrument places a file called poller.csv in each iterations result directory. @@ -36,7 +51,7 @@ class FilePoller(Instrument): before writing them. """ - parameters = [ + parameters: List[Parameter] = [ Parameter('sample_interval', kind=int, default=1000, description="""The interval between samples in mS."""), Parameter('files', kind=list_or_string, mandatory=True, @@ -67,34 +82,40 @@ class FilePoller(Instrument): """), ] - def validate(self): - if not self.files: + def validate(self) -> None: + """ + validate files and labels to poll + """ + if not cast(FilePollerProtocol, self).files: raise ConfigError('You must specify atleast one file to poll') - if self.labels and any(['*' in f for f in self.files]): + if cast(FilePollerProtocol, self).labels and any(['*' in f for f in cast(FilePollerProtocol, self).files]): raise ConfigError('You cannot used manual labels with `*` wildcards') - def initialize(self, context): + def initialize(self, context: 'ExecutionContext'): + """ + initialize the file poller instrument + """ if not self.target.is_rooted and self.as_root: raise ConfigError('The target is not rooted, cannot run poller as root.') - host_poller = context.get_resource(Executable(self, self.target.abi, - "poller")) - target_poller = self.target.install(host_poller) + host_poller: Optional[str] = context.get_resource(Executable(self, self.target.abi or '', + "poller")) + target_poller: str = self.target.install(host_poller) - expanded_paths = [] - for path in self.files: + expanded_paths: List[str] = [] + for path in cast(FilePollerProtocol, self).files: if "*" in path: for p in self.target.list_directory(path): expanded_paths.append(p) else: expanded_paths.append(path) self.files = expanded_paths - if not self.labels: + if not cast(FilePollerProtocol, self).labels: self.labels = self._generate_labels() - self.target_output_path = self.target.path.join(self.target.working_directory, 'poller.csv') - self.target_log_path = self.target.path.join(self.target.working_directory, 'poller.log') + self.target_output_path = self.target.path.join(self.target.working_directory, 'poller.csv') if self.target.path else '' + self.target_log_path = self.target.path.join(self.target.working_directory, 'poller.log') if self.target.path else '' marker_option = '' - if self.align_with_ftrace: + if cast(FilePollerProtocol, self).align_with_ftrace: marker_option = '-m' signal.connect(self._adjust_timestamps, signal.AFTER_JOB_OUTPUT_PROCESSED) reopen_option = '' @@ -109,18 +130,27 @@ def initialize(self, context): self.target_output_path, self.target_log_path) - def start(self, context): + def start(self, context: 'ExecutionContext') -> None: + """ + start the file poller + """ self.target.kick_off(self.command, as_root=self.as_root) - def stop(self, context): - self.target.killall('poller', signal='TERM', as_root=self.as_root) - - def update_output(self, context): - host_output_file = os.path.join(context.output_directory, 'poller.csv') + def stop(self, context: 'ExecutionContext') -> None: + """ + stop the file poller + """ + self.target.killall('poller', signal=cast(Signals, 'TERM'), as_root=self.as_root) + + def update_output(self, context: 'ExecutionContext') -> None: + """ + update the file poller output + """ + host_output_file: str = os.path.join(context.output_directory, 'poller.csv') self.target.pull(self.target_output_path, host_output_file) context.add_artifact('poller-output', host_output_file, kind='data') - host_log_file = os.path.join(context.output_directory, 'poller.log') + host_log_file: str = os.path.join(context.output_directory, 'poller.log') self.target.pull(self.target_log_path, host_log_file) context.add_artifact('poller-log', host_log_file, kind='log') @@ -131,35 +161,44 @@ def update_output(self, context): if 'WARNING' in line: self.logger.warning(line.strip()) - def teardown(self, context): + def teardown(self, context: 'ExecutionContext'): + """ + teardown file poller + """ self.target.remove(self.target_output_path) self.target.remove(self.target_log_path) - def _generate_labels(self): + def _generate_labels(self) -> List[str]: + """ + generate labels + """ # Split paths into their parts - path_parts = [f.split(self.target.path.sep) for f in self.files] + path_parts: List[List[str]] = [f.split(self.target.path.sep) for f in self.files] if self.target.path else [] # Identify which parts differ between at least two of the paths - differ_map = [len(set(x)) > 1 for x in zip(*path_parts)] + differ_map: List[bool] = [len(set(x)) > 1 for x in zip(*path_parts)] # compose labels from path parts that differ - labels = [] + labels: List[str] = [] for pp in path_parts: - label_parts = [p for i, p in enumerate(pp[:-1]) - if i >= len(differ_map) or differ_map[i]] + label_parts: List[str] = [p for i, p in enumerate(pp[:-1]) + if i >= len(differ_map) or differ_map[i]] label_parts.append(pp[-1]) # always use file name even if same for all labels.append('-'.join(label_parts)) return labels - def _adjust_timestamps(self, context): - output_file = context.get_artifact_path('poller-output') - message = 'Adjusting timestamps inside "{}" to align with ftrace' + def _adjust_timestamps(self, context: 'ExecutionContext') -> None: + """ + adjust timestamps in output file to align with trace + """ + output_file: str = context.get_artifact_path('poller-output') + message: str = 'Adjusting timestamps inside "{}" to align with ftrace' self.logger.debug(message.format(output_file)) - trace_txt = context.get_artifact_path('trace-cmd-txt') + trace_txt: str = context.get_artifact_path('trace-cmd-txt') trace_parser = TraceCmdParser(filter_markers=False) - marker_timestamp = None + marker_timestamp: Optional[Union[int, float]] = None for event in trace_parser.parse(trace_txt): - if event.name == 'print' and 'POLLER_START' in event.text: + if event.name == 'print' and 'POLLER_START' in (event.text or ''): marker_timestamp = event.timestamp break diff --git a/wa/instruments/proc_stat/__init__.py b/wa/instruments/proc_stat/__init__.py index 2b07d7380..251b72d21 100644 --- a/wa/instruments/proc_stat/__init__.py +++ b/wa/instruments/proc_stat/__init__.py @@ -16,19 +16,22 @@ import time from datetime import datetime, timedelta -import pandas as pd +import pandas as pd # type: ignore from wa import Instrument, Parameter, File, InstrumentError +from typing import List, TYPE_CHECKING, Optional, cast +if TYPE_CHECKING: + from wa.framework.execution import ExecutionContext class ProcStatCollector(Instrument): - name = 'proc_stat' - description = ''' + name: str = 'proc_stat' + description: str = ''' Collect CPU load information from /proc/stat. ''' - parameters = [ + parameters: List[Parameter] = [ Parameter('period', int, default=5, constraint=lambda x: x > 0, description=''' @@ -36,14 +39,20 @@ class ProcStatCollector(Instrument): '''), ] - def initialize(self, context): # pylint: disable=unused-argument - self.host_script = context.get_resource(File(self, 'gather-load.sh')) - self.target_script = self.target.install(self.host_script) - self.target_output = self.target.get_workpath('proc-stat-raw.csv') - self.stop_file = self.target.get_workpath('proc-stat-stop.signal') + def initialize(self, context: 'ExecutionContext') -> None: # pylint: disable=unused-argument + """ + initialize proc stat collector + """ + self.host_script: Optional[str] = context.get_resource(File(self, 'gather-load.sh')) + self.target_script: Optional[str] = self.target.install(self.host_script) + self.target_output: Optional[str] = self.target.get_workpath('proc-stat-raw.csv') + self.stop_file: Optional[str] = self.target.get_workpath('proc-stat-stop.signal') - def setup(self, context): # pylint: disable=unused-argument - self.command = '{} sh {} {} {} {} {}'.format( + def setup(self, context: 'ExecutionContext') -> None: # pylint: disable=unused-argument + """ + setup proc stat collector + """ + self.command: str = '{} sh {} {} {} {} {}'.format( self.target.busybox, self.target_script, self.target.busybox, @@ -54,13 +63,22 @@ def setup(self, context): # pylint: disable=unused-argument self.target.remove(self.target_output) self.target.remove(self.stop_file) - def start(self, context): # pylint: disable=unused-argument + def start(self, context: 'ExecutionContext') -> None: # pylint: disable=unused-argument + """ + start proc stat collector + """ self.target.kick_off(self.command) - def stop(self, context): # pylint: disable=unused-argument + def stop(self, context: 'ExecutionContext') -> None: # pylint: disable=unused-argument + """ + stop proc stat collector + """ self.target.execute('{} touch {}'.format(self.target.busybox, self.stop_file)) - def update_output(self, context): + def update_output(self, context: 'ExecutionContext') -> None: + """ + update output of proc stat collector + """ self.logger.debug('Waiting for collector script to terminate...') self._wait_for_script() self.logger.debug('Waiting for collector script to terminate...') @@ -74,18 +92,24 @@ def update_output(self, context): total = deltas.sum(axis=1) util = (total - deltas.idle) / total * 100 out_df = pd.concat([df.timestamp, util], axis=1).dropna() - out_df.columns = ['timestamp', 'cpu_util'] + out_df.columns = cast(pd.Index, ['timestamp', 'cpu_util']) util_file = os.path.join(context.output_directory, 'proc-stat.csv') out_df.to_csv(util_file, index=False) context.add_artifact('proc-stat', util_file, kind='data') - def finalize(self, context): # pylint: disable=unused-argument + def finalize(self, context: 'ExecutionContext') -> None: # pylint: disable=unused-argument + """ + finalize proc stat collector + """ if self.cleanup_assets and getattr(self, 'target_output'): self.target.remove(self.target_output) self.target.remove(self.target_script) - def _wait_for_script(self): + def _wait_for_script(self) -> None: + """ + wait for proc stat collector to terminate + """ start_time = datetime.utcnow() timeout = timedelta(seconds=300) while self.target.file_exists(self.stop_file): diff --git a/wa/utils/android.py b/wa/utils/android.py index 176cea110..8f5e648d1 100644 --- a/wa/utils/android.py +++ b/wa/utils/android.py @@ -24,23 +24,26 @@ from wa.utils.serializer import read_pod, write_pod, Podable from wa.utils.types import enum from wa.utils.misc import atomic_write_path - +from typing import Optional, List, Generator, Any, Dict LogcatLogLevel = enum(['verbose', 'debug', 'info', 'warn', 'error', 'assert'], start=2) -log_level_map = ''.join(n[0].upper() for n in LogcatLogLevel.names) +log_level_map: str = ''.join(n[0].upper() for n in LogcatLogLevel.names) -logcat_logger = logging.getLogger('logcat') -apk_info_cache_logger = logging.getLogger('apk_info_cache') +logcat_logger: logging.Logger = logging.getLogger('logcat') +apk_info_cache_logger: logging.Logger = logging.getLogger('apk_info_cache') apk_info_cache = None class LogcatEvent(object): + """ + Represents a Logcat event + """ + __slots__: List[str] = ['timestamp', 'pid', 'tid', 'level', 'tag', 'message'] - __slots__ = ['timestamp', 'pid', 'tid', 'level', 'tag', 'message'] - - def __init__(self, timestamp, pid, tid, level, tag, message): + def __init__(self, timestamp: datetime, pid: int, tid: int, + level, tag: str, message: str): self.timestamp = timestamp self.pid = pid self.tid = tid @@ -59,29 +62,37 @@ def __repr__(self): class LogcatParser(object): - - def parse(self, filepath): + """ + Logcat parser + """ + def parse(self, filepath: str) -> Generator[LogcatEvent, Any, None]: + """ + parse logcat event file + """ with open(filepath, errors='replace') as fh: for line in fh: - event = self.parse_line(line) + event: Optional[LogcatEvent] = self.parse_line(line) if event: yield event - def parse_line(self, line): # pylint: disable=no-self-use + def parse_line(self, line: str) -> Optional[LogcatEvent]: # pylint: disable=no-self-use + """ + parse one logcat line + """ line = line.strip() if not line or line.startswith('-') or ': ' not in line: return None metadata, message = line.split(': ', 1) - parts = metadata.split(None, 5) + parts: List[str] = metadata.split(None, 5) try: - ts = ' '.join([parts.pop(0), parts.pop(0)]) - timestamp = datetime.strptime(ts, '%m-%d %H:%M:%S.%f').replace(year=datetime.now().year) + ts: str = ' '.join([parts.pop(0), parts.pop(0)]) + timestamp: datetime = datetime.strptime(ts, '%m-%d %H:%M:%S.%f').replace(year=datetime.now().year) pid = int(parts.pop(0)) tid = int(parts.pop(0)) level = LogcatLogLevel.levels[log_level_map.index(parts.pop(0))] - tag = (parts.pop(0) if parts else '').strip() + tag: str = (parts.pop(0) if parts else '').strip() except Exception as e: # pylint: disable=broad-except message = 'Invalid metadata for line:\n\t{}\n\tgot: "{}"' logcat_logger.warning(message.format(line, e)) @@ -94,10 +105,13 @@ def parse_line(self, line): # pylint: disable=no-self-use class ApkInfo(_ApkInfo, Podable): '''Implement ApkInfo as a Podable class.''' - _pod_serialization_version = 1 + _pod_serialization_version: int = 1 @staticmethod - def from_pod(pod): + def from_pod(pod: Dict[str, Any]) -> 'ApkInfo': + """ + create ApkInfo from pod + """ instance = ApkInfo() instance.path = pod['path'] instance.package = pod['package'] @@ -112,11 +126,14 @@ def from_pod(pod): instance._methods = pod['_methods'] return instance - def __init__(self, path=None): + def __init__(self, path: Optional[str] = None): super().__init__(path) self._pod_version = self._pod_serialization_version - def to_pod(self): + def to_pod(self) -> Dict[str, Any]: + """ + convert ApkInfo to pod + """ pod = super().to_pod() pod['path'] = self.path pod['package'] = self.package @@ -132,26 +149,37 @@ def to_pod(self): return pod @staticmethod - def _pod_upgrade_v1(pod): + def _pod_upgrade_v1(pod: Dict[str, Any]) -> Dict[str, Any]: + """ + pod upgrade function version 1 + """ pod['_pod_version'] = pod.get('_pod_version', 1) return pod class ApkInfoCache: - + """ + cache of Apk info + """ @staticmethod - def _check_env(): + def _check_env() -> None: + """ + check environment + """ if not os.path.exists(settings.cache_directory): os.makedirs(settings.cache_directory) - def __init__(self, path=settings.apk_info_cache_file): + def __init__(self, path: str = settings.apk_info_cache_file): self._check_env() self.path = path - self.last_modified = None - self.cache = {} + self.last_modified: Optional[os.stat_result] = None + self.cache: Dict[str, Dict] = {} self._update_cache() - def store(self, apk_info, apk_id, overwrite=True): + def store(self, apk_info: ApkInfo, apk_id: str, overwrite: bool = True) -> None: + """ + store Apk Info into cache + """ self._update_cache() if apk_id in self.cache and not overwrite: raise ValueError('ApkInfo for {} is already in cache.'.format(apk_info.path)) @@ -160,14 +188,20 @@ def store(self, apk_info, apk_id, overwrite=True): write_pod(self.cache, at_path) self.last_modified = os.stat(self.path) - def get_info(self, key): + def get_info(self, key: str) -> Optional[ApkInfo]: + """ + get apk info from cache + """ self._update_cache() pod = self.cache.get(key) info = ApkInfo.from_pod(pod) if pod else None return info - def _update_cache(self): + def _update_cache(self) -> None: + """ + update apk info cache + """ if not os.path.exists(self.path): return if self.last_modified != os.stat(self.path): @@ -176,22 +210,29 @@ def _update_cache(self): self.last_modified = os.stat(self.path) -def get_cacheable_apk_info(path): +def get_cacheable_apk_info(path: Optional[str]) -> Optional[ApkInfo]: + """ + get cacheable apk info + """ # pylint: disable=global-statement global apk_info_cache if not path: - return - stat = os.stat(path) - modified = stat.st_mtime - apk_id = '{}-{}'.format(path, modified) - info = apk_info_cache.get_info(apk_id) + return None + stat: os.stat_result = os.stat(path) + modified: float = stat.st_mtime + apk_id: str = '{}-{}'.format(path, modified) + if apk_info_cache: + info = apk_info_cache.get_info(apk_id) + else: + info = None if info: - msg = 'Using ApkInfo ({}) from cache'.format(info.package) + msg: str = 'Using ApkInfo ({}) from cache'.format(info.package) else: info = ApkInfo(path) - apk_info_cache.store(info, apk_id, overwrite=True) - msg = 'Storing ApkInfo ({}) in cache'.format(info.package) + if apk_info_cache: + apk_info_cache.store(info, apk_id, overwrite=True) + msg = 'Storing ApkInfo ({}) in cache'.format(info.package) apk_info_cache_logger.debug(msg) return info @@ -199,7 +240,11 @@ def get_cacheable_apk_info(path): apk_info_cache = ApkInfoCache() -def build_apk_launch_command(package, activity=None, apk_args=None): +def build_apk_launch_command(package: Optional[str], activity: Optional[str] = None, + apk_args: Optional[Dict] = None) -> str: + """ + build apk launch command + """ args_string = '' if apk_args: for k, v in apk_args.items(): diff --git a/wa/utils/cpustates.py b/wa/utils/cpustates.py index 60a16fb2f..2b23bcb80 100755 --- a/wa/utils/cpustates.py +++ b/wa/utils/cpustates.py @@ -22,21 +22,30 @@ from devlib.utils.csvutil import create_writer, csvwriter -from wa.utils.trace_cmd import TraceCmdParser, trace_has_marker, TRACE_MARKER_START, TRACE_MARKER_STOP +from wa.utils.trace_cmd import (TraceCmdParser, trace_has_marker, TRACE_MARKER_START, TRACE_MARKER_STOP, + DroppedEventsEvent, TraceCmdEvent) +from typing import (DefaultDict, Optional, List, TYPE_CHECKING, Any, Set, + Generator, Tuple, Union, cast, Pattern, Dict) +from typing_extensions import Protocol +if TYPE_CHECKING: + from wa.framework.target.info import CpuInfo -logger = logging.getLogger('cpustates') +logger: logging.Logger = logging.getLogger('cpustates') -INIT_CPU_FREQ_REGEX = re.compile(r'CPU (?P\d+) FREQUENCY: (?P\d+) kHZ') -DEVLIB_CPU_FREQ_REGEX = re.compile(r'cpu_frequency(?:_devlib):\s+state=(?P\d+)\s+cpu_id=(?P\d+)') +INIT_CPU_FREQ_REGEX: Pattern[str] = re.compile(r'CPU (?P\d+) FREQUENCY: (?P\d+) kHZ') +DEVLIB_CPU_FREQ_REGEX: Pattern[str] = re.compile(r'cpu_frequency(?:_devlib):\s+state=(?P\d+)\s+cpu_id=(?P\d+)') class CorePowerTransitionEvent(object): + """ + represents a core power transition event + """ + kind: str = 'transition' + __slots__: List[str] = ['timestamp', 'cpu_id', 'frequency', 'idle_state'] - kind = 'transition' - __slots__ = ['timestamp', 'cpu_id', 'frequency', 'idle_state'] - - def __init__(self, timestamp, cpu_id, frequency=None, idle_state=None): + def __init__(self, timestamp: Optional[Union[int, float]], cpu_id: int, + frequency: Optional[int] = None, idle_state: Optional[int] = None): if (frequency is None) == (idle_state is None): raise ValueError('Power transition must specify a frequency or an idle_state, but not both.') self.timestamp = timestamp @@ -54,11 +63,13 @@ def __repr__(self): class CorePowerDroppedEvents(object): + """ + represents core power dropped events + """ + kind: str = 'dropped_events' + __slots__: List[str] = ['cpu_id'] - kind = 'dropped_events' - __slots__ = ['cpu_id'] - - def __init__(self, cpu_id): + def __init__(self, cpu_id: int): self.cpu_id = cpu_id def __str__(self): @@ -68,11 +79,13 @@ def __str__(self): class TraceMarkerEvent(object): + """ + represents a trace marker event + """ + kind: str = 'marker' + __slots__: List[str] = ['name'] - kind = 'marker' - __slots__ = ['name'] - - def __init__(self, name): + def __init__(self, name: str): self.name = name def __str__(self): @@ -80,18 +93,26 @@ def __str__(self): class CpuPowerState(object): - - __slots__ = ['frequency', 'idle_state'] + """ + represents a cpu power state + """ + __slots__: List[str] = ['frequency', 'idle_state'] @property - def is_idling(self): + def is_idling(self) -> bool: + """ + checks whether cpu is idling + """ return self.idle_state is not None and self.idle_state >= 0 @property - def is_active(self): + def is_active(self) -> bool: + """ + check if cpu is active + """ return self.idle_state == -1 - def __init__(self, frequency=None, idle_state=None): + def __init__(self, frequency: Optional[int] = None, idle_state: Optional[int] = None): self.frequency = frequency self.idle_state = idle_state @@ -102,21 +123,29 @@ def __str__(self): class SystemPowerState(object): - - __slots__ = ['timestamp', 'cpus'] + """ + represents system power state + """ + __slots__: List[str] = ['timestamp', 'cpus'] @property - def num_cores(self): + def num_cores(self) -> int: + """ + number of cores + """ return len(self.cpus) - def __init__(self, num_cores, no_idle=False): - self.timestamp = None - self.cpus = [] - idle_state = -1 if no_idle else None + def __init__(self, num_cores: int, no_idle: bool = False): + self.timestamp: Optional[Union[int, float]] = None + self.cpus: List[CpuPowerState] = [] + idle_state: Optional[int] = -1 if no_idle else None for _ in range(num_cores): self.cpus.append(CpuPowerState(idle_state=idle_state)) - def copy(self): + def copy(self) -> 'SystemPowerState': + """ + return a copy the current system power state + """ new = SystemPowerState(self.num_cores) new.timestamp = self.timestamp for i, c in enumerate(self.cpus): @@ -138,30 +167,46 @@ class PowerStateProcessor(object): """ @property - def cpu_states(self): + def cpu_states(self) -> List[CpuPowerState]: + """ + get a list of cpu power states + """ return self.power_state.cpus @property - def current_time(self): + def current_time(self) -> Optional[Union[int, float]]: + """ + get current timestamp + """ return self.power_state.timestamp @current_time.setter - def current_time(self, value): + def current_time(self, value: Optional[Union[int, float]]) -> None: + """ + set current timestamp + """ self.power_state.timestamp = value - def __init__(self, cpus, wait_for_marker=True, no_idle=None): + def __init__(self, cpus: List['CpuInfo'], wait_for_marker: bool = True, + no_idle: Optional[bool] = None): if no_idle is None: no_idle = not (cpus[0].cpuidle and cpus[0].cpuidle.states) self.power_state = SystemPowerState(len(cpus), no_idle=no_idle) - self.requested_states = {} # cpu_id -> requeseted state + self.requested_states: Dict[int, Optional[int]] = {} # cpu_id -> requeseted state self.wait_for_marker = wait_for_marker - self._saw_start_marker = False - self._saw_stop_marker = False - self.exceptions = [] + self._saw_start_marker: bool = False + self._saw_stop_marker: bool = False + self.exceptions: List[Exception] = [] - self.idle_related_cpus = build_idle_state_map(cpus) + self.idle_related_cpus: DefaultDict[Tuple[int, Optional[int]], + List[int]] = build_idle_state_map(cpus) - def process(self, event_stream): + def process(self, event_stream: Generator[Union[CorePowerTransitionEvent, + CorePowerDroppedEvents, + TraceMarkerEvent], Any, None]) -> Generator[SystemPowerState, Any, None]: + """ + process the power state event stream + """ for event in event_stream: try: next_state = self.update_power_state(event) @@ -175,26 +220,30 @@ def process(self, event_stream): if self.wait_for_marker: logger.warning("Did not see a STOP marker in the trace") - def update_power_state(self, event): + def update_power_state(self, + event: Union[CorePowerTransitionEvent, CorePowerDroppedEvents, TraceMarkerEvent]) -> SystemPowerState: """ Update the tracked power state based on the specified event and return updated power state. """ if event.kind == 'transition': - self._process_transition(event) + self._process_transition(cast(CorePowerTransitionEvent, event)) elif event.kind == 'dropped_events': - self._process_dropped_events(event) + self._process_dropped_events(cast(CorePowerDroppedEvents, event)) elif event.kind == 'marker': - if event.name == 'START': + if cast(TraceMarkerEvent, event).name == 'START': self._saw_start_marker = True - elif event.name == 'STOP': + elif cast(TraceMarkerEvent, event).name == 'STOP': self._saw_stop_marker = True else: raise ValueError('Unexpected event type: {}'.format(event.kind)) return self.power_state.copy() - def _process_transition(self, event): + def _process_transition(self, event: CorePowerTransitionEvent) -> None: + """ + process power state transition + """ self.current_time = event.timestamp if event.idle_state is None: self.cpu_states[event.cpu_id].frequency = event.frequency @@ -204,41 +253,53 @@ def _process_transition(self, event): else: self._process_idle_entry(event) - def _process_dropped_events(self, event): + def _process_dropped_events(self, event: CorePowerDroppedEvents) -> None: + """ + process dropped power state events + """ self.cpu_states[event.cpu_id].frequency = None old_idle_state = self.cpu_states[event.cpu_id].idle_state self.cpu_states[event.cpu_id].idle_state = None - related_ids = self.idle_related_cpus[(event.cpu_id, old_idle_state)] + related_ids: List[int] = self.idle_related_cpus[(event.cpu_id, old_idle_state)] for rid in related_ids: self.cpu_states[rid].idle_state = None - def _process_idle_entry(self, event): + def _process_idle_entry(self, event: CorePowerTransitionEvent) -> None: + """ + process idle power state entry + """ if self.cpu_states[event.cpu_id].is_idling: raise ValueError('Got idle state entry event for an idling core: {}'.format(event)) self.requested_states[event.cpu_id] = event.idle_state - self._try_transition_to_idle_state(event.cpu_id, event.idle_state) + self._try_transition_to_idle_state(event.cpu_id, event.idle_state or 0) - def _process_idle_exit(self, event): + def _process_idle_exit(self, event: CorePowerTransitionEvent) -> None: + """ + process idle power state exit + """ if self.cpu_states[event.cpu_id].is_active: raise ValueError('Got idle state exit event for an active core: {}'.format(event)) self.requested_states.pop(event.cpu_id, None) # remove outstanding request if there is one - old_state = self.cpu_states[event.cpu_id].idle_state + old_state: Optional[int] = self.cpu_states[event.cpu_id].idle_state self.cpu_states[event.cpu_id].idle_state = -1 - related_ids = self.idle_related_cpus[(event.cpu_id, old_state)] + related_ids: List[int] = self.idle_related_cpus[(event.cpu_id, old_state)] if old_state is not None: - new_state = old_state - 1 + new_state: int = old_state - 1 for rid in related_ids: - if self.cpu_states[rid].idle_state > new_state: + if (self.cpu_states[rid].idle_state or 0) > new_state: self._try_transition_to_idle_state(rid, new_state) - def _try_transition_to_idle_state(self, cpu_id, idle_state): - related_ids = self.idle_related_cpus[(cpu_id, idle_state)] + def _try_transition_to_idle_state(self, cpu_id: int, idle_state: int) -> None: + """ + try transition to idle state + """ + related_ids: List[int] = self.idle_related_cpus[(cpu_id, idle_state)] # Tristate: True - can transition, False - can't transition, # None - unknown idle state on at least one related cpu - transition_check = self._can_enter_state(related_ids, idle_state) + transition_check: Optional[bool] = self._can_enter_state(related_ids, idle_state) if transition_check is None: # Unknown state on a related cpu means we're not sure whether we're @@ -255,7 +316,7 @@ def _try_transition_to_idle_state(self, cpu_id, idle_state): for rid in related_ids: self.cpu_states[rid].idle_state = idle_state - def _can_enter_state(self, related_ids, state): + def _can_enter_state(self, related_ids: List[int], state: int) -> Optional[bool]: """ This is a tri-state check. Returns ``True`` if related cpu states allow transition into this state, ``False`` if related cpu states don't allow transition into this @@ -274,7 +335,13 @@ def _can_enter_state(self, related_ids, state): return True -def stream_cpu_power_transitions(events): +def stream_cpu_power_transitions(events: Generator[Union[DroppedEventsEvent, TraceCmdEvent], + Any, None]) -> Generator[Union[CorePowerTransitionEvent, + CorePowerDroppedEvents, + TraceMarkerEvent], Any, None]: + """ + stream cpu power transition events + """ for event in events: if event.name == 'cpu_idle': state = c_int32(event.state).value @@ -284,26 +351,31 @@ def stream_cpu_power_transitions(events): elif event.name == 'DROPPED EVENTS DETECTED': yield CorePowerDroppedEvents(event.cpu_id) elif event.name == 'print': - if TRACE_MARKER_START in event.text: + if TRACE_MARKER_START in (event.text or ''): yield TraceMarkerEvent('START') - elif TRACE_MARKER_STOP in event.text: + elif TRACE_MARKER_STOP in (event.text or ''): yield TraceMarkerEvent('STOP') else: - if 'cpu_frequency' in event.text: - match = DEVLIB_CPU_FREQ_REGEX.search(event.text) + if 'cpu_frequency' in (event.text or ''): + match = DEVLIB_CPU_FREQ_REGEX.search(event.text or '') else: - match = INIT_CPU_FREQ_REGEX.search(event.text) + match = INIT_CPU_FREQ_REGEX.search(event.text or '') if match: yield CorePowerTransitionEvent(event.timestamp, int(match.group('cpu')), frequency=int(match.group('freq'))) -def gather_core_states(system_state_stream, freq_dependent_idle_states=None): # NOQA +def gather_core_states(system_state_stream: Generator[SystemPowerState, Any, None], + freq_dependent_idle_states: Optional[List[int]] = None) -> Generator[Tuple[Optional[Union[int, float]], + List[Tuple[Optional[int], Optional[int]]]], Any, None]: # NOQA + """ + gather core power states + """ if freq_dependent_idle_states is None: freq_dependent_idle_states = [] for system_state in system_state_stream: - core_states = [] + core_states: List[Tuple[Optional[int], Optional[int]]] = [] for cpu in system_state.cpus: if cpu.idle_state == -1: core_states.append((-1, cpu.frequency)) @@ -317,53 +389,76 @@ def gather_core_states(system_state_stream, freq_dependent_idle_states=None): # yield (system_state.timestamp, core_states) -def record_state_transitions(reporter, stream): +def record_state_transitions(reporter: 'PowerStateTransitions', + stream: Generator[Union[CorePowerTransitionEvent, + CorePowerDroppedEvents, + TraceMarkerEvent], Any, None]) -> Generator[Union[CorePowerTransitionEvent, + CorePowerDroppedEvents, + TraceMarkerEvent], Any, None]: + """ + record power state transitions + """ for event in stream: if event.kind == 'transition': - reporter.record_transition(event) + reporter.record_transition(cast(CorePowerTransitionEvent, event)) yield event class PowerStateTransitions(object): - name = 'transitions-timeline' + name: str = 'transitions-timeline' - def __init__(self, output_directory): - self.filepath = os.path.join(output_directory, 'state-transitions-timeline.csv') + def __init__(self, output_directory: str): + self.filepath: str = os.path.join(output_directory, 'state-transitions-timeline.csv') self.writer, self._wfh = create_writer(self.filepath) - headers = ['timestamp', 'cpu_id', 'frequency', 'idle_state'] + headers: List[str] = ['timestamp', 'cpu_id', 'frequency', 'idle_state'] self.writer.writerow(headers) - def update(self, timestamp, core_states): # NOQA + def update(self, timestamp: Union[int, float], + core_states: List[Tuple[Optional[int], Optional[int]]]) -> None: # NOQA # Just recording transitions, not doing anything # with states. pass - def record_transition(self, transition): + def record_transition(self, transition: CorePowerTransitionEvent) -> None: + """ + record power transition + """ row = [transition.timestamp, transition.cpu_id, transition.frequency, transition.idle_state] self.writer.writerow(row) - def report(self): + def report(self) -> 'PowerStateTransitions': + """ + report power state transitions + """ return self - def write(self): + def write(self) -> None: + """ + write the power state transition and close file handle + """ self._wfh.close() class PowerStateTimeline(object): - name = 'state-timeline' + name: str = 'state-timeline' - def __init__(self, output_directory, cpus): - self.filepath = os.path.join(output_directory, 'power-state-timeline.csv') - self.idle_state_names = {cpu.id: [s.name for s in cpu.cpuidle.states] for cpu in cpus} + def __init__(self, output_directory: Optional[str], cpus: List['CpuInfo']): + self.filepath: str = os.path.join(output_directory or '', 'power-state-timeline.csv') + self.idle_state_names: Dict[Optional[int], + List[Optional[str]]] = {cpu.id: [s.name for s in cpu.cpuidle.states] for cpu in cpus} self.writer, self._wfh = create_writer(self.filepath) - headers = ['ts'] + ['{} CPU{}'.format(cpu.name, cpu.id) for cpu in cpus] + headers: List[str] = ['ts'] + ['{} CPU{}'.format(cpu.name, cpu.id) for cpu in cpus] self.writer.writerow(headers) - def update(self, timestamp, core_states): # NOQA - row = [timestamp] + def update(self, timestamp: Union[int, float], + core_states: List[Tuple[Optional[int], Optional[int]]]) -> None: # NOQA + """ + update power state timeline + """ + row: List[Union[int, float, str]] = [timestamp] for cpu_idx, (idle_state, frequency) in enumerate(core_states): if frequency is None: if idle_state == -1: @@ -373,7 +468,7 @@ def update(self, timestamp, core_states): # NOQA elif not self.idle_state_names[cpu_idx]: row.append('idle[{}]'.format(idle_state)) else: - row.append(self.idle_state_names[cpu_idx][idle_state]) + row.append(self.idle_state_names[cpu_idx][idle_state] or '') else: # frequency is not None if idle_state == -1: row.append(frequency) @@ -384,42 +479,52 @@ def update(self, timestamp, core_states): # NOQA frequency)) self.writer.writerow(row) - def report(self): + def report(self) -> 'PowerStateTimeline': + """ + report the power state timeline + """ return self - def write(self): + def write(self) -> None: + """ + write the power state timeline and close the file handle + """ self._wfh.close() class ParallelStats(object): - def __init__(self, output_directory, cpus, use_ratios=False): - self.filepath = os.path.join(output_directory, 'parallel-stats.csv') - self.clusters = defaultdict(set) + def __init__(self, output_directory: str, cpus: List['CpuInfo'], use_ratios: bool = False): + self.filepath: str = os.path.join(output_directory, 'parallel-stats.csv') + self.clusters: DefaultDict[str, Set] = defaultdict(set) self.use_ratios = use_ratios - clusters = [] + clusters: List[List[int]] = [] for cpu in cpus: if cpu.cpufreq.related_cpus not in clusters: clusters.append(cpu.cpufreq.related_cpus) for i, clust in enumerate(clusters): self.clusters[str(i)] = set(clust) - self.clusters['all'] = {cpu.id for cpu in cpus} + self.clusters['all'] = {cpu.id or 0 for cpu in cpus} - self.first_timestamp = None - self.last_timestamp = None - self.previous_states = None - self.parallel_times = defaultdict(lambda: defaultdict(int)) - self.running_times = defaultdict(int) + self.first_timestamp: Optional[Union[int, float]] = None + self.last_timestamp: Optional[Union[int, float]] = None + self.previous_states: Optional[List[Tuple[Optional[int], Optional[int]]]] = None + self.parallel_times: DefaultDict[str, Dict[int, Union[int, float]]] = defaultdict(lambda: defaultdict(int)) + self.running_times: DefaultDict[str, Union[int, float]] = defaultdict(int) - def update(self, timestamp, core_states): + def update(self, timestamp: Union[int, float], + core_states: List[Tuple[Optional[int], Optional[int]]]) -> None: + """ + update parallel stats + """ if self.last_timestamp is not None: - delta = timestamp - self.last_timestamp - active_cores = [i for i, c in enumerate(self.previous_states) - if c and c[0] == -1] + delta: Union[int, float] = timestamp - self.last_timestamp + active_cores: List[int] = [i for i, c in enumerate(self.previous_states or '') + if c and c[0] == -1] for cluster, cluster_cores in self.clusters.items(): - clust_active_cores = len(cluster_cores.intersection(active_cores)) + clust_active_cores: int = len(cluster_cores.intersection(active_cores)) self.parallel_times[cluster][clust_active_cores] += delta if clust_active_cores: self.running_times[cluster] += delta @@ -429,17 +534,20 @@ def update(self, timestamp, core_states): self.last_timestamp = timestamp self.previous_states = core_states - def report(self): # NOQA + def report(self) -> Optional['ParallelReport']: # NOQA + """ + report parallel stats + """ if self.last_timestamp is None: return None report = ParallelReport(self.filepath) - total_time = self.last_timestamp - self.first_timestamp + total_time: Union[int, float] = self.last_timestamp - (self.first_timestamp or 0) for cluster in sorted(self.parallel_times): - running_time = self.running_times[cluster] + running_time: Union[int, float] = self.running_times[cluster] for n in range(len(self.clusters[cluster]) + 1): - time = self.parallel_times[cluster][n] - time_pc = time / total_time + time: Union[int, float] = self.parallel_times[cluster][n] + time_pc: float = time / total_time if not self.use_ratios: time_pc *= 100 if n: @@ -451,8 +559,8 @@ def report(self): # NOQA running_time_pc *= 100 else: running_time_pc = 0 - precision = 3 if self.use_ratios else 1 - fmt = '{{:.{}f}}'.format(precision) + precision: int = 3 if self.use_ratios else 1 + fmt: str = '{{:.{}f}}'.format(precision) report.add([cluster, n, fmt.format(time), fmt.format(time_pc), @@ -463,16 +571,22 @@ def report(self): # NOQA class ParallelReport(object): - name = 'parallel-stats' + name: str = 'parallel-stats' - def __init__(self, filepath): + def __init__(self, filepath: str): self.filepath = filepath - self.values = [] + self.values: List[List[Union[int, str]]] = [] - def add(self, value): + def add(self, value: List[Union[int, str]]): + """ + add value to report + """ self.values.append(value) - def write(self): + def write(self) -> None: + """ + write report to csv file + """ with csvwriter(self.filepath) as writer: writer.writerow(['cluster', 'number_of_cores', 'total_time', '%time', '%running_time']) writer.writerows(self.values) @@ -480,26 +594,36 @@ def write(self): class PowerStateStats(object): - def __init__(self, output_directory, cpus, use_ratios=False): - self.filepath = os.path.join(output_directory, 'power-state-stats.csv') - self.core_names = [cpu.name for cpu in cpus] - self.idle_state_names = {cpu.id: [s.name for s in cpu.cpuidle.states] for cpu in cpus} + def __init__(self, output_directory: str, cpus: List['CpuInfo'], use_ratios: bool = False): + self.filepath: str = os.path.join(output_directory, 'power-state-stats.csv') + self.core_names: List[Optional[str]] = [cpu.name for cpu in cpus] + self.idle_state_names: Dict[Optional[int], + List[Optional[str]]] = {cpu.id: [s.name for s in cpu.cpuidle.states] for cpu in cpus} self.use_ratios = use_ratios - self.first_timestamp = None - self.last_timestamp = None - self.previous_states = None - self.cpu_states = defaultdict(lambda: defaultdict(int)) + self.first_timestamp: Optional[Union[int, float]] = None + self.last_timestamp: Optional[Union[int, float]] = None + self.previous_states: Optional[List[Tuple[Optional[int], Optional[int]]]] = None + self.cpu_states: DefaultDict[int, Dict[Optional[str], Union[int, float]]] = defaultdict(lambda: defaultdict(int)) - def update(self, timestamp, core_states): # NOQA + def update(self, timestamp: Union[int, float], + core_states: List[Tuple[Optional[int], Optional[int]]]) -> None: # NOQA + """ + update power state stats + """ if self.last_timestamp is not None: - delta = timestamp - self.last_timestamp + delta: Union[int, float] = timestamp - self.last_timestamp + if self.previous_states is None: + raise ValueError("previous_states should not be None here") for cpu, (idle, freq) in enumerate(self.previous_states): if idle == -1: if freq is not None: - state = '{:07}KHz'.format(freq) + state: Optional[str] = '{:07}KHz'.format(freq) else: state = 'Running (unknown KHz)' elif freq: + # Ensure idle is not None in this branch. + if idle is None: + raise ValueError("idle must not be None when freq is provided") state = '{}-{:07}KHz'.format(self.idle_state_names[cpu][idle], freq) elif idle is not None and self.idle_state_names[cpu]: state = self.idle_state_names[cpu][idle] @@ -512,11 +636,14 @@ def update(self, timestamp, core_states): # NOQA self.last_timestamp = timestamp self.previous_states = core_states - def report(self): + def report(self) -> Optional['PowerStateStatsReport']: + """ + report powerstate stats + """ if self.last_timestamp is None: return None - total_time = self.last_timestamp - self.first_timestamp - state_stats = defaultdict(lambda: [None] * len(self.core_names)) + total_time = self.last_timestamp - (self.first_timestamp or 0) + state_stats: Dict[Optional[str], List[Optional[float]]] = defaultdict(lambda: [None] * len(self.core_names)) for cpu, states in self.cpu_states.items(): for state in states: @@ -526,66 +653,94 @@ def report(self): time_pc *= 100 state_stats[state][cpu] = time_pc - precision = 3 if self.use_ratios else 1 + precision: int = 3 if self.use_ratios else 1 return PowerStateStatsReport(self.filepath, state_stats, self.core_names, precision) class PowerStateStatsReport(object): - name = 'power-state-stats' + name: str = 'power-state-stats' - def __init__(self, filepath, state_stats, core_names, precision=2): + def __init__(self, filepath: str, state_stats: Dict[Optional[str], List[Optional[float]]], + core_names: List[Optional[str]], precision: int = 2): self.filepath = filepath self.state_stats = state_stats self.core_names = core_names self.precision = precision - def write(self): + def write(self) -> None: + """ + write powerstate stats into csv file + """ with csvwriter(self.filepath) as writer: headers = ['state'] + ['{} CPU{}'.format(c, i) for i, c in enumerate(self.core_names)] writer.writerow(headers) - for state in sorted(self.state_stats): + for state in sorted(cast(Dict, self.state_stats)): stats = self.state_stats[state] fmt = '{{:.{}f}}'.format(self.precision) writer.writerow([state] + [fmt.format(s if s is not None else 0) for s in stats]) +class ReporterProtocol(Protocol): + def update(self, timestamp: Union[int, float], + core_states: List[Tuple[Optional[int], Optional[int]]]) -> None: + ... + + def report(self) -> Union[Optional[PowerStateStatsReport], Optional[ParallelReport], + PowerStateTimeline, 'CpuUtilizationTimeline', + PowerStateTransitions]: + ... + + class CpuUtilizationTimeline(object): - name = 'utilization-timeline' + name: str = 'utilization-timeline' - def __init__(self, output_directory, cpus): - self.filepath = os.path.join(output_directory, 'utilization-timeline.csv') + def __init__(self, output_directory: str, cpus: List['CpuInfo']): + self.filepath: str = os.path.join(output_directory, 'utilization-timeline.csv') self.writer, self._wfh = create_writer(self.filepath) - headers = ['ts'] + ['{} CPU{}'.format(cpu.name, cpu.id) for cpu in cpus] + headers: List[str] = ['ts'] + ['{} CPU{}'.format(cpu.name, cpu.id) for cpu in cpus] self.writer.writerow(headers) self._max_freq_list = [cpu.cpufreq.available_frequencies[-1] for cpu in cpus if cpu.cpufreq.available_frequencies] - def update(self, timestamp, core_states): # NOQA - row = [timestamp] + def update(self, timestamp: Union[int, float], + core_states: List[Tuple[Optional[int], Optional[int]]]) -> None: # NOQA + """ + update cpu utilization timeline + """ + row: List[Optional[Union[int, float]]] = [timestamp] for core, [_, frequency] in enumerate(core_states): if frequency is not None and core in self._max_freq_list: - frequency /= float(self._max_freq_list[core]) - row.append(frequency) + frequency_ = frequency / float(self._max_freq_list[core]) + row.append(frequency_) else: row.append(None) self.writer.writerow(row) - def report(self): + def report(self) -> 'CpuUtilizationTimeline': + """ + report cpu utilization timeline + """ return self - def write(self): + def write(self) -> None: + """ + write cpu utilization timeline to file and close it + """ self._wfh.close() -def build_idle_state_map(cpus): - idle_state_map = defaultdict(list) +def build_idle_state_map(cpus: List['CpuInfo']) -> DefaultDict[Tuple[int, Optional[int]], List[int]]: + """ + build map of idle states + """ + idle_state_map: DefaultDict[Tuple[int, Optional[int]], List[int]] = defaultdict(list) for cpu_idx, cpu in enumerate(cpus): - related_cpus = set(cpu.cpufreq.related_cpus) - set([cpu_idx]) - first_cluster_state = cpu.cpuidle.num_states - 1 + related_cpus: Set[int] = set(cpu.cpufreq.related_cpus) - set([cpu_idx]) + first_cluster_state: int = cpu.cpuidle.num_states - 1 for state_idx, _ in enumerate(cpu.cpuidle.states): if state_idx < first_cluster_state: idle_state_map[(cpu_idx, state_idx)] = [] @@ -594,8 +749,9 @@ def build_idle_state_map(cpus): return idle_state_map -def report_power_stats(trace_file, cpus, output_basedir, use_ratios=False, no_idle=None, # pylint: disable=too-many-locals - split_wfi_states=False): +def report_power_stats(trace_file: str, cpus: List['CpuInfo'], output_basedir: str, + use_ratios: bool = False, no_idle: Optional[bool] = None, # pylint: disable=too-many-locals + split_wfi_states: bool = False): """ Process trace-cmd output to generate timelines and statistics of CPU power state (a.k.a P- and C-state) transitions in the trace. @@ -656,11 +812,11 @@ def report_power_stats(trace_file, cpus, output_basedir, use_ratios=False, no_id 6. Update reporters/stats generators with cpu states. """ - output_directory = os.path.join(output_basedir, 'power-states') + output_directory: str = os.path.join(output_basedir, 'power-states') if not os.path.isdir(output_directory): os.mkdir(output_directory) - freq_dependent_idle_states = [] + freq_dependent_idle_states: List[int] = [] if split_wfi_states: freq_dependent_idle_states = [0] @@ -673,7 +829,7 @@ def report_power_stats(trace_file, cpus, output_basedir, use_ratios=False, no_id ps_processor = PowerStateProcessor(cpus, wait_for_marker=trace_has_marker(trace_file), no_idle=no_idle) transitions_reporter = PowerStateTransitions(output_directory) - reporters = [ + reporters: List[ReporterProtocol] = [ ParallelStats(output_directory, cpus, use_ratios), PowerStateStats(output_directory, cpus, use_ratios), PowerStateTimeline(output_directory, cpus), @@ -682,16 +838,21 @@ def report_power_stats(trace_file, cpus, output_basedir, use_ratios=False, no_id ] # assemble the pipeline - event_stream = parser.parse(trace_file) - transition_stream = stream_cpu_power_transitions(event_stream) - recorded_trans_stream = record_state_transitions(transitions_reporter, transition_stream) - power_state_stream = ps_processor.process(recorded_trans_stream) - core_state_stream = gather_core_states(power_state_stream, freq_dependent_idle_states) + event_stream: Generator[Union[DroppedEventsEvent, TraceCmdEvent], Any, None] = parser.parse(trace_file) + transition_stream: Generator[Union[CorePowerTransitionEvent, + CorePowerDroppedEvents, + TraceMarkerEvent], Any, None] = stream_cpu_power_transitions(event_stream) + recorded_trans_stream: Generator[Union[CorePowerTransitionEvent, + CorePowerDroppedEvents, + TraceMarkerEvent], Any, None] = record_state_transitions(transitions_reporter, transition_stream) + power_state_stream: Generator[SystemPowerState, Any, None] = ps_processor.process(recorded_trans_stream) + core_state_stream: Generator[Tuple[Optional[Union[int, float]], + List[Tuple[Optional[int], Optional[int]]]], Any, None] = gather_core_states(power_state_stream, freq_dependent_idle_states) # execute the pipeline for timestamp, states in core_state_stream: for reporter in reporters: - reporter.update(timestamp, states) + cast(ReporterProtocol, reporter).update(timestamp or 0, states) # report any issues encountered while executing the pipeline if ps_processor.exceptions: @@ -700,9 +861,12 @@ def report_power_stats(trace_file, cpus, output_basedir, use_ratios=False, no_id logger.warning(str(e)) # generate reports - reports = {} + reports: Dict[str, Union[Optional[PowerStateStatsReport], Optional[ParallelReport], + PowerStateTimeline, 'CpuUtilizationTimeline', + PowerStateTransitions]] = {} for reporter in reporters: - report = reporter.report() - report.write() - reports[report.name] = report + report = cast(ReporterProtocol, reporter).report() + if report: + report.write() + reports[report.name] = report return reports diff --git a/wa/utils/diff.py b/wa/utils/diff.py index 1db18bf1f..8c52bc593 100644 --- a/wa/utils/diff.py +++ b/wa/utils/diff.py @@ -19,24 +19,28 @@ from builtins import zip # pylint: disable=redefined-builtin -from future.moves.itertools import zip_longest +from future.moves.itertools import zip_longest # type:ignore from wa.utils.misc import diff_tokens, write_table from wa.utils.misc import ensure_file_directory_exists as _f +from typing import Optional, List -logger = logging.getLogger('diff') +logger: logging.Logger = logging.getLogger('diff') -def diff_interrupt_files(before, after, result): # pylint: disable=R0914 - output_lines = [] +def diff_interrupt_files(before: str, after: str, result: str) -> None: # pylint: disable=R0914 + """ + diff between interrupt stats files + """ + output_lines: List[List[str]] = [] with open(before) as bfh: with open(after) as ofh: for bline, aline in zip(bfh, ofh): bchunks = bline.strip().split() while True: - achunks = aline.strip().split() + achunks: List[str] = aline.strip().split() if achunks[0] == bchunks[0]: - diffchunks = [''] + diffchunks: List[str] = [''] diffchunks.append(achunks[0]) diffchunks.extend([diff_tokens(b, a) for b, a in zip(bchunks[1:], achunks[1:])]) @@ -58,8 +62,8 @@ def diff_interrupt_files(before, after, result): # pylint: disable=R0914 # columns -- they are a single column where space-spearated words got # split. Merge them back together to prevent them from being # column-aligned by write_table. - table_rows = [output_lines[0]] - num_cols = len(output_lines[0]) + table_rows: List[List[str]] = [output_lines[0]] + num_cols: int = len(output_lines[0]) for row in output_lines[1:]: table_row = row[:num_cols] table_row.append(' '.join(row[num_cols:])) @@ -69,14 +73,17 @@ def diff_interrupt_files(before, after, result): # pylint: disable=R0914 write_table(table_rows, wfh) -def diff_sysfs_dirs(before, after, result): # pylint: disable=R0914 - before_files = [] +def diff_sysfs_dirs(before: str, after: str, result: str) -> None: # pylint: disable=R0914 + """ + diff between sysfs directories + """ + before_files: List[str] = [] for root, _, files in os.walk(before): before_files.extend([os.path.join(root, f) for f in files]) before_files = list(filter(os.path.isfile, before_files)) files = [os.path.relpath(f, before) for f in before_files] - after_files = [os.path.join(after, f) for f in files] - diff_files = [os.path.join(result, f) for f in files] + after_files: List[str] = [os.path.join(after, f) for f in files] + diff_files: List[str] = [os.path.join(result, f) for f in files] for bfile, afile, dfile in zip(before_files, after_files, diff_files): if not os.path.isfile(afile): @@ -89,8 +96,8 @@ def diff_sysfs_dirs(before, after, result): # pylint: disable=R0914 if aline is None: logger.debug('Lines missing from {}'.format(afile)) break - bchunks = re.split(r'(\W+)', bline) - achunks = re.split(r'(\W+)', aline) + bchunks: List[str] = re.split(r'(\W+)', bline) + achunks: List[str] = re.split(r'(\W+)', aline) if len(bchunks) != len(achunks): logger.debug('Token length mismatch in {} on line {}'.format(bfile, i)) dfh.write('xxx ' + bline) diff --git a/wa/utils/doc.py b/wa/utils/doc.py index 7cf94dc12..7a1c0118d 100644 --- a/wa/utils/doc.py +++ b/wa/utils/doc.py @@ -23,12 +23,18 @@ import inspect from itertools import cycle -USER_HOME = os.path.expanduser('~') +from typing import (Type, Any, Match, Optional, List, Iterable, + TYPE_CHECKING, Union) +if TYPE_CHECKING: + from wa.framework.configuration.core import ConfigurationPoint + from wa.framework.plugin import Alias, Plugin, AliasCollection -BULLET_CHARS = '-*' +USER_HOME: str = os.path.expanduser('~') +BULLET_CHARS: str = '-*' -def get_summary(aclass): + +def get_summary(aclass: Type): """ Returns the summary description for an extension class. The summary is the first paragraph (separated by blank line) of the description taken either from @@ -39,7 +45,7 @@ def get_summary(aclass): return get_description(aclass).split('\n\n')[0] -def get_description(aclass): +def get_description(aclass: Type) -> str: """ Return the description of the specified extension class. The description is taken either from ``description`` attribute of the class or its docstring. @@ -48,34 +54,34 @@ def get_description(aclass): if hasattr(aclass, 'description') and aclass.description: return inspect.cleandoc(aclass.description) if aclass.__doc__: - return inspect.getdoc(aclass) + return inspect.getdoc(aclass) or '' else: return 'no documentation found for {}'.format(aclass.__name__) -def get_type_name(obj): +def get_type_name(obj: Any) -> str: """Returns the name of the type object or function specified. In case of a lambda, the definiition is returned with the parameter replaced by "value".""" - match = re.search(r"<(type|class|function) '?(.*?)'?>", str(obj)) + match: Optional[Match[str]] = re.search(r"<(type|class|function) '?(.*?)'?>", str(obj)) if isinstance(obj, tuple): - name = obj[1] - elif match.group(1) == 'function': - text = str(obj) + name: str = obj[1] + elif match is not None and match.group(1) == 'function': + text: str = str(obj) name = text.split()[1] if name.endswith(''): - source = inspect.getsource(obj).strip().replace('\n', ' ') + source: str = inspect.getsource(obj).strip().replace('\n', ' ') match = re.search(r'lambda\s+(\w+)\s*:\s*(.*?)\s*[\n,]', source) if not match: raise ValueError('could not get name for {}'.format(obj)) name = match.group(2).replace(match.group(1), 'value') else: - name = match.group(2) + name = match.group(2) if match else '' if '.' in name: name = name.split('.')[-1] return name -def count_leading_spaces(text): +def count_leading_spaces(text: str) -> int: """ Counts the number of leading space characters in a string. @@ -92,7 +98,7 @@ def count_leading_spaces(text): return nspaces -def format_column(text, width): +def format_column(text: str, width: int) -> str: """ Formats text into a column of specified width. If a line is too long, it will be broken on a word boundary. The new lines will have the same @@ -101,16 +107,16 @@ def format_column(text, width): Note: this will not attempt to join up lines that are too short. """ - formatted = [] + formatted: List[str] = [] for line in text.split('\n'): - line_len = len(line) + line_len: int = len(line) if line_len <= width: formatted.append(line) else: - words = line.split(' ') - new_line = words.pop(0) + words: List[str] = line.split(' ') + new_line: str = words.pop(0) while words: - next_word = words.pop(0) + next_word: str = words.pop(0) if (len(new_line) + len(next_word) + 1) < width: new_line += ' ' + next_word else: @@ -120,7 +126,8 @@ def format_column(text, width): return '\n'.join(formatted) -def format_bullets(text, width, char='-', shift=3, outchar=None): +def format_bullets(text: str, width: int, char: str = '-', + shift: int = 3, outchar: Optional[str] = None) -> str: """ Formats text into bulleted list. Assumes each line of input that starts with ``char`` (possibly preceeded with whitespace) is a new bullet point. Note: leading @@ -136,13 +143,13 @@ def format_bullets(text, width, char='-', shift=3, outchar=None): left as ``None``, ``char`` will be used. """ - bullet_lines = [] - output = '' + bullet_lines: List[str] = [] + output: str = '' - def __process_bullet(bullet_lines): + def __process_bullet(bullet_lines: List[str]) -> str: if bullet_lines: bullet = format_paragraph(indent(' '.join(bullet_lines), shift + 2), width) - bullet = bullet[:3] + outchar + bullet[4:] + bullet = bullet[:3] + (outchar or '') + bullet[4:] del bullet_lines[:] return bullet + '\n' else: @@ -160,26 +167,28 @@ def __process_bullet(bullet_lines): return output -def format_simple_table(rows, headers=None, align='>', show_borders=True, borderchar='='): # pylint: disable=R0914 +def format_simple_table(rows: Iterable, headers: Optional[List[str]] = None, + align: str = '>', show_borders: bool = True, + borderchar: str = '=') -> str: # pylint: disable=R0914 """Formats a simple table.""" if not rows: return '' rows = [list(map(str, r)) for r in rows] - num_cols = len(rows[0]) + num_cols: int = len(rows[0]) # cycle specified alignments until we have num_cols of them. This is # consitent with how such cases are handled in R, pandas, etc. it = cycle(align) - align = [next(it) for _ in range(num_cols)] + align_: List[str] = [next(it) for _ in range(num_cols)] - cols = list(zip(*rows)) - col_widths = [max(list(map(len, c))) for c in cols] + cols: List = list(zip(*rows)) + col_widths: List[int] = [max(list(map(len, c))) for c in cols] if headers: col_widths = [max(len(h), cw) for h, cw in zip(headers, col_widths)] - row_format = ' '.join(['{:%s%s}' % (align[i], w) for i, w in enumerate(col_widths)]) + row_format: str = ' '.join(['{:%s%s}' % (align_[i], w) for i, w in enumerate(col_widths)]) row_format += '\n' - border = row_format.format(*[borderchar * cw for cw in col_widths]) + border: str = row_format.format(*[borderchar * cw for cw in col_widths]) result = border if show_borders else '' if headers: @@ -192,7 +201,7 @@ def format_simple_table(rows, headers=None, align='>', show_borders=True, border return result -def format_paragraph(text, width): +def format_paragraph(text: str, width: int) -> str: """ Format the specified text into a column of specified with. The text is assumed to be a single paragraph and existing line breaks will not be preserved. @@ -203,7 +212,7 @@ def format_paragraph(text, width): return format_column(text, width) -def format_body(text, width): +def format_body(text: str, width: int) -> str: """ Format the specified text into a column of specified width. The text is assumed to be a "body" of one or more paragraphs separated by one or more @@ -212,8 +221,8 @@ def format_body(text, width): """ text = re.sub('\n\\s*\n', '\n\n', text.strip('\n')) # get rid of all-whitespace lines - paragraphs = re.split('\n\n+', text) - formatted_paragraphs = [] + paragraphs: List[str] = re.split('\n\n+', text) + formatted_paragraphs: List[str] = [] for p in paragraphs: if p.strip() and p.strip()[0] in BULLET_CHARS: formatted_paragraphs.append(format_bullets(p, width)) @@ -222,7 +231,7 @@ def format_body(text, width): return '\n\n'.join(formatted_paragraphs) -def strip_inlined_text(text): +def strip_inlined_text(text: str) -> str: """ This function processes multiline inlined text (e.g. form docstrings) to strip away leading spaces and leading and trailing new lines. @@ -233,12 +242,12 @@ def strip_inlined_text(text): # first line is special as it may not have the indet that follows the # others, e.g. if it starts on the same as the multiline quote ("""). - nspaces = count_leading_spaces(lines[0]) + nspaces: int = count_leading_spaces(lines[0]) if len([ln for ln in lines if ln]) > 1: - to_strip = min(count_leading_spaces(ln) for ln in lines[1:] if ln) + to_strip: int = min(count_leading_spaces(ln) for ln in lines[1:] if ln) if nspaces >= to_strip: - stripped = [lines[0][to_strip:]] + stripped: List[str] = [lines[0][to_strip:]] else: stripped = [lines[0][nspaces:]] stripped += [ln[to_strip:] for ln in lines[1:]] @@ -247,9 +256,9 @@ def strip_inlined_text(text): return '\n'.join(stripped).strip('\n') -def indent(text, spaces=4): +def indent(text: str, spaces: int = 4) -> str: """Indent the lines i the specified text by ``spaces`` spaces.""" - indented = [] + indented: List[str] = [] for line in text.split('\n'): if line: indented.append(' ' * spaces + line) @@ -258,7 +267,7 @@ def indent(text, spaces=4): return '\n'.join(indented) -def format_literal(lit): +def format_literal(lit: str) -> str: if isinstance(lit, str): return '``\'{}\'``'.format(lit) elif hasattr(lit, 'pattern'): # regex @@ -270,12 +279,15 @@ def format_literal(lit): return '``{}``'.format(lit) -def get_params_rst(parameters): - text = '' +def get_params_rst(parameters: List['ConfigurationPoint']) -> str: + """ + get parameters restructured text form + """ + text: str = '' for param in parameters: text += '{}: {}\n'.format(param.name, param.mandatory and '(mandatory)' or ' ') text += indent("type: ``'{}'``\n\n".format(get_type_name(param.kind))) - desc = strip_inlined_text(param.description or '') + desc: str = strip_inlined_text(param.description or '') text += indent('{}\n'.format(desc)) if param.aliases: text += indent('\naliases: {}\n'.format(', '.join(map(format_literal, param.aliases)))) @@ -294,35 +306,41 @@ def get_params_rst(parameters): return text -def get_aliases_rst(aliases): - text = '' +def get_aliases_rst(aliases: Union[List['Alias'], 'AliasCollection']) -> str: + """ + get aliases restructured text form + """ + text: str = '' for alias in aliases: - param_str = ', '.join(['{}={}'.format(n, format_literal(v)) - for n, v in alias.params.items()]) + param_str: str = ', '.join(['{}={}'.format(n, format_literal(v)) + for n, v in alias.params.items()]) text += '{}\n{}\n\n'.format(alias.name, indent(param_str)) return text -def underline(text, symbol='='): +def underline(text: str, symbol: str = '=') -> str: + """ + underline the text + """ return '{}\n{}\n\n'.format(text, symbol * len(text)) -def line_break(length=10, symbol='-'): +def line_break(length: int = 10, symbol: str = '-') -> str: """Insert a line break""" return '\n{}\n\n'.format(symbol * length) -def get_rst_from_plugin(plugin): - text = underline(plugin.name, '-') +def get_rst_from_plugin(plugin: Type['Plugin']): + text: str = underline(plugin.name or '', '-') if hasattr(plugin, 'description'): - desc = strip_inlined_text(plugin.description or '') + desc: str = strip_inlined_text(plugin.description or '') # type: ignore elif plugin.__doc__: desc = strip_inlined_text(plugin.__doc__) else: desc = '' text += desc + '\n\n' - aliases_rst = get_aliases_rst(plugin.aliases) + aliases_rst: str = get_aliases_rst(plugin.aliases) if aliases_rst: text += underline('aliases', '~') + aliases_rst diff --git a/wa/utils/exec_control.py b/wa/utils/exec_control.py index ba4bd1945..1463c3eb9 100644 --- a/wa/utils/exec_control.py +++ b/wa/utils/exec_control.py @@ -12,13 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from typing import Dict, List, Optional, Callable, Any # "environment" management: -__environments = {} -__active_environment = None +__environments: Dict[Optional[str], List[str]] = {} +__active_environment: Optional[str] = None -def activate_environment(name): +def activate_environment(name: str) -> None: """ Sets the current tracking environment to ``name``. If an environment with that name does not already exist, it will be @@ -32,7 +33,7 @@ def activate_environment(name): __active_environment = name -def init_environment(name): +def init_environment(name: str) -> None: """ Create a new environment called ``name``, but do not set it as the current environment. @@ -41,12 +42,12 @@ def init_environment(name): already exists. """ if name in list(__environments.keys()): - msg = "Environment {} already exists".format(name) + msg: str = "Environment {} already exists".format(name) raise ValueError(msg) __environments[name] = [] -def reset_environment(name=None): +def reset_environment(name: Optional[str] = None) -> None: """ Reset method call tracking for environment ``name``. If ``name`` is not specified or is ``None``, reset the current active environment. @@ -57,7 +58,7 @@ def reset_environment(name=None): if name is not None: if name not in list(__environments.keys()): - msg = "Environment {} does not exist".format(name) + msg: str = "Environment {} does not exist".format(name) raise ValueError(msg) __environments[name] = [] else: @@ -67,7 +68,7 @@ def reset_environment(name=None): # The decorators: -def once_per_instance(method): +def once_per_instance(method: Callable) -> Callable: """ The specified method will be invoked only once for every bound instance within the environment. @@ -76,7 +77,7 @@ def wrapper(*args, **kwargs): if __active_environment is None: activate_environment('default') func_id = repr(method.__hash__()) + repr(args[0].__hash__()) - if func_id in __environments[__active_environment]: + if func_id in (__environments[__active_environment] or ''): return else: __environments[__active_environment].append(func_id) @@ -85,16 +86,16 @@ def wrapper(*args, **kwargs): return wrapper -def once_per_class(method): +def once_per_class(method: Callable) -> Callable: """ The specified method will be invoked only once for all instances of a class within the environment. """ - def wrapper(*args, **kwargs): + def wrapper(*args, **kwargs) -> Any: if __active_environment is None: activate_environment('default') - func_id = repr(method.__name__) + repr(args[0].__class__) + func_id: str = repr(method.__name__) + repr(args[0].__class__) if func_id in __environments[__active_environment]: return @@ -105,19 +106,19 @@ def wrapper(*args, **kwargs): return wrapper -def once_per_attribute_value(attr_name): +def once_per_attribute_value(attr_name: str) -> Callable: """ The specified method will be invoked once for all instances that share the same value for the specified attribute (sameness is established by comparing repr() of the values). """ - def wrapped_once_per_attribute_value(method): - def wrapper(*args, **kwargs): + def wrapped_once_per_attribute_value(method: Callable) -> Any: + def wrapper(*args, **kwargs) -> Any: if __active_environment is None: activate_environment('default') attr_value = getattr(args[0], attr_name) - func_id = repr(method.__name__) + repr(args[0].__class__) + repr(attr_value) + func_id: str = repr(method.__name__) + repr(args[0].__class__) + repr(attr_value) if func_id in __environments[__active_environment]: return @@ -129,16 +130,16 @@ def wrapper(*args, **kwargs): return wrapped_once_per_attribute_value -def once(method): +def once(method: Callable) -> Callable: """ The specified method will be invoked only once within the environment. """ - def wrapper(*args, **kwargs): + def wrapper(*args, **kwargs) -> Any: if __active_environment is None: activate_environment('default') - func_id = repr(method.__code__) + func_id: str = repr(method.__code__) if func_id in __environments[__active_environment]: return diff --git a/wa/utils/formatter.py b/wa/utils/formatter.py index 970508686..6f109762f 100644 --- a/wa/utils/formatter.py +++ b/wa/utils/formatter.py @@ -15,9 +15,10 @@ from wa.utils.terminalsize import get_terminal_size +from typing import Optional, List, Any, Union +from typing_extensions import LiteralString - -INDENTATION_FROM_TITLE = 4 +INDENTATION_FROM_TITLE: int = 4 class TextFormatter(object): @@ -29,13 +30,13 @@ class TextFormatter(object): attribute represents the name of the foramtter. """ - name = None - data = None + name: Optional[str] = None + data: Optional[List[Any]] = None def __init__(self): pass - def add_item(self, new_data, item_title): + def add_item(self, new_data: str, item_title: str) -> None: """ Add new item to the text formatter. @@ -44,7 +45,7 @@ def add_item(self, new_data, item_title): """ raise NotImplementedError() - def format_data(self): + def format_data(self) -> Optional[str]: """ It returns a formatted text """ @@ -52,49 +53,57 @@ def format_data(self): class DescriptionListFormatter(TextFormatter): + """ + description list formatter + """ + name: str = 'description_list_formatter' + data: Optional[List[Any]] = None - name = 'description_list_formatter' - data = None + def __init__(self, title: Optional[str] = None, width: Optional[int] = None): + super(DescriptionListFormatter, self).__init__() + self.data_title = title + self._text_width = width + self.longest_word_length: int = 0 + self.data = [] - def get_text_width(self): + def get_text_width(self) -> Optional[int]: if not self._text_width: self._text_width, _ = get_terminal_size() # pylint: disable=unpacking-non-sequence return self._text_width - def set_text_width(self, value): + def set_text_width(self, value: int) -> None: self._text_width = value text_width = property(get_text_width, set_text_width) - def __init__(self, title=None, width=None): - super(DescriptionListFormatter, self).__init__() - self.data_title = title - self._text_width = width - self.longest_word_length = 0 - self.data = [] - - def add_item(self, new_data, item_title): + def add_item(self, new_data: str, item_title: str) -> None: + """ + add item to formatter + """ if len(item_title) > self.longest_word_length: self.longest_word_length = len(item_title) - self.data[len(self.data):] = [(item_title, self._remove_newlines(new_data))] + self.data[len(self.data):] = [(item_title, self._remove_newlines(new_data))] # type:ignore - def format_data(self): - parag_indentation = self.longest_word_length + INDENTATION_FROM_TITLE - string_formatter = '{}:<{}{} {}'.format('{', parag_indentation, '}', '{}') + def format_data(self) -> Optional[str]: + """ + format data + """ + parag_indentation: int = self.longest_word_length + INDENTATION_FROM_TITLE + string_formatter: str = '{}:<{}{} {}'.format('{', parag_indentation, '}', '{}') - formatted_data = '' + formatted_data: str = '' if self.data_title: formatted_data += self.data_title - line_width = self.text_width - parag_indentation - for title, paragraph in self.data: + line_width: int = (self.text_width or 0) - parag_indentation + for title, paragraph in (self.data or []): formatted_data += '\n' - title_len = self.longest_word_length - len(title) + title_len: int = self.longest_word_length - len(title) title += ':' if title_len > 0: title = (' ' * title_len) + title - parag_lines = self._break_lines(paragraph, line_width).splitlines() + parag_lines: List[LiteralString] = self._break_lines(paragraph, line_width).splitlines() if parag_lines: formatted_data += string_formatter.format(title, parag_lines[0]) for line in parag_lines[1:]: @@ -107,10 +116,13 @@ def format_data(self): # Return text's paragraphs sperated in a list, such that each index in the # list is a single text paragraph with no new lines - def _remove_newlines(self, new_data): # pylint: disable=R0201 - parag_list = [''] - parag_num = 0 - prv_parag = None + def _remove_newlines(self, new_data: str): # pylint: disable=R0201 + """ + remove newline characters + """ + parag_list: List[str] = [''] + parag_num: int = 0 + prv_parag: Optional[str] = None # For each paragraph sperated by a new line for paragraph in new_data.splitlines(): if paragraph: @@ -127,8 +139,11 @@ def _remove_newlines(self, new_data): # pylint: disable=R0201 return parag_list[:-1] return parag_list - def _break_lines(self, parag_list, line_width): # pylint: disable=R0201 - formatted_paragraphs = [] + def _break_lines(self, parag_list: List[LiteralString], line_width: int): # pylint: disable=R0201 + """ + break lines + """ + formatted_paragraphs: List[LiteralString] = [] for para in parag_list: words = para.split() if words: diff --git a/wa/utils/log.py b/wa/utils/log.py index 2868aa591..00611ad01 100644 --- a/wa/utils/log.py +++ b/wa/utils/log.py @@ -22,10 +22,15 @@ import subprocess import threading from contextlib import contextmanager +from typing_extensions import Protocol +from typing import (cast, Type, Optional, Union, + List, Generator, Any, Dict, Callable, + IO) +from louie import dispatcher # type: ignore -import colorama +import colorama # type: ignore -from devlib import DevlibError +from devlib.exception import DevlibError from wa.framework import signal from wa.framework.exception import WAError @@ -44,22 +49,29 @@ DEFAULT_INIT_BUFFER_CAPACITY = 1000 -_indent_level = 0 -_indent_width = 4 -_console_handler = None -_init_handler = None +_indent_level: int = 0 +_indent_width: int = 4 +_console_handler: Optional[logging.StreamHandler] = None +_init_handler: Optional['InitHandler'] = None + + +class LoggedExc(Protocol): + logged: bool # Declares the attribute for type checkers # pylint: disable=global-statement -def init(verbosity=logging.INFO, color=True, indent_with=4, - regular_fmt='%(levelname)-8s %(message)s', - verbose_fmt='%(asctime)s %(levelname)-8s %(name)10.10s: %(message)s', - debug=False): +def init(verbosity: int = logging.INFO, color: bool = True, indent_with: int = 4, + regular_fmt: str = '%(levelname)-8s %(message)s', + verbose_fmt: str = '%(asctime)s %(levelname)-8s %(name)10.10s: %(message)s', + debug: bool = False) -> None: + """ + initialize logger + """ global _indent_width, _console_handler, _init_handler _indent_width = indent_with - signal.log_error_func = lambda m: log_error(m, signal.logger) + signal.log_error_func = lambda m: log_error(m, signal.logger) # type: ignore - root_logger = logging.getLogger() + root_logger: logging.Logger = logging.getLogger() root_logger.setLevel(logging.DEBUG) error_handler = ErrorSignalHandler(logging.DEBUG) @@ -67,7 +79,7 @@ def init(verbosity=logging.INFO, color=True, indent_with=4, _console_handler = logging.StreamHandler() if color: - formatter = ColorFormatter + formatter: Type[logging.Formatter] = ColorFormatter else: formatter = LineFormatter if verbosity: @@ -89,18 +101,26 @@ def init(verbosity=logging.INFO, color=True, indent_with=4, logging.raiseExceptions = False logger = logging.getLogger('CGroups') + # FIXME - cannot assign to a method logger.info = logger.debug -def set_level(level): - _console_handler.setLevel(level) +def set_level(level: Union[int, str]) -> None: + """ + set log level + """ + if _console_handler: + _console_handler.setLevel(level) # pylint: disable=global-statement -def add_file(filepath, level=logging.DEBUG, - fmt='%(asctime)s %(levelname)-8s %(name)10.10s: %(message)s'): +def add_file(filepath: str, level: int = logging.DEBUG, + fmt: str = '%(asctime)s %(levelname)-8s %(name)10.10s: %(message)s') -> None: + """ + add log file + """ global _init_handler - root_logger = logging.getLogger() + root_logger: logging.Logger = logging.getLogger() file_handler = logging.FileHandler(filepath) file_handler.setLevel(level) file_handler.setFormatter(LineFormatter(fmt)) @@ -113,7 +133,10 @@ def add_file(filepath, level=logging.DEBUG, root_logger.addHandler(file_handler) -def enable(logs): +def enable(logs: Union[str, List[str]]) -> None: + """ + enable logging + """ if isinstance(logs, list): for log in logs: __enable_logger(log) @@ -121,7 +144,10 @@ def enable(logs): __enable_logger(logs) -def disable(logs): +def disable(logs: Union[str, List[str]]) -> None: + """ + disable logging + """ if isinstance(logs, list): for log in logs: __disable_logger(log) @@ -129,32 +155,44 @@ def disable(logs): __disable_logger(logs) -def __enable_logger(logger): +def __enable_logger(logger: Union[str, logging.Logger]) -> None: + """ + enable logger + """ if isinstance(logger, str): logger = logging.getLogger(logger) logger.propagate = True -def __disable_logger(logger): +def __disable_logger(logger: Union[str, logging.Logger]) -> None: + """ + disable logger + """ if isinstance(logger, str): logger = logging.getLogger(logger) logger.propagate = False # pylint: disable=global-statement -def indent(): +def indent() -> None: + """ + increase indent level + """ global _indent_level _indent_level += 1 # pylint: disable=global-statement -def dedent(): +def dedent() -> None: + """ + decrease indent level + """ global _indent_level _indent_level -= 1 @contextmanager -def indentcontext(): +def indentcontext() -> Generator[None, Any, None]: indent() try: yield @@ -163,14 +201,14 @@ def indentcontext(): # pylint: disable=global-statement -def set_indent_level(level): +def set_indent_level(level: int): global _indent_level old_level = _indent_level _indent_level = level return old_level -def log_error(e, logger, critical=False): +def log_error(e: BaseException, logger: logging.Logger, critical: Optional[bool] = False) -> None: """ Log the specified Exception as an error. The Error message will be formatted differently depending on the nature of the exception. @@ -190,18 +228,18 @@ def log_error(e, logger, critical=False): log_func = logger.error if isinstance(e, KeyboardInterrupt): - old_level = set_indent_level(0) + old_level: int = set_indent_level(0) logger.info('Got CTRL-C. Aborting.') set_indent_level(old_level) elif isinstance(e, (WAError, DevlibError)): log_func(str(e)) elif isinstance(e, subprocess.CalledProcessError): - tb = get_traceback() + tb: Optional[str] = get_traceback() log_func(tb) - command = e.cmd + command: str = e.cmd if e.args: command = '{} {}'.format(command, ' '.join(map(str, e.args))) - message = 'Command \'{}\' returned non-zero exit status {}\nOUTPUT:\n{}\n' + message: str = 'Command \'{}\' returned non-zero exit status {}\nOUTPUT:\n{}\n' log_func(message.format(command, e.returncode, e.output)) elif isinstance(e, SyntaxError): tb = get_traceback() @@ -214,7 +252,7 @@ def log_error(e, logger, critical=False): log_func(tb) log_func('{}({})'.format(e.__class__.__name__, e)) - e.logged = True + cast(LoggedExc, e).logged = True class ErrorSignalHandler(logging.Handler): @@ -223,11 +261,14 @@ class ErrorSignalHandler(logging.Handler): """ - def emit(self, record): + def emit(self, record: logging.LogRecord): + """ + emit a log record + """ if record.levelno == logging.ERROR: - signal.send(signal.ERROR_LOGGED, self, record) + signal.send(signal.ERROR_LOGGED, cast(Type[dispatcher.Anonymous], self), record) elif record.levelno == logging.WARNING: - signal.send(signal.WARNING_LOGGED, self, record) + signal.send(signal.WARNING_LOGGED, cast(Type[dispatcher.Anonymous], self), record) class InitHandler(logging.handlers.BufferingHandler): @@ -236,24 +277,36 @@ class InitHandler(logging.handlers.BufferingHandler): """ - def __init__(self, capacity): + def __init__(self, capacity: int): super(InitHandler, self).__init__(capacity) - self.targets = [] + self.targets: List[logging.Handler] = [] - def emit(self, record): + def emit(self, record: logging.LogRecord) -> None: + """ + emit a log record + """ record.indent_level = _indent_level super(InitHandler, self).emit(record) - def flush(self): + def flush(self) -> None: + """ + flush logs + """ for target in self.targets: self.flush_to_target(target) - self.buffer = [] + self.buffer: List[logging.LogRecord] = [] - def add_target(self, target): + def add_target(self, target: logging.Handler): + """ + add target handler + """ if target not in self.targets: self.targets.append(target) - def flush_to_target(self, target): + def flush_to_target(self, target: logging.Handler): + """ + emit log to target handler + """ for record in self.buffer: target.emit(record) @@ -264,19 +317,23 @@ class LineFormatter(logging.Formatter): """ - def format(self, record): + def format(self, record: logging.LogRecord) -> str: + """ + format lines of the message + """ record.message = record.getMessage() if self.usesTime(): record.asctime = self.formatTime(record, self.datefmt) - indent_level = getattr(record, 'indent_level', _indent_level) - cur_indent = _indent_width * indent_level - d = record.__dict__ - parts = [] + indent_level: int = getattr(record, 'indent_level', _indent_level) + cur_indent: int = _indent_width * indent_level + d: Dict[str, Any] = record.__dict__ + parts: List[str] = [] for line in record.message.split('\n'): line = ' ' * cur_indent + line d.update({'message': line.strip('\r')}) - parts.append(self._fmt % d) + if self._fmt: + parts.append(self._fmt % d) return '\n'.join(parts) @@ -294,23 +351,29 @@ class ColorFormatter(LineFormatter): """ - def __init__(self, fmt=None, datefmt=None): + def __init__(self, fmt: Optional[str] = None, datefmt: Optional[str] = None): super(ColorFormatter, self).__init__(fmt, datefmt) - template_text = self._fmt.replace('%(message)s', RESET_COLOR + '%(message)s${color}') + template_text = self._fmt.replace('%(message)s', RESET_COLOR + '%(message)s${color}') if self._fmt else '' template_text = '${color}' + template_text + RESET_COLOR self.fmt_template = string.Template(template_text) - def format(self, record): + def format(self, record: logging.LogRecord) -> str: + """ + format line with color + """ self._set_color(COLOR_MAP[record.levelno]) return super(ColorFormatter, self).format(record) - def _set_color(self, color): + def _set_color(self, color: str) -> None: + """ + set log color + """ self._fmt = self.fmt_template.substitute(color=color) class BaseLogWriter(object): - def __init__(self, name, level=logging.DEBUG): + def __init__(self, name: str, level: int = logging.DEBUG): """ File-like object class designed to be used for logging from streams Each complete line (terminated by new line character) gets logged @@ -319,10 +382,10 @@ def __init__(self, name, level=logging.DEBUG): :param name: The name of the logger that will be used. """ - self.logger = logging.getLogger(name) - self.buffer = '' + self.logger: logging.Logger = logging.getLogger(name) + self.buffer: str = '' if level == logging.DEBUG: - self.do_write = self.logger.debug + self.do_write: Callable = self.logger.debug elif level == logging.INFO: self.do_write = self.logger.info elif level == logging.WARNING: @@ -332,24 +395,36 @@ def __init__(self, name, level=logging.DEBUG): else: raise Exception('Unknown logging level: {}'.format(level)) - def flush(self): + def flush(self) -> 'BaseLogWriter': + """ + flush base log writer + """ # Defined to match the interface expected by pexpect. return self - def close(self): + def close(self) -> 'BaseLogWriter': + """ + close base log writer + """ if self.buffer: self.logger.debug(self.buffer) self.buffer = '' return self - def __del__(self): + def __del__(self) -> None: # Ensure we don't lose bufferd output self.close() class LogWriter(BaseLogWriter): + """ + Log writer + """ - def write(self, data): + def write(self, data: str) -> 'LogWriter': + """ + write logs + """ data = data.replace('\r\n', '\n').replace('\r', '\n') if '\n' in data: parts = data.split('\n') @@ -363,8 +438,14 @@ def write(self, data): class LineLogWriter(BaseLogWriter): + """ + Line log writer + """ - def write(self, data): + def write(self, data: str) -> None: + """ + write logs as lines + """ self.do_write(data) @@ -374,14 +455,17 @@ class StreamLogger(threading.Thread): """ - def __init__(self, name, stream, level=logging.DEBUG, klass=LogWriter): + def __init__(self, name: str, stream: IO, level: int = logging.DEBUG, klass: Type = LogWriter): super(StreamLogger, self).__init__() self.writer = klass(name, level) self.stream = stream self.daemon = True - def run(self): - line = self.stream.readline() + def run(self) -> None: + """ + run the stream logger + """ + line: str = self.stream.readline() while line: self.writer.write(line.rstrip('\n')) line = self.stream.readline() diff --git a/wa/utils/misc.py b/wa/utils/misc.py index 75cc3b892..2849141bb 100644 --- a/wa/utils/misc.py +++ b/wa/utils/misc.py @@ -35,6 +35,7 @@ import sys import traceback import uuid +import importlib.util from contextlib import contextmanager from datetime import datetime, timedelta from functools import reduce # pylint: disable=redefined-builtin @@ -50,7 +51,7 @@ except ImportError: from distutils.spawn import find_executable # pylint: disable=no-name-in-module, import-error -from dateutil import tz +from dateutil import tz # type:ignore # pylint: disable=wrong-import-order from devlib.exception import TargetError @@ -60,15 +61,19 @@ isiterable, getch, as_relative, ranges_to_list, memoized, list_to_ranges, list_to_mask, mask_to_list, which, to_identifier, safe_extract, LoadSyntaxError) +from devlib.target import Target +from devlib.module.cpufreq import CpufreqModule +from typing import (List, Union, Optional, cast, Dict, Any, IO, Type, + Tuple, Pattern, Match, OrderedDict, Generator) +from types import TracebackType, ModuleType +check_output_logger: logging.Logger = logging.getLogger('check_output') -check_output_logger = logging.getLogger('check_output') - -file_lock_logger = logging.getLogger('file_lock') -at_write_logger = logging.getLogger('at_write') +file_lock_logger: logging.Logger = logging.getLogger('file_lock') +at_write_logger: logging.Logger = logging.getLogger('at_write') # Defined here rather than in wa.exceptions due to module load dependencies -def diff_tokens(before_token, after_token): +def diff_tokens(before_token: str, after_token: str) -> str: """ Creates a diff of two tokens. @@ -96,7 +101,7 @@ def diff_tokens(before_token, after_token): return "[%s -> %s]" % (before_token, after_token) -def prepare_table_rows(rows): +def prepare_table_rows(rows: List[List[str]]): """Given a list of lists, make sure they are prepared to be formatted into a table by making sure each row has the same number of columns and stringifying all values.""" rows = [list(map(str, r)) for r in rows] @@ -108,23 +113,23 @@ def prepare_table_rows(rows): return rows -def write_table(rows, wfh, align='>', headers=None): # pylint: disable=R0914 +def write_table(rows: List[List[str]], wfh: IO[str], align: str = '>', headers: Optional[List[str]] = None): # pylint: disable=R0914 """Write a column-aligned table to the specified file object.""" if not rows: return rows = prepare_table_rows(rows) - num_cols = len(rows[0]) + num_cols: int = len(rows[0]) # cycle specified alignments until we have max_cols of them. This is # consitent with how such cases are handled in R, pandas, etc. it = cycle(align) - align = [next(it) for _ in range(num_cols)] + align_ = [next(it) for _ in range(num_cols)] - cols = list(zip(*rows)) - col_widths = [max(list(map(len, c))) for c in cols] + cols: List = list(zip(*rows)) + col_widths: List[int] = [max(list(map(len, c))) for c in cols] if headers: col_widths = [max([c, len(h)]) for c, h in zip(col_widths, headers)] - row_format = ' '.join(['{:%s%s}' % (align[i], w) for i, w in enumerate(col_widths)]) + row_format: str = ' '.join(['{:%s%s}' % (align_[i], w) for i, w in enumerate(col_widths)]) row_format += '\n' if headers: @@ -136,12 +141,13 @@ def write_table(rows, wfh, align='>', headers=None): # pylint: disable=R0914 wfh.write(row_format.format(*row)) -def get_null(): +def get_null() -> str: """Returns the correct null sink based on the OS.""" return 'NUL' if os.name == 'nt' else '/dev/null' -def get_traceback(exc=None): +def get_traceback(exc: Optional[Tuple[Optional[Type[BaseException]], + Optional[BaseException], Optional[TracebackType]]] = None) -> str: """ Returns the string with the traceback for the specifiec exc object, or for the current exception exc is not specified. @@ -151,14 +157,14 @@ def get_traceback(exc=None): exc = sys.exc_info() if not exc: return None - tb = exc[2] + tb: Optional[TracebackType] = exc[2] sio = StringIO() traceback.print_tb(tb, file=sio) del tb # needs to be done explicitly see: http://docs.python.org/2/library/sys.html#sys.exc_info return sio.getvalue() -def _check_remove_item(the_list, item): +def _check_remove_item(the_list: List[str], item: str) -> bool: """Helper function for merge_lists that implements checking wether an items should be removed from the list and doing so if needed. Returns ``True`` if the item has been removed and ``False`` otherwise.""" @@ -172,9 +178,9 @@ def _check_remove_item(the_list, item): return True -VALUE_REGEX = re.compile(r'(\d+(?:\.\d+)?)\s*(\w*)') +VALUE_REGEX: Pattern[str] = re.compile(r'(\d+(?:\.\d+)?)\s*(\w*)') -UNITS_MAP = { +UNITS_MAP: Dict[str, str] = { 's': 'seconds', 'ms': 'milliseconds', 'us': 'microseconds', @@ -186,7 +192,8 @@ def _check_remove_item(the_list, item): } -def parse_value(value_string): +def parse_value(value_string: str) -> Union[Tuple[Union[float, int], str], + Tuple[str, None]]: """parses a string representing a numerical value and returns a tuple (value, units), where value will be either int or float, and units will be a string representing the units or None.""" @@ -201,7 +208,7 @@ def parse_value(value_string): return (value_string, None) -def get_meansd(values): +def get_meansd(values: List[Union[int, float]]) -> Tuple[float, float]: """Returns mean and standard deviation of the specified values.""" if not values: return float('nan'), float('nan') @@ -210,12 +217,12 @@ def get_meansd(values): return mean, sd -def geomean(values): +def geomean(values: List[Union[int, float]]) -> float: """Returns the geometric mean of the values.""" return reduce(mul, values) ** (1.0 / len(values)) -def capitalize(text): +def capitalize(text: str) -> str: """Capitalises the specified text: first letter upper case, all subsequent letters lower case.""" if not text: @@ -223,31 +230,31 @@ def capitalize(text): return text[0].upper() + text[1:].lower() -def utc_to_local(dt): +def utc_to_local(dt: datetime) -> datetime: """Convert naive datetime to local time zone, assuming UTC.""" return dt.replace(tzinfo=tz.tzutc()).astimezone(tz.tzlocal()) -def local_to_utc(dt): +def local_to_utc(dt: datetime) -> datetime: """Convert naive datetime to UTC, assuming local time zone.""" return dt.replace(tzinfo=tz.tzlocal()).astimezone(tz.tzutc()) -def load_class(classpath): +def load_class(classpath: str) -> Type: """Loads the specified Python class. ``classpath`` must be a fully-qualified class name (i.e. namspaced under module/package).""" modname, clsname = classpath.rsplit('.', 1) - mod = importlib.import_module(modname) - cls = getattr(mod, clsname) + mod: ModuleType = importlib.import_module(modname) + cls: Type = getattr(mod, clsname) if isinstance(cls, type): return cls else: raise ValueError(f'The classpath "{classpath}" does not point at a class: {cls}') -def get_pager(): +def get_pager() -> Optional[str]: """Returns the name of the system pager program.""" - pager = os.getenv('PAGER') + pager: Optional[str] = os.getenv('PAGER') if pager is None: pager = find_executable('less') if pager is None: @@ -255,27 +262,31 @@ def get_pager(): return pager -_bash_color_regex = re.compile('\x1b\[[0-9;]+m') +_bash_color_regex: Pattern[str] = re.compile(r'\x1b\[[0-9;]+m') -def strip_bash_colors(text): +def strip_bash_colors(text: str) -> str: + """ + strip bash colors + """ return _bash_color_regex.sub('', text) -def format_duration(seconds, sep=' ', order=['day', 'hour', 'minute', 'second']): # pylint: disable=dangerous-default-value +def format_duration(seconds: Union[int, timedelta], sep: str = ' ', + order: List[str] = ['day', 'hour', 'minute', 'second']) -> str: # pylint: disable=dangerous-default-value """ Formats the specified number of seconds into human-readable duration. """ if isinstance(seconds, timedelta): - td = seconds + td: timedelta = seconds else: td = timedelta(seconds=seconds or 0) dt = datetime(1, 1, 1) + td - result = [] + result: List[str] = [] for item in order: value = getattr(dt, item, None) - if item == 'day': + if item == 'day' and value is not None: value -= 1 if not value: continue @@ -284,7 +295,7 @@ def format_duration(seconds, sep=' ', order=['day', 'hour', 'minute', 'second']) return sep.join(result) if result else 'N/A' -def get_article(word): +def get_article(word: str) -> str: """ Returns the appropriate indefinite article for the word (ish). @@ -296,12 +307,12 @@ def get_article(word): return 'an' if word[0] in 'aoeiu' else 'a' -def get_random_string(length): +def get_random_string(length: int) -> str: """Returns a random ASCII string of the specified length).""" return ''.join(random.choice(string.ascii_letters + string.digits) for _ in range(length)) -def import_path(filepath, module_name=None): +def import_path(filepath: str, module_name: Optional[str] = None) -> ModuleType: """ Programmatically import the given Python source file under the name ``module_name``. If ``module_name`` is not provided, a stable name based on @@ -318,10 +329,12 @@ def import_path(filepath, module_name=None): return sys.modules[module_name] except KeyError: spec = importlib.util.spec_from_file_location(module_name, filepath) - module = importlib.util.module_from_spec(spec) + if spec: + module = importlib.util.module_from_spec(spec) try: sys.modules[module_name] = module - spec.loader.exec_module(module) + if spec and spec.loader: + spec.loader.exec_module(module) except BaseException: sys.modules.pop(module_name, None) raise @@ -333,14 +346,14 @@ def import_path(filepath, module_name=None): return sys.modules[module_name] -def load_struct_from_python(filepath): +def load_struct_from_python(filepath: str) -> Dict[str, Any]: """Parses a config structure from a .py file. The structure should be composed of basic Python types (strings, ints, lists, dicts, etc.).""" try: - mod = import_path(filepath) + mod: ModuleType = import_path(filepath) except SyntaxError as e: - raise LoadSyntaxError(e.message, filepath, e.lineno) + raise LoadSyntaxError(e.msg, filepath, e.lineno) else: return { k: v @@ -349,7 +362,7 @@ def load_struct_from_python(filepath): } -def open_file(filepath): +def open_file(filepath: str) -> int: """ Open the specified file path with the associated launcher in an OS-agnostic way. @@ -362,32 +375,32 @@ def open_file(filepath): return subprocess.call(['xdg-open', filepath]) -def sha256(path, chunk=2048): +def sha256(path: str, chunk: int = 2048): """Calculates SHA256 hexdigest of the file at the specified path.""" h = hashlib.sha256() with open(path, 'rb') as fh: - buf = fh.read(chunk) + buf: bytes = fh.read(chunk) while buf: h.update(buf) buf = fh.read(chunk) return h.hexdigest() -def urljoin(*parts): +def urljoin(*parts) -> str: return '/'.join(p.rstrip('/') for p in parts) # From: http://eli.thegreenplace.net/2011/10/19/perls-guess-if-file-is-text-or-binary-implemented-in-python/ -def istextfile(fileobj, blocksize=512): +def istextfile(fileobj: IO, blocksize: int = 512) -> bool: """ Uses heuristics to guess whether the given file is text or binary, by reading a single block of bytes from the file. If more than 30% of the chars in the block are non-text, or there are NUL ('\x00') bytes in the block, assume this is a binary file. """ - _text_characters = (b''.join(chr(i) for i in range(32, 127)) - + b'\n\r\t\f\b') + _text_characters: bytes = (b''.join(bytes(i) for i in range(32, 127)) + + b'\n\r\t\f\b') - block = fileobj.read(blocksize) + block: bytes = fileobj.read(blocksize) if b'\x00' in block: # Files with null bytes are binary return False @@ -401,7 +414,10 @@ def istextfile(fileobj, blocksize=512): return float(len(nontext)) / len(block) <= 0.30 -def categorize(v): +def categorize(v: Any) -> str: + """ + categorize an object + """ if hasattr(v, 'merge_with') and hasattr(v, 'merge_into'): return 'o' elif hasattr(v, 'items'): @@ -415,7 +431,7 @@ def categorize(v): # pylint: disable=too-many-return-statements,too-many-branches -def merge_config_values(base, other): +def merge_config_values(base: Any, other: Any) -> Any: """ This is used to merge two objects, typically when setting the value of a ``ConfigurationPoint``. First, both objects are categorized into @@ -476,8 +492,8 @@ def merge_config_values(base, other): configuration point values. """ - cat_base = categorize(base) - cat_other = categorize(other) + cat_base: str = categorize(base) + cat_other: str = categorize(other) if cat_base == 'n': return other @@ -512,27 +528,42 @@ def merge_config_values(base, other): return other -def merge_sequencies(s1, s2): - return type(s2)(unique(chain(s1, s2))) +def merge_sequencies(s1: List, s2: List) -> List: + """ + merge sequences + """ + return type(s2)(unique(cast(List, chain(s1, s2)))) -def merge_maps(m1, m2): +def merge_maps(m1: Dict, m2: Dict) -> Dict: + """ + merge dicts + """ return type(m2)(chain(iter(m1.items()), iter(m2.items()))) -def merge_dicts_simple(base, other): - result = base.copy() +def merge_dicts_simple(base: Dict, other: Dict) -> Dict: + """ + merge dicts + """ + result: Dict = base.copy() for key, value in (other or {}).items(): result[key] = merge_config_values(result.get(key), value) return result -def touch(path): +def touch(path: str) -> None: + """ + open and clear file in the path + """ with open(path, 'w'): pass -def get_object_name(obj): +def get_object_name(obj: Any) -> Optional[str]: + """ + get name of the object + """ if hasattr(obj, 'name'): return obj.name elif hasattr(obj, '__func__') and hasattr(obj, '__self__'): @@ -547,7 +578,7 @@ def get_object_name(obj): return None -def resolve_cpus(name, target): +def resolve_cpus(name: Optional[Union[str, int]], target: Target) -> List[int]: """ Returns a list of cpu numbers that corresponds to a passed name. Allowed formats are: @@ -558,7 +589,7 @@ def resolve_cpus(name, target): - 'all' - returns all cpus - '' - Empty name will also return all cpus """ - cpu_list = list(range(target.number_of_cpus)) + cpu_list: List[int] = list(range(target.number_of_cpus)) # Support for passing cpu no directly if isinstance(name, int): @@ -587,7 +618,7 @@ def resolve_cpus(name, target): # Check if core number has been supplied. else: - core_no = re.match('cpu([0-9]+)', name, re.IGNORECASE) + core_no: Optional[Match[str]] = re.match('cpu([0-9]+)', name, re.IGNORECASE) if core_no: cpu = int(core_no.group(1)) if cpu not in cpu_list: @@ -595,33 +626,33 @@ def resolve_cpus(name, target): raise ValueError(message.format(cpu, cpu_list)) return [cpu] else: - msg = 'Unexpected core name "{}"' + msg: str = 'Unexpected core name "{}"' raise ValueError(msg.format(name)) @memoized -def resolve_unique_domain_cpus(name, target): +def resolve_unique_domain_cpus(name: str, target: Target) -> List[int]: """ Same as `resolve_cpus` above but only returns only the first cpu in each of the different frequency domains. Requires cpufreq. """ - cpus = resolve_cpus(name, target) + cpus: List[int] = resolve_cpus(name, target) if not target.has('cpufreq'): - msg = 'Device does not appear to support cpufreq; ' \ - 'Cannot obtain cpu domain information' + msg: str = 'Device does not appear to support cpufreq; ' \ + 'Cannot obtain cpu domain information' raise TargetError(msg) - unique_cpus = [] - domain_cpus = [] + unique_cpus: List[int] = [] + domain_cpus: List[int] = [] for cpu in cpus: if cpu not in domain_cpus: - domain_cpus = target.cpufreq.get_related_cpus(cpu) + domain_cpus = cast(CpufreqModule, target.cpufreq).get_related_cpus(cpu) if domain_cpus[0] not in unique_cpus: unique_cpus.append(domain_cpus[0]) return unique_cpus -def format_ordered_dict(od): +def format_ordered_dict(od: OrderedDict) -> str: """ Provide a string representation of ordered dict that is similar to the regular dict representation, as that is more concise and easier to read @@ -632,14 +663,14 @@ def format_ordered_dict(od): @contextmanager -def atomic_write_path(path, mode='w'): +def atomic_write_path(path: str, mode: str = 'w') -> Generator[str, Any, None]: """ Gets a file path to write to which will be replaced with the original file path to simulate an atomic write from the point of view of other processes. This is achieved by writing to a tmp file and replacing the exiting file to prevent inconsistencies. """ - tmp_file = None + tmp_file: Optional[IO] = None try: tmp_file = NamedTemporaryFile(mode=mode, delete=False, suffix=os.path.basename(path)) @@ -649,11 +680,11 @@ def atomic_write_path(path, mode='w'): finally: if tmp_file: tmp_file.close() - at_write_logger.debug('Moving {} to {}'.format(tmp_file.name, path)) - safe_move(tmp_file.name, path) + at_write_logger.debug('Moving {} to {}'.format(tmp_file.name if tmp_file else '', path)) + safe_move(tmp_file.name if tmp_file else '', path) -def safe_move(src, dst): +def safe_move(src: str, dst: str) -> None: """ Taken from: https://alexwlchan.net/2019/03/atomic-cross-filesystem-moves-in-python/ @@ -676,7 +707,7 @@ def safe_move(src, dst): # across a filesystem boundary, this initial copy may not be # atomic. We intersperse a random UUID so if different processes # are copying into ``, they don't overlap in their tmp copies. - copy_id = uuid.uuid4() + copy_id: uuid.UUID = uuid.uuid4() tmp_dst = "%s.%s.tmp" % (dst, copy_id) shutil.copyfile(src, tmp_dst) @@ -689,7 +720,7 @@ def safe_move(src, dst): @contextmanager -def lock_file(path, timeout=30): +def lock_file(path: str, timeout: int = 30): """ Enable automatic locking and unlocking of a file path given. Used to prevent synchronisation issues between multiple wa processes. @@ -701,8 +732,8 @@ def lock_file(path, timeout=30): # pylint: disable=wrong-import-position,cyclic-import, import-outside-toplevel from wa.framework.exception import ResourceError - locked = False - l_file = 'wa-{}.lock'.format(path) + locked: bool = False + l_file: str = 'wa-{}.lock'.format(path) l_file = os.path.join(gettempdir(), l_file.replace(os.path.sep, '_')) file_lock_logger.debug('Acquiring lock on "{}"'.format(path)) try: @@ -713,7 +744,7 @@ def lock_file(path, timeout=30): file_lock_logger.debug('Lock acquired on "{}"'.format(path)) break except FileExistsError: - msg = 'Failed to acquire lock on "{}" Retrying...' + msg: str = 'Failed to acquire lock on "{}" Retrying...' file_lock_logger.debug(msg.format(l_file)) sleep(1) timeout -= 1 diff --git a/wa/utils/postgres.py b/wa/utils/postgres.py index 69c4f6164..e9f71dd2a 100644 --- a/wa/utils/postgres.py +++ b/wa/utils/postgres.py @@ -31,31 +31,44 @@ import os try: - from psycopg2 import InterfaceError - from psycopg2.extensions import AsIs + from psycopg2 import InterfaceError # type:ignore + from psycopg2.extensions import AsIs # type:ignore except ImportError: - InterfaceError = None - AsIs = None + InterfaceError = None # type:ignore + AsIs = None # type:ignore from wa.utils.types import level +from typing import Callable, Optional, List, Any, Tuple, TYPE_CHECKING +from enum import Enum +if TYPE_CHECKING: + from psycopg2.extensions import cursor, connection # type:ignore +else: + cursor = None + connection = None -POSTGRES_SCHEMA_DIR = os.path.join(os.path.dirname(__file__), - '..', - 'commands', - 'postgres_schemas') +POSTGRES_SCHEMA_DIR: str = os.path.join(os.path.dirname(__file__), + '..', + 'commands', + 'postgres_schemas') -def cast_level(value, cur): # pylint: disable=unused-argument +class Level(Enum): + LOW = 1 + MEDIUM = 2 + HIGH = 3 + + +def cast_level(value: str, cur: Optional['cursor']): # pylint: disable=unused-argument """Generic Level caster for psycopg2""" - if not InterfaceError: + if InterfaceError is None: raise ImportError('There was a problem importing psycopg2.') if value is None: return None m = re.match(r"([^\()]*)\((\d*)\)", value) - name = str(m.group(1)) - number = int(m.group(2)) + name = str(m.group(1)) if m else '' + number = int(m.group(2)) if m else 0 if m: return level(name, number) @@ -63,7 +76,7 @@ def cast_level(value, cur): # pylint: disable=unused-argument raise InterfaceError("Bad level representation: {}".format(value)) -def cast_vanilla(value, cur): # pylint: disable=unused-argument +def cast_vanilla(value: Optional[str], cur: Optional['cursor']) -> Optional[str]: # pylint: disable=unused-argument """Vanilla Type caster for psycopg2 Simply returns the string representation. @@ -76,26 +89,26 @@ def cast_vanilla(value, cur): # pylint: disable=unused-argument # List functions and classes for adapting -def adapt_level(a_level): +def adapt_level(a_level: Level): """Generic Level Adapter for psycopg2""" return "{}({})".format(a_level.name, a_level.value) class ListOfLevel(object): - value = None + value: Optional[Level] = None - def __init__(self, a_level): + def __init__(self, a_level: Level): self.value = a_level - def return_original(self): + def return_original(self) -> Optional[Level]: return self.value -def adapt_ListOfX(adapt_X): +def adapt_ListOfX(adapt_X: Callable): """This will create a multi-column adapter for a particular type. Note that the type must itself need to be in array form. Therefore - this function serves to seaprate out individual lists into multiple + this function serves to separate out individual lists into multiple big lists. E.g. if the X adapter produces array (a,b,c) then this adapter will take an list of Xs and produce a master array: @@ -115,16 +128,16 @@ def adapt_ListOfX(adapt_X): subarray following processing then the outer {} are stripped to give a 1 dimensional array. """ - def adapter_function(param): - if not AsIs: + def adapter_function(param: Any) -> AsIs: # type:ignore + if AsIs is None: raise ImportError('There was a problem importing psycopg2.') param = param.value - result_list = [] + result_list: List[str] = [] for element in param: # Where param will be a list of X's result_list.append(adapt_X(element)) test_element = result_list[0] - num_items = len(test_element.split(",")) - master_list = [] + num_items: int = len(test_element.split(",")) + master_list: List[str] = [] for x in range(num_items): master_list.append("") for element in result_list: @@ -133,7 +146,7 @@ def adapter_function(param): for x in range(num_items): master_list[x] = master_list[x] + element[x] + "," if num_items > 1: - master_sql_string = "{" + master_sql_string: str = "{" else: master_sql_string = "" for x in range(num_items): @@ -148,29 +161,30 @@ def adapter_function(param): return adapter_function -def return_as_is(adapt_X): +def return_as_is(adapt_X: Callable) -> Callable: """Returns the AsIs appended function of the function passed This is useful for adapter functions intended to be used with the adapt_ListOfX function, which must return strings, as it allows them to be standalone adapters. """ - if not AsIs: + if AsIs is None: raise ImportError('There was a problem importing psycopg2.') - def adapter_function(param): - return AsIs("'{}'".format(adapt_X(param))) + def adapter_function(param: Any) -> AsIs: # type:ignore + if AsIs is not None: + return AsIs("'{}'".format(adapt_X(param))) return adapter_function -def adapt_vanilla(param): +def adapt_vanilla(param: Any) -> AsIs: # type:ignore """Vanilla adapter: simply returns the string representation""" - if not AsIs: + if AsIs is None: raise ImportError('There was a problem importing psycopg2.') return AsIs("'{}'".format(param)) -def create_iterable_adapter(array_columns, explicit_iterate=False): +def create_iterable_adapter(array_columns: int, explicit_iterate: bool = False) -> Callable: """Create an iterable adapter of a specified dimension If explicit_iterate is True, then it will be assumed that the param needs @@ -183,16 +197,16 @@ def create_iterable_adapter(array_columns, explicit_iterate=False): If array_columns is 0, then this indicates that the iterable contains single items. """ - if not AsIs: + if AsIs is None: raise ImportError('There was a problem importing psycopg2.') - def adapt_iterable(param): + def adapt_iterable(param: Any) -> AsIs: # type:ignore """Adapts an iterable object into an SQL array""" - final_string = "" # String stores a string representation of the array + final_string: str = "" # String stores a string representation of the array if param: if array_columns > 1: for index in range(array_columns): - array_string = "" + array_string: str = "" for item in param.iteritems(): array_string = array_string + str(item[index]) + "," array_string = array_string.strip(",") @@ -207,16 +221,17 @@ def adapt_iterable(param): else: for item in param: final_string = final_string + str(item) + "," - return AsIs("'{{{}}}'".format(final_string)) + if AsIs is not None: + return AsIs("'{{{}}}'".format(final_string)) return adapt_iterable # For reference only and future use -def adapt_list(param): +def adapt_list(param: Any) -> AsIs: # type: ignore """Adapts a list into an array""" - if not AsIs: + if AsIs is None: raise ImportError('There was a problem importing psycopg2.') - final_string = "" + final_string: str = "" if param: for item in param: final_string = final_string + str(item) + "," @@ -224,34 +239,38 @@ def adapt_list(param): return AsIs("'{}'".format(final_string)) -def get_schema(schemafilepath): +def get_schema(schemafilepath: str) -> Tuple[Optional[int], Optional[int], str]: + """ + get schema + """ with open(schemafilepath, 'r') as sqlfile: sql_commands = sqlfile.read() - schema_major = None - schema_minor = None + schema_major: Optional[str] = None + schema_minor: Optional[str] = None # Extract schema version if present if sql_commands.startswith('--!VERSION'): splitcommands = sql_commands.split('!ENDVERSION!\n') schema_major, schema_minor = splitcommands[0].strip('--!VERSION!').split('.') - schema_major = int(schema_major) - schema_minor = int(schema_minor) + schema_major_ = int(schema_major) + schema_minor_ = int(schema_minor) sql_commands = splitcommands[1] - return schema_major, schema_minor, sql_commands + return schema_major_, schema_minor_, sql_commands -def get_database_schema_version(conn): +def get_database_schema_version(conn: 'connection') -> Tuple[Optional[int], Optional[int]]: with conn.cursor() as cursor: cursor.execute('''SELECT DatabaseMeta.schema_major, DatabaseMeta.schema_minor FROM DatabaseMeta;''') - schema_major, schema_minor = cursor.fetchone() + schema_major, schema_minor = cursor.fetchone() or (0, 0) return (schema_major, schema_minor) -def get_schema_versions(conn): +def get_schema_versions(conn: 'connection') -> Tuple[Tuple[Optional[int], Optional[int]], + Tuple[Optional[int], Optional[int]]]: schemafilepath = os.path.join(POSTGRES_SCHEMA_DIR, 'postgres_schema.sql') cur_major_version, cur_minor_version, _ = get_schema(schemafilepath) db_schema_version = get_database_schema_version(conn) diff --git a/wa/utils/revent.py b/wa/utils/revent.py index f858ab471..5244ac707 100644 --- a/wa/utils/revent.py +++ b/wa/utils/revent.py @@ -24,10 +24,13 @@ from wa.framework.resource import Executable, NO_ONE, ResourceResolver from wa.utils.exec_control import once_per_class +from typing import (TYPE_CHECKING, IO, Tuple, Any, List, Union, + Optional, cast, Generator) +if TYPE_CHECKING: + from devlib.target import Target - -GENERAL_MODE = 0 -GAMEPAD_MODE = 1 +GENERAL_MODE: int = 0 +GAMEPAD_MODE: int = 1 u16_struct = struct.Struct(' Tuple[Any, ...]: + """ + read struct + """ data = fh.read(struct_spec.size) return struct_spec.unpack(data) -def read_string(fh): +def read_string(fh: IO) -> str: + """ + read string from struct + """ length, = read_struct(fh, u32_struct) str_struct = struct.Struct('<{}s'.format(length)) return read_struct(fh, str_struct)[0] -def count_bits(bitarr): +def count_bits(bitarr: List[int]) -> int: + """ + count bits in bit array + """ return sum(bin(b).count('1') for b in bitarr) -def is_set(bitarr, bit): +def is_set(bitarr: List[int], bit: int) -> int: + """ + check if bit is set + """ byte = bit // 8 bytebit = bit % 8 return bitarr[byte] & bytebit @@ -72,32 +87,36 @@ def is_set(bitarr, bit): class UinputDeviceInfo(object): + """ + Uinput device information + """ + def __init__(self, fh: IO): + parts: Tuple[Any, ...] = read_struct(fh, devid_struct) + self.bustype: int = parts[0] + self.vendor: int = parts[1] + self.product: int = parts[2] + self.version: int = parts[3] - def __init__(self, fh): - parts = read_struct(fh, devid_struct) - self.bustype = parts[0] - self.vendor = parts[1] - self.product = parts[2] - self.version = parts[3] - - self.name = read_string(fh) + self.name: str = read_string(fh) parts = read_struct(fh, devinfo_struct) self.ev_bits = bytearray(parts[0]) self.key_bits = bytearray(parts[1]) self.rel_bits = bytearray(parts[2]) self.abs_bits = bytearray(parts[3]) - self.num_absinfo = parts[4] - self.absinfo = [absinfo(*read_struct(fh, absinfo_struct)) - for _ in range(self.num_absinfo)] + self.num_absinfo: int = parts[4] + self.absinfo: List[absinfo] = [absinfo(*read_struct(fh, absinfo_struct)) + for _ in range(self.num_absinfo)] - def __str__(self): + def __str__(self) -> str: return 'UInputInfo({})'.format(self.__dict__) class ReventEvent(object): - - def __init__(self, fh, legacy=False): + """ + represents an revent event + """ + def __init__(self, fh: IO, legacy: bool = False): if not legacy: dev_id, ts_sec, ts_usec, type_, code, value = read_struct(fh, event_struct) else: @@ -117,16 +136,16 @@ class ReventRecording(object): Represents a parsed revent recording. This contains input events and device descriptions recorded by revent. Two parsing modes are supported. By default, the recording will be parsed in the "streaming" mode. In this - mode, initial headers and device descritions are parsed on creation and an + mode, initial headers and device descriptions are parsed on creation and an open file handle to the recording is saved. Events will be read from the file as they are being iterated over. In this mode, the entire recording is never loaded into memory at once. The underlying file may be "released" by - calling ``close`` on the recroding, after which further iteration over the + calling ``close`` on the recording, after which further iteration over the events will not be possible (but would still be possible to access the file description and header information). The alternative is to load the entire recording on creation (in which case - the file handle will be closed once the recroding is loaded). This can be + the file handle will be closed once the recording is loaded). This can be enabled by specifying ``streaming=False``. This will make it faster to subsequently iterate over the events, and also will not "hold" the file open. @@ -140,11 +159,41 @@ class ReventRecording(object): """ + def __init__(self, f: Union[str, IO], stream: bool = True): + self.device_paths: List[str] = [] + self.gamepad_device: Optional[UinputDeviceInfo] = None + self.num_events: Optional[int] = None + self.stream = stream + self._events: Optional[List[ReventEvent]] = None + self._close_when_done: bool = False + self._events_start: Optional[int] = None + self._duration: Optional[Union[int, float]] = None + + if hasattr(f, 'name'): # file-like object + self.filepath: str = cast(IO, f).name + self.fh: Optional[IO] = cast(IO, f) + else: # path to file + self.filepath = cast(str, f) + self.fh = open(self.filepath, 'rb') + if not self.stream: + self._close_when_done = True + try: + self._parse_header_and_devices(self.fh) + self._events_start = self.fh.tell() + if not self.stream: + self._events = list(self._iter_events()) + finally: + if self._close_when_done: + self.close() + @property - def duration(self): + def duration(self) -> Optional[Union[float, int]]: + """ + recording duration in seconds + """ if self._duration is None: if self.stream: - events = self._iter_events() + events: Generator[ReventEvent, Any, None] = self._iter_events() try: first = last = next(events) except StopIteration: @@ -155,54 +204,38 @@ def duration(self): else: # not streaming if not self._events: self._duration = 0 - self._duration = (self._events[-1].time - - self._events[0].time).total_seconds() + else: + self._duration = (self._events[-1].time + - self._events[0].time).total_seconds() return self._duration @property - def events(self): + def events(self) -> Union[Generator[ReventEvent, Any, None], + Optional[List[ReventEvent]]]: + """ + Revent events + """ if self.stream: return self._iter_events() else: return self._events - def __init__(self, f, stream=True): - self.device_paths = [] - self.gamepad_device = None - self.num_events = None - self.stream = stream - self._events = None - self._close_when_done = False - self._events_start = None - self._duration = None - - if hasattr(f, 'name'): # file-like object - self.filepath = f.name - self.fh = f - else: # path to file - self.filepath = f - self.fh = open(self.filepath, 'rb') - if not self.stream: - self._close_when_done = True - try: - self._parse_header_and_devices(self.fh) - self._events_start = self.fh.tell() - if not self.stream: - self._events = list(self._iter_events()) - finally: - if self._close_when_done: - self.close() - - def close(self): + def close(self) -> None: + """ + close file handle + """ if self.fh is not None: self.fh.close() self.fh = None self._events_start = None - def _parse_header_and_devices(self, fh): + def _parse_header_and_devices(self, fh: IO) -> None: + """ + parse header and devices + """ magic, version = read_struct(fh, header_one_struct) if magic != b'REVENT': - msg = '{} does not appear to be an revent recording' + msg: str = '{} does not appear to be an revent recording' raise ValueError(msg.format(self.filepath)) self.version = version @@ -216,8 +249,8 @@ def _parse_header_and_devices(self, fh): raise ValueError('Unexpected recording mode: {}'.format(self.mode)) self.num_events, = read_struct(fh, u64_struct) if self.version > 2: - ts_sec = read_struct(fh, u64_struct)[0] - ts_usec = read_struct(fh, u64_struct)[0] + ts_sec: float = read_struct(fh, u64_struct)[0] + ts_usec: float = read_struct(fh, u64_struct)[0] self.start_time = datetime.fromtimestamp(ts_sec + float(ts_usec) / 1000000) ts_sec = read_struct(fh, u64_struct)[0] ts_usec = read_struct(fh, u64_struct)[0] @@ -229,31 +262,41 @@ def _parse_header_and_devices(self, fh): else: raise ValueError('Invalid recording version: {}'.format(self.version)) - def _read_devices(self, fh): + def _read_devices(self, fh: IO) -> None: + """ + read devices + """ num_devices, = read_struct(fh, u32_struct) for _ in range(num_devices): self.device_paths.append(read_string(fh)) - def _read_gamepad_info(self, fh): + def _read_gamepad_info(self, fh: IO) -> None: + """ + read gamepad info + """ self.gamepad_device = UinputDeviceInfo(fh) self.device_paths.append('[GAMEPAD]') - def _iter_events(self): + def _iter_events(self) -> Generator[ReventEvent, Any, None]: + """ + iterate over recorded events + """ if self.fh is None: - msg = 'Attempting to iterate over events of a closed recording' + msg: str = 'Attempting to iterate over events of a closed recording' raise RuntimeError(msg) - self.fh.seek(self._events_start) + self.fh.seek(self._events_start or 0) if self.version >= 2: - for _ in range(self.num_events): + for _ in range(self.num_events or 0): yield ReventEvent(self.fh) else: - file_size = os.path.getsize(self.filepath) + file_size: int = os.path.getsize(self.filepath) while self.fh.tell() < file_size: yield ReventEvent(self.fh, legacy=True) def __iter__(self): - for event in self.events: - yield event + if self.events: + for event in self.events: + yield event def __enter__(self): return self @@ -265,7 +308,10 @@ def __del__(self): self.close() -def get_revent_binary(abi): +def get_revent_binary(abi: str) -> Optional[str]: + """ + get revent binary + """ resolver = ResourceResolver() resolver.load() resource = Executable(NO_ONE, abi, 'revent') @@ -273,40 +319,60 @@ def get_revent_binary(abi): class ReventRecorder(object): - + """ + revent recorder + """ # Share location of target excutable across all instances - target_executable = None + target_executable: Optional[str] = None - def __init__(self, target): + def __init__(self, target: 'Target'): self.target = target if not ReventRecorder.target_executable: ReventRecorder.target_executable = self._get_target_path(self.target) @once_per_class - def deploy(self): + def deploy(self) -> None: + """ + deploy the revent recorder + """ if not ReventRecorder.target_executable: ReventRecorder.target_executable = self.target.get_installed('revent') - host_executable = get_revent_binary(self.target.abi) + host_executable = get_revent_binary(self.target.abi or '') ReventRecorder.target_executable = self.target.install(host_executable) @once_per_class - def remove(self): + def remove(self) -> None: + """ + uninstall revent on target + """ if ReventRecorder.target_executable: self.target.uninstall('revent') - def start_record(self, revent_file): - command = '{} record -s {}'.format(ReventRecorder.target_executable, revent_file) + def start_record(self, revent_file: str) -> None: + """ + start recording + """ + command: str = '{} record -s {}'.format(ReventRecorder.target_executable, revent_file) self.target.kick_off(command, self.target.is_rooted) - def stop_record(self): + def stop_record(self) -> None: + """ + stop recording + """ self.target.killall('revent', signal.SIGINT, as_root=self.target.is_rooted) - def replay(self, revent_file, timeout=None): + def replay(self, revent_file: str, timeout: Optional[int] = None) -> None: + """ + replay the recording + """ self.target.killall('revent') - command = "{} replay {}".format(ReventRecorder.target_executable, revent_file) + command: str = "{} replay {}".format(ReventRecorder.target_executable, revent_file) self.target.execute(command, timeout=timeout) @memoized @staticmethod - def _get_target_path(target): + def _get_target_path(target: 'Target') -> str: + """ + get path of revent installation on target + """ return target.get_installed('revent') diff --git a/wa/utils/serializer.py b/wa/utils/serializer.py index 23a313bf4..9c3a14612 100644 --- a/wa/utils/serializer.py +++ b/wa/utils/serializer.py @@ -22,8 +22,8 @@ The modifications to standard serilization procedures are: - - mappings are deserialized as ``OrderedDict``\ 's rather than standard - Python ``dict``\ 's. This allows for cleaner syntax in certain parts + - mappings are deserialized as ``OrderedDict`` 's rather than standard + Python ``dict`` 's. This allows for cleaner syntax in certain parts of WA configuration (e.g. values to be written to files can be specified as a dict, and they will be written in the order specified in the config). - regular expressions are automatically encoded/decoded. This allows for @@ -62,25 +62,22 @@ from collections import OrderedDict from collections.abc import Hashable from datetime import datetime -import dateutil.parser +import dateutil.parser # type:ignore import yaml as _yaml # pylint: disable=wrong-import-order -from yaml import MappingNode +from yaml import MappingNode, Dumper, Node, ScalarNode try: from yaml import FullLoader as _yaml_loader except ImportError: from yaml import Loader as _yaml_loader -from yaml.constructor import ConstructorError - - -# pylint: disable=redefined-builtin -from past.builtins import basestring # pylint: disable=wrong-import-order +from yaml.constructor import ConstructorError # type:ignore from wa.framework.exception import SerializerSyntaxError from wa.utils.misc import isiterable from wa.utils.types import regex_type, none_type, level, cpu_mask +from typing import (Dict, Any, Callable, Optional, IO, Union, + List, Type, cast, Pattern) - -__all__ = [ +__all__: List[str] = [ 'json', 'yaml', 'read_pod', @@ -90,12 +87,11 @@ 'POD_TYPES', ] -POD_TYPES = [ +POD_TYPES: List[Type] = [ list, tuple, dict, set, - basestring, str, int, float, @@ -110,8 +106,10 @@ class WAJSONEncoder(_json.JSONEncoder): - - def default(self, obj): # pylint: disable=method-hidden,arguments-differ + """ + Json encoder for WA + """ + def default(self, obj: Any) -> str: # pylint: disable=method-hidden,arguments-differ if isinstance(obj, regex_type): return 'REGEX:{}:{}'.format(obj.flags, obj.pattern) elif isinstance(obj, datetime): @@ -125,12 +123,14 @@ def default(self, obj): # pylint: disable=method-hidden,arguments-differ class WAJSONDecoder(_json.JSONDecoder): - + """ + Json decoder for WA + """ def decode(self, s, **kwargs): # pylint: disable=arguments-differ d = _json.JSONDecoder.decode(self, s, **kwargs) - def try_parse_object(v): - if isinstance(v, basestring): + def try_parse_object(v: Any) -> Any: + if isinstance(v, str): if v.startswith('REGEX:'): _, flags, pattern = v.split(':', 2) return re.compile(pattern, int(flags or 0)) @@ -146,10 +146,10 @@ def try_parse_object(v): return v - def load_objects(d): + def load_objects(d: Dict) -> Union[Dict, OrderedDict]: if not hasattr(d, 'items'): return d - pairs = [] + pairs: List = [] for k, v in d.items(): if hasattr(v, 'items'): pairs.append((k, load_objects(v))) @@ -165,77 +165,113 @@ def load_objects(d): class json(object): @staticmethod - def dump(o, wfh, indent=4, *args, **kwargs): + def dump(o: Any, wfh: IO, indent: int = 4, *args, **kwargs) -> None: + """ + serialize o as json formatted stream to wfh + """ return _json.dump(o, wfh, cls=WAJSONEncoder, indent=indent, *args, **kwargs) @staticmethod - def dumps(o, indent=4, *args, **kwargs): + def dumps(o: Any, indent: Optional[int] = 4, *args, **kwargs) -> str: + """ + serialize o to json formatted string + """ return _json.dumps(o, cls=WAJSONEncoder, indent=indent, *args, **kwargs) @staticmethod - def load(fh, *args, **kwargs): + def load(fh: IO, *args, **kwargs) -> Any: + """ + deserialize json from file + """ try: return _json.load(fh, cls=WAJSONDecoder, object_pairs_hook=OrderedDict, *args, **kwargs) except ValueError as e: raise SerializerSyntaxError(e.args[0]) @staticmethod - def loads(s, *args, **kwargs): + def loads(s: str, *args, **kwargs) -> Any: + """ + deserialize json string to python object + """ try: return _json.loads(s, cls=WAJSONDecoder, object_pairs_hook=OrderedDict, *args, **kwargs) except ValueError as e: raise SerializerSyntaxError(e.args[0]) -_mapping_tag = _yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG -_regex_tag = 'tag:wa:regex' -_level_tag = 'tag:wa:level' -_cpu_mask_tag = 'tag:wa:cpu_mask' +_mapping_tag: str = _yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG +_regex_tag: str = 'tag:wa:regex' +_level_tag: str = 'tag:wa:level' +_cpu_mask_tag: str = 'tag:wa:cpu_mask' -def _wa_dict_representer(dumper, data): +def _wa_dict_representer(dumper: Dumper, data: OrderedDict) -> Any: + """ + represent ordered dict in dumped json + """ return dumper.represent_mapping(_mapping_tag, iter(data.items())) -def _wa_regex_representer(dumper, data): - text = '{}:{}'.format(data.flags, data.pattern) +def _wa_regex_representer(dumper: Dumper, data: re.Pattern) -> Any: + """ + represent regex in dumped json + """ + text: str = '{}:{}'.format(data.flags, data.pattern) return dumper.represent_scalar(_regex_tag, text) -def _wa_level_representer(dumper, data): - text = '{}:{}'.format(data.name, data.level) +def _wa_level_representer(dumper: Dumper, data: level) -> Any: + """ + represent level in dumped json + """ + text = '{}:{}'.format(data.name, data.level) # type: ignore return dumper.represent_scalar(_level_tag, text) -def _wa_cpu_mask_representer(dumper, data): +def _wa_cpu_mask_representer(dumper: Dumper, data: cpu_mask) -> Any: + """ + represent cpu mask in dumped json + """ return dumper.represent_scalar(_cpu_mask_tag, data.mask()) -def _wa_regex_constructor(loader, node): - value = loader.construct_scalar(node) +def _wa_regex_constructor(loader: _yaml_loader, node: Node) -> Pattern[str]: + """ + regex constructor + """ + value: str = cast(str, loader.construct_scalar(cast(ScalarNode, node))) flags, pattern = value.split(':', 1) return re.compile(pattern, int(flags or 0)) -def _wa_level_constructor(loader, node): - value = loader.construct_scalar(node) - name, value = value.split(':', 1) +def _wa_level_constructor(loader: _yaml_loader, node: Node) -> level: + """ + level constructor + """ + value = loader.construct_scalar(cast(ScalarNode, node)) + name, value = cast(str, value).split(':', 1) return level(name, value) -def _wa_cpu_mask_constructor(loader, node): - value = loader.construct_scalar(node) +def _wa_cpu_mask_constructor(loader: _yaml_loader, node: Node) -> cpu_mask: + value = cast(str, loader.construct_scalar(cast(ScalarNode, node))) return cpu_mask(value) class _WaYamlLoader(_yaml_loader): # pylint: disable=too-many-ancestors - - def construct_mapping(self, node, deep=False): + """ + yaml loader for WA + """ + + def construct_mapping(self, node: Type[Node], deep: bool = False) -> OrderedDict: + """ + construct mapping + """ if isinstance(node, MappingNode): self.flatten_mapping(node) if not isinstance(node, MappingNode): raise ConstructorError(None, None, - "expected a mapping node, but found %s" % node.id, + "expected a mapping node, but found %s" % node.id, # type:ignore node.start_mark) mapping = OrderedDict() for key_node, value_node in node.value: @@ -261,11 +297,17 @@ def construct_mapping(self, node, deep=False): class yaml(object): @staticmethod - def dump(o, wfh, *args, **kwargs): + def dump(o: Any, wfh: IO, *args, **kwargs) -> None: + """ + serialize object into yaml format and dump into file + """ return _yaml.dump(o, wfh, *args, **kwargs) @staticmethod - def load(fh, *args, **kwargs): + def load(fh: IO, *args, **kwargs) -> Any: + """ + deserialize yaml from file and create python object + """ try: return _yaml.load(fh, *args, Loader=_WaYamlLoader, **kwargs) except _yaml.YAMLError as e: @@ -281,49 +323,66 @@ def load(fh, *args, **kwargs): class python(object): @staticmethod - def dump(o, wfh, *args, **kwargs): + def dump(o: Any, wfh: IO, *args, **kwargs): + """ + serialize object and dump into file + """ raise NotImplementedError() @classmethod - def load(cls, fh, *args, **kwargs): + def load(cls: Type, fh: IO, *args, **kwargs) -> Dict[str, Any]: + """ + load object from file + """ return cls.loads(fh.read()) @staticmethod - def loads(s, *args, **kwargs): - pod = {} + def loads(s: str, *args, **kwargs) -> Dict[str, Any]: + """ + load object from string + """ + pod: Dict[str, Any] = {} try: exec(s, pod) # pylint: disable=exec-used except SyntaxError as e: - raise SerializerSyntaxError(e.message, e.lineno) + raise SerializerSyntaxError(e.msg, e.lineno) for k in list(pod.keys()): # pylint: disable=consider-iterating-dictionary if k.startswith('__'): del pod[k] return pod -def read_pod(source, fmt=None): +def read_pod(source: Union[str, IO], fmt: Optional[str] = None) -> Dict[str, Any]: + """ + read plain old datastructure from file. + source -> file handle or a file path + fmt -> file type - py, json or yaml + """ if isinstance(source, str): with open(source) as fh: return _read_pod(fh, fmt) elif hasattr(source, 'read') and (hasattr(source, 'name') or fmt): return _read_pod(source, fmt) else: - message = 'source must be a path or an open file handle; got {}' + message: str = 'source must be a path or an open file handle; got {}' raise ValueError(message.format(type(source))) -def write_pod(pod, dest, fmt=None): +def write_pod(pod: Dict[str, Any], dest: Union[str, IO], fmt: Optional[str] = None) -> None: + """ + write pod into string or file + """ if isinstance(dest, str): with open(dest, 'w') as wfh: return _write_pod(pod, wfh, fmt) elif hasattr(dest, 'write') and (hasattr(dest, 'name') or fmt): return _write_pod(pod, dest, fmt) else: - message = 'dest must be a path or an open file handle; got {}' + message: str = 'dest must be a path or an open file handle; got {}' raise ValueError(message.format(type(dest))) -def dump(o, wfh, fmt='json', *args, **kwargs): +def dump(o: Any, wfh: IO, fmt: str = 'json', *args, **kwargs): serializer = {'yaml': yaml, 'json': json, 'python': python, @@ -331,14 +390,20 @@ def dump(o, wfh, fmt='json', *args, **kwargs): }.get(fmt) if serializer is None: raise ValueError('Unknown serialization format: "{}"'.format(fmt)) - serializer.dump(o, wfh, *args, **kwargs) + serializer.dump(o, wfh, *args, **kwargs) # type:ignore -def load(s, fmt='json', *args, **kwargs): +def load(s: str, fmt: str = 'json', *args, **kwargs): + """ + load from string into python object + """ return read_pod(s, fmt=fmt) -def _read_pod(fh, fmt=None): +def _read_pod(fh: IO, fmt: Optional[str] = None) -> Dict[str, Any]: + """ + read pod from file + """ if fmt is None: fmt = os.path.splitext(fh.name)[1].lower().strip('.') if fmt == '': @@ -357,7 +422,10 @@ def _read_pod(fh, fmt=None): raise ValueError('Unknown format "{}": {}'.format(fmt, getattr(fh, 'name', ''))) -def _write_pod(pod, wfh, fmt=None): +def _write_pod(pod: Dict[str, Any], wfh: IO, fmt: Optional[str] = None) -> None: + """ + write pod into file + """ if fmt is None: fmt = os.path.splitext(wfh.name)[1].lower().strip('.') if fmt == 'yaml': @@ -370,7 +438,10 @@ def _write_pod(pod, wfh, fmt=None): raise ValueError('Unknown format "{}": {}'.format(fmt, getattr(wfh, 'name', ''))) -def is_pod(obj): +def is_pod(obj: Any) -> bool: + """ + check if object is podable + """ if type(obj) not in POD_TYPES: # pylint: disable=unidiomatic-typecheck return False if hasattr(obj, 'items'): @@ -386,28 +457,37 @@ def is_pod(obj): class Podable(object): - _pod_serialization_version = 0 + _pod_serialization_version: int = 0 @classmethod - def from_pod(cls, pod): + def from_pod(cls: Type, pod: Dict[str, Any]) -> 'Podable': + """ + create a cls object with a plain old datastructure + """ pod = cls._upgrade_pod(pod) instance = cls() instance._pod_version = pod.pop('_pod_version') # pylint: disable=protected-access return instance @classmethod - def _upgrade_pod(cls, pod): + def _upgrade_pod(cls: Type, pod: Dict[str, Any]) -> Dict[str, Any]: + """ + upgrade pod version and access the highest implemented upgrade function to do the upgrade + """ _pod_serialization_version = pod.pop('_pod_serialization_version', None) or 0 while _pod_serialization_version < cls._pod_serialization_version: _pod_serialization_version += 1 - upgrade = getattr(cls, '_pod_upgrade_v{}'.format(_pod_serialization_version)) + upgrade: Callable = getattr(cls, '_pod_upgrade_v{}'.format(_pod_serialization_version)) pod = upgrade(pod) return pod def __init__(self): self._pod_version = self._pod_serialization_version - def to_pod(self): + def to_pod(self) -> Dict[str, Any]: + """ + convert the cls to a plain old datastructure + """ pod = {} pod['_pod_version'] = self._pod_version pod['_pod_serialization_version'] = self._pod_serialization_version diff --git a/wa/utils/terminalsize.py b/wa/utils/terminalsize.py index 0ee8e7dc0..4a3f28b13 100644 --- a/wa/utils/terminalsize.py +++ b/wa/utils/terminalsize.py @@ -22,16 +22,18 @@ import platform import subprocess +from typing import Tuple, Optional, Any -def get_terminal_size(): + +def get_terminal_size() -> Tuple[int, int]: """ getTerminalSize() - get width and height of console - works on linux,os x,windows,cygwin(windows) originally retrieved from: http://stackoverflow.com/questions/566746/how-to-get-console-window-width-in-python """ - current_os = platform.system() - tuple_xy = None + current_os: str = platform.system() + tuple_xy: Optional[Tuple[int, int]] = None if current_os == 'Windows': tuple_xy = _get_terminal_size_windows() if tuple_xy is None: @@ -44,10 +46,13 @@ def get_terminal_size(): return tuple_xy -def _get_terminal_size_windows(): +def _get_terminal_size_windows() -> Optional[Tuple[int, int]]: + """ + get terminal size in windows os + """ # pylint: disable=unused-variable,redefined-outer-name,too-many-locals, import-outside-toplevel try: - from ctypes import windll, create_string_buffer + from ctypes import windll, create_string_buffer # type:ignore # stdin handle is -10 # stdout handle is -11 # stderr handle is -12 @@ -58,14 +63,18 @@ def _get_terminal_size_windows(): (bufx, bufy, curx, cury, wattr, left, top, right, bottom, maxx, maxy) = struct.unpack("hhhhHhhhhhh", csbi.raw) - sizex = right - left + 1 - sizey = bottom - top + 1 + sizex: int = right - left + 1 + sizey: int = bottom - top + 1 return sizex, sizey except: # NOQA pass + return None -def _get_terminal_size_tput(): +def _get_terminal_size_tput() -> Optional[Tuple[int, int]]: + """ + get terminal size tput + """ # get terminal width # src: http://stackoverflow.com/questions/263890/how-do-i-find-the-width-height-of-a-terminal-window try: @@ -74,23 +83,27 @@ def _get_terminal_size_tput(): return (cols, rows) except: # NOQA pass + return None -def _get_terminal_size_linux(): +def _get_terminal_size_linux() -> Optional[Tuple[int, int]]: + """ + get terminal size in linux os + """ # pylint: disable=import-outside-toplevel - def ioctl_GWINSZ(fd): + def ioctl_GWINSZ(fd: int): try: import fcntl import termios - cr = struct.unpack('hh', - fcntl.ioctl(fd, termios.TIOCGWINSZ, '1234')) + cr: Tuple[Any, ...] = struct.unpack('hh', + fcntl.ioctl(fd, termios.TIOCGWINSZ, '1234')) # type:ignore return cr except: # NOQA pass - cr = ioctl_GWINSZ(0) or ioctl_GWINSZ(1) or ioctl_GWINSZ(2) + cr: Optional[Tuple[Any, ...]] = ioctl_GWINSZ(0) or ioctl_GWINSZ(1) or ioctl_GWINSZ(2) if not cr: try: - fd = os.open(os.ctermid(), os.O_RDONLY) + fd: int = os.open(os.ctermid(), os.O_RDONLY) cr = ioctl_GWINSZ(fd) os.close(fd) except: # NOQA diff --git a/wa/utils/trace_cmd.py b/wa/utils/trace_cmd.py index c6a14fbea..f6c24d324 100644 --- a/wa/utils/trace_cmd.py +++ b/wa/utils/trace_cmd.py @@ -21,9 +21,10 @@ from wa.utils.misc import isiterable from wa.utils.types import numeric +from typing import (Optional, Union, Pattern, Any, List, + Callable, Dict, Generator, Match) - -logger = logging.getLogger('trace-cmd') +logger: logging.Logger = logging.getLogger('trace-cmd') class TraceCmdEvent(object): @@ -37,9 +38,10 @@ class TraceCmdEvent(object): """ - __slots__ = ['thread', 'reporting_cpu_id', 'timestamp', 'name', 'text', 'fields'] + __slots__: List[str] = ['thread', 'reporting_cpu_id', 'timestamp', 'name', 'text', 'fields'] - def __init__(self, thread, cpu_id, ts, name, body, parser=None): + def __init__(self, thread: str, cpu_id: str, ts: int, + name: str, body: str, parser: Optional[Callable] = None): """ parameters: @@ -66,7 +68,7 @@ def __init__(self, thread, cpu_id, ts, name, body, parser=None): self.timestamp = numeric(ts) self.name = name self.text = body - self.fields = {} + self.fields: Dict[str, Any] = {} if parser: try: @@ -76,13 +78,13 @@ def __init__(self, thread, cpu_id, ts, name, body, parser=None): # parse self.text pass - def __getattr__(self, name): + def __getattr__(self, name: str): try: return self.fields[name] except KeyError: raise AttributeError(name) - def __str__(self): + def __str__(self) -> str: return 'TE({} @ {})'.format(self.name, self.timestamp) __repr__ = __str__ @@ -90,29 +92,32 @@ def __str__(self): class DroppedEventsEvent(object): - __slots__ = ['thread', 'reporting_cpu_id', 'timestamp', 'name', 'text', 'fields'] + __slots__: List[str] = ['thread', 'reporting_cpu_id', 'timestamp', 'name', 'text', 'fields'] - def __init__(self, cpu_id): + def __init__(self, cpu_id: Union[int, str]): self.thread = None self.reporting_cpu_id = None self.timestamp = None self.name = 'DROPPED EVENTS DETECTED' - self.text = None - self.fields = {'cpu_id': int(cpu_id)} + self.text: Optional[str] = None + self.fields: Dict[str, Any] = {'cpu_id': int(cpu_id)} - def __getattr__(self, name): + def __getattr__(self, name: str): try: return self.fields[name] except KeyError: raise AttributeError(name) - def __str__(self): + def __str__(self) -> str: return 'DROPPED_EVENTS_ON_CPU{}'.format(self.cpu_id) __repr__ = __str__ -def try_convert_to_numeric(v): +def try_convert_to_numeric(v: Any) -> Union[int, float, List[Union[int, float]]]: + """ + convert to numeric + """ try: if isiterable(v): return list(map(numeric, v)) @@ -122,7 +127,7 @@ def try_convert_to_numeric(v): return v -def default_body_parser(event, text): +def default_body_parser(event: TraceCmdEvent, text: str) -> None: """ Default parser to attempt to use to parser body text for the event (i.e. after the "header" common to all events has been parsed). This assumes that the body is @@ -131,9 +136,10 @@ def default_body_parser(event, text): """ parts = [e.rsplit(' ', 1) for e in text.strip().split('=')] - parts = [p.strip() for p in chain.from_iterable(parts)] - if not len(parts) % 2: - i = iter(parts) + parts_ = [p.strip() for p in chain.from_iterable(parts)] + if not len(parts_) % 2: + i = iter(parts_) + v: Union[int, str] for k, v in zip(i, i): try: v = int(v) @@ -142,7 +148,7 @@ def default_body_parser(event, text): event.fields[k] = v -def regex_body_parser(regex, flags=0): +def regex_body_parser(regex: Union[str, Pattern[str]], flags: int = 0) -> Callable: """ Creates an event body parser form the specified regular expression (could be an ``re.RegexObject``, or a string). The regular expression should contain some named @@ -157,8 +163,11 @@ def regex_body_parser(regex, flags=0): if isinstance(regex, str): regex = re.compile(regex, flags) - def regex_parser_func(event, text): - match = regex.search(text) + def regex_parser_func(event: TraceCmdEvent, text: str) -> None: + """ + regex parser function + """ + match: Optional[Match[str]] = regex.search(text) if match: for k, v in match.groupdict().items(): try: @@ -169,25 +178,25 @@ def regex_parser_func(event, text): return regex_parser_func -def sched_switch_parser(event, text): +def sched_switch_parser(event: TraceCmdEvent, text: str) -> None: """ Sched switch output may be presented in a couple of different formats. One is handled by a regex. The other format can *almost* be handled by the default parser, if it weren't for the ``==>`` that appears in the middle. """ if text.count('=') == 2: # old format - regex = re.compile( + regex: Pattern[str] = re.compile( r'(?P\S.*):(?P\d+) \[(?P\d+)\] (?P\S+)' r' ==> ' r'(?P\S.*):(?P\d+) \[(?P\d+)\]' ) - parser_func = regex_body_parser(regex) + parser_func: Callable = regex_body_parser(regex) return parser_func(event, text) else: # there are more than two "=" -- new format return default_body_parser(event, text.replace('==>', '')) -def sched_stat_parser(event, text): +def sched_stat_parser(event: TraceCmdEvent, text: str) -> None: """ sched_stat_* events unclude the units, "[ns]", in an otherwise regular key=value sequence; so the units need to be stripped out first. @@ -195,8 +204,11 @@ def sched_stat_parser(event, text): return default_body_parser(event, text.replace(' [ns]', '')) -def sched_wakeup_parser(event, text): - regex = re.compile(r'(?P\S+):(?P\d+) \[(?P\d+)\] success=(?P\d) CPU:(?P\d+)') +def sched_wakeup_parser(event: TraceCmdEvent, text: str) -> None: + """ + sched wakeup parser + """ + regex: Pattern[str] = re.compile(r'(?P\S+):(?P\d+) \[(?P\d+)\] success=(?P\d) CPU:(?P\d+)') parse_func = regex_body_parser(regex) return parse_func(event, text) @@ -220,14 +232,14 @@ def sched_wakeup_parser(event, text): 'sched_wakeup_new': sched_wakeup_parser, } -TRACE_EVENT_REGEX = re.compile(r'^\s+(?P\S+.*?\S+)\s+\[(?P\d+)\]\s+(?P[\d.]+):\s+' - r'(?P[^:]+):\s+(?P.*?)\s*$') +TRACE_EVENT_REGEX: Pattern[str] = re.compile(r'^\s+(?P\S+.*?\S+)\s+\[(?P\d+)\]\s+(?P[\d.]+):\s+' + r'(?P[^:]+):\s+(?P.*?)\s*$') -HEADER_REGEX = re.compile(r'^\s*(?:version|cpus)\s*=\s*([\d.]+)\s*$') +HEADER_REGEX: Pattern[str] = re.compile(r'^\s*(?:version|cpus)\s*=\s*([\d.]+)\s*$') -DROPPED_EVENTS_REGEX = re.compile(r'CPU:(?P\d+) \[\d*\s*EVENTS DROPPED\]') +DROPPED_EVENTS_REGEX: Pattern[str] = re.compile(r'CPU:(?P\d+) \[\d*\s*EVENTS DROPPED\]') -EMPTY_CPU_REGEX = re.compile(r'CPU \d+ is empty') +EMPTY_CPU_REGEX: Pattern[str] = re.compile(r'CPU \d+ is empty') class TraceCmdParser(object): @@ -236,7 +248,8 @@ class TraceCmdParser(object): """ - def __init__(self, filter_markers=True, check_for_markers=True, events=None): + def __init__(self, filter_markers: bool = True, check_for_markers: bool = True, + events: Optional[List[str]] = None): """ Initialize a new trace parser. @@ -258,17 +271,18 @@ def __init__(self, filter_markers=True, check_for_markers=True, events=None): self.check_for_markers = check_for_markers self.events = events - def parse(self, filepath): # pylint: disable=too-many-branches,too-many-locals + def parse(self, filepath: str) -> Generator[Union[DroppedEventsEvent, + TraceCmdEvent], Any, None]: # pylint: disable=too-many-branches,too-many-locals """ This is a generator for the trace event stream. :param filepath: The path to the file containg text trace as reported by trace-cmd """ - inside_maked_region = False + inside_maked_region: bool = False # pylint: disable=superfluous-parens - filters = [re.compile('^{}$'.format(e)) for e in (self.events or [])] - filter_markers = self.filter_markers + filters: List[Pattern[str]] = [re.compile('^{}$'.format(e)) for e in (self.events or [])] + filter_markers: bool = self.filter_markers if filter_markers and self.check_for_markers: with open(filepath) as fh: for line in fh: @@ -291,12 +305,12 @@ def parse(self, filepath): # pylint: disable=too-many-branches,too-many-locals inside_maked_region = False continue - match = DROPPED_EVENTS_REGEX.search(line) + match: Optional[Match[str]] = DROPPED_EVENTS_REGEX.search(line) if match: yield DroppedEventsEvent(match.group('cpu_id')) continue - matched = False + matched: bool = False for rx in [HEADER_REGEX, EMPTY_CPU_REGEX]: match = rx.search(line) if match: @@ -311,10 +325,10 @@ def parse(self, filepath): # pylint: disable=too-many-branches,too-many-locals logger.warning('Invalid trace event: "{}"'.format(line)) continue - event_name = match.group('name') + event_name: str = match.group('name') if filters: - found = False + found: bool = False for f in filters: if f.search(event_name): found = True @@ -322,13 +336,16 @@ def parse(self, filepath): # pylint: disable=too-many-branches,too-many-locals if not found: continue - body_parser = EVENT_PARSER_MAP.get(event_name, default_body_parser) + body_parser: Callable = EVENT_PARSER_MAP.get(event_name, default_body_parser) if isinstance(body_parser, (str, re.Pattern)): # pylint: disable=protected-access body_parser = regex_body_parser(body_parser) - yield TraceCmdEvent(parser=body_parser, **match.groupdict()) + yield TraceCmdEvent(parser=body_parser, **match.groupdict()) # type:ignore -def trace_has_marker(filepath, max_lines_to_check=2000000): +def trace_has_marker(filepath: str, max_lines_to_check: int = 2000000) -> bool: + """ + check if trace has marker + """ with open(filepath) as fh: for i, line in enumerate(fh): if TRACE_MARKER_START in line: diff --git a/wa/utils/types.py b/wa/utils/types.py index 767b882bb..09f2d7a1e 100644 --- a/wa/utils/types.py +++ b/wa/utils/types.py @@ -36,16 +36,18 @@ from collections.abc import MutableMapping from functools import total_ordering -from past.builtins import basestring # pylint: disable=redefined-builtin -from future.utils import with_metaclass +from future.utils import with_metaclass # type:ignore from devlib.utils.types import identifier, boolean, integer, numeric, caseless_string from wa.utils.misc import (isiterable, list_to_ranges, list_to_mask, mask_to_list, ranges_to_list) +from typing import (List, Any, Optional, Iterable, Callable, + Union, Type, Pattern, Tuple, DefaultDict, + Dict, Set) -def list_of_strs(value): +def list_of_strs(value: Iterable) -> List[str]: """ Value must be iterable. All elements will be converted to strings. @@ -55,12 +57,12 @@ def list_of_strs(value): return list(map(str, value)) -list_of_strings = list_of_strs +list_of_strings: Callable[[Iterable], List[str]] = list_of_strs -def list_of_ints(value): +def list_of_ints(value: Iterable) -> List[int]: """ - Value must be iterable. All elements will be converted to ``int``\ s. + Value must be iterable. All elements will be converted to ``int`` s. """ if not isiterable(value): @@ -68,13 +70,13 @@ def list_of_ints(value): return list(map(int, value)) -list_of_integers = list_of_ints +list_of_integers: Callable[[Iterable], List[int]] = list_of_ints -def list_of_numbers(value): +def list_of_numbers(value: Iterable) -> List[Union[float, int]]: """ Value must be iterable. All elements will be converted to numbers (either ``ints`` or - ``float``\ s depending on the elements). + ``float`` s depending on the elements). """ if not isiterable(value): @@ -82,9 +84,9 @@ def list_of_numbers(value): return list(map(numeric, value)) -def list_of_bools(value, interpret_strings=True): +def list_of_bools(value: Iterable, interpret_strings: bool = True) -> List[bool]: """ - Value must be iterable. All elements will be converted to ``bool``\ s. + Value must be iterable. All elements will be converted to ``bool`` s. .. note:: By default, ``boolean()`` conversion function will be used, which means that strings like ``"0"`` or ``"false"`` will be @@ -100,26 +102,26 @@ def list_of_bools(value, interpret_strings=True): return list(map(bool, value)) -def list_of(type_): +def list_of(type_: Type) -> Type[List]: """Generates a "list of" callable for the specified type. The callable attempts to convert all elements in the passed value to the specified ``type_``, raising ``ValueError`` on error.""" - def __init__(self, values): + def __init__(self, values: Iterable): list.__init__(self, list(map(type_, values))) - def append(self, value): + def append(self, value: Any) -> None: list.append(self, type_(value)) - def extend(self, other): + def extend(self, other: Iterable) -> None: list.extend(self, list(map(type_, other))) - def from_pod(cls, pod): + def from_pod(cls: Type, pod: Iterable): return cls(list(map(type_, pod))) def _to_pod(self): return self - def __setitem__(self, idx, value): + def __setitem__(self, idx: int, value: Any) -> None: list.__setitem__(self, idx, type_(value)) return type('list_of_{}s'.format(type_.__name__), @@ -133,7 +135,7 @@ def __setitem__(self, idx, value): }) -def list_or_string(value): +def list_or_string(value: Union[str, Iterable]) -> List[str]: """ Converts the value into a list of strings. If the value is not iterable, a one-element list with stringified value will be returned. @@ -148,7 +150,7 @@ def list_or_string(value): return [str(value)] -def list_or_caseless_string(value): +def list_or_caseless_string(value: Union[str, Iterable]) -> List[caseless_string]: """ Converts the value into a list of ``caseless_string``'s. If the value is not iterable a one-element list with stringified value will be returned. @@ -188,11 +190,11 @@ def __init__(self, value): list_or_bool = list_or(boolean) -regex_type = type(re.compile('')) -none_type = type(None) +regex_type: Type[Pattern[str]] = type(re.compile('')) +none_type: Type[None] = type(None) -def regex(value): +def regex(value: Union[str, Pattern[str]]) -> Pattern[str]: """ Regular expression. If value is a string, it will be complied with no flags. If you want to specify flags, value must be precompiled. @@ -204,7 +206,7 @@ def regex(value): return re.compile(value) -def version_tuple(v): +def version_tuple(v: str) -> Tuple[str, ...]: """ Converts a version string into a tuple of strings that can be used for natural comparison allowing delimeters of "-" and ".". @@ -213,7 +215,7 @@ def version_tuple(v): return tuple(map(str, (v.split(".")))) -def module_name_set(l): # noqa: E741 +def module_name_set(l: List): # noqa: E741 """ Converts a list of target modules into a set of module names, disregarding any configuration that may be present. @@ -227,19 +229,25 @@ def module_name_set(l): # noqa: E741 return modules -__counters = defaultdict(int) +__counters: DefaultDict = defaultdict(int) -def reset_counter(name=None, value=0): +def reset_counter(name: Optional[str] = None, value: int = 0) -> None: + """ + reset counter + """ __counters[name] = value -def reset_all_counters(value=0): +def reset_all_counters(value: int = 0) -> None: + """ + reset all counters + """ for k in __counters: reset_counter(k, value) -def counter(name=None): +def counter(name: Optional[str] = None) -> int: """ An auto incrementing value (kind of like an AUTO INCREMENT field in SQL). Optionally, the name of the counter to be used is specified (each counter @@ -249,7 +257,7 @@ def counter(name=None): """ __counters[name] += 1 - value = __counters[name] + value: int = __counters[name] return value @@ -259,9 +267,9 @@ class arguments(list): """ - def __init__(self, value=None): + def __init__(self, value: Optional[Union[Iterable, str]] = None): if isiterable(value): - super(arguments, self).__init__(list(map(str, value))) + super(arguments, self).__init__(list(map(str, value or []))) elif isinstance(value, str): posix = os.name != 'nt' super(arguments, self).__init__(shlex.split(value, posix=posix)) @@ -270,31 +278,31 @@ def __init__(self, value=None): else: super(arguments, self).__init__([str(value)]) - def append(self, value): + def append(self, value: Optional[Union[Iterable, str]]): return super(arguments, self).append(str(value)) - def extend(self, values): + def extend(self, values: Iterable): return super(arguments, self).extend(list(map(str, values))) - def __str__(self): + def __str__(self) -> str: return ' '.join(self) class prioritylist(object): - def __init__(self): + def __init__(self) -> None: """ Returns an OrderedReceivers object that externally behaves like a list but it maintains the order of its elements according to their priority. """ - self.elements = defaultdict(list) - self.is_ordered = True - self.priorities = [] - self.size = 0 - self._cached_elements = None + self.elements: DefaultDict = defaultdict(list) + self.is_ordered: bool = True + self.priorities: List[int] = [] + self.size: int = 0 + self._cached_elements: Optional[List[Any]] = None - def add(self, new_element, priority=0): + def add(self, new_element: Any, priority: int = 0) -> None: """ adds a new item in the list. @@ -304,35 +312,54 @@ def add(self, new_element, priority=0): """ self._add_element(new_element, priority) - def add_before(self, new_element, element): + def add_before(self, new_element: Any, element: Any) -> None: + """ + add new element before the specified element + """ priority, index = self._priority_index(element) self._add_element(new_element, priority, index) - def add_after(self, new_element, element): + def add_after(self, new_element: Any, element: Any) -> None: + """ + add new element after the specified element + """ priority, index = self._priority_index(element) self._add_element(new_element, priority, index + 1) - def index(self, element): - return self._to_list().index(element) + def index(self, element: Any) -> Optional[int]: + return self._to_list().index(element) # type:ignore - def remove(self, element): + def remove(self, element: Any) -> None: + """ + remove element from the list + """ index = self.index(element) self.__delitem__(index) - def _priority_index(self, element): + def _priority_index(self, element: Any) -> Tuple[int, int]: + """ + get priority and index of element + """ for priority, elements in self.elements.items(): if element in elements: return (priority, elements.index(element)) raise IndexError(element) - def _to_list(self): + def _to_list(self) -> Optional[List]: + """ + convert to list + """ if self._cached_elements is None: self._cached_elements = [] for priority in self.priorities: self._cached_elements += self.elements[priority] return self._cached_elements - def _add_element(self, element, priority, index=None): + def _add_element(self, element: Any, priority: int, + index: Optional[int] = None) -> None: + """ + add element to the priority list + """ if index is None: self.elements[priority].append(element) else: @@ -342,7 +369,10 @@ def _add_element(self, element, priority, index=None): if priority not in self.priorities: insort(self.priorities, priority) - def _delete(self, priority, priority_index): + def _delete(self, priority: int, priority_index: int) -> None: + """ + remove element from priority list + """ del self.elements[priority][priority_index] self.size -= 1 if not self.elements[priority]: @@ -354,39 +384,39 @@ def __iter__(self): for element in self.elements[priority]: yield element - def __getitem__(self, index): - return self._to_list()[index] + def __getitem__(self, index: int) -> Any: + return self._to_list()[index] # type:ignore - def __delitem__(self, index): + def __delitem__(self, index: Optional[int]): if isinstance(index, numbers.Integral): index = int(index) if index < 0: - index_range = [len(self) + index] + index_range: List[int] = [len(self) + index] else: index_range = [index] elif isinstance(index, slice): index_range = list(range(index.start or 0, index.stop, index.step or 1)) else: raise ValueError('Invalid index {}'.format(index)) - current_global_offset = 0 - priority_counts = dict(zip(self.priorities, [len(self.elements[p]) - for p in self.priorities])) + current_global_offset: int = 0 + priority_counts: Dict[int, int] = dict(zip(self.priorities, [len(self.elements[p]) + for p in self.priorities])) for priority in self.priorities: if not index_range: break - priority_offset = 0 + priority_offset: int = 0 while index_range: - del_index = index_range[0] + del_index: int = index_range[0] if priority_counts[priority] + current_global_offset <= del_index: current_global_offset += priority_counts[priority] break - within_priority_index = del_index - \ + within_priority_index: int = del_index - \ (current_global_offset + priority_offset) self._delete(priority, within_priority_index) priority_offset += 1 index_range.pop(0) - def __len__(self): + def __len__(self) -> int: return self.size @@ -400,33 +430,36 @@ class toggle_set(set): """ @staticmethod - def from_pod(pod): + def from_pod(pod: Any) -> 'toggle_set': return toggle_set(pod) @staticmethod - def merge(dest, source): + def merge(dest: 'toggle_set', source: Union[Set, 'toggle_set']) -> 'toggle_set': + """ + merge two toggle sets + """ if '~~' in source: return toggle_set(source) dest = toggle_set(dest) for item in source: if item not in dest: - #Disable previously enabled item + # Disable previously enabled item if item.startswith('~') and item[1:] in dest: dest.remove(item[1:]) - #Enable previously disabled item + # Enable previously disabled item if not item.startswith('~') and ('~' + item) in dest: dest.remove('~' + item) dest.add(item) return dest - def __init__(self, *args): + def __init__(self, *args) -> None: if args: value = args[0] if isinstance(value, str): - msg = 'invalid type for toggle_set: "{}"' + msg: str = 'invalid type for toggle_set: "{}"' raise TypeError(msg.format(type(value))) - updated_value = [] + updated_value: List[str] = [] for v in value: if v.startswith('~') and v[1:] in updated_value: updated_value.remove(v[1:]) @@ -436,29 +469,38 @@ def __init__(self, *args): args = tuple([updated_value] + list(args[1:])) set.__init__(self, *args) - def merge_with(self, other): + def merge_with(self, other: Union[Set, 'toggle_set']) -> 'toggle_set': + """ + merge this toggle set with other toggle set + """ return toggle_set.merge(self, other) - def merge_into(self, other): + def merge_into(self, other: 'toggle_set') -> 'toggle_set': + """ + merge other toggle set with this toggle set + """ return toggle_set.merge(other, self) - def add(self, item): + def add(self, item: str) -> None: + """ + add item to toggle set + """ if item not in self: - #Disable previously enabled item + # Disable previously enabled item if item.startswith('~') and item[1:] in self: self.remove(item[1:]) - #Enable previously disabled item + # Enable previously disabled item if not item.startswith('~') and ('~' + item) in self: self.remove('~' + item) super(toggle_set, self).add(item) - def values(self): + def values(self) -> Set[str]: """ returns a list of enabled items. """ return {item for item in self if not item.startswith('~')} - def conflicts_with(self, other): + def conflicts_with(self, other: 'toggle_set') -> List[str]: """ Checks if any items in ``other`` conflict with items already in this list. @@ -468,7 +510,7 @@ def conflicts_with(self, other): Returns: A list of items in ``other`` that conflict with items in this list """ - conflicts = [] + conflicts: List[str] = [] for item in other: if item.startswith('~') and item[1:] in self: conflicts.append(item) @@ -476,16 +518,16 @@ def conflicts_with(self, other): conflicts.append(item) return conflicts - def to_pod(self): + def to_pod(self) -> List[str]: return list(self.values()) class ID(str): - def merge_with(self, other): + def merge_with(self, other: 'ID') -> str: return '_'.join([self, other]) - def merge_into(self, other): + def merge_into(self, other: 'ID') -> str: return '_'.join([other, self]) @@ -499,30 +541,30 @@ class obj_dict(MutableMapping): """ @staticmethod - def from_pod(pod): + def from_pod(pod: Any) -> 'obj_dict': return obj_dict(pod) # pylint: disable=super-init-not-called - def __init__(self, values=None, not_in_dict=None): + def __init__(self, values: Any = None, not_in_dict: Optional[List] = None): self.__dict__['dict'] = dict(values or {}) self.__dict__['not_in_dict'] = not_in_dict if not_in_dict is not None else [] - def to_pod(self): + def to_pod(self) -> Any: return self.__dict__['dict'] - def __getitem__(self, key): + def __getitem__(self, key: str): if key in self.not_in_dict: msg = '"{}" is in the list keys that can only be accessed as attributes' raise KeyError(msg.format(key)) return self.__dict__['dict'][key] - def __setitem__(self, key, value): + def __setitem__(self, key: str, value: Any): self.__dict__['dict'][key] = value - def __delitem__(self, key): + def __delitem__(self, key: str): del self.__dict__['dict'][key] - def __len__(self): + def __len__(self) -> int: return sum(1 for _ in self) def __iter__(self): @@ -530,22 +572,22 @@ def __iter__(self): if key not in self.__dict__['not_in_dict']: yield key - def __repr__(self): + def __repr__(self) -> str: return repr(dict(self)) - def __str__(self): + def __str__(self) -> str: return str(dict(self)) - def __setattr__(self, name, value): + def __setattr__(self, name: str, value: Any): self.__dict__['dict'][name] = value - def __delattr__(self, name): + def __delattr__(self, name: str) -> None: if name in self: del self.__dict__['dict'][name] else: raise AttributeError("No such attribute: " + name) - def __getattr__(self, name): + def __getattr__(self, name: str) -> Any: if 'dict' not in self.__dict__: raise AttributeError("No such attribute: " + name) if name in self.__dict__['dict']: @@ -563,30 +605,30 @@ class level(object): """ @staticmethod - def from_pod(pod): + def from_pod(pod: Any) -> 'level': name, value_part = pod.split('(') return level(name, numeric(value_part.rstrip(')'))) - def __init__(self, name, value): + def __init__(self, name: str, value: Any): self.name = caseless_string(name) self.value = numeric(value) - def to_pod(self): + def to_pod(self) -> str: return repr(self) - def __str__(self): + def __str__(self) -> str: return str(self.name) - def __repr__(self): + def __repr__(self) -> str: return '{}({})'.format(self.name, self.value) - def __hash__(self): + def __hash__(self) -> int: return hash(self.name) def __eq__(self, other): if isinstance(other, level): return self.value == other.value - elif isinstance(other, basestring): + elif isinstance(other, str): return self.name == other else: return self.value == other @@ -594,7 +636,7 @@ def __eq__(self, other): def __lt__(self, other): if isinstance(other, level): return self.value < other.value - elif isinstance(other, basestring): + elif isinstance(other, str): return self.name < other else: return self.value < other @@ -602,7 +644,7 @@ def __lt__(self, other): def __ne__(self, other): if isinstance(other, level): return self.value != other.value - elif isinstance(other, basestring): + elif isinstance(other, str): return self.name != other else: return self.value != other @@ -610,16 +652,16 @@ def __ne__(self, other): class _EnumMeta(type): - def __str__(cls): + def __str__(cls) -> str: return str(cls.levels) - def __getattr__(cls, name): + def __getattr__(cls, name: str): name = name.lower() if name in cls.__dict__: return cls.__dict__[name] -def enum(args, start=0, step=1): +def enum(args, start: int = 0, step: int = 1): """ Creates a class with attributes named by the first argument. Each attribute is a ``level`` so they behave is integers in comparisons. @@ -641,23 +683,23 @@ class MyEnum(object): """ - class Enum(with_metaclass(_EnumMeta, object)): + class Enum(with_metaclass(_EnumMeta, object)): # type:ignore @classmethod - def from_pod(cls, pod): - lv = level.from_pod(pod) + def from_pod(cls: Type, pod: Any) -> 'Enum': + lv: level = level.from_pod(pod) for enum_level in cls.levels: if enum_level == lv: return enum_level - msg = 'Unexpected value "{}" for enum.' + msg: str = 'Unexpected value "{}" for enum.' raise ValueError(msg.format(pod)) - def __new__(cls, name): + def __new__(cls: Type, name: str) -> 'Enum': for attr_name in dir(cls): if attr_name.startswith('__'): continue - attr = getattr(cls, attr_name) + attr: Any = getattr(cls, attr_name) if name == attr: return attr @@ -666,14 +708,14 @@ def __new__(cls, name): except ValueError: raise ValueError('Invalid enum value: {}'.format(repr(name))) - reserved = ['values', 'levels', 'names'] + reserved: List[str] = ['values', 'levels', 'names'] - levels = [] - n = start + levels: List['level'] = [] + n: int = start for v in args: - id_v = identifier(v) + id_v: str = identifier(v) if id_v in reserved: - message = 'Invalid enum level name "{}"; must not be in {}' + message: str = 'Invalid enum level name "{}"; must not be in {}' raise ValueError(message.format(v, reserved)) name = caseless_string(id_v) lv = level(v, n) @@ -698,7 +740,7 @@ class ParameterDict(dict): # Function to determine the appropriate prefix based on the parameters type @staticmethod - def _get_prefix(obj): + def _get_prefix(obj) -> str: if isinstance(obj, str): prefix = 's' elif isinstance(obj, float): @@ -715,13 +757,13 @@ def _get_prefix(obj): # Function to add prefix and urlencode a provided parameter. @staticmethod - def _encode(obj): + def _encode(obj: Any) -> str: if isinstance(obj, list): - t = type(obj[0]) - prefix = ParameterDict._get_prefix(obj[0]) + 'l' + t: Type = type(obj[0]) + prefix: str = ParameterDict._get_prefix(obj[0]) + 'l' for item in obj: if not isinstance(item, t): - msg = 'Lists must only contain a single type, contains {} and {}' + msg: str = 'Lists must only contain a single type, contains {} and {}' raise ValueError(msg.format(t, type(item))) obj = '0newelement0'.join(str(x) for x in obj) else: @@ -731,10 +773,10 @@ def _encode(obj): # Function to decode a string and return a value of the original parameter type. # pylint: disable=too-many-return-statements @staticmethod - def _decode(string): - value_type = string[:1] - value_dimension = string[1:2] - value = unquote(string[2:]) + def _decode(string: str): + value_type: str = string[:1] + value_dimension: str = string[1:2] + value: str = unquote(string[2:]) if value_dimension == 's': if value_type == 's': return str(value) @@ -759,13 +801,13 @@ def __init__(self, *args, **kwargs): self.__setitem__(k, v) dict.__init__(self, *args) - def __setitem__(self, name, value): + def __setitem__(self, name: str, value: Any): dict.__setitem__(self, name, self._encode(value)) - def __getitem__(self, name): + def __getitem__(self, name: str): return self._decode(dict.__getitem__(self, name)) - def __contains__(self, item): + def __contains__(self, item: Any): return dict.__contains__(self, self._encode(item)) def __iter__(self): @@ -775,7 +817,7 @@ def iteritems(self): return self.__iter__() def get(self, name): - return self._decode(dict.get(self, name)) + return self._decode(dict.get(self, name) or '') def pop(self, key): return self._decode(dict.pop(self, key)) @@ -807,10 +849,10 @@ class cpu_mask(object): sysfs-style string. """ @staticmethod - def from_pod(pod): + def from_pod(pod: Any) -> 'cpu_mask': return cpu_mask(int(pod['cpu_mask'])) - def __init__(self, cpus): + def __init__(self, cpus: Union[int, str, List, 'cpu_mask']): self._mask = 0 if isinstance(cpus, int): self._mask = cpus @@ -824,10 +866,10 @@ def __init__(self, cpus): elif isinstance(cpus, cpu_mask): self._mask = cpus._mask # pylint: disable=protected-access else: - msg = 'Unknown conversion from {} to cpu mask' + msg: str = 'Unknown conversion from {} to cpu mask' raise ValueError(msg.format(cpus)) - def __bool__(self): + def __bool__(self) -> bool: """Allow for use in comparisons to check if a mask has been set""" return bool(self._mask) @@ -838,11 +880,11 @@ def __repr__(self): __str__ = __repr__ - def list(self): + def list(self) -> List: """Returns a list of the indexes of bits that are set in the mask.""" return list(reversed(mask_to_list(self._mask))) - def mask(self, prefix=True): + def mask(self, prefix: bool = True): """Returns a hex representation of the mask with an optional prefix""" if prefix: return hex(self._mask)