aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/layers/ElementwiseBaseLayer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/layers/ElementwiseBaseLayer.cpp')
-rw-r--r--src/armnn/layers/ElementwiseBaseLayer.cpp29
1 files changed, 18 insertions, 11 deletions
diff --git a/src/armnn/layers/ElementwiseBaseLayer.cpp b/src/armnn/layers/ElementwiseBaseLayer.cpp
index a169d31b2d..87093f684e 100644
--- a/src/armnn/layers/ElementwiseBaseLayer.cpp
+++ b/src/armnn/layers/ElementwiseBaseLayer.cpp
@@ -13,11 +13,12 @@
namespace armnn
{
-ElementwiseBaseLayer::ElementwiseBaseLayer(unsigned int numInputSlots, unsigned int numOutputSlots,
- LayerType type, const char* name)
+ElementwiseBaseLayer::ElementwiseBaseLayer(unsigned int numInputSlots,
+ unsigned int numOutputSlots,
+ LayerType type,
+ const char* name)
: Layer(numInputSlots, numOutputSlots, type, name)
-{
-}
+{}
std::vector<TensorShape> ElementwiseBaseLayer::InferOutputShapes(const std::vector<TensorShape>& inputShapes) const
{
@@ -27,7 +28,15 @@ std::vector<TensorShape> ElementwiseBaseLayer::InferOutputShapes(const std::vect
if (m_ShapeInferenceMethod == ShapeInferenceMethod::ValidateOnly)
{
- ARMNN_ASSERT(input0.GetNumDimensions() == input1.GetNumDimensions());
+ if (input0.GetNumDimensions() != input1.GetNumDimensions())
+ {
+ std::stringstream errorMessage;
+ errorMessage << GetLayerTypeAsCString(GetType()) << " layer \"" << GetName() << "\": ";
+ errorMessage << "The tensor inputs to an element-wise operator are expected to have equal number of "
+ "dimensions. First = "
+ << input0.GetNumDimensions() << " second = " << input1.GetNumDimensions();
+ throw InvalidArgumentException(errorMessage.str(), CHECK_LOCATION());
+ }
}
else if (m_ShapeInferenceMethod == ShapeInferenceMethod::InferAndValidate &&
inputShapes[0].GetNumDimensions() < inputShapes[1].GetNumDimensions())
@@ -36,7 +45,7 @@ std::vector<TensorShape> ElementwiseBaseLayer::InferOutputShapes(const std::vect
input0 = inputShapes[1];
}
- unsigned int numDims = input0.GetNumDimensions();
+ unsigned int numDims = input0.GetNumDimensions();
unsigned int shiftedDims = input0.GetNumDimensions() - input1.GetNumDimensions();
// Get the max of the inputs.
@@ -72,10 +81,8 @@ void ElementwiseBaseLayer::ValidateTensorShapesFromInputs()
VerifyShapeInferenceType(outputShape, m_ShapeInferenceMethod);
- auto inferredShapes = InferOutputShapes({
- GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape(),
- GetInputSlot(1).GetConnection()->GetTensorInfo().GetShape()
- });
+ auto inferredShapes = InferOutputShapes({ GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape(),
+ GetInputSlot(1).GetConnection()->GetTensorInfo().GetShape() });
ARMNN_ASSERT(inferredShapes.size() == 1);
@@ -87,4 +94,4 @@ void ElementwiseBaseLayer::ExecuteStrategy(IStrategy& strategy) const
strategy.ExecuteStrategy(this, BaseDescriptor(), {}, GetName());
}
-} // namespace armnn
+} // namespace armnn