Update paged-attention-metal/attention/paged_attention.metal
Browse files
paged-attention-metal/attention/paged_attention.metal
CHANGED
|
@@ -683,21 +683,26 @@ inline Bfloat8_ fp8_convert<Bfloat8_, Uchar8_>(const thread Uchar8_ &in,
|
|
| 683 |
// TODO(EricLBuehler): optimize with vectorization
|
| 684 |
template <int THREAD_GROUP_SIZE, typename Vec, int N>
|
| 685 |
inline float qk_dot_(const threadgroup Vec (&q)[N], const thread Vec (&k)[N]) {
|
| 686 |
-
// Compute the parallel products for Q*K^T (treat vector lanes separately).
|
| 687 |
using A_vec = typename FloatVec<Vec>::Type;
|
| 688 |
A_vec qk_vec = mul<A_vec, Vec, Vec>(q[0], k[0]);
|
|
|
|
| 689 |
#pragma unroll
|
| 690 |
for (int ii = 1; ii < N; ++ii) {
|
| 691 |
qk_vec = fma(q[ii], k[ii], qk_vec);
|
| 692 |
}
|
| 693 |
|
| 694 |
-
//
|
| 695 |
-
float
|
|
|
|
|
|
|
| 696 |
#pragma unroll
|
| 697 |
-
for (int mask =
|
| 698 |
-
|
| 699 |
}
|
| 700 |
-
|
|
|
|
|
|
|
|
|
|
| 701 |
}
|
| 702 |
|
| 703 |
template <typename T, int THREAD_GROUP_SIZE> struct Qk_dot {
|
|
|
|
| 683 |
// TODO(EricLBuehler): optimize with vectorization
|
| 684 |
template <int THREAD_GROUP_SIZE, typename Vec, int N>
|
| 685 |
inline float qk_dot_(const threadgroup Vec (&q)[N], const thread Vec (&k)[N]) {
|
|
|
|
| 686 |
using A_vec = typename FloatVec<Vec>::Type;
|
| 687 |
A_vec qk_vec = mul<A_vec, Vec, Vec>(q[0], k[0]);
|
| 688 |
+
|
| 689 |
#pragma unroll
|
| 690 |
for (int ii = 1; ii < N; ++ii) {
|
| 691 |
qk_vec = fma(q[ii], k[ii], qk_vec);
|
| 692 |
}
|
| 693 |
|
| 694 |
+
// Sum across elements within each thread
|
| 695 |
+
float qk_local = sum(qk_vec);
|
| 696 |
+
|
| 697 |
+
// Reduction within SIMD group only
|
| 698 |
#pragma unroll
|
| 699 |
+
for (int mask = simdgroup_size / 2; mask >= 1; mask /= 2) {
|
| 700 |
+
qk_local += simd_shuffle_xor(qk_local, mask);
|
| 701 |
}
|
| 702 |
+
|
| 703 |
+
// Each lane now has the sum across the SIMD group
|
| 704 |
+
// Return only from lane 0 or use simd_shuffle if needed
|
| 705 |
+
return simd_shuffle(qk_local, 0);
|
| 706 |
}
|
| 707 |
|
| 708 |
template <typename T, int THREAD_GROUP_SIZE> struct Qk_dot {
|