ArmNN
 21.11
ClImportTensorHandleFactoryTests.cpp File Reference
#include <armnn/utility/Assert.hpp>
#include <cl/ClImportTensorHandleFactory.hpp>
#include <doctest/doctest.h>

Go to the source code of this file.

Functions

 TEST_SUITE ("ClImportTensorHandleFactoryTests")
 

Function Documentation

◆ TEST_SUITE()

TEST_SUITE ( "ClImportTensorHandleFactoryTests"  )

Definition at line 12 of file ClImportTensorHandleFactoryTests.cpp.

References ARMNN_ASSERT, ClImportTensorHandleFactory::CreateSubTensorHandle(), ClImportTensorHandleFactory::CreateTensorHandle(), armnn::Float32, armnn::Malloc, armnn::NCHW, and armnn::NHWC.

13 {
14 using namespace armnn;
15 
16 TEST_CASE("ImportTensorFactoryAskedToCreateManagedTensorThrowsException")
17 {
18  // Create the factory to import tensors.
19  ClImportTensorHandleFactory factory(static_cast<MemorySourceFlags>(MemorySource::Malloc),
20  static_cast<MemorySourceFlags>(MemorySource::Malloc));
21  TensorInfo tensorInfo;
22  // This factory is designed to import the memory of tensors. Asking for a handle that requires
23  // a memory manager should result in an exception.
24  REQUIRE_THROWS_AS(factory.CreateTensorHandle(tensorInfo, true), InvalidArgumentException);
25  REQUIRE_THROWS_AS(factory.CreateTensorHandle(tensorInfo, DataLayout::NCHW, true), InvalidArgumentException);
26 }
27 
28 TEST_CASE("ImportTensorFactoryCreateMallocTensorHandle")
29 {
30  // Create the factory to import tensors.
31  ClImportTensorHandleFactory factory(static_cast<MemorySourceFlags>(MemorySource::Malloc),
32  static_cast<MemorySourceFlags>(MemorySource::Malloc));
33  TensorShape tensorShape{ 6, 7, 8, 9 };
34  TensorInfo tensorInfo(tensorShape, armnn::DataType::Float32);
35  // Start with the TensorInfo factory method. Create an import tensor handle and verify the data is
36  // passed through correctly.
37  auto tensorHandle = factory.CreateTensorHandle(tensorInfo);
38  ARMNN_ASSERT(tensorHandle);
39  ARMNN_ASSERT(tensorHandle->GetImportFlags() == static_cast<MemorySourceFlags>(MemorySource::Malloc));
40  ARMNN_ASSERT(tensorHandle->GetShape() == tensorShape);
41 
42  // Same method but explicitly specifying isManaged = false.
43  tensorHandle = factory.CreateTensorHandle(tensorInfo, false);
44  CHECK(tensorHandle);
45  ARMNN_ASSERT(tensorHandle->GetImportFlags() == static_cast<MemorySourceFlags>(MemorySource::Malloc));
46  ARMNN_ASSERT(tensorHandle->GetShape() == tensorShape);
47 
48  // Now try TensorInfo and DataLayout factory method.
49  tensorHandle = factory.CreateTensorHandle(tensorInfo, DataLayout::NHWC);
50  CHECK(tensorHandle);
51  ARMNN_ASSERT(tensorHandle->GetImportFlags() == static_cast<MemorySourceFlags>(MemorySource::Malloc));
52  ARMNN_ASSERT(tensorHandle->GetShape() == tensorShape);
53 }
54 
55 TEST_CASE("CreateSubtensorOfImportTensor")
56 {
57  // Create the factory to import tensors.
58  ClImportTensorHandleFactory factory(static_cast<MemorySourceFlags>(MemorySource::Malloc),
59  static_cast<MemorySourceFlags>(MemorySource::Malloc));
60  // Create a standard inport tensor.
61  TensorShape tensorShape{ 224, 224, 1, 1 };
62  TensorInfo tensorInfo(tensorShape, armnn::DataType::Float32);
63  auto tensorHandle = factory.CreateTensorHandle(tensorInfo);
64  // Use the factory to create a 16x16 sub tensor.
65  TensorShape subTensorShape{ 16, 16, 1, 1 };
66  // Starting at an offset of 1x1.
67  uint32_t origin[4] = { 1, 1, 0, 0 };
68  auto subTensor = factory.CreateSubTensorHandle(*tensorHandle, subTensorShape, origin);
69  CHECK(subTensor);
70  ARMNN_ASSERT(subTensor->GetShape() == subTensorShape);
71  ARMNN_ASSERT(subTensor->GetParent() == tensorHandle.get());
72 }
73 
74 TEST_CASE("CreateSubtensorNonZeroXYIsInvalid")
75 {
76  // Create the factory to import tensors.
77  ClImportTensorHandleFactory factory(static_cast<MemorySourceFlags>(MemorySource::Malloc),
78  static_cast<MemorySourceFlags>(MemorySource::Malloc));
79  // Create a standard import tensor.
80  TensorShape tensorShape{ 224, 224, 1, 1 };
81  TensorInfo tensorInfo(tensorShape, armnn::DataType::Float32);
82  auto tensorHandle = factory.CreateTensorHandle(tensorInfo);
83  // Use the factory to create a 16x16 sub tensor.
84  TensorShape subTensorShape{ 16, 16, 1, 1 };
85  // This looks a bit backwards because of how Cl specifies tensors. Essentially we want to trigger our
86  // check "(coords.x() != 0 || coords.y() != 0)"
87  uint32_t origin[4] = { 0, 0, 1, 1 };
88  auto subTensor = factory.CreateSubTensorHandle(*tensorHandle, subTensorShape, origin);
89  // We expect a nullptr.
90  ARMNN_ASSERT(subTensor == nullptr);
91 }
92 
93 TEST_CASE("CreateSubtensorXYMustMatchParent")
94 {
95  // Create the factory to import tensors.
96  ClImportTensorHandleFactory factory(static_cast<MemorySourceFlags>(MemorySource::Malloc),
97  static_cast<MemorySourceFlags>(MemorySource::Malloc));
98  // Create a standard import tensor.
99  TensorShape tensorShape{ 224, 224, 1, 1 };
100  TensorInfo tensorInfo(tensorShape, armnn::DataType::Float32);
101  auto tensorHandle = factory.CreateTensorHandle(tensorInfo);
102  // Use the factory to create a 16x16 sub tensor but make the CL x and y axis different.
103  TensorShape subTensorShape{ 16, 16, 2, 2 };
104  // We want to trigger our ((parentShape.x() != shape.x()) || (parentShape.y() != shape.y()))
105  uint32_t origin[4] = { 1, 1, 0, 0 };
106  auto subTensor = factory.CreateSubTensorHandle(*tensorHandle, subTensorShape, origin);
107  // We expect a nullptr.
108  ARMNN_ASSERT(subTensor == nullptr);
109 }
110 
111 TEST_CASE("CreateSubtensorMustBeSmallerThanParent")
112 {
113  // Create the factory to import tensors.
114  ClImportTensorHandleFactory factory(static_cast<MemorySourceFlags>(MemorySource::Malloc),
115  static_cast<MemorySourceFlags>(MemorySource::Malloc));
116  // Create a standard import tensor.
117  TensorShape tensorShape{ 224, 224, 1, 1 };
118  TensorInfo tensorInfo(tensorShape, armnn::DataType::Float32);
119  auto tensorHandle = factory.CreateTensorHandle(tensorInfo);
120  // Ask for a subtensor that's the same size as the parent.
121  TensorShape subTensorShape{ 224, 224, 1, 1 };
122  uint32_t origin[4] = { 1, 1, 0, 0 };
123  // This should result in a nullptr.
124  auto subTensor = factory.CreateSubTensorHandle(*tensorHandle, subTensorShape, origin);
125  ARMNN_ASSERT(subTensor == nullptr);
126 }
127 
128 }
unsigned int MemorySourceFlags
Copyright (c) 2021 ARM Limited and Contributors.
#define ARMNN_ASSERT(COND)
Definition: Assert.hpp:14
This factory creates ClImportTensorHandles that refer to imported memory tensors. ...