Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions fiddle/_src/codegen/auto_config/code_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,11 @@ class ModuleReference(BaseNameReference):
"""Reference to an imported module."""


@dataclasses.dataclass
class BuiltinReference(BaseNameReference):
"""Reference to an imported module."""


@dataclasses.dataclass
class FixtureReference(BaseNameReference):
"""Reference to another fixture."""
Expand All @@ -129,6 +134,14 @@ def __hash__(self):
return id(self)


@dataclasses.dataclass
class ParameterizedTypeExpression(CodegenNode):
"""Reference to a parameterized type like list[int]."""

base_expression: Any # Expression like BuiltinReference(Name("list"))
param_expressions: List[Any] # List of (positional) argument expressions


@dataclasses.dataclass
class ArgFactoryExpr(CodegenNode):
"""Represents a factory that should be interpreted as an argument factory.
Expand Down
60 changes: 60 additions & 0 deletions fiddle/_src/codegen/auto_config/import_manager_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# coding=utf-8
# Copyright 2022 The Fiddle-Config Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Small helper functions around the ImportManager.

This is a bit of cruft and should eventually be cleaned up.

Context: The import manager predates modern (auto_config) codegen, and is used
by legacy codegen and diff codegen. The latter is still pretty important and
needs to be supported.
"""

import logging
import typing
from typing import Any

from fiddle._src.codegen import import_manager as import_manager_lib
from fiddle._src.codegen.auto_config import code_ir


def _name_to_attribute_expression(name: str) -> code_ir.CodegenNode:
"""Converts a fully-qualified name to a code_ir node.

Args:
name: Output from the import manager.

Returns:
Codegen node.
"""
if "." not in name:
logging.warning(
"Expected to find a module in %s, but found none. This might be because"
" your module is from __main__, so we'll still emit code, but you might"
" need to fix imports for this symbol.",
name,
)
return code_ir.BaseNameReference(code_ir.Name(name))
base, *parts = name.split(".")
value = code_ir.ModuleReference(code_ir.Name(base))
for part in parts:
value = code_ir.AttributeExpression(value, part)
return typing.cast(code_ir.AttributeExpression, value)


def add(
value: Any, import_manager: import_manager_lib.ImportManager
) -> code_ir.CodegenNode:
return _name_to_attribute_expression(import_manager.add(value))
32 changes: 8 additions & 24 deletions fiddle/_src/codegen/auto_config/make_symbolic_references.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@
import enum
import functools
import inspect
import typing
from typing import Any, Callable

from fiddle import arg_factory
from fiddle import daglish
from fiddle._src import config as config_lib
from fiddle._src.codegen.auto_config import code_ir
from fiddle._src.codegen.auto_config import import_manager_wrapper


def is_plain_symbol_or_enum_value(value: Any) -> bool:
Expand Down Expand Up @@ -67,16 +67,6 @@ def noop_history_comments(unused_buildable):
return code_ir.HistoryComments()


def _name_to_attribute_expression(name: str) -> code_ir.AttributeExpression:
if "." not in name:
raise ValueError(f"Could not parse symbol import {name}")
base, *parts = name.split(".")
value = code_ir.ModuleReference(code_ir.Name(base))
for part in parts:
value = code_ir.AttributeExpression(value, part)
return typing.cast(code_ir.AttributeExpression, value)


def replace_callables_and_configs_with_symbols(
task: code_ir.CodegenTask,
*,
Expand All @@ -98,7 +88,7 @@ def replace_callables_and_configs_with_symbols(
def _handle_partial(
value: config_lib.Partial,
state: daglish.State,
ir_for_symbol: code_ir.AttributeExpression,
ir_for_symbol: code_ir.CodegenNode,
):
"""Split-out helper method to handle Partial() nodes."""
arguments = config_lib.ordered_arguments(value)
Expand Down Expand Up @@ -131,9 +121,7 @@ def _handle_partial(

def _arg_factory_partial():
return code_ir.SymbolOrFixtureCall(
_name_to_attribute_expression(
task.import_manager.add(arg_factory.partial)
),
import_manager_wrapper.add(arg_factory.partial, task.import_manager),
positional_arg_expressions=[ir_for_symbol],
arg_expressions=arg_factory_args,
history_comments=format_history(value),
Expand All @@ -146,9 +134,7 @@ def _arg_factory_partial():
# the auto_config fixture's as_buildable() method. If we got rid of the
# functools.partial, then we couldn't configure any attributes.
return code_ir.SymbolOrFixtureCall(
_name_to_attribute_expression(
task.import_manager.add(functools.partial)
),
import_manager_wrapper.add(functools.partial, task.import_manager),
positional_arg_expressions=[ir_for_symbol],
arg_expressions=regular_args,
history_comments=format_history(value),
Expand All @@ -161,18 +147,16 @@ def _arg_factory_partial():
# which order, but we need to emit both decorators. Go with functools
# on the outer level.
return code_ir.SymbolOrFixtureCall(
_name_to_attribute_expression(
task.import_manager.add(functools.partial)
),
import_manager_wrapper.add(functools.partial, task.import_manager),
positional_arg_expressions=[_arg_factory_partial()],
arg_expressions=regular_args,
history_comments=format_history(value),
)

def traverse(value, state: daglish.State):
if isinstance(value, config_lib.Buildable):
ir_for_symbol = _name_to_attribute_expression(
task.import_manager.add(config_lib.get_callable(value))
ir_for_symbol = import_manager_wrapper.add(
config_lib.get_callable(value), task.import_manager
)
if isinstance(value, config_lib.Config):
all_tags = value.__argument_tags__
Expand Down Expand Up @@ -214,7 +198,7 @@ def traverse(value, state: daglish.State):
else:
raise TypeError(f"Unsupported Buildable {type(value)}")
elif is_plain_symbol_or_enum_value(value):
return _name_to_attribute_expression(task.import_manager.add(value))
return import_manager_wrapper.add(value, task.import_manager)
else:
return state.map_children(value)

Expand Down
22 changes: 6 additions & 16 deletions fiddle/_src/codegen/newcg_symbolic_references.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@
N.B. Please see codegen/auto_config for the auto_config version!!
"""

import typing
from typing import Callable

from fiddle import daglish
from fiddle._src import config as config_lib
from fiddle._src.codegen.auto_config import code_ir
from fiddle._src.codegen.auto_config import import_manager_wrapper
from fiddle._src.codegen.auto_config import make_symbolic_references as ac_make_symbolic_references

is_plain_symbol_or_enum_value = (
Expand Down Expand Up @@ -56,16 +56,6 @@ def import_symbols(task: code_ir.CodegenTask) -> None:
task.import_manager.add(value)


def _name_to_attribute_expression(name: str) -> code_ir.AttributeExpression:
if "." not in name:
raise ValueError(f"Could not parse symbol import {name}")
base, *parts = name.split(".")
value = code_ir.ModuleReference(code_ir.Name(base))
for part in parts:
value = code_ir.AttributeExpression(value, part)
return typing.cast(code_ir.AttributeExpression, value)


def replace_callables_and_configs_with_symbols(
task: code_ir.CodegenTask,
*,
Expand All @@ -84,11 +74,11 @@ def replace_callables_and_configs_with_symbols(

def traverse(value, state: daglish.State):
if isinstance(value, config_lib.Buildable):
ir_for_buildable_type = _name_to_attribute_expression(
task.import_manager.add(type(value))
ir_for_buildable_type = import_manager_wrapper.add(
type(value), task.import_manager
)
ir_for_symbol = _name_to_attribute_expression(
task.import_manager.add(config_lib.get_callable(value))
ir_for_symbol = import_manager_wrapper.add(
config_lib.get_callable(value), task.import_manager
)
all_tags = value.__argument_tags__
value = state.map_children(value)
Expand All @@ -113,7 +103,7 @@ def traverse(value, state: daglish.State):
history_comments=format_history(value),
)
elif is_plain_symbol_or_enum_value(value):
return _name_to_attribute_expression(task.import_manager.add(value))
return import_manager_wrapper.add(value, task.import_manager)
else:
return state.map_children(value)

Expand Down