Triton & 九齿 2024 冬季作业 #9
voltjia
announced in
Announcements
Replies: 1 comment 1 reply
-
感觉九齿的符号定义类似于Sympy,或者说早年的静态图神经网络框架的网络定义,就是符号运算后还是一个符号,编译后有了真实数据后会带入编译后的图进行运算(因为现在才继续看九齿的课,所以没能及时在在线直播的评论区吐槽,就转Github讨论区了,若不合时宜可直接删除此评论,就是Sympy的例子更加直观,可体验,作为示范可以加深理解) |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
本作业的目标是让学员实现一个基于 Triton 的注意力计算内核,并确保它与 PyTorch 中的
scaled_dot_product_attention
函数的输出一致。具体来说,作业要求实现一个带有is_causal
参数的注意力计算内核,该参数控制是否使用因果注意力。你需要通过 Triton 实现这一功能,并验证其正确性。作业内容
1. 理解注意力机制
在注意力机制中,给定查询(Query)、键(Key)和值(Value),计算过程通常为:
其中,因果注意力(Causal Attention)通过掩蔽未来时间步(确保当前位置只与之前的位置进行交互)来实现。
PyTorch 中的
scaled_dot_product_attention
实现了该机制。你将基于 Triton 实现相同的计算。感兴趣的同学可以阅读 Attention is All You Need。
2. 实现步骤
给定
attention(query, key, value, is_causal=False, scale=None)
函数的签名,你需要实现核心的注意力计算。在这里:query
,key
,value
是注意力机制中的三个张量,分别表示查询、键和值;is_causal
用于控制是否应用因果掩蔽;scale
控制是否进行缩放处理,通常在计算点积时除以键的维度的平方根。你可以参考 FlashAttention 等算法进行实现。
在
test_attention_kernel
中,我们提供了一个与 PyTorch 结果对比的测试函数compare_results
。你需要确保你的 Triton 实现与 PyTorch 的scaled_dot_product_attention
输出一致。请使用以下命令运行测试:3. 提交要求
attention
函数签名;attention
函数体以调用 Triton 内核;测试代码
希望这个作业帮助大家更好地理解 Triton 与深度学习中的并行计算。
Beta Was this translation helpful? Give feedback.
All reactions