diff --git a/openupgradelib/openupgrade.py b/openupgradelib/openupgrade.py index d534f98b..1350f8c5 100644 --- a/openupgradelib/openupgrade.py +++ b/openupgradelib/openupgrade.py @@ -4,6 +4,7 @@ # Copyright Odoo Community Association (OCA) # License AGPL-3.0 or later (https://www.gnu.org/licenses/agpl.html). +import ast import inspect import json import logging as _logging_module @@ -215,6 +216,7 @@ def do_raise(error): "cow_templates_mark_if_equal_to_upstream", "cow_templates_replicate_upstream", "clean_transient_models", + "update_domain", ] @@ -3753,3 +3755,95 @@ def clean_transient_models(cr): cr.execute(query) except Exception as e: logger.warning("Failed to clean transient table %s\n%s", table_name, str(e)) + + +def update_domain(domain, update_name_func=None, update_value_func=None): + """ + Given a string containing a domain, this function + returns a text representation of it modified as executed by the + update_* functions passed. + + :param domain: an Odoo domain like "[('some_field', '=', value)]" + :param update_name_func: function being passed a field name as string, returning a replacement string. + + Example: + + With:: + + def rename_company_id(name): + if name == 'company_id': + return 'company_id2' + return name + + calling:: + + update_domain("[('company_id', '=', 42), ('other_field', '=', value)]", update_name_func=rename_company_id) + + returns:: + + "[('company_id2', '=', 42), ('other_field', '=', value)]" + + :param update_value_func: function being passed a domain field name as string, and a value as string, int, list or ast.Expression if the right hand side of a domain leaf cannot be parsed as one of the former. Return one of the aformentioned types to replace the right hand side of the domain leaf. + + Example: + + With:: + + def change_company_id(name, value): + if name == 'company_id': + if isinstance(value, int) and value == 42: + return 43 + elif isinstance(value, (list, tuple)) and 42 in value: + return [val if val != 42 else 43 for val in value] + return value + + calling:: + + update_domain("[('company_id', '=', 42), ('company_id', 'in', [41, 42]), ('other_field', '=', value)]", update_value_func=change_company_id) + + returns:: + + "[('company_id', '=', 43), ('company_id', 'in', (41, 43)), ('other_field', '=', value)]" + """ + + class Transformer(ast.NodeTransformer): + def visit_List(self, node): + + for element in node.elts: + + if not isinstance(element, (ast.List, ast.Tuple)): + continue + + left = element.elts[0] + if not isinstance(left, ast.Constant) or len(element.elts) != 3: + continue + + if update_name_func: + new_left = update_name_func(left.value) + element.elts[0] = self._get_ast_for_value(new_left or left) + + if update_value_func: + new_right = update_value_func( + left.value, self._get_literal_value_or_ast(element.elts[2]) + ) + element.elts[2] = self._get_ast_for_value( + new_right or element.elts[2] + ) + + return node + + def _get_literal_value_or_ast(self, node): + try: + return ast.literal_eval(node) + except: + return node + + def _get_ast_for_value(self, value): + if isinstance(value, (int, str, float)): + return ast.Constant(value) + elif isinstance(value, (list, tuple)): + return ast.Tuple([self._get_ast_for_value(val) for val in value]) + return value + + expr = ast.parse(domain, mode="eval") + return ast.unparse(Transformer().visit(expr))