aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon/WorkloadData.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/backendsCommon/WorkloadData.cpp')
-rw-r--r--src/backends/backendsCommon/WorkloadData.cpp85
1 files changed, 85 insertions, 0 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp
index 324c1debc0..878602391c 100644
--- a/src/backends/backendsCommon/WorkloadData.cpp
+++ b/src/backends/backendsCommon/WorkloadData.cpp
@@ -582,6 +582,91 @@ void ConcatQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
}
//---------------------------------------------------------------
+void StackQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
+{
+ ValidateNumOutputs(workloadInfo, "StackQueueDescriptor", 1);
+
+ if (m_Parameters.m_NumInputs != workloadInfo.m_InputTensorInfos.size())
+ {
+ throw InvalidArgumentException("StackQueueDescriptor: Must have the defined number of input tensors.");
+ }
+
+ // All inputs must have the same shape, which is defined in parameters
+ const TensorShape& inputShape = m_Parameters.m_InputShape;
+ for (unsigned int i = 0; i < workloadInfo.m_InputTensorInfos.size(); ++i)
+ {
+ if (workloadInfo.m_InputTensorInfos[i].GetShape() != inputShape)
+ {
+ throw InvalidArgumentException("StackQueueDescriptor: All input tensor shapes "
+ "must match the defined shape.");
+ }
+ }
+
+ // m_Axis is 0-based and may take values from 0 to the number of input dimensions (inclusive),
+ // since the output tensor has an additional dimension.
+ if (m_Parameters.m_Axis > inputShape.GetNumDimensions())
+ {
+ throw InvalidArgumentException("StackQueueDescriptor: Axis may not be greater "
+ "than the number of input dimensions.");
+ }
+
+ // Output shape must be as inferred from the input shape
+ const TensorShape& outputShape = workloadInfo.m_OutputTensorInfos[0].GetShape();
+ for (unsigned int i = 0; i < m_Parameters.m_Axis; ++i)
+ {
+ if (outputShape[i] != inputShape[i])
+ {
+ throw InvalidArgumentException("StackQueueDescriptor: Output tensor must "
+ "match shape inferred from input tensor.");
+ }
+ }
+
+ if (outputShape[m_Parameters.m_Axis] != m_Parameters.m_NumInputs)
+ {
+ throw InvalidArgumentException("StackQueueDescriptor: Output tensor must "
+ "match shape inferred from input tensor.");
+ }
+
+ for (unsigned int i = m_Parameters.m_Axis + 1; i < inputShape.GetNumDimensions() + 1; ++i)
+ {
+ if (outputShape[i] != inputShape[i-1])
+ {
+ throw InvalidArgumentException("StackQueueDescriptor: Output tensor must "
+ "match shape inferred from input tensor.");
+ }
+ }
+
+ // Check the supported data types
+ std::vector<DataType> supportedTypes =
+ {
+ DataType::Float32,
+ DataType::Float16,
+ DataType::Boolean,
+ DataType::Signed32,
+ DataType::QuantisedAsymm8,
+ DataType::QuantisedSymm16
+ };
+
+ ValidateDataTypes(workloadInfo.m_InputTensorInfos[0],
+ supportedTypes,
+ "StackQueueDescriptor");
+
+ for (unsigned int i = 1; i < workloadInfo.m_InputTensorInfos.size(); ++i)
+ {
+ ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
+ workloadInfo.m_InputTensorInfos[i],
+ "StackQueueDescriptor",
+ "InputTensor[0]",
+ "InputTensor[" + std::to_string(i) + "]");
+ }
+ ValidateTensorDataTypesMatch(workloadInfo.m_InputTensorInfos[0],
+ workloadInfo.m_OutputTensorInfos[0],
+ "StackQueueDescriptor",
+ "InputTensor[0]",
+ "OutputTensor[0]");
+}
+
+//---------------------------------------------------------------
void FullyConnectedQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
ValidateNumInputs(workloadInfo, "FullyConnectedQueueDescriptor", 1);