25#define COMPILER_CLANG 1
28#define COMPILER_CLANG 0
33#define COMPILER_CLANG 1
36#define COMPILER_CLANG 0
45#ifndef __ARM_ARCH_7A__
46#define __ARM_ARCH_7A__ 0
49constexpr bool ARMv7a = (__ARM_ARCH == 7 && __ARM_ARCH_7A__);
50constexpr bool kCompilerClang = COMPILER_CLANG;
54[[gnu::always_inline]]
static constexpr int32_t multiply_32x32_rshift32(int32_t a, int32_t b) {
56 if constexpr (ARMv7a && !kCompilerClang) {
58 asm(
"smmul %0, %1, %2" :
"=r"(out) :
"r"(a),
"r"(b));
63 return (int32_t)(((int64_t)a * b) >> 32);
67[[gnu::always_inline]]
static constexpr int32_t multiply_32x32_rshift32_rounded(int32_t a, int32_t b) {
69 if constexpr (ARMv7a && !kCompilerClang) {
71 asm(
"smmulr %0, %1, %2" :
"=r"(out) :
"r"(a),
"r"(b));
76 return (int32_t)(((int64_t)a * b + 0x80000000) >> 32);
81[[gnu::always_inline]]
static constexpr int32_t q31_mult(int32_t a, int32_t b) {
82 return multiply_32x32_rshift32(a, b) << 1;
87[[gnu::always_inline]]
static constexpr int32_t q31_mult_rounded(int32_t a, int32_t b) {
88 return multiply_32x32_rshift32_rounded(a, b) << 1;
92[[gnu::always_inline]]
static constexpr int32_t multiply_accumulate_32x32_rshift32_rounded(int32_t sum, int32_t a,
95 if constexpr (ARMv7a && !kCompilerClang) {
97 asm(
"smmlar %0, %1, %2, %3" :
"=r"(out) :
"r"(a),
"r"(b),
"r"(sum));
102 return (int32_t)(((((int64_t)sum) << 32) + ((int64_t)a * b) + 0x80000000) >> 32);
106[[gnu::always_inline]]
static constexpr int32_t multiply_accumulate_32x32_rshift32(int32_t sum, int32_t a, int32_t b) {
108 if constexpr (ARMv7a && !kCompilerClang) {
110 asm(
"smmla %0, %1, %2, %3" :
"=r"(out) :
"r"(a),
"r"(b),
"r"(sum));
115 return (int32_t)(((((int64_t)sum) << 32) + ((int64_t)a * b)) >> 32);
119[[gnu::always_inline]]
static constexpr int32_t multiply_subtract_32x32_rshift32_rounded(int32_t sum, int32_t a,
122 if constexpr (ARMv7a && !kCompilerClang) {
124 asm(
"smmlsr %0, %1, %2, %3" :
"=r"(out) :
"r"(a),
"r"(b),
"r"(sum));
129 return (int32_t)((((((int64_t)sum) << 32) - ((int64_t)a * b)) + 0x80000000) >> 32);
133template <
size_t bits>
135[[gnu::always_inline]]
static constexpr int32_t signed_saturate(int32_t val) {
137 if constexpr (ARMv7a) {
139 asm(
"ssat %0, %1, %2" :
"=r"(out) :
"I"(bits),
"r"(val));
144 return std::clamp<int32_t>(val, -(1LL << (bits - 1)), (1LL << (bits - 1)) - 1);
150template <
size_t bits>
152[[gnu::always_inline]]
static constexpr uint32_t unsigned_saturate(uint32_t val) {
155 if constexpr (ARMv7a) {
157 asm(
"usat %0, %1, %2" :
"=r"(out) :
"I"(bits),
"r"(val));
162 return std::clamp<uint32_t>(val, 0u, (1uL << bits) - 1);
165template <
size_t shift,
size_t bits = 32>
166requires(shift < 32 && bits <= 32)
167[[gnu::always_inline]]
static constexpr int32_t shift_left_saturate(int32_t val) {
169 if constexpr (ARMv7a) {
171 asm(
"ssat %0, %1, %2, LSL %3" :
"=r"(out) :
"I"(bits),
"r"(val),
"I"(shift));
176 return std::clamp<int64_t>((int64_t)val << shift, -(1LL << (bits - 1)), (1LL << (bits - 1)) - 1);
179template <
size_t shift,
size_t bits = 32>
180requires(shift < 32 && bits <= 32)
181[[gnu::always_inline]]
static constexpr uint32_t shift_left_saturate(uint32_t val) {
183 if constexpr (ARMv7a) {
185 asm(
"usat %0, %1, %2, LSL %3" :
"=r"(out) :
"I"(bits),
"r"(val),
"I"(shift));
189 return std::clamp<uint64_t>((uint64_t)val << shift, 0u, (1uLL << bits) - 1);
192[[gnu::always_inline]]
static constexpr int32_t add_saturate(int32_t a, int32_t b) {
194 if constexpr (ARMv7a) {
196 asm(
"qadd %0, %1, %2" :
"=r"(out) :
"r"(a),
"r"(b));
200 return std::clamp<int64_t>(a + b, -(1LL << 31), (1LL << 31) - 1);
203[[gnu::always_inline]]
static constexpr int32_t subtract_saturate(int32_t a, int32_t b) {
205 if constexpr (ARMv7a) {
207 asm(
"qsub %0, %1, %2" :
"=r"(out) :
"r"(a),
"r"(b));
211 return std::clamp<int64_t>(a - b, -(1LL << 31), (1LL << 31) - 1);
220static inline int32_t q31_from_float(
float value) {
221 asm(
"vcvt.s32.f32 %0, %0, #31" :
"=t"(value) :
"t"(value));
222 return std::bit_cast<int32_t>(value);
228static inline float int32_to_float(int32_t value) {
229 asm(
"vcvt.f32.s32 %0, %0, #31" :
"=t"(value) :
"t"(value));
230 return std::bit_cast<float>(value);