You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Thanks for your great work and kindly sharing. I'm learning "pytest" package usage and the flops calculation through your assertion condition.
Now I want to add "MultiHeadAttention" layer support and writing a test case for "MultiHeadAttention" layer and submitting to your repository. But some serious problems occur to me:
I study your function test_attention. In the part of assert condition[0], why the flops calculation of softmax is "5 * Tq * Tv"?
[0]
I attached the log by the way:
Profile:
node name | # float_ops TFProfRoot (--/6.90k flops)
model_4/attention/MatMul (3.20k/3.20k flops)
model_4/attention/MatMul_1 (3.20k/3.20k flops)
model_4/attention/Softmax (500/500 flops)
I write a toy test case for multiheadattention as follows:
def test_multiheadattention():
Tq = 10
dim = 16
q_shape = (Tq, dim)
q = Input(q_shape)
x = MultiHeadAttention(num_heads=1, key_dim=2)(q, q)
model = Model(inputs=q, outputs=x)
flops = get_flops(model, batch_size=1)
print(f'{flops}')
Profile:
node name | # float_ops TFProfRoot (--/740 flops)
model/multi_head_attention/softmax/Softmax (500/500 flops)
model/multi_head_attention/attention_output/add (160/160 flops)
model/multi_head_attention/Mul (20/20 flops)
model/multi_head_attention/key/add (20/20 flops)
model/multi_head_attention/query/add (20/20 flops)
model/multi_head_attention/value/add (20/20 flops)
what I supposed to be is:
query input is \in M_{10, 16}, key input is \in M_{10, 16} and value input is \in M_{10, 16}. First the query and key are projected to M_{16, 2} by two M_{16, 2} respectively, which has total 10 * 16 * 2 * 2[convert MAC to flops] flops. The attention matrix require 10 * 2 * 10 * 2[convert MAC to flops] flops. Then the values are projection need 10 * 10 * 16 * 2[convert MAC to flops] flops. I did not count the softmax operation at this time.
Do my thought have fundamental flaw or something?
The text was updated successfully, but these errors were encountered:
Hi Tokusumi,
Thanks for your great work and kindly sharing. I'm learning "pytest" package usage and the flops calculation through your assertion condition.
Now I want to add "MultiHeadAttention" layer support and writing a test case for "MultiHeadAttention" layer and submitting to your repository. But some serious problems occur to me:
I study your function test_attention. In the part of assert condition[0], why the flops calculation of softmax is "5 * Tq * Tv"?
[0]
keras-flops/tests/test_flops.py
Line 385 in 5c130bf
I attached the log by the way:
Profile:
node name | # float_ops
TFProfRoot (--/6.90k flops)
model_4/attention/MatMul (3.20k/3.20k flops)
model_4/attention/MatMul_1 (3.20k/3.20k flops)
model_4/attention/Softmax (500/500 flops)
I write a toy test case for multiheadattention as follows:
def test_multiheadattention():
Tq = 10
dim = 16
q_shape = (Tq, dim)
q = Input(q_shape)
x = MultiHeadAttention(num_heads=1, key_dim=2)(q, q)
model = Model(inputs=q, outputs=x)
flops = get_flops(model, batch_size=1)
print(f'{flops}')
Profile:
node name | # float_ops
TFProfRoot (--/740 flops)
model/multi_head_attention/softmax/Softmax (500/500 flops)
model/multi_head_attention/attention_output/add (160/160 flops)
model/multi_head_attention/Mul (20/20 flops)
model/multi_head_attention/key/add (20/20 flops)
model/multi_head_attention/query/add (20/20 flops)
model/multi_head_attention/value/add (20/20 flops)
what I supposed to be is:
query input is \in M_{10, 16}, key input is \in M_{10, 16} and value input is \in M_{10, 16}. First the query and key are projected to M_{16, 2} by two M_{16, 2} respectively, which has total 10 * 16 * 2 * 2[convert MAC to flops] flops. The attention matrix require 10 * 2 * 10 * 2[convert MAC to flops] flops. Then the values are projection need 10 * 10 * 16 * 2[convert MAC to flops] flops. I did not count the softmax operation at this time.
Do my thought have fundamental flaw or something?
The text was updated successfully, but these errors were encountered: