aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon/WorkloadData.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/backendsCommon/WorkloadData.cpp')
-rw-r--r--src/backends/backendsCommon/WorkloadData.cpp63
1 files changed, 27 insertions, 36 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp
index a26aaf490b..bd3c7c2760 100644
--- a/src/backends/backendsCommon/WorkloadData.cpp
+++ b/src/backends/backendsCommon/WorkloadData.cpp
@@ -1617,18 +1617,35 @@ void ResizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
void ReverseV2QueueDescriptor::Validate(const WorkloadInfo &workloadInfo) const {
const std::string descriptorName{"ReverseV2QueueDescriptor"};
- ValidateNumInputs(workloadInfo, descriptorName, 1);
+ // Backend restriction
+ const unsigned int maxDimensions = 4;
+
+ ValidateNumInputs(workloadInfo, descriptorName, 2);
ValidateNumOutputs(workloadInfo, descriptorName, 1);
const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
+ const TensorInfo& axisTensorInfo = workloadInfo.m_InputTensorInfos[1];
const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
- auto inputTensorNumDimensions = inputTensorInfo.GetNumDimensions();
- if (inputTensorNumDimensions > m_Parameters.m_MaxDimension)
+ const auto inputTensorNumDimensions = inputTensorInfo.GetNumDimensions();
+ if (inputTensorNumDimensions > maxDimensions)
{
throw InvalidArgumentException(descriptorName +
": Input tensors with rank greater than " +
- std::to_string(m_Parameters.m_MaxDimension) + " are not supported.");
+ std::to_string(maxDimensions) + " are not supported.");
+ }
+
+ const auto axisTensorNumDimensions = axisTensorInfo.GetNumDimensions();
+ if (axisTensorNumDimensions > maxDimensions)
+ {
+ throw InvalidArgumentException(descriptorName +
+ ": More than " + std::to_string(maxDimensions) + " axes cannot be specified.");
+ }
+
+ if (axisTensorNumDimensions > inputTensorNumDimensions)
+ {
+ throw InvalidArgumentException(descriptorName +
+ ": More axes specified than the number of axes on the input tensor.");
}
std::vector<DataType> supportedTypes =
@@ -1642,44 +1659,18 @@ void ReverseV2QueueDescriptor::Validate(const WorkloadInfo &workloadInfo) const
};
ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
- ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
- ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
- if (m_Parameters.m_Axis.size() > inputTensorNumDimensions)
- {
- throw InvalidArgumentException(descriptorName + ": More axes specified than is on the input tensor.");
- }
- if (m_Parameters.m_Axis.size() > m_Parameters.m_MaxDimension)
+ std::vector<DataType> axisSupportedTypes =
{
- throw InvalidArgumentException(descriptorName +
- ": More than " + std::to_string(m_Parameters.m_MaxDimension) + " axes cannot be specified.");
- }
+ DataType::Signed32,
+ };
- if (! m_Parameters.m_Axis.empty())
- {
- // First check that we have unique axis values
- auto checkAxis = m_Parameters.m_Axis;
- std::sort(checkAxis.begin(), checkAxis.end());
- auto lastUnique = std::unique(checkAxis.begin(), checkAxis.end());
- if (lastUnique != checkAxis.end())
- {
- throw InvalidArgumentException(descriptorName + ": Axes values must be unique.");
- }
+ ValidateDataTypes(axisTensorInfo, axisSupportedTypes, descriptorName);
- // Next check that the axes values are in range: [-rank, rank]
- const auto minmax =
- std::minmax_element(std::begin(m_Parameters.m_Axis), std::end(m_Parameters.m_Axis));
- if (((*minmax.first) < int32_t(-inputTensorNumDimensions)) ||
- ((*minmax.second) >= int32_t (inputTensorNumDimensions)))
- {
- throw InvalidArgumentException(descriptorName +
- ": Axes values must in range [-" + std::to_string(inputTensorNumDimensions) + "," +
- std::to_string(inputTensorNumDimensions) + "].");
- }
- }
+ ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
+ ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
}
-
void FakeQuantizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
const std::string descriptorName{"FakeQuantizationQueueDescriptor"};