diff options
Diffstat (limited to 'src/armnnSerializer/Serializer.cpp')
-rw-r--r-- | src/armnnSerializer/Serializer.cpp | 47 |
1 files changed, 26 insertions, 21 deletions
diff --git a/src/armnnSerializer/Serializer.cpp b/src/armnnSerializer/Serializer.cpp index 37ab326a28..cb7a5c456e 100644 --- a/src/armnnSerializer/Serializer.cpp +++ b/src/armnnSerializer/Serializer.cpp @@ -1420,25 +1420,43 @@ flatbuffers::Offset<flatbuffers::Vector<T>> SerializerVisitor::CreateDataVector( return fbVector; } -flatbuffers::Offset<serializer::ConstTensor> - SerializerVisitor::CreateConstTensorInfo(const armnn::ConstTensor& constTensor) +flatbuffers::Offset<TensorInfo> SerializerVisitor::CreateTensorInfo(const armnn::TensorInfo& tensorInfo) { - armnn::TensorInfo tensorInfo = constTensor.GetInfo(); - // Get the dimensions std::vector<unsigned int> shape; - for(unsigned int dim = 0; dim < tensorInfo.GetShape().GetNumDimensions(); ++dim) { shape.push_back(tensorInfo.GetShape()[dim]); } + if (tensorInfo.HasPerAxisQuantization()) + { + // Create FlatBuffer TensorInfo + auto flatBufferTensorInfo = + serializer::CreateTensorInfo(m_flatBufferBuilder, + m_flatBufferBuilder.CreateVector(shape), + GetFlatBufferDataType(tensorInfo.GetDataType()), + tensorInfo.GetQuantizationScales()[0], + tensorInfo.GetQuantizationOffset(), + m_flatBufferBuilder.CreateVector(tensorInfo.GetQuantizationScales()), + tensorInfo.GetQuantizationDim().value()); + return flatBufferTensorInfo; + } + // Create FlatBuffer TensorInfo auto flatBufferTensorInfo = serializer::CreateTensorInfo(m_flatBufferBuilder, m_flatBufferBuilder.CreateVector(shape), GetFlatBufferDataType(tensorInfo.GetDataType()), tensorInfo.GetQuantizationScale(), tensorInfo.GetQuantizationOffset()); + return flatBufferTensorInfo; +} + +flatbuffers::Offset<serializer::ConstTensor> + SerializerVisitor::CreateConstTensorInfo(const armnn::ConstTensor& constTensor) +{ + armnn::TensorInfo tensorInfo = constTensor.GetInfo(); + flatbuffers::Offset<void> fbPayload; switch (tensorInfo.GetDataType()) @@ -1471,6 +1489,7 @@ flatbuffers::Offset<serializer::ConstTensor> fbPayload = flatBuffersData.o; break; } + case armnn::DataType::QSymmS8: case armnn::DataType::QAsymmU8: case armnn::DataType::Boolean: default: @@ -1484,7 +1503,7 @@ flatbuffers::Offset<serializer::ConstTensor> } flatbuffers::Offset<serializer::ConstTensor> flatBufferConstTensor = serializer::CreateConstTensor( m_flatBufferBuilder, - flatBufferTensorInfo, + CreateTensorInfo(tensorInfo), GetFlatBufferConstTensorData(tensorInfo.GetDataType()), fbPayload); return flatBufferConstTensor; @@ -1533,24 +1552,10 @@ std::vector<fb::Offset<serializer::OutputSlot>> const IOutputSlot& outputSlot = layer->GetOutputSlot(slotIndex); const armnn::TensorInfo& tensorInfo = outputSlot.GetTensorInfo(); - // Get the dimensions - std::vector<unsigned int> shape; - for(unsigned int dim = 0; dim < tensorInfo.GetShape().GetNumDimensions(); ++dim) - { - shape.push_back(tensorInfo.GetShape()[dim]); - } - - // Create FlatBuffer TensorInfo - auto flatBufferTensorInfo = serializer::CreateTensorInfo(m_flatBufferBuilder, - m_flatBufferBuilder.CreateVector(shape), - GetFlatBufferDataType(tensorInfo.GetDataType()), - tensorInfo.GetQuantizationScale(), - tensorInfo.GetQuantizationOffset()); - // Create FlatBuffer Outputslot outputSlots.push_back(serializer::CreateOutputSlot(m_flatBufferBuilder, slotIndex, - flatBufferTensorInfo)); + CreateTensorInfo(tensorInfo))); } return outputSlots; } |