aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon/WorkloadData.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/backendsCommon/WorkloadData.cpp')
-rw-r--r--src/backends/backendsCommon/WorkloadData.cpp36
1 files changed, 36 insertions, 0 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp
index 89277d798f..ea0e5c82b8 100644
--- a/src/backends/backendsCommon/WorkloadData.cpp
+++ b/src/backends/backendsCommon/WorkloadData.cpp
@@ -15,6 +15,7 @@
#include <boost/format.hpp>
#include <boost/numeric/conversion/cast.hpp>
+#include <TensorUtils.hpp>
using namespace armnnUtils;
@@ -485,6 +486,41 @@ void ArgMinMaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
};
ValidateDataTypes(inputTensorInfo, supportedInputTypes, descriptorName);
+
+ auto inputShape = inputTensorInfo.GetShape();
+ auto outputShape = outputTensorInfo.GetShape();
+
+ auto inputNumDimensions = inputShape.GetNumDimensions();
+ auto unsignedAxis = armnnUtils::GetUnsignedAxis(inputNumDimensions, m_Parameters.m_Axis);
+
+ const std::string outputShapeError{": Output tensor shape does not match shape inferred from input tensor."};
+
+ // 1D input shape results in scalar output shape
+ if (inputShape.GetNumDimensions() == 1)
+ {
+ if (outputShape.GetNumDimensions() != 1 && outputShape[0] != 1)
+ {
+ throw InvalidArgumentException(descriptorName + outputShapeError);
+ }
+ }
+ else
+ {
+ for (unsigned int i = 0; i < unsignedAxis; ++i)
+ {
+ if (outputShape[i] != inputShape[i])
+ {
+ throw InvalidArgumentException(descriptorName + outputShapeError);
+ }
+ }
+
+ for (auto i = unsignedAxis + 1; i < inputNumDimensions; ++i)
+ {
+ if (outputShape[i - 1] != inputShape[i])
+ {
+ throw InvalidArgumentException(descriptorName + outputShapeError);
+ }
+ }
+ }
}
void SoftmaxQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const