Grok 10.0.5
dot-inl.h
Go to the documentation of this file.
1// Copyright 2021 Google LLC
2// SPDX-License-Identifier: Apache-2.0
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16// Include guard (still compiled once per target)
17#include <cmath>
18
19#if defined(HIGHWAY_HWY_CONTRIB_DOT_DOT_INL_H_) == \
20 defined(HWY_TARGET_TOGGLE)
21#ifdef HIGHWAY_HWY_CONTRIB_DOT_DOT_INL_H_
22#undef HIGHWAY_HWY_CONTRIB_DOT_DOT_INL_H_
23#else
24#define HIGHWAY_HWY_CONTRIB_DOT_DOT_INL_H_
25#endif
26
27#include "hwy/highway.h"
28
30namespace hwy {
31namespace HWY_NAMESPACE {
32
33struct Dot {
34 // Specify zero or more of these, ORed together, as the kAssumptions template
35 // argument to Compute. Each one may improve performance or reduce code size,
36 // at the cost of additional requirements on the arguments.
38 // num_elements is at least N, which may be up to HWY_MAX_BYTES / sizeof(T).
40 // num_elements is divisible by N (a power of two, so this can be used if
41 // the problem size is known to be a power of two >= HWY_MAX_BYTES /
42 // sizeof(T)).
44 // RoundUpTo(num_elements, N) elements are accessible; their value does not
45 // matter (will be treated as if they were zero).
47 };
48
49 // Returns sum{pa[i] * pb[i]} for float or double inputs. Aligning the
50 // pointers to a multiple of N elements is helpful but not required.
51 template <int kAssumptions, class D, typename T = TFromD<D>,
52 HWY_IF_NOT_LANE_SIZE_D(D, 2)>
53 static HWY_INLINE T Compute(const D d, const T* const HWY_RESTRICT pa,
54 const T* const HWY_RESTRICT pb,
55 const size_t num_elements) {
56 static_assert(IsFloat<T>(), "MulAdd requires float type");
57 using V = decltype(Zero(d));
58
59 const size_t N = Lanes(d);
60 size_t i = 0;
61
62 constexpr bool kIsAtLeastOneVector =
63 (kAssumptions & kAtLeastOneVector) != 0;
64 constexpr bool kIsMultipleOfVector =
65 (kAssumptions & kMultipleOfVector) != 0;
66 constexpr bool kIsPaddedToVector = (kAssumptions & kPaddedToVector) != 0;
67
68 // Won't be able to do a full vector load without padding => scalar loop.
69 if (!kIsAtLeastOneVector && !kIsMultipleOfVector && !kIsPaddedToVector &&
70 HWY_UNLIKELY(num_elements < N)) {
71 // Only 2x unroll to avoid excessive code size.
72 T sum0 = T(0);
73 T sum1 = T(0);
74 for (; i + 2 <= num_elements; i += 2) {
75 sum0 += pa[i + 0] * pb[i + 0];
76 sum1 += pa[i + 1] * pb[i + 1];
77 }
78 if (i < num_elements) {
79 sum1 += pa[i] * pb[i];
80 }
81 return sum0 + sum1;
82 }
83
84 // Compiler doesn't make independent sum* accumulators, so unroll manually.
85 // 2 FMA ports * 4 cycle latency = up to 8 in-flight, but that is excessive
86 // for unaligned inputs (each unaligned pointer halves the throughput
87 // because it occupies both L1 load ports for a cycle). We cannot have
88 // arrays of vectors on RVV/SVE, so always unroll 4x.
89 V sum0 = Zero(d);
90 V sum1 = Zero(d);
91 V sum2 = Zero(d);
92 V sum3 = Zero(d);
93
94 // Main loop: unrolled
95 for (; i + 4 * N <= num_elements; /* i += 4 * N */) { // incr in loop
96 const auto a0 = LoadU(d, pa + i);
97 const auto b0 = LoadU(d, pb + i);
98 i += N;
99 sum0 = MulAdd(a0, b0, sum0);
100 const auto a1 = LoadU(d, pa + i);
101 const auto b1 = LoadU(d, pb + i);
102 i += N;
103 sum1 = MulAdd(a1, b1, sum1);
104 const auto a2 = LoadU(d, pa + i);
105 const auto b2 = LoadU(d, pb + i);
106 i += N;
107 sum2 = MulAdd(a2, b2, sum2);
108 const auto a3 = LoadU(d, pa + i);
109 const auto b3 = LoadU(d, pb + i);
110 i += N;
111 sum3 = MulAdd(a3, b3, sum3);
112 }
113
114 // Up to 3 iterations of whole vectors
115 for (; i + N <= num_elements; i += N) {
116 const auto a = LoadU(d, pa + i);
117 const auto b = LoadU(d, pb + i);
118 sum0 = MulAdd(a, b, sum0);
119 }
120
121 if (!kIsMultipleOfVector) {
122 const size_t remaining = num_elements - i;
123 if (remaining != 0) {
124 if (kIsPaddedToVector) {
125 const auto mask = FirstN(d, remaining);
126 const auto a = LoadU(d, pa + i);
127 const auto b = LoadU(d, pb + i);
128 sum1 = MulAdd(IfThenElseZero(mask, a), IfThenElseZero(mask, b), sum1);
129 } else {
130 // Unaligned load such that the last element is in the highest lane -
131 // ensures we do not touch any elements outside the valid range.
132 // If we get here, then num_elements >= N.
133 HWY_DASSERT(i >= N);
134 i += remaining - N;
135 const auto skip = FirstN(d, N - remaining);
136 const auto a = LoadU(d, pa + i); // always unaligned
137 const auto b = LoadU(d, pb + i);
138 sum1 = MulAdd(IfThenZeroElse(skip, a), IfThenZeroElse(skip, b), sum1);
139 }
140 }
141 } // kMultipleOfVector
142
143 // Reduction tree: sum of all accumulators by pairs, then across lanes.
144 sum0 = Add(sum0, sum1);
145 sum2 = Add(sum2, sum3);
146 sum0 = Add(sum0, sum2);
147 return GetLane(SumOfLanes(d, sum0));
148 }
149
150 // Returns sum{pa[i] * pb[i]} for bfloat16 inputs. Aligning the pointers to a
151 // multiple of N elements is helpful but not required.
152 template <int kAssumptions, class D>
153 static HWY_INLINE float Compute(const D d,
154 const bfloat16_t* const HWY_RESTRICT pa,
155 const bfloat16_t* const HWY_RESTRICT pb,
156 const size_t num_elements) {
157 const RebindToUnsigned<D> du16;
158 const Repartition<float, D> df32;
159
160 using V = decltype(Zero(df32));
161 const size_t N = Lanes(d);
162 size_t i = 0;
163
164 constexpr bool kIsAtLeastOneVector =
165 (kAssumptions & kAtLeastOneVector) != 0;
166 constexpr bool kIsMultipleOfVector =
167 (kAssumptions & kMultipleOfVector) != 0;
168 constexpr bool kIsPaddedToVector = (kAssumptions & kPaddedToVector) != 0;
169
170 // Won't be able to do a full vector load without padding => scalar loop.
171 if (!kIsAtLeastOneVector && !kIsMultipleOfVector && !kIsPaddedToVector &&
172 HWY_UNLIKELY(num_elements < N)) {
173 float sum0 = 0.0f; // Only 2x unroll to avoid excessive code size for..
174 float sum1 = 0.0f; // this unlikely(?) case.
175 for (; i + 2 <= num_elements; i += 2) {
176 sum0 += F32FromBF16(pa[i + 0]) * F32FromBF16(pb[i + 0]);
177 sum1 += F32FromBF16(pa[i + 1]) * F32FromBF16(pb[i + 1]);
178 }
179 if (i < num_elements) {
180 sum1 += F32FromBF16(pa[i]) * F32FromBF16(pb[i]);
181 }
182 return sum0 + sum1;
183 }
184
185 // See comment in the other Compute() overload. Unroll 2x, but we need
186 // twice as many sums for ReorderWidenMulAccumulate.
187 V sum0 = Zero(df32);
188 V sum1 = Zero(df32);
189 V sum2 = Zero(df32);
190 V sum3 = Zero(df32);
191
192 // Main loop: unrolled
193 for (; i + 2 * N <= num_elements; /* i += 2 * N */) { // incr in loop
194 const auto a0 = LoadU(d, pa + i);
195 const auto b0 = LoadU(d, pb + i);
196 i += N;
197 sum0 = ReorderWidenMulAccumulate(df32, a0, b0, sum0, sum1);
198 const auto a1 = LoadU(d, pa + i);
199 const auto b1 = LoadU(d, pb + i);
200 i += N;
201 sum2 = ReorderWidenMulAccumulate(df32, a1, b1, sum2, sum3);
202 }
203
204 // Possibly one more iteration of whole vectors
205 if (i + N <= num_elements) {
206 const auto a0 = LoadU(d, pa + i);
207 const auto b0 = LoadU(d, pb + i);
208 i += N;
209 sum0 = ReorderWidenMulAccumulate(df32, a0, b0, sum0, sum1);
210 }
211
212 if (!kIsMultipleOfVector) {
213 const size_t remaining = num_elements - i;
214 if (remaining != 0) {
215 if (kIsPaddedToVector) {
216 const auto mask = FirstN(du16, remaining);
217 const auto va = LoadU(d, pa + i);
218 const auto vb = LoadU(d, pb + i);
219 const auto a16 = BitCast(d, IfThenElseZero(mask, BitCast(du16, va)));
220 const auto b16 = BitCast(d, IfThenElseZero(mask, BitCast(du16, vb)));
221 sum2 = ReorderWidenMulAccumulate(df32, a16, b16, sum2, sum3);
222
223 } else {
224 // Unaligned load such that the last element is in the highest lane -
225 // ensures we do not touch any elements outside the valid range.
226 // If we get here, then num_elements >= N.
227 HWY_DASSERT(i >= N);
228 i += remaining - N;
229 const auto skip = FirstN(du16, N - remaining);
230 const auto va = LoadU(d, pa + i); // always unaligned
231 const auto vb = LoadU(d, pb + i);
232 const auto a16 = BitCast(d, IfThenZeroElse(skip, BitCast(du16, va)));
233 const auto b16 = BitCast(d, IfThenZeroElse(skip, BitCast(du16, vb)));
234 sum2 = ReorderWidenMulAccumulate(df32, a16, b16, sum2, sum3);
235 }
236 }
237 } // kMultipleOfVector
238
239 // Reduction tree: sum of all accumulators by pairs, then across lanes.
240 sum0 = Add(sum0, sum1);
241 sum2 = Add(sum2, sum3);
242 sum0 = Add(sum0, sum2);
243 return GetLane(SumOfLanes(df32, sum0));
244 }
245};
246
247// NOLINTNEXTLINE(google-readability-namespace-comments)
248} // namespace HWY_NAMESPACE
249} // namespace hwy
251
252#endif // HIGHWAY_HWY_CONTRIB_DOT_DOT_INL_H_
#define HWY_RESTRICT
Definition: base.h:64
#define HWY_INLINE
Definition: base.h:70
#define HWY_DASSERT(condition)
Definition: base.h:238
#define HWY_UNLIKELY(expr)
Definition: base.h:76
HWY_AFTER_NAMESPACE()
HWY_BEFORE_NAMESPACE()
HWY_INLINE Vec128< T, N > Add(hwy::NonFloatTag, Vec128< T, N > a, Vec128< T, N > b)
Definition: emu128-inl.h:535
d
Definition: rvv-inl.h:1998
HWY_API Mask128< T, N > FirstN(const Simd< T, N, 0 > d, size_t num)
Definition: arm_neon-inl.h:2456
HWY_API Vec128< float, N > MulAdd(const Vec128< float, N > mul, const Vec128< float, N > x, const Vec128< float, N > add)
Definition: arm_neon-inl.h:1799
HWY_API Vec128< T, N > SumOfLanes(Simd< T, N, 0 >, const Vec128< T, N > v)
Definition: arm_neon-inl.h:5334
Rebind< MakeUnsigned< TFromD< D > >, D > RebindToUnsigned
Definition: ops/shared-inl.h:212
HWY_API Vec128< float, N > ReorderWidenMulAccumulate(Simd< float, N, 0 > df32, Vec128< bfloat16_t, 2 *N > a, Vec128< bfloat16_t, 2 *N > b, const Vec128< float, N > sum0, Vec128< float, N > &sum1)
Definition: arm_neon-inl.h:4288
HWY_API Vec128< T, N > IfThenElseZero(const Mask128< T, N > mask, const Vec128< T, N > yes)
Definition: arm_neon-inl.h:2253
HWY_API constexpr size_t Lanes(Simd< T, N, kPow2 >)
Definition: arm_sve-inl.h:243
HWY_API Vec128< uint8_t > LoadU(Full128< uint8_t >, const uint8_t *HWY_RESTRICT unaligned)
Definition: arm_neon-inl.h:2591
HWY_API Vec128< T, N > BitCast(Simd< T, N, 0 > d, Vec128< FromT, N *sizeof(T)/sizeof(FromT)> v)
Definition: arm_neon-inl.h:997
HWY_API Vec128< T, N > Zero(Simd< T, N, 0 > d)
Definition: arm_neon-inl.h:1020
HWY_API Vec128< T, N > IfThenZeroElse(const Mask128< T, N > mask, const Vec128< T, N > no)
Definition: arm_neon-inl.h:2260
HWY_API TFromV< V > GetLane(const V v)
Definition: arm_neon-inl.h:1076
typename D::template Repartition< T > Repartition
Definition: ops/shared-inl.h:218
N
Definition: rvv-inl.h:1998
Definition: aligned_allocator.h:27
HWY_API float F32FromBF16(bfloat16_t bf)
Definition: base.h:975
#define HWY_NAMESPACE
Definition: set_macros-inl.h:82
Definition: dot-inl.h:33
static HWY_INLINE T Compute(const D d, const T *const HWY_RESTRICT pa, const T *const HWY_RESTRICT pb, const size_t num_elements)
Definition: dot-inl.h:53
static HWY_INLINE float Compute(const D d, const bfloat16_t *const HWY_RESTRICT pa, const bfloat16_t *const HWY_RESTRICT pb, const size_t num_elements)
Definition: dot-inl.h:153
Assumptions
Definition: dot-inl.h:37
@ kMultipleOfVector
Definition: dot-inl.h:43
@ kPaddedToVector
Definition: dot-inl.h:46
@ kAtLeastOneVector
Definition: dot-inl.h:39
Definition: base.h:296