From 5ce897f80a1a6ade8a07d61c7aaaf70d2aa5ee02 Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Wed, 29 Apr 2020 11:44:10 +0100 Subject: COMPMID-3108: Add Winograd 3x3,4x4 FP16 support for NEON Change-Id: I20680dc74a3d709297539e2132417308a7aecc9d Signed-off-by: Georgios Pinitas Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/3159 Reviewed-by: Michele Di Giorgio Reviewed-by: Gian Marco Iodice Tested-by: Arm Jenkins Comments-Addressed: Arm Jenkins --- .../NEON/functions/NEWinogradConvolutionLayer.cpp | 259 +++++++++++++-------- 1 file changed, 160 insertions(+), 99 deletions(-) (limited to 'src/runtime/NEON/functions/NEWinogradConvolutionLayer.cpp') diff --git a/src/runtime/NEON/functions/NEWinogradConvolutionLayer.cpp b/src/runtime/NEON/functions/NEWinogradConvolutionLayer.cpp index 81190fbf0e..d567a18709 100644 --- a/src/runtime/NEON/functions/NEWinogradConvolutionLayer.cpp +++ b/src/runtime/NEON/functions/NEWinogradConvolutionLayer.cpp @@ -23,11 +23,11 @@ */ #include "arm_compute/runtime/NEON/functions/NEWinogradConvolutionLayer.h" +#include "arm_compute/core/CPP/Validate.h" #include "arm_compute/core/Error.h" #include "arm_compute/core/NEON/kernels/NEWinogradConvolutionLayerKernel.h" #include "arm_compute/core/Utils.h" #include "arm_compute/core/Validate.h" -#include "arm_compute/core/Validate.h" #include "arm_compute/core/utils/misc/ShapeCalculator.h" #include "arm_compute/runtime/NEON/NEScheduler.h" #include "arm_compute/runtime/NEON/functions/NEGEMMAssemblyDispatch.h" @@ -43,18 +43,32 @@ namespace inline Status validate_kernel_3x3(const Size2D input_dims, const ITensorInfo *input, const TensorInfo *input0, const TensorInfo *input1, const TensorInfo *batched_mm_output, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, const WinogradInfo &winograd_info, const ActivationLayerInfo &act_info) { - if(input_dims.width > 4 && input_dims.height > 4) + ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F16, DataType::F32); + + if(input->data_type() == DataType::F32) { - ARM_COMPUTE_RETURN_ON_ERROR((NEWinogradLayerTransformInputKernel::validate(input, input0, winograd_info))); - ARM_COMPUTE_RETURN_ON_ERROR((NEWinogradLayerTransformWeightsKernel::validate(weights, input1, winograd_info))); - ARM_COMPUTE_RETURN_ON_ERROR((NEWinogradLayerTransformOutputKernel::validate(batched_mm_output, biases, output, winograd_info))); + if(input_dims.width > 4 && input_dims.height > 4) + { + ARM_COMPUTE_RETURN_ON_ERROR((NEWinogradLayerTransformInputKernel::validate(input, input0, winograd_info))); + ARM_COMPUTE_RETURN_ON_ERROR((NEWinogradLayerTransformWeightsKernel::validate(weights, input1, winograd_info))); + ARM_COMPUTE_RETURN_ON_ERROR((NEWinogradLayerTransformOutputKernel::validate(batched_mm_output, biases, output, winograd_info))); + } + else + { + ARM_COMPUTE_RETURN_ON_ERROR((NEWinogradLayerTransformInputKernel::validate(input, input0, winograd_info))); + ARM_COMPUTE_RETURN_ON_ERROR((NEWinogradLayerTransformWeightsKernel::validate(weights, input1, winograd_info))); + ARM_COMPUTE_RETURN_ON_ERROR((NEWinogradLayerTransformOutputKernel::validate(batched_mm_output, biases, output, winograd_info))); + } } - else +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + else if(input->data_type() == DataType::F32) { - ARM_COMPUTE_RETURN_ON_ERROR((NEWinogradLayerTransformInputKernel::validate(input, input0, winograd_info))); - ARM_COMPUTE_RETURN_ON_ERROR((NEWinogradLayerTransformWeightsKernel::validate(weights, input1, winograd_info))); - ARM_COMPUTE_RETURN_ON_ERROR((NEWinogradLayerTransformOutputKernel::validate(batched_mm_output, biases, output, winograd_info))); + ARM_COMPUTE_RETURN_ON_ERROR((NEWinogradLayerTransformInputKernel<__fp16, 4, 4, 3, 3>::validate(input, input0, winograd_info))); + ARM_COMPUTE_RETURN_ON_ERROR((NEWinogradLayerTransformWeightsKernel<__fp16, 4, 4, 3, 3>::validate(weights, input1, winograd_info))); + ARM_COMPUTE_RETURN_ON_ERROR((NEWinogradLayerTransformOutputKernel<__fp16, 4, 4, 3, 3>::validate(batched_mm_output, biases, output, winograd_info))); } +#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ if(act_info.enabled()) { @@ -79,6 +93,7 @@ inline Status validate_kernel_5x5(const ITensorInfo *input, const TensorInfo *in inline Status validate_kernel_3x1(const ITensorInfo *input, const TensorInfo *input0, const TensorInfo *input1, const TensorInfo *batched_mm_output, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, const WinogradInfo &winograd_info, const ActivationLayerInfo &act_info) { + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F32); ARM_COMPUTE_RETURN_ON_ERROR((NEWinogradLayerTransformInputKernel::validate(input, input0, winograd_info))); ARM_COMPUTE_RETURN_ON_ERROR((NEWinogradLayerTransformWeightsKernel::validate(weights, input1, winograd_info))); ARM_COMPUTE_RETURN_ON_ERROR((NEWinogradLayerTransformOutputKernel::validate(batched_mm_output, biases, output, winograd_info))); @@ -92,6 +107,7 @@ inline Status validate_kernel_3x1(const ITensorInfo *input, const TensorInfo *in inline Status validate_kernel_1x3(const ITensorInfo *input, const TensorInfo *input0, const TensorInfo *input1, const TensorInfo *batched_mm_output, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, const WinogradInfo &winograd_info, const ActivationLayerInfo &act_info) { + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F32); ARM_COMPUTE_RETURN_ON_ERROR((NEWinogradLayerTransformInputKernel::validate(input, input0, winograd_info))); ARM_COMPUTE_RETURN_ON_ERROR((NEWinogradLayerTransformWeightsKernel::validate(weights, input1, winograd_info))); ARM_COMPUTE_RETURN_ON_ERROR((NEWinogradLayerTransformOutputKernel::validate(batched_mm_output, biases, output, winograd_info))); @@ -106,6 +122,7 @@ inline Status validate_kernel_1x3(const ITensorInfo *input, const TensorInfo *in inline Status validate_kernel_5x1(const ITensorInfo *input, const TensorInfo *input0, const TensorInfo *input1, const TensorInfo *batched_mm_output, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, const WinogradInfo &winograd_info, const ActivationLayerInfo &act_info) { + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F32); ARM_COMPUTE_RETURN_ON_ERROR((NEWinogradLayerTransformInputKernel::validate(input, input0, winograd_info))); ARM_COMPUTE_RETURN_ON_ERROR((NEWinogradLayerTransformWeightsKernel::validate(weights, input1, winograd_info))); ARM_COMPUTE_RETURN_ON_ERROR((NEWinogradLayerTransformOutputKernel::validate(batched_mm_output, biases, output, winograd_info))); @@ -118,6 +135,7 @@ inline Status validate_kernel_5x1(const ITensorInfo *input, const TensorInfo *in inline Status validate_kernel_1x5(const ITensorInfo *input, const TensorInfo *input0, const TensorInfo *input1, const TensorInfo *batched_mm_output, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, const WinogradInfo &winograd_info, const ActivationLayerInfo &act_info) { + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F32); ARM_COMPUTE_RETURN_ON_ERROR((NEWinogradLayerTransformInputKernel::validate(input, input0, winograd_info))); ARM_COMPUTE_RETURN_ON_ERROR((NEWinogradLayerTransformWeightsKernel::validate(weights, input1, winograd_info))); ARM_COMPUTE_RETURN_ON_ERROR((NEWinogradLayerTransformOutputKernel::validate(batched_mm_output, biases, output, winograd_info))); @@ -131,6 +149,7 @@ inline Status validate_kernel_1x5(const ITensorInfo *input, const TensorInfo *in inline Status validate_kernel_7x1(const ITensorInfo *input, const TensorInfo *input0, const TensorInfo *input1, const TensorInfo *batched_mm_output, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, const WinogradInfo &winograd_info, const ActivationLayerInfo &act_info) { + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F32); ARM_COMPUTE_RETURN_ON_ERROR((NEWinogradLayerTransformInputKernel::validate(input, input0, winograd_info))); ARM_COMPUTE_RETURN_ON_ERROR((NEWinogradLayerTransformWeightsKernel::validate(weights, input1, winograd_info))); ARM_COMPUTE_RETURN_ON_ERROR((NEWinogradLayerTransformOutputKernel::validate(batched_mm_output, biases, output, winograd_info))); @@ -144,6 +163,7 @@ inline Status validate_kernel_7x1(const ITensorInfo *input, const TensorInfo *in inline Status validate_kernel_1x7(const ITensorInfo *input, const TensorInfo *input0, const TensorInfo *input1, const TensorInfo *batched_mm_output, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, const WinogradInfo &winograd_info, const ActivationLayerInfo &act_info) { + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::F32); ARM_COMPUTE_RETURN_ON_ERROR((NEWinogradLayerTransformInputKernel::validate(input, input0, winograd_info))); ARM_COMPUTE_RETURN_ON_ERROR((NEWinogradLayerTransformWeightsKernel::validate(weights, input1, winograd_info))); ARM_COMPUTE_RETURN_ON_ERROR((NEWinogradLayerTransformOutputKernel::validate(batched_mm_output, biases, output, winograd_info))); @@ -169,21 +189,27 @@ inline Tensor4DShape internal_get_input_shape(const arm_compute::ITensor *input) Status validate_arguments(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, const PadStrideInfo &conv_info) { ARM_COMPUTE_UNUSED(output); + ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(conv_info.stride().first != 1 || conv_info.stride().second != 1, "Winograd layer only supports unit strides."); if(biases != nullptr) { ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, biases); ARM_COMPUTE_RETURN_ERROR_ON(biases->num_dimensions() > 1); } - return INEWinogradLayerTransformWeightsKernel::validate(input, weights); + return INEWinogradLayerTransformWeightsKernel::validate(input, weights); } -Size2D winograd_output_tile(const Size2D &input_dims, const Size2D &kernel_dims) +Size2D winograd_output_tile(const Size2D &input_dims, const Size2D &kernel_dims, DataType data_type) { Size2D output_tile = Size2D{}; if(kernel_dims == Size2D(3U, 3U)) { output_tile = (input_dims.width <= 4 || input_dims.height <= 4) ? Size2D(2U, 2U) : Size2D(4U, 4U); + if(data_type == DataType::F16) + { + output_tile = Size2D(4U, 4U); + } } else if(kernel_dims == Size2D(5U, 5U)) { @@ -216,12 +242,17 @@ Size2D winograd_output_tile(const Size2D &input_dims, const Size2D &kernel_dims) return output_tile; } -bool check_support_fast_math(const Size2D &output_tile, const Size2D &kernel_size) +bool check_support_fast_math(const Size2D &output_tile, const Size2D &kernel_size, DataType data_type) { // Check if we want to configure a Winograd configuration which requires fast math using WinogradConfiguration = std::pair, std::pair>; - const std::vector fast_math_winograd = + const std::vector fast_math_winograd_f16 = + { + WinogradConfiguration(std::pair(4, 4), std::pair(3, 3)) + }; + + const std::vector fast_math_winograd_f32 = { WinogradConfiguration(std::pair(2, 2), std::pair(5, 5)), WinogradConfiguration(std::pair(4, 4), std::pair(5, 5)) @@ -230,7 +261,15 @@ bool check_support_fast_math(const Size2D &output_tile, const Size2D &kernel_siz auto p = std::make_pair(std::pair(output_tile.width, output_tile.height), std::pair(kernel_size.width, kernel_size.height)); - return std::find(fast_math_winograd.begin(), fast_math_winograd.end(), p) != fast_math_winograd.end(); + switch(data_type) + { + case DataType::F16: + return std::find(fast_math_winograd_f16.begin(), fast_math_winograd_f16.end(), p) != fast_math_winograd_f16.end(); + case DataType::F32: + return std::find(fast_math_winograd_f32.begin(), fast_math_winograd_f32.end(), p) != fast_math_winograd_f32.end(); + default: + return false; + } } inline bool fuse_function_supported(const ActivationLayerInfo &act_info) @@ -256,7 +295,6 @@ arm_gemm::Activation arm_gemm_activation_from_acl_activation(const ActivationLay } } } - } //namespace NEWinogradConvolutionLayer::NEWinogradConvolutionLayer(const std::shared_ptr &memory_manager) @@ -278,14 +316,16 @@ void NEWinogradConvolutionLayer::configure(const ITensor *input, const ITensor * const unsigned int height_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT); const unsigned int channel_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL); - const Size2D input_dims = Size2D(input->info()->dimension(width_idx), input->info()->dimension(height_idx)); - const Size2D kernel_size = Size2D(weights->info()->dimension(width_idx), weights->info()->dimension(height_idx)); - const Size2D output_tile = winograd_output_tile(input_dims, kernel_size); + const Size2D input_dims = Size2D(input->info()->dimension(width_idx), input->info()->dimension(height_idx)); + const Size2D kernel_size = Size2D(weights->info()->dimension(width_idx), weights->info()->dimension(height_idx)); + const DataType data_type = input->info()->data_type(); + const Size2D output_tile = winograd_output_tile(input_dims, kernel_size, data_type); // Check if the Winograd configuration requires fast math if(!enable_fast_math) { - ARM_COMPUTE_ERROR_ON_MSG(check_support_fast_math(output_tile, kernel_size), "This Winograd configuration requires enable_fast_math=true"); + ARM_COMPUTE_ERROR_ON_MSG(check_support_fast_math(output_tile, kernel_size, data_type), + "This Winograd configuration requires enable_fast_math=true"); } _weights = weights; @@ -293,101 +333,122 @@ void NEWinogradConvolutionLayer::configure(const ITensor *input, const ITensor * _output = output; _is_prepared = false; - std::unique_ptr> transform_input_kernel; - std::unique_ptr> transform_weights_kernel; - std::unique_ptr> transform_output_kernel; - int n_gemms = 0; int N_BLOCK = 0; // Size of block used by GEMM. - if(kernel_size == Size2D(3, 3)) + std::unique_ptr transform_input_kernel; + std::unique_ptr transform_weights_kernel; + std::unique_ptr transform_output_kernel; + + if(data_type == DataType::F32) { - if(input->info()->dimension(width_idx) > 4 && input->info()->dimension(height_idx) > 4) + if(kernel_size == Size2D(3, 3)) { - using config = NEWinogradLayerConfiguration; + if(input->info()->dimension(width_idx) > 4 && input->info()->dimension(height_idx) > 4) + { + using config = NEWinogradLayerConfiguration; + transform_input_kernel = support::cpp14::make_unique(); + transform_weights_kernel = support::cpp14::make_unique(); + transform_output_kernel = support::cpp14::make_unique(); + n_gemms = config::WinogradBase::N_GEMMS; + N_BLOCK = config::WinogradConv::N_BLOCK; + } + else + { + using config = NEWinogradLayerConfiguration; + transform_input_kernel = support::cpp14::make_unique(); + transform_weights_kernel = support::cpp14::make_unique(); + transform_output_kernel = support::cpp14::make_unique(); + n_gemms = config::WinogradBase::N_GEMMS; + N_BLOCK = config::WinogradConv::N_BLOCK; + } + } + else if(kernel_size == Size2D(5, 5)) + { + using config = NEWinogradLayerConfiguration; transform_input_kernel = support::cpp14::make_unique(); transform_weights_kernel = support::cpp14::make_unique(); transform_output_kernel = support::cpp14::make_unique(); n_gemms = config::WinogradBase::N_GEMMS; N_BLOCK = config::WinogradConv::N_BLOCK; } - else + else if(kernel_size == Size2D(1, 3)) { - using config = NEWinogradLayerConfiguration; + using config = NEWinogradLayerConfiguration; transform_input_kernel = support::cpp14::make_unique(); transform_weights_kernel = support::cpp14::make_unique(); transform_output_kernel = support::cpp14::make_unique(); n_gemms = config::WinogradBase::N_GEMMS; N_BLOCK = config::WinogradConv::N_BLOCK; } + else if(kernel_size == Size2D(3, 1)) + { + using config = NEWinogradLayerConfiguration; + transform_input_kernel = support::cpp14::make_unique(); + transform_weights_kernel = support::cpp14::make_unique(); + transform_output_kernel = support::cpp14::make_unique(); + n_gemms = config::WinogradBase::N_GEMMS; + N_BLOCK = config::WinogradConv::N_BLOCK; + } + else if(kernel_size == Size2D(1, 5)) + { + using config = NEWinogradLayerConfiguration; + transform_input_kernel = support::cpp14::make_unique(); + transform_weights_kernel = support::cpp14::make_unique(); + transform_output_kernel = support::cpp14::make_unique(); + n_gemms = config::WinogradBase::N_GEMMS; + N_BLOCK = config::WinogradConv::N_BLOCK; + } + else if(kernel_size == Size2D(5, 1)) + { + using config = NEWinogradLayerConfiguration; + transform_input_kernel = support::cpp14::make_unique(); + transform_weights_kernel = support::cpp14::make_unique(); + transform_output_kernel = support::cpp14::make_unique(); + n_gemms = config::WinogradBase::N_GEMMS; + N_BLOCK = config::WinogradConv::N_BLOCK; + } + else if(kernel_size == Size2D(1, 7)) + { + using config = NEWinogradLayerConfiguration; + transform_input_kernel = support::cpp14::make_unique(); + transform_weights_kernel = support::cpp14::make_unique(); + transform_output_kernel = support::cpp14::make_unique(); + n_gemms = config::WinogradBase::N_GEMMS; + N_BLOCK = config::WinogradConv::N_BLOCK; + } + else if(kernel_size == Size2D(7, 1)) + { + using config = NEWinogradLayerConfiguration; + transform_input_kernel = support::cpp14::make_unique(); + transform_weights_kernel = support::cpp14::make_unique(); + transform_output_kernel = support::cpp14::make_unique(); + n_gemms = config::WinogradBase::N_GEMMS; + N_BLOCK = config::WinogradConv::N_BLOCK; + } + else + { + ARM_COMPUTE_ERROR("Not supported."); + } } - else if(kernel_size == Size2D(5, 5)) - { - using config = NEWinogradLayerConfiguration; - transform_input_kernel = support::cpp14::make_unique(); - transform_weights_kernel = support::cpp14::make_unique(); - transform_output_kernel = support::cpp14::make_unique(); - n_gemms = config::WinogradBase::N_GEMMS; - N_BLOCK = config::WinogradConv::N_BLOCK; - } - else if(kernel_size == Size2D(1, 3)) - { - using config = NEWinogradLayerConfiguration; - transform_input_kernel = support::cpp14::make_unique(); - transform_weights_kernel = support::cpp14::make_unique(); - transform_output_kernel = support::cpp14::make_unique(); - n_gemms = config::WinogradBase::N_GEMMS; - N_BLOCK = config::WinogradConv::N_BLOCK; - } - else if(kernel_size == Size2D(3, 1)) - { - using config = NEWinogradLayerConfiguration; - transform_input_kernel = support::cpp14::make_unique(); - transform_weights_kernel = support::cpp14::make_unique(); - transform_output_kernel = support::cpp14::make_unique(); - n_gemms = config::WinogradBase::N_GEMMS; - N_BLOCK = config::WinogradConv::N_BLOCK; - } - else if(kernel_size == Size2D(1, 5)) - { - using config = NEWinogradLayerConfiguration; - transform_input_kernel = support::cpp14::make_unique(); - transform_weights_kernel = support::cpp14::make_unique(); - transform_output_kernel = support::cpp14::make_unique(); - n_gemms = config::WinogradBase::N_GEMMS; - N_BLOCK = config::WinogradConv::N_BLOCK; - } - else if(kernel_size == Size2D(5, 1)) - { - using config = NEWinogradLayerConfiguration; - transform_input_kernel = support::cpp14::make_unique(); - transform_weights_kernel = support::cpp14::make_unique(); - transform_output_kernel = support::cpp14::make_unique(); - n_gemms = config::WinogradBase::N_GEMMS; - N_BLOCK = config::WinogradConv::N_BLOCK; - } - else if(kernel_size == Size2D(1, 7)) - { - using config = NEWinogradLayerConfiguration; - transform_input_kernel = support::cpp14::make_unique(); - transform_weights_kernel = support::cpp14::make_unique(); - transform_output_kernel = support::cpp14::make_unique(); - n_gemms = config::WinogradBase::N_GEMMS; - N_BLOCK = config::WinogradConv::N_BLOCK; - } - else if(kernel_size == Size2D(7, 1)) - { - using config = NEWinogradLayerConfiguration; - transform_input_kernel = support::cpp14::make_unique(); - transform_weights_kernel = support::cpp14::make_unique(); - transform_output_kernel = support::cpp14::make_unique(); - n_gemms = config::WinogradBase::N_GEMMS; - N_BLOCK = config::WinogradConv::N_BLOCK; - } - else +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + else if(data_type == DataType::F16) { - ARM_COMPUTE_ERROR("Not supported."); + if(kernel_size == Size2D(3, 3)) + { + using config = NEWinogradLayerConfiguration<__fp16, __fp16, 4, 4, 3, 3>; + transform_input_kernel = support::cpp14::make_unique(); + transform_weights_kernel = support::cpp14::make_unique(); + transform_output_kernel = support::cpp14::make_unique(); + n_gemms = config::WinogradBase::N_GEMMS; + N_BLOCK = config::WinogradConv::N_BLOCK; + } + else + { + ARM_COMPUTE_ERROR("Not supported."); + } } +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC const PaddingType use_padding_type = (conv_info.pad_top() != 0u || conv_info.pad_left() != 0) ? PADDING_SAME : PADDING_VALID; const bool use_same_padding = use_padding_type == PADDING_SAME; @@ -397,7 +458,6 @@ void NEWinogradConvolutionLayer::configure(const ITensor *input, const ITensor * const int out_channels = output->info()->dimension(channel_idx); const Tensor4DShape in_shape(internal_get_input_shape(input)); - const DataType data_type = input->info()->data_type(); const size_t data_type_size = input->info()->element_size(); // Get the memory required to instantiate a new Winograd operator. constexpr size_t storage_alignment = 64; @@ -592,14 +652,16 @@ Status NEWinogradConvolutionLayer::validate(const ITensorInfo *input, const ITen const size_t idx_height = get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::HEIGHT); // Input shape, kernel size and output tile - const Size2D input_dims = Size2D(input->dimension(idx_width), input->dimension(idx_height)); - const Size2D kernel_size = Size2D(weights->dimension(idx_width), weights->dimension(idx_height)); - const Size2D output_tile = winograd_output_tile(input_dims, kernel_size); + const Size2D input_dims = Size2D(input->dimension(idx_width), input->dimension(idx_height)); + const Size2D kernel_size = Size2D(weights->dimension(idx_width), weights->dimension(idx_height)); + const DataType data_type = input->data_type(); + const Size2D output_tile = winograd_output_tile(input_dims, kernel_size, data_type); // Check if the Winograd configuration requires fast math if(!enable_fast_math) { - ARM_COMPUTE_RETURN_ERROR_ON_MSG(check_support_fast_math(output_tile, kernel_size), "This Winograd configuration requires enable_fast_math=true"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(check_support_fast_math(output_tile, kernel_size, data_type), + "This Winograd configuration requires enable_fast_math=true"); } const WinogradInfo winograd_info = WinogradInfo(output_tile, @@ -706,5 +768,4 @@ void NEWinogradConvolutionLayer::prepare() _is_prepared = true; } } - } // namespace arm_compute -- cgit v1.2.1