aboutsummaryrefslogtreecommitdiff
path: root/src/graph/nodes/DepthConcatenateLayerNode.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/graph/nodes/DepthConcatenateLayerNode.cpp')
-rw-r--r--src/graph/nodes/DepthConcatenateLayerNode.cpp38
1 files changed, 15 insertions, 23 deletions
diff --git a/src/graph/nodes/DepthConcatenateLayerNode.cpp b/src/graph/nodes/DepthConcatenateLayerNode.cpp
index 1c0539744f..08cccc1ff1 100644
--- a/src/graph/nodes/DepthConcatenateLayerNode.cpp
+++ b/src/graph/nodes/DepthConcatenateLayerNode.cpp
@@ -34,7 +34,7 @@ namespace graph
DepthConcatenateLayerNode::DepthConcatenateLayerNode(unsigned int total_nodes)
: _total_nodes(total_nodes), _is_enabled(true)
{
- _input_edges.resize(total_nodes, EmptyEdgeID);
+ _input_edges.resize(_total_nodes, EmptyEdgeID);
_outputs.resize(1, NullTensorID);
}
@@ -48,28 +48,28 @@ bool DepthConcatenateLayerNode::is_enabled() const
return _is_enabled;
}
-TensorShape DepthConcatenateLayerNode::compute_output_shape(const std::vector<TensorShape> &input_shapes)
+TensorDescriptor DepthConcatenateLayerNode::compute_output_descriptor(const std::vector<TensorDescriptor> &input_descriptors)
{
- ARM_COMPUTE_ERROR_ON(input_shapes.size() == 0);
+ ARM_COMPUTE_ERROR_ON(input_descriptors.size() == 0);
- TensorShape output_shape = input_shapes[0];
+ TensorDescriptor output_descriptor = input_descriptors[0];
size_t max_x = 0;
size_t max_y = 0;
size_t depth = 0;
- for(const auto &shape : input_shapes)
+ for(const auto &input_descriptor : input_descriptors)
{
- max_x = std::max(shape.x(), max_x);
- max_y = std::max(shape.y(), max_y);
- depth += shape.z();
+ max_x = std::max(input_descriptor.shape.x(), max_x);
+ max_y = std::max(input_descriptor.shape.y(), max_y);
+ depth += input_descriptor.shape.z();
}
- output_shape.set(0, max_x);
- output_shape.set(1, max_y);
- output_shape.set(2, depth);
+ output_descriptor.shape.set(0, max_x);
+ output_descriptor.shape.set(1, max_y);
+ output_descriptor.shape.set(2, depth);
- return output_shape;
+ return output_descriptor;
}
bool DepthConcatenateLayerNode::forward_descriptors()
@@ -99,27 +99,19 @@ TensorDescriptor DepthConcatenateLayerNode::configure_output(size_t idx) const
if(are_all_inputs_set)
{
- std::vector<TensorShape> inputs_shapes;
+ std::vector<TensorDescriptor> inputs_descriptors;
for(unsigned int i = 0; i < _input_edges.size(); ++i)
{
const Tensor *t = _graph->tensor(input_id(i));
ARM_COMPUTE_ERROR_ON(t == nullptr);
- inputs_shapes.push_back(t->desc().shape);
+ inputs_descriptors.push_back(t->desc());
}
- output_info = input(0)->desc();
- TensorShape output_shape = compute_output_shape(inputs_shapes);
- output_info.shape = output_shape;
+ output_info = compute_output_descriptor(inputs_descriptors);
}
return output_info;
}
-Status DepthConcatenateLayerNode::validate()
-{
- ARM_COMPUTE_UNUSED(_total_nodes);
- return Status{};
-}
-
NodeType DepthConcatenateLayerNode::type() const
{
return NodeType::DepthConcatenateLayer;