aboutsummaryrefslogtreecommitdiff
path: root/src/backends/WorkloadData.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/WorkloadData.cpp')
-rw-r--r--src/backends/WorkloadData.cpp40
1 files changed, 40 insertions, 0 deletions
diff --git a/src/backends/WorkloadData.cpp b/src/backends/WorkloadData.cpp
index ef31fbd1fb..495d4ecde9 100644
--- a/src/backends/WorkloadData.cpp
+++ b/src/backends/WorkloadData.cpp
@@ -741,6 +741,46 @@ void ReshapeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
}
}
+void SpaceToBatchNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
+{
+ ValidateSingleInput(workloadInfo, "SpaceToBatchNdQueueDescriptor");
+ ValidateSingleOutput(workloadInfo, "SpaceToBatchNdQueueDescriptor");
+
+ ValidateTensorNumDimensions(workloadInfo.m_InputTensorInfos[0], "SpaceToBatchNdQueueDescriptor", 4, "input");
+ ValidateTensorNumDimensions(workloadInfo.m_OutputTensorInfos[0], "SpaceToBatchNdQueueDescriptor", 4, "output");
+
+ if (workloadInfo.m_InputTensorInfos[0].GetNumElements() != workloadInfo.m_OutputTensorInfos[0].GetNumElements())
+ {
+ throw InvalidArgumentException("SpaceToBatchNdQueueDescriptor: Input tensor has " +
+ to_string(workloadInfo.m_InputTensorInfos[0].GetNumElements()) + " but output tensor has " +
+ to_string(workloadInfo.m_OutputTensorInfos[0].GetNumElements()) + " elements.");
+ }
+
+ if (m_Parameters.m_BlockShape.size() != 2)
+ {
+ throw InvalidArgumentException("Block Shape must contains 2 spatial dimensions");
+ }
+
+ if (m_Parameters.m_BlockShape.size() != m_Parameters.m_PadList.size())
+ {
+ throw InvalidArgumentException("Pad List must contains the same number of dimensions as Block Shape.");
+ }
+
+ const TensorShape inputShape = workloadInfo.m_InputTensorInfos[0].GetShape();
+
+ std::pair<unsigned int, unsigned int> heightPad = m_Parameters.m_PadList[0];
+ std::pair<unsigned int, unsigned int> widthPad = m_Parameters.m_PadList[1];
+
+ if ((inputShape[m_Parameters.m_DataLayout.GetHeightIndex()] + heightPad.first + heightPad.second)
+ % m_Parameters.m_BlockShape[0] != 0 ||
+ (inputShape[m_Parameters.m_DataLayout.GetWidthIndex()] + widthPad.first + widthPad.second)
+ % m_Parameters.m_BlockShape[1] != 0)
+ {
+ throw InvalidArgumentException(
+ "Input shape after padding must be divisible by Block Shape in all spatial dimensions");
+ }
+}
+
void FloorQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
ValidateSingleInput(workloadInfo, "FloorQueueDescriptor");