diff options
Diffstat (limited to 'src/graph/nodes/ConvolutionLayer.cpp')
-rw-r--r-- | src/graph/nodes/ConvolutionLayer.cpp | 22 |
1 files changed, 12 insertions, 10 deletions
diff --git a/src/graph/nodes/ConvolutionLayer.cpp b/src/graph/nodes/ConvolutionLayer.cpp index 53d06ea75f..f292b893ed 100644 --- a/src/graph/nodes/ConvolutionLayer.cpp +++ b/src/graph/nodes/ConvolutionLayer.cpp @@ -106,13 +106,16 @@ std::unique_ptr<arm_compute::IFunction> instantiate<TargetHint::OPENCL>(arm_comp const WeightsInfo &weights_info, ConvolutionMethodHint conv_method) { - if(conv_method == ConvolutionMethodHint::GEMM) + if((conv_method == ConvolutionMethodHint::DIRECT) + && arm_compute::CLDirectConvolutionLayer::validate(input->info(), weights->info(), biases != nullptr ? biases->info() : nullptr, output->info(), conv_info)) // NOLINT { - return instantiate_function<arm_compute::CLConvolutionLayer, arm_compute::ICLTensor, TargetHint::OPENCL>(input, weights, biases, output, conv_info, weights_info); + ARM_COMPUTE_LOG_GRAPH_INFO("Instantiating CLDirectConvolutionLayer"); + return instantiate_direct_function<arm_compute::CLDirectConvolutionLayer, arm_compute::ICLTensor, TargetHint::OPENCL>(input, weights, biases, output, conv_info); } else { - return instantiate_direct_function<arm_compute::CLDirectConvolutionLayer, arm_compute::ICLTensor, TargetHint::OPENCL>(input, weights, biases, output, conv_info); + ARM_COMPUTE_LOG_GRAPH_INFO("Instantiating CLConvolutionLayer"); + return instantiate_function<arm_compute::CLConvolutionLayer, arm_compute::ICLTensor, TargetHint::OPENCL>(input, weights, biases, output, conv_info, weights_info); } } @@ -122,13 +125,16 @@ std::unique_ptr<arm_compute::IFunction> instantiate<TargetHint::NEON>(arm_comput const WeightsInfo &weights_info, ConvolutionMethodHint conv_method) { - if(conv_method == ConvolutionMethodHint::GEMM) + if((conv_method == ConvolutionMethodHint::DIRECT) + && arm_compute::NEDirectConvolutionLayer::validate(input->info(), weights->info(), biases != nullptr ? biases->info() : nullptr, output->info(), conv_info)) // NOLINT { - return instantiate_function<arm_compute::NEConvolutionLayer, arm_compute::ITensor, TargetHint::NEON>(input, weights, biases, output, conv_info, weights_info); + ARM_COMPUTE_LOG_GRAPH_INFO("Instantiating NEDirectConvolutionLayer"); + return instantiate_direct_function<arm_compute::NEDirectConvolutionLayer, arm_compute::ITensor, TargetHint::NEON>(input, weights, biases, output, conv_info); } else { - return instantiate_direct_function<arm_compute::NEDirectConvolutionLayer, arm_compute::ITensor, TargetHint::NEON>(input, weights, biases, output, conv_info); + ARM_COMPUTE_LOG_GRAPH_INFO("Instantiating NEConvolutionLayer"); + return instantiate_function<arm_compute::NEConvolutionLayer, arm_compute::ITensor, TargetHint::NEON>(input, weights, biases, output, conv_info, weights_info); } } } // namespace @@ -258,12 +264,10 @@ std::unique_ptr<arm_compute::IFunction> ConvolutionLayer::instantiate_convolutio std::unique_ptr<arm_compute::IFunction> func; if(_target_hint == TargetHint::OPENCL) { - ARM_COMPUTE_LOG_GRAPH_INFO("Instantiating CLConvolutionLayer"); func = instantiate<TargetHint::OPENCL>(input, _weights.tensor(), _biases.tensor(), output, _conv_info, _weights_info, conv_method_hint); } else { - ARM_COMPUTE_LOG_GRAPH_INFO("Instantiating NEConvolutionLayer"); func = instantiate<TargetHint::NEON>(input, _weights.tensor(), _biases.tensor(), output, _conv_info, _weights_info, conv_method_hint); } return func; @@ -325,12 +329,10 @@ std::unique_ptr<arm_compute::IFunction> ConvolutionLayer::instantiate_grouped_co // Instantiate convolution function if(_target_hint == TargetHint::OPENCL) { - ARM_COMPUTE_LOG_GRAPH_INFO("Instantiating CLConvolutionLayer"); func = instantiate<TargetHint::OPENCL>(_is[i].tensor(), _ws[i].tensor(), _bs[i].tensor(), _os[i].tensor(), _conv_info, _weights_info, conv_method_hint); } else { - ARM_COMPUTE_LOG_GRAPH_INFO("Instantiating NEConvolutionLayer"); func = instantiate<TargetHint::NEON>(_is[i].tensor(), _ws[i].tensor(), _bs[i].tensor(), _os[i].tensor(), _conv_info, _weights_info, conv_method_hint); } |