diff options
Diffstat (limited to 'tests/validation/fixtures')
-rw-r--r-- | tests/validation/fixtures/ConvolutionLayerFixture.h | 35 |
1 files changed, 18 insertions, 17 deletions
diff --git a/tests/validation/fixtures/ConvolutionLayerFixture.h b/tests/validation/fixtures/ConvolutionLayerFixture.h index c58a0a2c91..63e6dc9377 100644 --- a/tests/validation/fixtures/ConvolutionLayerFixture.h +++ b/tests/validation/fixtures/ConvolutionLayerFixture.h @@ -122,14 +122,14 @@ protected: { case DataType::QASYMM8: { - std::pair<int, int> bounds = get_quantized_bounds(tensor.quantization_info(), -1.0f, 1.0f); + std::pair<int, int> bounds = get_quantized_bounds(tensor.quantization_info(), -1.0f, 1.0f); std::uniform_int_distribution<uint32_t> distribution(bounds.first, bounds.second); library->fill(tensor, distribution, i); break; } case DataType::QASYMM8_SIGNED: { - std::pair<int, int> bounds = get_quantized_qasymm8_signed_bounds(tensor.quantization_info(), -1.0f, 1.0f); + std::pair<int, int> bounds = get_quantized_qasymm8_signed_bounds(tensor.quantization_info(), -1.0f, 1.0f); std::uniform_int_distribution<int32_t> distribution(bounds.first, bounds.second); library->fill(tensor, distribution, i); break; @@ -476,7 +476,7 @@ inline void rearrange_data(const AccessorType src, AccessorType dst, const arm_c } } -template <typename ConvolutionFunction, typename TensorClass, typename AccessorType, typename ScalarType> +template <typename ConvolutionFunction, typename TensorClass, typename AccessorType, typename ScalarType, bool enable_fast_math> class VariableWeightsFixtureBaseClass : public framework::Fixture { public: @@ -581,14 +581,14 @@ protected: SimpleTensor<ScalarType> _reference{}; }; -template <typename ConvolutionFunction, typename TensorClass, typename AccessorType, typename ScalarType> -class VariableWeightsFixture : public VariableWeightsFixtureBaseClass<ConvolutionFunction, TensorClass, AccessorType, ScalarType> +template <typename ConvolutionFunction, typename TensorClass, typename AccessorType, typename ScalarType, bool enable_fast_math> +class VariableWeightsFixture : public VariableWeightsFixtureBaseClass<ConvolutionFunction, TensorClass, AccessorType, ScalarType, enable_fast_math> { void configure_and_execute_kernel(TensorInfo src_tensor_info, TensorInfo weight_tensor_info, TensorInfo bias_tensor_info, TensorInfo dst_tensor_info, const WeightsInfo weights_info, const PadStrideInfo &conv_info, const Size2D &dilation) { - this->conv->configure(&src_tensor_info, &weight_tensor_info, &bias_tensor_info, &dst_tensor_info, conv_info, weights_info, dilation); + this->conv->configure(&src_tensor_info, &weight_tensor_info, &bias_tensor_info, &dst_tensor_info, conv_info, weights_info, dilation, ActivationLayerInfo(), enable_fast_math); // Allocate input tensors auto src = create_tensor<TensorClass>(src_tensor_info); @@ -624,8 +624,8 @@ class VariableWeightsFixture : public VariableWeightsFixtureBaseClass<Convolutio } }; -template <typename ConvolutionFunction, typename TensorClass, typename AccessorType, typename ScalarType> -class VariableWeightsFixtureNEInterface : public VariableWeightsFixtureBaseClass<ConvolutionFunction, TensorClass, AccessorType, ScalarType> +template <typename ConvolutionFunction, typename TensorClass, typename AccessorType, typename ScalarType, bool enable_fast_math> +class VariableWeightsFixtureNEInterface : public VariableWeightsFixtureBaseClass<ConvolutionFunction, TensorClass, AccessorType, ScalarType, enable_fast_math> { void configure_and_execute_kernel(TensorInfo src_tensor_info, TensorInfo weight_tensor_info, TensorInfo bias_tensor_info, TensorInfo dst_tensor_info, const WeightsInfo weights_info, const PadStrideInfo &conv_info, @@ -644,7 +644,7 @@ class VariableWeightsFixtureNEInterface : public VariableWeightsFixtureBaseClass // Allocate destination tensor this->_target = create_tensor<TensorClass>(dst_tensor_info); this->_target.allocator()->allocate(); - this->conv->configure(&src, &weights_transformed, &bias, &(this->_target), conv_info, weights_info, dilation); + this->conv->configure(&src, &weights_transformed, &bias, &(this->_target), conv_info, weights_info, dilation, ActivationLayerInfo(), enable_fast_math); // Prepare source and biases that are left unchanged. this->fill(AccessorType(src), 0); this->fill(AccessorType(bias), 1); @@ -664,7 +664,7 @@ class VariableWeightsFixtureNEInterface : public VariableWeightsFixtureBaseClass } }; -template <typename ConvolutionClass> +template <typename ConvolutionClass, bool enable_fast_math> class HasOptImplFixture : public framework::Fixture { public: @@ -672,14 +672,15 @@ public: void setup(DataType data_type, arm_compute::WeightFormat query_weight_format) { auto conv = std::make_unique<ConvolutionClass>(); - const auto src_info = TensorInfo(TensorShape(1U, 5U, 2U), 1, data_type, DataLayout::NHWC); - const auto weight_info = TensorInfo(TensorShape(1U, 3U, 2U, 3U), 1, data_type, DataLayout::NHWC); - const auto bias_info = TensorInfo(TensorShape(3U), 1, data_type, DataLayout::NHWC); - auto dst_info = TensorInfo(TensorShape(1U, 7U, 3U), 1, data_type, DataLayout::NHWC); - const auto conv_info = PadStrideInfo(1, 1, 0, 0, 2, 2, DimensionRoundingType::FLOOR); - const WeightsInfo weights_info(false, 3U, 3U, 1U, false, query_weight_format); + const auto src_info = TensorInfo(TensorShape(56U, 56U, 64U), 1, data_type, DataLayout::NHWC); + const auto weight_info = TensorInfo(TensorShape(64, 3U, 3U, 64U), 1, enable_fast_math ? DataType::BFLOAT16 : data_type, DataLayout::NHWC); + const auto bias_info = TensorInfo(TensorShape(64U), 1, data_type, DataLayout::NHWC); + auto dst_info = TensorInfo(TensorShape(56U, 56U, 64U), 1, data_type, DataLayout::NHWC); + const auto conv_info = PadStrideInfo(1, 1, 1, 1, 1, 1, DimensionRoundingType::FLOOR); + const WeightsInfo weights_info(false, 3U, 3U, 64U, false, query_weight_format); _kernel_found = bool(ConvolutionClass::has_opt_impl(_computed_weight_format, &src_info, &weight_info, - &bias_info, &dst_info, conv_info, weights_info)); + &bias_info, &dst_info, conv_info, weights_info, + /*dilation*/ Size2D(1U, 1U), /*act_info*/ ActivationLayerInfo(), enable_fast_math)); } protected: |