ArmNN
 20.05
ArmNNQuantizerMain.cpp File Reference
#include "CommandLineProcessor.hpp"
#include <armnnDeserializer/IDeserializer.hpp>
#include <armnnQuantizer/INetworkQuantizer.hpp>
#include <armnnSerializer/ISerializer.hpp>
#include "QuantizationDataSet.hpp"
#include "QuantizationInput.hpp"
#include <algorithm>
#include <fstream>
#include <iostream>

Go to the source code of this file.

Functions

int main (int argc, char *argv[])
 

Function Documentation

◆ main()

int main ( int  argc,
char *  argv[] 
)

Definition at line 17 of file ArmNNQuantizerMain.cpp.

References ISerializer::Create(), IDeserializer::Create(), INetworkQuantizer::Create(), CommandLineProcessor::GetInputFileName(), CommandLineProcessor::GetOutputDirectoryName(), CommandLineProcessor::GetOutputFileName(), CommandLineProcessor::GetQuantizationDataSet(), CommandLineProcessor::GetQuantizationScheme(), InputLayerVisitor::GetTensorInfo(), CommandLineProcessor::HasPreservedDataType(), CommandLineProcessor::HasQuantizationData(), QuantizationDataSet::IsEmpty(), QuantizerOptions::m_ActivationFormat, QuantizerOptions::m_PreserveType, CommandLineProcessor::ProcessCommandLine(), armnn::QAsymmS8, armnn::QAsymmU8, and armnn::QSymmS16.

18 {
20  if (!cmdline.ProcessCommandLine(argc, argv))
21  {
22  return -1;
23  }
25  std::ifstream inputFileStream(cmdline.GetInputFileName(), std::ios::binary);
26  std::vector<std::uint8_t> binaryContent;
27  while (inputFileStream)
28  {
29  char c;
30  inputFileStream.get(c);
31  if (inputFileStream)
32  {
33  binaryContent.push_back(static_cast<std::uint8_t>(c));
34  }
35  }
36  inputFileStream.close();
37 
38  armnn::QuantizerOptions quantizerOptions;
39 
40  if (cmdline.GetQuantizationScheme() == "QAsymmS8")
41  {
43  }
44  else if (cmdline.GetQuantizationScheme() == "QSymmS16")
45  {
47  }
48  else
49  {
51  }
52 
53  quantizerOptions.m_PreserveType = cmdline.HasPreservedDataType();
54 
55  armnn::INetworkPtr network = parser->CreateNetworkFromBinary(binaryContent);
56  armnn::INetworkQuantizerPtr quantizer = armnn::INetworkQuantizer::Create(network.get(), quantizerOptions);
57 
58  if (cmdline.HasQuantizationData())
59  {
61  if (!dataSet.IsEmpty())
62  {
63  // Get the Input Tensor Infos
64  armnnQuantizer::InputLayerVisitor inputLayerVisitor;
65  network->Accept(inputLayerVisitor);
66 
67  for (armnnQuantizer::QuantizationInput quantizationInput : dataSet)
68  {
69  armnn::InputTensors inputTensors;
70  std::vector<std::vector<float>> inputData(quantizationInput.GetNumberOfInputs());
71  std::vector<armnn::LayerBindingId> layerBindingIds = quantizationInput.GetLayerBindingIds();
72  unsigned int count = 0;
73  for (armnn::LayerBindingId layerBindingId : quantizationInput.GetLayerBindingIds())
74  {
75  armnn::TensorInfo tensorInfo = inputLayerVisitor.GetTensorInfo(layerBindingId);
76  inputData[count] = quantizationInput.GetDataForEntry(layerBindingId);
77  armnn::ConstTensor inputTensor(tensorInfo, inputData[count].data());
78  inputTensors.push_back(std::make_pair(layerBindingId, inputTensor));
79  count++;
80  }
81  quantizer->Refine(inputTensors);
82  }
83  }
84  }
85 
86  armnn::INetworkPtr quantizedNetwork = quantizer->ExportNetwork();
88  serializer->Serialize(*quantizedNetwork);
89 
90  std::string output(cmdline.GetOutputDirectoryName());
91  output.append(cmdline.GetOutputFileName());
92  std::ofstream outputFileStream;
93  outputFileStream.open(output);
94  serializer->SaveSerializedToStream(outputFileStream);
95  outputFileStream.flush();
96  outputFileStream.close();
97 
98  return 0;
99 }
QuantizationDataSet is a structure which is created after parsing a quantization CSV file...
bool ProcessCommandLine(int argc, char *argv[])
std::unique_ptr< class INetworkQuantizer, void(*)(INetworkQuantizer *quantizer)> INetworkQuantizerPtr
static IDeserializerPtr Create()
std::vector< std::pair< LayerBindingId, class ConstTensor > > InputTensors
Definition: Tensor.hpp:225
QuantizationInput for specific pass ID, can list a corresponding raw data file for each LayerBindingI...
Visitor class implementation to gather the TensorInfo for LayerBindingID for creation of ConstTensor ...
int LayerBindingId
Type of identifiers for bindable layers (inputs, outputs).
Definition: Types.hpp:171
armnn::TensorInfo GetTensorInfo(armnn::LayerBindingId)
A tensor defined by a TensorInfo (shape and data type) and an immutable backing store.
Definition: Tensor.hpp:199
std::unique_ptr< IDeserializer, void(*)(IDeserializer *parser)> IDeserializerPtr
std::unique_ptr< ISerializer, void(*)(ISerializer *serializer)> ISerializerPtr
Definition: ISerializer.hpp:15
static ISerializerPtr Create()
std::unique_ptr< INetwork, void(*)(INetwork *network)> INetworkPtr
Definition: INetwork.hpp:101
static INetworkQuantizerPtr Create(INetwork *inputNetwork, const QuantizerOptions &options=QuantizerOptions())
Create Quantizer object wrapped in unique_ptr.