From 80fbcd5f4d7b362360963af1df0121aa6b561576 Mon Sep 17 00:00:00 2001 From: Matthew Sloyan Date: Thu, 7 Jan 2021 13:28:47 +0000 Subject: IVGCVSW-5483 'Implement Loading and Saving to File' * Implemented Serialization and Deserialization of CLContext. * Fixed flatbuffers android-nn-driver dependency. !android-nn-driver:4772 Signed-off-by: Matthew Sloyan Signed-off-by: Sadik Armagan Change-Id: If806f050535ffaa70922ba0f1ffe7bb10f902329 --- Android.mk | 2 + src/backends/cl/CMakeLists.txt | 20 +++ src/backends/cl/ClContextDeserializer.cpp | 80 ++++++++++ src/backends/cl/ClContextDeserializer.hpp | 41 +++++ src/backends/cl/ClContextSchema.fbs | 21 +++ src/backends/cl/ClContextSchema_generated.h | 185 ++++++++++++++++++++++ src/backends/cl/ClContextSerializer.cpp | 57 +++++++ src/backends/cl/ClContextSerializer.hpp | 35 ++++ src/backends/cl/ClWorkloadFactory.cpp | 13 +- src/backends/cl/backend.mk | 3 + src/backends/cl/test/CMakeLists.txt | 1 + src/backends/cl/test/ClContextSerializerTests.cpp | 138 ++++++++++++++++ 12 files changed, 594 insertions(+), 2 deletions(-) create mode 100644 src/backends/cl/ClContextDeserializer.cpp create mode 100644 src/backends/cl/ClContextDeserializer.hpp create mode 100644 src/backends/cl/ClContextSchema.fbs create mode 100644 src/backends/cl/ClContextSchema_generated.h create mode 100644 src/backends/cl/ClContextSerializer.cpp create mode 100644 src/backends/cl/ClContextSerializer.hpp create mode 100644 src/backends/cl/test/ClContextSerializerTests.cpp diff --git a/Android.mk b/Android.mk index d683c2312f..df0bb040a0 100644 --- a/Android.mk +++ b/Android.mk @@ -238,6 +238,7 @@ LOCAL_SRC_FILES := \ src/profiling/backends/BackendProfiling.cpp LOCAL_STATIC_LIBRARIES := \ + libflatbuffers-framework \ arm_compute_library LOCAL_SHARED_LIBRARIES := \ @@ -422,6 +423,7 @@ endif LOCAL_STATIC_LIBRARIES := \ libneuralnetworks_common \ libboost_unit_test_framework \ + libflatbuffers-framework \ arm_compute_library LOCAL_WHOLE_STATIC_LIBRARIES := libarmnn diff --git a/src/backends/cl/CMakeLists.txt b/src/backends/cl/CMakeLists.txt index 4b5890af50..bfb99dde96 100644 --- a/src/backends/cl/CMakeLists.txt +++ b/src/backends/cl/CMakeLists.txt @@ -4,7 +4,23 @@ # if(ARMCOMPUTECL) + find_program(FLATC flatc + HINTS ${FLATC_DIR} + DOC "Path to 'flatc', the flatbuffers compiler") + if (NOT FLATC) + message(SEND_ERROR "flatc not found. Specify the full path of the flatc executable with -DFLATC=") + endif() + + add_custom_command( + # Generate an ClContextSchema_generated.h file if it doesn't exist, or update it when necessary otherwise + OUTPUT ClContextSchema_generated.h DEPENDS ClContextSchema.fbs + COMMAND ${FLATC} -o ${CMAKE_CURRENT_BINARY_DIR} --cpp ${CMAKE_CURRENT_SOURCE_DIR}/ClContextSchema.fbs + #COMMAND cp ${CMAKE_CURRENT_BINARY_DIR}/ClContextSchema_generated.h + # ${CMAKE_CURRENT_SOURCE_DIR}/ClContextSchema_generated.h + ) + list(APPEND armnnClBackend_sources + ClContextSchema_generated.h ClBackend.cpp ClBackend.hpp ClBackendContext.cpp @@ -14,6 +30,10 @@ if(ARMCOMPUTECL) ClBackendModelContext.hpp ClContextControl.cpp ClContextControl.hpp + ClContextDeserializer.hpp + ClContextDeserializer.cpp + ClContextSerializer.hpp + ClContextSerializer.cpp ClLayerSupport.cpp ClLayerSupport.hpp ClRegistryInitializer.cpp diff --git a/src/backends/cl/ClContextDeserializer.cpp b/src/backends/cl/ClContextDeserializer.cpp new file mode 100644 index 0000000000..8a1b585d47 --- /dev/null +++ b/src/backends/cl/ClContextDeserializer.cpp @@ -0,0 +1,80 @@ +// +// Copyright © 2020 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "ClContextDeserializer.hpp" +#include "ClContextSchema_generated.h" + +#include +#include + +#include + +#include + +#include +#include +#include +#include + +namespace armnn +{ + +void ClContextDeserializer::Deserialize(arm_compute::CLCompileContext& clCompileContext, + cl::Context& context, + cl::Device& device, + const std::string& filePath) +{ + std::ifstream inputFileStream(filePath, std::ios::binary); + std::vector binaryContent; + while (inputFileStream) + { + char input; + inputFileStream.get(input); + if (inputFileStream) + { + binaryContent.push_back(static_cast(input)); + } + } + inputFileStream.close(); + DeserializeFromBinary(clCompileContext, context, device, binaryContent); +} + +void ClContextDeserializer::DeserializeFromBinary(arm_compute::CLCompileContext& clCompileContext, + cl::Context& context, + cl::Device& device, + const std::vector& binaryContent) +{ + if (binaryContent.data() == nullptr) + { + throw InvalidArgumentException(fmt::format("Invalid (null) binary content {}", + CHECK_LOCATION().AsString())); + } + + size_t binaryContentSize = binaryContent.size(); + flatbuffers::Verifier verifier(binaryContent.data(), binaryContentSize); + if (verifier.VerifyBuffer() == false) + { + throw ParseException(fmt::format("Buffer doesn't conform to the expected Armnn " + "flatbuffers format. size:{0} {1}", + binaryContentSize, + CHECK_LOCATION().AsString())); + } + auto clContext = GetClContext(binaryContent.data()); + + for (Program const* program : *clContext->programs()) + { + auto programName = program->name()->c_str(); + auto programBinary = program->binary(); + std::vector binary(programBinary->begin(), programBinary->begin() + programBinary->size()); + + cl::Program::Binaries binaries{ binary }; + std::vector devices {device}; + cl::Program theProgram(context, devices, binaries); + theProgram.build(); + clCompileContext.add_built_program(programName, theProgram); + } +} + +} // namespace armnn diff --git a/src/backends/cl/ClContextDeserializer.hpp b/src/backends/cl/ClContextDeserializer.hpp new file mode 100644 index 0000000000..e3a9b9deb4 --- /dev/null +++ b/src/backends/cl/ClContextDeserializer.hpp @@ -0,0 +1,41 @@ +// +// Copyright © 2020 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include + +namespace armnn +{ + +class ClContextDeserializer +{ +public: + ClContextDeserializer() = default; + ~ClContextDeserializer() = default; + + /// Deserializes the CLCompileContext built-in programs from a binary file + /// @param [in] clCompileContext The CLCompileContext to be serialized + /// @param [in] context The CL Kernel context built-in program will be created from + /// @param [in] device The CL Kernel device built-in program will be created from + /// @param [in] filePath The serialized file + void Deserialize(arm_compute::CLCompileContext& clCompileContext, + cl::Context& context, + cl::Device& device, + const std::string& filePath); + + /// Deserializes the CLCompileContext built-in programs from binary file contents + /// @param [in] clCompileContext The CLCompileContext to be serialized + /// @param [in] context The CL Kernel context built-in program will be created from + /// @param [in] device The CL Kernel device built-in program will be created from + /// @param [in] filePath The serialized file + void DeserializeFromBinary(arm_compute::CLCompileContext& clCompileContext, + cl::Context& context, + cl::Device& device, + const std::vector& binaryContent); + +}; + +} // namespace armnn \ No newline at end of file diff --git a/src/backends/cl/ClContextSchema.fbs b/src/backends/cl/ClContextSchema.fbs new file mode 100644 index 0000000000..c517d8039a --- /dev/null +++ b/src/backends/cl/ClContextSchema.fbs @@ -0,0 +1,21 @@ +// +// Copyright © 2020 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +namespace armnn; + +file_identifier "ARMN"; + +file_extension "armnn"; + +table ClContext { + programs:[Program]; +} + +table Program { + name:string; + binary:[ubyte]; +} + +root_type ClContext; \ No newline at end of file diff --git a/src/backends/cl/ClContextSchema_generated.h b/src/backends/cl/ClContextSchema_generated.h new file mode 100644 index 0000000000..88b759f78f --- /dev/null +++ b/src/backends/cl/ClContextSchema_generated.h @@ -0,0 +1,185 @@ +// +// Copyright © 2020 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +// automatically generated by the FlatBuffers compiler, do not modify + +#ifndef FLATBUFFERS_GENERATED_CLCONTEXTSCHEMA_ARMNN_H_ +#define FLATBUFFERS_GENERATED_CLCONTEXTSCHEMA_ARMNN_H_ + +#include "flatbuffers/flatbuffers.h" + +namespace armnn { + +struct ClContext; +struct ClContextBuilder; + +struct Program; +struct ProgramBuilder; + +struct ClContext FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef ClContextBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_PROGRAMS = 4 + }; + const flatbuffers::Vector> *programs() const { + return GetPointer> *>(VT_PROGRAMS); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_PROGRAMS) && + verifier.VerifyVector(programs()) && + verifier.VerifyVectorOfTables(programs()) && + verifier.EndTable(); + } +}; + +struct ClContextBuilder { + typedef ClContext Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_programs(flatbuffers::Offset>> programs) { + fbb_.AddOffset(ClContext::VT_PROGRAMS, programs); + } + explicit ClContextBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ClContextBuilder &operator=(const ClContextBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateClContext( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset>> programs = 0) { + ClContextBuilder builder_(_fbb); + builder_.add_programs(programs); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateClContextDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const std::vector> *programs = nullptr) { + auto programs__ = programs ? _fbb.CreateVector>(*programs) : 0; + return armnn::CreateClContext( + _fbb, + programs__); +} + +struct Program FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef ProgramBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_NAME = 4, + VT_BINARY = 6 + }; + const flatbuffers::String *name() const { + return GetPointer(VT_NAME); + } + const flatbuffers::Vector *binary() const { + return GetPointer *>(VT_BINARY); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_NAME) && + verifier.VerifyString(name()) && + VerifyOffset(verifier, VT_BINARY) && + verifier.VerifyVector(binary()) && + verifier.EndTable(); + } +}; + +struct ProgramBuilder { + typedef Program Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_name(flatbuffers::Offset name) { + fbb_.AddOffset(Program::VT_NAME, name); + } + void add_binary(flatbuffers::Offset> binary) { + fbb_.AddOffset(Program::VT_BINARY, binary); + } + explicit ProgramBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ProgramBuilder &operator=(const ProgramBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateProgram( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset name = 0, + flatbuffers::Offset> binary = 0) { + ProgramBuilder builder_(_fbb); + builder_.add_binary(binary); + builder_.add_name(name); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateProgramDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const char *name = nullptr, + const std::vector *binary = nullptr) { + auto name__ = name ? _fbb.CreateString(name) : 0; + auto binary__ = binary ? _fbb.CreateVector(*binary) : 0; + return armnn::CreateProgram( + _fbb, + name__, + binary__); +} + +inline const armnn::ClContext *GetClContext(const void *buf) { + return flatbuffers::GetRoot(buf); +} + +inline const armnn::ClContext *GetSizePrefixedClContext(const void *buf) { + return flatbuffers::GetSizePrefixedRoot(buf); +} + +inline const char *ClContextIdentifier() { + return "ARMN"; +} + +inline bool ClContextBufferHasIdentifier(const void *buf) { + return flatbuffers::BufferHasIdentifier( + buf, ClContextIdentifier()); +} + +inline bool VerifyClContextBuffer( + flatbuffers::Verifier &verifier) { + return verifier.VerifyBuffer(ClContextIdentifier()); +} + +inline bool VerifySizePrefixedClContextBuffer( + flatbuffers::Verifier &verifier) { + return verifier.VerifySizePrefixedBuffer(ClContextIdentifier()); +} + +inline const char *ClContextExtension() { + return "armnn"; +} + +inline void FinishClContextBuffer( + flatbuffers::FlatBufferBuilder &fbb, + flatbuffers::Offset root) { + fbb.Finish(root, ClContextIdentifier()); +} + +inline void FinishSizePrefixedClContextBuffer( + flatbuffers::FlatBufferBuilder &fbb, + flatbuffers::Offset root) { + fbb.FinishSizePrefixed(root, ClContextIdentifier()); +} + +} // namespace armnn + +#endif // FLATBUFFERS_GENERATED_CLCONTEXTSCHEMA_ARMNN_H_ diff --git a/src/backends/cl/ClContextSerializer.cpp b/src/backends/cl/ClContextSerializer.cpp new file mode 100644 index 0000000000..db89203a25 --- /dev/null +++ b/src/backends/cl/ClContextSerializer.cpp @@ -0,0 +1,57 @@ +// +// Copyright © 2020 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "ClContextSerializer.hpp" +#include "ClContextSchema_generated.h" + +#include +#include +#include + +#include + +#include +#include +#include +#include + +namespace armnn +{ + +void ClContextSerializer::Serialize(const arm_compute::CLCompileContext& clCompileContext) +{ + // Get map of built programs from clCompileContext + std::map builtProgramsMap = clCompileContext.get_built_programs(); + if (builtProgramsMap.empty()) + { + ARMNN_LOG(warning) << "There are no built programs to be serialised."; + return; + } + + // Create Flatbuffer CL Programs + std::vector> clPrograms; + for(const auto& program : builtProgramsMap) + { + std::vector> binaries = program.second.getInfo(); + clPrograms.push_back(CreateProgram(m_FlatBufferBuilder, + m_FlatBufferBuilder.CreateString(program.first), + m_FlatBufferBuilder.CreateVector(binaries[0]))); + } + + // Create Flatbuffer CLContext + auto clContext = CreateClContext(m_FlatBufferBuilder, m_FlatBufferBuilder.CreateVector(clPrograms)); + + m_FlatBufferBuilder.Finish(clContext); +} + +bool ClContextSerializer::SaveSerializedToStream(std::ostream& stream) +{ + // Write to a stream + auto bytesToWrite = armnn::numeric_cast(m_FlatBufferBuilder.GetSize()); + stream.write(reinterpret_cast(m_FlatBufferBuilder.GetBufferPointer()), bytesToWrite); + return !stream.bad(); +} + +} // namespace armnn diff --git a/src/backends/cl/ClContextSerializer.hpp b/src/backends/cl/ClContextSerializer.hpp new file mode 100644 index 0000000000..71e2b1f1c9 --- /dev/null +++ b/src/backends/cl/ClContextSerializer.hpp @@ -0,0 +1,35 @@ +// +// Copyright © 2020 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include + +#include + +namespace armnn +{ + +class ClContextSerializer +{ +public: + ClContextSerializer() = default; + ~ClContextSerializer() = default; + + /// Serializes the CLCompileContext built-in programs + /// @param [in] clCompileContext The CLCompileContext to be serialized. + void Serialize(const arm_compute::CLCompileContext& clCompileContext); + + /// Serializes the ClContext to the stream. + /// @param [stream] the stream to save to + /// @return true if ClContext is Serialized to the Stream, false otherwise + bool SaveSerializedToStream(std::ostream& stream); + +private: + /// FlatBufferBuilder to create the CLContext FlatBuffers. + flatbuffers::FlatBufferBuilder m_FlatBufferBuilder; +}; + +} // namespace armnn \ No newline at end of file diff --git a/src/backends/cl/ClWorkloadFactory.cpp b/src/backends/cl/ClWorkloadFactory.cpp index 5a5cb89204..d65b26314e 100644 --- a/src/backends/cl/ClWorkloadFactory.cpp +++ b/src/backends/cl/ClWorkloadFactory.cpp @@ -5,6 +5,8 @@ #include "ClWorkloadFactory.hpp" #include "ClBackendId.hpp" #include "ClBackendModelContext.hpp" +#include "ClContextDeserializer.hpp" +#include "ClContextSerializer.hpp" #include @@ -28,6 +30,7 @@ #include #include +#include namespace armnn { @@ -68,7 +71,11 @@ void ClWorkloadFactory::AfterWorkloadsCreated() auto filePath = modelOptions->GetCachedNetworkFilePath(); if (filePath != "" && fs::exists(filePath) && fs::is_regular_file(filePath)) { - /// Saving will be implemented within IVGCVSW-5483 story. + // Serialize ClContext to the file specified + ClContextSerializer serializer; + serializer.Serialize(m_CLCompileContext); + std::ofstream file(filePath, std::ios::out | std::ios::binary); + serializer.SaveSerializedToStream(file); } } } @@ -121,7 +128,9 @@ void ClWorkloadFactory::InitializeCLCompileContext() && fs::is_regular_file(filePath) && !(modelOptions->SaveCachedNetwork())) { - /// Loading will be implemented within IVGCVSW-5483 story. + // Deserialize binary file and load into m_CLCompileContext + ClContextDeserializer deserializer; + deserializer.Deserialize(m_CLCompileContext, context, device, filePath); } } } diff --git a/src/backends/cl/backend.mk b/src/backends/cl/backend.mk index 52295ccc4f..9514750563 100644 --- a/src/backends/cl/backend.mk +++ b/src/backends/cl/backend.mk @@ -18,6 +18,8 @@ BACKEND_SOURCES := \ ClBackendContext.cpp \ ClBackendModelContext.cpp \ ClContextControl.cpp \ + ClContextDeserializer.cpp \ + ClContextSerializer.cpp \ ClLayerSupport.cpp \ ClRegistryInitializer.cpp \ ClTensorHandleFactory.cpp \ @@ -97,6 +99,7 @@ ifeq ($(ARMNN_COMPUTE_CL_ENABLED),1) # Include the source files for the CL backend tests BACKEND_TEST_SOURCES := \ + test/ClContextSerializerTests.cpp \ test/ClCreateWorkloadTests.cpp \ test/ClEndToEndTests.cpp \ test/ClJsonPrinterTests.cpp \ diff --git a/src/backends/cl/test/CMakeLists.txt b/src/backends/cl/test/CMakeLists.txt index 2cf00d106b..a2950d9d44 100644 --- a/src/backends/cl/test/CMakeLists.txt +++ b/src/backends/cl/test/CMakeLists.txt @@ -5,6 +5,7 @@ list(APPEND armnnClBackendUnitTests_sources ClContextControlFixture.hpp + ClContextSerializerTests.cpp ClCreateWorkloadTests.cpp ClEndToEndTests.cpp ClJsonPrinterTests.cpp diff --git a/src/backends/cl/test/ClContextSerializerTests.cpp b/src/backends/cl/test/ClContextSerializerTests.cpp new file mode 100644 index 0000000000..1fc0fb9205 --- /dev/null +++ b/src/backends/cl/test/ClContextSerializerTests.cpp @@ -0,0 +1,138 @@ +// +// Copyright © 2020 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include + +#include + +#include + +#include + +namespace +{ + +armnn::INetworkPtr CreateNetwork() +{ + // Builds up the structure of the network. + armnn::INetworkPtr net(armnn::INetwork::Create()); + + armnn::IConnectableLayer* input = net->AddInputLayer(0, "input"); + armnn::IConnectableLayer* softmax = net->AddSoftmaxLayer(armnn::SoftmaxDescriptor(), "softmax"); + armnn::IConnectableLayer* output = net->AddOutputLayer(0, "output"); + + input->GetOutputSlot(0).Connect(softmax->GetInputSlot(0)); + softmax->GetOutputSlot(0).Connect(output->GetInputSlot(0)); + + // Sets the input and output tensors + armnn::TensorInfo inputTensorInfo(armnn::TensorShape({1, 5}), armnn::DataType::QAsymmU8, 10000.0f, 1); + input->GetOutputSlot(0).SetTensorInfo(inputTensorInfo); + + armnn::TensorInfo outputTensorInfo(armnn::TensorShape({1, 5}), armnn::DataType::QAsymmU8, 1.0f/255.0f, 0); + softmax->GetOutputSlot(0).SetTensorInfo(outputTensorInfo); + + return net; +} + +void RunInference(armnn::NetworkId& netId, armnn::IRuntimePtr& runtime, std::vector& outputData) +{ + // Creates structures for input & output. + std::vector inputData + { + 1, 10, 3, 200, 5 // Some inputs - one of which is sufficiently larger than the others to saturate softmax. + }; + + armnn::InputTensors inputTensors + { + {0, armnn::ConstTensor(runtime->GetInputTensorInfo(netId, 0), inputData.data())} + }; + + armnn::OutputTensors outputTensors + { + {0, armnn::Tensor(runtime->GetOutputTensorInfo(netId, 0), outputData.data())} + }; + + // Run inference. + runtime->EnqueueWorkload(netId, inputTensors, outputTensors); +} + +std::vector ReadBinaryFile(const std::string& binaryFileName) +{ + std::ifstream input(binaryFileName, std::ios::binary); + return std::vector(std::istreambuf_iterator(input), {}); +} + +} // anonymous namespace + +BOOST_FIXTURE_TEST_SUITE(ClContextSerializer, ClContextControlFixture) + +BOOST_AUTO_TEST_CASE(ClContextSerializerTest) +{ + // Get tmp directory and create blank file. + fs::path filePath = armnnUtils::Filesystem::NamedTempFile("Armnn-CachedNetworkFileTest-TempFile.bin"); + std::string const filePathString{filePath.string()}; + std::ofstream file { filePathString }; + + // Create runtime in which test will run + armnn::IRuntime::CreationOptions options; + armnn::IRuntimePtr runtime(armnn::IRuntime::Create(options)); + + std::vector backends = {armnn::Compute::GpuAcc}; + + // Create two networks. + // net1 will serialize and save context to file. + // net2 will deserialize context saved from net1 and load. + armnn::INetworkPtr net1 = CreateNetwork(); + armnn::INetworkPtr net2 = CreateNetwork(); + + // Add specific optimizerOptions to each network. + armnn::OptimizerOptions optimizerOptions1; + armnn::OptimizerOptions optimizerOptions2; + armnn::BackendOptions modelOptions1("GpuAcc", + {{"SaveCachedNetwork", true}, {"CachedNetworkFilePath", filePathString}}); + armnn::BackendOptions modelOptions2("GpuAcc", + {{"SaveCachedNetwork", false}, {"CachedNetworkFilePath", filePathString}}); + optimizerOptions1.m_ModelOptions.push_back(modelOptions1); + optimizerOptions2.m_ModelOptions.push_back(modelOptions2); + + armnn::IOptimizedNetworkPtr optNet1 = armnn::Optimize( + *net1, backends, runtime->GetDeviceSpec(), optimizerOptions1); + armnn::IOptimizedNetworkPtr optNet2 = armnn::Optimize( + *net2, backends, runtime->GetDeviceSpec(), optimizerOptions2); + BOOST_CHECK(optNet1); + BOOST_CHECK(optNet2); + + // Cached file should be empty until net1 is loaded into runtime. + BOOST_TEST(fs::is_empty(filePathString)); + + // Load net1 into the runtime. + armnn::NetworkId netId1; + BOOST_TEST(runtime->LoadNetwork(netId1, std::move(optNet1)) == armnn::Status::Success); + + // File should now exist and not be empty. It has been serialized. + BOOST_TEST(fs::exists(filePathString)); + std::vector dataSerialized = ReadBinaryFile(filePathString); + BOOST_TEST(dataSerialized.size() != 0); + + // Load net2 into the runtime using file and deserialize. + armnn::NetworkId netId2; + BOOST_TEST(runtime->LoadNetwork(netId2, std::move(optNet2)) == armnn::Status::Success); + + // Run inference and get output data. + std::vector outputData1(5); + RunInference(netId1, runtime, outputData1); + + std::vector outputData2(5); + RunInference(netId2, runtime, outputData2); + + // Compare outputs from both networks. + BOOST_CHECK_EQUAL_COLLECTIONS(outputData1.begin(), outputData1.end(), + outputData2.begin(), outputData2.end()); + + // Remove temp file created. + fs::remove(filePath); +} + +BOOST_AUTO_TEST_SUITE_END() -- cgit v1.2.1