Skip to content

Commit

Permalink
Exp fix
Browse files Browse the repository at this point in the history
  • Loading branch information
nyLiao committed Oct 7, 2024
1 parent 34a5c9b commit f7e246b
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 14 deletions.
24 changes: 21 additions & 3 deletions benchmark/run_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import uuid
from copy import deepcopy

from pyg_spectral.nn import get_model_regi, get_conv_regi
from pyg_spectral.nn import get_model_regi, get_conv_subregi
from pyg_spectral.nn.parse_args import compose_param
from pyg_spectral.utils import CallableDict

Expand Down Expand Up @@ -45,27 +45,45 @@ def __init__(self, data_loader, model_loader, args, res_logger = None):
}[self.model_loader.get_trn(args)]

def _get_suggest(self, trial, key):
r"""Get the suggested value and format of the hyperparameter.
Args:
trial (optuna.Trial): The trial object.
key (str): The hyperparameter key in :obj:`args.param`.
Returns:
val (Any | list[Any]): The suggested value.
fmt (Callable): The format of the suggested value.
"""
def parse_param(val):
r"""From :class:`ParamTuple` to suggest trial value and format.
Args:
val (ParamTuple | list[ParamTuple]): registry entry.
Returns:
val (Any | list[Any]): The suggested value.
fmt (Callable): The format of the suggested value
"""
if isinstance(val, list):
fmt = val[0][-1]
val = [getattr(trial, 'suggest_'+func)(key+'-'+str(i), *fargs, **fkwargs) for i, (func, fargs, fkwargs, _) in enumerate(val)]
return val, fmt
func, fargs, fkwargs, fmt = val
return getattr(trial, 'suggest_'+func)(key, *fargs, **fkwargs), fmt

# Alias compose models
# Alias param for compose models
if (self.args.model in compose_param and
self.model_loader.conv_repr in compose_param[self.args.model] and
key in compose_param[self.args.model][self.model_loader.conv_repr]):
return parse_param(compose_param[self.args.model][self.model_loader.conv_repr](key, self.args))

# Param of trainer and model level
single_param = SingleGraphLoader_Trial.param | ModelLoader_Trial.param | self.trn_cls.param
single_param = CallableDict(single_param)
single_param |= get_model_regi(self.args.model, 'param')
if key in single_param:
return parse_param(single_param(key, self.args))

return parse_param(get_conv_regi(self.args.conv, 'param')(key, self.args))
# Param of conv level
return parse_param(get_conv_subregi(self.args.conv, 'param', key, self.args))

def __call__(self, trial):
args = deepcopy(self.args)
Expand Down
10 changes: 4 additions & 6 deletions benchmark/scripts/runfb-iter.sh
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,14 @@ PARLIST="dropout_lin,lr_lin,wd_lin"
PARLIST="normg,dropout_conv,$PARLIST"
# Linear
python run_param.py --data $data --model $model --conv AdjiConv --param $PARLIST "${ARGS_P[@]}" \
--theta_scheme ones --beta 1.0
--beta 1.0
python run_single.py --data $data --model $model --conv AdjiConv "${ARGS_S[@]}" \
--theta_scheme ones --beta 1.0
--beta 1.0

PARLIST="beta,$PARLIST"
# PPR
python run_param.py --data $data --model $model --conv AdjResConv --param $PARLIST "${ARGS_P[@]}" \
--theta_scheme appr
python run_single.py --data $data --model $model --conv AdjResConv "${ARGS_S[@]}" \
--theta_scheme appr
python run_param.py --data $data --model $model --conv AdjResConv --param $PARLIST "${ARGS_P[@]}"
python run_single.py --data $data --model $model --conv AdjResConv "${ARGS_S[@]}"

: '
# ========== PyG
Expand Down
2 changes: 1 addition & 1 deletion pyg_spectral/nn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .parse_args import (
get_model_regi, get_conv_regi,
get_model_regi, get_conv_regi, get_conv_subregi,
get_nn_name, set_pargs
)
2 changes: 2 additions & 0 deletions pyg_spectral/nn/models/acm_gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class ACMGNN(BaseNN):
# FEATURE: separate arch
name = 'Iterative'
conv_name = lambda x, args: '-'.join([x, args.theta_scheme])
pargs = ['theta_scheme']

def init_conv(self,
conv: str,
Expand Down Expand Up @@ -83,6 +84,7 @@ class ACMGNNDec(BaseNN):
# FEATURE: separate arch
name = 'DecoupledVar'
conv_name = lambda x, args: '-'.join([x, args.theta_scheme])
pargs = ['theta_scheme']

def init_conv(self,
conv: str,
Expand Down
2 changes: 0 additions & 2 deletions pyg_spectral/nn/models/iterative.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,7 @@ def init_conv(self,

class IterativeFixed(Iterative):
name = 'IterativeFixed'
conv_name = lambda x, args: '-'.join([x, args.theta_scheme])


class IterativeFixedCompose(IterativeCompose):
name = 'IterativeFixed'
conv_name = lambda x, args: '-'.join([x, args.theta_scheme])
25 changes: 23 additions & 2 deletions pyg_spectral/nn/parse_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@ def update_regi(regi, new_regi):
model_regi = BaseNN.register_classes()
model_regi = update_regi(model_regi, model_regi_pyg)

conv_regi = CallableDict.to_callableVal(conv_regi, ['pargs_default', 'param'])
conv_regi = CallableDict.to_callableVal(conv_regi, reckeys=['pargs_default', 'param'])
r'''Fields:
* name (CallableDict[str, str]): Conv class logging path name.
* pargs (CallableDict[str, list[str]]): Conv arguments from argparse.
* pargs_default (dict[str, CallableDict[str, Any]]): Default values for model arguments. Not recommended.
* param (dict[str, CallableDict[str, ParamTuple]]): Conv parameters to tune.
'''
model_regi = CallableDict.to_callableVal(model_regi, ['pargs_default', 'param'])
model_regi = CallableDict.to_callableVal(model_regi, reckeys=['pargs_default', 'param'])
r'''Fields:
name (CallableDict[str, str]): Model class logging path name.
conv_name (CallableDict[str, Callable[[str, Any], str]]): Wrap conv logging path name.
Expand All @@ -41,6 +41,11 @@ def update_regi(regi, new_regi):
'ACMConv-2-low-high': 'FBGNNII',
'ACMConv-1-low-high-id': 'ACMGNNI',
'ACMConv-2-low-high-id': 'ACMGNNII',},
'ACMGNNDec': {
'ACMConv-1-low-high': 'FBGNNI',
'ACMConv-2-low-high': 'FBGNNII',
'ACMConv-1-low-high-id': 'ACMGNNI',
'ACMConv-2-low-high-id': 'ACMGNNII',},
'DecoupledFixedCompose': {
'AdjiConv,AdjiConv-ones,ones': 'FAGNN',
'Adji2Conv,Adji2Conv-gaussian,gaussian': 'G2CN',
Expand Down Expand Up @@ -104,6 +109,22 @@ def get_conv_regi(conv: str, k: str, args=None) -> str:
return conv_regi[k](conv, args) if args else conv_regi[k][conv]


def get_conv_subregi(conv: str, k: str, pargs: str, args=None) -> str:
r"""Getter for calling a sub-CallableDict in :attr:`conv_regi`.
Args:
conv: The name of the convolution.
k: The key in :attr:`conv_regi`.
pargs: The key in the sub-CallableDict.
args: Configuration arguments.
Returns:
value (str): The value of the sub-CallableDict.
"""
if ',' in conv:
return [conv_regi[k][channel](pargs, args) for channel in conv.split(',')]
return conv_regi[k][conv](pargs, args) if args else conv_regi[k][conv][pargs]


def get_nn_name(model: str, conv: str, args) -> str:
r"""Parse model+conv name for logging path from argparse input.
Expand Down

0 comments on commit f7e246b

Please sign in to comment.