aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/NEDirectConvolutionLayerKernel.cpp
diff options
context:
space:
mode:
authorGiorgio Arena <giorgio.arena@arm.com>2018-03-16 14:02:34 +0000
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:50:48 +0000
commitc0f54434383f945d95f95549c1c4b0d5f5d2caff (patch)
treec4dadc7d83fa9dccef8cd7e85b31223266946093 /src/core/NEON/kernels/NEDirectConvolutionLayerKernel.cpp
parent3c520c5a6ca9352560828fdf389d31e38b85afeb (diff)
downloadComputeLibrary-c0f54434383f945d95f95549c1c4b0d5f5d2caff.tar.gz
COMPMID-808 Add NHWC data format support for NEON direct convolution
Change-Id: I5d4cc3d5b0d25f3fe4ed998c0f15b1b8e260a43a Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/125697 Tested-by: Jenkins <bsgcomp@arm.com> Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com> Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Diffstat (limited to 'src/core/NEON/kernels/NEDirectConvolutionLayerKernel.cpp')
-rw-r--r--src/core/NEON/kernels/NEDirectConvolutionLayerKernel.cpp435
1 files changed, 291 insertions, 144 deletions
diff --git a/src/core/NEON/kernels/NEDirectConvolutionLayerKernel.cpp b/src/core/NEON/kernels/NEDirectConvolutionLayerKernel.cpp
index 285ec2d0a0..5eafdf0363 100644
--- a/src/core/NEON/kernels/NEDirectConvolutionLayerKernel.cpp
+++ b/src/core/NEON/kernels/NEDirectConvolutionLayerKernel.cpp
@@ -33,6 +33,7 @@
#include "arm_compute/core/Types.h"
#include "arm_compute/core/Utils.h"
#include "arm_compute/core/Validate.h"
+#include "arm_compute/core/utils/misc/ShapeCalculator.h"
#include <algorithm>
#include <arm_neon.h>
@@ -663,6 +664,118 @@ void accumulate_results<3>(qint16_t *buffer, const qint16x8x2_t &values)
vst1_qs16(buffer, vqadd_qs16(vld1_qs16(buffer), vget_low_s16(values.val[0])));
}
+template <typename T1>
+class convolver_nhwc
+{
+public:
+ static void convolve(const Window &window, int kernel_size, unsigned int num_elems_read_per_iteration,
+ const ITensor *input, const ITensor *weights, ITensor *output, const PadStrideInfo &conv_info)
+ {
+ const int input_width = input->info()->dimension(0);
+ const int input_depth = input->info()->dimension(2);
+ const int input_stride_x = input->info()->strides_in_bytes().x();
+ const int input_stride_y = input->info()->strides_in_bytes().y();
+ const int input_stride_z = input->info()->strides_in_bytes().z();
+ const int output_stride_x = output->info()->strides_in_bytes().x();
+ const int kernel_stride_x = weights->info()->strides_in_bytes().x();
+ const int kernel_stride_y = weights->info()->strides_in_bytes().y();
+ const int kernel_stride_z = weights->info()->strides_in_bytes().z();
+ const int conv_pad_top = conv_info.pad_top();
+ const unsigned int conv_stride_x = std::get<0>(conv_info.stride());
+ const unsigned int conv_stride_y = std::get<1>(conv_info.stride());
+ const T1 zero = 0;
+
+ // Setup input window for the input iterator
+ Window window_in = window;
+ window_in.set(Window::DimX, Window::Dimension(0, 0, 0));
+ window_in.set(Window::DimY, Window::Dimension(0, 0, 0));
+ window_in.set(Window::DimZ, Window::Dimension(0, 0, 0));
+
+ // Setup input window for the output iterator
+ Window window_out = window;
+ window_out.set(Window::DimX, Window::Dimension(0, 1, 1));
+
+ // Setup input window for the weights iterator
+ Window window_k = calculate_max_window(*weights->info(), Steps());
+ window_k.set(Window::DimX, Window::Dimension(0, 1, 1));
+ window_k.set(Window::DimY, Window::Dimension(0, 1, 1));
+ window_k.set(Window::DimZ, Window::Dimension(0, 1, 1));
+ window_k.set(3, Window::Dimension(0, weights->info()->dimension(3), 1));
+
+ Iterator in(input, window_in);
+ Iterator out(output, window_out);
+ Iterator k(weights, window_k);
+
+ execute_window_loop(window_k, [&](const Coordinates & id_k)
+ {
+ execute_window_loop(window_out, [&](const Coordinates & id)
+ {
+ const auto in_y = static_cast<int>(id.y() * conv_stride_x - conv_info.pad_left());
+ const auto in_z = static_cast<int>(id.z() * conv_stride_y - conv_pad_top);
+
+ const uint8_t *in_ptr = in.ptr() + in_y * input_stride_y + in_z * input_stride_z;
+ uint8_t *out_ptr = out.ptr() + id_k[3] * output_stride_x;
+
+ T1 out_val = 0;
+
+ auto in_addr_base0 = in_ptr;
+ auto we_addr_base0 = k.ptr();
+
+ for(int z = 0; z < kernel_size; ++z, in_addr_base0 += input_stride_z, we_addr_base0 += kernel_stride_z)
+ {
+ const int in_z = id.z() * conv_stride_y + z - conv_pad_top;
+
+ if(in_z >= 0 && in_z < input_depth) // If false, pad top/bottom
+ {
+ auto in_addr_base1 = in_addr_base0;
+ auto we_addr_base1 = we_addr_base0;
+
+ for(int y = 0; y < kernel_size; ++y, in_addr_base1 += input_stride_y, we_addr_base1 += kernel_stride_y)
+ {
+ auto out_values = internal_vdupq_n(zero);
+
+ int x = 0;
+ int no_leftover = input_width - num_elems_read_per_iteration;
+
+ for(; x < no_leftover; x += num_elems_read_per_iteration)
+ {
+ const auto in_addr = reinterpret_cast<const T1 *>(in_addr_base1 + x * input_stride_x);
+ const auto in_values = internal_vld1q<1>(in_addr);
+
+ const auto we_addr = reinterpret_cast<const T1 *>(we_addr_base1 + x * kernel_stride_x);
+ const auto we_values = internal_vld1q<1>(we_addr);
+
+ out_values = internal_vmlal(out_values, in_values, we_values, 0);
+ }
+
+ out_val += out_values[0];
+ out_val += out_values[1];
+ out_val += out_values[2];
+ out_val += out_values[3];
+
+ // Leftover
+ for(; x < input_width; ++x)
+ {
+ const auto in_addr = reinterpret_cast<const T1 *>(in_addr_base1 + x * input_stride_x);
+ const auto in_value = *(in_addr);
+
+ const auto we_addr = reinterpret_cast<const T1 *>(we_addr_base1 + x * kernel_stride_x);
+ const auto we_value = *(we_addr);
+
+ out_val += in_value * we_value;
+ }
+ }
+ }
+ }
+
+ *(reinterpret_cast<T1 *>(out_ptr)) = out_val;
+ },
+ in, out);
+ },
+ k);
+ }
+};
+
template <typename T1, typename T2, unsigned int stridex>
class convolver_3x3
{
@@ -1003,35 +1116,28 @@ inline void convolve_5x5(const Window &window, unsigned int num_elems_read_per_i
}
}
-inline TensorShape get_convolved_dimensions(const ITensorInfo *input, const ITensorInfo *weights, const int kernel_size, const PadStrideInfo &conv_info)
-{
- unsigned int output_width = 0;
- unsigned int output_height = 0;
- std::tie(output_width, output_height) = scaled_dimensions(input->dimension(0), input->dimension(1), kernel_size, kernel_size, conv_info);
-
- TensorShape output_shape = input->tensor_shape();
- output_shape.set(0, output_width);
- output_shape.set(1, output_height);
- output_shape.set(2, weights->dimension(3));
-
- return output_shape;
-}
-
Status validate_arguments(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *output, const PadStrideInfo &conv_info)
{
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, weights, output);
+ ARM_COMPUTE_RETURN_ERROR_ON(input->data_layout() == DataLayout::UNKNOWN);
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights);
+ const DataLayout data_layout = input->data_layout();
+ const int width_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
+ const int height_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
+ const int channel_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL);
+
ARM_COMPUTE_RETURN_ERROR_ON_MSG(std::get<0>(conv_info.stride()) > 3, "Strides larger than 3 not supported.");
- ARM_COMPUTE_RETURN_ERROR_ON(weights->dimension(2) != input->dimension(2));
- ARM_COMPUTE_RETURN_ERROR_ON(weights->dimension(0) != weights->dimension(1));
+ ARM_COMPUTE_RETURN_ERROR_ON(weights->dimension(channel_idx) != input->dimension(channel_idx));
+ ARM_COMPUTE_RETURN_ERROR_ON(weights->dimension(width_idx) != weights->dimension(height_idx));
ARM_COMPUTE_RETURN_ERROR_ON(weights->num_dimensions() > 4);
+ ARM_COMPUTE_RETURN_ERROR_ON(data_layout == DataLayout::NHWC && input->data_type() != DataType::F32);
// Checks performed when output is configured
if(output->total_size() != 0)
{
- TensorShape output_shape = get_convolved_dimensions(input, weights, weights->dimension(0), conv_info);
+ TensorShape output_shape = misc::shape_calculator::compute_deep_convolution_shape(*input, *weights, conv_info);
DataType data_type = input->data_type();
if(is_data_type_fixed_point(data_type))
@@ -1050,101 +1156,127 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *weights,
std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *weights, ITensorInfo *output, const PadStrideInfo &conv_info, unsigned int &num_weight_elems_read_per_row,
unsigned int &num_elems_read_per_iteration, unsigned int &num_elems_written_per_iteration, BorderSize &border_size)
{
+ ARM_COMPUTE_ERROR_ON(input->data_layout() == DataLayout::UNKNOWN);
+
+ const DataLayout data_layout = input->data_layout();
+ const int width_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
+
// Calculate right and bottom border
- unsigned int kernel_size = weights->dimension(0);
+ unsigned int kernel_size = weights->dimension(width_idx);
const int conv_stride_x = std::get<0>(conv_info.stride());
const int conv_stride_y = std::get<1>(conv_info.stride());
- const int input_width = input->dimension(0);
+ const int input_width = input->dimension(width_idx);
+
+ Window win{};
+ bool window_changed = false;
- switch(kernel_size)
+ if(data_layout == DataLayout::NCHW)
{
- case 1:
+ switch(kernel_size)
{
- switch(input->data_type())
+ case 1:
{
+ switch(input->data_type())
+ {
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
- case DataType::F16:
+ case DataType::F16:
#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
- case DataType::QS8:
- case DataType::QS16:
- num_elems_written_per_iteration = 8;
- break;
- case DataType::F32:
- if(run_optim_small_tensor_info(input))
- {
+ case DataType::QS8:
+ case DataType::QS16:
num_elems_written_per_iteration = 8;
- }
- else
- {
- num_elems_written_per_iteration = 4;
- }
- break;
- default:
- ARM_COMPUTE_ERROR("Data type not supported.");
- break;
+ break;
+ case DataType::F32:
+ if(run_optim_small_tensor_info(input))
+ {
+ num_elems_written_per_iteration = 8;
+ }
+ else
+ {
+ num_elems_written_per_iteration = 4;
+ }
+ break;
+ default:
+ ARM_COMPUTE_ERROR("Data type not supported.");
+ break;
+ }
+ num_weight_elems_read_per_row = kernel_size;
+ num_elems_read_per_iteration = conv_stride_x * num_elems_written_per_iteration;
+ break;
}
- num_weight_elems_read_per_row = kernel_size;
- num_elems_read_per_iteration = conv_stride_x * num_elems_written_per_iteration;
- break;
- }
- case 3:
- case 5:
- {
- switch(input->data_type())
+ case 3:
+ case 5:
{
- case DataType::F32:
- num_weight_elems_read_per_row = 4 + kernel_size - 1;
- num_elems_read_per_iteration = 12;
- num_elems_written_per_iteration = 16 >> conv_stride_x;
- break;
+ switch(input->data_type())
+ {
+ case DataType::F32:
+ num_weight_elems_read_per_row = 4 + kernel_size - 1;
+ num_elems_read_per_iteration = 12;
+ num_elems_written_per_iteration = 16 >> conv_stride_x;
+ break;
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
- case DataType::F16:
+ case DataType::F16:
#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
- case DataType::QS8:
- case DataType::QS16:
- num_weight_elems_read_per_row = 8 + kernel_size - 1;
- num_elems_read_per_iteration = 24;
- num_elems_written_per_iteration = 32 >> conv_stride_x;
- break;
- default:
- ARM_COMPUTE_ERROR("Data type not supported.");
- break;
+ case DataType::QS8:
+ case DataType::QS16:
+ num_weight_elems_read_per_row = 8 + kernel_size - 1;
+ num_elems_read_per_iteration = 24;
+ num_elems_written_per_iteration = 32 >> conv_stride_x;
+ break;
+ default:
+ ARM_COMPUTE_ERROR("Data type not supported.");
+ break;
+ }
}
- }
- break;
- default:
- {
- ARM_COMPUTE_ERROR("Not implemented");
break;
+ default:
+ {
+ ARM_COMPUTE_ERROR("Not implemented");
+ break;
+ }
}
- }
- // Calculate right pad
- int start_x = kernel_size / 2 - static_cast<int>(conv_info.pad_left());
- int end_x = ceil_to_multiple(static_cast<int>(output->dimension(0)), num_elems_written_per_iteration) * conv_stride_x;
- int upper_bound_w = ceil_to_multiple(start_x + end_x, num_elems_read_per_iteration) - input_width;
+ // Calculate right pad
+ int start_x = kernel_size / 2 - static_cast<int>(conv_info.pad_left());
+ int end_x = ceil_to_multiple(static_cast<int>(output->dimension(0)), num_elems_written_per_iteration) * conv_stride_x;
+ int upper_bound_w = ceil_to_multiple(start_x + end_x, num_elems_read_per_iteration) - input_width;
- // Calculate border
- const unsigned int conv_pad_left = conv_info.pad_left();
- const unsigned int conv_pad_top = conv_info.pad_top();
- const unsigned int conv_pad_right = std::max(upper_bound_w, 0);
- const unsigned int conv_pad_bottom = conv_info.pad_bottom();
+ // Calculate border
+ const unsigned int conv_pad_left = conv_info.pad_left();
+ const unsigned int conv_pad_top = conv_info.pad_top();
+ const unsigned int conv_pad_right = std::max(upper_bound_w, 0);
+ const unsigned int conv_pad_bottom = conv_info.pad_bottom();
+
+ border_size.left = conv_pad_left;
+ border_size.top = conv_pad_top;
+ border_size.right = conv_pad_right;
+ border_size.bottom = conv_pad_bottom;
+
+ // Configure window
+ win = calculate_max_window(*output, Steps(num_elems_written_per_iteration));
+
+ AccessWindowRectangle input_access(input, -conv_pad_left, -conv_pad_top,
+ num_elems_read_per_iteration, kernel_size,
+ conv_stride_x, conv_stride_y);
+ AccessWindowStatic weights_access(weights, 0, 0, num_weight_elems_read_per_row, kernel_size);
+ AccessWindowHorizontal output_access(output, 0, num_elems_written_per_iteration);
+ window_changed = update_window_and_padding(win, input_access, weights_access, output_access);
+ output_access.set_valid_region(win, ValidRegion(Coordinates(), output->tensor_shape()));
+ }
+ else
+ {
+ border_size.left = 0;
+ border_size.top = conv_info.pad_left();
+ border_size.right = 0;
+ border_size.bottom = conv_info.pad_right();
- border_size.left = conv_pad_left;
- border_size.top = conv_pad_top;
- border_size.right = conv_pad_right;
- border_size.bottom = conv_pad_bottom;
+ num_elems_read_per_iteration = 16 / element_size_from_data_type(input->data_type());
- // Configure window
- Window win = calculate_max_window(*output, Steps(num_elems_written_per_iteration));
+ win = calculate_max_window(*output, Steps());
- AccessWindowRectangle input_access(input, -conv_pad_left, -conv_pad_top,
- num_elems_read_per_iteration, kernel_size,
- conv_stride_x, conv_stride_y);
- AccessWindowStatic weights_access(weights, 0, 0, num_weight_elems_read_per_row, kernel_size);
- AccessWindowHorizontal output_access(output, 0, num_elems_written_per_iteration);
- bool window_changed = update_window_and_padding(win, input_access, weights_access, output_access);
- output_access.set_valid_region(win, ValidRegion(Coordinates(), output->tensor_shape()));
+ AccessWindowRectangle input_access(input, 0, -border_size.top, num_elems_read_per_iteration, kernel_size, 1.f, conv_stride_x);
+ AccessWindowRectangle weights_access(weights, 0, 0, num_elems_read_per_iteration, kernel_size);
+ window_changed = update_window_and_padding(win, input_access, weights_access);
+ }
Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
return std::make_pair(err, win);
@@ -1170,7 +1302,7 @@ void NEDirectConvolutionLayerKernel::configure(const ITensor *input, const ITens
_weights = weights;
_output = output;
_conv_info = conv_info;
- _kernel_size = weights->info()->dimension(0);
+ _kernel_size = weights->info()->dimension(get_data_layout_dimension_index(weights->info()->data_layout(), DataLayoutDimension::WIDTH));
const unsigned int conv_pad_left = conv_info.pad_left();
const unsigned int conv_pad_top = conv_info.pad_top();
@@ -1179,7 +1311,7 @@ void NEDirectConvolutionLayerKernel::configure(const ITensor *input, const ITens
_border_size = BorderSize(conv_pad_top, conv_pad_right, conv_pad_bottom, conv_pad_left);
// Get convolved dimensions
- TensorShape output_shape = get_convolved_dimensions(input->info(), weights->info(), _kernel_size, conv_info);
+ TensorShape output_shape = misc::shape_calculator::compute_deep_convolution_shape(*input->info(), *weights->info(), conv_info);
DataType data_type = input->info()->data_type();
@@ -1229,73 +1361,88 @@ void NEDirectConvolutionLayerKernel::run(const Window &window, const ThreadInfo
ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
ARM_COMPUTE_ERROR_ON(_input->buffer() == nullptr);
- const int kernel_size = _weights->info()->dimension(0);
+ const int kernel_size = _weights->info()->dimension(get_data_layout_dimension_index(_weights->info()->data_layout(), DataLayoutDimension::WIDTH));
- switch(kernel_size)
+ if(_input->info()->data_layout() == DataLayout::NCHW)
{
- case 1:
+ switch(kernel_size)
{
- switch(_input->info()->data_type())
+ case 1:
{
- case DataType::QS8:
- convolve_1x1<qint8_t, qint16_t>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
- break;
- case DataType::QS16:
- convolve_1x1<qint16_t, qint32_t>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
- break;
- case DataType::F32:
- convolve_1x1<float, float>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
- break;
+ switch(_input->info()->data_type())
+ {
+ case DataType::QS8:
+ convolve_1x1<qint8_t, qint16_t>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
+ break;
+ case DataType::QS16:
+ convolve_1x1<qint16_t, qint32_t>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
+ break;
+ case DataType::F32:
+ convolve_1x1<float, float>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
+ break;
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
- case DataType::F16:
- convolve_1x1<float16_t, float16_t>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
- break;
+ case DataType::F16:
+ convolve_1x1<float16_t, float16_t>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
+ break;
#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
- default:
- ARM_COMPUTE_ERROR("Data type not supported");
- break;
+ default:
+ ARM_COMPUTE_ERROR("Data type not supported");
+ break;
+ }
+ break;
}
- break;
- }
- case 3:
- {
- switch(_input->info()->data_type())
+ case 3:
{
- case DataType::QS8:
- convolve_3x3<qint8_t, qint16_t>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
- break;
- case DataType::F32:
- convolve_3x3<float, float>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
- break;
+ switch(_input->info()->data_type())
+ {
+ case DataType::QS8:
+ convolve_3x3<qint8_t, qint16_t>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
+ break;
+ case DataType::F32:
+ convolve_3x3<float, float>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
+ break;
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
- case DataType::F16:
- convolve_3x3<float16_t, float16_t>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
- break;
+ case DataType::F16:
+ convolve_3x3<float16_t, float16_t>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
+ break;
#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
- default:
- ARM_COMPUTE_ERROR("Data type not supported");
- break;
+ default:
+ ARM_COMPUTE_ERROR("Data type not supported");
+ break;
+ }
+ break;
}
- break;
- }
- case 5:
- {
- switch(_input->info()->data_type())
+ case 5:
+ {
+ switch(_input->info()->data_type())
+ {
+ case DataType::F32:
+ convolve_5x5<float, float>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
+ break;
+ default:
+ ARM_COMPUTE_ERROR("Data type not supported");
+ break;
+ }
+ break;
+ }
+
+ default:
{
- case DataType::F32:
- convolve_5x5<float, float>(window, _num_elems_read_per_iteration, _num_elems_written_per_iteration, _input, _weights, _output, _conv_info);
- break;
- default:
- ARM_COMPUTE_ERROR("Data type not supported");
- break;
+ ARM_COMPUTE_ERROR("Only kernel sizes 1x1, 3x3 and 5x5 are supported.");
+ break;
}
- break;
}
-
- default:
+ }
+ else
+ {
+ switch(_input->info()->data_type())
{
- ARM_COMPUTE_ERROR("Only kernel sizes 1x1, 3x3 and 5x5 are supported.");
- break;
+ case DataType::F32:
+ convolver_nhwc<float>::convolve(window, kernel_size, _num_elems_read_per_iteration, _input, _weights, _output, _conv_info);
+ break;
+ default:
+ ARM_COMPUTE_ERROR("Data type not supported");
+ break;
}
}
}