diff options
Diffstat (limited to 'src/graph/mutators/NodeExecutionMethodMutator.cpp')
-rw-r--r-- | src/graph/mutators/NodeExecutionMethodMutator.cpp | 46 |
1 files changed, 25 insertions, 21 deletions
diff --git a/src/graph/mutators/NodeExecutionMethodMutator.cpp b/src/graph/mutators/NodeExecutionMethodMutator.cpp index 72e2645dd2..588befecae 100644 --- a/src/graph/mutators/NodeExecutionMethodMutator.cpp +++ b/src/graph/mutators/NodeExecutionMethodMutator.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2019 ARM Limited. + * Copyright (c) 2018-2020 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -23,13 +23,13 @@ */ #include "arm_compute/graph/mutators/NodeExecutionMethodMutator.h" +#include "arm_compute/graph/backends/BackendRegistry.h" #include "arm_compute/graph/Graph.h" #include "arm_compute/graph/Logger.h" -#include "arm_compute/graph/Utils.h" -#include "arm_compute/graph/backends/BackendRegistry.h" #include "arm_compute/graph/nodes/Nodes.h" +#include "arm_compute/graph/Utils.h" -#include "arm_compute/core/utils/misc/Cast.h" +#include "support/Cast.h" namespace arm_compute { @@ -49,17 +49,17 @@ template <typename Setter> void set_default_on_invalid_method(Graph &g, NodeType node_type, Setter &&setter) { const std::vector<NodeID> &node_ids = g.nodes(node_type); - for(auto &node_id : node_ids) + for (auto &node_id : node_ids) { INode *node = g.node(node_id); - if(node != nullptr) + if (node != nullptr) { // Validate node backends::IDeviceBackend &backend = backends::BackendRegistry::get().get_backend(node->assigned_target()); Status status = backend.validate_node(*node); // Set default execution method in case of failure - if(!bool(status)) + if (!bool(status)) { setter(node); } @@ -81,22 +81,26 @@ IGraphMutator::MutationType NodeExecutionMethodMutator::type() const void NodeExecutionMethodMutator::mutate(Graph &g) { // Convolution Layer - set_default_on_invalid_method(g, NodeType::ConvolutionLayer, [](INode * n) - { - ARM_COMPUTE_LOG_GRAPH_INFO("Switched ConvolutionLayer method of node with ID : " - << n->id() << " and Name: " << n->name() << std::endl); - auto *casted_node = arm_compute::utils::cast::polymorphic_downcast<ConvolutionLayerNode *>(n); - casted_node->set_convolution_method(ConvolutionMethod::Default); - }); + set_default_on_invalid_method(g, NodeType::ConvolutionLayer, + [](INode *n) + { + ARM_COMPUTE_LOG_GRAPH_INFO("Switched ConvolutionLayer method of node with ID : " + << n->id() << " and Name: " << n->name() << std::endl); + auto *casted_node = + arm_compute::utils::cast::polymorphic_downcast<ConvolutionLayerNode *>(n); + casted_node->set_convolution_method(ConvolutionMethod::Default); + }); // Depthwise Convolution Layer - set_default_on_invalid_method(g, NodeType::DepthwiseConvolutionLayer, [](INode * n) - { - ARM_COMPUTE_LOG_GRAPH_INFO("Switched Depthwise ConvolutionLayer method of node with ID : " - << n->id() << " and Name: " << n->name() << std::endl); - auto *casted_node = arm_compute::utils::cast::polymorphic_downcast<DepthwiseConvolutionLayerNode *>(n); - casted_node->set_depthwise_convolution_method(DepthwiseConvolutionMethod::Default); - }); + set_default_on_invalid_method( + g, NodeType::DepthwiseConvolutionLayer, + [](INode *n) + { + ARM_COMPUTE_LOG_GRAPH_INFO("Switched Depthwise ConvolutionLayer method of node with ID : " + << n->id() << " and Name: " << n->name() << std::endl); + auto *casted_node = arm_compute::utils::cast::polymorphic_downcast<DepthwiseConvolutionLayerNode *>(n); + casted_node->set_depthwise_convolution_method(DepthwiseConvolutionMethod::Default); + }); } } // namespace graph } // namespace arm_compute |