aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/armnn/backends/WorkloadData.cpp35
-rw-r--r--src/armnn/backends/WorkloadData.hpp9
2 files changed, 44 insertions, 0 deletions
diff --git a/src/armnn/backends/WorkloadData.cpp b/src/armnn/backends/WorkloadData.cpp
index c934a53a5e..3ed77dacdb 100644
--- a/src/armnn/backends/WorkloadData.cpp
+++ b/src/armnn/backends/WorkloadData.cpp
@@ -129,6 +129,18 @@ void ValidateTensorNumDimensions(const TensorInfo& tensor,
}
}
+void ValidateTensorMaxNumElements(const TensorInfo& tensor,
+ std::string const& descName,
+ unsigned int maxNumElements,
+ std::string const& tensorName)
+{
+ if (tensor.GetNumElements() > maxNumElements)
+ {
+ throw InvalidArgumentException(descName + ": Expected maximum of " + to_string(maxNumElements) + " but got " +
+ to_string(tensor.GetNumElements()) + " elements for " + tensorName + " tensor.");
+ }
+}
+
//---------------------------------------------------------------
void ValidateTensorDataType(const TensorInfo& tensor, DataType dataType,
const std::string& descName, std::string const& tensorName)
@@ -828,6 +840,29 @@ void MeanQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
ValidateSingleInput(workloadInfo, "MeanQueueDescriptor");
ValidateSingleOutput(workloadInfo, "MeanQueueDescriptor");
+
+ const TensorInfo& input = workloadInfo.m_InputTensorInfos[0];
+ const TensorInfo& output = workloadInfo.m_OutputTensorInfos[0];
+
+ if (m_Keepdims)
+ {
+ ValidateTensorNumDimensions(output, "MeanQueueDescriptor", input.GetNumDimensions(), "output");
+ }
+ else if (m_Axis == nullptr)
+ {
+ ValidateTensorNumDimensions(output, "MeanQueueDescriptor", 1, "output");
+ }
+ else
+ {
+ const TensorInfo& axis = m_Axis->GetTensorInfo();
+ ValidateTensorNumDimensions(axis, "MeanQueueDescriptor", 1, "axis");
+ ValidateTensorMaxNumElements(axis, "MeanQueueDescriptor", input.GetNumDimensions(), "axis");
+ unsigned int outputDim = input.GetNumDimensions() - axis.GetNumElements();
+ ValidateTensorNumDimensions(output,
+ "MeanQueueDescriptor",
+ outputDim > 0 ? outputDim : 1,
+ "output");
+ }
}
} //namespace armnn
diff --git a/src/armnn/backends/WorkloadData.hpp b/src/armnn/backends/WorkloadData.hpp
index f8f7e326dc..face761e73 100644
--- a/src/armnn/backends/WorkloadData.hpp
+++ b/src/armnn/backends/WorkloadData.hpp
@@ -199,6 +199,15 @@ struct SubtractionQueueDescriptor : QueueDescriptor
// Mean layer workload data.
struct MeanQueueDescriptor : QueueDescriptor
{
+ MeanQueueDescriptor()
+ : m_Axis(nullptr)
+ , m_Keepdims(false)
+ {
+ }
+
+ const ConstCpuTensorHandle* m_Axis;
+ bool m_Keepdims;
+
void Validate(const WorkloadInfo& workloadInfo) const;
};