diff options
Diffstat (limited to 'src/graph/nodes/EltwiseLayerNode.cpp')
-rw-r--r-- | src/graph/nodes/EltwiseLayerNode.cpp | 86 |
1 files changed, 78 insertions, 8 deletions
diff --git a/src/graph/nodes/EltwiseLayerNode.cpp b/src/graph/nodes/EltwiseLayerNode.cpp index 92d183e693..3f7a08e64d 100644 --- a/src/graph/nodes/EltwiseLayerNode.cpp +++ b/src/graph/nodes/EltwiseLayerNode.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2020 ARM Limited. + * Copyright (c) 2018-2020 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -23,6 +23,7 @@ */ #include "arm_compute/graph/nodes/EltwiseLayerNode.h" +#include "arm_compute/core/TensorShape.h" #include "arm_compute/graph/Graph.h" #include "arm_compute/graph/INodeVisitor.h" @@ -30,8 +31,7 @@ namespace arm_compute { namespace graph { -EltwiseLayerNode::EltwiseLayerNode(const descriptors::EltwiseLayerDescriptor &descriptor) - : descriptor(descriptor) +EltwiseLayerNode::EltwiseLayerNode(const descriptors::EltwiseLayerDescriptor &descriptor) : descriptor(descriptor) { _input_edges.resize(2, EmptyEdgeID); _outputs.resize(1, NullTensorID); @@ -57,6 +57,11 @@ ActivationLayerInfo EltwiseLayerNode::fused_activation() const return descriptor.fused_activation; } +QuantizationInfo EltwiseLayerNode::output_quant_info() const +{ + return descriptor.out_quant_info; +} + void EltwiseLayerNode::set_fused_activation(ActivationLayerInfo fused_activation) { descriptor.fused_activation = fused_activation; @@ -64,7 +69,7 @@ void EltwiseLayerNode::set_fused_activation(ActivationLayerInfo fused_activation bool EltwiseLayerNode::forward_descriptors() { - if((input_id(0) != NullTensorID) && (output_id(0) != NullTensorID)) + if ((input_id(0) != NullTensorID) && (input_id(1) != NullTensorID) && (output_id(0) != NullTensorID)) { Tensor *dst = output(0); ARM_COMPUTE_ERROR_ON(dst == nullptr); @@ -78,12 +83,20 @@ TensorDescriptor EltwiseLayerNode::configure_output(size_t idx) const { ARM_COMPUTE_UNUSED(idx); - const Tensor *src = input(0); - ARM_COMPUTE_ERROR_ON(src == nullptr); + const Tensor *src1 = input(0); + ARM_COMPUTE_ERROR_ON(src1 == nullptr); - auto output_info = src->desc(); + const Tensor *src2 = input(1); + ARM_COMPUTE_ERROR_ON(src2 == nullptr); - if(!descriptor.out_quant_info.empty()) + auto output_info = src1->desc(); + + TensorShape out_shape = TensorShape::broadcast_shape(src1->desc().shape, src2->desc().shape); + ARM_COMPUTE_ERROR_ON_MSG(out_shape.total_size() == 0, "Inputs are not broadcast compatible"); + + output_info.set_shape(out_shape); + + if (!descriptor.out_quant_info.empty()) { output_info.set_quantization_info(descriptor.out_quant_info); } @@ -100,5 +113,62 @@ void EltwiseLayerNode::accept(INodeVisitor &v) { v.visit(*this); } + +UnaryEltwiseLayerNode::UnaryEltwiseLayerNode(const descriptors::UnaryEltwiseLayerDescriptor &descriptor) + : descriptor(descriptor) +{ + _input_edges.resize(1, EmptyEdgeID); + _outputs.resize(1, NullTensorID); +} + +descriptors::UnaryEltwiseLayerDescriptor UnaryEltwiseLayerNode::eltwise_descriptor() const +{ + return descriptor; +} + +void UnaryEltwiseLayerNode::set_fused_activation(ActivationLayerInfo fused_activation) +{ + descriptor.fused_activation = fused_activation; +} + +bool UnaryEltwiseLayerNode::forward_descriptors() +{ + if ((input_id(0) != NullTensorID) && (output_id(0) != NullTensorID)) + { + Tensor *dst = output(0); + ARM_COMPUTE_ERROR_ON(dst == nullptr); + dst->desc() = configure_output(0); + return true; + } + return false; +} + +TensorDescriptor UnaryEltwiseLayerNode::configure_output(size_t idx) const +{ + ARM_COMPUTE_UNUSED(idx); + + const Tensor *src = input(0); + ARM_COMPUTE_ERROR_ON(src == nullptr); + + auto output_info = src->desc(); + + if (!descriptor.out_quant_info.empty()) + { + output_info.set_quantization_info(descriptor.out_quant_info); + } + + return output_info; +} + +NodeType UnaryEltwiseLayerNode::type() const +{ + return NodeType::UnaryEltwiseLayer; +} + +void UnaryEltwiseLayerNode::accept(INodeVisitor &v) +{ + v.visit(*this); +} + } // namespace graph } // namespace arm_compute |