aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon/WorkloadData.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/backendsCommon/WorkloadData.cpp')
-rw-r--r--src/backends/backendsCommon/WorkloadData.cpp57
1 files changed, 56 insertions, 1 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp
index f290cbd9cf..2fa0c92daf 100644
--- a/src/backends/backendsCommon/WorkloadData.cpp
+++ b/src/backends/backendsCommon/WorkloadData.cpp
@@ -2443,7 +2443,7 @@ void QuantizedLstmQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) co
ValidateTensorQuantizationSpace(inputInfo, outputStateInInfo, descriptorName, "input", "outputStateIn");
ValidateTensorQuantizationSpace(inputInfo, outputStateOutInfo, descriptorName, "input", "outputStateOut");
ValidateTensorQuantizationSpace(cellStateInInfo, cellStateOutInfo, descriptorName, "cellStateIn", "cellStateOut");
-
+
// Infer number of batches, input size and output size from tensor dimensions
const uint32_t numBatches = inputInfo.GetShape()[0];
const uint32_t inputSize = inputInfo.GetShape()[1];
@@ -2584,4 +2584,59 @@ void AbsQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
}
+void SliceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
+{
+ const std::string descriptorName{"SliceQueueDescriptor"};
+
+ ValidateNumInputs(workloadInfo, descriptorName, 1);
+ ValidateNumOutputs(workloadInfo, descriptorName, 1);
+
+ const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
+ const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
+
+ ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
+
+ const unsigned int rank = inputTensorInfo.GetNumDimensions();
+ if (rank > 4)
+ {
+ throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than 4 are not supported.");
+ }
+
+ ValidateTensorNumDimensions(outputTensorInfo, descriptorName, rank, "output");
+
+ // Check if m_Begin and m_Size have the expected length
+ if (m_Parameters.m_Begin.size() != rank)
+ {
+ throw InvalidArgumentException(descriptorName +
+ ": Length of begin offset descriptor must equal rank " + std::to_string(rank));
+ }
+ if (m_Parameters.m_Size.size() != rank)
+ {
+ throw InvalidArgumentException(descriptorName +
+ ": Length of size descriptor must equal rank " + std::to_string(rank));
+ }
+
+ // Check if the shape of the output tensor matches m_Size
+ const TensorShape& outputShape = outputTensorInfo.GetShape();
+ for (unsigned int i = 0u; i < rank; ++i)
+ {
+ if (m_Parameters.m_Size[i] != outputShape[i])
+ {
+ throw InvalidArgumentException(descriptorName + ": Size descriptor does not match output tensor.");
+ }
+ }
+
+ // Check if the sum of begin offset and size in a given dimension
+ // does not exceed the size of corresponding input
+ const TensorShape& inputShape = inputTensorInfo.GetShape();
+ for(unsigned int i = 0u; i < rank; ++i)
+ {
+ if (m_Parameters.m_Begin[i] + m_Parameters.m_Size[i] >= inputShape[i])
+ {
+ throw InvalidArgumentException(descriptorName + ": Sum of begin offset and size for dimension " +
+ std::to_string(i) + " exceeds input size.");
+ }
+ }
+}
+
} // namespace armnn