aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures/DepthwiseConvolutionLayerFixture.h
diff options
context:
space:
mode:
authorGiorgio Arena <giorgio.arena@arm.com>2018-03-26 16:20:05 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:49:37 +0000
commit1ed1fc6d3b7d8494ce3bbc5f8b46bfde6fc586f9 (patch)
treedc299cf46073d2bdd5a3a0252935ede216cf332e /tests/validation/fixtures/DepthwiseConvolutionLayerFixture.h
parent9373c8b2650f34b2804d3685588bad8e408ebe63 (diff)
downloadComputeLibrary-1ed1fc6d3b7d8494ce3bbc5f8b46bfde6fc586f9.tar.gz
COMPMID-812 Add NHWC data format support for NEON depthwise convolution (optimized case).
Change-Id: Icdfd6c02ed526daf4f59a4b76c7bbc1bc48fde74 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/125938 Tested-by: Jenkins <bsgcomp@arm.com> Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Diffstat (limited to 'tests/validation/fixtures/DepthwiseConvolutionLayerFixture.h')
-rw-r--r--tests/validation/fixtures/DepthwiseConvolutionLayerFixture.h37
1 files changed, 22 insertions, 15 deletions
diff --git a/tests/validation/fixtures/DepthwiseConvolutionLayerFixture.h b/tests/validation/fixtures/DepthwiseConvolutionLayerFixture.h
index df5436fcf7..ccdd443999 100644
--- a/tests/validation/fixtures/DepthwiseConvolutionLayerFixture.h
+++ b/tests/validation/fixtures/DepthwiseConvolutionLayerFixture.h
@@ -52,15 +52,22 @@ public:
public:
template <typename...>
- void setup(TensorShape in_shape, TensorShape weights_shape, TensorShape out_shape, PadStrideInfo pad_stride_info, DataType data_type, QuantizationInfo quantization_info)
+ void setup(TensorShape in_shape, TensorShape weights_shape, TensorShape out_shape, PadStrideInfo pad_stride_info, DataType data_type, QuantizationInfo quantization_info, DataLayout data_layout)
{
_quantization_info = quantization_info;
_data_type = data_type;
const TensorShape biases_shape(weights_shape[2]);
const DataType bias_data_type = is_data_type_quantized_asymmetric(data_type) ? DataType::S32 : data_type;
- _target = compute_target(in_shape, weights_shape, biases_shape, out_shape, pad_stride_info, data_type, bias_data_type, quantization_info);
- _reference = compute_reference(in_shape, weights_shape, biases_shape, out_shape, pad_stride_info, data_type, bias_data_type, quantization_info);
+ if(data_layout == DataLayout::NHWC)
+ {
+ permute(in_shape, PermutationVector(2U, 0U, 1U));
+ permute(weights_shape, PermutationVector(2U, 0U, 1U));
+ permute(out_shape, PermutationVector(2U, 0U, 1U));
+ }
+
+ _target = compute_target(in_shape, weights_shape, biases_shape, out_shape, pad_stride_info, data_type, bias_data_type, quantization_info, data_layout);
+ _reference = compute_reference(in_shape, weights_shape, biases_shape, out_shape, pad_stride_info, data_type, bias_data_type, quantization_info, data_layout);
}
protected:
@@ -94,13 +101,13 @@ protected:
}
TensorType compute_target(const TensorShape &input_shape, const TensorShape &weights_shape, const TensorShape &biases_shape, const TensorShape &output_shape, PadStrideInfo &pad_stride_info,
- const DataType data_type, const DataType bias_data_type, const QuantizationInfo quantization_info)
+ const DataType data_type, const DataType bias_data_type, const QuantizationInfo quantization_info, const DataLayout data_layout)
{
// Create tensors
- TensorType src = create_tensor<TensorType>(input_shape, data_type, 1, 0, quantization_info);
- TensorType weights = create_tensor<TensorType>(weights_shape, data_type, 1, 0, quantization_info);
- TensorType biases = create_tensor<TensorType>(biases_shape, bias_data_type, 1, 0, quantization_info);
- TensorType dst = create_tensor<TensorType>(output_shape, data_type, 1, 0, quantization_info);
+ TensorType src = create_tensor<TensorType>(input_shape, data_type, 1, 0, quantization_info, data_layout);
+ TensorType weights = create_tensor<TensorType>(weights_shape, data_type, 1, 0, quantization_info, data_layout);
+ TensorType biases = create_tensor<TensorType>(biases_shape, bias_data_type, 1, 0, quantization_info, data_layout);
+ TensorType dst = create_tensor<TensorType>(output_shape, data_type, 1, 0, quantization_info, data_layout);
// Create Depthwise Convolution configure function
FunctionType dwc;
@@ -134,11 +141,11 @@ protected:
}
SimpleTensor<T> compute_reference(const TensorShape &in_shape, const TensorShape &weights_shape, const TensorShape &biases_shape, const TensorShape &out_shape, const PadStrideInfo &pad_stride_info,
- const DataType data_type, const DataType bias_data_type, QuantizationInfo quantization_info)
+ const DataType data_type, const DataType bias_data_type, const QuantizationInfo quantization_info, const DataLayout data_layout)
{
- SimpleTensor<T> src{ in_shape, data_type, 1, 0, quantization_info };
- SimpleTensor<T> weights{ weights_shape, data_type, 1, 0, quantization_info };
- SimpleTensor<TBias> biases{ biases_shape, bias_data_type, 1, 0, quantization_info };
+ SimpleTensor<T> src{ in_shape, data_type, 1, 0, quantization_info, data_layout };
+ SimpleTensor<T> weights{ weights_shape, data_type, 1, 0, quantization_info, data_layout };
+ SimpleTensor<TBias> biases{ biases_shape, bias_data_type, 1, 0, quantization_info, data_layout };
fill(src, 0);
fill(weights, 1);
@@ -158,10 +165,10 @@ class DepthwiseConvolutionLayerValidationFixture : public DepthwiseConvolutionLa
{
public:
template <typename...>
- void setup(TensorShape in_shape, TensorShape weights_shape, TensorShape out_shape, PadStrideInfo pad_stride_info, DataType data_type)
+ void setup(TensorShape in_shape, TensorShape weights_shape, TensorShape out_shape, PadStrideInfo pad_stride_info, DataType data_type, DataLayout data_layout)
{
DepthwiseConvolutionLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(in_shape, weights_shape, out_shape, pad_stride_info,
- data_type, QuantizationInfo());
+ data_type, QuantizationInfo(), data_layout);
}
};
@@ -173,7 +180,7 @@ public:
void setup(TensorShape in_shape, TensorShape weights_shape, TensorShape out_shape, PadStrideInfo pad_stride_info, DataType data_type, QuantizationInfo quantization_info)
{
DepthwiseConvolutionLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(in_shape, weights_shape, out_shape, pad_stride_info,
- data_type, quantization_info);
+ data_type, quantization_info, DataLayout::NCHW);
}
};
} // namespace validation