aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon/WorkloadData.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/backendsCommon/WorkloadData.cpp')
-rw-r--r--src/backends/backendsCommon/WorkloadData.cpp48
1 files changed, 48 insertions, 0 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp
index b7317af9cd..adba86c79a 100644
--- a/src/backends/backendsCommon/WorkloadData.cpp
+++ b/src/backends/backendsCommon/WorkloadData.cpp
@@ -1123,6 +1123,54 @@ void SpaceToBatchNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) c
"SpaceToBatchNdQueueDescriptor");
}
+void SpaceToDepthQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
+{
+ ValidateNumInputs(workloadInfo, "SpaceToDepthQueueDescriptor", 1);
+ ValidateNumOutputs(workloadInfo, "SpaceToDepthQueueDescriptor", 1);
+
+ ValidateTensorNumDimensions(workloadInfo.m_InputTensorInfos[0],
+ "SpaceToDepthQueueDescriptor", 4, "input");
+ ValidateTensorNumDimensions(workloadInfo.m_OutputTensorInfos[0],
+ "SpaceToDepthQueueDescriptor", 4, "output");
+
+ DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
+
+ std::vector<DataType> supportedTypes =
+ {
+ DataType::Float32,
+ DataType::Float16,
+ DataType::QuantisedAsymm8
+ };
+
+ ValidateDataTypes(workloadInfo.m_InputTensorInfos[0],
+ supportedTypes,
+ "SpaceToDepthQueueDescriptor");
+ ValidateDataTypes(workloadInfo.m_OutputTensorInfos[0],
+ supportedTypes,
+ "SpaceToDepthQueueDescriptor");
+
+ const TensorShape inputShape = workloadInfo.m_InputTensorInfos[0].GetShape();
+
+ unsigned int numInputElements = inputShape[0]
+ * inputShape[dimensionIndices.GetWidthIndex()]
+ * inputShape[dimensionIndices.GetHeightIndex()]
+ * inputShape[dimensionIndices.GetChannelsIndex()];
+
+ if (workloadInfo.m_OutputTensorInfos[0].GetNumElements() != numInputElements)
+ {
+ throw InvalidArgumentException("SpaceToDepthQueueDescriptor: Input tensor has " +
+ to_string(numInputElements) + " but output tensor has " +
+ to_string(workloadInfo.m_OutputTensorInfos[0].GetNumElements()) + " elements.");
+ }
+
+ if (inputShape[dimensionIndices.GetHeightIndex()] % m_Parameters.m_BlockSize != 0 ||
+ inputShape[dimensionIndices.GetWidthIndex()] % m_Parameters.m_BlockSize != 0)
+ {
+ throw InvalidArgumentException(
+ "Input shape must be divisible by block size in all spatial dimensions");
+ }
+}
+
void FloorQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
const std::string floorQueueDescString = "FloorQueueDescriptor";