From 652bde553f506caac4c563988dc9baf746f9584d Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Wed, 10 Jan 2018 15:33:28 +0000 Subject: COMPMID-674 - Create Google InceptionV3 example Change-Id: I389e0d4104b7dde60b7cdd612a83f3328517e44c Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/115804 Tested-by: Jenkins Reviewed-by: Anthony Barbier --- src/graph/nodes/ConvolutionLayer.cpp | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) (limited to 'src/graph/nodes/ConvolutionLayer.cpp') diff --git a/src/graph/nodes/ConvolutionLayer.cpp b/src/graph/nodes/ConvolutionLayer.cpp index 53d06ea75f..f292b893ed 100644 --- a/src/graph/nodes/ConvolutionLayer.cpp +++ b/src/graph/nodes/ConvolutionLayer.cpp @@ -106,13 +106,16 @@ std::unique_ptr instantiate(arm_comp const WeightsInfo &weights_info, ConvolutionMethodHint conv_method) { - if(conv_method == ConvolutionMethodHint::GEMM) + if((conv_method == ConvolutionMethodHint::DIRECT) + && arm_compute::CLDirectConvolutionLayer::validate(input->info(), weights->info(), biases != nullptr ? biases->info() : nullptr, output->info(), conv_info)) // NOLINT { - return instantiate_function(input, weights, biases, output, conv_info, weights_info); + ARM_COMPUTE_LOG_GRAPH_INFO("Instantiating CLDirectConvolutionLayer"); + return instantiate_direct_function(input, weights, biases, output, conv_info); } else { - return instantiate_direct_function(input, weights, biases, output, conv_info); + ARM_COMPUTE_LOG_GRAPH_INFO("Instantiating CLConvolutionLayer"); + return instantiate_function(input, weights, biases, output, conv_info, weights_info); } } @@ -122,13 +125,16 @@ std::unique_ptr instantiate(arm_comput const WeightsInfo &weights_info, ConvolutionMethodHint conv_method) { - if(conv_method == ConvolutionMethodHint::GEMM) + if((conv_method == ConvolutionMethodHint::DIRECT) + && arm_compute::NEDirectConvolutionLayer::validate(input->info(), weights->info(), biases != nullptr ? biases->info() : nullptr, output->info(), conv_info)) // NOLINT { - return instantiate_function(input, weights, biases, output, conv_info, weights_info); + ARM_COMPUTE_LOG_GRAPH_INFO("Instantiating NEDirectConvolutionLayer"); + return instantiate_direct_function(input, weights, biases, output, conv_info); } else { - return instantiate_direct_function(input, weights, biases, output, conv_info); + ARM_COMPUTE_LOG_GRAPH_INFO("Instantiating NEConvolutionLayer"); + return instantiate_function(input, weights, biases, output, conv_info, weights_info); } } } // namespace @@ -258,12 +264,10 @@ std::unique_ptr ConvolutionLayer::instantiate_convolutio std::unique_ptr func; if(_target_hint == TargetHint::OPENCL) { - ARM_COMPUTE_LOG_GRAPH_INFO("Instantiating CLConvolutionLayer"); func = instantiate(input, _weights.tensor(), _biases.tensor(), output, _conv_info, _weights_info, conv_method_hint); } else { - ARM_COMPUTE_LOG_GRAPH_INFO("Instantiating NEConvolutionLayer"); func = instantiate(input, _weights.tensor(), _biases.tensor(), output, _conv_info, _weights_info, conv_method_hint); } return func; @@ -325,12 +329,10 @@ std::unique_ptr ConvolutionLayer::instantiate_grouped_co // Instantiate convolution function if(_target_hint == TargetHint::OPENCL) { - ARM_COMPUTE_LOG_GRAPH_INFO("Instantiating CLConvolutionLayer"); func = instantiate(_is[i].tensor(), _ws[i].tensor(), _bs[i].tensor(), _os[i].tensor(), _conv_info, _weights_info, conv_method_hint); } else { - ARM_COMPUTE_LOG_GRAPH_INFO("Instantiating NEConvolutionLayer"); func = instantiate(_is[i].tensor(), _ws[i].tensor(), _bs[i].tensor(), _os[i].tensor(), _conv_info, _weights_info, conv_method_hint); } -- cgit v1.2.1