aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon/test/layerTests/DetectionPostProcessTestImpl.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/backendsCommon/test/layerTests/DetectionPostProcessTestImpl.hpp')
-rw-r--r--src/backends/backendsCommon/test/layerTests/DetectionPostProcessTestImpl.hpp57
1 files changed, 32 insertions, 25 deletions
diff --git a/src/backends/backendsCommon/test/layerTests/DetectionPostProcessTestImpl.hpp b/src/backends/backendsCommon/test/layerTests/DetectionPostProcessTestImpl.hpp
index c6636554ea..143f9e06b1 100644
--- a/src/backends/backendsCommon/test/layerTests/DetectionPostProcessTestImpl.hpp
+++ b/src/backends/backendsCommon/test/layerTests/DetectionPostProcessTestImpl.hpp
@@ -155,23 +155,15 @@ void DetectionPostProcessImpl(const armnn::TensorInfo& boxEncodingsInfo,
FactoryType workloadFactory = WorkloadFactoryHelper<FactoryType>::GetFactory(memoryManager);
auto tensorHandleFactory = WorkloadFactoryHelper<FactoryType>::GetTensorHandleFactory(memoryManager);
- auto boxEncodings = MakeTensor<T, 3>(boxEncodingsInfo, boxEncodingsData);
- auto scores = MakeTensor<T, 3>(scoresInfo, scoresData);
- auto anchors = MakeTensor<T, 2>(anchorsInfo, anchorsData);
-
armnn::TensorInfo detectionBoxesInfo({ 1, 3, 4 }, armnn::DataType::Float32);
- armnn::TensorInfo detectionScoresInfo({ 1, 3 }, armnn::DataType::Float32);
armnn::TensorInfo detectionClassesInfo({ 1, 3 }, armnn::DataType::Float32);
+ armnn::TensorInfo detectionScoresInfo({ 1, 3 }, armnn::DataType::Float32);
armnn::TensorInfo numDetectionInfo({ 1 }, armnn::DataType::Float32);
- LayerTestResult<float, 3> detectionBoxesResult(detectionBoxesInfo);
- detectionBoxesResult.outputExpected = MakeTensor<float, 3>(detectionBoxesInfo, expectedDetectionBoxes);
- LayerTestResult<float, 2> detectionClassesResult(detectionClassesInfo);
- detectionClassesResult.outputExpected = MakeTensor<float, 2>(detectionClassesInfo, expectedDetectionClasses);
- LayerTestResult<float, 2> detectionScoresResult(detectionScoresInfo);
- detectionScoresResult.outputExpected = MakeTensor<float, 2>(detectionScoresInfo, expectedDetectionScores);
- LayerTestResult<float, 1> numDetectionsResult(numDetectionInfo);
- numDetectionsResult.outputExpected = MakeTensor<float, 1>(numDetectionInfo, expectedNumDetections);
+ std::vector<float> actualDetectionBoxesOutput(detectionBoxesInfo.GetNumElements());
+ std::vector<float> actualDetectionClassesOutput(detectionClassesInfo.GetNumElements());
+ std::vector<float> actualDetectionScoresOutput(detectionScoresInfo.GetNumElements());
+ std::vector<float> actualNumDetectionOutput(numDetectionInfo.GetNumElements());
auto boxedHandle = tensorHandleFactory.CreateTensorHandle(boxEncodingsInfo);
auto scoreshandle = tensorHandleFactory.CreateTensorHandle(scoresInfo);
@@ -182,7 +174,7 @@ void DetectionPostProcessImpl(const armnn::TensorInfo& boxEncodingsInfo,
auto numDetectionHandle = tensorHandleFactory.CreateTensorHandle(numDetectionInfo);
armnn::ScopedTensorHandle anchorsTensor(anchorsInfo);
- AllocateAndCopyDataToITensorHandle(&anchorsTensor, &anchors[0][0]);
+ AllocateAndCopyDataToITensorHandle(&anchorsTensor, anchorsData.data());
armnn::DetectionPostProcessQueueDescriptor data;
data.m_Parameters.m_UseRegularNms = useRegularNms;
@@ -200,7 +192,7 @@ void DetectionPostProcessImpl(const armnn::TensorInfo& boxEncodingsInfo,
armnn::WorkloadInfo info;
AddInputToWorkload(data, info, boxEncodingsInfo, boxedHandle.get());
- AddInputToWorkload(data, info, scoresInfo, scoreshandle.get());
+ AddInputToWorkload(data, info, scoresInfo, scoreshandle.get());
AddOutputToWorkload(data, info, detectionBoxesInfo, outputBoxesHandle.get());
AddOutputToWorkload(data, info, detectionClassesInfo, classesHandle.get());
AddOutputToWorkload(data, info, detectionScoresInfo, outputScoresHandle.get());
@@ -215,23 +207,38 @@ void DetectionPostProcessImpl(const armnn::TensorInfo& boxEncodingsInfo,
outputScoresHandle->Allocate();
numDetectionHandle->Allocate();
- CopyDataToITensorHandle(boxedHandle.get(), boxEncodings.origin());
- CopyDataToITensorHandle(scoreshandle.get(), scores.origin());
+ CopyDataToITensorHandle(boxedHandle.get(), boxEncodingsData.data());
+ CopyDataToITensorHandle(scoreshandle.get(), scoresData.data());
workload->Execute();
- CopyDataFromITensorHandle(detectionBoxesResult.output.origin(), outputBoxesHandle.get());
- CopyDataFromITensorHandle(detectionClassesResult.output.origin(), classesHandle.get());
- CopyDataFromITensorHandle(detectionScoresResult.output.origin(), outputScoresHandle.get());
- CopyDataFromITensorHandle(numDetectionsResult.output.origin(), numDetectionHandle.get());
+ CopyDataFromITensorHandle(actualDetectionBoxesOutput.data(), outputBoxesHandle.get());
+ CopyDataFromITensorHandle(actualDetectionClassesOutput.data(), classesHandle.get());
+ CopyDataFromITensorHandle(actualDetectionScoresOutput.data(), outputScoresHandle.get());
+ CopyDataFromITensorHandle(actualNumDetectionOutput.data(), numDetectionHandle.get());
- auto result = CompareTensors(detectionBoxesResult.output, detectionBoxesResult.outputExpected);
+ auto result = CompareTensors(actualDetectionBoxesOutput,
+ expectedDetectionBoxes,
+ outputBoxesHandle->GetShape(),
+ detectionBoxesInfo.GetShape());
BOOST_TEST(result.m_Result, result.m_Message.str());
- result = CompareTensors(detectionClassesResult.output, detectionClassesResult.outputExpected);
+
+ result = CompareTensors(actualDetectionClassesOutput,
+ expectedDetectionClasses,
+ classesHandle->GetShape(),
+ detectionClassesInfo.GetShape());
BOOST_TEST(result.m_Result, result.m_Message.str());
- result = CompareTensors(detectionScoresResult.output, detectionScoresResult.outputExpected);
+
+ result = CompareTensors(actualDetectionScoresOutput,
+ expectedDetectionScores,
+ outputScoresHandle->GetShape(),
+ detectionScoresInfo.GetShape());
BOOST_TEST(result.m_Result, result.m_Message.str());
- result = CompareTensors(numDetectionsResult.output, numDetectionsResult.outputExpected);
+
+ result = CompareTensors(actualNumDetectionOutput,
+ expectedNumDetections,
+ numDetectionHandle->GetShape(),
+ numDetectionInfo.GetShape());
BOOST_TEST(result.m_Result, result.m_Message.str());
}