diff options
Diffstat (limited to 'tests/validation/fixtures/FullyConnectedLayerFixture.h')
-rw-r--r-- | tests/validation/fixtures/FullyConnectedLayerFixture.h | 27 |
1 files changed, 16 insertions, 11 deletions
diff --git a/tests/validation/fixtures/FullyConnectedLayerFixture.h b/tests/validation/fixtures/FullyConnectedLayerFixture.h index 7f0ceadea1..6952b226da 100644 --- a/tests/validation/fixtures/FullyConnectedLayerFixture.h +++ b/tests/validation/fixtures/FullyConnectedLayerFixture.h @@ -34,6 +34,7 @@ #include "tests/framework/Asserts.h" #include "tests/framework/Fixture.h" #include "tests/validation/Helpers.h" +#include "tests/validation/reference/ActivationLayer.h" #include "tests/validation/reference/FullyConnectedLayer.h" #include "tests/validation/reference/Utils.h" @@ -55,7 +56,7 @@ public: public: template <typename...> void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, bool transpose_weights, bool reshape_weights, - DataType data_type, QuantizationInfo quantization_info) + DataType data_type, QuantizationInfo quantization_info, ActivationLayerInfo activation_info) { ARM_COMPUTE_UNUSED(weights_shape); ARM_COMPUTE_UNUSED(bias_shape); @@ -63,6 +64,7 @@ public: _data_type = data_type; _bias_data_type = is_data_type_quantized_asymmetric(data_type) ? DataType::S32 : data_type; _quantization_info = quantization_info; + _activation_info = activation_info; _target = compute_target(input_shape, weights_shape, bias_shape, output_shape, transpose_weights, reshape_weights); _reference = compute_reference(input_shape, weights_shape, bias_shape, output_shape); @@ -130,6 +132,7 @@ protected: FullyConnectedLayerInfo fc_info; fc_info.transpose_weights = transpose_weights; fc_info.are_weights_reshaped = !reshape_weights; + fc_info.activation_info = _activation_info; // Create and configure function. FunctionType fc; @@ -199,14 +202,15 @@ protected: fill(weights, 1); fill(bias, 2); - return reference::fully_connected_layer<T>(src, weights, bias, output_shape); + return reference::activation_layer(reference::fully_connected_layer<T>(src, weights, bias, output_shape), _activation_info, _quantization_info); } - TensorType _target{}; - SimpleTensor<T> _reference{}; - DataType _data_type{}; - DataType _bias_data_type{}; - QuantizationInfo _quantization_info{}; + TensorType _target{}; + SimpleTensor<T> _reference{}; + DataType _data_type{}; + DataType _bias_data_type{}; + QuantizationInfo _quantization_info{}; + ActivationLayerInfo _activation_info{}; }; template <typename TensorType, typename AccessorType, typename FunctionType, typename T> @@ -214,11 +218,12 @@ class FullyConnectedLayerValidationFixture : public FullyConnectedLayerValidatio { public: template <typename...> - void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, bool transpose_weights, bool reshape_weights, DataType data_type) + void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, bool transpose_weights, bool reshape_weights, DataType data_type, + ActivationLayerInfo activation_info) { FullyConnectedLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(input_shape, weights_shape, bias_shape, output_shape, transpose_weights, reshape_weights, data_type, - QuantizationInfo()); + QuantizationInfo(), activation_info); } }; @@ -228,11 +233,11 @@ class FullyConnectedLayerValidationQuantizedFixture : public FullyConnectedLayer public: template <typename...> void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, bool transpose_weights, bool reshape_weights, DataType data_type, - QuantizationInfo quantization_info) + QuantizationInfo quantization_info, ActivationLayerInfo activation_info) { FullyConnectedLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(input_shape, weights_shape, bias_shape, output_shape, transpose_weights, reshape_weights, data_type, - quantization_info); + quantization_info, activation_info); } }; } // namespace validation |