From a855af10a486c53c2271361cb87f349eca64b749 Mon Sep 17 00:00:00 2001 From: Giorgio Arena Date: Mon, 16 Jul 2018 17:20:38 +0100 Subject: COMPMID-1401 Implement NEFullyConnectedLayer for QASYMM8 Change-Id: I0404df6d369855e2f458f2db8f26e81c80a1ee87 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/140148 Reviewed-by: Georgios Pinitas Reviewed-by: Anthony Barbier Reviewed-by: Gian Marco Iodice Tested-by: Jenkins --- .../fixtures/FullyConnectedLayerFixture.h | 54 ++++++++-------------- 1 file changed, 18 insertions(+), 36 deletions(-) (limited to 'tests/validation/fixtures/FullyConnectedLayerFixture.h') diff --git a/tests/validation/fixtures/FullyConnectedLayerFixture.h b/tests/validation/fixtures/FullyConnectedLayerFixture.h index 18321480f8..49c3be0c2e 100644 --- a/tests/validation/fixtures/FullyConnectedLayerFixture.h +++ b/tests/validation/fixtures/FullyConnectedLayerFixture.h @@ -45,7 +45,7 @@ namespace test { namespace validation { -template +template class FullyConnectedLayerValidationGenericFixture : public framework::Fixture { public: @@ -103,8 +103,8 @@ protected: // -----------+-----------+--------------------------- // transpose | | *** // -----------+-----------+--------------------------- - // !transpose | transpose | transpose & - // | | transpose1xW (if required) + // !transpose | transpose | transpose + // | | // // ***: That combination is invalid. But we can ignore the transpose flag and handle all !reshape the same if(!reshape_weights || !transpose_weights) @@ -112,16 +112,6 @@ protected: const size_t shape_x = reshaped_weights_shape.x(); reshaped_weights_shape.set(0, reshaped_weights_shape.y()); reshaped_weights_shape.set(1, shape_x); - - // Weights have to be passed reshaped - // Transpose 1xW for batched version - if(!reshape_weights && output_shape.y() > 1 && run_interleave) - { - const int transpose_width = 16 / data_size_from_type(_data_type); - const float shape_x = reshaped_weights_shape.x(); - reshaped_weights_shape.set(0, reshaped_weights_shape.y() * transpose_width); - reshaped_weights_shape.set(1, static_cast(std::ceil(shape_x / transpose_width))); - } } // Create tensors @@ -170,14 +160,6 @@ protected: // Transpose elementwise tmp = transpose(tmp); - // Reshape weights for batched runs - if(!reshape_weights && output_shape.y() > 1 && run_interleave) - { - // Transpose with interleave - const int interleave_size = 16 / tmp.element_size(); - tmp = transpose(tmp, interleave_size); - } - AccessorType weights_accessor(weights); for(int i = 0; i < tmp.num_elements(); ++i) @@ -222,43 +204,43 @@ protected: QuantizationInfo _quantization_info{}; }; -template -class FullyConnectedLayerValidationFixture : public FullyConnectedLayerValidationGenericFixture +template +class FullyConnectedLayerValidationFixture : public FullyConnectedLayerValidationGenericFixture { public: template void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, bool transpose_weights, bool reshape_weights, DataType data_type) { - FullyConnectedLayerValidationGenericFixture::setup(input_shape, weights_shape, bias_shape, output_shape, transpose_weights, - reshape_weights, data_type, - QuantizationInfo()); + FullyConnectedLayerValidationGenericFixture::setup(input_shape, weights_shape, bias_shape, output_shape, transpose_weights, + reshape_weights, data_type, + QuantizationInfo()); } }; -template -class FullyConnectedLayerValidationFixedPointFixture : public FullyConnectedLayerValidationGenericFixture +template +class FullyConnectedLayerValidationFixedPointFixture : public FullyConnectedLayerValidationGenericFixture { public: template void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, bool transpose_weights, bool reshape_weights, DataType data_type) { - FullyConnectedLayerValidationGenericFixture::setup(input_shape, weights_shape, bias_shape, output_shape, transpose_weights, - reshape_weights, data_type, - QuantizationInfo()); + FullyConnectedLayerValidationGenericFixture::setup(input_shape, weights_shape, bias_shape, output_shape, transpose_weights, + reshape_weights, data_type, + QuantizationInfo()); } }; -template -class FullyConnectedLayerValidationQuantizedFixture : public FullyConnectedLayerValidationGenericFixture +template +class FullyConnectedLayerValidationQuantizedFixture : public FullyConnectedLayerValidationGenericFixture { public: template void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, bool transpose_weights, bool reshape_weights, DataType data_type, QuantizationInfo quantization_info) { - FullyConnectedLayerValidationGenericFixture::setup(input_shape, weights_shape, bias_shape, output_shape, transpose_weights, - reshape_weights, data_type, - quantization_info); + FullyConnectedLayerValidationGenericFixture::setup(input_shape, weights_shape, bias_shape, output_shape, transpose_weights, + reshape_weights, data_type, + quantization_info); } }; } // namespace validation -- cgit v1.2.1