diff options
Diffstat (limited to 'src/backends/backendsCommon')
-rw-r--r-- | src/backends/backendsCommon/WorkloadData.cpp | 25 | ||||
-rw-r--r-- | src/backends/backendsCommon/test/WorkloadDataValidation.cpp | 2 |
2 files changed, 19 insertions, 8 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp index d9779e4e37..ea84c0b9f2 100644 --- a/src/backends/backendsCommon/WorkloadData.cpp +++ b/src/backends/backendsCommon/WorkloadData.cpp @@ -850,13 +850,13 @@ void ConstantQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const // Check the supported data types std::vector<DataType> supportedTypes = - { - DataType::Float32, - DataType::Float16, - DataType::Signed32, - DataType::QuantisedAsymm8, - DataType::QuantisedSymm16 - }; + { + DataType::Float32, + DataType::Float16, + DataType::Signed32, + DataType::QuantisedAsymm8, + DataType::QuantisedSymm16 + }; ValidateDataTypes(workloadInfo.m_OutputTensorInfos[0], supportedTypes, "ConstantQueueDescriptor"); } @@ -872,6 +872,17 @@ void ReshapeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const to_string(workloadInfo.m_InputTensorInfos[0].GetNumElements()) + " but output tensor has " + to_string(workloadInfo.m_OutputTensorInfos[0].GetNumElements()) + " elements."); } + + // Check the supported data types + std::vector<DataType> supportedTypes = + { + DataType::Float32, + DataType::Float16, + DataType::QuantisedAsymm8 + }; + + ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], supportedTypes, "ReshapeQueueDescriptor"); + ValidateDataTypes(workloadInfo.m_OutputTensorInfos[0], supportedTypes, "ReshapeQueueDescriptor"); } void SpaceToBatchNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const diff --git a/src/backends/backendsCommon/test/WorkloadDataValidation.cpp b/src/backends/backendsCommon/test/WorkloadDataValidation.cpp index 119eb7df90..067cca8319 100644 --- a/src/backends/backendsCommon/test/WorkloadDataValidation.cpp +++ b/src/backends/backendsCommon/test/WorkloadDataValidation.cpp @@ -447,7 +447,7 @@ BOOST_AUTO_TEST_CASE(ReshapeQueueDescriptor_Validate_MismatchingNumElements) AddOutputToWorkload(invalidData, invalidInfo, outputTensorInfo, nullptr); // InvalidArgumentException is expected, because the number of elements don't match. - BOOST_CHECK_THROW(RefReshapeFloat32Workload(invalidData, invalidInfo), armnn::InvalidArgumentException); + BOOST_CHECK_THROW(RefReshapeWorkload(invalidData, invalidInfo), armnn::InvalidArgumentException); } |