diff options
Diffstat (limited to 'src/backends/cl/test/ClImportTensorHandleFactoryTests.cpp')
-rw-r--r-- | src/backends/cl/test/ClImportTensorHandleFactoryTests.cpp | 28 |
1 files changed, 13 insertions, 15 deletions
diff --git a/src/backends/cl/test/ClImportTensorHandleFactoryTests.cpp b/src/backends/cl/test/ClImportTensorHandleFactoryTests.cpp index fee40fd257..46be3a122d 100644 --- a/src/backends/cl/test/ClImportTensorHandleFactoryTests.cpp +++ b/src/backends/cl/test/ClImportTensorHandleFactoryTests.cpp @@ -1,10 +1,8 @@ // -// Copyright © 2021 Arm Ltd. All rights reserved. +// Copyright © 2021, 2024 Arm Ltd. All rights reserved. // SPDX-License-Identifier: MIT // -#include <armnn/utility/Assert.hpp> - #include <cl/ClImportTensorHandleFactory.hpp> #include <doctest/doctest.h> @@ -35,21 +33,21 @@ TEST_CASE("ImportTensorFactoryCreateMallocTensorHandle") // Start with the TensorInfo factory method. Create an import tensor handle and verify the data is // passed through correctly. auto tensorHandle = factory.CreateTensorHandle(tensorInfo); - ARMNN_ASSERT(tensorHandle); - ARMNN_ASSERT(tensorHandle->GetImportFlags() == static_cast<MemorySourceFlags>(MemorySource::Malloc)); - ARMNN_ASSERT(tensorHandle->GetShape() == tensorShape); + CHECK(tensorHandle); + CHECK(tensorHandle->GetImportFlags() == static_cast<MemorySourceFlags>(MemorySource::Malloc)); + CHECK(tensorHandle->GetShape() == tensorShape); // Same method but explicitly specifying isManaged = false. tensorHandle = factory.CreateTensorHandle(tensorInfo, false); CHECK(tensorHandle); - ARMNN_ASSERT(tensorHandle->GetImportFlags() == static_cast<MemorySourceFlags>(MemorySource::Malloc)); - ARMNN_ASSERT(tensorHandle->GetShape() == tensorShape); + CHECK(tensorHandle->GetImportFlags() == static_cast<MemorySourceFlags>(MemorySource::Malloc)); + CHECK(tensorHandle->GetShape() == tensorShape); // Now try TensorInfo and DataLayout factory method. tensorHandle = factory.CreateTensorHandle(tensorInfo, DataLayout::NHWC); CHECK(tensorHandle); - ARMNN_ASSERT(tensorHandle->GetImportFlags() == static_cast<MemorySourceFlags>(MemorySource::Malloc)); - ARMNN_ASSERT(tensorHandle->GetShape() == tensorShape); + CHECK(tensorHandle->GetImportFlags() == static_cast<MemorySourceFlags>(MemorySource::Malloc)); + CHECK(tensorHandle->GetShape() == tensorShape); } TEST_CASE("CreateSubtensorOfImportTensor") @@ -67,8 +65,8 @@ TEST_CASE("CreateSubtensorOfImportTensor") uint32_t origin[4] = { 1, 1, 0, 0 }; auto subTensor = factory.CreateSubTensorHandle(*tensorHandle, subTensorShape, origin); CHECK(subTensor); - ARMNN_ASSERT(subTensor->GetShape() == subTensorShape); - ARMNN_ASSERT(subTensor->GetParent() == tensorHandle.get()); + CHECK(subTensor->GetShape() == subTensorShape); + CHECK(subTensor->GetParent() == tensorHandle.get()); } TEST_CASE("CreateSubtensorNonZeroXYIsInvalid") @@ -87,7 +85,7 @@ TEST_CASE("CreateSubtensorNonZeroXYIsInvalid") uint32_t origin[4] = { 0, 0, 1, 1 }; auto subTensor = factory.CreateSubTensorHandle(*tensorHandle, subTensorShape, origin); // We expect a nullptr. - ARMNN_ASSERT(subTensor == nullptr); + CHECK(subTensor == nullptr); } TEST_CASE("CreateSubtensorXYMustMatchParent") @@ -105,7 +103,7 @@ TEST_CASE("CreateSubtensorXYMustMatchParent") uint32_t origin[4] = { 1, 1, 0, 0 }; auto subTensor = factory.CreateSubTensorHandle(*tensorHandle, subTensorShape, origin); // We expect a nullptr. - ARMNN_ASSERT(subTensor == nullptr); + CHECK(subTensor == nullptr); } TEST_CASE("CreateSubtensorMustBeSmallerThanParent") @@ -122,7 +120,7 @@ TEST_CASE("CreateSubtensorMustBeSmallerThanParent") uint32_t origin[4] = { 1, 1, 0, 0 }; // This should result in a nullptr. auto subTensor = factory.CreateSubTensorHandle(*tensorHandle, subTensorShape, origin); - ARMNN_ASSERT(subTensor == nullptr); + CHECK(subTensor == nullptr); } } |