ArmNN
 22.08
NetworkExecutionUtils.hpp File Reference
#include <armnn/Logging.hpp>
#include <armnn/utility/StringUtils.hpp>
#include <armnn/utility/NumericCast.hpp>
#include <armnn/BackendRegistry.hpp>
#include <iostream>
#include <fstream>
#include <iomanip>
#include <iterator>

Go to the source code of this file.

Classes

struct  OutputWriteInfo
 

Functions

bool CheckInferenceTimeThreshold (const std::chrono::duration< double, std::milli > &duration, const double &thresholdTime)
 Given a measured duration and a threshold time tell the user whether we succeeded or not. More...
 
bool CheckRequestedBackendsAreValid (const std::vector< armnn::BackendId > &backendIds, armnn::Optional< std::string &> invalidBackendIds=armnn::EmptyOptional())
 
std::vector< unsigned int > ParseArray (std::istream &stream)
 
std::vector< std::string > ParseStringList (const std::string &inputString, const char *delimiter)
 Splits a given string at every accurance of delimiter into a vector of string. More...
 
template<typename T >
std::vector< float > DequantizeArray (const void *array, unsigned int numElements, float scale, int32_t offset)
 Dequantize an array of a given type. More...
 
void LogAndThrow (std::string eMsg)
 
bool ValidatePath (const std::string &file, const bool expectFile)
 Verifies if the given string is a valid path. More...
 
bool ValidatePaths (const std::vector< std::string > &fileVec, const bool expectFile)
 Verifies if a given vector of strings are valid paths. More...
 
template<typename Integer , typename std::enable_if_t< std::is_integral< Integer >::value > * = nullptr>
std::function< Integer(const std::string &)> GetParseElementFunc ()
 Returns a function of read the given type as a string. More...
 
template<typename Float , std::enable_if_t< std::is_floating_point< Float >::value > * = nullptr>
std::function< Float(const std::string &)> GetParseElementFunc ()
 
template<typename T >
void PopulateTensorWithData (T *tensor, const unsigned int numElements, const armnn::Optional< std::string > &dataFile, const std::string &inputName)
 
template<typename T >
void WriteToFile (const std::string &outputTensorFileName, const std::string &outputName, const T *const array, const unsigned int numElements)
 
template<typename T >
void PrintTensor (OutputWriteInfo &info, const char *formatString)
 
template<typename T >
void PrintQuantizedTensor (OutputWriteInfo &info)
 
template<typename T , typename TParseElementFunc >
std::vector< T > ParseArrayImpl (std::istream &stream, TParseElementFunc parseElementFunc, const char *chars="\,:")
 
template<typename T >
float ComputeRMSE (const void *expected, const void *actual, const size_t size)
 Compute the root-mean-square error (RMSE) More...
 

Function Documentation

◆ CheckInferenceTimeThreshold()

bool CheckInferenceTimeThreshold ( const std::chrono::duration< double, std::milli > &  duration,
const double &  thresholdTime 
)

Given a measured duration and a threshold time tell the user whether we succeeded or not.

Parameters
durationthe measured inference duration.
thresholdTimethe threshold time in milliseconds.
Returns
false if the measured time exceeded the threshold.

Definition at line 17 of file NetworkExecutionUtils.cpp.

References ARMNN_LOG.

Referenced by ArmNNExecutor::ArmNNExecutor(), and TfLiteExecutor::Execute().

19 {
20  ARMNN_LOG(info) << "Inference time: " << std::setprecision(2)
21  << std::fixed << duration.count() << " ms\n";
22  // If thresholdTime == 0.0 (default), then it hasn't been supplied at command line
23  if (thresholdTime != 0.0)
24  {
25  ARMNN_LOG(info) << "Threshold time: " << std::setprecision(2)
26  << std::fixed << thresholdTime << " ms";
27  auto thresholdMinusInference = thresholdTime - duration.count();
28  ARMNN_LOG(info) << "Threshold time - Inference time: " << std::setprecision(2)
29  << std::fixed << thresholdMinusInference << " ms" << "\n";
30  if (thresholdMinusInference < 0)
31  {
32  std::string errorMessage = "Elapsed inference time is greater than provided threshold time.";
33  ARMNN_LOG(fatal) << errorMessage;
34  return false;
35  }
36  }
37  return true;
38 }
#define ARMNN_LOG(severity)
Definition: Logging.hpp:205

◆ CheckRequestedBackendsAreValid()

bool CheckRequestedBackendsAreValid ( const std::vector< armnn::BackendId > &  backendIds,
armnn::Optional< std::string &>  invalidBackendIds = armnn::EmptyOptional() 
)
inline

Definition at line 28 of file NetworkExecutionUtils.hpp.

References armnn::BackendRegistryInstance(), BackendRegistry::GetBackendIds(), ParseArray(), and ParseStringList().

Referenced by InferenceModel< IParser, TDataType >::InferenceModel(), main(), and ExecuteNetworkParams::ValidateParams().

30 {
31  if (backendIds.empty())
32  {
33  return false;
34  }
35 
37 
38  bool allValid = true;
39  for (const auto& backendId : backendIds)
40  {
41  if (std::find(validBackendIds.begin(), validBackendIds.end(), backendId) == validBackendIds.end())
42  {
43  allValid = false;
44  if (invalidBackendIds)
45  {
46  if (!invalidBackendIds.value().empty())
47  {
48  invalidBackendIds.value() += ", ";
49  }
50  invalidBackendIds.value() += backendId;
51  }
52  }
53  }
54  return allValid;
55 }
BackendIdSet GetBackendIds() const
std::unordered_set< BackendId > BackendIdSet
Definition: BackendId.hpp:193
BackendRegistry & BackendRegistryInstance()

◆ ComputeRMSE()

float ComputeRMSE ( const void *  expected,
const void *  actual,
const size_t  size 
)

Compute the root-mean-square error (RMSE)

Parameters
expected
actual
sizesize of the tensor
Returns
float the RMSE

Definition at line 278 of file NetworkExecutionUtils.hpp.

279 {
280  auto typedExpected = reinterpret_cast<const T*>(expected);
281  auto typedActual = reinterpret_cast<const T*>(actual);
282 
283  T errorSum = 0;
284 
285  for (unsigned int i = 0; i < size; i++)
286  {
287  if (std::abs(typedExpected[i] - typedActual[i]) != 0)
288  {
289  std::cout << "";
290  }
291  errorSum += std::pow(std::abs(typedExpected[i] - typedActual[i]), 2);
292  }
293 
294  float rmse = std::sqrt(armnn::numeric_cast<float>(errorSum) / armnn::numeric_cast<float>(size / sizeof(T)));
295  return rmse;
296 }

◆ DequantizeArray()

std::vector<float> DequantizeArray ( const void *  array,
unsigned int  numElements,
float  scale,
int32_t  offset 
)

Dequantize an array of a given type.

Parameters
arrayType erased array to dequantize
numElementsElements in the array
arrayType erased array to dequantize

Definition at line 67 of file NetworkExecutionUtils.hpp.

References armnn::Dequantize(), LogAndThrow(), ValidatePath(), and ValidatePaths().

68 {
69  const T* quantizedArray = reinterpret_cast<const T*>(array);
70  std::vector<float> dequantizedVector;
71  dequantizedVector.reserve(numElements);
72  for (unsigned int i = 0; i < numElements; ++i)
73  {
74  float f = armnn::Dequantize(*(quantizedArray + i), scale, offset);
75  dequantizedVector.push_back(f);
76  }
77  return dequantizedVector;
78 }
float Dequantize(QuantizedType value, float scale, int32_t offset)
Dequantize an 8-bit data type into a floating point data type.
Definition: TypesUtils.cpp:46

◆ GetParseElementFunc() [1/2]

std::function<Integer(const std::string&)> GetParseElementFunc ( )

Returns a function of read the given type as a string.

Definition at line 100 of file NetworkExecutionUtils.hpp.

References armnn::numeric_cast().

101 {
102  return [](const std::string& s) { return armnn::numeric_cast<Integer>(std::stoi(s)); };
103 }
std::enable_if_t< std::is_unsigned< Source >::value &&std::is_unsigned< Dest >::value, Dest > numeric_cast(Source source)
Definition: NumericCast.hpp:35

◆ GetParseElementFunc() [2/2]

std::function<Float(const std::string&)> GetParseElementFunc ( )

Definition at line 106 of file NetworkExecutionUtils.hpp.

107 {
108  return [](const std::string& s) { return std::stof(s); };
109 }

◆ LogAndThrow()

void LogAndThrow ( std::string  eMsg)

Definition at line 75 of file NetworkExecutionUtils.cpp.

References ARMNN_LOG.

Referenced by ArmNNExecutor::CompareAndPrintResult(), DequantizeArray(), TfLiteExecutor::Execute(), ArmNNExecutor::PrintNetworkInfo(), and TfLiteExecutor::TfLiteExecutor().

76 {
77  ARMNN_LOG(error) << eMsg;
78  throw armnn::Exception(eMsg);
79 }
#define ARMNN_LOG(severity)
Definition: Logging.hpp:205
Base class for all ArmNN exceptions so that users can filter to just those.
Definition: Exceptions.hpp:46

◆ ParseArray()

std::vector<unsigned int> ParseArray ( std::istream &  stream)

Definition at line 55 of file NetworkExecutionUtils.cpp.

References armnn::numeric_cast().

Referenced by CheckRequestedBackendsAreValid(), and ProgramOptions::ParseOptions().

56 {
57  return ParseArrayImpl<unsigned int>(
58  stream,
59  [](const std::string& s) { return armnn::numeric_cast<unsigned int>(std::stoi(s)); });
60 }
std::enable_if_t< std::is_unsigned< Source >::value &&std::is_unsigned< Dest >::value, Dest > numeric_cast(Source source)
Definition: NumericCast.hpp:35

◆ ParseArrayImpl()

std::vector<T> ParseArrayImpl ( std::istream &  stream,
TParseElementFunc  parseElementFunc,
const char *  chars = "\t ,:" 
)

Definition at line 245 of file NetworkExecutionUtils.hpp.

References ARMNN_LOG, and armnn::stringUtils::StringTokenizer().

245  :")
246 {
247  std::vector<T> result;
248  // Processes line-by-line.
249  std::string line;
250  while (std::getline(stream, line))
251  {
252  std::vector<std::string> tokens = armnn::stringUtils::StringTokenizer(line, chars);
253  for (const std::string& token : tokens)
254  {
255  if (!token.empty()) // See https://stackoverflow.com/questions/10437406/
256  {
257  try
258  {
259  result.push_back(parseElementFunc(token));
260  }
261  catch (const std::exception&)
262  {
263  ARMNN_LOG(error) << "'" << token << "' is not a valid number. It has been ignored.";
264  }
265  }
266  }
267  }
268 
269  return result;
270 }

◆ ParseStringList()

std::vector<std::string> ParseStringList ( const std::string &  inputString,
const char *  delimiter 
)

Splits a given string at every accurance of delimiter into a vector of string.

Definition at line 10 of file NetworkExecutionUtils.cpp.

References armnn::stringUtils::StringTrimCopy().

Referenced by CheckRequestedBackendsAreValid(), GetBackendIDs(), and ProgramOptions::ParseOptions().

11 {
12  std::stringstream stream(inputString);
13  return ParseArrayImpl<std::string>(stream, [](const std::string& s) {
14  return armnn::stringUtils::StringTrimCopy(s); }, delimiter);
15 }
std::string StringTrimCopy(const std::string &str, const std::string &chars="\\\")
Trim from both the start and the end of a string, returns a trimmed copy of the string.
Definition: StringUtils.hpp:88

◆ PopulateTensorWithData()

void PopulateTensorWithData ( T *  tensor,
const unsigned int  numElements,
const armnn::Optional< std::string > &  dataFile,
const std::string &  inputName 
)

Definition at line 112 of file NetworkExecutionUtils.hpp.

References ARMNN_LOG, OptionalBase::has_value(), armnn::stringUtils::StringTokenizer(), and OptionalReferenceSwitch< IsReference, T >::value().

116 {
117  const bool readFromFile = dataFile.has_value() && !dataFile.value().empty();
118 
119  std::ifstream inputTensorFile;
120  if (!readFromFile)
121  {
122  std::fill(tensor, tensor + numElements, 0);
123  return;
124  }
125  else
126  {
127  inputTensorFile = std::ifstream(dataFile.value());
128  }
129 
130  auto parseElementFunc = GetParseElementFunc<T>();
131  std::string line;
132  unsigned int index = 0;
133  while (std::getline(inputTensorFile, line))
134  {
135  std::vector<std::string> tokens = armnn::stringUtils::StringTokenizer(line, "\t ,:");
136  for (const std::string& token : tokens)
137  {
138  if (!token.empty()) // See https://stackoverflow.com/questions/10437406/
139  {
140  try
141  {
142  if (index == numElements)
143  {
144  ARMNN_LOG(error) << "Number of elements: " << (index +1) << " in file \"" << dataFile.value()
145  << "\" does not match number of elements: " << numElements
146  << " for input \"" << inputName << "\".";
147  }
148  *(tensor + index) = parseElementFunc(token);
149  index++;
150  }
151  catch (const std::exception&)
152  {
153  ARMNN_LOG(error) << "'" << token << "' is not a valid number. It has been ignored.";
154  }
155  }
156  }
157  }
158 
159  if (index != numElements)
160  {
161  ARMNN_LOG(error) << "Number of elements: " << (index +1) << " in file \"" << inputName
162  << "\" does not match number of elements: " << numElements
163  << " for input \"" << inputName << "\".";
164  }
165 }
std::vector< std::string > StringTokenizer(const std::string &str, const char *delimiters, bool tokenCompression=true)
Function to take a string and a list of delimiters and split the string into tokens based on those de...
Definition: StringUtils.hpp:23
#define ARMNN_LOG(severity)
Definition: Logging.hpp:205
bool has_value() const noexcept
Definition: Optional.hpp:53

◆ PrintQuantizedTensor()

void PrintQuantizedTensor ( OutputWriteInfo info)

Definition at line 218 of file NetworkExecutionUtils.hpp.

References OptionalBase::has_value(), OutputWriteInfo::m_OutputName, OutputWriteInfo::m_OutputTensorFile, OutputWriteInfo::m_PrintTensor, OutputWriteInfo::m_Tensor, OptionalReferenceSwitch< IsReference, T >::value(), and WriteToFile().

219 {
220  std::vector<float> dequantizedValues;
221  auto tensor = info.m_Tensor;
222  dequantizedValues = DequantizeArray<T>(tensor.GetMemoryArea(),
223  tensor.GetNumElements(),
224  tensor.GetInfo().GetQuantizationScale(),
225  tensor.GetInfo().GetQuantizationOffset());
226 
227  if (info.m_OutputTensorFile.has_value())
228  {
230  info.m_OutputName,
231  dequantizedValues.data(),
232  tensor.GetNumElements());
233  }
234 
235  if (info.m_PrintTensor)
236  {
237  std::for_each(dequantizedValues.begin(), dequantizedValues.end(), [&](float value)
238  {
239  printf("%f ", value);
240  });
241  }
242 }
const std::string & m_OutputName
const armnn::Tensor & m_Tensor
bool has_value() const noexcept
Definition: Optional.hpp:53
const armnn::Optional< std::string > & m_OutputTensorFile
void WriteToFile(const std::string &outputTensorFileName, const std::string &outputName, const T *const array, const unsigned int numElements)

◆ PrintTensor()

void PrintTensor ( OutputWriteInfo info,
const char *  formatString 
)

Definition at line 196 of file NetworkExecutionUtils.hpp.

References BaseTensor< MemoryType >::GetMemoryArea(), BaseTensor< MemoryType >::GetNumElements(), OptionalBase::has_value(), OutputWriteInfo::m_OutputName, OutputWriteInfo::m_OutputTensorFile, OutputWriteInfo::m_PrintTensor, OutputWriteInfo::m_Tensor, OptionalReferenceSwitch< IsReference, T >::value(), and WriteToFile().

197 {
198  const T* array = reinterpret_cast<const T*>(info.m_Tensor.GetMemoryArea());
199 
200  if (info.m_OutputTensorFile.has_value())
201  {
203  info.m_OutputName,
204  array,
205  info.m_Tensor.GetNumElements());
206  }
207 
208  if (info.m_PrintTensor)
209  {
210  for (unsigned int i = 0; i < info.m_Tensor.GetNumElements(); i++)
211  {
212  printf(formatString, array[i]);
213  }
214  }
215 }
unsigned int GetNumElements() const
Definition: Tensor.hpp:303
MemoryType GetMemoryArea() const
Definition: Tensor.hpp:305
const std::string & m_OutputName
const armnn::Tensor & m_Tensor
bool has_value() const noexcept
Definition: Optional.hpp:53
const armnn::Optional< std::string > & m_OutputTensorFile
void WriteToFile(const std::string &outputTensorFileName, const std::string &outputName, const T *const array, const unsigned int numElements)

◆ ValidatePath()

bool ValidatePath ( const std::string &  file,
const bool  expectFile 
)

Verifies if the given string is a valid path.

Reports invalid paths to std::err.

Parameters
filestring - A string containing the path to check
expectFilebool - If true, checks for a regular file.
Returns
bool - True if given string is a valid path., false otherwise.

Definition at line 40 of file NetworkExecutionUtils.cpp.

Referenced by CheckClTuningParameter(), DequantizeArray(), and ValidatePaths().

41 {
42  if (!fs::exists(file))
43  {
44  std::cerr << "Given file path '" << file << "' does not exist" << std::endl;
45  return false;
46  }
47  if (!fs::is_regular_file(file) && expectFile)
48  {
49  std::cerr << "Given file path '" << file << "' is not a regular file" << std::endl;
50  return false;
51  }
52  return true;
53 }

◆ ValidatePaths()

bool ValidatePaths ( const std::vector< std::string > &  fileVec,
const bool  expectFile 
)

Verifies if a given vector of strings are valid paths.

Reports invalid paths to std::err.

Parameters
fileVecvector of string - A vector of string containing the paths to check
expectFilebool - If true, checks for a regular file.
Returns
bool - True if all given strings are valid paths., false otherwise.

Definition at line 62 of file NetworkExecutionUtils.cpp.

References ValidatePath().

Referenced by DequantizeArray(), and ExecuteNetworkParams::ValidateParams().

63 {
64  bool allPathsValid = true;
65  for (auto const& file : fileVec)
66  {
67  if(!ValidatePath(file, expectFile))
68  {
69  allPathsValid = false;
70  }
71  }
72  return allPathsValid;
73 }
bool ValidatePath(const std::string &file, const bool expectFile)
Verifies if the given string is a valid path.

◆ WriteToFile()

void WriteToFile ( const std::string &  outputTensorFileName,
const std::string &  outputName,
const T *const  array,
const unsigned int  numElements 
)

Definition at line 168 of file NetworkExecutionUtils.hpp.

References ARMNN_LOG.

Referenced by PrintQuantizedTensor(), and PrintTensor().

172 {
173  std::ofstream outputTensorFile;
174  outputTensorFile.open(outputTensorFileName, std::ofstream::out | std::ofstream::trunc);
175  if (outputTensorFile.is_open())
176  {
177  outputTensorFile << outputName << ": ";
178  std::copy(array, array + numElements, std::ostream_iterator<T>(outputTensorFile, " "));
179  }
180  else
181  {
182  ARMNN_LOG(info) << "Output Tensor File: " << outputTensorFileName << " could not be opened!";
183  }
184  outputTensorFile.close();
185 }
#define ARMNN_LOG(severity)
Definition: Logging.hpp:205