aboutsummaryrefslogtreecommitdiff
path: root/src/graph/mutators/DepthConcatSubTensorMutator.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/graph/mutators/DepthConcatSubTensorMutator.cpp')
-rw-r--r--src/graph/mutators/DepthConcatSubTensorMutator.cpp39
1 files changed, 22 insertions, 17 deletions
diff --git a/src/graph/mutators/DepthConcatSubTensorMutator.cpp b/src/graph/mutators/DepthConcatSubTensorMutator.cpp
index 963b948432..1b7ee3c4a4 100644
--- a/src/graph/mutators/DepthConcatSubTensorMutator.cpp
+++ b/src/graph/mutators/DepthConcatSubTensorMutator.cpp
@@ -23,12 +23,12 @@
*/
#include "arm_compute/graph/mutators/DepthConcatSubTensorMutator.h"
-#include "arm_compute/graph/Graph.h"
-#include "arm_compute/graph/Logger.h"
-#include "arm_compute/graph/Utils.h"
#include "arm_compute/graph/algorithms/TopologicalSort.h"
#include "arm_compute/graph/backends/BackendRegistry.h"
+#include "arm_compute/graph/Graph.h"
+#include "arm_compute/graph/Logger.h"
#include "arm_compute/graph/nodes/ConcatenateLayerNode.h"
+#include "arm_compute/graph/Utils.h"
#include "support/Cast.h"
#include "support/Iterable.h"
@@ -50,7 +50,7 @@ IGraphMutator::MutationType DepthConcatSubTensorMutator::type() const
void DepthConcatSubTensorMutator::mutate(Graph &g)
{
// Early exit if no Concatenation layers exist in graph
- if(g.nodes(NodeType::ConcatenateLayer).empty())
+ if (g.nodes(NodeType::ConcatenateLayer).empty())
{
return;
}
@@ -59,43 +59,48 @@ void DepthConcatSubTensorMutator::mutate(Graph &g)
std::vector<NodeID> topological_sorted_node_ids = dfs(g);
// Should be in reverse order of execution
- for(auto &node_id : arm_compute::utils::iterable::reverse_iterate(topological_sorted_node_ids))
+ for (auto &node_id : arm_compute::utils::iterable::reverse_iterate(topological_sorted_node_ids))
{
INode *node = g.node(node_id);
- if(node != nullptr && node->type() == NodeType::ConcatenateLayer && node->output(0) != nullptr)
+ if (node != nullptr && node->type() == NodeType::ConcatenateLayer && node->output(0) != nullptr)
{
// Get output tensor
auto output_tensor = node->output(0);
// Check concatenation axis (Sub-tensor optimization is supported for concatenation axis >=2)
auto *concat_node = arm_compute::utils::cast::polymorphic_downcast<ConcatenateLayerNode *>(node);
- if(output_tensor == nullptr || get_dimension_idx(output_tensor->desc().layout, concat_node->concatenation_axis()) < 2)
+ if (output_tensor == nullptr ||
+ get_dimension_idx(output_tensor->desc().layout, concat_node->concatenation_axis()) < 2)
{
continue;
}
// Check that all tensor have the same target, valid inputs and same quantization info
- bool is_valid = std::all_of(node->input_edges().cbegin(), node->input_edges().cend(),
- [&](const EdgeID & eid)
- {
- return (g.edge(eid) != nullptr) && (g.edge(eid)->tensor() != nullptr) && (g.edge(eid)->tensor()->desc().target == output_tensor->desc().target)
- && (g.edge(eid)->tensor()->desc().quant_info == output_tensor->desc().quant_info);
- });
+ bool is_valid =
+ std::all_of(node->input_edges().cbegin(), node->input_edges().cend(),
+ [&](const EdgeID &eid)
+ {
+ return (g.edge(eid) != nullptr) && (g.edge(eid)->tensor() != nullptr) &&
+ (g.edge(eid)->tensor()->desc().target == output_tensor->desc().target) &&
+ (g.edge(eid)->tensor()->desc().quant_info == output_tensor->desc().quant_info);
+ });
// Create subtensors
- if(is_valid && is_target_supported(output_tensor->desc().target))
+ if (is_valid && is_target_supported(output_tensor->desc().target))
{
ARM_COMPUTE_LOG_GRAPH_VERBOSE("Using sub-tensors for the node with ID : "
<< node->id() << " and name : " << node->name() << std::endl);
// Create sub-tensor handles
unsigned depth = 0;
- for(unsigned int i = 0; i < node->input_edges().size(); ++i)
+ for (unsigned int i = 0; i < node->input_edges().size(); ++i)
{
auto input_tensor = node->input(i);
const auto input_shape = input_tensor->desc().shape;
- backends::IDeviceBackend &backend = backends::BackendRegistry::get().get_backend(input_tensor->desc().target);
- std::unique_ptr<ITensorHandle> handle = backend.create_subtensor(output_tensor->handle(), input_shape, Coordinates(0, 0, depth), false);
+ backends::IDeviceBackend &backend =
+ backends::BackendRegistry::get().get_backend(input_tensor->desc().target);
+ std::unique_ptr<ITensorHandle> handle =
+ backend.create_subtensor(output_tensor->handle(), input_shape, Coordinates(0, 0, depth), false);
input_tensor->set_handle(std::move(handle));
depth += input_shape.z();