aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/Helpers.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/Helpers.cpp')
-rw-r--r--tests/validation/Helpers.cpp37
1 files changed, 36 insertions, 1 deletions
diff --git a/tests/validation/Helpers.cpp b/tests/validation/Helpers.cpp
index 2f273e7042..cb4d87601c 100644
--- a/tests/validation/Helpers.cpp
+++ b/tests/validation/Helpers.cpp
@@ -352,8 +352,43 @@ void add_padding_x(std::initializer_list<ITensor *> tensors, const DataLayout &d
}
}
+QuantizationHint suggest_conv_dst_q_info_and_bias(const QuantizationInfo &in_q_info,
+ const QuantizationInfo &weight_q_info,
+ int32_t height,
+ int32_t width,
+ int32_t channels,
+ DataType data_type,
+ float bias_fraction)
+{
+ /** Quantization Setup of convolution
+ *
+ * Just like any other multiply-accummulate, convolution (2D) operation
+ * multiplies and accumulates the input and weight tensors. This operation
+ * takes place in three dimensions: height, width and channels. All of them
+ * belong to the weight tensor.
+ *
+ * The formula for simple convolution can be written as:
+ * C = sum_h sum_w sum_c(I[h_offset + h, w_offset + w, c] * W[h, w, c])
+ *
+ * Here, h_offset and w_offset are the starting positions in the image. Effects
+ * of paddings are ignored. This accumulation reduces to something like
+ *
+ * C = sum_m(I_index * W_hwc)
+ * where m is height x width x channels.
+ *
+ * Non-unit strides and/or dilations do not change the probabilistic nature of
+ * this sum because we always iterate as the size of the weight tensor.
+ *
+ * Paddings may affect this summation, but it's a boundary condition and so is
+ * neglected for brevity.
+ */
+
+ return suggest_mac_dst_q_info_and_bias(in_q_info, weight_q_info, height * width * channels, data_type, bias_fraction);
+}
+
QuantizationHint suggest_matmul_dst_q_info_and_bias(const QuantizationInfo &lhs_q_info,
- const QuantizationInfo &rhs_q_info, int32_t m, int32_t n, int32_t k, DataType data_type,
+ const QuantizationInfo &rhs_q_info,
+ int32_t m, int32_t n, int32_t k, DataType data_type,
float bias_fraction)
{
ARM_COMPUTE_UNUSED(m, n);