aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/core/NEON/kernels/assembly/arm_gemm.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'arm_compute/core/NEON/kernels/assembly/arm_gemm.hpp')
-rw-r--r--arm_compute/core/NEON/kernels/assembly/arm_gemm.hpp52
1 files changed, 37 insertions, 15 deletions
diff --git a/arm_compute/core/NEON/kernels/assembly/arm_gemm.hpp b/arm_compute/core/NEON/kernels/assembly/arm_gemm.hpp
index d51fda525b..e89523981d 100644
--- a/arm_compute/core/NEON/kernels/assembly/arm_gemm.hpp
+++ b/arm_compute/core/NEON/kernels/assembly/arm_gemm.hpp
@@ -108,23 +108,45 @@ public:
}
};
-struct ARequantizeLayer32
+struct Requantize32
{
public:
- const int32_t *bias;
- size_t bias_multi_stride;
- int32_t a_offset;
- int32_t b_offset;
- int32_t c_offset;
- int32_t requant_shift;
- int32_t requant_mul;
- int32_t minval;
- int32_t maxval;
-
- ARequantizeLayer32() = default;
-
- ARequantizeLayer32(const int32_t *b, size_t bms, int32_t ao, int32_t bo, int32_t co, int32_t rs, int32_t rm, int32_t minv, int32_t maxv) :
- bias(b), bias_multi_stride(bms), a_offset(ao), b_offset(bo), c_offset(co), requant_shift(rs), requant_mul(rm), minval(minv), maxval(maxv)
+ const int32_t *bias = nullptr;
+ size_t bias_multi_stride = 0;
+ int32_t a_offset = 0;
+ int32_t b_offset = 0;
+ int32_t c_offset = 0;
+ bool per_channel_requant = false;
+ int32_t per_layer_shift = 0;
+ int32_t per_layer_mul = 0;
+ const int32_t *per_channel_shifts = nullptr;
+ const int32_t *per_channel_muls = nullptr;
+ int32_t minval = 0;
+ int32_t maxval = 0;
+
+ Requantize32() = default;
+
+ // Constructor for per-tensor quantization
+ Requantize32(const int32_t *bias, size_t bias_multi_stride,
+ int32_t a_offset, int32_t b_offset, int32_t c_offset,
+ int32_t requant_shift, int32_t requant_mul,
+ int32_t minv, int32_t maxv) :
+ bias(bias), bias_multi_stride(bias_multi_stride),
+ a_offset(a_offset), b_offset(b_offset), c_offset(c_offset),
+ per_channel_requant(false), per_layer_shift(requant_shift), per_layer_mul(requant_mul),
+ minval(minv), maxval(maxv)
+ {
+ }
+
+ // Constructor for per-channel quantization
+ Requantize32(const int32_t *bias, size_t bias_multi_stride,
+ int32_t a_offset, int32_t b_offset, int32_t c_offset,
+ const int32_t *requant_shifts, const int32_t *requant_muls,
+ int32_t minv, int32_t maxv) :
+ bias(bias), bias_multi_stride(bias_multi_stride),
+ a_offset(a_offset), b_offset(b_offset), c_offset(c_offset),
+ per_channel_requant(true), per_channel_shifts(requant_shifts), per_channel_muls(requant_muls),
+ minval(minv), maxval(maxv)
{
}
};