aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon/WorkloadData.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/backendsCommon/WorkloadData.cpp')
-rw-r--r--src/backends/backendsCommon/WorkloadData.cpp47
1 files changed, 39 insertions, 8 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp
index bd3c7c2760..7efca9de50 100644
--- a/src/backends/backendsCommon/WorkloadData.cpp
+++ b/src/backends/backendsCommon/WorkloadData.cpp
@@ -207,9 +207,9 @@ void ValidateBiasTensorQuantization(const TensorInfo& biasTensor,
//---------------------------------------------------------------
void ValidateTensors(const std::vector<ITensorHandle*>& vec,
- unsigned int numExpected,
- const std::string& descName,
- const std::string& varName)
+ unsigned int numExpected,
+ const std::string& descName,
+ const std::string& varName)
{
if (vec.empty() && numExpected > 0)
{
@@ -433,9 +433,9 @@ void QueueDescriptor::ValidateTensorNumDimensions(const TensorInfo& tensor,
//---------------------------------------------------------------
void QueueDescriptor::ValidateTensorNumDimNumElem(const TensorInfo& tensorInfo,
- unsigned int numDimension,
- unsigned int numElements,
- std::string const& tensorName) const
+ unsigned int numDimension,
+ unsigned int numElements,
+ std::string const& tensorName) const
{
const std::string functionName{"ValidateTensorNumDimNumElem"};
ValidateTensorNumDimensions(tensorInfo, functionName, numDimension, tensorName);
@@ -1614,7 +1614,8 @@ void ResizeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
}
}
-void ReverseV2QueueDescriptor::Validate(const WorkloadInfo &workloadInfo) const {
+void ReverseV2QueueDescriptor::Validate(const WorkloadInfo &workloadInfo) const
+{
const std::string descriptorName{"ReverseV2QueueDescriptor"};
// Backend restriction
@@ -2948,7 +2949,6 @@ void ShapeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
DataType::Float32,
DataType::QAsymmS8,
DataType::QAsymmU8,
- DataType::QAsymmS8,
DataType::QSymmS8,
DataType::QSymmS16,
DataType::Signed32
@@ -4378,5 +4378,36 @@ void BatchMatMulQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) cons
}
}
+void TileQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
+{
+ const std::string& descriptorName{"TileQueueDescriptor"};
+
+ ValidateNumInputs(workloadInfo, descriptorName, 1);
+ ValidateNumOutputs(workloadInfo, descriptorName, 1);
+
+ const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
+ const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
+
+ std::vector<DataType> supportedTypes =
+ {
+ DataType::Float32,
+ DataType::Float16,
+ DataType::QAsymmS8,
+ DataType::QAsymmU8,
+ DataType::QSymmS8,
+ DataType::QSymmS16,
+ DataType::Signed32
+ };
+
+ // Multiples length must be the same as the number of dimensions in input.
+ if (m_Parameters.m_Multiples.size() != inputTensorInfo.GetNumDimensions())
+ {
+ throw InvalidArgumentException(descriptorName +
+ ": Multiples length is not same as the number of dimensions in Input.");
+ }
+
+ ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
+ ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
+}
} // namespace armnn \ No newline at end of file