Skip to content

Sklearn Workflow Integration#5

Open
jxl26 wants to merge 16 commits intochemprop:mainfrom
jxl26:sklearn
Open

Sklearn Workflow Integration#5
jxl26 wants to merge 16 commits intochemprop:mainfrom
jxl26:sklearn

Conversation

@jxl26
Copy link
Copy Markdown

@jxl26 jxl26 commented Dec 2, 2025

Description

Implement sklearn transformer & regressor modules that encapsulate functionalities of Chemprop, such that users can readily employ the Chemprop model as an sklearn estimator and apply the sklearn library to validate/optimize it. Compatible with the latest Chemprop version.

Questions

N/A

Relevant Chemprop Issue

#1075

Checklist

  • License file included
  • Documentation provided (perhaps a Notebook or README)
  • Tests provided and passing on the stated Chemprop version

Copy link
Copy Markdown
Member

@KnathanM KnathanM left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like your notebook needs to have the imports updated. (the tests fail)

You could add this as the license file

MIT License

Copyright (c) 2025 Chemprop Dev Team

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

Comment thread chemprop_contrib/sklearn_integration/chemprop_estimator.py Outdated
self.model, train_dataloaders=train_loader, val_dataloaders=val_loader
)
else:
trainer.fit(self.model, train_dataloaders=train_loader)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can always do trainer.fit(self.model, train_dataloaders=train_loader, val_dataloaders=val_loader) even if val_loader is None, because that is the default for val_dataloaders.


def predict(self, X):
if self.model is None:
raise ValueError("The regressor has not been fitted.")
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe a RuntimeError is more accurate? I don't know a lot about python error though.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adopted.

Comment on lines +570 to +571
if not self.args.no_cache:
test_set.cache = True
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test set doesn't need to be cached as it is only ever used once.

Suggested change
if not self.args.no_cache:
test_set.cache = True

test_set,
batch_size=self.args.batch_size,
num_workers=self.args.num_workers,
collate_fn=pick_collate(test_set),
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe you could use chemprop.data.build_dataloader instead of defining your own pick_collate function.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adopted.

accelerator=self.args.accelerator, devices=1, enable_progress_bar=True
)
preds = eval_trainer.predict(
self.model, dataloaders=dl, return_predictions=True
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The lightning docs say:

return_predictions: Whether to return predictions.
    ``True`` by default except when an accelerator that spawns processes is used (not supported).

So I don't think you need to include this arg.



class ChempropEnsembleRegressor(ChempropRegressor):
def __init__(self, ensemble_size: int = 5, **chemprop_kwargs):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you use **chemprop_kwargs, I don't think you can use this class with any cross_val things like cross_val_score and RandomizedSearchCV. These functions clone the estimator for each separate split of data and the clone method uses the init signature to know what parameters to copy over. Here it looks like the ensemble regressor only takes two parameters, ensemble_size and chemprop_kwargs. But the estimator does not have an attribute self.chemprop_kwargs, so clone` will error.

I think the structure that is more typical in sklearn is to make the sub-estimators outside the composite estimator and pass those in as parameters. Here is some documentation about that: https://scikit-learn.org/stable/modules/grid_search.html#composite-estimators-and-parameter-spaces.

What I am not sure about is the number of sub-estimators can change, so we have to assign all of them to a single attribute, maybe as a list? Or maybe we just document that the ensemble estimator shouldn't be used with CV.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit confused on this again. We discussed how the issue stems from the cv function directly mutating the arguments and not re-running init(), but doesn't this mean that our issue remains unsolved even if we explicitly have each argument and create namespace in fit, since that would refer to fields of self, and assignments to these fields are made in init()?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assignments to these fields are also made when cross validate goes to fit a cloned copy of the estimator. https://github.com/scikit-learn/scikit-learn/blob/eec13ccc9c81027ce9387e1fce6f04fd22e80d4d/sklearn/model_selection/_validation.py#L821

This is why the signature of __init__() needs to match what parameters the estimator has. Instead of relying on __init__() to set the parameters of a cloned estimator, sklearn manually sets the parameters using the signature of __init__().

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assignments to these fields are also made when cross validate goes to fit a cloned copy of the estimator. https://github.com/scikit-learn/scikit-learn/blob/eec13ccc9c81027ce9387e1fce6f04fd22e80d4d/sklearn/model_selection/_validation.py#L821

This is why the signature of __init__() needs to match what parameters the estimator has. Instead of relying on __init__() to set the parameters of a cloned estimator, sklearn manually sets the parameters using the signature of __init__().

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, now addressed, thanks!

Comment thread chemprop_contrib/sklearn_integration/chemprop_estimator.py
Comment thread chemprop_contrib/sklearn_integration/chemprop_estimator.py Outdated
if self.checkpoint is not None:
if len(self.checkpoint) != self.ensemble_size:
logger.warning(
f"The number of models in ensemble for each splitting of data is set to {len(self.args.checkpoint)}."
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably would be good in the warning to add something along the lines of "number of model checkpoints supplied is not equal to the specified ensemble size. got {len(self.args.checkpoint)} model checkpoints" to explain why the number of models is being set to a number.

Copy link
Copy Markdown
Member

@KnathanM KnathanM left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM thanks for your work on this. I hope that people will find this code useful when incorporating chemprop models in an sklearn workflow.

I am rerunning the tests before merging as it has been a while and chemprop has updated versions. Hopefully there are no problems.

"\n",
" | Name | Type | Params | Mode \n",
"-------------------------------------------------------------------------\n",
"0 | message_passing | MulticomponentMessagePassing | 252 K | train\n",
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm now noticing that this is using MulticomponentMessagePassing instead of BondMessagePassing despite being single component. This is because make_datapoints (imported from chemprop and used in chemprop_estimator.py) always returns a list of lists. If it is single component, then it is a list with a single list. So instead of checking isinstance(X[0], list), we should do len(X) > 1 and X= X[0] like the chemprop CLI here and here

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I pushed some changes (b9b5061) that address this.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call! The changes make sense to me, but they cause the examples in the notebook to fail with sklearn complaint: "TypeError: empty() received an invalid combination of arguments - got (tuple, dtype=NoneType, device=NoneType)", likely due to type mismatch induced by the fix. I will see if I can pin down the issue, and let you know if I need help!

@KnathanM
Copy link
Copy Markdown
Member

See the linked PR on Chemprop for why the tests fail on this PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants