aboutsummaryrefslogtreecommitdiff
path: root/src/armnnQuantizer/QuantizationDataSet.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnQuantizer/QuantizationDataSet.cpp')
-rw-r--r--src/armnnQuantizer/QuantizationDataSet.cpp30
1 files changed, 30 insertions, 0 deletions
diff --git a/src/armnnQuantizer/QuantizationDataSet.cpp b/src/armnnQuantizer/QuantizationDataSet.cpp
index acd301a470..99fc021a51 100644
--- a/src/armnnQuantizer/QuantizationDataSet.cpp
+++ b/src/armnnQuantizer/QuantizationDataSet.cpp
@@ -47,6 +47,36 @@ QuantizationDataSet::~QuantizationDataSet()
{
}
+
+/// Visitor class implementation to gather the TensorInfo for LayerBindingID for creation of ConstTensor for Refine.
+
+void InputLayerStrategy::ExecuteStrategy(const armnn::IConnectableLayer* layer,
+ const armnn::BaseDescriptor& descriptor,
+ const std::vector<armnn::ConstTensor>& constants,
+ const char* name,
+ const armnn::LayerBindingId id)
+{
+ armnn::IgnoreUnused(name, descriptor, constants);
+
+ m_TensorInfos.emplace(id, layer->GetOutputSlot(0).GetTensorInfo());
+}
+
+
+
+
+armnn::TensorInfo InputLayerStrategy::GetTensorInfo(armnn::LayerBindingId layerBindingId)
+{
+ auto iterator = m_TensorInfos.find(layerBindingId);
+ if (iterator != m_TensorInfos.end())
+ {
+ return m_TensorInfos.at(layerBindingId);
+ }
+ else
+ {
+ throw armnn::Exception("Could not retrieve tensor info for binding ID " + std::to_string(layerBindingId));
+ }
+}
+
void InputLayerVisitor::VisitInputLayer(const armnn::IConnectableLayer* layer,
armnn::LayerBindingId id,
const char* name)