aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/layers/AdditionLayer.cpp
diff options
context:
space:
mode:
authortelsoa01 <telmo.soares@arm.com>2018-08-31 09:22:23 +0100
committertelsoa01 <telmo.soares@arm.com>2018-08-31 09:22:23 +0100
commitc577f2c6a3b4ddb6ba87a882723c53a248afbeba (patch)
treebd7d4c148df27f8be6649d313efb24f536b7cf34 /src/armnn/layers/AdditionLayer.cpp
parent4c7098bfeab1ffe1cdc77f6c15548d3e73274746 (diff)
downloadarmnn-c577f2c6a3b4ddb6ba87a882723c53a248afbeba.tar.gz
Release 18.08
Diffstat (limited to 'src/armnn/layers/AdditionLayer.cpp')
-rw-r--r--src/armnn/layers/AdditionLayer.cpp40
1 files changed, 25 insertions, 15 deletions
diff --git a/src/armnn/layers/AdditionLayer.cpp b/src/armnn/layers/AdditionLayer.cpp
index 85d12eabcb..ab73a918db 100644
--- a/src/armnn/layers/AdditionLayer.cpp
+++ b/src/armnn/layers/AdditionLayer.cpp
@@ -28,41 +28,51 @@ AdditionLayer* AdditionLayer::Clone(Graph& graph) const
return CloneBase<AdditionLayer>(graph, GetName());
}
-void AdditionLayer::ValidateTensorShapesFromInputs()
+std::vector<TensorShape> AdditionLayer::InferOutputShapes(const std::vector<TensorShape>& inputShapes) const
{
- auto& input0 = GetInputSlot(0).GetConnection()->GetTensorInfo();
- auto& input1 = GetInputSlot(1).GetConnection()->GetTensorInfo();
+ BOOST_ASSERT(inputShapes.size() == 2);
+ auto& input0 = inputShapes[0];
+ auto& input1 = inputShapes[1];
- // Get the max of the inputs
+ // Get the max of the inputs.
BOOST_ASSERT(input0.GetNumDimensions() == input1.GetNumDimensions());
unsigned int numDims = input0.GetNumDimensions();
std::vector<unsigned int> dims(numDims);
- // validate inputs are broadcast compatible
-#if !NDEBUG
for (unsigned int i = 0; i < numDims; i++)
{
- unsigned int dim0 = input0.GetShape()[i];
- unsigned int dim1 = input1.GetShape()[i];
+ unsigned int dim0 = input0[i];
+ unsigned int dim1 = input1[i];
+
+ // Validates inputs are broadcast compatible.
+#if !NDEBUG
if (dim0 != dim1)
{
BOOST_ASSERT_MSG(dim0 == 1 || dim1 == 1, "Dimensions should either match or one should be of size 1.");
}
- }
#endif
- for (unsigned int i = 0; i < numDims; i++)
- {
- unsigned int dim0 = input0.GetShape()[i];
- unsigned int dim1 = input1.GetShape()[i];
dims[i] = std::max(dim0, dim1);
}
- TensorShape outShape(numDims, dims.data());
+ return std::vector<TensorShape>({ TensorShape(numDims, dims.data()) });
+}
+
+void AdditionLayer::ValidateTensorShapesFromInputs()
+{
+ VerifyLayerConnections(2, CHECK_LOCATION());
+
+ auto inferredShapes = InferOutputShapes({
+ GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape(),
+ GetInputSlot(1).GetConnection()->GetTensorInfo().GetShape()
+ });
+
+ BOOST_ASSERT(inferredShapes.size() == 1);
+
ConditionalThrowIfNotEqual<LayerValidationException>(
"AdditionLayer: TensorShape set on OutputSlot[0] does not match the inferred shape.",
GetOutputSlot(0).GetTensorInfo().GetShape(),
- outShape);
+ inferredShapes[0]);
}
} // namespace armnn