aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures/ConvolutionLayerFixture.h
diff options
context:
space:
mode:
authorMoritz Pflanzer <moritz.pflanzer@arm.com>2017-09-08 09:53:14 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:35:24 +0000
commitcde1e8adeacea5c33a1682ef7b05a0ef643463b8 (patch)
tree47e58abdf5bb6ef39db362a2ac777c93b3f76666 /tests/validation/fixtures/ConvolutionLayerFixture.h
parent86b53339679e12c952a24a8845a5409ac3d52de6 (diff)
downloadComputeLibrary-cde1e8adeacea5c33a1682ef7b05a0ef643463b8.tar.gz
COMPMID-415: Add tests for ConvolutionLayer reshaped weights
Change-Id: I6c1209a2afafccba2cbdbcda16aceb3ae0cc7b4b Reviewed-on: http://mpd-gerrit.cambridge.arm.com/87000 Tested-by: Kaizen <jeremy.johnson+kaizengerrit@arm.com> Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Diffstat (limited to 'tests/validation/fixtures/ConvolutionLayerFixture.h')
-rw-r--r--tests/validation/fixtures/ConvolutionLayerFixture.h125
1 files changed, 116 insertions, 9 deletions
diff --git a/tests/validation/fixtures/ConvolutionLayerFixture.h b/tests/validation/fixtures/ConvolutionLayerFixture.h
index 87b11ac130..dd2df727e9 100644
--- a/tests/validation/fixtures/ConvolutionLayerFixture.h
+++ b/tests/validation/fixtures/ConvolutionLayerFixture.h
@@ -32,6 +32,7 @@
#include "tests/framework/Asserts.h"
#include "tests/framework/Fixture.h"
#include "tests/validation/CPP/ConvolutionLayer.h"
+#include "tests/validation/CPP/Utils.h"
#include "tests/validation/Helpers.h"
#include <random>
@@ -47,12 +48,12 @@ class ConvolutionValidationFixedPointFixture : public framework::Fixture
{
public:
template <typename...>
- void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, PadStrideInfo info, DataType data_type, int fractional_bits)
+ void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, PadStrideInfo info, bool reshape_weights, DataType data_type, int fractional_bits)
{
_fractional_bits = fractional_bits;
_data_type = data_type;
- _target = compute_target(input_shape, weights_shape, bias_shape, output_shape, info, data_type, fractional_bits);
+ _target = compute_target(input_shape, weights_shape, bias_shape, output_shape, info, reshape_weights, data_type, fractional_bits);
_reference = compute_reference(input_shape, weights_shape, bias_shape, output_shape, info, data_type, fractional_bits);
}
@@ -75,17 +76,45 @@ protected:
}
TensorType compute_target(const TensorShape &input_shape, const TensorShape &weights_shape, const TensorShape &bias_shape, const TensorShape &output_shape, const PadStrideInfo &info,
- DataType data_type, int fixed_point_position)
+ bool reshape_weights, DataType data_type, int fixed_point_position)
{
+ WeightsInfo weights_info(!reshape_weights, weights_shape.x(), weights_shape.y(), weights_shape[3]);
+ TensorShape reshaped_weights_shape(weights_shape);
+
+ if(!reshape_weights)
+ {
+ // Check if its a "fully connected" convolution
+ const bool is_fully_connected_convolution = (output_shape.x() == 1 && output_shape.y() == 1);
+ reshaped_weights_shape.collapse(3);
+
+ if(bias_shape.total_size() > 0)
+ {
+ reshaped_weights_shape.set(0, reshaped_weights_shape.x() + 1);
+ }
+
+ if(is_fully_connected_convolution)
+ {
+ const size_t shape_x = reshaped_weights_shape.x();
+ reshaped_weights_shape.set(0, reshaped_weights_shape.y());
+ reshaped_weights_shape.set(1, shape_x);
+ }
+ else
+ {
+ const int interleave_width = 16 / data_size_from_type(data_type);
+ reshaped_weights_shape.set(0, reshaped_weights_shape.x() * interleave_width);
+ reshaped_weights_shape.set(1, static_cast<unsigned int>(std::ceil(reshaped_weights_shape.y() / static_cast<float>(interleave_width))));
+ }
+ }
+
// Create tensors
TensorType src = create_tensor<TensorType>(input_shape, data_type, 1, fixed_point_position);
- TensorType weights = create_tensor<TensorType>(weights_shape, data_type, 1, fixed_point_position);
+ TensorType weights = create_tensor<TensorType>(reshaped_weights_shape, data_type, 1, fixed_point_position);
TensorType bias = create_tensor<TensorType>(bias_shape, data_type, 1, fixed_point_position);
TensorType dst = create_tensor<TensorType>(output_shape, data_type, 1, fixed_point_position);
// Create and configure function
FunctionType conv;
- conv.configure(&src, &weights, &bias, &dst, info);
+ conv.configure(&src, &weights, &bias, &dst, info, weights_info);
ARM_COMPUTE_EXPECT(src.info()->is_resizable(), framework::LogLevel::ERRORS);
ARM_COMPUTE_EXPECT(weights.info()->is_resizable(), framework::LogLevel::ERRORS);
@@ -105,8 +134,41 @@ protected:
// Fill tensors
fill(AccessorType(src), 0);
- fill(AccessorType(weights), 1);
- fill(AccessorType(bias), 2);
+
+ if(!reshape_weights)
+ {
+ const bool is_fully_connected_convolution = (output_shape.x() == 1 && output_shape.y() == 1);
+
+ TensorShape tmp_weights_shape(weights_shape);
+ SimpleTensor<T> tmp_weights(tmp_weights_shape, data_type, 1, fixed_point_position);
+ SimpleTensor<T> tmp_bias(bias_shape, data_type, 1, fixed_point_position);
+
+ // Fill with original shape
+ fill(tmp_weights, 1);
+ fill(tmp_bias, 2);
+
+ tmp_weights = linearise_weights(tmp_weights, &tmp_bias);
+
+ if(!is_fully_connected_convolution)
+ {
+ // Transpose with interleave
+ const int interleave_size = 16 / tmp_weights.element_size();
+ tmp_weights = transpose(std::move(tmp_weights), interleave_size);
+ }
+
+ AccessorType weights_accessor(weights);
+
+ for(int i = 0; i < tmp_weights.num_elements(); ++i)
+ {
+ Coordinates coord = index2coord(tmp_weights.shape(), i);
+ std::copy_n(static_cast<const T *>(tmp_weights(coord)), 1, static_cast<T *>(weights_accessor(coord)));
+ }
+ }
+ else
+ {
+ fill(AccessorType(weights), 1);
+ fill(AccessorType(bias), 2);
+ }
// Compute NEConvolutionLayer function
conv.run();
@@ -134,6 +196,51 @@ protected:
SimpleTensor<T> _reference{};
int _fractional_bits{};
DataType _data_type{};
+
+private:
+ template <typename U>
+ SimpleTensor<U> linearise_weights(const SimpleTensor<U> &weights, const SimpleTensor<U> *biases = nullptr)
+ {
+ TensorShape dst_shape(weights.shape());
+ dst_shape.collapse(3);
+
+ if(biases != nullptr)
+ {
+ dst_shape.set(0, dst_shape.x() + 1);
+ }
+
+ const size_t shape_x = dst_shape.x();
+ dst_shape.set(0, dst_shape.y());
+ dst_shape.set(1, shape_x);
+
+ SimpleTensor<U> dst(dst_shape, weights.data_type());
+
+ // Don't iterate over biases yet
+ for(int weights_idx = 0; weights_idx < weights.num_elements(); ++weights_idx)
+ {
+ Coordinates weights_coord = index2coord(weights.shape(), weights_idx);
+ const int dst_row = weights_idx % weights.shape().total_size_lower(3);
+ Coordinates dst_coord{ weights_coord[3], dst_row, weights_coord[4] };
+ const int dst_idx = coord2index(dst.shape(), dst_coord);
+
+ dst[dst_idx] = weights[weights_idx];
+ }
+
+ if(biases != nullptr)
+ {
+ // Fill last row with biases
+ for(int bias_idx = 0; bias_idx < biases->num_elements(); ++bias_idx)
+ {
+ Coordinates bias_coord = index2coord(biases->shape(), bias_idx);
+ Coordinates dst_coord{ bias_coord.x(), static_cast<int>(dst.shape().y()) - 1, bias_coord.y() };
+ int dst_idx = coord2index(dst.shape(), dst_coord);
+
+ dst[dst_idx] = (*biases)[bias_idx];
+ }
+ }
+
+ return dst;
+ }
};
template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
@@ -141,9 +248,9 @@ class ConvolutionValidationFixture : public ConvolutionValidationFixedPointFixtu
{
public:
template <typename...>
- void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, PadStrideInfo info, DataType data_type)
+ void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, PadStrideInfo info, bool reshape_weights, DataType data_type)
{
- ConvolutionValidationFixedPointFixture<TensorType, AccessorType, FunctionType, T>::setup(input_shape, weights_shape, bias_shape, output_shape, info, data_type, 0);
+ ConvolutionValidationFixedPointFixture<TensorType, AccessorType, FunctionType, T>::setup(input_shape, weights_shape, bias_shape, output_shape, info, reshape_weights, data_type, 0);
}
};
} // namespace validation