diff --git a/odoo_module_migrate/migration_scripts/migrate_130_140.py b/odoo_module_migrate/migration_scripts/migrate_130_140.py index e98c2e42..50c73c91 100644 --- a/odoo_module_migrate/migration_scripts/migrate_130_140.py +++ b/odoo_module_migrate/migration_scripts/migrate_130_140.py @@ -7,6 +7,7 @@ from pathlib import Path import lxml.etree as et from odoo_module_migrate.base_migration_script import BaseMigrationScript +from ..tools import get_files def src_model_new_value(field_elem, model_dot_name): @@ -112,15 +113,6 @@ def _reformat_file(file_path: Path): return file_path -def _get_files(module_path, reformat_file_ext): - """Get files to be reformatted.""" - file_paths = list() - if not module_path.is_dir(): - raise Exception(f"'{module_path}' is not a directory") - file_paths.extend(module_path.rglob("*" + reformat_file_ext)) - return file_paths - - def reformat_deprecated_tags( logger, module_path, module_name, manifest_path, migration_steps, tools ): @@ -130,8 +122,8 @@ def reformat_deprecated_tags( they have to be substituted by the `record` tag. """ - reformat_file_ext = ".xml" - file_paths = _get_files(module_path, reformat_file_ext) + reformat_file_ext = [".xml"] + file_paths = get_files(module_path, reformat_file_ext) logger.debug(f"{reformat_file_ext} files found:\n" f"{list(map(str, file_paths))}") reformatted_files = list() diff --git a/odoo_module_migrate/migration_scripts/migrate_160_170.py b/odoo_module_migrate/migration_scripts/migrate_160_170.py index 96421dab..5078d540 100644 --- a/odoo_module_migrate/migration_scripts/migrate_160_170.py +++ b/odoo_module_migrate/migration_scripts/migrate_160_170.py @@ -7,13 +7,19 @@ import os import ast from typing import Any +from ..tools import get_files +import re empty_list = ast.parse("[]").body[0].value +### Abstract visitor class ### + + class AbstractVisitor(ast.NodeVisitor): - def __init__(self) -> None: + def __init__(self, all_code=None) -> None: # ((line, line_end, col_offset, end_col_offset), replace_by) NO OVERLAPS + self.all_code = all_code self.change_todo = [] def post_process(self, all_code: str, file: str) -> str: @@ -45,6 +51,73 @@ def add_change(self, old_node: ast.AST, new_node: ast.AST | str): self.change_todo.append((position, ast.unparse(new_node))) +def apply_visitor_scripts(logger, filename): + with open(filename, mode="rt") as file: + new_all = all_code = file.read() + if ".read_group(" in all_code or "._read_group(" in all_code: + for Step in Steps_visitor: + visitor = Step() + try: + visitor.visit(ast.parse(new_all)) + except Exception: + logger.error( + f"ERROR in {filename} at step {visitor.__class__}: \n{new_all}" + ) + raise + new_all = visitor.post_process(new_all, filename) + if new_all == all_code: + logger.warning( + "read_group detected but not changed in file %s" % filename + ) + else: + logger.info("Script read_group replace applied in file %s" % filename) + with open(filename, mode="wt") as file: + file.write(new_all) + + if "def name_get(" in all_code: + visitor = VisitorNameGet(all_code) + try: + visitor.visit(ast.parse(new_all)) + except Exception: + logger.error( + f"ERROR in {filename} at step {visitor.__class__}: \n{new_all}" + ) + raise + new_all = visitor.post_process(new_all, filename) + if new_all == all_code: + logger.warning( + "name_get detected but not changed in file: %s" % filename + ) + else: + logger.warning( + "Script name_get replace applied in file %s. " + "You probably need to add manually some dependencies to the new computed method." + % filename + ) + with open(filename, mode="wt") as file: + file.write(new_all) + + +def _reformat_files( + logger, module_path, module_name, manifest_path, migration_steps, tools +): + """Reformat read_group method in py files.""" + + reformat_file_ext = [".py"] + file_paths = get_files(module_path, reformat_file_ext) + logger.debug(f"{reformat_file_ext} files found:\n" f"{list(map(str, file_paths))}") + + reformatted_files = list() + for file_path in file_paths: + reformatted_file = apply_visitor_scripts(logger, file_path) + if reformatted_file: + reformatted_files.append(reformatted_file) + logger.debug("Reformatted files:\n" f"{list(reformatted_files)}") + + +### Refactor _read_group script ### + + class VisitorToPrivateReadGroup(AbstractVisitor): def post_process(self, all_code: str, file: str) -> str: all_lines = all_code.split("\n") @@ -213,36 +286,7 @@ def visit_Call(self, node: ast.Call) -> Any: ] -def replace_read_group_signature(logger, filename): - with open(filename, mode="rt") as file: - new_all = all_code = file.read() - if ".read_group(" in all_code or "._read_group(" in all_code: - for Step in Steps_visitor: - visitor = Step() - try: - visitor.visit(ast.parse(new_all)) - except Exception: - logger.info( - f"ERROR in {filename} at step {visitor.__class__}: \n{new_all}" - ) - raise - new_all = visitor.post_process(new_all, filename) - if new_all == all_code: - logger.info("read_group detected but not changed in file %s" % filename) - - if new_all != all_code: - logger.info("Script read_group replace applied in file %s" % filename) - with open(filename, mode="wt") as file: - file.write(new_all) - - -def _get_files(module_path, reformat_file_ext): - """Get files to be reformatted.""" - file_paths = list() - if not module_path.is_dir(): - raise Exception(f"'{module_path}' is not a directory") - file_paths.extend(module_path.rglob("*" + reformat_file_ext)) - return file_paths +### Warning open_form_view ### def _check_open_form_view(logger, file_path: Path): @@ -406,8 +450,8 @@ def attrs_to_attributes(attrs_string): def _check_open_form( logger, module_path, module_name, manifest_path, migration_steps, tools ): - reformat_file_ext = ".xml" - file_paths = _get_files(module_path, reformat_file_ext) + reformat_file_ext = [".xml"] + file_paths = get_files(module_path, reformat_file_ext) logger.debug(f"{reformat_file_ext} files found:\n" f"{list(map(str, file_paths))}") for file_path in file_paths: @@ -425,27 +469,68 @@ def _move_attrs_to_attributes( _move_attrs_to_attributes_view(logger, file_path) -def _reformat_read_group( - logger, module_path, module_name, manifest_path, migration_steps, tools -): - """Reformat read_group method in py files.""" +### name_get script ### - reformat_file_ext = ".py" - file_paths = _get_files(module_path, reformat_file_ext) - logger.debug(f"{reformat_file_ext} files found:\n" f"{list(map(str, file_paths))}") +list_comprehension_re = re.compile( + r"""def name_get\(self\):(.*) + return \[\s*\(.*id,\s*(.+)\s*\)\s*for\s*(\w+)\s*in\s*self\s*\]""", + re.S, +) +list_comprehension_into = r"""@api.depends('') + def _compute_display_name(self):\g<1> + for \g<3> in self: + \g<3>.display_name = \g<2>""" - reformatted_files = list() - for file_path in file_paths: - reformatted_file = replace_read_group_signature(logger, file_path) - if reformatted_file: - reformatted_files.append(reformatted_file) - logger.debug("Reformatted files:\n" f"{list(reformatted_files)}") +list_append_re = re.compile( + r"""def name_get\(self\):.+ + return ([A-Za-z_]+)""", + re.S, +) + + +class VisitorNameGet(AbstractVisitor): + def post_process(self, all_code: str, file: str) -> str: + for (old_str, new_str) in self.change_todo: + all_code = all_code.replace(old_str, new_str) + return all_code + + def add_change(self, old_str, new_str): + self.change_todo.append((old_str, new_str)) + + def visit_FunctionDef(self, node: ast.FunctionDef) -> Any: + if node.name == "name_get": + # Replace fields by aggregate and orderby by order + code = ast.get_source_segment(self.all_code, node) + new_code = list_comprehension_re.sub(list_comprehension_into, code, 1) + + if new_code != code: + self.add_change(code, new_code) + elif groups := list_append_re.fullmatch(new_code): + variable_to_remove = groups.group(1) + new_code = new_code.replace(f"\n {variable_to_remove} = []", "") + new_code = new_code.replace( + "def name_get(self):", "def _compute_display_name(self):" + ) + new_code = new_code.replace( + "super().name_get()", "super()._compute_display_name()" + ) + new_code = new_code.replace( + f"\n return {variable_to_remove}", "" + ) + new_code = re.sub( + variable_to_remove + r".append\(\((.+).id, (.*),?\)\)", + r"\g<1>.display_name = \g<2>", + new_code, + ) + self.add_change(code, new_code) + else: + print(f"Cannot find a way to convert:\n{code}") class MigrationScript(BaseMigrationScript): _GLOBAL_FUNCTIONS = [ _check_open_form, - _reformat_read_group, + _reformat_files, _move_attrs_to_attributes, ] diff --git a/tests/data_result/module_160_170/models/res_partner.py b/tests/data_result/module_160_170/models/res_partner.py index 38863a15..baaee281 100644 --- a/tests/data_result/module_160_170/models/res_partner.py +++ b/tests/data_result/module_160_170/models/res_partner.py @@ -7,6 +7,7 @@ class ResPartner(models.Model): test_field_1 = fields.Boolean() task_ids = fields.One2many('project.task') task_count = fields.Integer(compute='_compute_task_count', string='# Tasks') + last_name = fields.Char() def _compute_task_count(self): # retrieve all children partners and prefetch 'parent_id' on them @@ -21,3 +22,10 @@ def _compute_task_count(self): group_dependent = self.env['project.task']._read_group([ ('depend_on_ids', 'in', task_data.ids), ], ['depend_on_ids'], ['__count']) + + def _compute_display_name(self): + if self.test_field_1: + return super()._compute_display_name() + for partner in self: + name = partner.name + ' ' + partner.last_name + partner.display_name = name diff --git a/tests/data_template/module_160/models/res_partner.py b/tests/data_template/module_160/models/res_partner.py index 487a2e3f..2392b426 100644 --- a/tests/data_template/module_160/models/res_partner.py +++ b/tests/data_template/module_160/models/res_partner.py @@ -7,6 +7,7 @@ class ResPartner(models.Model): test_field_1 = fields.Boolean() task_ids = fields.One2many('project.task') task_count = fields.Integer(compute='_compute_task_count', string='# Tasks') + last_name = fields.Char() def _compute_task_count(self): # retrieve all children partners and prefetch 'parent_id' on them @@ -21,3 +22,12 @@ def _compute_task_count(self): group_dependent = self.env['project.task']._read_group([ ('depend_on_ids', 'in', task_data.ids), ], ['depend_on_ids'], ['depend_on_ids']) + + def name_get(self): + result = [] + if self.test_field_1: + return super().name_get() + for partner in self: + name = partner.name + ' ' + partner.last_name + result.append((partner.id, name)) + return result