aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/layers/ElementwiseBaseLayer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/layers/ElementwiseBaseLayer.cpp')
-rw-r--r--src/armnn/layers/ElementwiseBaseLayer.cpp19
1 files changed, 2 insertions, 17 deletions
diff --git a/src/armnn/layers/ElementwiseBaseLayer.cpp b/src/armnn/layers/ElementwiseBaseLayer.cpp
index 87093f684e..87fddfea79 100644
--- a/src/armnn/layers/ElementwiseBaseLayer.cpp
+++ b/src/armnn/layers/ElementwiseBaseLayer.cpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2017-2018,2020-2021,2023 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -26,20 +26,7 @@ std::vector<TensorShape> ElementwiseBaseLayer::InferOutputShapes(const std::vect
TensorShape input0 = inputShapes[0];
TensorShape input1 = inputShapes[1];
- if (m_ShapeInferenceMethod == ShapeInferenceMethod::ValidateOnly)
- {
- if (input0.GetNumDimensions() != input1.GetNumDimensions())
- {
- std::stringstream errorMessage;
- errorMessage << GetLayerTypeAsCString(GetType()) << " layer \"" << GetName() << "\": ";
- errorMessage << "The tensor inputs to an element-wise operator are expected to have equal number of "
- "dimensions. First = "
- << input0.GetNumDimensions() << " second = " << input1.GetNumDimensions();
- throw InvalidArgumentException(errorMessage.str(), CHECK_LOCATION());
- }
- }
- else if (m_ShapeInferenceMethod == ShapeInferenceMethod::InferAndValidate &&
- inputShapes[0].GetNumDimensions() < inputShapes[1].GetNumDimensions())
+ if (inputShapes[0].GetNumDimensions() < inputShapes[1].GetNumDimensions())
{
input1 = inputShapes[0];
input0 = inputShapes[1];
@@ -55,11 +42,9 @@ std::vector<TensorShape> ElementwiseBaseLayer::InferOutputShapes(const std::vect
unsigned int dim0 = input0[i];
unsigned int dim1 = input1[i - shiftedDims];
-#if !NDEBUG
// Validate inputs are broadcast compatible.
ARMNN_ASSERT_MSG(dim0 == dim1 || dim0 == 1 || dim1 == 1,
"Dimensions should either match or one should be of size 1.");
-#endif
dims[i] = std::max(dim0, dim1);
}