diff options
Diffstat (limited to 'src/armnnQuantizer')
-rw-r--r-- | src/armnnQuantizer/ArmNNQuantizerMain.cpp | 6 | ||||
-rw-r--r-- | src/armnnQuantizer/QuantizationDataSet.cpp | 30 | ||||
-rw-r--r-- | src/armnnQuantizer/QuantizationDataSet.hpp | 16 |
3 files changed, 49 insertions, 3 deletions
diff --git a/src/armnnQuantizer/ArmNNQuantizerMain.cpp b/src/armnnQuantizer/ArmNNQuantizerMain.cpp index 219363edbb..49652efe25 100644 --- a/src/armnnQuantizer/ArmNNQuantizerMain.cpp +++ b/src/armnnQuantizer/ArmNNQuantizerMain.cpp @@ -61,8 +61,8 @@ int main(int argc, char* argv[]) if (!dataSet.IsEmpty()) { // Get the Input Tensor Infos - armnnQuantizer::InputLayerVisitor inputLayerVisitor; - network->Accept(inputLayerVisitor); + armnnQuantizer::InputLayerStrategy inputLayerStrategy; + network->ExecuteStrategy(inputLayerStrategy); for (armnnQuantizer::QuantizationInput quantizationInput : dataSet) { @@ -72,7 +72,7 @@ int main(int argc, char* argv[]) unsigned int count = 0; for (armnn::LayerBindingId layerBindingId : quantizationInput.GetLayerBindingIds()) { - armnn::TensorInfo tensorInfo = inputLayerVisitor.GetTensorInfo(layerBindingId); + armnn::TensorInfo tensorInfo = inputLayerStrategy.GetTensorInfo(layerBindingId); inputData[count] = quantizationInput.GetDataForEntry(layerBindingId); armnn::ConstTensor inputTensor(tensorInfo, inputData[count].data()); inputTensors.push_back(std::make_pair(layerBindingId, inputTensor)); diff --git a/src/armnnQuantizer/QuantizationDataSet.cpp b/src/armnnQuantizer/QuantizationDataSet.cpp index acd301a470..99fc021a51 100644 --- a/src/armnnQuantizer/QuantizationDataSet.cpp +++ b/src/armnnQuantizer/QuantizationDataSet.cpp @@ -47,6 +47,36 @@ QuantizationDataSet::~QuantizationDataSet() { } + +/// Visitor class implementation to gather the TensorInfo for LayerBindingID for creation of ConstTensor for Refine. + +void InputLayerStrategy::ExecuteStrategy(const armnn::IConnectableLayer* layer, + const armnn::BaseDescriptor& descriptor, + const std::vector<armnn::ConstTensor>& constants, + const char* name, + const armnn::LayerBindingId id) +{ + armnn::IgnoreUnused(name, descriptor, constants); + + m_TensorInfos.emplace(id, layer->GetOutputSlot(0).GetTensorInfo()); +} + + + + +armnn::TensorInfo InputLayerStrategy::GetTensorInfo(armnn::LayerBindingId layerBindingId) +{ + auto iterator = m_TensorInfos.find(layerBindingId); + if (iterator != m_TensorInfos.end()) + { + return m_TensorInfos.at(layerBindingId); + } + else + { + throw armnn::Exception("Could not retrieve tensor info for binding ID " + std::to_string(layerBindingId)); + } +} + void InputLayerVisitor::VisitInputLayer(const armnn::IConnectableLayer* layer, armnn::LayerBindingId id, const char* name) diff --git a/src/armnnQuantizer/QuantizationDataSet.hpp b/src/armnnQuantizer/QuantizationDataSet.hpp index 3a97630ccf..47b893a7f7 100644 --- a/src/armnnQuantizer/QuantizationDataSet.hpp +++ b/src/armnnQuantizer/QuantizationDataSet.hpp @@ -43,6 +43,22 @@ private: }; /// Visitor class implementation to gather the TensorInfo for LayerBindingID for creation of ConstTensor for Refine. +class InputLayerStrategy : public armnn::IStrategy +{ +public: + virtual void ExecuteStrategy(const armnn::IConnectableLayer* layer, + const armnn::BaseDescriptor& descriptor, + const std::vector<armnn::ConstTensor>& constants, + const char* name, + const armnn::LayerBindingId id = 0) override; + + armnn::TensorInfo GetTensorInfo(armnn::LayerBindingId); +private: + std::map<armnn::LayerBindingId, armnn::TensorInfo> m_TensorInfos; +}; + + +/// Visitor class implementation to gather the TensorInfo for LayerBindingID for creation of ConstTensor for Refine. class InputLayerVisitor : public armnn::LayerVisitorBase<armnn::VisitorNoThrowPolicy> { public: |