Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 21 additions & 14 deletions src/transformers/utils/backbone_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,10 @@ def load_backbone(config):
use_pretrained_backbone = getattr(config, "use_pretrained_backbone", None)
backbone_checkpoint = getattr(config, "backbone", None)
backbone_kwargs = getattr(config, "backbone_kwargs", None)
backbone_kwargs = {} if backbone_kwargs is None else backbone_kwargs
if backbone_kwargs is None:
backbone_kwargs = {}

# Fast path: check mutually exclusive config keys as soon as possible to avoid unnecessary imports/instantiation

if backbone_kwargs and backbone_config is not None:
raise ValueError("You can't specify both `backbone_kwargs` and `backbone_config`.")
Expand All @@ -332,31 +335,35 @@ def load_backbone(config):

# If any of the following are set, then the config passed in is from a model which contains a backbone.
if backbone_config is None and use_timm_backbone is None and backbone_checkpoint is None:
return AutoBackbone.from_config(config=config, **backbone_kwargs)
return AutoBackbone.from_config(config=config)

# Avoid redundant variable assignments in branches when possible.
# No need to check for mutually exclusive input combinations again, logic covers all cases.

# config from the parent model that has a backbone

# config from the parent model that has a backbone
if use_timm_backbone:
if backbone_checkpoint is None:
raise ValueError("config.backbone must be set if use_timm_backbone is True")
# Because of how timm backbones were originally added to models, we need to pass in use_pretrained_backbone
# to determine whether to load the pretrained weights.
backbone = AutoBackbone.from_pretrained(
# The kwargs dict is always passed as-is, so merge all args up front.
return AutoBackbone.from_pretrained(
backbone_checkpoint,
use_timm_backbone=use_timm_backbone,
use_pretrained_backbone=use_pretrained_backbone,
**backbone_kwargs,
)
elif use_pretrained_backbone:
if use_pretrained_backbone:
if backbone_checkpoint is None:
raise ValueError("config.backbone must be set if use_pretrained_backbone is True")
backbone = AutoBackbone.from_pretrained(backbone_checkpoint, **backbone_kwargs)
else:
if backbone_config is None and backbone_checkpoint is None:
raise ValueError("Either config.backbone_config or config.backbone must be set")
if backbone_config is None:
backbone_config = AutoConfig.from_pretrained(backbone_checkpoint, **backbone_kwargs)
backbone = AutoBackbone.from_config(config=backbone_config)
return backbone
return AutoBackbone.from_pretrained(backbone_checkpoint, **backbone_kwargs)

if backbone_config is None and backbone_checkpoint is None:
raise ValueError("Either config.backbone_config or config.backbone must be set")
if backbone_config is None:
# Only call from_pretrained if strictly necessary
backbone_config = AutoConfig.from_pretrained(backbone_checkpoint, **backbone_kwargs)
return AutoBackbone.from_config(config=backbone_config)


def verify_backbone_config_arguments(
Expand Down