diff options
Diffstat (limited to 'src/armnnSerializer/Serializer.cpp')
-rw-r--r-- | src/armnnSerializer/Serializer.cpp | 109 |
1 files changed, 108 insertions, 1 deletions
diff --git a/src/armnnSerializer/Serializer.cpp b/src/armnnSerializer/Serializer.cpp index b229ae7e3f..f475be1015 100644 --- a/src/armnnSerializer/Serializer.cpp +++ b/src/armnnSerializer/Serializer.cpp @@ -91,6 +91,44 @@ void SerializerVisitor::VisitAdditionLayer(const IConnectableLayer* layer, const CreateAnyLayer(flatBufferAdditionLayer.o, serializer::Layer::Layer_AdditionLayer); } +// Build FlatBuffer for Convolution2dLayer +void SerializerVisitor::VisitConvolution2dLayer(const IConnectableLayer* layer, + const Convolution2dDescriptor& descriptor, + const ConstTensor& weights, + const Optional<ConstTensor>& biases, + const char* name) +{ + // Create FlatBuffer BaseLayer + auto flatBufferBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Convolution2d); + + auto flatBufferDescriptor = CreateConvolution2dDescriptor(m_flatBufferBuilder, + descriptor.m_PadLeft, + descriptor.m_PadRight, + descriptor.m_PadTop, + descriptor.m_PadBottom, + descriptor.m_StrideX, + descriptor.m_StrideY, + descriptor.m_BiasEnabled, + GetFlatBufferDataLayout(descriptor.m_DataLayout)); + auto flatBufferWeightsConstTensorInfo = CreateConstTensorInfo(weights); + flatbuffers::Offset<serializer::ConstTensor> flatBufferBiasesConstTensorInfo; + + if (biases.has_value()) + { + flatBufferBiasesConstTensorInfo = CreateConstTensorInfo(biases.value()); + } + + // Create the FlatBuffer Convolution2dLayer + auto flatBufferLayer = CreateConvolution2dLayer(m_flatBufferBuilder, + flatBufferBaseLayer, + flatBufferDescriptor, + flatBufferWeightsConstTensorInfo, + flatBufferBiasesConstTensorInfo); + + // Add the AnyLayer to the FlatBufferLayers + CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_Convolution2dLayer); +} + // Build FlatBuffer for Multiplication Layer void SerializerVisitor::VisitMultiplicationLayer(const IConnectableLayer* layer, const char* name) { @@ -200,9 +238,78 @@ void SerializerVisitor::CreateAnyLayer(const flatbuffers::Offset<void>& layer, c m_serializedLayers.push_back(anyLayer); } +template <typename T> +flatbuffers::Offset<flatbuffers::Vector<T>> SerializerVisitor::CreateDataVector(const void* memory, unsigned int size) +{ + const T* buffer = reinterpret_cast<const T*>(memory); + std::vector<T> vector(buffer, buffer + (size / sizeof(T))); + auto fbVector = m_flatBufferBuilder.CreateVector(vector); + return fbVector; +} + +flatbuffers::Offset<serializer::ConstTensor> SerializerVisitor::CreateConstTensorInfo(const ConstTensor& constTensor) +{ + TensorInfo tensorInfo = constTensor.GetInfo(); + + // Get the dimensions + std::vector<unsigned int> shape; + + for(unsigned int dim = 0; dim < tensorInfo.GetShape().GetNumDimensions(); ++dim) + { + shape.push_back(tensorInfo.GetShape()[dim]); + } + + // Create FlatBuffer TensorInfo + auto flatBufferTensorInfo = serializer::CreateTensorInfo(m_flatBufferBuilder, + m_flatBufferBuilder.CreateVector(shape), + GetFlatBufferDataType(tensorInfo.GetDataType()), + tensorInfo.GetQuantizationScale(), + tensorInfo.GetQuantizationOffset()); + flatbuffers::Offset<void> fbPayload; + + switch (tensorInfo.GetDataType()) + { + case DataType::Float32: + case DataType::Signed32: + { + auto fbVector = CreateDataVector<int32_t>(constTensor.GetMemoryArea(), constTensor.GetNumBytes()); + flatbuffers::Offset<serializer::IntData> flatBuffersData = serializer::CreateIntData( + m_flatBufferBuilder, + fbVector); + fbPayload = flatBuffersData.o; + break; + } + case DataType::Float16: + { + auto fbVector = CreateDataVector<int16_t>(constTensor.GetMemoryArea(), constTensor.GetNumBytes()); + flatbuffers::Offset<serializer::ShortData> flatBuffersData = serializer::CreateShortData( + m_flatBufferBuilder, + fbVector); + fbPayload = flatBuffersData.o; + break; + } + case DataType::QuantisedAsymm8: + case DataType::Boolean: + default: + { + auto fbVector = CreateDataVector<int8_t>(constTensor.GetMemoryArea(), constTensor.GetNumBytes()); + flatbuffers::Offset<serializer::ByteData> flatBuffersData = serializer::CreateByteData( + m_flatBufferBuilder, + fbVector); + fbPayload = flatBuffersData.o; + } + } + flatbuffers::Offset<serializer::ConstTensor> flatBufferConstTensor = serializer::CreateConstTensor( + m_flatBufferBuilder, + flatBufferTensorInfo, + GetFlatBufferConstTensorData(tensorInfo.GetDataType()), + fbPayload); + return flatBufferConstTensor; +} + std::vector<fb::Offset<serializer::InputSlot>> SerializerVisitor::CreateInputSlots(const IConnectableLayer* layer) { - std::vector<fb::Offset <serializer::InputSlot>> inputSlots; + std::vector<fb::Offset<serializer::InputSlot>> inputSlots; // Get the InputSlots for (unsigned int slotIndex = 0; slotIndex<layer->GetNumInputSlots(); ++slotIndex) |