diff options
Diffstat (limited to 'src/backends/aclCommon/test')
-rw-r--r-- | src/backends/aclCommon/test/MemCopyTestImpl.hpp | 29 | ||||
-rw-r--r-- | src/backends/aclCommon/test/MemCopyTests.cpp | 12 |
2 files changed, 21 insertions, 20 deletions
diff --git a/src/backends/aclCommon/test/MemCopyTestImpl.hpp b/src/backends/aclCommon/test/MemCopyTestImpl.hpp index 1f542d24b4..91ba4eae17 100644 --- a/src/backends/aclCommon/test/MemCopyTestImpl.hpp +++ b/src/backends/aclCommon/test/MemCopyTestImpl.hpp @@ -15,8 +15,6 @@ #include <test/TensorHelpers.hpp> -#include <boost/multi_array.hpp> - namespace { @@ -28,21 +26,20 @@ LayerTestResult<T, 4> MemCopyTest(armnn::IWorkloadFactory& srcWorkloadFactory, const std::array<unsigned int, 4> shapeData = { { 1u, 1u, 6u, 5u } }; const armnn::TensorShape tensorShape(4, shapeData.data()); const armnn::TensorInfo tensorInfo(tensorShape, dataType); - boost::multi_array<T, 4> inputData = MakeTensor<T, 4>(tensorInfo, std::vector<T>( - { - 1, 2, 3, 4, 5, - 6, 7, 8, 9, 10, - 11, 12, 13, 14, 15, - 16, 17, 18, 19, 20, - 21, 22, 23, 24, 25, - 26, 27, 28, 29, 30, - }) - ); + std::vector<T> inputData = + { + 1, 2, 3, 4, 5, + 6, 7, 8, 9, 10, + 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, + 21, 22, 23, 24, 25, + 26, 27, 28, 29, 30, + }; LayerTestResult<T, 4> ret(tensorInfo); - ret.outputExpected = inputData; + ret.m_ExpectedData = inputData; - boost::multi_array<T, 4> outputData(shapeData); + std::vector<T> actualOutput(tensorInfo.GetNumElements()); ARMNN_NO_DEPRECATE_WARN_BEGIN auto inputTensorHandle = srcWorkloadFactory.CreateTensorHandle(tensorInfo); @@ -71,8 +68,8 @@ LayerTestResult<T, 4> MemCopyTest(armnn::IWorkloadFactory& srcWorkloadFactory, dstWorkloadFactory.CreateMemCopy(memCopyQueueDesc, workloadInfo)->Execute(); - CopyDataFromITensorHandle(outputData.data(), workloadOutput.get()); - ret.output = outputData; + CopyDataFromITensorHandle(actualOutput.data(), workloadOutput.get()); + ret.m_ActualData = actualOutput; return ret; } diff --git a/src/backends/aclCommon/test/MemCopyTests.cpp b/src/backends/aclCommon/test/MemCopyTests.cpp index ffba19323a..7612cbfe28 100644 --- a/src/backends/aclCommon/test/MemCopyTests.cpp +++ b/src/backends/aclCommon/test/MemCopyTests.cpp @@ -48,7 +48,8 @@ BOOST_AUTO_TEST_CASE(CopyBetweenNeonAndGpu) { LayerTestResult<float, 4> result = MemCopyTest<armnn::NeonWorkloadFactory, armnn::ClWorkloadFactory, armnn::DataType::Float32>(false); - auto predResult = CompareTensors(result.output, result.outputExpected); + auto predResult = CompareTensors(result.m_ActualData, result.m_ExpectedData, + result.m_ActualShape, result.m_ExpectedShape); BOOST_TEST(predResult.m_Result, predResult.m_Message.str()); } @@ -56,7 +57,8 @@ BOOST_AUTO_TEST_CASE(CopyBetweenGpuAndNeon) { LayerTestResult<float, 4> result = MemCopyTest<armnn::ClWorkloadFactory, armnn::NeonWorkloadFactory, armnn::DataType::Float32>(false); - auto predResult = CompareTensors(result.output, result.outputExpected); + auto predResult = CompareTensors(result.m_ActualData, result.m_ExpectedData, + result.m_ActualShape, result.m_ExpectedShape); BOOST_TEST(predResult.m_Result, predResult.m_Message.str()); } @@ -64,7 +66,8 @@ BOOST_AUTO_TEST_CASE(CopyBetweenNeonAndGpuWithSubtensors) { LayerTestResult<float, 4> result = MemCopyTest<armnn::NeonWorkloadFactory, armnn::ClWorkloadFactory, armnn::DataType::Float32>(true); - auto predResult = CompareTensors(result.output, result.outputExpected); + auto predResult = CompareTensors(result.m_ActualData, result.m_ExpectedData, + result.m_ActualShape, result.m_ExpectedShape); BOOST_TEST(predResult.m_Result, predResult.m_Message.str()); } @@ -72,7 +75,8 @@ BOOST_AUTO_TEST_CASE(CopyBetweenGpuAndNeonWithSubtensors) { LayerTestResult<float, 4> result = MemCopyTest<armnn::ClWorkloadFactory, armnn::NeonWorkloadFactory, armnn::DataType::Float32>(true); - auto predResult = CompareTensors(result.output, result.outputExpected); + auto predResult = CompareTensors(result.m_ActualData, result.m_ExpectedData, + result.m_ActualShape, result.m_ExpectedShape); BOOST_TEST(predResult.m_Result, predResult.m_Message.str()); } |