aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon/WorkloadData.cpp
diff options
context:
space:
mode:
authorTracy Narine <tracy.narine@arm.com>2023-07-13 16:50:54 +0100
committerTracy Narine <tracy.narine@arm.com>2023-07-17 14:19:36 +0100
commitbb8d7591a35bd95480b39001f8b7e41a6671f3a6 (patch)
treeabf2871aa1bb86378f423df405164b0d4521db3f /src/backends/backendsCommon/WorkloadData.cpp
parent688268328c69e7d4181cdd31fe4717c80a6d1685 (diff)
downloadarmnn-bb8d7591a35bd95480b39001f8b7e41a6671f3a6.tar.gz
IVGCVSW-7879 Change REVERSE_V2 from LayerWithParameters with 1 input, to Layer with 2 inputs
* Changing ReverseV2 to use two inputs * This is required by the backends * The ReverseV2Descriptor was removed * Tests updated * Added a Run<> templatefor inputs with different data types Signed-off-by: Tracy Narine <tracy.narine@arm.com> Change-Id: I22f947de829b4b3da6bda3a74f4ffdef4052cc25
Diffstat (limited to 'src/backends/backendsCommon/WorkloadData.cpp')
-rw-r--r--src/backends/backendsCommon/WorkloadData.cpp63
1 files changed, 27 insertions, 36 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp
index a26aaf490b..bd3c7c2760 100644
--- a/src/backends/backendsCommon/WorkloadData.cpp
+++ b/src/backends/backendsCommon/WorkloadData.cpp
@@ -1617,18 +1617,35 @@ void ResizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
void ReverseV2QueueDescriptor::Validate(const WorkloadInfo &workloadInfo) const {
const std::string descriptorName{"ReverseV2QueueDescriptor"};
- ValidateNumInputs(workloadInfo, descriptorName, 1);
+ // Backend restriction
+ const unsigned int maxDimensions = 4;
+
+ ValidateNumInputs(workloadInfo, descriptorName, 2);
ValidateNumOutputs(workloadInfo, descriptorName, 1);
const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
+ const TensorInfo& axisTensorInfo = workloadInfo.m_InputTensorInfos[1];
const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
- auto inputTensorNumDimensions = inputTensorInfo.GetNumDimensions();
- if (inputTensorNumDimensions > m_Parameters.m_MaxDimension)
+ const auto inputTensorNumDimensions = inputTensorInfo.GetNumDimensions();
+ if (inputTensorNumDimensions > maxDimensions)
{
throw InvalidArgumentException(descriptorName +
": Input tensors with rank greater than " +
- std::to_string(m_Parameters.m_MaxDimension) + " are not supported.");
+ std::to_string(maxDimensions) + " are not supported.");
+ }
+
+ const auto axisTensorNumDimensions = axisTensorInfo.GetNumDimensions();
+ if (axisTensorNumDimensions > maxDimensions)
+ {
+ throw InvalidArgumentException(descriptorName +
+ ": More than " + std::to_string(maxDimensions) + " axes cannot be specified.");
+ }
+
+ if (axisTensorNumDimensions > inputTensorNumDimensions)
+ {
+ throw InvalidArgumentException(descriptorName +
+ ": More axes specified than the number of axes on the input tensor.");
}
std::vector<DataType> supportedTypes =
@@ -1642,44 +1659,18 @@ void ReverseV2QueueDescriptor::Validate(const WorkloadInfo &workloadInfo) const
};
ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
- ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
- ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
- if (m_Parameters.m_Axis.size() > inputTensorNumDimensions)
- {
- throw InvalidArgumentException(descriptorName + ": More axes specified than is on the input tensor.");
- }
- if (m_Parameters.m_Axis.size() > m_Parameters.m_MaxDimension)
+ std::vector<DataType> axisSupportedTypes =
{
- throw InvalidArgumentException(descriptorName +
- ": More than " + std::to_string(m_Parameters.m_MaxDimension) + " axes cannot be specified.");
- }
+ DataType::Signed32,
+ };
- if (! m_Parameters.m_Axis.empty())
- {
- // First check that we have unique axis values
- auto checkAxis = m_Parameters.m_Axis;
- std::sort(checkAxis.begin(), checkAxis.end());
- auto lastUnique = std::unique(checkAxis.begin(), checkAxis.end());
- if (lastUnique != checkAxis.end())
- {
- throw InvalidArgumentException(descriptorName + ": Axes values must be unique.");
- }
+ ValidateDataTypes(axisTensorInfo, axisSupportedTypes, descriptorName);
- // Next check that the axes values are in range: [-rank, rank]
- const auto minmax =
- std::minmax_element(std::begin(m_Parameters.m_Axis), std::end(m_Parameters.m_Axis));
- if (((*minmax.first) < int32_t(-inputTensorNumDimensions)) ||
- ((*minmax.second) >= int32_t (inputTensorNumDimensions)))
- {
- throw InvalidArgumentException(descriptorName +
- ": Axes values must in range [-" + std::to_string(inputTensorNumDimensions) + "," +
- std::to_string(inputTensorNumDimensions) + "].");
- }
- }
+ ValidateTensorShapesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
+ ValidateTensorDataTypesMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
}
-
void FakeQuantizationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
const std::string descriptorName{"FakeQuantizationQueueDescriptor"};