aboutsummaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
authorGiorgio Arena <giorgio.arena@arm.com>2021-10-19 15:45:57 +0100
committerSheri Zhang <sheri.zhang@arm.com>2021-10-20 15:54:24 +0000
commit51847d5dd9cad6bc81673642a01fd531def44311 (patch)
tree1c5b79334d054141308c1e03e05749ac33ed5d34 /tests
parent7a8cf1707e45ea011e0db5c0b3091c381ccd387f (diff)
downloadComputeLibrary-51847d5dd9cad6bc81673642a01fd531def44311.tar.gz
Implement CLDirectConv3DKernel - uint8/int8
Resolve COMPMID-4663 Signed-off-by: Giorgio Arena <giorgio.arena@arm.com> Change-Id: I5c3c1cffed5385c06b789543318f7f4d6096987e Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/6468 Tested-by: Arm Jenkins <bsgcomp@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Sheri Zhang <sheri.zhang@arm.com>
Diffstat (limited to 'tests')
-rw-r--r--tests/validation/CL/Convolution3D.cpp122
-rw-r--r--tests/validation/fixtures/DirectConvolution3DFixture.h51
-rw-r--r--tests/validation/reference/Conv3D.cpp110
-rw-r--r--tests/validation/reference/Conv3D.h4
4 files changed, 253 insertions, 34 deletions
diff --git a/tests/validation/CL/Convolution3D.cpp b/tests/validation/CL/Convolution3D.cpp
index 75e2e99b03..381aacc465 100644
--- a/tests/validation/CL/Convolution3D.cpp
+++ b/tests/validation/CL/Convolution3D.cpp
@@ -38,10 +38,11 @@ namespace validation
{
namespace
{
-RelativeTolerance<half> tolerance_fp16(half(0.2)); /**< Tolerance for floating point tests */
-RelativeTolerance<float> tolerance_fp32(0.05f); /**< Tolerance for floating point tests */
-constexpr float abs_tolerance_f32(0.0001f); /**< Absolute tolerance for FP32 tests*/
-constexpr float tolerance_num = 0.07f; /**< Tolerance number */
+RelativeTolerance<half> tolerance_fp16(half(0.2)); /**< Tolerance for floating point tests */
+RelativeTolerance<float> tolerance_fp32(0.05f); /**< Tolerance for floating point tests */
+constexpr AbsoluteTolerance<uint8_t> tolerance_qasymm8(1); /**< Tolerance for quantized tests */
+constexpr float abs_tolerance_f32(0.0001f); /**< Absolute tolerance for FP32 tests*/
+constexpr float tolerance_num = 0.07f; /**< Tolerance number */
} // namespace
TEST_SUITE(CL)
@@ -165,6 +166,8 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip(zip(zip(zi
template <typename T>
using CLDirectConvolution3DFixture = DirectConvolution3DValidationFixture<CLTensor, CLAccessor, CLConv3D, T>;
+template <typename T>
+using CLDirectConvolution3DQuantizedFixture = DirectConvolution3DValidationQuantizedFixture<CLTensor, CLAccessor, CLConv3D, T>;
TEST_SUITE(NDHWC)
TEST_SUITE(FP16)
@@ -266,6 +269,117 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLDirectConvolution3DFixture<float>, framework:
// clang-format on
// *INDENT-ON*
TEST_SUITE_END() // FP32
+
+TEST_SUITE(QASYMM8)
+FIXTURE_DATA_TEST_CASE(RunSmall, CLDirectConvolution3DQuantizedFixture<uint8_t>, framework::DatasetMode::PRECOMMIT,
+ combine(combine(combine(combine(combine(combine(zip(zip(zip(zip(zip(zip(zip(zip(zip(zip(zip(
+ framework::dataset::make("InputShape", { TensorShape(7U, 5U, 3U, 13U, 3U),
+ TensorShape(15U, 7U, 11U, 7U),
+ TensorShape(19U, 5U, 16U, 4U),
+ TensorShape(13U, 5U, 17U, 2U)
+ }),
+ framework::dataset::make("StrideX", { 1, 3, 2, 1 })),
+ framework::dataset::make("StrideY", { 2, 1, 3, 1 })),
+ framework::dataset::make("StrideZ", { 3, 2, 1, 1 })),
+ framework::dataset::make("PadX", { 0, 2, 1, 0 })),
+ framework::dataset::make("PadY", { 1, 0, 2, 0 })),
+ framework::dataset::make("PadZ", { 2, 1, 0, 0 })),
+ framework::dataset::make("KernelWidth", { 3, 7, 5, 1 })),
+ framework::dataset::make("KernelHeight", { 5, 3, 7, 1 })),
+ framework::dataset::make("KernelDepth", { 7, 5, 3, 1 })),
+ framework::dataset::make("NumKernels", { 5, 3, 1, 11 })),
+ framework::dataset::make("HasBias", { true, true, true, false })),
+ framework::dataset::make("Activation", ActivationLayerInfo())),
+ framework::dataset::make("DataType", DataType::QASYMM8)),
+ framework::dataset::make("DataLayout", DataLayout::NDHWC)),
+ framework::dataset::make("SrcQuantizationInfo", QuantizationInfo(0.1f, 10))),
+ framework::dataset::make("WeightsQuantizationInfo", QuantizationInfo(0.3f, 20))),
+ framework::dataset::make("DstQuantizationInfo", QuantizationInfo(0.2f, 5))))
+{
+ validate(CLAccessor(_target), _reference, tolerance_qasymm8);
+}
+
+FIXTURE_DATA_TEST_CASE(RunLarge, CLDirectConvolution3DQuantizedFixture<uint8_t>, framework::DatasetMode::NIGHTLY,
+ combine(combine(combine(combine(combine(combine(zip(zip(zip(zip(zip(zip(zip(zip(zip(zip(zip(
+ framework::dataset::make("InputShape", { TensorShape(400U, 400U, 200U, 11U) }),
+ framework::dataset::make("StrideX", { 1 })),
+ framework::dataset::make("StrideY", { 1 })),
+ framework::dataset::make("StrideZ", { 1 })),
+ framework::dataset::make("PadX", { 1 })),
+ framework::dataset::make("PadY", { 1 })),
+ framework::dataset::make("PadZ", { 1 })),
+ framework::dataset::make("KernelWidth", { 9 })),
+ framework::dataset::make("KernelHeight", { 9 })),
+ framework::dataset::make("KernelDepth", { 9 })),
+ framework::dataset::make("NumKernels", { 300 })),
+ framework::dataset::make("HasBias", { true })),
+ framework::dataset::make("Activation", ActivationLayerInfo())),
+ framework::dataset::make("DataType", DataType::QASYMM8)),
+ framework::dataset::make("DataLayout", DataLayout::NDHWC)),
+ framework::dataset::make("SrcQuantizationInfo", QuantizationInfo(0.1f, 10))),
+ framework::dataset::make("WeightsQuantizationInfo", QuantizationInfo(0.3f, 20))),
+ framework::dataset::make("DstQuantizationInfo", QuantizationInfo(0.2f, 5))))
+{
+ validate(CLAccessor(_target), _reference, tolerance_qasymm8);
+}
+
+TEST_SUITE_END() // QASYMM8
+
+TEST_SUITE(QASYMM8_SIGNED)
+FIXTURE_DATA_TEST_CASE(RunSmall, CLDirectConvolution3DQuantizedFixture<int8_t>, framework::DatasetMode::PRECOMMIT,
+ combine(combine(combine(combine(combine(combine(zip(zip(zip(zip(zip(zip(zip(zip(zip(zip(zip(
+ framework::dataset::make("InputShape", { TensorShape(7U, 5U, 3U, 13U, 3U),
+ TensorShape(15U, 7U, 11U, 7U),
+ TensorShape(19U, 5U, 16U, 4U),
+ TensorShape(13U, 5U, 17U, 2U)
+ }),
+ framework::dataset::make("StrideX", { 1, 3, 2, 1 })),
+ framework::dataset::make("StrideY", { 2, 1, 3, 1 })),
+ framework::dataset::make("StrideZ", { 3, 2, 1, 1 })),
+ framework::dataset::make("PadX", { 0, 2, 1, 0 })),
+ framework::dataset::make("PadY", { 1, 0, 2, 0 })),
+ framework::dataset::make("PadZ", { 2, 1, 0, 0 })),
+ framework::dataset::make("KernelWidth", { 3, 7, 5, 1 })),
+ framework::dataset::make("KernelHeight", { 5, 3, 7, 1 })),
+ framework::dataset::make("KernelDepth", { 7, 5, 3, 1 })),
+ framework::dataset::make("NumKernels", { 5, 3, 1, 11 })),
+ framework::dataset::make("HasBias", { true, true, true, false })),
+ framework::dataset::make("Activation", ActivationLayerInfo())),
+ framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)),
+ framework::dataset::make("DataLayout", DataLayout::NDHWC)),
+ framework::dataset::make("SrcQuantizationInfo", QuantizationInfo(0.1f, 10))),
+ framework::dataset::make("WeightsQuantizationInfo", QuantizationInfo(0.3f, 20))),
+ framework::dataset::make("DstQuantizationInfo", QuantizationInfo(0.2f, 5))))
+{
+ validate(CLAccessor(_target), _reference, tolerance_qasymm8);
+}
+
+FIXTURE_DATA_TEST_CASE(RunLarge, CLDirectConvolution3DQuantizedFixture<int8_t>, framework::DatasetMode::NIGHTLY,
+ combine(combine(combine(combine(combine(combine(zip(zip(zip(zip(zip(zip(zip(zip(zip(zip(zip(
+ framework::dataset::make("InputShape", { TensorShape(400U, 400U, 200U, 11U) }),
+ framework::dataset::make("StrideX", { 1 })),
+ framework::dataset::make("StrideY", { 1 })),
+ framework::dataset::make("StrideZ", { 1 })),
+ framework::dataset::make("PadX", { 1 })),
+ framework::dataset::make("PadY", { 1 })),
+ framework::dataset::make("PadZ", { 1 })),
+ framework::dataset::make("KernelWidth", { 9 })),
+ framework::dataset::make("KernelHeight", { 9 })),
+ framework::dataset::make("KernelDepth", { 9 })),
+ framework::dataset::make("NumKernels", { 300 })),
+ framework::dataset::make("HasBias", { true })),
+ framework::dataset::make("Activation", ActivationLayerInfo())),
+ framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)),
+ framework::dataset::make("DataLayout", DataLayout::NDHWC)),
+ framework::dataset::make("SrcQuantizationInfo", QuantizationInfo(0.1f, 10))),
+ framework::dataset::make("WeightsQuantizationInfo", QuantizationInfo(0.3f, 20))),
+ framework::dataset::make("DstQuantizationInfo", QuantizationInfo(0.2f, 5))))
+{
+ validate(CLAccessor(_target), _reference, tolerance_qasymm8);
+}
+
+TEST_SUITE_END() // QASYMM8_SIGNED
+
TEST_SUITE_END() // NDHWC
TEST_SUITE_END() // DirectConvolution3D
TEST_SUITE_END() // CL
diff --git a/tests/validation/fixtures/DirectConvolution3DFixture.h b/tests/validation/fixtures/DirectConvolution3DFixture.h
index 3a675ac6d3..2250dcaeb0 100644
--- a/tests/validation/fixtures/DirectConvolution3DFixture.h
+++ b/tests/validation/fixtures/DirectConvolution3DFixture.h
@@ -40,19 +40,23 @@ template <typename TensorType, typename AccessorType, typename FunctionType, typ
class DirectConvolution3DValidationGenericFixture : public framework::Fixture
{
public:
+ using TBias = typename std::conditional < std::is_same<T, uint8_t>::value || std::is_same<T, int8_t>::value, int32_t, T >::type;
+
template <typename...>
void setup(const TensorShape &input_shape, int stride_x, int stride_y, int stride_z, int pad_x, int pad_y, int pad_z, unsigned int kernel_width, int kernel_height, int kernel_depth,
- unsigned int num_kernels, bool has_bias, const ActivationLayerInfo &act_info, const DataType &data_type, const DataLayout &data_layout)
+ unsigned int num_kernels, bool has_bias, const ActivationLayerInfo &act_info, const DataType &data_type, const DataLayout &data_layout,
+ const QuantizationInfo &src_qinfo = QuantizationInfo(), const QuantizationInfo &weights_qinfo = QuantizationInfo(), const QuantizationInfo &dst_qinfo = QuantizationInfo())
{
ARM_COMPUTE_ERROR_ON(data_layout != DataLayout::NDHWC);
const TensorShape weights_shape(num_kernels, input_shape[0], kernel_width, kernel_height, kernel_depth);
const TensorShape bias_shape(num_kernels);
+ const DataType bias_data_type = is_data_type_quantized(data_type) ? DataType::S32 : data_type;
const Conv3dInfo conv3d_info(Size3D(stride_x, stride_y, stride_z), Padding3D(pad_x, pad_y, pad_z), act_info, Size3D(1U, 1U, 1U), DimensionRoundingType::FLOOR, false);
const TensorShape output_shape = compute_conv3d_shape(input_shape, weights_shape, conv3d_info);
- _target = compute_target(input_shape, weights_shape, bias_shape, output_shape, conv3d_info, has_bias, data_type, data_layout);
- _reference = compute_reference(input_shape, weights_shape, bias_shape, output_shape, conv3d_info, has_bias, data_type);
+ _target = compute_target(input_shape, weights_shape, bias_shape, output_shape, conv3d_info, has_bias, data_type, bias_data_type, data_layout, src_qinfo, weights_qinfo, dst_qinfo);
+ _reference = compute_reference(input_shape, weights_shape, bias_shape, output_shape, conv3d_info, has_bias, data_type, bias_data_type, src_qinfo, weights_qinfo, dst_qinfo);
}
protected:
@@ -79,13 +83,14 @@ protected:
}
TensorType compute_target(const TensorShape &input_shape, const TensorShape &weights_shape, const TensorShape &bias_shape, const TensorShape &output_shape, const Conv3dInfo &conv3d_info,
- bool has_bias, const DataType &data_type, const DataLayout &data_layout)
+ bool has_bias, const DataType &data_type, const DataType &bias_data_type, const DataLayout &data_layout, const QuantizationInfo &src_qinfo,
+ const QuantizationInfo &weights_qinfo, const QuantizationInfo &dst_qinfo)
{
// Create tensors
- TensorType src = create_tensor<TensorType>(input_shape, data_type, 1, QuantizationInfo(), data_layout);
- TensorType weights = create_tensor<TensorType>(weights_shape, data_type, 1, QuantizationInfo(), data_layout);
- TensorType bias = has_bias ? create_tensor<TensorType>(bias_shape, data_type, 1, QuantizationInfo()) : TensorType();
- TensorType dst = create_tensor<TensorType>(output_shape, data_type, 1, QuantizationInfo(), data_layout);
+ TensorType src = create_tensor<TensorType>(input_shape, data_type, 1, src_qinfo, data_layout);
+ TensorType weights = create_tensor<TensorType>(weights_shape, data_type, 1, weights_qinfo, data_layout);
+ TensorType bias = has_bias ? create_tensor<TensorType>(bias_shape, bias_data_type, 1, QuantizationInfo()) : TensorType();
+ TensorType dst = create_tensor<TensorType>(output_shape, data_type, 1, dst_qinfo, data_layout);
// Create and configure function
FunctionType conv{};
@@ -122,14 +127,15 @@ protected:
return dst;
}
- SimpleTensor<T> compute_reference(const TensorShape &input_shape, const TensorShape &weights_shape, const TensorShape &bias_shape, const TensorShape &output_shape, const Conv3dInfo &conv3d_info,
- bool has_bias, const DataType &data_type)
+ SimpleTensor<T> compute_reference(const TensorShape &input_shape, const TensorShape &weights_shape, const TensorShape &bias_shape, const TensorShape &output_shape,
+ const Conv3dInfo &conv3d_info, bool has_bias, const DataType &data_type, const DataType &bias_data_type, const QuantizationInfo &src_qinfo,
+ const QuantizationInfo &weights_qinfo, const QuantizationInfo &dst_qinfo)
{
// Create reference
- SimpleTensor<T> src{ input_shape, data_type };
- SimpleTensor<T> weights{ weights_shape, data_type };
- SimpleTensor<T> bias{ bias_shape, data_type };
- SimpleTensor<T> dst{ output_shape, data_type };
+ SimpleTensor<T> src{ input_shape, data_type, 1, src_qinfo };
+ SimpleTensor<T> weights{ weights_shape, data_type, 1, weights_qinfo };
+ SimpleTensor<TBias> bias{ bias_shape, bias_data_type };
+ SimpleTensor<T> dst{ output_shape, data_type, 1, dst_qinfo };
// Fill reference
fill(src, 0);
@@ -140,7 +146,7 @@ protected:
fill(bias, 2);
}
- return reference::activation_layer(reference::conv3d<T>(src, weights, bias, dst, conv3d_info), conv3d_info.act_info);
+ return reference::activation_layer(reference::conv3d<T, TBias>(src, weights, bias, dst, conv3d_info), conv3d_info.act_info);
}
TensorType _target{};
@@ -159,6 +165,21 @@ public:
kernel_depth, num_kernels, has_bias, act_info, data_type, data_layout);
}
};
+
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
+class DirectConvolution3DValidationQuantizedFixture : public DirectConvolution3DValidationGenericFixture<TensorType, AccessorType, FunctionType, T>
+{
+public:
+ template <typename...>
+ void setup(TensorShape input_shape, int stride_x, int stride_y, int stride_z, int pad_x, int pad_y, int pad_z, unsigned int kernel_width, int kernel_height, int kernel_depth,
+ unsigned int num_kernels, bool has_bias, ActivationLayerInfo act_info, DataType data_type, DataLayout data_layout, QuantizationInfo src_qinfo, QuantizationInfo weights_qinfo,
+ QuantizationInfo dst_qinfo)
+ {
+ DirectConvolution3DValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(input_shape, stride_x, stride_y, stride_z, pad_x, pad_y, pad_z, kernel_width, kernel_height,
+ kernel_depth, num_kernels, has_bias, act_info, data_type, data_layout, src_qinfo,
+ weights_qinfo, dst_qinfo);
+ }
+};
} // namespace validation
} // namespace test
} // namespace arm_compute
diff --git a/tests/validation/reference/Conv3D.cpp b/tests/validation/reference/Conv3D.cpp
index ad61105b36..706059d1cb 100644
--- a/tests/validation/reference/Conv3D.cpp
+++ b/tests/validation/reference/Conv3D.cpp
@@ -22,7 +22,11 @@
* SOFTWARE.
*/
#include "Conv3D.h"
+
#include "arm_compute/core/utils/misc/ShapeCalculator.h"
+#include "arm_compute/core/utils/quantization/AsymmHelpers.h"
+#include "support/Requires.h"
+#include "tests/validation/reference/UtilsQuantizedAsymm.h"
// Source/Destination Tensor shape indices (N D H W C)
constexpr unsigned int batch_dim = 4u;
@@ -52,11 +56,14 @@ inline bool is_valid_pixel(int i, int min, int max)
{
return (i >= min && i < max);
}
+
// Evaluate the weights against an element in a given tensor.
-template <typename T>
-T calculate_conv3d(const SimpleTensor<T> &src, const SimpleTensor<T> &weights, const Size3D &dilation, int batch,
- int z_start, int y_start, int x_start, int ch_out)
+template < typename T, typename TB, typename std::enable_if < validation::is_floating_point<T>::value &&validation::is_floating_point<TB>::value, int >::type = 0 >
+T calculate_conv3d(const SimpleTensor<T> &src, const SimpleTensor<T> &weights, const SimpleTensor<TB> &bias, const Size3D &dilation, int batch,
+ int z_start, int y_start, int x_start, int ch_out, UniformQuantizationInfo oq_info)
{
+ ARM_COMPUTE_UNUSED(oq_info);
+
const unsigned int weights_width = weights.shape()[weights_width_dim];
const unsigned int weights_height = weights.shape()[weights_height_dim];
const unsigned int weights_depth = weights.shape()[weights_depth_dim];
@@ -101,12 +108,89 @@ T calculate_conv3d(const SimpleTensor<T> &src, const SimpleTensor<T> &weights, c
}
}
}
- return total;
+
+ const TB *b_ptr = bias.data();
+ TB bias_value = b_ptr[ch_out];
+
+ return total + bias_value;
}
+
+template < typename T, typename TB, ARM_COMPUTE_REQUIRES_TA(std::is_same<T, uint8_t>::value || std::is_same<T, int8_t>::value) >
+T calculate_conv3d(const SimpleTensor<T> &src, const SimpleTensor<T> &weights, const SimpleTensor<TB> &bias, const Size3D &dilation, int batch,
+ int z_start, int y_start, int x_start, int ch_out, UniformQuantizationInfo oq_info)
+{
+ const unsigned int weights_width = weights.shape()[weights_width_dim];
+ const unsigned int weights_height = weights.shape()[weights_height_dim];
+ const unsigned int weights_depth = weights.shape()[weights_depth_dim];
+
+ const unsigned int src_channels = src.shape()[channel_dim];
+ const unsigned int src_width = src.shape()[width_dim];
+ const unsigned int src_height = src.shape()[height_dim];
+ const unsigned int src_depth = src.shape()[depth_dim];
+
+ const UniformQuantizationInfo iq_info = src.quantization_info().uniform();
+ const UniformQuantizationInfo wq_info = weights.quantization_info().uniform();
+
+ const int input_offset = -iq_info.offset;
+ const float input_scale = iq_info.scale;
+ int weights_offset = -wq_info.offset;
+ float weights_scale = wq_info.scale;
+ const int output_offset = oq_info.offset;
+ const float output_scale = oq_info.scale;
+
+ int output_multiplier = 0;
+ int output_shift = 0;
+ const float multiplier = input_scale * weights_scale / output_scale;
+ arm_compute::quantization::calculate_quantized_multiplier(multiplier, &output_multiplier, &output_shift);
+
+ int32_t total(0);
+ for(unsigned int weight_d = 0; weight_d < weights_depth; ++weight_d)
+ {
+ const int idx_z = z_start + dilation.depth * weight_d;
+ for(unsigned int weight_y = 0; weight_y < weights_height; ++weight_y)
+ {
+ const int idx_y = y_start + dilation.height * weight_y;
+ for(unsigned int weight_x = 0; weight_x < weights_width; ++weight_x)
+ {
+ const int idx_x = x_start + dilation.width * weight_x;
+
+ //Check if the point is within padding
+ const bool is_x_valid = is_valid_pixel(idx_x, 0, src_width);
+ const bool is_y_valid = is_valid_pixel(idx_y, 0, src_height);
+ const bool is_z_valid = is_valid_pixel(idx_z, 0, src_depth);
+ const bool is_invalid_pixel = !(is_x_valid && is_y_valid && is_z_valid);
+ if(is_invalid_pixel)
+ {
+ continue;
+ }
+
+ for(unsigned int ch_in = 0; ch_in < src_channels; ++ch_in)
+ {
+ const T *in_ptr = src.data();
+ const T *w_ptr = weights.data();
+
+ const int in_offset = coord2index(src.shape(), Coordinates{ ch_in, idx_x, idx_y, idx_z, batch });
+ const int weight_offset = coord2index(weights.shape(), Coordinates{ ch_out, ch_in, weight_x, weight_y, weight_d });
+ T input_value = in_ptr[in_offset];
+ T weight_value = w_ptr[weight_offset];
+ total += ((input_value + input_offset) * (weight_value + weights_offset));
+ }
+ }
+ }
+ }
+
+ const TB *b_ptr = bias.data();
+ TB bias_value = b_ptr[ch_out];
+
+ total += bias_value;
+
+ return validation::quantize_down_scale_by_fixedpoint(total, output_multiplier, output_shift, output_offset,
+ std::numeric_limits<T>::lowest(), std::numeric_limits<T>::max());
}
+} // namespace
-template <typename T>
-SimpleTensor<T> conv3d(const SimpleTensor<T> &src, const SimpleTensor<T> &weights, const SimpleTensor<T> &bias, SimpleTensor<T> &dst, const Conv3dInfo &conv3d_info)
+template <typename T, typename TB>
+SimpleTensor<T> conv3d(const SimpleTensor<T> &src, const SimpleTensor<T> &weights, const SimpleTensor<TB> &bias, SimpleTensor<T> &dst, const Conv3dInfo &conv3d_info)
{
// Compute reference
const unsigned int batch_size = src.shape()[batch_dim];
@@ -150,14 +234,10 @@ SimpleTensor<T> conv3d(const SimpleTensor<T> &src, const SimpleTensor<T> &weight
const int x_start = (x_out * stride_x) - pad_left;
for(unsigned int ch_out = 0; ch_out < dst_channels; ++ch_out)
{
- T weighted_value = calculate_conv3d<T>(src, weights, conv3d_info.dilation, batch, z_start,
- y_start, x_start, ch_out);
- T *out_ptr = dst.data();
- const T *b_ptr = bias.data();
- T bias_value(0);
+ T *out_ptr = dst.data();
+
const int out_offset = coord2index(dst.shape(), Coordinates{ ch_out, x_out, y_out, z_out, batch });
- bias_value = b_ptr[ch_out];
- out_ptr[out_offset] = weighted_value + bias_value;
+ out_ptr[out_offset] = calculate_conv3d<T, TB>(src, weights, bias, conv3d_info.dilation, batch, z_start, y_start, x_start, ch_out, dst.quantization_info().uniform());
}
}
}
@@ -170,6 +250,10 @@ template SimpleTensor<float> conv3d(const SimpleTensor<float> &src, const Simple
const Conv3dInfo &conv3d_info);
template SimpleTensor<half> conv3d(const SimpleTensor<half> &src, const SimpleTensor<half> &weights, const SimpleTensor<half> &bias, SimpleTensor<half> &dst,
const Conv3dInfo &conv3d_info);
+template SimpleTensor<uint8_t> conv3d(const SimpleTensor<uint8_t> &src, const SimpleTensor<uint8_t> &weights, const SimpleTensor<int32_t> &bias, SimpleTensor<uint8_t> &dst,
+ const Conv3dInfo &conv3d_info);
+template SimpleTensor<int8_t> conv3d(const SimpleTensor<int8_t> &src, const SimpleTensor<int8_t> &weights, const SimpleTensor<int32_t> &bias, SimpleTensor<int8_t> &dst,
+ const Conv3dInfo &conv3d_info);
} // namespace reference
} // namespace validation
} // namespace test
diff --git a/tests/validation/reference/Conv3D.h b/tests/validation/reference/Conv3D.h
index ade8a2c242..e3674f4bfb 100644
--- a/tests/validation/reference/Conv3D.h
+++ b/tests/validation/reference/Conv3D.h
@@ -37,8 +37,8 @@ namespace validation
{
namespace reference
{
-template <typename T>
-SimpleTensor<T> conv3d(const SimpleTensor<T> &src, const SimpleTensor<T> &weights, const SimpleTensor<T> &bias, SimpleTensor<T> &dst,
+template <typename T, typename TB>
+SimpleTensor<T> conv3d(const SimpleTensor<T> &src, const SimpleTensor<T> &weights, const SimpleTensor<TB> &bias, SimpleTensor<T> &dst,
const Conv3dInfo &conv3d_info);
} // namespace reference
} // namespace validation