ArmNN
 22.05.01
SerializerTestUtils.hpp File Reference
#include <armnn/Descriptors.hpp>
#include <armnn/INetwork.hpp>
#include <armnn/TypesUtils.hpp>
#include <armnnDeserializer/IDeserializer.hpp>
#include <armnn/utility/IgnoreUnused.hpp>
#include <random>
#include <vector>
#include <cstdlib>
#include <doctest/doctest.h>

Go to the source code of this file.

Classes

class  LayerVerifierBase
 
class  LayerVerifierBaseWithDescriptor< Descriptor >
 
class  LayerVerifierBaseWithDescriptorAndConstants< Descriptor >
 

Functions

armnn::INetworkPtr DeserializeNetwork (const std::string &serializerString)
 
std::string SerializeNetwork (const armnn::INetwork &network)
 
void CompareConstTensor (const armnn::ConstTensor &tensor1, const armnn::ConstTensor &tensor2)
 
template<typename T >
void CompareConstTensorData (const void *data1, const void *data2, unsigned int numElements)
 

Function Documentation

◆ CompareConstTensor()

void CompareConstTensor ( const armnn::ConstTensor tensor1,
const armnn::ConstTensor tensor2 
)

Definition at line 122 of file SerializerTestUtils.cpp.

References armnn::Boolean, armnn::Float32, BaseTensor< MemoryType >::GetDataType(), armnn::GetDataTypeName(), BaseTensor< MemoryType >::GetMemoryArea(), BaseTensor< MemoryType >::GetNumElements(), BaseTensor< MemoryType >::GetShape(), armnn::QAsymmU8, armnn::QSymmS8, and armnn::Signed32.

Referenced by LayerVerifierBaseWithDescriptorAndConstants< Descriptor >::ExecuteStrategy(), and TEST_SUITE().

123 {
124  CHECK(tensor1.GetShape() == tensor2.GetShape());
125  CHECK(GetDataTypeName(tensor1.GetDataType()) == GetDataTypeName(tensor2.GetDataType()));
126 
127  switch (tensor1.GetDataType())
128  {
130  CompareConstTensorData<const float*>(
131  tensor1.GetMemoryArea(), tensor2.GetMemoryArea(), tensor1.GetNumElements());
132  break;
135  CompareConstTensorData<const uint8_t*>(
136  tensor1.GetMemoryArea(), tensor2.GetMemoryArea(), tensor1.GetNumElements());
137  break;
139  CompareConstTensorData<const int8_t*>(
140  tensor1.GetMemoryArea(), tensor2.GetMemoryArea(), tensor1.GetNumElements());
141  break;
143  CompareConstTensorData<const int32_t*>(
144  tensor1.GetMemoryArea(), tensor2.GetMemoryArea(), tensor1.GetNumElements());
145  break;
146  default:
147  // Note that Float16 is not yet implemented
148  MESSAGE("Unexpected datatype");
149  CHECK(false);
150  }
151 }
const TensorShape & GetShape() const
Definition: Tensor.hpp:297
unsigned int GetNumElements() const
Definition: Tensor.hpp:303
MemoryType GetMemoryArea() const
Definition: Tensor.hpp:305
constexpr const char * GetDataTypeName(DataType dataType)
Definition: TypesUtils.hpp:202
DataType GetDataType() const
Definition: Tensor.hpp:300

◆ CompareConstTensorData()

void CompareConstTensorData ( const void *  data1,
const void *  data2,
unsigned int  numElements 
)

Definition at line 93 of file SerializerTestUtils.hpp.

94 {
95  T typedData1 = static_cast<T>(data1);
96  T typedData2 = static_cast<T>(data2);
97  CHECK(typedData1);
98  CHECK(typedData2);
99 
100  for (unsigned int i = 0; i < numElements; i++)
101  {
102  CHECK(typedData1[i] == typedData2[i]);
103  }
104 }

◆ DeserializeNetwork()

armnn::INetworkPtr DeserializeNetwork ( const std::string &  serializerString)

Definition at line 153 of file SerializerTestUtils.cpp.

Referenced by TEST_SUITE().

154 {
155  std::vector<std::uint8_t> const serializerVector{serializerString.begin(), serializerString.end()};
156  return IDeserializer::Create()->CreateNetworkFromBinary(serializerVector);
157 }

◆ SerializeNetwork()

std::string SerializeNetwork ( const armnn::INetwork network)

Definition at line 159 of file SerializerTestUtils.cpp.

References ISerializer::Create().

Referenced by TEST_SUITE().

160 {
162 
163  serializer->Serialize(network);
164 
165  std::stringstream stream;
166  serializer->SaveSerializedToStream(stream);
167 
168  std::string serializerString{stream.str()};
169  return serializerString;
170 }
std::unique_ptr< ISerializer, void(*)(ISerializer *serializer)> ISerializerPtr
Definition: ISerializer.hpp:15
static ISerializerPtr Create()
Definition: Serializer.cpp:35