diff options
Diffstat (limited to 'src/backends/backendsCommon/WorkloadData.cpp')
-rw-r--r-- | src/backends/backendsCommon/WorkloadData.cpp | 63 |
1 files changed, 27 insertions, 36 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp index a26aaf490b..bd3c7c2760 100644 --- a/src/backends/backendsCommon/WorkloadData.cpp +++ b/src/backends/backendsCommon/WorkloadData.cpp @@ -1617,18 +1617,35 @@ void ResizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const void ReverseV2QueueDescriptor::Validate(const WorkloadInfo &workloadInfo) const { const std::string descriptorName{"ReverseV2QueueDescriptor"}; - ValidateNumInputs(workloadInfo, descriptorName, 1); + // Backend restriction + const unsigned int maxDimensions = 4; + + ValidateNumInputs(workloadInfo, descriptorName, 2); ValidateNumOutputs(workloadInfo, descriptorName, 1); const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0]; + const TensorInfo& axisTensorInfo = workloadInfo.m_InputTensorInfos[1]; const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0]; - auto inputTensorNumDimensions = inputTensorInfo.GetNumDimensions(); - if (inputTensorNumDimensions > m_Parameters.m_MaxDimension) + const auto inputTensorNumDimensions = inputTensorInfo.GetNumDimensions(); + if (inputTensorNumDimensions > maxDimensions) { throw InvalidArgumentException(descriptorName + ": Input tensors with rank greater than " + - std::to_string(m_Parameters.m_MaxDimension) + " are not supported."); + std::to_string(maxDimensions) + " are not supported."); + } + + const auto axisTensorNumDimensions = axisTensorInfo.GetNumDimensions(); + if (axisTensorNumDimensions > maxDimensions) + { + throw InvalidArgumentException(descriptorName + + ": More than " + std::to_string(maxDimensions) + " axes cannot be specified."); + } + + if (axisTensorNumDimensions > inputTensorNumDimensions) + { + throw InvalidArgumentException(descriptorName + + ": More axes specified than the number of axes on the input tensor."); } std::vector<DataType> supportedTypes = @@ -1642,44 +1659,18 @@ void ReverseV2QueueDescriptor::Validate(const WorkloadInfo &workloadInfo) const }; ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName); - ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output"); - ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output"); - if (m_Parameters.m_Axis.size() > inputTensorNumDimensions) - { - throw InvalidArgumentException(descriptorName + ": More axes specified than is on the input tensor."); - } - if (m_Parameters.m_Axis.size() > m_Parameters.m_MaxDimension) + std::vector<DataType> axisSupportedTypes = { - throw InvalidArgumentException(descriptorName + - ": More than " + std::to_string(m_Parameters.m_MaxDimension) + " axes cannot be specified."); - } + DataType::Signed32, + }; - if (! m_Parameters.m_Axis.empty()) - { - // First check that we have unique axis values - auto checkAxis = m_Parameters.m_Axis; - std::sort(checkAxis.begin(), checkAxis.end()); - auto lastUnique = std::unique(checkAxis.begin(), checkAxis.end()); - if (lastUnique != checkAxis.end()) - { - throw InvalidArgumentException(descriptorName + ": Axes values must be unique."); - } + ValidateDataTypes(axisTensorInfo, axisSupportedTypes, descriptorName); - // Next check that the axes values are in range: [-rank, rank] - const auto minmax = - std::minmax_element(std::begin(m_Parameters.m_Axis), std::end(m_Parameters.m_Axis)); - if (((*minmax.first) < int32_t(-inputTensorNumDimensions)) || - ((*minmax.second) >= int32_t (inputTensorNumDimensions))) - { - throw InvalidArgumentException(descriptorName + - ": Axes values must in range [-" + std::to_string(inputTensorNumDimensions) + "," + - std::to_string(inputTensorNumDimensions) + "]."); - } - } + ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output"); + ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output"); } - void FakeQuantizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { const std::string descriptorName{"FakeQuantizationQueueDescriptor"}; |