aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGiorgio Arena <giorgio.arena@arm.com>2018-02-16 11:01:04 +0000
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:47:18 +0000
commit287b570b86ba40a801136aded140b83435ca9314 (patch)
tree08f1462dd8b28020d2aaa72509d21b8a90005cc2
parentd267b05aaaec9b462a8c988c7b5fcebd5776c72f (diff)
downloadComputeLibrary-287b570b86ba40a801136aded140b83435ca9314.tar.gz
COMPMID-853 Use tile 2 for CL depthwise convolution QASYM8
Change-Id: I91f6a0b057f5eb84c6ac7db5abbc05c7520ed5d2 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/120760 Tested-by: Jenkins <bsgcomp@arm.com> Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
-rw-r--r--src/core/CL/cl_kernels/depthwise_convolution_quantized.cl328
-rw-r--r--src/core/CL/kernels/CLDepthwiseConvolutionLayer3x3Kernel.cpp43
-rw-r--r--tests/benchmark/fixtures/DepthwiseConvolutionLayerFixture.h6
3 files changed, 179 insertions, 198 deletions
diff --git a/src/core/CL/cl_kernels/depthwise_convolution_quantized.cl b/src/core/CL/cl_kernels/depthwise_convolution_quantized.cl
index e4345817fc..b2527a4c7d 100644
--- a/src/core/CL/cl_kernels/depthwise_convolution_quantized.cl
+++ b/src/core/CL/cl_kernels/depthwise_convolution_quantized.cl
@@ -24,158 +24,45 @@
#include "helpers_asymm.h"
-#if defined(CONV_STRIDE_X)
+#if defined(CONV_STRIDE_X) && defined(CONV_STRIDE_Y) && defined(WEIGHTS_OFFSET) && defined(INPUT_OFFSET) && defined(K_OFFSET) && defined(OUTPUT_OFFSET) && defined(OUTPUT_MULTIPLIER) && defined(OUTPUT_SHIFT)
+
+#if CONV_STRIDE_X > 3
+#error "Stride X not supported"
+#endif /* CONV_STRIDE_X > 3 */
#if CONV_STRIDE_X == 1
-#define convolution1x3 convolution1x3_stride_1
+#define GET_VALUES(first_value, left, middle, right) \
+ ({ \
+ int8 temp0 = CONVERT(vload8(0, first_value), int8); \
+ int2 temp1 = CONVERT(vload2(0, (first_value + 8 * sizeof(uchar))), int2); \
+ \
+ left = CONVERT(temp0.s01234567, int8); \
+ middle = CONVERT((int8)(temp0.s1234, temp0.s567, temp1.s0), int8); \
+ right = CONVERT((int8)(temp0.s2345, temp0.s67, temp1.s01), int8); \
+ })
#elif CONV_STRIDE_X == 2
-#define convolution1x3 convolution1x3_stride_2
-#elif CONV_STRIDE_X == 3
-#define convolution1x3 convolution1x3_stride_3
+#define GET_VALUES(first_value, left, middle, right) \
+ ({ \
+ int16 temp0 = CONVERT(vload16(0, first_value), int16); \
+ int temp1 = CONVERT(*(first_value + 16 * sizeof(uchar)), int); \
+ \
+ left = CONVERT(temp0.s02468ace, int8); \
+ middle = CONVERT(temp0.s13579bdf, int8); \
+ right = CONVERT((int8)(temp0.s2468, temp0.sace, temp1), int8); \
+ })
#else /* CONV_STRIDE_X */
-#error "Stride not supported"
+#define GET_VALUES(first_value, left, middle, right) \
+ ({ \
+ int16 temp0 = CONVERT(vload16(0, first_value), int16); \
+ int8 temp1 = CONVERT(vload8(0, (first_value + 16 * sizeof(uchar))), int8); \
+ \
+ left = CONVERT((int8)(temp0.s0369, temp0.scf, temp1.s25), int8); \
+ middle = CONVERT((int8)(temp0.s147a, temp0.sd, temp1.s036), int8); \
+ right = CONVERT((int8)(temp0.s258b, temp0.se, temp1.s147), int8); \
+ })
#endif /* CONV_STRIDE_X */
-/** Compute a 1D horizontal convolution of size 3 and stride 1 for uchar type.
- *
- * @param[in] left_pixel Pointer to the left pixel.
- * @param[in] left_coeff Weight of the left pixel
- * @param[in] middle_coeff Weight of the middle pixel
- * @param[in] right_coeff Weight of the right pixel
- * @param[in] input_offset Quantized offset of zero point of the input tensor data range
- * @param[in] weight_offset Quantized offset of zero point of the weights tensor data range
- *
- * @return a int8 containing 8 convoluted values.
- */
-inline int8 convolution1x3_stride_1(__global const uchar *left_pixel,
- const int left_coeff,
- const int middle_coeff,
- const int right_coeff,
- const int input_offset,
- const int weight_offset)
-{
- int8 temp0 = CONVERT(vload8(0, left_pixel), int8);
- int2 temp1 = CONVERT(vload2(0, (left_pixel + 8 * sizeof(uchar))), int2);
-
- int8 left = CONVERT(temp0.s01234567, int8);
- int8 middle = CONVERT((int8)(temp0.s1234, temp0.s567, temp1.s0), int8);
- int8 right = CONVERT((int8)(temp0.s2345, temp0.s67, temp1.s01), int8);
-
- return (left + input_offset) * (int8)(left_coeff + weight_offset) + (middle + input_offset) * (int8)(middle_coeff + weight_offset) + (right + input_offset) * (int8)(right_coeff + weight_offset);
-}
-
-/** Compute a 1D horizontal convolution of size 3 and stride 2 for uchar type.
- *
- * @param[in] left_pixel Pointer to the left pixel.
- * @param[in] left_coeff Weight of the left pixel
- * @param[in] middle_coeff Weight of the middle pixel
- * @param[in] right_coeff Weight of the right pixel
- * @param[in] input_offset Quantized offset of zero point of the input tensor data range
- * @param[in] weight_offset Quantized offset of zero point of the weights tensor data range
- *
- * @return a int8 containing 8 convoluted values.
- */
-inline int8 convolution1x3_stride_2(__global const uchar *left_pixel,
- const int left_coeff,
- const int middle_coeff,
- const int right_coeff,
- const int input_offset,
- const int weight_offset)
-{
- int16 temp0 = CONVERT(vload16(0, left_pixel), int16);
- int temp1 = CONVERT(*(left_pixel + 16 * sizeof(uchar)), int);
-
- int8 left = CONVERT(temp0.s02468ace, int8);
- int8 middle = CONVERT(temp0.s13579bdf, int8);
- int8 right = CONVERT((int8)(temp0.s2468, temp0.sace, temp1), int8);
-
- return (left + input_offset) * (int8)(left_coeff + weight_offset) + (middle + input_offset) * (int8)(middle_coeff + weight_offset) + (right + input_offset) * (int8)(right_coeff + weight_offset);
-}
-
-/** Compute a 1D horizontal convolution of size 3 and stride 3 for uchar type.
- *
- * @param[in] left_pixel Pointer to the left pixel.
- * @param[in] left_coeff Weight of the left pixel
- * @param[in] middle_coeff Weight of the middle pixel
- * @param[in] right_coeff Weight of the right pixel
- * @param[in] input_offset Quantized offset of zero point of the input tensor data range
- * @param[in] weight_offset Quantized offset of zero point of the weights tensor data range
- *
- * @return a int8 containing 8 convoluted values.
- */
-inline int8 convolution1x3_stride_3(__global const uchar *left_pixel,
- const int left_coeff,
- const int middle_coeff,
- const int right_coeff,
- const int input_offset,
- const int weight_offset)
-{
- int16 temp0 = CONVERT(vload16(0, left_pixel), int16);
- int8 temp1 = CONVERT(vload8(0, (left_pixel + 16 * sizeof(uchar))), int8);
-
- int8 left = CONVERT((int8)(temp0.s0369, temp0.scf, temp1.s25), int8);
- int8 middle = CONVERT((int8)(temp0.s147a, temp0.sd, temp1.s036), int8);
- int8 right = CONVERT((int8)(temp0.s258b, temp0.se, temp1.s147), int8);
-
- return (left + input_offset) * (int8)(left_coeff + weight_offset) + (middle + input_offset) * (int8)(middle_coeff + weight_offset) + (right + input_offset) * (int8)(right_coeff + weight_offset);
-}
-
-/** Apply a 3x3 convolution matrix to a single channel QASYMM8 input image and return the result.
- *
- * Convolution matrix layout:
- *
- * [ mat0, mat1, mat2 ]\n
- * [ mat3, mat4, mat5 ]\n
- * [ mat6, mat7, mat8 ]\n
- *
- * @param[in] src A pointer to source Image structure
- * @param[in] mat0 Coefficient from the convolution matrix
- * @param[in] mat1 Coefficient from the convolution matrix
- * @param[in] mat2 Coefficient from the convolution matrix
- * @param[in] mat3 Coefficient from the convolution matrix
- * @param[in] mat4 Coefficient from the convolution matrix
- * @param[in] mat5 Coefficient from the convolution matrix
- * @param[in] mat6 Coefficient from the convolution matrix
- * @param[in] mat7 Coefficient from the convolution matrix
- * @param[in] mat8 Coefficient from the convolution matrix
- * @param[in] input_offset Quantized offset of zero point of the input tensor data range
- * @param[in] weight_offset Quantized offset of zero point of the weights tensor data range
- * @param[in] output_offset Quantized offset of zero point of the output tensor data range
- * @param[in] output_multiplier Output scale multiplier
- * @param[in] output_shift Output scale divisor exponent
- * @param[in] bias (Optional) Bias value
- *
- * @return a uchar8 containing 8 convoluted values.
- */
-inline uchar8 convolution3x3(
- Image *src,
- const uchar mat0, const uchar mat1, const uchar mat2,
- const uchar mat3, const uchar mat4, const uchar mat5,
- const uchar mat6, const uchar mat7, const uchar mat8,
- const int input_offset, const int weight_offset, const int output_offset,
- const int output_multiplier, const int output_shift
-#if defined(HAS_BIAS)
- ,
- const int bias
-#endif //defined(HAS_BIAS)
-)
-{
- int8 pixels;
-
- pixels = convolution1x3(offset(src, 0, 0), mat0, mat1, mat2, input_offset, weight_offset);
- pixels += convolution1x3(offset(src, 0, 1), mat3, mat4, mat5, input_offset, weight_offset);
- pixels += convolution1x3(offset(src, 0, 2), mat6, mat7, mat8, input_offset, weight_offset);
-#if defined(HAS_BIAS)
- pixels += (int8)(bias);
-#endif //defined(HAS_BIAS)
-
- pixels = ASYMM_MULT_BY_QUANT_MULTIPLIER_LESS_THAN_ONE(pixels, output_multiplier, output_shift, 8);
- pixels = pixels + output_offset;
-
- return CONVERT_SAT(pixels, uchar8);
-}
-
-/** This function computes the horizontal integral of the image.
+/** This function computes the horizontal integral of the image and adds offsets.
*
* @param[in] src_ptr Pointer to the source image. Supported data types: QASYMM8
* @param[in] src_stride_x Stride of the source image in X dimension (in bytes)
@@ -205,11 +92,6 @@ inline uchar8 convolution3x3(
* @param[in] biases_stride_x (Optional) Stride of the biases vector in X dimension (in bytes)
* @param[in] biases_step_x (Optional) biases_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] biases_offset_first_element_in_bytes (Optional) The offset of the first element in the biases vector
- * @param[in] input_offset Quantized offset of zero point of the input tensor data range
- * @param[in] weight_offset Quantized offset of zero point of the weights tensor data range
- * @param[in] output_offset Quantized offset of zero point of the output tensor data range
- * @param[in] output_multiplier Output scale multiplier
- * @param[in] output_shift Output scale divisor exponent
*/
__kernel void depthwise_convolution_3x3_quantized(
@@ -217,41 +99,139 @@ __kernel void depthwise_convolution_3x3_quantized(
TENSOR3D_DECLARATION(dst),
TENSOR3D_DECLARATION(weights),
#if defined(HAS_BIAS)
- VECTOR_DECLARATION(biases),
+ VECTOR_DECLARATION(biases)
#endif //defined(HAS_BIAS)
- int input_offset,
- int weight_offset,
- int output_offset,
- int output_multiplier,
- int output_shift)
+)
{
Image src = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(src);
Image dst = CONVERT_TENSOR3D_TO_IMAGE_STRUCT(dst);
Tensor3D weights = CONVERT_TO_TENSOR3D_STRUCT(weights);
#if defined(HAS_BIAS)
Vector biases = CONVERT_TO_VECTOR_STRUCT_NO_STEP(biases);
-#endif //defined(HAS_BIAS)
-
- uchar3 offset = (uchar3)(0, 1, 2) * (uchar3)weights_stride_y;
- uchar3 weights_values0 = vload3(0, weights.ptr + offset.s0);
- uchar3 weights_values1 = vload3(0, weights.ptr + offset.s1);
- uchar3 weights_values2 = vload3(0, weights.ptr + offset.s2);
-#if defined(HAS_BIAS)
int bias_value = *((__global int *)(vector_offset(&biases, get_global_id(2))));
#endif //defined(HAS_BIAS)
- uchar8 pixels = convolution3x3(&src, weights_values0.s0, weights_values0.s1, weights_values0.s2,
- weights_values1.s0, weights_values1.s1, weights_values1.s2,
- weights_values2.s0, weights_values2.s1, weights_values2.s2,
- input_offset, weight_offset, output_offset,
- output_multiplier, output_shift
+ uchar3 w0 = vload3(0, weights.ptr + 0 * weights_stride_y);
+ uchar3 w1 = vload3(0, weights.ptr + 1 * weights_stride_y);
+ uchar3 w2 = vload3(0, weights.ptr + 2 * weights_stride_y);
+
+ int8 values0 = 0;
+ int8 sum0 = 0;
+#if CONV_STRIDE_Y == 1
+ int8 values1 = 0;
+ int8 sum1 = 0;
+#endif /* CONV_STRIDE_Y */
+
+ // Row0
+ int8 left, middle, right;
+ GET_VALUES(src.ptr + 0 * src_stride_y, left, middle, right);
+ values0 += left * (int8)(w0.s0);
+ values0 += middle * (int8)(w0.s1);
+ values0 += right * (int8)(w0.s2);
+
+#if WEIGHTS_OFFSET != 0
+ sum0 += left + middle + right;
+#endif /* WEIGHTS_OFFSET != 0 */
+
+ // Row1
+ GET_VALUES(src.ptr + 1 * src_stride_y, left, middle, right);
+ values0 += left * (int8)(w1.s0);
+ values0 += middle * (int8)(w1.s1);
+ values0 += right * (int8)(w1.s2);
+#if CONV_STRIDE_Y == 1
+ values1 += left * (int8)(w0.s0);
+ values1 += middle * (int8)(w0.s1);
+ values1 += right * (int8)(w0.s2);
+#endif /* CONV_STRIDE_Y == 1 */
+
+#if WEIGHTS_OFFSET != 0
+ int8 tmp = left + middle + right;
+ sum0 += tmp;
+#if CONV_STRIDE_Y == 1
+ sum1 += tmp;
+#endif /* CONV_STRIDE_Y == 1 */
+#endif /* WEIGHTS_OFFSET != 0 */
+
+ // Row2
+ GET_VALUES(src.ptr + 2 * src_stride_y, left, middle, right);
+ values0 += left * (int8)(w2.s0);
+ values0 += middle * (int8)(w2.s1);
+ values0 += right * (int8)(w2.s2);
+#if CONV_STRIDE_Y == 1
+ values1 += left * (int8)(w1.s0);
+ values1 += middle * (int8)(w1.s1);
+ values1 += right * (int8)(w1.s2);
+#endif /* CONV_STRIDE_Y == 1 */
+
+#if WEIGHTS_OFFSET != 0
+ tmp = left + middle + right;
+ sum0 += tmp;
+#if CONV_STRIDE_Y == 1
+ sum1 += tmp;
+#endif /* CONV_STRIDE_Y == 1 */
+#endif /* WEIGHTS_OFFSET != 0 */
+
+#if CONV_STRIDE_Y == 1
+ // Row3
+ GET_VALUES(src.ptr + 3 * src_stride_y, left, middle, right);
+ values1 += left * (int8)(w2.s0);
+ values1 += middle * (int8)(w2.s1);
+ values1 += right * (int8)(w2.s2);
+
+#if WEIGHTS_OFFSET != 0
+ sum1 += left + middle + right;
+#endif /* WEIGHTS_OFFSET != 0 */
+#endif /* CONV_STRIDE_Y == 1 */
+
#if defined(HAS_BIAS)
- ,
- bias_value
+ values0 += (int8)(bias_value);
+#if CONV_STRIDE_Y == 1
+ values1 += (int8)(bias_value);
+#endif /* CONV_STRIDE_Y == 1 */
#endif //defined(HAS_BIAS)
- );
- vstore8(pixels, 0, dst.ptr);
+#if WEIGHTS_OFFSET != 0
+ values0 += sum0 * (int8)(WEIGHTS_OFFSET);
+#if CONV_STRIDE_Y == 1
+ values1 += sum1 * (int8)(WEIGHTS_OFFSET);
+#endif /* CONV_STRIDE_Y == 1 */
+#endif /* WEIGHTS_OFFSET != 0 */
+
+#if INPUT_OFFSET != 0
+ ushort sum_weights = 0;
+ ushort3 tmp_we = convert_ushort3(w0) + convert_ushort3(w1) + convert_ushort3(w2);
+ sum_weights += tmp_we.s0 + tmp_we.s1 + tmp_we.s2;
+ values0 += sum_weights * (int8)(INPUT_OFFSET);
+#if CONV_STRIDE_Y == 1
+ values1 += sum_weights * (int8)(INPUT_OFFSET);
+#endif /* CONV_STRIDE_Y == 1 */
+#endif /* INPUT_OFFSET != 0 */
+
+#if K_OFFSET != 0
+ values0 += (int8)(K_OFFSET);
+#if CONV_STRIDE_Y == 1
+ values1 += (int8)(K_OFFSET);
+#endif /* CONV_STRIDE_Y == 1 */
+#endif /* K_OFFSET != 0 */
+
+ values0 = ASYMM_MULT_BY_QUANT_MULTIPLIER_LESS_THAN_ONE(values0, OUTPUT_MULTIPLIER, OUTPUT_SHIFT, 8);
+ values0 += (int8)OUTPUT_OFFSET;
+ uchar8 res0 = convert_uchar8_sat(values0);
+ res0 = max(res0, (uchar8)0);
+ res0 = min(res0, (uchar8)255);
+
+ vstore8(res0, 0, dst.ptr);
+#if CONV_STRIDE_Y == 1
+
+ values1 = ASYMM_MULT_BY_QUANT_MULTIPLIER_LESS_THAN_ONE(values1, OUTPUT_MULTIPLIER, OUTPUT_SHIFT, 8);
+ values1 += (int8)OUTPUT_OFFSET;
+ uchar8 res1 = convert_uchar8_sat(values1);
+ res1 = max(res1, (uchar8)0);
+ res1 = min(res1, (uchar8)255);
+
+ vstore8(res1, 0, dst.ptr + dst_stride_y);
+#endif /* CONV_STRIDE_Y == 1 */
}
-#endif //defined(CONV_STRIDE_X) \ No newline at end of file
+
+#endif /* defined(CONV_STRIDE_X) && defined(CONV_STRIDE_Y) && defined(WEIGHTS_OFFSET) && defined(INPUT_OFFSET) && defined(K_OFFSET) && defined(OUTPUT_OFFSET) && defined(OUTPUT_MULTIPLIER) && defined(OUTPUT_SHIFT) */
diff --git a/src/core/CL/kernels/CLDepthwiseConvolutionLayer3x3Kernel.cpp b/src/core/CL/kernels/CLDepthwiseConvolutionLayer3x3Kernel.cpp
index 2a60f60723..3613419273 100644
--- a/src/core/CL/kernels/CLDepthwiseConvolutionLayer3x3Kernel.cpp
+++ b/src/core/CL/kernels/CLDepthwiseConvolutionLayer3x3Kernel.cpp
@@ -55,9 +55,11 @@ void CLDepthwiseConvolutionLayer3x3Kernel::configure(const ICLTensor *input, con
ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights);
ARM_COMPUTE_ERROR_ON(weights->info()->dimension(0) != 3 || weights->info()->dimension(1) != 3);
+ bool is_qasymm = is_data_type_quantized_asymmetric(input->info()->data_type());
+
if(biases != nullptr)
{
- if(is_data_type_quantized_asymmetric(weights->info()->data_type()))
+ if(is_qasymm)
{
ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(biases, 1, DataType::S32);
}
@@ -98,6 +100,22 @@ void CLDepthwiseConvolutionLayer3x3Kernel::configure(const ICLTensor *input, con
build_opts.add_option("-DCONV_STRIDE_X=" + support::cpp11::to_string(_conv_stride_x));
build_opts.add_option_if(_biases != nullptr, "-DHAS_BIAS");
+ if(is_qasymm)
+ {
+ float multiplier = _input->info()->quantization_info().scale * _weights->info()->quantization_info().scale / _output->info()->quantization_info().scale;
+ int output_multiplier = 0;
+ int output_shift = 0;
+ quantization::calculate_quantized_multiplier_less_than_one(multiplier, &output_multiplier, &output_shift);
+
+ build_opts.add_option("-DCONV_STRIDE_Y=" + support::cpp11::to_string(_conv_stride_y));
+ build_opts.add_option("-DINPUT_OFFSET=" + support::cpp11::to_string(-_input->info()->quantization_info().offset));
+ build_opts.add_option("-DWEIGHTS_OFFSET=" + support::cpp11::to_string(-_weights->info()->quantization_info().offset));
+ build_opts.add_option("-DOUTPUT_OFFSET=" + support::cpp11::to_string(_output->info()->quantization_info().offset));
+ build_opts.add_option("-DK_OFFSET=" + support::cpp11::to_string(9 * input->info()->quantization_info().offset * weights->info()->quantization_info().offset));
+ build_opts.add_option("-DOUTPUT_MULTIPLIER=" + support::cpp11::to_string(output_multiplier));
+ build_opts.add_option("-DOUTPUT_SHIFT=" + support::cpp11::to_string(output_shift));
+ }
+
// Configure the local work size for Bifrost with a value obtained
// via exhaustive autotuning for the MobileNets tensor shapes.
const GPUTarget gpu_target = get_arch_from_target(get_target());
@@ -145,11 +163,11 @@ void CLDepthwiseConvolutionLayer3x3Kernel::configure(const ICLTensor *input, con
}
else
{
- kernel_name = is_data_type_quantized_asymmetric(_input->info()->data_type()) ? "depthwise_convolution_3x3_quantized" : "depthwise_convolution_3x3";
+ kernel_name = is_qasymm ? "depthwise_convolution_3x3_quantized" : "depthwise_convolution_3x3";
num_elems_written_per_iteration_x = 8 / data_size_from_type(input->info()->data_type());
- num_elems_written_per_iteration_y = 1;
+ num_elems_written_per_iteration_y = (is_qasymm && _conv_stride_y < 3) ? (2 / _conv_stride_y) : 1;
num_elems_read_per_iteration_x = 3 + (num_elems_written_per_iteration_x - 1) * _conv_stride_x;
- num_elems_read_per_iteration_y = 3;
+ num_elems_read_per_iteration_y = num_elems_written_per_iteration_y + 2;
}
// Calculate right and bottom border
@@ -175,23 +193,6 @@ void CLDepthwiseConvolutionLayer3x3Kernel::configure(const ICLTensor *input, con
_kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel(kernel_name, build_opts.options()));
- // Set static arguments
- if(is_data_type_quantized_asymmetric(_input->info()->data_type()))
- {
- float multiplier = _input->info()->quantization_info().scale * _weights->info()->quantization_info().scale / _output->info()->quantization_info().scale;
- int output_multiplier = 0;
- int output_shift = 0;
- quantization::calculate_quantized_multiplier_less_than_one(multiplier, &output_multiplier, &output_shift);
-
- unsigned int idx = 3 * num_arguments_per_3D_tensor() + ((_biases != nullptr) ? num_arguments_per_1D_tensor() : 0);
-
- _kernel.setArg(idx++, -_input->info()->quantization_info().offset);
- _kernel.setArg(idx++, -_weights->info()->quantization_info().offset);
- _kernel.setArg(idx++, _output->info()->quantization_info().offset);
- _kernel.setArg(idx++, output_multiplier);
- _kernel.setArg(idx++, output_shift);
- }
-
// Set config_id for enabling LWS tuning
_config_id = kernel_name;
_config_id += "_";
diff --git a/tests/benchmark/fixtures/DepthwiseConvolutionLayerFixture.h b/tests/benchmark/fixtures/DepthwiseConvolutionLayerFixture.h
index 8283b4d514..a156f4bc6f 100644
--- a/tests/benchmark/fixtures/DepthwiseConvolutionLayerFixture.h
+++ b/tests/benchmark/fixtures/DepthwiseConvolutionLayerFixture.h
@@ -48,10 +48,10 @@ public:
dst_shape.set(3 /* batch */, batches);
// Create tensors
- src = create_tensor<TensorType>(src_shape, data_type, 1, fixed_point_position);
- weights = create_tensor<TensorType>(weights_shape, data_type, 1, fixed_point_position);
+ src = create_tensor<TensorType>(src_shape, data_type, 1, fixed_point_position, QuantizationInfo(0.5f, 10));
+ weights = create_tensor<TensorType>(weights_shape, data_type, 1, fixed_point_position, QuantizationInfo(0.5f, 10));
biases = create_tensor<TensorType>(TensorShape(weights_shape[2]), is_data_type_quantized_asymmetric(data_type) ? DataType::S32 : data_type, 1, fixed_point_position);
- dst = create_tensor<TensorType>(dst_shape, data_type, 1, fixed_point_position);
+ dst = create_tensor<TensorType>(dst_shape, data_type, 1, fixed_point_position, QuantizationInfo(0.5f, 10));
// Create and configure function
depth_conv.configure(&src, &weights, &biases, &dst, info);