-
Notifications
You must be signed in to change notification settings - Fork 493
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: embedding-aware attention #217
base: develop
Are you sure you want to change the base?
Conversation
Apply common mask to all embedded columns of a feature
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for sharing! I had a very quick look so I'm not sure I fully understood your approach.
Please correct me if I'm wrong:
- you create mask for each feature without taking embedding dimensions into account. Then give the same attention to each corresponding embedding dimension?
That's a way to make attention aware indeed. I'm not a big fan of the for loops it creates though.
reducing_matrix is a way of getting back the information without for loops.
I did not try any benchmark with this, but you seem to have interesting results on ForestCoverType. Have you tried #92 on forest cover type? I could try to make this branch up to date, will see if I have time this weekend.
pytorch_tabnet/tab_network.py
Outdated
mask_type=self.mask_type) | ||
attention = AttentiveTransformer( | ||
n_a, | ||
len(self.feature_embed_widths) if self.feature_embed_widths else self.input_dim, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I understand correctly don't you always have len(self.feature_embed_widths)==self.input_dim
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No - since TabNetNoEmbeddings
is just the portion of the network after embeddings have been done, the input_dim
is the post-embedding dimension.
@@ -40,6 +40,7 @@ def forward(self, x): | |||
|
|||
class TabNetNoEmbeddings(torch.nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As the name suggest, this class is supposed to be basic tabnet with no embeddings.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup and I agree it's a nice distinction to keep! So my thought was to add an optional parameter (feature_embed_widths
) to this class where, if they want, the user can tell TabNet to treat multiple columns of the input as a single "feature" for attention purposes. By default (None
) TabNetNoEmbeddings
should work as before, treating every column independently. There are a couple of ways this API could pass in the information required, so the current one is a list of how wide each "feature" in your input is: E.g. [1, 1, 2, 3]
would mean:
input_dim
is 1+1+2+3=7n_features
is 4- First two columns are scalar features, next two are a feature with emb_dim 2, next three are a feature with emb_dim 3
f-string issue and unused imports
Pre-compute n_features count and expand mask matrices via indexing instead of for-loop concatenation.
Completely agree with you about the for loops and my concerns about that were why I measured execution time initially! Having thought about it for a while, I realised e.g. for This change seems to deliver a modest speedup on my setup, from 11min 41sec to 11min 32sec at 50 epochs (faster than the original develop-branch code without embedding-aware attention). Your understanding was correct by the way: not making any changes to the actual implementation of AttentionTransformer; just restricting its output dimension down from the number of columns to the number of underlying features... So matrix |
OK I tried several approaches but have not been able to get #92 to install on PyTorch v1.4 environments because of the However I was able to run a comparison on a PyTorch v1.6 environment.
Experimental ResultsValidation accuracy at 10 epochs
Training time at 10 epochs (seconds)
Validation accuracy at 50 epochs
Training time at 50 epochs (seconds)
Interpretation / ObservationsPer the tables:
I also noticed:
This is of course only one sample dataset, and only one hyperparameter configuration! My preference would be to push forward with an approach along the lines of this PR if we're comfortable it can deliver comparable performance to #92 - because it avoids introducing the extra CUDA-linked |
@athewsey amazing contribution! Thanks for this detailed analysis. I'll need to spend more time to dig deeper into the impact of your proposal, but I'm on board with the idea. My only concern is that we are currently working on self supervised pretraining, so it will imply some refactorization of the network part, I'll then need to adapt your proposal to the code so that it fits a bit better to the overall code. But I don't want to "steal" your contribution, so we'll need to figure out a way to do this together. We'll find out! Thanks! I'll get back to you! Let us know if you do a similar benchmark on other datasets! |
Sure thing, thanks! Happy to support on porting to another branch too if you have one that's looking stable & favoured - I'd expect this change should map fairly nicely since it's pretty self-contained and hopefully shouldn't have too much functional overlap, just touches same code sections. |
@athewsey I think it would be worth running for 200 epochs (not necessarily in a 5 fold setting, but with the original split of the paper, the same as we do in the Forest Cover Type notebook). Because what you showed is that we are converging faster (which is good) but not necessarily better. The final test accuracy should be able to reach 0.96~0.97 and after 50 epochs we are still far from this score. |
Hey sorry it's taken a while - now got some extra results from longer testing! I modified the train/val/test split to 60/20/20 in line with the example (previous tests were on 80/10/10)... But actually as commented before I've never seen the library reach ~96% on this dataset/hyperparam combo in PyTorch v1.6 - It got there in PyTorch v1.4, but on 1.6 always topped out at ~93% in my previous (200 epoch) tests. ...So I ran the tests on to 300 epochs and took measures every hundred, to see whether there were any easy gains to be had. I did repeat for 5 random seeds again, because (as we see below) the ranges all do overlap quite a bit so it didn't seem safe to just take a particular result for each branch. Code branches are the same as in the previous tests - not updated with any new commit merges. Validation accuracy at 100 epochs
Training time at 100 epochs (seconds)
Validation accuracy at 200 epochs
Training time at 200 epochs (seconds)
Validation accuracy at 300 epochs
Training time at 300 epochs (seconds)
Again the accuracy results all overlap significantly, but this candidate seems to come out top on average. The timings are notably switched from the short tests: unchanged develop branch comes out reliably fastest. The difference between the two PR candidates gets a bit clearer with the longer test. |
Updating feature branch
@Optimox sorry for the bump, but do you have plans of merging this? |
Hmmm I would probably need some more time to review carefully and convince myself that this is the way to solve the problem. But yes I'll have a look as soon as I can. @bratao Would you mind telling us a bit more about the improvements it gave on your particular problem and the problem itself? |
I use for a Fintech credit model. I got an extra 5% in RMSE compared to master ( I ported this patch to Regression) It still not beat a well tuned random forest for my use case, but is getting closer. |
This implementation amends the attention transformer output dimension to equal the number of features, instead of the number of post-embedding dimensions.
EmbeddingGenerator
is modified to keep a record of the number of dimensions that each feature was embedded to (in a deliberately agnostic way, because I've been experimenting with embedding scalar fields to multiple dimensions too).TabNetNoEmbeddings
is modified to use thisfeature_embed_widths
list to expand out the raw mask matrixM
(by features) to the embedding-compatibleM_x
- by replicating each feature's mask weights to however many columns it was embedded to.AbstractModel
is just to remove thereducing_matrix
altogether? But should get more familiar with this area of the code.Important limitations:
TabNetClassifier
fit & transform: As mentioned above I would need to drill further into the explainability to check there's no potential bugs introduced there. To my knowledge this change is agnostic to whether it's a regression or multi-task problem, but hopefully the CI tests will help confirm that 😂Testing and results:
On a Forest Cover Type based example , I've observed this change to improve validation set performance from approx:
...at essentially the same training speed (11min 41sec to 50 epochs for both pre- and post-change algorithms).
Specifically:
Area
andSoil_Type
features consolidated from the raw (one-hot) data to categorical fields, with embedding dimensions 2 and 3 respectively (representing 4 distinctArea
s, 40Soil_Types
)batch_size=16384, clip_value=2.0, epsilon=1e-15, gamma=1.5, lambda_sparse=0.0001, lr=0.02, max_epochs=50, model_type='classification', momentum=0.3, n_a=64, n_d=64, n_independent=2, n_shared=2, n_steps=5, patience=100, seed=1337, target='Cover_Type', virtual_batch_size=256
IMPORTANT: Please do not create a Pull Request without creating an issue first.
Any change needs to be discussed before proceeding. Failure to do so may result in the rejection of the pull request.
What kind of change does this PR introduce? feature
Does this PR introduce a breaking change?⚠️ Kinda
What needs to be documented once your changes are merged?
Closing issues
Hopefully #122, eventually