aboutsummaryrefslogtreecommitdiff
path: root/src/graph/nodes/EltwiseLayerNode.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/graph/nodes/EltwiseLayerNode.cpp')
-rw-r--r--src/graph/nodes/EltwiseLayerNode.cpp86
1 files changed, 78 insertions, 8 deletions
diff --git a/src/graph/nodes/EltwiseLayerNode.cpp b/src/graph/nodes/EltwiseLayerNode.cpp
index 92d183e693..3f7a08e64d 100644
--- a/src/graph/nodes/EltwiseLayerNode.cpp
+++ b/src/graph/nodes/EltwiseLayerNode.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2020 ARM Limited.
+ * Copyright (c) 2018-2020 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -23,6 +23,7 @@
*/
#include "arm_compute/graph/nodes/EltwiseLayerNode.h"
+#include "arm_compute/core/TensorShape.h"
#include "arm_compute/graph/Graph.h"
#include "arm_compute/graph/INodeVisitor.h"
@@ -30,8 +31,7 @@ namespace arm_compute
{
namespace graph
{
-EltwiseLayerNode::EltwiseLayerNode(const descriptors::EltwiseLayerDescriptor &descriptor)
- : descriptor(descriptor)
+EltwiseLayerNode::EltwiseLayerNode(const descriptors::EltwiseLayerDescriptor &descriptor) : descriptor(descriptor)
{
_input_edges.resize(2, EmptyEdgeID);
_outputs.resize(1, NullTensorID);
@@ -57,6 +57,11 @@ ActivationLayerInfo EltwiseLayerNode::fused_activation() const
return descriptor.fused_activation;
}
+QuantizationInfo EltwiseLayerNode::output_quant_info() const
+{
+ return descriptor.out_quant_info;
+}
+
void EltwiseLayerNode::set_fused_activation(ActivationLayerInfo fused_activation)
{
descriptor.fused_activation = fused_activation;
@@ -64,7 +69,7 @@ void EltwiseLayerNode::set_fused_activation(ActivationLayerInfo fused_activation
bool EltwiseLayerNode::forward_descriptors()
{
- if((input_id(0) != NullTensorID) && (output_id(0) != NullTensorID))
+ if ((input_id(0) != NullTensorID) && (input_id(1) != NullTensorID) && (output_id(0) != NullTensorID))
{
Tensor *dst = output(0);
ARM_COMPUTE_ERROR_ON(dst == nullptr);
@@ -78,12 +83,20 @@ TensorDescriptor EltwiseLayerNode::configure_output(size_t idx) const
{
ARM_COMPUTE_UNUSED(idx);
- const Tensor *src = input(0);
- ARM_COMPUTE_ERROR_ON(src == nullptr);
+ const Tensor *src1 = input(0);
+ ARM_COMPUTE_ERROR_ON(src1 == nullptr);
- auto output_info = src->desc();
+ const Tensor *src2 = input(1);
+ ARM_COMPUTE_ERROR_ON(src2 == nullptr);
- if(!descriptor.out_quant_info.empty())
+ auto output_info = src1->desc();
+
+ TensorShape out_shape = TensorShape::broadcast_shape(src1->desc().shape, src2->desc().shape);
+ ARM_COMPUTE_ERROR_ON_MSG(out_shape.total_size() == 0, "Inputs are not broadcast compatible");
+
+ output_info.set_shape(out_shape);
+
+ if (!descriptor.out_quant_info.empty())
{
output_info.set_quantization_info(descriptor.out_quant_info);
}
@@ -100,5 +113,62 @@ void EltwiseLayerNode::accept(INodeVisitor &v)
{
v.visit(*this);
}
+
+UnaryEltwiseLayerNode::UnaryEltwiseLayerNode(const descriptors::UnaryEltwiseLayerDescriptor &descriptor)
+ : descriptor(descriptor)
+{
+ _input_edges.resize(1, EmptyEdgeID);
+ _outputs.resize(1, NullTensorID);
+}
+
+descriptors::UnaryEltwiseLayerDescriptor UnaryEltwiseLayerNode::eltwise_descriptor() const
+{
+ return descriptor;
+}
+
+void UnaryEltwiseLayerNode::set_fused_activation(ActivationLayerInfo fused_activation)
+{
+ descriptor.fused_activation = fused_activation;
+}
+
+bool UnaryEltwiseLayerNode::forward_descriptors()
+{
+ if ((input_id(0) != NullTensorID) && (output_id(0) != NullTensorID))
+ {
+ Tensor *dst = output(0);
+ ARM_COMPUTE_ERROR_ON(dst == nullptr);
+ dst->desc() = configure_output(0);
+ return true;
+ }
+ return false;
+}
+
+TensorDescriptor UnaryEltwiseLayerNode::configure_output(size_t idx) const
+{
+ ARM_COMPUTE_UNUSED(idx);
+
+ const Tensor *src = input(0);
+ ARM_COMPUTE_ERROR_ON(src == nullptr);
+
+ auto output_info = src->desc();
+
+ if (!descriptor.out_quant_info.empty())
+ {
+ output_info.set_quantization_info(descriptor.out_quant_info);
+ }
+
+ return output_info;
+}
+
+NodeType UnaryEltwiseLayerNode::type() const
+{
+ return NodeType::UnaryEltwiseLayer;
+}
+
+void UnaryEltwiseLayerNode::accept(INodeVisitor &v)
+{
+ v.visit(*this);
+}
+
} // namespace graph
} // namespace arm_compute