diff options
Diffstat (limited to 'src/runtime/CL/functions/CLDepthwiseConvolutionLayer.cpp')
-rw-r--r-- | src/runtime/CL/functions/CLDepthwiseConvolutionLayer.cpp | 8 |
1 files changed, 6 insertions, 2 deletions
diff --git a/src/runtime/CL/functions/CLDepthwiseConvolutionLayer.cpp b/src/runtime/CL/functions/CLDepthwiseConvolutionLayer.cpp index d9c21150df..5ac7a7a7c6 100644 --- a/src/runtime/CL/functions/CLDepthwiseConvolutionLayer.cpp +++ b/src/runtime/CL/functions/CLDepthwiseConvolutionLayer.cpp @@ -278,7 +278,9 @@ void CLDepthwiseConvolutionLayer::configure(ICLTensor *input, const ICLTensor *w const size_t idx_w = get_data_layout_dimension_index(input->info()->data_layout(), DataLayoutDimension::WIDTH); const size_t idx_h = get_data_layout_dimension_index(input->info()->data_layout(), DataLayoutDimension::HEIGHT); - const bool can_run_optimised_3x3_kernel = (weights->info()->dimension(idx_w) == 3) && (weights->info()->dimension(idx_h) == 3); + const GPUTarget gpu_target = CLScheduler::get().target(); + const bool can_run_optimised_3x3_kernel = (weights->info()->dimension(idx_w) == 3) && (weights->info()->dimension(idx_h) == 3) && (is_data_type_float(input->info()->data_type()) + || (get_arch_from_target(gpu_target) == GPUTarget::MIDGARD)); _needs_permute = false; _is_prepared = false; @@ -347,7 +349,9 @@ Status CLDepthwiseConvolutionLayer::validate(const ITensorInfo *input, const ITe ARM_COMPUTE_RETURN_ERROR_ON(weights->dimension(idx_w) + (weights->dimension(idx_w) - 1) * (dilation.x() - 1) > input->dimension(idx_w) + conv_info.pad_left() + conv_info.pad_right()); ARM_COMPUTE_RETURN_ERROR_ON(weights->dimension(idx_h) + (weights->dimension(idx_h) - 1) * (dilation.y() - 1) > input->dimension(idx_h) + conv_info.pad_top() + conv_info.pad_bottom()); - const bool can_run_optimised_3x3_kernel = (weights->dimension(idx_w) == 3) && (weights->dimension(idx_h) == 3); + const GPUTarget gpu_target = CLScheduler::get().target(); + const bool can_run_optimised_3x3_kernel = (weights->dimension(idx_w) == 3) && (weights->dimension(idx_h) == 3) && (is_data_type_float(input->data_type()) + || (get_arch_from_target(gpu_target) == GPUTarget::MIDGARD)); if(!can_run_optimised_3x3_kernel) { |