aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/Layer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/Layer.cpp')
-rw-r--r--src/armnn/Layer.cpp41
1 files changed, 36 insertions, 5 deletions
diff --git a/src/armnn/Layer.cpp b/src/armnn/Layer.cpp
index 3241b5024e..b1d495244d 100644
--- a/src/armnn/Layer.cpp
+++ b/src/armnn/Layer.cpp
@@ -11,6 +11,8 @@
#include <armnn/utility/NumericCast.hpp>
+#include <armnnUtils/TensorUtils.hpp>
+
#include <client/include/IProfilingService.hpp>
#include <fmt/format.h>
@@ -425,11 +427,40 @@ void Layer::ValidateAndCopyShape(const TensorShape& outputShape,
{
if (shapeInferenceMethod == ShapeInferenceMethod::ValidateOnly)
{
- ConditionalThrowIfNotEqual<LayerValidationException>(
- layerName + ": TensorShape set on OutputSlot[0] does not match the inferred shape.",
- outputShape,
- inferredShape);
- return;
+ if (m_AllowExpandedDims)
+ {
+ std::vector<unsigned int> outputDims = armnnUtils::SqueezeDims(outputShape);
+ std::vector<unsigned int> inferredDims = armnnUtils::SqueezeDims(inferredShape);
+
+ if (outputDims.size() != inferredDims.size())
+ {
+ std::stringstream ss;
+ ss << layerName << ": TensorShape set on OutputSlot[" << outputSlotIndex <<
+ "] does not match the inferred shape. ";
+ ss << outputShape << " != " << inferredShape;
+ throw LayerValidationException(ss.str());
+ }
+ for (unsigned int i = 0; i < outputDims.size(); ++i)
+ {
+ if (outputDims[i] != inferredDims[i])
+ {
+ std::stringstream ss;
+ ss << layerName << ": TensorShape set on OutputSlot[" << outputSlotIndex <<
+ "] does not match the inferred shape at dimension index [";
+ ss << i << "] " << outputShape << " != " << inferredShape;
+ throw LayerValidationException(ss.str());
+ }
+ }
+ return;
+ }
+ else
+ {
+ ConditionalThrowIfNotEqual<LayerValidationException>(
+ layerName + ": TensorShape set on OutputSlot[0] does not match the inferred shape.",
+ outputShape,
+ inferredShape);
+ return;
+ }
}
if (outputShape.GetDimensionality() == Dimensionality::Specified)