diff options
Diffstat (limited to 'src/runtime/CL/functions/CLDeconvolutionLayer.cpp')
-rw-r--r-- | src/runtime/CL/functions/CLDeconvolutionLayer.cpp | 16 |
1 files changed, 9 insertions, 7 deletions
diff --git a/src/runtime/CL/functions/CLDeconvolutionLayer.cpp b/src/runtime/CL/functions/CLDeconvolutionLayer.cpp index ea7f3e75f7..56e9dae45d 100644 --- a/src/runtime/CL/functions/CLDeconvolutionLayer.cpp +++ b/src/runtime/CL/functions/CLDeconvolutionLayer.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2022 Arm Limited. + * Copyright (c) 2017-2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -146,11 +146,6 @@ DeconvolutionMethod CLDeconvolutionLayer::get_deconvolution_method(const ITensor return DeconvolutionMethod::UPSCALE_CONV2D; } - if(input->data_layout() == DataLayout::NHWC) - { - return DeconvolutionMethod::DIRECT; - } - const DataLayout data_layout = input->data_layout(); const size_t idx_w = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH); @@ -158,7 +153,14 @@ DeconvolutionMethod CLDeconvolutionLayer::get_deconvolution_method(const ITensor if(weights->dimension(idx_w) != deconv_info.stride().first || weights->dimension(idx_h) != deconv_info.stride().second) { - return DeconvolutionMethod::UPSCALE_CONV2D; + if(input->data_layout() == DataLayout::NHWC) + { + return DeconvolutionMethod::DIRECT; + } + else + { + return DeconvolutionMethod::UPSCALE_CONV2D; + } } return DeconvolutionMethod::GEMM; |