diff options
Diffstat (limited to 'src/armnnSerializer')
-rw-r--r-- | src/armnnSerializer/ArmnnSchema.fbs | 2 | ||||
-rw-r--r-- | src/armnnSerializer/Serializer.cpp | 6 | ||||
-rw-r--r-- | src/armnnSerializer/test/SerializerTestUtils.cpp | 2 | ||||
-rw-r--r-- | src/armnnSerializer/test/SerializerTests.cpp | 30 |
4 files changed, 38 insertions, 2 deletions
diff --git a/src/armnnSerializer/ArmnnSchema.fbs b/src/armnnSerializer/ArmnnSchema.fbs index 76a1c12787..995a6013c1 100644 --- a/src/armnnSerializer/ArmnnSchema.fbs +++ b/src/armnnSerializer/ArmnnSchema.fbs @@ -106,6 +106,8 @@ table ConstTensor { table InputSlot { index:uint; connection:Connection; + isOverridden:bool; + overriddenTensorInfo:TensorInfo; } table OutputSlot { diff --git a/src/armnnSerializer/Serializer.cpp b/src/armnnSerializer/Serializer.cpp index cf098eb5f8..e10b66f51d 100644 --- a/src/armnnSerializer/Serializer.cpp +++ b/src/armnnSerializer/Serializer.cpp @@ -1974,12 +1974,16 @@ std::vector<fb::Offset<serializer::InputSlot>> // Get the Connection for the InputSlot const IOutputSlot* connection = inputSlot.GetConnection(); + bool isOverridden = inputSlot.IsTensorInfoOverridden(); + + flatbuffers::Offset<TensorInfo> overriddenTensorInfo = CreateTensorInfo(inputSlot.GetTensorInfo()); // Create FlatBuffer Connection serializer::Connection conn(GetSerializedId(inputSlot.GetConnection()->GetOwningLayerGuid()), connection->CalculateIndexOnOwner()); // Create FlatBuffer InputSlot - inputSlots.push_back(serializer::CreateInputSlot(m_flatBufferBuilder, slotIndex, &conn)); + inputSlots.push_back(serializer::CreateInputSlot(m_flatBufferBuilder, slotIndex, &conn, isOverridden, + overriddenTensorInfo)); } return inputSlots; } diff --git a/src/armnnSerializer/test/SerializerTestUtils.cpp b/src/armnnSerializer/test/SerializerTestUtils.cpp index 187384777d..c0f90f9edf 100644 --- a/src/armnnSerializer/test/SerializerTestUtils.cpp +++ b/src/armnnSerializer/test/SerializerTestUtils.cpp @@ -49,7 +49,7 @@ void LayerVerifierBase::VerifyNameAndConnections(const armnn::IConnectableLayer* const armnn::IOutputSlot* connectedOutput = layer->GetInputSlot(i).GetConnection(); CHECK(connectedOutput); - const armnn::TensorInfo& connectedInfo = connectedOutput->GetTensorInfo(); + const armnn::TensorInfo& connectedInfo = layer->GetInputSlot(i).GetTensorInfo(); CHECK(connectedInfo.GetShape() == m_InputTensorInfos[i].GetShape()); CHECK(GetDataTypeName(connectedInfo.GetDataType()) == GetDataTypeName(m_InputTensorInfos[i].GetDataType())); diff --git a/src/armnnSerializer/test/SerializerTests.cpp b/src/armnnSerializer/test/SerializerTests.cpp index 7a988ee134..49971d2ecc 100644 --- a/src/armnnSerializer/test/SerializerTests.cpp +++ b/src/armnnSerializer/test/SerializerTests.cpp @@ -3008,4 +3008,34 @@ TEST_CASE("SerializeDeserializeNonLinearNetwork") deserializedNetwork->ExecuteStrategy(verifier); } +TEST_CASE("SerializeOverriddenSlot") +{ + const std::string layerName("subtraction"); + const armnn::TensorInfo info({ 1, 4 }, armnn::DataType::Float32); + const armnn::TensorInfo incompatibleInfo({ 4, 1 }, armnn::DataType::Float32); + + armnn::INetworkPtr network = armnn::INetwork::Create(); + armnn::IConnectableLayer* const inputLayer0 = network->AddInputLayer(0); + armnn::IConnectableLayer* const inputLayer1 = network->AddInputLayer(1); + armnn::IConnectableLayer* const subtractionLayer = network->AddElementwiseBinaryLayer(armnn::BinaryOperation::Sub, + layerName.c_str()); + armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0); + + inputLayer0->GetOutputSlot(0).Connect(subtractionLayer->GetInputSlot(0)); + inputLayer1->GetOutputSlot(0).Connect(subtractionLayer->GetInputSlot(1)); + subtractionLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0)); + + inputLayer0->GetOutputSlot(0).SetTensorInfo(info); + inputLayer1->GetOutputSlot(0).SetTensorInfo(incompatibleInfo); + subtractionLayer->GetInputSlot(1).SetTensorInfo(info); + subtractionLayer->GetOutputSlot(0).SetTensorInfo(info); + + armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network)); + CHECK(deserializedNetwork); + + LayerVerifierBase verifier(layerName, {info, info}, {info}); + deserializedNetwork->ExecuteStrategy(verifier); +} + + } |