ArmNN
 21.05
SerializerTestUtils.hpp File Reference
#include <armnn/Descriptors.hpp>
#include <armnn/INetwork.hpp>
#include <armnn/TypesUtils.hpp>
#include <armnnDeserializer/IDeserializer.hpp>
#include <armnn/utility/IgnoreUnused.hpp>
#include <random>
#include <vector>
#include <boost/test/unit_test.hpp>

Go to the source code of this file.

Classes

class  LayerVerifierBase
 
class  LayerVerifierBaseWithDescriptor< Descriptor >
 
class  LayerVerifierBaseWithDescriptorAndConstants< Descriptor >
 

Functions

armnn::INetworkPtr DeserializeNetwork (const std::string &serializerString)
 
std::string SerializeNetwork (const armnn::INetwork &network)
 
void CompareConstTensor (const armnn::ConstTensor &tensor1, const armnn::ConstTensor &tensor2)
 
template<typename T >
void CompareConstTensorData (const void *data1, const void *data2, unsigned int numElements)
 

Function Documentation

◆ CompareConstTensor()

void CompareConstTensor ( const armnn::ConstTensor tensor1,
const armnn::ConstTensor tensor2 
)

Definition at line 115 of file SerializerTestUtils.cpp.

References armnn::Boolean, armnn::Float32, BaseTensor< MemoryType >::GetDataType(), armnn::GetDataTypeName(), BaseTensor< MemoryType >::GetMemoryArea(), BaseTensor< MemoryType >::GetNumElements(), BaseTensor< MemoryType >::GetShape(), armnn::QAsymmU8, armnn::QSymmS8, and armnn::Signed32.

Referenced by BOOST_AUTO_TEST_CASE(), and LayerVerifierBaseWithDescriptorAndConstants< Descriptor >::ExecuteStrategy().

116 {
117  BOOST_TEST(tensor1.GetShape() == tensor2.GetShape());
118  BOOST_TEST(GetDataTypeName(tensor1.GetDataType()) == GetDataTypeName(tensor2.GetDataType()));
119 
120  switch (tensor1.GetDataType())
121  {
123  CompareConstTensorData<const float*>(
124  tensor1.GetMemoryArea(), tensor2.GetMemoryArea(), tensor1.GetNumElements());
125  break;
128  CompareConstTensorData<const uint8_t*>(
129  tensor1.GetMemoryArea(), tensor2.GetMemoryArea(), tensor1.GetNumElements());
130  break;
132  CompareConstTensorData<const int8_t*>(
133  tensor1.GetMemoryArea(), tensor2.GetMemoryArea(), tensor1.GetNumElements());
134  break;
136  CompareConstTensorData<const int32_t*>(
137  tensor1.GetMemoryArea(), tensor2.GetMemoryArea(), tensor1.GetNumElements());
138  break;
139  default:
140  // Note that Float16 is not yet implemented
141  BOOST_TEST_MESSAGE("Unexpected datatype");
142  BOOST_TEST(false);
143  }
144 }
const TensorShape & GetShape() const
Definition: Tensor.hpp:284
unsigned int GetNumElements() const
Definition: Tensor.hpp:290
MemoryType GetMemoryArea() const
Definition: Tensor.hpp:292
constexpr const char * GetDataTypeName(DataType dataType)
Definition: TypesUtils.hpp:191
DataType GetDataType() const
Definition: Tensor.hpp:287

◆ CompareConstTensorData()

void CompareConstTensorData ( const void *  data1,
const void *  data2,
unsigned int  numElements 
)

Definition at line 92 of file SerializerTestUtils.hpp.

93 {
94  T typedData1 = static_cast<T>(data1);
95  T typedData2 = static_cast<T>(data2);
96  BOOST_CHECK(typedData1);
97  BOOST_CHECK(typedData2);
98 
99  for (unsigned int i = 0; i < numElements; i++)
100  {
101  BOOST_TEST(typedData1[i] == typedData2[i]);
102  }
103 }

◆ DeserializeNetwork()

armnn::INetworkPtr DeserializeNetwork ( const std::string &  serializerString)

Definition at line 146 of file SerializerTestUtils.cpp.

Referenced by BOOST_AUTO_TEST_CASE(), and SerializeArgMinMaxTest().

147 {
148  std::vector<std::uint8_t> const serializerVector{serializerString.begin(), serializerString.end()};
149  return IDeserializer::Create()->CreateNetworkFromBinary(serializerVector);
150 }

◆ SerializeNetwork()

std::string SerializeNetwork ( const armnn::INetwork network)

Definition at line 152 of file SerializerTestUtils.cpp.

References ISerializer::Create().

Referenced by BOOST_AUTO_TEST_CASE(), and SerializeArgMinMaxTest().

153 {
155 
156  serializer->Serialize(network);
157 
158  std::stringstream stream;
159  serializer->SaveSerializedToStream(stream);
160 
161  std::string serializerString{stream.str()};
162  return serializerString;
163 }
std::unique_ptr< ISerializer, void(*)(ISerializer *serializer)> ISerializerPtr
Definition: ISerializer.hpp:15
static ISerializerPtr Create()
Definition: Serializer.cpp:35