aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon/WorkloadData.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/backendsCommon/WorkloadData.cpp')
-rw-r--r--src/backends/backendsCommon/WorkloadData.cpp51
1 files changed, 51 insertions, 0 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp
index c8c4f9aae4..52d14097af 100644
--- a/src/backends/backendsCommon/WorkloadData.cpp
+++ b/src/backends/backendsCommon/WorkloadData.cpp
@@ -2642,4 +2642,55 @@ void SliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
}
}
+void DepthToSpaceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
+{
+ const std::string descriptorName{"DepthToSpaceQueueDescriptor"};
+
+ ValidateNumInputs(workloadInfo, descriptorName, 1);
+ ValidateNumOutputs(workloadInfo, descriptorName, 1);
+
+ const TensorInfo& inputInfo = workloadInfo.m_InputTensorInfos[0];
+ const TensorInfo& outputInfo = workloadInfo.m_OutputTensorInfos[0];
+
+ ValidateTensorNumDimensions(inputInfo, descriptorName, 4, "input");
+ ValidateTensorNumDimensions(outputInfo, descriptorName, 4, "output");
+
+ std::vector<DataType> supportedTypes =
+ {
+ DataType::Float32,
+ DataType::Float16,
+ DataType::QuantisedAsymm8,
+ DataType::QuantisedSymm16
+ };
+
+ ValidateDataTypes(inputInfo, supportedTypes, descriptorName);
+ ValidateDataTypes(outputInfo, supportedTypes, descriptorName);
+
+ ValidateTensorNumElementsMatch(inputInfo, outputInfo, descriptorName, "input", "output");
+
+ if (m_Parameters.m_BlockSize == 0)
+ {
+ throw InvalidArgumentException(descriptorName + ": Block size cannot be 0.");
+ }
+
+ DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
+ const unsigned int wIndex = dimensionIndices.GetWidthIndex();
+ const unsigned int hIndex = dimensionIndices.GetHeightIndex();
+ const unsigned int cIndex = dimensionIndices.GetChannelsIndex();
+
+ const TensorShape& outputShape = outputInfo.GetShape();
+ if (outputShape[hIndex] % m_Parameters.m_BlockSize != 0 || outputShape[wIndex] % m_Parameters.m_BlockSize != 0)
+ {
+ throw InvalidArgumentException(descriptorName + ": Output width and height shape"
+ "must be divisible by block size.");
+ }
+
+ const TensorShape& inputShape = inputInfo.GetShape();
+ if (inputShape[cIndex] % (m_Parameters.m_BlockSize * m_Parameters.m_BlockSize) != 0)
+ {
+ throw InvalidArgumentException(descriptorName + ": The depth of the input tensor"
+ "must be divisible by the square of block size." );
+ }
+}
+
} // namespace armnn