Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rmsnorm backward simple baseline kernel #763

Closed
wants to merge 1 commit into from

Conversation

ngc92
Copy link
Contributor

@ngc92 ngc92 commented Sep 21, 2024

baseline kernel and cpu versions


float o = 0.0f;
for (int i = 0; i < C; i++) {
o += weight[i] * dout_bt[i] * inp_bt[i];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for this! As a result of this PR I have found a bug in my own RMSNorm impl. where I forgot to multiply by the weight in the backward pass. Spent days chasing it.

PS: I've ran this rmsnorm_backward_cpu through my test cases that verify against torch and all is passing.

Thanks again!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

    void rms_backward(
        float* _In, float* _δ_Out, float* _δ_In,
        float* _δ_Gamma, float* _Gamma,
        int N, float EPSILON) {
        float rms = 0;
        for (int i = 0; i < N; i++) {
            rms += _In[i] * _In[i];
        }
        rms = std.sqrtf(rms / N + EPSILON);
        float rrms = 1.0f / rms;
        float rnrms = 1.0f / (N * rms * rms);
        float δ = 0;
        for (int i = 0; i < N; i++) {
            δ += _δ_Out[i] * _In[i] * rrms * **_Gamma[i]**; // This is what I missed!!
        }
        for (int i = 0; i < N; i++) {
            _δ_In[i] += _δ_Out[i] * _Gamma[i] * rrms - δ * _In[i] * rnrms;
            _δ_Gamma[i] += _δ_Out[i] * _In[i] * rrms;
        }
    }

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great to hear that this passes an independent set of tests. I did some testing, but wasn't yet fully sure that my implementation was correct.

@ngc92 ngc92 closed this Sep 26, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants