aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--arm_compute/graph/GraphBuilder.h13
-rw-r--r--arm_compute/graph/TypePrinter.h3
-rw-r--r--arm_compute/graph/Types.h1
-rw-r--r--arm_compute/graph/backends/FunctionHelpers.h49
-rw-r--r--arm_compute/graph/backends/ValidateHelpers.h27
-rw-r--r--arm_compute/graph/frontend/Layers.h44
-rw-r--r--arm_compute/graph/nodes/GenerateProposalsLayerNode.h60
-rw-r--r--arm_compute/graph/nodes/Nodes.h1
-rw-r--r--arm_compute/graph/nodes/NodesFwd.h1
-rw-r--r--src/graph/GraphBuilder.cpp16
-rw-r--r--src/graph/backends/CL/CLFunctionsFactory.cpp2
-rw-r--r--src/graph/backends/CL/CLNodeValidator.cpp2
-rw-r--r--src/graph/backends/GLES/GCNodeValidator.cpp2
-rw-r--r--src/graph/backends/NEON/NENodeValidator.cpp2
-rw-r--r--src/graph/nodes/GenerateProposalsLayerNode.cpp101
15 files changed, 324 insertions, 0 deletions
diff --git a/arm_compute/graph/GraphBuilder.h b/arm_compute/graph/GraphBuilder.h
index 611e69dbb3..c501006ec6 100644
--- a/arm_compute/graph/GraphBuilder.h
+++ b/arm_compute/graph/GraphBuilder.h
@@ -240,6 +240,19 @@ public:
const FullyConnectedLayerInfo fc_info = FullyConnectedLayerInfo(),
const QuantizationInfo weights_quant_info = QuantizationInfo(),
const QuantizationInfo out_quant_info = QuantizationInfo());
+ /** Adds a generate proposals layer node to the graph
+ *
+ * @param[in] g Graph to add the layer to
+ * @param[in] params Common node parameters
+ * @param[in] scores Input scores to the generate proposals layer node as a NodeID-Index pair
+ * @param[in] deltas Input deltas to the generate proposals layer node as a NodeID-Index pair
+ * @param[in] anchors Input anchors to the generate proposals layer node as a NodeID-Index pair
+ * @param[in] info Generate proposals operation information
+ *
+ * @return Node ID of the created node, EmptyNodeID in case of error
+ */
+ static NodeID add_generate_proposals_node(Graph &g, NodeParams params, NodeIdxPair scores, NodeIdxPair deltas,
+ NodeIdxPair anchors, GenerateProposalsInfo info);
/** Adds a normalization layer node to the graph
*
* @param[in] g Graph to add the node to
diff --git a/arm_compute/graph/TypePrinter.h b/arm_compute/graph/TypePrinter.h
index 697ee94331..b7dc2bb284 100644
--- a/arm_compute/graph/TypePrinter.h
+++ b/arm_compute/graph/TypePrinter.h
@@ -95,6 +95,9 @@ inline ::std::ostream &operator<<(::std::ostream &os, const NodeType &node_type)
case NodeType::FullyConnectedLayer:
os << "FullyConnectedLayer";
break;
+ case NodeType::GenerateProposalsLayer:
+ os << "GenerateProposalsLayer";
+ break;
case NodeType::NormalizationLayer:
os << "NormalizationLayer";
break;
diff --git a/arm_compute/graph/Types.h b/arm_compute/graph/Types.h
index b9589b753c..ceee776aaa 100644
--- a/arm_compute/graph/Types.h
+++ b/arm_compute/graph/Types.h
@@ -135,6 +135,7 @@ enum class NodeType
EltwiseLayer,
FlattenLayer,
FullyConnectedLayer,
+ GenerateProposalsLayer,
NormalizationLayer,
NormalizePlanarYUVLayer,
PadLayer,
diff --git a/arm_compute/graph/backends/FunctionHelpers.h b/arm_compute/graph/backends/FunctionHelpers.h
index d235fe9f6f..082d43afdb 100644
--- a/arm_compute/graph/backends/FunctionHelpers.h
+++ b/arm_compute/graph/backends/FunctionHelpers.h
@@ -622,6 +622,55 @@ std::unique_ptr<IFunction> create_fully_connected_layer(FullyConnectedLayerNode
return std::move(func);
}
+/** Create a backend generate proposals layer function
+ *
+ * @tparam GenerateProposalsLayerFunction Backend generate proposals function
+ * @tparam TargetInfo Target-specific information
+ *
+ * @param[in] node Node to create the backend function for
+ * @param[in] ctx Graph context
+ *
+ * @return Backend generate proposals layer function
+ */
+template <typename GenerateProposalsLayerFunction, typename TargetInfo>
+std::unique_ptr<IFunction> create_generate_proposals_layer(GenerateProposalsLayerNode &node, GraphContext &ctx)
+{
+ validate_node<TargetInfo>(node, 3 /* expected inputs */, 3 /* expected outputs */);
+
+ // Extract IO and info
+ typename TargetInfo::TensorType *scores = get_backing_tensor<TargetInfo>(node.input(0));
+ typename TargetInfo::TensorType *deltas = get_backing_tensor<TargetInfo>(node.input(1));
+ typename TargetInfo::TensorType *anchors = get_backing_tensor<TargetInfo>(node.input(2));
+ typename TargetInfo::TensorType *proposals = get_backing_tensor<TargetInfo>(node.output(0));
+ typename TargetInfo::TensorType *scores_out = get_backing_tensor<TargetInfo>(node.output(1));
+ typename TargetInfo::TensorType *num_valid_proposals = get_backing_tensor<TargetInfo>(node.output(2));
+ const GenerateProposalsInfo info = node.info();
+
+ ARM_COMPUTE_ERROR_ON(scores == nullptr);
+ ARM_COMPUTE_ERROR_ON(deltas == nullptr);
+ ARM_COMPUTE_ERROR_ON(anchors == nullptr);
+ ARM_COMPUTE_ERROR_ON(proposals == nullptr);
+ ARM_COMPUTE_ERROR_ON(scores_out == nullptr);
+
+ // Create and configure function
+ auto func = support::cpp14::make_unique<GenerateProposalsLayerFunction>(get_memory_manager(ctx, TargetInfo::TargetType));
+ func->configure(scores, deltas, anchors, proposals, scores_out, num_valid_proposals, info);
+
+ // Log info
+ ARM_COMPUTE_LOG_GRAPH_INFO("Instantiated " << node.type()
+ << " Target " << TargetInfo::TargetType
+ << " Data Type: " << scores->info()->data_type()
+ << " Scores shape: " << scores->info()->tensor_shape()
+ << " Deltas shape: " << deltas->info()->tensor_shape()
+ << " Anchors shape: " << anchors->info()->tensor_shape()
+ << " Proposals shape: " << proposals->info()->tensor_shape()
+ << " Num valid proposals shape: " << num_valid_proposals->info()->tensor_shape()
+ << " Scores Out shape: " << scores_out->info()->tensor_shape()
+ << std::endl);
+
+ return std::move(func);
+}
+
/** Create a backend normalization layer function
*
* @tparam NormalizationLayerFunction Backend normalization function
diff --git a/arm_compute/graph/backends/ValidateHelpers.h b/arm_compute/graph/backends/ValidateHelpers.h
index 999ce190ab..169c795fb4 100644
--- a/arm_compute/graph/backends/ValidateHelpers.h
+++ b/arm_compute/graph/backends/ValidateHelpers.h
@@ -203,6 +203,33 @@ Status validate_depthwise_convolution_layer(DepthwiseConvolutionLayerNode &node)
return status;
}
+/** Validates a Generate Proposals layer node
+ *
+ * @tparam GenerateProposalsLayer Generate Proposals layer type
+ *
+ * @param[in] node Node to validate
+ *
+ * @return Status
+ */
+template <typename GenerateProposalsLayer>
+Status validate_generate_proposals_layer(GenerateProposalsLayerNode &node)
+{
+ ARM_COMPUTE_LOG_GRAPH_VERBOSE("Validating GenerateProposalsLayer node with ID : " << node.id() << " and Name: " << node.name() << std::endl);
+ ARM_COMPUTE_RETURN_ERROR_ON(node.num_inputs() != 3);
+ ARM_COMPUTE_RETURN_ERROR_ON(node.num_outputs() != 3);
+
+ // Extract IO and info
+ arm_compute::ITensorInfo *scores = detail::get_backing_tensor_info(node.input(0));
+ arm_compute::ITensorInfo *deltas = detail::get_backing_tensor_info(node.input(1));
+ arm_compute::ITensorInfo *anchors = detail::get_backing_tensor_info(node.input(2));
+ arm_compute::ITensorInfo *proposals = get_backing_tensor_info(node.output(0));
+ arm_compute::ITensorInfo *scores_out = get_backing_tensor_info(node.output(1));
+ arm_compute::ITensorInfo *num_valid_proposals = get_backing_tensor_info(node.output(2));
+ const GenerateProposalsInfo info = node.info();
+
+ return GenerateProposalsLayer::validate(scores, deltas, anchors, proposals, scores_out, num_valid_proposals, info);
+}
+
/** Validates a NormalizePlanarYUV layer node
*
* @tparam NormalizePlanarYUVLayer layer type
diff --git a/arm_compute/graph/frontend/Layers.h b/arm_compute/graph/frontend/Layers.h
index fa0656dcdc..56dcd88077 100644
--- a/arm_compute/graph/frontend/Layers.h
+++ b/arm_compute/graph/frontend/Layers.h
@@ -531,6 +531,12 @@ public:
{
}
+ /** Create layer and add to the given stream.
+ *
+ * @param[in] s Stream to add layer to.
+ *
+ * @return ID of the created node.
+ */
NodeID create_layer(IStream &s) override
{
NodeParams common_params = { name(), s.hints().target_hint };
@@ -549,6 +555,44 @@ private:
const QuantizationInfo _out_quant_info;
};
+/** Generate Proposals Layer */
+class GenerateProposalsLayer final : public ILayer
+{
+public:
+ /** Construct a generate proposals layer.
+ *
+ * @param[in] ss_scores Graph sub-stream for the scores.
+ * @param[in] ss_deltas Graph sub-stream for the deltas.
+ * @param[in] ss_anchors Graph sub-stream for the anchors.
+ * @param[in] info Generate Proposals operation information.
+ */
+ GenerateProposalsLayer(SubStream &&ss_scores, SubStream &&ss_deltas, SubStream &&ss_anchors, GenerateProposalsInfo info)
+ : _ss_scores(std::move(ss_scores)), _ss_deltas(std::move(ss_deltas)), _ss_anchors(std::move(ss_anchors)), _info(info)
+ {
+ }
+
+ /** Create layer and add to the given stream.
+ *
+ * @param[in] s Stream to add layer to.
+ *
+ * @return ID of the created node.
+ */
+ NodeID create_layer(IStream &s) override
+ {
+ NodeParams common_params = { name(), s.hints().target_hint };
+ NodeIdxPair scores = { _ss_scores.tail_node(), 0 };
+ NodeIdxPair deltas = { _ss_deltas.tail_node(), 0 };
+ NodeIdxPair anchors = { _ss_anchors.tail_node(), 0 };
+ return GraphBuilder::add_generate_proposals_node(s.graph(), common_params, scores, deltas, anchors, _info);
+ }
+
+private:
+ SubStream _ss_scores;
+ SubStream _ss_deltas;
+ SubStream _ss_anchors;
+ GenerateProposalsInfo _info;
+};
+
/** Normalization Layer */
class NormalizationLayer final : public ILayer
{
diff --git a/arm_compute/graph/nodes/GenerateProposalsLayerNode.h b/arm_compute/graph/nodes/GenerateProposalsLayerNode.h
new file mode 100644
index 0000000000..09fbb3ee15
--- /dev/null
+++ b/arm_compute/graph/nodes/GenerateProposalsLayerNode.h
@@ -0,0 +1,60 @@
+/*
+ * Copyright (c) 2018 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#ifndef __ARM_COMPUTE_GENERATE_PROPOSALS_NODE_H__
+#define __ARM_COMPUTE_GENERATE_PROPOSALS_NODE_H__
+
+#include "arm_compute/graph/INode.h"
+
+namespace arm_compute
+{
+namespace graph
+{
+/** Generate Proposals Layer node */
+class GenerateProposalsLayerNode final : public INode
+{
+public:
+ /** Constructor
+ *
+ * @param[in] info Generate proposals operation information.
+ */
+ GenerateProposalsLayerNode(GenerateProposalsInfo &info);
+ /** GenerateProposalsInfo accessor
+ *
+ * @return GenerateProposalsInfo
+ */
+ const GenerateProposalsInfo &info() const;
+
+ // Inherited overridden methods:
+ NodeType type() const override;
+ bool forward_descriptors() override;
+ TensorDescriptor configure_output(size_t idx) const override;
+ void accept(INodeVisitor &v) override;
+
+private:
+ GenerateProposalsInfo _info;
+};
+} // namespace graph
+} // namespace arm_compute
+#endif /* __ARM_COMPUTE_GENERATE_PROPOSALS_NODE_H__ */
diff --git a/arm_compute/graph/nodes/Nodes.h b/arm_compute/graph/nodes/Nodes.h
index 6acebdc231..afd4e2cf67 100644
--- a/arm_compute/graph/nodes/Nodes.h
+++ b/arm_compute/graph/nodes/Nodes.h
@@ -37,6 +37,7 @@
#include "arm_compute/graph/nodes/EltwiseLayerNode.h"
#include "arm_compute/graph/nodes/FlattenLayerNode.h"
#include "arm_compute/graph/nodes/FullyConnectedLayerNode.h"
+#include "arm_compute/graph/nodes/GenerateProposalsLayerNode.h"
#include "arm_compute/graph/nodes/InputNode.h"
#include "arm_compute/graph/nodes/NormalizationLayerNode.h"
#include "arm_compute/graph/nodes/NormalizePlanarYUVLayerNode.h"
diff --git a/arm_compute/graph/nodes/NodesFwd.h b/arm_compute/graph/nodes/NodesFwd.h
index e7045579d3..929a4021ef 100644
--- a/arm_compute/graph/nodes/NodesFwd.h
+++ b/arm_compute/graph/nodes/NodesFwd.h
@@ -43,6 +43,7 @@ class DummyNode;
class EltwiseLayerNode;
class FlattenLayerNode;
class FullyConnectedLayerNode;
+class GenerateProposalsLayerNode;
class InputNode;
class NormalizationLayerNode;
class NormalizePlanarYUVLayerNode;
diff --git a/src/graph/GraphBuilder.cpp b/src/graph/GraphBuilder.cpp
index 1441786377..7870fb10ea 100644
--- a/src/graph/GraphBuilder.cpp
+++ b/src/graph/GraphBuilder.cpp
@@ -432,6 +432,22 @@ NodeID GraphBuilder::add_fully_connected_layer(Graph &g, NodeParams params, Node
return fc_nid;
}
+NodeID GraphBuilder::add_generate_proposals_node(Graph &g, NodeParams params, NodeIdxPair scores, NodeIdxPair deltas, NodeIdxPair anchors, GenerateProposalsInfo info)
+{
+ CHECK_NODEIDX_PAIR(scores, g);
+ CHECK_NODEIDX_PAIR(deltas, g);
+ CHECK_NODEIDX_PAIR(anchors, g);
+
+ NodeID nid = g.add_node<GenerateProposalsLayerNode>(info);
+
+ g.add_connection(scores.node_id, scores.index, nid, 0);
+ g.add_connection(deltas.node_id, deltas.index, nid, 1);
+ g.add_connection(anchors.node_id, anchors.index, nid, 2);
+
+ set_node_params(g, nid, params);
+ return nid;
+}
+
NodeID GraphBuilder::add_normalization_node(Graph &g, NodeParams params, NodeIdxPair input, NormalizationLayerInfo norm_info)
{
return create_simple_single_input_output_node<NormalizationLayerNode>(g, params, input, norm_info);
diff --git a/src/graph/backends/CL/CLFunctionsFactory.cpp b/src/graph/backends/CL/CLFunctionsFactory.cpp
index ea4e89e5c1..d627a39557 100644
--- a/src/graph/backends/CL/CLFunctionsFactory.cpp
+++ b/src/graph/backends/CL/CLFunctionsFactory.cpp
@@ -101,6 +101,8 @@ std::unique_ptr<IFunction> CLFunctionFactory::create(INode *node, GraphContext &
return detail::create_flatten_layer<CLFlattenLayer, CLTargetInfo>(*polymorphic_downcast<FlattenLayerNode *>(node));
case NodeType::FullyConnectedLayer:
return detail::create_fully_connected_layer<CLFullyConnectedLayer, CLTargetInfo>(*polymorphic_downcast<FullyConnectedLayerNode *>(node), ctx);
+ case NodeType::GenerateProposalsLayer:
+ return detail::create_generate_proposals_layer<CLGenerateProposalsLayer, CLTargetInfo>(*polymorphic_downcast<GenerateProposalsLayerNode *>(node), ctx);
case NodeType::NormalizationLayer:
return detail::create_normalization_layer<CLNormalizationLayer, CLTargetInfo>(*polymorphic_downcast<NormalizationLayerNode *>(node), ctx);
case NodeType::NormalizePlanarYUVLayer:
diff --git a/src/graph/backends/CL/CLNodeValidator.cpp b/src/graph/backends/CL/CLNodeValidator.cpp
index 2a3121fa65..9cbf4bb3eb 100644
--- a/src/graph/backends/CL/CLNodeValidator.cpp
+++ b/src/graph/backends/CL/CLNodeValidator.cpp
@@ -59,6 +59,8 @@ Status CLNodeValidator::validate(INode *node)
case NodeType::DepthwiseConvolutionLayer:
return detail::validate_depthwise_convolution_layer<CLDepthwiseConvolutionLayer,
CLDepthwiseConvolutionLayer3x3>(*polymorphic_downcast<DepthwiseConvolutionLayerNode *>(node));
+ case NodeType::GenerateProposalsLayer:
+ return detail::validate_generate_proposals_layer<CLGenerateProposalsLayer>(*polymorphic_downcast<GenerateProposalsLayerNode *>(node));
case NodeType::NormalizePlanarYUVLayer:
return detail::validate_normalize_planar_yuv_layer<CLNormalizePlanarYUVLayer>(*polymorphic_downcast<NormalizePlanarYUVLayerNode *>(node));
case NodeType::PadLayer:
diff --git a/src/graph/backends/GLES/GCNodeValidator.cpp b/src/graph/backends/GLES/GCNodeValidator.cpp
index 29f317b881..e5ba66205f 100644
--- a/src/graph/backends/GLES/GCNodeValidator.cpp
+++ b/src/graph/backends/GLES/GCNodeValidator.cpp
@@ -113,6 +113,8 @@ Status GCNodeValidator::validate(INode *node)
return validate_depthwise_convolution_layer(*polymorphic_downcast<DepthwiseConvolutionLayerNode *>(node));
case NodeType::FlattenLayer:
return ARM_COMPUTE_CREATE_ERROR(arm_compute::ErrorCode::RUNTIME_ERROR, "Unsupported operation : FlattenLayer");
+ case NodeType::GenerateProposalsLayer:
+ return ARM_COMPUTE_CREATE_ERROR(arm_compute::ErrorCode::RUNTIME_ERROR, "Unsupported operation : GenerateProposalsLayer");
case NodeType::NormalizePlanarYUVLayer:
return detail::validate_normalize_planar_yuv_layer<GCNormalizePlanarYUVLayer>(*polymorphic_downcast<NormalizePlanarYUVLayerNode *>(node));
case NodeType::PadLayer:
diff --git a/src/graph/backends/NEON/NENodeValidator.cpp b/src/graph/backends/NEON/NENodeValidator.cpp
index bf2225c02a..606cdf8291 100644
--- a/src/graph/backends/NEON/NENodeValidator.cpp
+++ b/src/graph/backends/NEON/NENodeValidator.cpp
@@ -59,6 +59,8 @@ Status NENodeValidator::validate(INode *node)
case NodeType::DepthwiseConvolutionLayer:
return detail::validate_depthwise_convolution_layer<NEDepthwiseConvolutionLayer,
NEDepthwiseConvolutionLayer3x3>(*polymorphic_downcast<DepthwiseConvolutionLayerNode *>(node));
+ case NodeType::GenerateProposalsLayer:
+ return ARM_COMPUTE_CREATE_ERROR(arm_compute::ErrorCode::RUNTIME_ERROR, "Unsupported operation : GenerateProposalsLayer");
case NodeType::NormalizePlanarYUVLayer:
return ARM_COMPUTE_CREATE_ERROR(arm_compute::ErrorCode::RUNTIME_ERROR, "Unsupported operation : NormalizePlanarYUVLayer");
case NodeType::PadLayer:
diff --git a/src/graph/nodes/GenerateProposalsLayerNode.cpp b/src/graph/nodes/GenerateProposalsLayerNode.cpp
new file mode 100644
index 0000000000..f5a3c02dd5
--- /dev/null
+++ b/src/graph/nodes/GenerateProposalsLayerNode.cpp
@@ -0,0 +1,101 @@
+/*
+ * Copyright (c) 2018 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "arm_compute/graph/nodes/GenerateProposalsLayerNode.h"
+
+#include "arm_compute/graph/Graph.h"
+#include "arm_compute/graph/INodeVisitor.h"
+
+#include "arm_compute/core/Helpers.h"
+
+namespace arm_compute
+{
+namespace graph
+{
+GenerateProposalsLayerNode::GenerateProposalsLayerNode(GenerateProposalsInfo &info)
+ : _info(info)
+{
+ _input_edges.resize(3, EmptyEdgeID);
+ _outputs.resize(3, NullTensorID);
+}
+
+const GenerateProposalsInfo &GenerateProposalsLayerNode::info() const
+{
+ return _info;
+}
+
+bool GenerateProposalsLayerNode::forward_descriptors()
+{
+ if((input_id(0) != NullTensorID) && (input_id(1) != NullTensorID) && (input_id(2) != NullTensorID) && (output_id(0) != NullTensorID) && (output_id(1) != NullTensorID)
+ && (output_id(2) != NullTensorID))
+ {
+ for(unsigned int i = 0; i < 3; ++i)
+ {
+ Tensor *dst = output(i);
+ ARM_COMPUTE_ERROR_ON(dst == nullptr);
+ dst->desc() = configure_output(i);
+ }
+ return true;
+ }
+ return false;
+}
+
+TensorDescriptor GenerateProposalsLayerNode::configure_output(size_t idx) const
+{
+ ARM_COMPUTE_ERROR_ON(idx > 3);
+
+ const Tensor *src = input(0);
+ ARM_COMPUTE_ERROR_ON(src == nullptr);
+ TensorDescriptor output_desc = src->desc();
+
+ switch(idx)
+ {
+ case 0:
+ // Configure proposals output
+ output_desc.shape = TensorShape(5, src->desc().shape.total_size());
+ break;
+ case 1:
+ // Configure scores_out output
+ output_desc.shape = TensorShape(src->desc().shape.total_size());
+ break;
+ case 2:
+ // Configure num_valid_proposals
+ output_desc.shape = TensorShape(1);
+ break;
+ default:
+ ARM_COMPUTE_ERROR("Unsupported output index");
+ }
+ return output_desc;
+}
+
+NodeType GenerateProposalsLayerNode::type() const
+{
+ return NodeType::GenerateProposalsLayer;
+}
+
+void GenerateProposalsLayerNode::accept(INodeVisitor &v)
+{
+ v.visit(*this);
+}
+} // namespace graph
+} // namespace arm_compute