diff options
Diffstat (limited to 'src/graph')
-rw-r--r-- | src/graph/nodes/ConvolutionLayer.cpp | 22 |
1 files changed, 13 insertions, 9 deletions
diff --git a/src/graph/nodes/ConvolutionLayer.cpp b/src/graph/nodes/ConvolutionLayer.cpp index d8089d804d..5b3a84a4ad 100644 --- a/src/graph/nodes/ConvolutionLayer.cpp +++ b/src/graph/nodes/ConvolutionLayer.cpp @@ -26,6 +26,7 @@ #include "arm_compute/graph/Error.h" #include "arm_compute/runtime/CL/functions/CLConvolutionLayer.h" #include "arm_compute/runtime/CL/functions/CLDirectConvolutionLayer.h" +#include "arm_compute/runtime/CL/functions/CLWinogradConvolutionLayer.h" #include "arm_compute/runtime/IFunction.h" #include "arm_compute/runtime/NEON/functions/NEConvolutionLayer.h" #include "arm_compute/runtime/NEON/functions/NEDirectConvolutionLayer.h" @@ -107,8 +108,14 @@ std::unique_ptr<arm_compute::IFunction> instantiate<TargetHint::OPENCL>(arm_comp const WeightsInfo &weights_info, ConvolutionMethodHint conv_method) { - if((conv_method == ConvolutionMethodHint::DIRECT) - && arm_compute::CLDirectConvolutionLayer::validate(input->info(), weights->info(), biases != nullptr ? biases->info() : nullptr, output->info(), conv_info)) // NOLINT + if((conv_method == ConvolutionMethodHint::WINOGRAD) + && arm_compute::CLWinogradConvolutionLayer::validate(input->info(), weights->info(), biases != nullptr ? biases->info() : nullptr, output->info(), conv_info)) // NOLINT + { + ARM_COMPUTE_LOG_GRAPH_INFO("Instantiating CLWinogradConvolutionLayer"); + return instantiate_direct_function<arm_compute::CLWinogradConvolutionLayer, arm_compute::ICLTensor, TargetHint::OPENCL>(input, weights, biases, output, conv_info); + } + else if((conv_method == ConvolutionMethodHint::DIRECT) + && arm_compute::CLDirectConvolutionLayer::validate(input->info(), weights->info(), biases != nullptr ? biases->info() : nullptr, output->info(), conv_info)) // NOLINT { ARM_COMPUTE_LOG_GRAPH_INFO("Instantiating CLDirectConvolutionLayer"); return instantiate_direct_function<arm_compute::CLDirectConvolutionLayer, arm_compute::ICLTensor, TargetHint::OPENCL>(input, weights, biases, output, conv_info); @@ -159,10 +166,7 @@ class GroupedConvolutionFunction final : public arm_compute::IFunction { public: /** Default Constructor */ - GroupedConvolutionFunction() - : _convolutions() - { - } + GroupedConvolutionFunction() = default; /** Default Destructor */ ~GroupedConvolutionFunction() final = default; /** Prevent instances from being copy constructed */ @@ -177,12 +181,12 @@ public: * * @param convolution Convolution function to add */ - void add_convolution_function(std::unique_ptr<IFunction> convolution) + void add_convolution_function(std::unique_ptr<IFunction> convolution) // NOLINT { _convolutions.emplace_back(std::move(convolution)); } - // Inherited methods overriden: + // Inherited methods overridden: void run() override { for(auto &c : _convolutions) @@ -192,7 +196,7 @@ public: } private: - std::vector<std::unique_ptr<IFunction>> _convolutions; + std::vector<std::unique_ptr<IFunction>> _convolutions{}; }; std::unique_ptr<arm_compute::IFunction> ConvolutionLayer::instantiate_node(GraphContext &ctx, ITensorObject *input, ITensorObject *output) |