aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--tests/validation/NEON/DepthwiseConvolutionLayer.cpp2
-rw-r--r--tests/validation/fixtures/DepthwiseConvolutionLayerFixture.h82
-rw-r--r--tests/validation/reference/DepthwiseConvolutionLayer.cpp82
-rw-r--r--tests/validation/reference/DepthwiseConvolutionLayer.h4
4 files changed, 116 insertions, 54 deletions
diff --git a/tests/validation/NEON/DepthwiseConvolutionLayer.cpp b/tests/validation/NEON/DepthwiseConvolutionLayer.cpp
index c62c07bdfd..6392906037 100644
--- a/tests/validation/NEON/DepthwiseConvolutionLayer.cpp
+++ b/tests/validation/NEON/DepthwiseConvolutionLayer.cpp
@@ -53,7 +53,7 @@ RelativeTolerance<half_float::half> tolerance_f16(half_float::half(0.01)); /**<
constexpr float tolerance_num = 0.05f; /**< Tolerance number */
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-const auto depth_multipliers = framework::dataset::make("DepthMultiplier", { 1, 2, 5 });
+const auto depth_multipliers = framework::dataset::make("DepthMultiplier", { 1, 2, 5 });
const auto large_depth_multipliers = framework::dataset::make("DepthMultiplier", { 1, 2, 5, 8 });
//Activation Functions
diff --git a/tests/validation/fixtures/DepthwiseConvolutionLayerFixture.h b/tests/validation/fixtures/DepthwiseConvolutionLayerFixture.h
index 2c9b31866b..85930eb95e 100644
--- a/tests/validation/fixtures/DepthwiseConvolutionLayerFixture.h
+++ b/tests/validation/fixtures/DepthwiseConvolutionLayerFixture.h
@@ -48,32 +48,34 @@ namespace validation
{
using namespace arm_compute::misc::shape_calculator;
-template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T, typename TW>
class DepthwiseConvolutionLayerValidationGenericFixture : public framework::Fixture
{
public:
- using TBias = typename std::conditional<std::is_same<typename std::decay<T>::type, uint8_t>::value, int32_t, T>::type;
+ using TBias = typename std::conditional < std::is_same<T, uint8_t>::value || std::is_same<T, int8_t>::value, int32_t, T >::type;
public:
template <typename...>
- void setup(TensorShape in_shape, Size2D kernel_size, PadStrideInfo pad_stride_info, Size2D dilation, unsigned int depth_multiplier, DataType data_type,
- QuantizationInfo input_quantization_info, QuantizationInfo output_quantization_info, DataLayout data_layout, ActivationLayerInfo act_info)
+ void setup(TensorShape in_shape, Size2D kernel_size, PadStrideInfo pad_stride_info, Size2D dilation,
+ unsigned int depth_multiplier, DataType input_data_type, DataType weights_data_type,
+ QuantizationInfo input_quantization_info, QuantizationInfo weights_quantization_info, QuantizationInfo output_quantization_info,
+ DataLayout data_layout, ActivationLayerInfo act_info)
{
- const DataType bias_data_type = is_data_type_quantized_asymmetric(data_type) ? DataType::S32 : data_type;
+ const DataType bias_data_type = is_data_type_quantized(input_data_type) ? DataType::S32 : input_data_type;
TensorShape weights_shape(kernel_size.width, kernel_size.height);
- const TensorInfo in_info(in_shape, 1, data_type);
- const TensorInfo we_info(weights_shape, 1, data_type);
+ const TensorInfo in_info(in_shape, 1, input_data_type);
+ const TensorInfo we_info(weights_shape, 1, weights_data_type);
const TensorShape out_shape = compute_depthwise_convolution_shape(in_info, we_info, pad_stride_info, depth_multiplier, dilation);
weights_shape.set(2, out_shape.z());
const TensorShape biases_shape(weights_shape[2]);
_target = compute_target(in_shape, weights_shape, biases_shape, out_shape, pad_stride_info, dilation, depth_multiplier,
- data_type, bias_data_type, input_quantization_info, output_quantization_info, data_layout, act_info);
+ input_data_type, weights_data_type, bias_data_type, input_quantization_info, weights_quantization_info, output_quantization_info, data_layout, act_info);
_reference = compute_reference(in_shape, weights_shape, biases_shape, out_shape, pad_stride_info, dilation, depth_multiplier,
- data_type, bias_data_type, input_quantization_info, output_quantization_info, act_info);
+ input_data_type, weights_data_type, bias_data_type, input_quantization_info, weights_quantization_info, output_quantization_info, act_info);
}
protected:
@@ -88,6 +90,12 @@ protected:
library->fill(tensor, distribution, i);
break;
}
+ case DataType::QSYMM8_PER_CHANNEL:
+ {
+ std::uniform_int_distribution<int8_t> distribution(-10, 10);
+ library->fill(tensor, distribution, i);
+ break;
+ }
case DataType::F32:
case DataType::F16:
{
@@ -107,9 +115,8 @@ protected:
}
TensorType compute_target(TensorShape input_shape, TensorShape weights_shape, TensorShape biases_shape, TensorShape output_shape, PadStrideInfo &pad_stride_info, Size2D dilation,
- unsigned int depth_multiplier,
- const DataType data_type, const DataType bias_data_type,
- const QuantizationInfo &input_quantization_info, const QuantizationInfo &output_quantization_info,
+ unsigned int depth_multiplier, const DataType input_data_type, const DataType weights_data_type, const DataType bias_data_type,
+ const QuantizationInfo &input_quantization_info, const QuantizationInfo &weights_quantization_info, const QuantizationInfo &output_quantization_info,
const DataLayout data_layout, const ActivationLayerInfo &act_info)
{
if(data_layout == DataLayout::NHWC)
@@ -120,10 +127,10 @@ protected:
}
// Create tensors
- TensorType src = create_tensor<TensorType>(input_shape, data_type, 1, input_quantization_info, data_layout);
- TensorType weights = create_tensor<TensorType>(weights_shape, data_type, 1, input_quantization_info, data_layout);
+ TensorType src = create_tensor<TensorType>(input_shape, input_data_type, 1, input_quantization_info, data_layout);
+ TensorType weights = create_tensor<TensorType>(weights_shape, weights_data_type, 1, weights_quantization_info, data_layout);
TensorType biases = create_tensor<TensorType>(biases_shape, bias_data_type, 1, input_quantization_info, data_layout);
- TensorType dst = create_tensor<TensorType>(output_shape, data_type, 1, output_quantization_info, data_layout);
+ TensorType dst = create_tensor<TensorType>(output_shape, input_data_type, 1, output_quantization_info, data_layout);
// Create Depthwise Convolution configure function
FunctionType dwc;
@@ -157,14 +164,13 @@ protected:
}
SimpleTensor<T> compute_reference(const TensorShape &in_shape, const TensorShape &weights_shape, const TensorShape &biases_shape, const TensorShape &out_shape,
- const PadStrideInfo &pad_stride_info,
- const Size2D &dilation, unsigned int depth_multiplier,
- const DataType data_type, const DataType bias_data_type,
- const QuantizationInfo &input_quantization_info, const QuantizationInfo &output_quantization_info,
+ const PadStrideInfo &pad_stride_info, const Size2D &dilation, unsigned int depth_multiplier,
+ const DataType input_data_type, const DataType weights_data_type, const DataType bias_data_type,
+ const QuantizationInfo &input_quantization_info, const QuantizationInfo &weights_quantization_info, const QuantizationInfo &output_quantization_info,
const ActivationLayerInfo &act_info)
{
- SimpleTensor<T> src{ in_shape, data_type, 1, input_quantization_info };
- SimpleTensor<T> weights{ weights_shape, data_type, 1, input_quantization_info };
+ SimpleTensor<T> src{ in_shape, input_data_type, 1, input_quantization_info };
+ SimpleTensor<TW> weights{ weights_shape, weights_data_type, 1, weights_quantization_info };
SimpleTensor<TBias> biases{ biases_shape, bias_data_type, 1, input_quantization_info };
fill(src, 0);
@@ -180,20 +186,21 @@ protected:
};
template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
-class DepthwiseConvolutionLayerValidationFixture : public DepthwiseConvolutionLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>
+class DepthwiseConvolutionLayerValidationFixture : public DepthwiseConvolutionLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T, T>
{
public:
template <typename...>
void setup(TensorShape in_shape, Size2D kernel_size, PadStrideInfo pad_stride_info, Size2D dilation, unsigned int depth_multiplier, DataType data_type, DataLayout data_layout,
ActivationLayerInfo act_info)
{
- DepthwiseConvolutionLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(in_shape, kernel_size, pad_stride_info, dilation, depth_multiplier,
- data_type, QuantizationInfo(), QuantizationInfo(), data_layout, act_info);
+ DepthwiseConvolutionLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T, T>::setup(in_shape, kernel_size, pad_stride_info, dilation, depth_multiplier,
+ data_type, data_type, QuantizationInfo(), QuantizationInfo(), QuantizationInfo(),
+ data_layout, act_info);
}
};
template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
-class DepthwiseConvolutionLayerNativeValidationFixture : public DepthwiseConvolutionLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>
+class DepthwiseConvolutionLayerNativeValidationFixture : public DepthwiseConvolutionLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T, T>
{
public:
template <typename...>
@@ -302,7 +309,7 @@ protected:
};
template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
-class DepthwiseConvolutionLayerNativeConfigurableValidationFixture : public DepthwiseConvolutionLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>
+class DepthwiseConvolutionLayerNativeConfigurableValidationFixture : public DepthwiseConvolutionLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T, T>
{
public:
template <typename...>
@@ -423,15 +430,32 @@ protected:
};
template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
-class DepthwiseConvolutionLayerValidationQuantizedFixture : public DepthwiseConvolutionLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>
+class DepthwiseConvolutionLayerValidationQuantizedFixture : public DepthwiseConvolutionLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T, T>
{
public:
template <typename...>
void setup(TensorShape in_shape, Size2D kernel_size, PadStrideInfo pad_stride_info, Size2D dilation, unsigned int depth_multiplier, DataType data_type,
QuantizationInfo input_quantization_info, QuantizationInfo output_quantization_info, DataLayout data_layout, ActivationLayerInfo act_info)
{
- DepthwiseConvolutionLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(in_shape, kernel_size, pad_stride_info, dilation, depth_multiplier,
- data_type, input_quantization_info, output_quantization_info, data_layout, act_info);
+ DepthwiseConvolutionLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T, T>::setup(in_shape, kernel_size, pad_stride_info, dilation, depth_multiplier, data_type,
+ data_type, input_quantization_info, input_quantization_info, output_quantization_info,
+ data_layout, act_info);
+ }
+};
+
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T, typename TW>
+class DepthwiseConvolutionLayerValidationQuantizedPerChannelFixture : public DepthwiseConvolutionLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T, TW>
+{
+public:
+ template <typename...>
+ void setup(TensorShape in_shape, Size2D kernel_size, PadStrideInfo pad_stride_info, Size2D dilation, unsigned int depth_multiplier, DataType input_data_type, DataType weights_data_type,
+ QuantizationInfo input_quantization_info, QuantizationInfo weights_quantization_info, QuantizationInfo output_quantization_info,
+ DataLayout data_layout, ActivationLayerInfo act_info)
+ {
+ DepthwiseConvolutionLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T, TW>::setup(in_shape, kernel_size, pad_stride_info, dilation, depth_multiplier,
+ input_data_type, weights_data_type,
+ input_quantization_info, weights_quantization_info, output_quantization_info,
+ data_layout, act_info);
}
};
} // namespace validation
diff --git a/tests/validation/reference/DepthwiseConvolutionLayer.cpp b/tests/validation/reference/DepthwiseConvolutionLayer.cpp
index b1d2b923f7..7458f815b8 100644
--- a/tests/validation/reference/DepthwiseConvolutionLayer.cpp
+++ b/tests/validation/reference/DepthwiseConvolutionLayer.cpp
@@ -40,7 +40,9 @@ namespace validation
{
namespace reference
{
-/** Perform a depthwise convolution
+namespace
+{
+/** Perform a depthwise convolution for floating-point types
*
* - Three dimensions tensors
* - Third dimention is number of channels
@@ -48,9 +50,9 @@ namespace reference
* - Padding, stride and output shape "match"
*
*/
-template <typename T, typename TB>
-SimpleTensor<T> depthwise_convolution(const SimpleTensor<T> &src, const SimpleTensor<T> &weights, const SimpleTensor<TB> &biases, const TensorShape &dst_shape, const PadStrideInfo &conv_info,
- unsigned int depth_multiplier, const Size2D &dilation, const QuantizationInfo &out_quant_info)
+template <typename T>
+SimpleTensor<T> depthwise_convolution_fp(const SimpleTensor<T> &src, const SimpleTensor<T> &weights, const SimpleTensor<T> &biases, const TensorShape &dst_shape, const PadStrideInfo &conv_info,
+ unsigned int depth_multiplier, const Size2D &dilation, const QuantizationInfo &out_quant_info)
{
ARM_COMPUTE_UNUSED(out_quant_info);
@@ -114,7 +116,7 @@ SimpleTensor<T> depthwise_convolution(const SimpleTensor<T> &src, const SimpleTe
}
}
- dst[out_pos++] = saturate_cast<T>(val + *static_cast<const TB *>(biases(Coordinates(out_z))));
+ dst[out_pos++] = saturate_cast<T>(val + *static_cast<const T *>(biases(Coordinates(out_z))));
}
}
}
@@ -124,26 +126,32 @@ SimpleTensor<T> depthwise_convolution(const SimpleTensor<T> &src, const SimpleTe
return dst;
}
-template <>
-SimpleTensor<uint8_t> depthwise_convolution(const SimpleTensor<uint8_t> &src, const SimpleTensor<uint8_t> &weights, const SimpleTensor<int32_t> &biases, const TensorShape &dst_shape,
- const PadStrideInfo &conv_info, unsigned int depth_multiplier, const Size2D &dilation, const QuantizationInfo &out_quant_info)
+/** Perform a quantized depthwise convolution
+ *
+ * - Three dimensions tensors
+ * - Third dimention is number of channels
+ * - Depths of input tensor and filter are equals
+ * - Padding, stride and output shape "match"
+ * - QASYMM8 input, output
+ * - QASYMM8 or QSYMM8_PER_CHANNEL filter
+ *
+ */
+template <typename T, typename TW, typename TB>
+SimpleTensor<T> depthwise_convolution_quantized(const SimpleTensor<T> &src, const SimpleTensor<TW> &weights, const SimpleTensor<int32_t> &biases, const TensorShape &dst_shape,
+ const PadStrideInfo &conv_info, unsigned int depth_multiplier, const Size2D &dilation, const QuantizationInfo &out_quant_info)
{
// if no explicit quantization has been set you the same as src
const QuantizationInfo &dst_qinfo = out_quant_info.uniform().empty() ? src.quantization_info() : out_quant_info;
- SimpleTensor<uint8_t> dst{ dst_shape, src.data_type(), 1, dst_qinfo };
+ SimpleTensor<T> dst{ dst_shape, src.data_type(), 1, dst_qinfo };
// Create reference
const int input_offset = -src.quantization_info().uniform().offset;
const float input_scale = src.quantization_info().uniform().scale;
const int weights_offset = -weights.quantization_info().uniform().offset;
- const float weights_scale = weights.quantization_info().uniform().scale;
const int output_offset = dst_qinfo.uniform().offset;
const float output_scale = dst_qinfo.uniform().scale;
- int output_multiplier = 0;
- int output_shift = 0;
- const float multiplier = input_scale * weights_scale / output_scale;
- arm_compute::quantization::calculate_quantized_multiplier_less_than_one(multiplier, &output_multiplier, &output_shift);
+ const std::vector<float> weights_scale_vec = weights.quantization_info().scale();
// Compute reference
const int filter_width = weights.shape().x();
@@ -173,11 +181,19 @@ SimpleTensor<uint8_t> depthwise_convolution(const SimpleTensor<uint8_t> &src, co
const int maximum_x = input_width + pad_left + pad_right - static_cast<int>(patch_width);
const int maximum_y = input_height + pad_top + pad_bottom - static_cast<int>(patch_height);
+ const bool is_quantized_per_channel = is_data_type_quantized_per_channel(weights.data_type());
+
int out_pos = 0;
for(int r = 0; r < num_batches; ++r)
{
for(int z = 0; z < input_depth; ++z)
{
+ int output_multiplier = 0;
+ int output_shift = 0;
+ const float weights_scale = (is_quantized_per_channel) ? weights_scale_vec[z] : weights_scale_vec[0];
+ const float multiplier = input_scale * weights_scale / output_scale;
+ arm_compute::quantization::calculate_quantized_multiplier_less_than_one(multiplier, &output_multiplier, &output_shift);
+
for(unsigned int m = 0; m < depth_multiplier; ++m)
{
const int out_z = z * depth_multiplier + m;
@@ -197,8 +213,8 @@ SimpleTensor<uint8_t> depthwise_convolution(const SimpleTensor<uint8_t> &src, co
{
coords.set(0, i);
coords.set(1, j);
- const auto in_val = tensor_elem_at<uint8_t>(src, coords, BorderMode::CONSTANT, -input_offset);
- const uint8_t w_val = *(weights.data() + filter_offset);
+ const auto in_val = tensor_elem_at<T>(src, coords, BorderMode::CONSTANT, -input_offset);
+ const TW w_val = *(weights.data() + filter_offset);
val += (in_val + input_offset) * (w_val + weights_offset);
++filter_offset;
}
@@ -206,8 +222,7 @@ SimpleTensor<uint8_t> depthwise_convolution(const SimpleTensor<uint8_t> &src, co
val += bias_val;
val = asymm_rounding_divide_by_pow2(asymm_int_mult(val, output_multiplier), output_shift);
val += output_offset;
- val = std::max<int32_t>(val, 0);
- val = std::min<int32_t>(val, 255);
+ val = utility::clamp<int32_t>(val, 0, 255);
// Store the result
dst[out_pos++] = val;
@@ -219,12 +234,35 @@ SimpleTensor<uint8_t> depthwise_convolution(const SimpleTensor<uint8_t> &src, co
return dst;
}
+} // namespace
+
+template <>
+SimpleTensor<float> depthwise_convolution(const SimpleTensor<float> &src, const SimpleTensor<float> &weights, const SimpleTensor<float> &biases, const TensorShape &dst_shape,
+ const PadStrideInfo &conv_info, unsigned int depth_multiplier, const Size2D &dilation, const QuantizationInfo &out_quant_info)
+{
+ return depthwise_convolution_fp(src, weights, biases, dst_shape, conv_info, depth_multiplier, dilation, out_quant_info);
+}
-template SimpleTensor<float> depthwise_convolution(const SimpleTensor<float> &src, const SimpleTensor<float> &weights, const SimpleTensor<float> &biases, const TensorShape &dst_shape,
- const PadStrideInfo &conv_info, unsigned int depth_multiplier, const Size2D &dilation, const QuantizationInfo &out_quant_info);
+template <>
+SimpleTensor<half> depthwise_convolution(const SimpleTensor<half> &src, const SimpleTensor<half> &weights, const SimpleTensor<half> &biases, const TensorShape &dst_shape,
+ const PadStrideInfo &conv_info, unsigned int depth_multiplier, const Size2D &dilation, const QuantizationInfo &out_quant_info)
+{
+ return depthwise_convolution_fp(src, weights, biases, dst_shape, conv_info, depth_multiplier, dilation, out_quant_info);
+}
-template SimpleTensor<half> depthwise_convolution(const SimpleTensor<half> &src, const SimpleTensor<half> &weights, const SimpleTensor<half> &biases, const TensorShape &dst_shape,
- const PadStrideInfo &conv_info, unsigned int depth_multiplier, const Size2D &dilation, const QuantizationInfo &out_quant_info);
+template <>
+SimpleTensor<uint8_t> depthwise_convolution(const SimpleTensor<uint8_t> &src, const SimpleTensor<uint8_t> &weights, const SimpleTensor<int32_t> &biases, const TensorShape &dst_shape,
+ const PadStrideInfo &conv_info, unsigned int depth_multiplier, const Size2D &dilation, const QuantizationInfo &out_quant_info)
+{
+ return depthwise_convolution_quantized<uint8_t, uint8_t, int32_t>(src, weights, biases, dst_shape, conv_info, depth_multiplier, dilation, out_quant_info);
+}
+
+template <>
+SimpleTensor<uint8_t> depthwise_convolution(const SimpleTensor<uint8_t> &src, const SimpleTensor<int8_t> &weights, const SimpleTensor<int32_t> &biases, const TensorShape &dst_shape,
+ const PadStrideInfo &conv_info, unsigned int depth_multiplier, const Size2D &dilation, const QuantizationInfo &out_quant_info)
+{
+ return depthwise_convolution_quantized<uint8_t, int8_t, int32_t>(src, weights, biases, dst_shape, conv_info, depth_multiplier, dilation, out_quant_info);
+}
} // namespace reference
} // namespace validation
} // namespace test
diff --git a/tests/validation/reference/DepthwiseConvolutionLayer.h b/tests/validation/reference/DepthwiseConvolutionLayer.h
index ee323fa8df..38a225a1ae 100644
--- a/tests/validation/reference/DepthwiseConvolutionLayer.h
+++ b/tests/validation/reference/DepthwiseConvolutionLayer.h
@@ -35,8 +35,8 @@ namespace validation
{
namespace reference
{
-template <typename T, typename TB>
-SimpleTensor<T> depthwise_convolution(const SimpleTensor<T> &src, const SimpleTensor<T> &weights, const SimpleTensor<TB> &biases, const TensorShape &dst_shape, const PadStrideInfo &conv_info,
+template <typename T, typename TW, typename TB>
+SimpleTensor<T> depthwise_convolution(const SimpleTensor<T> &src, const SimpleTensor<TW> &weights, const SimpleTensor<TB> &biases, const TensorShape &dst_shape, const PadStrideInfo &conv_info,
unsigned int depth_multiplier, const Size2D &dilation = Size2D(1U, 1U), const QuantizationInfo &out_quant_info = QuantizationInfo(0.0f, 0));
} // namespace reference
} // namespace validation