diff options
author | Giorgio Arena <giorgio.arena@arm.com> | 2020-02-07 13:46:45 +0000 |
---|---|---|
committer | Giorgio Arena <giorgio.arena@arm.com> | 2020-03-02 15:51:39 +0000 |
commit | 1856ff7ebb29e04c3549b74d7ced336111cbf05e (patch) | |
tree | c94654f0d8535930a81712bf7aadffd757c82577 /tests/validation/fixtures | |
parent | 3c4bf0c4eab5ead756c472f17ddf008b882cc905 (diff) | |
download | ComputeLibrary-1856ff7ebb29e04c3549b74d7ced336111cbf05e.tar.gz |
COMPMID-3097 Fuse activation with fully connected layer CL
Change-Id: I447030e69b9e565f2f81529a41af8c5e7ece7ecf
Signed-off-by: Giorgio Arena <giorgio.arena@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/2702
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'tests/validation/fixtures')
-rw-r--r-- | tests/validation/fixtures/FullyConnectedLayerFixture.h | 27 |
1 files changed, 16 insertions, 11 deletions
diff --git a/tests/validation/fixtures/FullyConnectedLayerFixture.h b/tests/validation/fixtures/FullyConnectedLayerFixture.h index 7f0ceadea1..6952b226da 100644 --- a/tests/validation/fixtures/FullyConnectedLayerFixture.h +++ b/tests/validation/fixtures/FullyConnectedLayerFixture.h @@ -34,6 +34,7 @@ #include "tests/framework/Asserts.h" #include "tests/framework/Fixture.h" #include "tests/validation/Helpers.h" +#include "tests/validation/reference/ActivationLayer.h" #include "tests/validation/reference/FullyConnectedLayer.h" #include "tests/validation/reference/Utils.h" @@ -55,7 +56,7 @@ public: public: template <typename...> void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, bool transpose_weights, bool reshape_weights, - DataType data_type, QuantizationInfo quantization_info) + DataType data_type, QuantizationInfo quantization_info, ActivationLayerInfo activation_info) { ARM_COMPUTE_UNUSED(weights_shape); ARM_COMPUTE_UNUSED(bias_shape); @@ -63,6 +64,7 @@ public: _data_type = data_type; _bias_data_type = is_data_type_quantized_asymmetric(data_type) ? DataType::S32 : data_type; _quantization_info = quantization_info; + _activation_info = activation_info; _target = compute_target(input_shape, weights_shape, bias_shape, output_shape, transpose_weights, reshape_weights); _reference = compute_reference(input_shape, weights_shape, bias_shape, output_shape); @@ -130,6 +132,7 @@ protected: FullyConnectedLayerInfo fc_info; fc_info.transpose_weights = transpose_weights; fc_info.are_weights_reshaped = !reshape_weights; + fc_info.activation_info = _activation_info; // Create and configure function. FunctionType fc; @@ -199,14 +202,15 @@ protected: fill(weights, 1); fill(bias, 2); - return reference::fully_connected_layer<T>(src, weights, bias, output_shape); + return reference::activation_layer(reference::fully_connected_layer<T>(src, weights, bias, output_shape), _activation_info, _quantization_info); } - TensorType _target{}; - SimpleTensor<T> _reference{}; - DataType _data_type{}; - DataType _bias_data_type{}; - QuantizationInfo _quantization_info{}; + TensorType _target{}; + SimpleTensor<T> _reference{}; + DataType _data_type{}; + DataType _bias_data_type{}; + QuantizationInfo _quantization_info{}; + ActivationLayerInfo _activation_info{}; }; template <typename TensorType, typename AccessorType, typename FunctionType, typename T> @@ -214,11 +218,12 @@ class FullyConnectedLayerValidationFixture : public FullyConnectedLayerValidatio { public: template <typename...> - void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, bool transpose_weights, bool reshape_weights, DataType data_type) + void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, bool transpose_weights, bool reshape_weights, DataType data_type, + ActivationLayerInfo activation_info) { FullyConnectedLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(input_shape, weights_shape, bias_shape, output_shape, transpose_weights, reshape_weights, data_type, - QuantizationInfo()); + QuantizationInfo(), activation_info); } }; @@ -228,11 +233,11 @@ class FullyConnectedLayerValidationQuantizedFixture : public FullyConnectedLayer public: template <typename...> void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, bool transpose_weights, bool reshape_weights, DataType data_type, - QuantizationInfo quantization_info) + QuantizationInfo quantization_info, ActivationLayerInfo activation_info) { FullyConnectedLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(input_shape, weights_shape, bias_shape, output_shape, transpose_weights, reshape_weights, data_type, - quantization_info); + quantization_info, activation_info); } }; } // namespace validation |