aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/NEDirectConvolutionLayerKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/NEDirectConvolutionLayerKernel.cpp')
-rw-r--r--src/core/NEON/kernels/NEDirectConvolutionLayerKernel.cpp17
1 files changed, 15 insertions, 2 deletions
diff --git a/src/core/NEON/kernels/NEDirectConvolutionLayerKernel.cpp b/src/core/NEON/kernels/NEDirectConvolutionLayerKernel.cpp
index 59244c876c..f525d93e83 100644
--- a/src/core/NEON/kernels/NEDirectConvolutionLayerKernel.cpp
+++ b/src/core/NEON/kernels/NEDirectConvolutionLayerKernel.cpp
@@ -987,6 +987,7 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *weights,
ARM_COMPUTE_RETURN_ERROR_ON(weights->dimension(width_idx) != weights->dimension(height_idx));
ARM_COMPUTE_RETURN_ERROR_ON(weights->num_dimensions() > 4);
ARM_COMPUTE_RETURN_ERROR_ON(data_layout == DataLayout::NHWC && input->data_type() != DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON((weights->dimension(width_idx) > 3) && (input->data_type() == DataType::F16));
// Checks performed when output is configured
if(output->total_size() != 0)
@@ -1051,8 +1052,6 @@ std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITen
break;
}
case 3:
- case 5:
- {
switch(input->data_type())
{
case DataType::F32:
@@ -1071,6 +1070,20 @@ std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITen
ARM_COMPUTE_ERROR("Data type not supported.");
break;
}
+ break;
+ case 5:
+ {
+ switch(input->data_type())
+ {
+ case DataType::F32:
+ num_weight_elems_read_per_row = 4 + kernel_size - 1;
+ num_elems_read_per_iteration = 12;
+ num_elems_written_per_iteration = 16 >> conv_stride_x;
+ break;
+ default:
+ ARM_COMPUTE_ERROR("Data type not supported.");
+ break;
+ }
}
break;
default: