aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorthecha01 <theo.charalambous@arm.com>2020-07-28 17:45:07 +0100
committerMichele Di Giorgio <michele.digiorgio@arm.com>2020-08-25 10:17:55 +0000
commit96a14008af85725d067cdd8247023474581102ea (patch)
treea920ff3766133ec0012964d2fbfbdfa39aed1d5a
parent90251bc1e162925c4b85e8a2923af153af23da93 (diff)
downloadComputeLibrary-96a14008af85725d067cdd8247023474581102ea.tar.gz
Fix EltwiseLayerNode and QuantizationLayerNode
- Fixed issue where EltwiseLayerNode would base output shape off of first input tensor only - Allow QuantizationLayerNode to use any quantized data type if specified in constructor Signed-off-by: thecha01 <theo.charalambous@arm.com> Change-Id: Ib93470316799028cd573592a3d79943493eaa093 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/3737 Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com> Reviewed-by: Manuel Bottini <manuel.bottini@arm.com>
-rw-r--r--arm_compute/graph/nodes/QuantizationLayerNode.h10
-rw-r--r--src/graph/nodes/EltwiseLayerNode.cpp17
-rw-r--r--src/graph/nodes/QuantizationLayerNode.cpp13
3 files changed, 32 insertions, 8 deletions
diff --git a/arm_compute/graph/nodes/QuantizationLayerNode.h b/arm_compute/graph/nodes/QuantizationLayerNode.h
index 94c718babb..e5d81afa0e 100644
--- a/arm_compute/graph/nodes/QuantizationLayerNode.h
+++ b/arm_compute/graph/nodes/QuantizationLayerNode.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019 Arm Limited.
+ * Copyright (c) 2019-2020 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -40,6 +40,13 @@ public:
*/
QuantizationLayerNode(QuantizationInfo out_quant_info);
+ /** Constructor
+ *
+ * @param[in] out_quant_info Output quantization info
+ * @param[in] out_data_type Output data type
+ */
+ QuantizationLayerNode(QuantizationInfo out_quant_info, DataType out_data_type);
+
// Inherited overridden methods:
NodeType type() const override;
bool forward_descriptors() override;
@@ -50,6 +57,7 @@ public:
private:
QuantizationInfo _out_quant_info;
+ DataType _out_data_type;
};
} // namespace graph
} // namespace arm_compute
diff --git a/src/graph/nodes/EltwiseLayerNode.cpp b/src/graph/nodes/EltwiseLayerNode.cpp
index 3149a9afef..4426e953ee 100644
--- a/src/graph/nodes/EltwiseLayerNode.cpp
+++ b/src/graph/nodes/EltwiseLayerNode.cpp
@@ -23,6 +23,7 @@
*/
#include "arm_compute/graph/nodes/EltwiseLayerNode.h"
+#include "arm_compute/core/TensorShape.h"
#include "arm_compute/graph/Graph.h"
#include "arm_compute/graph/INodeVisitor.h"
@@ -69,7 +70,7 @@ void EltwiseLayerNode::set_fused_activation(ActivationLayerInfo fused_activation
bool EltwiseLayerNode::forward_descriptors()
{
- if((input_id(0) != NullTensorID) && (output_id(0) != NullTensorID))
+ if((input_id(0) != NullTensorID) && (input_id(1) != NullTensorID) && (output_id(0) != NullTensorID))
{
Tensor *dst = output(0);
ARM_COMPUTE_ERROR_ON(dst == nullptr);
@@ -83,10 +84,18 @@ TensorDescriptor EltwiseLayerNode::configure_output(size_t idx) const
{
ARM_COMPUTE_UNUSED(idx);
- const Tensor *src = input(0);
- ARM_COMPUTE_ERROR_ON(src == nullptr);
+ const Tensor *src1 = input(0);
+ ARM_COMPUTE_ERROR_ON(src1 == nullptr);
- auto output_info = src->desc();
+ const Tensor *src2 = input(1);
+ ARM_COMPUTE_ERROR_ON(src2 == nullptr);
+
+ auto output_info = src1->desc();
+
+ TensorShape out_shape = TensorShape::broadcast_shape(src1->desc().shape, src2->desc().shape);
+ ARM_COMPUTE_ERROR_ON_MSG(out_shape.total_size() == 0, "Inputs are not broadcast compatible");
+
+ output_info.set_shape(out_shape);
if(!descriptor.out_quant_info.empty())
{
diff --git a/src/graph/nodes/QuantizationLayerNode.cpp b/src/graph/nodes/QuantizationLayerNode.cpp
index db70c2c312..08e2a4d961 100644
--- a/src/graph/nodes/QuantizationLayerNode.cpp
+++ b/src/graph/nodes/QuantizationLayerNode.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019 Arm Limited.
+ * Copyright (c) 2019-2020 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -31,8 +31,15 @@ namespace arm_compute
namespace graph
{
QuantizationLayerNode::QuantizationLayerNode(QuantizationInfo out_quant_info)
- : _out_quant_info(std::move(out_quant_info))
+ : QuantizationLayerNode(out_quant_info, DataType::QASYMM8)
{
+}
+
+QuantizationLayerNode::QuantizationLayerNode(QuantizationInfo out_quant_info, DataType out_data_type)
+ : _out_quant_info(std::move(out_quant_info)), _out_data_type(out_data_type)
+{
+ ARM_COMPUTE_ERROR_ON(!is_data_type_quantized(out_data_type));
+
_input_edges.resize(1, EmptyEdgeID);
_outputs.resize(1, NullTensorID);
}
@@ -58,7 +65,7 @@ TensorDescriptor QuantizationLayerNode::configure_output(size_t idx) const
ARM_COMPUTE_ERROR_ON(src == nullptr);
TensorDescriptor output_info = src->desc();
- output_info.data_type = DataType::QASYMM8;
+ output_info.data_type = _out_data_type;
output_info.quant_info = _out_quant_info;
return output_info;