aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatthew Sloyan <matthew.sloyan@arm.com>2021-01-07 13:28:47 +0000
committerMatthew Sloyan <matthew.sloyan@arm.com>2021-01-11 17:03:54 +0000
commit80fbcd5f4d7b362360963af1df0121aa6b561576 (patch)
tree64c8d2588e55aad2813f6b07e40f87ac3b8e8ce1
parenta20b3129aa1c450ccf867c7b63844e8391753730 (diff)
downloadarmnn-80fbcd5f4d7b362360963af1df0121aa6b561576.tar.gz
IVGCVSW-5483 'Implement Loading and Saving to File'
* Implemented Serialization and Deserialization of CLContext. * Fixed flatbuffers android-nn-driver dependency. !android-nn-driver:4772 Signed-off-by: Matthew Sloyan <matthew.sloyan@arm.com> Signed-off-by: Sadik Armagan <sadik.armagan@arm.com> Change-Id: If806f050535ffaa70922ba0f1ffe7bb10f902329
-rw-r--r--Android.mk2
-rw-r--r--src/backends/cl/CMakeLists.txt20
-rw-r--r--src/backends/cl/ClContextDeserializer.cpp80
-rw-r--r--src/backends/cl/ClContextDeserializer.hpp41
-rw-r--r--src/backends/cl/ClContextSchema.fbs21
-rw-r--r--src/backends/cl/ClContextSchema_generated.h185
-rw-r--r--src/backends/cl/ClContextSerializer.cpp57
-rw-r--r--src/backends/cl/ClContextSerializer.hpp35
-rw-r--r--src/backends/cl/ClWorkloadFactory.cpp13
-rw-r--r--src/backends/cl/backend.mk3
-rw-r--r--src/backends/cl/test/CMakeLists.txt1
-rw-r--r--src/backends/cl/test/ClContextSerializerTests.cpp138
12 files changed, 594 insertions, 2 deletions
diff --git a/Android.mk b/Android.mk
index d683c2312f..df0bb040a0 100644
--- a/Android.mk
+++ b/Android.mk
@@ -238,6 +238,7 @@ LOCAL_SRC_FILES := \
src/profiling/backends/BackendProfiling.cpp
LOCAL_STATIC_LIBRARIES := \
+ libflatbuffers-framework \
arm_compute_library
LOCAL_SHARED_LIBRARIES := \
@@ -422,6 +423,7 @@ endif
LOCAL_STATIC_LIBRARIES := \
libneuralnetworks_common \
libboost_unit_test_framework \
+ libflatbuffers-framework \
arm_compute_library
LOCAL_WHOLE_STATIC_LIBRARIES := libarmnn
diff --git a/src/backends/cl/CMakeLists.txt b/src/backends/cl/CMakeLists.txt
index 4b5890af50..bfb99dde96 100644
--- a/src/backends/cl/CMakeLists.txt
+++ b/src/backends/cl/CMakeLists.txt
@@ -4,7 +4,23 @@
#
if(ARMCOMPUTECL)
+ find_program(FLATC flatc
+ HINTS ${FLATC_DIR}
+ DOC "Path to 'flatc', the flatbuffers compiler")
+ if (NOT FLATC)
+ message(SEND_ERROR "flatc not found. Specify the full path of the flatc executable with -DFLATC=<flatc path>")
+ endif()
+
+ add_custom_command(
+ # Generate an ClContextSchema_generated.h file if it doesn't exist, or update it when necessary otherwise
+ OUTPUT ClContextSchema_generated.h DEPENDS ClContextSchema.fbs
+ COMMAND ${FLATC} -o ${CMAKE_CURRENT_BINARY_DIR} --cpp ${CMAKE_CURRENT_SOURCE_DIR}/ClContextSchema.fbs
+ #COMMAND cp ${CMAKE_CURRENT_BINARY_DIR}/ClContextSchema_generated.h
+ # ${CMAKE_CURRENT_SOURCE_DIR}/ClContextSchema_generated.h
+ )
+
list(APPEND armnnClBackend_sources
+ ClContextSchema_generated.h
ClBackend.cpp
ClBackend.hpp
ClBackendContext.cpp
@@ -14,6 +30,10 @@ if(ARMCOMPUTECL)
ClBackendModelContext.hpp
ClContextControl.cpp
ClContextControl.hpp
+ ClContextDeserializer.hpp
+ ClContextDeserializer.cpp
+ ClContextSerializer.hpp
+ ClContextSerializer.cpp
ClLayerSupport.cpp
ClLayerSupport.hpp
ClRegistryInitializer.cpp
diff --git a/src/backends/cl/ClContextDeserializer.cpp b/src/backends/cl/ClContextDeserializer.cpp
new file mode 100644
index 0000000000..8a1b585d47
--- /dev/null
+++ b/src/backends/cl/ClContextDeserializer.cpp
@@ -0,0 +1,80 @@
+//
+// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "ClContextDeserializer.hpp"
+#include "ClContextSchema_generated.h"
+
+#include <armnn/Exceptions.hpp>
+#include <armnn/utility/NumericCast.hpp>
+
+#include <flatbuffers/flexbuffers.h>
+
+#include <fmt/format.h>
+
+#include <cstdlib>
+#include <fstream>
+#include <iostream>
+#include <vector>
+
+namespace armnn
+{
+
+void ClContextDeserializer::Deserialize(arm_compute::CLCompileContext& clCompileContext,
+ cl::Context& context,
+ cl::Device& device,
+ const std::string& filePath)
+{
+ std::ifstream inputFileStream(filePath, std::ios::binary);
+ std::vector<std::uint8_t> binaryContent;
+ while (inputFileStream)
+ {
+ char input;
+ inputFileStream.get(input);
+ if (inputFileStream)
+ {
+ binaryContent.push_back(static_cast<std::uint8_t>(input));
+ }
+ }
+ inputFileStream.close();
+ DeserializeFromBinary(clCompileContext, context, device, binaryContent);
+}
+
+void ClContextDeserializer::DeserializeFromBinary(arm_compute::CLCompileContext& clCompileContext,
+ cl::Context& context,
+ cl::Device& device,
+ const std::vector<uint8_t>& binaryContent)
+{
+ if (binaryContent.data() == nullptr)
+ {
+ throw InvalidArgumentException(fmt::format("Invalid (null) binary content {}",
+ CHECK_LOCATION().AsString()));
+ }
+
+ size_t binaryContentSize = binaryContent.size();
+ flatbuffers::Verifier verifier(binaryContent.data(), binaryContentSize);
+ if (verifier.VerifyBuffer<ClContext>() == false)
+ {
+ throw ParseException(fmt::format("Buffer doesn't conform to the expected Armnn "
+ "flatbuffers format. size:{0} {1}",
+ binaryContentSize,
+ CHECK_LOCATION().AsString()));
+ }
+ auto clContext = GetClContext(binaryContent.data());
+
+ for (Program const* program : *clContext->programs())
+ {
+ auto programName = program->name()->c_str();
+ auto programBinary = program->binary();
+ std::vector<uint8_t> binary(programBinary->begin(), programBinary->begin() + programBinary->size());
+
+ cl::Program::Binaries binaries{ binary };
+ std::vector<cl::Device> devices {device};
+ cl::Program theProgram(context, devices, binaries);
+ theProgram.build();
+ clCompileContext.add_built_program(programName, theProgram);
+ }
+}
+
+} // namespace armnn
diff --git a/src/backends/cl/ClContextDeserializer.hpp b/src/backends/cl/ClContextDeserializer.hpp
new file mode 100644
index 0000000000..e3a9b9deb4
--- /dev/null
+++ b/src/backends/cl/ClContextDeserializer.hpp
@@ -0,0 +1,41 @@
+//
+// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include <arm_compute/core/CL/CLCompileContext.h>
+
+namespace armnn
+{
+
+class ClContextDeserializer
+{
+public:
+ ClContextDeserializer() = default;
+ ~ClContextDeserializer() = default;
+
+ /// Deserializes the CLCompileContext built-in programs from a binary file
+ /// @param [in] clCompileContext The CLCompileContext to be serialized
+ /// @param [in] context The CL Kernel context built-in program will be created from
+ /// @param [in] device The CL Kernel device built-in program will be created from
+ /// @param [in] filePath The serialized file
+ void Deserialize(arm_compute::CLCompileContext& clCompileContext,
+ cl::Context& context,
+ cl::Device& device,
+ const std::string& filePath);
+
+ /// Deserializes the CLCompileContext built-in programs from binary file contents
+ /// @param [in] clCompileContext The CLCompileContext to be serialized
+ /// @param [in] context The CL Kernel context built-in program will be created from
+ /// @param [in] device The CL Kernel device built-in program will be created from
+ /// @param [in] filePath The serialized file
+ void DeserializeFromBinary(arm_compute::CLCompileContext& clCompileContext,
+ cl::Context& context,
+ cl::Device& device,
+ const std::vector<uint8_t>& binaryContent);
+
+};
+
+} // namespace armnn \ No newline at end of file
diff --git a/src/backends/cl/ClContextSchema.fbs b/src/backends/cl/ClContextSchema.fbs
new file mode 100644
index 0000000000..c517d8039a
--- /dev/null
+++ b/src/backends/cl/ClContextSchema.fbs
@@ -0,0 +1,21 @@
+//
+// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+namespace armnn;
+
+file_identifier "ARMN";
+
+file_extension "armnn";
+
+table ClContext {
+ programs:[Program];
+}
+
+table Program {
+ name:string;
+ binary:[ubyte];
+}
+
+root_type ClContext; \ No newline at end of file
diff --git a/src/backends/cl/ClContextSchema_generated.h b/src/backends/cl/ClContextSchema_generated.h
new file mode 100644
index 0000000000..88b759f78f
--- /dev/null
+++ b/src/backends/cl/ClContextSchema_generated.h
@@ -0,0 +1,185 @@
+//
+// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+// automatically generated by the FlatBuffers compiler, do not modify
+
+#ifndef FLATBUFFERS_GENERATED_CLCONTEXTSCHEMA_ARMNN_H_
+#define FLATBUFFERS_GENERATED_CLCONTEXTSCHEMA_ARMNN_H_
+
+#include "flatbuffers/flatbuffers.h"
+
+namespace armnn {
+
+struct ClContext;
+struct ClContextBuilder;
+
+struct Program;
+struct ProgramBuilder;
+
+struct ClContext FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef ClContextBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_PROGRAMS = 4
+ };
+ const flatbuffers::Vector<flatbuffers::Offset<armnn::Program>> *programs() const {
+ return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<armnn::Program>> *>(VT_PROGRAMS);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffset(verifier, VT_PROGRAMS) &&
+ verifier.VerifyVector(programs()) &&
+ verifier.VerifyVectorOfTables(programs()) &&
+ verifier.EndTable();
+ }
+};
+
+struct ClContextBuilder {
+ typedef ClContext Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_programs(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<armnn::Program>>> programs) {
+ fbb_.AddOffset(ClContext::VT_PROGRAMS, programs);
+ }
+ explicit ClContextBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ ClContextBuilder &operator=(const ClContextBuilder &);
+ flatbuffers::Offset<ClContext> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<ClContext>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<ClContext> CreateClContext(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<armnn::Program>>> programs = 0) {
+ ClContextBuilder builder_(_fbb);
+ builder_.add_programs(programs);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<ClContext> CreateClContextDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ const std::vector<flatbuffers::Offset<armnn::Program>> *programs = nullptr) {
+ auto programs__ = programs ? _fbb.CreateVector<flatbuffers::Offset<armnn::Program>>(*programs) : 0;
+ return armnn::CreateClContext(
+ _fbb,
+ programs__);
+}
+
+struct Program FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef ProgramBuilder Builder;
+ enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE {
+ VT_NAME = 4,
+ VT_BINARY = 6
+ };
+ const flatbuffers::String *name() const {
+ return GetPointer<const flatbuffers::String *>(VT_NAME);
+ }
+ const flatbuffers::Vector<uint8_t> *binary() const {
+ return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_BINARY);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyOffset(verifier, VT_NAME) &&
+ verifier.VerifyString(name()) &&
+ VerifyOffset(verifier, VT_BINARY) &&
+ verifier.VerifyVector(binary()) &&
+ verifier.EndTable();
+ }
+};
+
+struct ProgramBuilder {
+ typedef Program Table;
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_name(flatbuffers::Offset<flatbuffers::String> name) {
+ fbb_.AddOffset(Program::VT_NAME, name);
+ }
+ void add_binary(flatbuffers::Offset<flatbuffers::Vector<uint8_t>> binary) {
+ fbb_.AddOffset(Program::VT_BINARY, binary);
+ }
+ explicit ProgramBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ ProgramBuilder &operator=(const ProgramBuilder &);
+ flatbuffers::Offset<Program> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<Program>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<Program> CreateProgram(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<flatbuffers::String> name = 0,
+ flatbuffers::Offset<flatbuffers::Vector<uint8_t>> binary = 0) {
+ ProgramBuilder builder_(_fbb);
+ builder_.add_binary(binary);
+ builder_.add_name(name);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<Program> CreateProgramDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ const char *name = nullptr,
+ const std::vector<uint8_t> *binary = nullptr) {
+ auto name__ = name ? _fbb.CreateString(name) : 0;
+ auto binary__ = binary ? _fbb.CreateVector<uint8_t>(*binary) : 0;
+ return armnn::CreateProgram(
+ _fbb,
+ name__,
+ binary__);
+}
+
+inline const armnn::ClContext *GetClContext(const void *buf) {
+ return flatbuffers::GetRoot<armnn::ClContext>(buf);
+}
+
+inline const armnn::ClContext *GetSizePrefixedClContext(const void *buf) {
+ return flatbuffers::GetSizePrefixedRoot<armnn::ClContext>(buf);
+}
+
+inline const char *ClContextIdentifier() {
+ return "ARMN";
+}
+
+inline bool ClContextBufferHasIdentifier(const void *buf) {
+ return flatbuffers::BufferHasIdentifier(
+ buf, ClContextIdentifier());
+}
+
+inline bool VerifyClContextBuffer(
+ flatbuffers::Verifier &verifier) {
+ return verifier.VerifyBuffer<armnn::ClContext>(ClContextIdentifier());
+}
+
+inline bool VerifySizePrefixedClContextBuffer(
+ flatbuffers::Verifier &verifier) {
+ return verifier.VerifySizePrefixedBuffer<armnn::ClContext>(ClContextIdentifier());
+}
+
+inline const char *ClContextExtension() {
+ return "armnn";
+}
+
+inline void FinishClContextBuffer(
+ flatbuffers::FlatBufferBuilder &fbb,
+ flatbuffers::Offset<armnn::ClContext> root) {
+ fbb.Finish(root, ClContextIdentifier());
+}
+
+inline void FinishSizePrefixedClContextBuffer(
+ flatbuffers::FlatBufferBuilder &fbb,
+ flatbuffers::Offset<armnn::ClContext> root) {
+ fbb.FinishSizePrefixed(root, ClContextIdentifier());
+}
+
+} // namespace armnn
+
+#endif // FLATBUFFERS_GENERATED_CLCONTEXTSCHEMA_ARMNN_H_
diff --git a/src/backends/cl/ClContextSerializer.cpp b/src/backends/cl/ClContextSerializer.cpp
new file mode 100644
index 0000000000..db89203a25
--- /dev/null
+++ b/src/backends/cl/ClContextSerializer.cpp
@@ -0,0 +1,57 @@
+//
+// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "ClContextSerializer.hpp"
+#include "ClContextSchema_generated.h"
+
+#include <armnn/Exceptions.hpp>
+#include <armnn/Logging.hpp>
+#include <armnn/utility/NumericCast.hpp>
+
+#include <fmt/format.h>
+
+#include <cstdlib>
+#include <fstream>
+#include <iostream>
+#include <vector>
+
+namespace armnn
+{
+
+void ClContextSerializer::Serialize(const arm_compute::CLCompileContext& clCompileContext)
+{
+ // Get map of built programs from clCompileContext
+ std::map<std::string, cl::Program> builtProgramsMap = clCompileContext.get_built_programs();
+ if (builtProgramsMap.empty())
+ {
+ ARMNN_LOG(warning) << "There are no built programs to be serialised.";
+ return;
+ }
+
+ // Create Flatbuffer CL Programs
+ std::vector<flatbuffers::Offset<armnn::Program>> clPrograms;
+ for(const auto& program : builtProgramsMap)
+ {
+ std::vector<std::vector<uint8_t>> binaries = program.second.getInfo<CL_PROGRAM_BINARIES>();
+ clPrograms.push_back(CreateProgram(m_FlatBufferBuilder,
+ m_FlatBufferBuilder.CreateString(program.first),
+ m_FlatBufferBuilder.CreateVector(binaries[0])));
+ }
+
+ // Create Flatbuffer CLContext
+ auto clContext = CreateClContext(m_FlatBufferBuilder, m_FlatBufferBuilder.CreateVector(clPrograms));
+
+ m_FlatBufferBuilder.Finish(clContext);
+}
+
+bool ClContextSerializer::SaveSerializedToStream(std::ostream& stream)
+{
+ // Write to a stream
+ auto bytesToWrite = armnn::numeric_cast<std::streamsize>(m_FlatBufferBuilder.GetSize());
+ stream.write(reinterpret_cast<const char*>(m_FlatBufferBuilder.GetBufferPointer()), bytesToWrite);
+ return !stream.bad();
+}
+
+} // namespace armnn
diff --git a/src/backends/cl/ClContextSerializer.hpp b/src/backends/cl/ClContextSerializer.hpp
new file mode 100644
index 0000000000..71e2b1f1c9
--- /dev/null
+++ b/src/backends/cl/ClContextSerializer.hpp
@@ -0,0 +1,35 @@
+//
+// Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include <arm_compute/core/CL/CLCompileContext.h>
+
+#include <flatbuffers/flatbuffers.h>
+
+namespace armnn
+{
+
+class ClContextSerializer
+{
+public:
+ ClContextSerializer() = default;
+ ~ClContextSerializer() = default;
+
+ /// Serializes the CLCompileContext built-in programs
+ /// @param [in] clCompileContext The CLCompileContext to be serialized.
+ void Serialize(const arm_compute::CLCompileContext& clCompileContext);
+
+ /// Serializes the ClContext to the stream.
+ /// @param [stream] the stream to save to
+ /// @return true if ClContext is Serialized to the Stream, false otherwise
+ bool SaveSerializedToStream(std::ostream& stream);
+
+private:
+ /// FlatBufferBuilder to create the CLContext FlatBuffers.
+ flatbuffers::FlatBufferBuilder m_FlatBufferBuilder;
+};
+
+} // namespace armnn \ No newline at end of file
diff --git a/src/backends/cl/ClWorkloadFactory.cpp b/src/backends/cl/ClWorkloadFactory.cpp
index 5a5cb89204..d65b26314e 100644
--- a/src/backends/cl/ClWorkloadFactory.cpp
+++ b/src/backends/cl/ClWorkloadFactory.cpp
@@ -5,6 +5,8 @@
#include "ClWorkloadFactory.hpp"
#include "ClBackendId.hpp"
#include "ClBackendModelContext.hpp"
+#include "ClContextDeserializer.hpp"
+#include "ClContextSerializer.hpp"
#include <Layer.hpp>
@@ -28,6 +30,7 @@
#include <arm_compute/runtime/CL/CLScheduler.h>
#include <Filesystem.hpp>
+#include <fstream>
namespace armnn
{
@@ -68,7 +71,11 @@ void ClWorkloadFactory::AfterWorkloadsCreated()
auto filePath = modelOptions->GetCachedNetworkFilePath();
if (filePath != "" && fs::exists(filePath) && fs::is_regular_file(filePath))
{
- /// Saving will be implemented within IVGCVSW-5483 story.
+ // Serialize ClContext to the file specified
+ ClContextSerializer serializer;
+ serializer.Serialize(m_CLCompileContext);
+ std::ofstream file(filePath, std::ios::out | std::ios::binary);
+ serializer.SaveSerializedToStream(file);
}
}
}
@@ -121,7 +128,9 @@ void ClWorkloadFactory::InitializeCLCompileContext()
&& fs::is_regular_file(filePath)
&& !(modelOptions->SaveCachedNetwork()))
{
- /// Loading will be implemented within IVGCVSW-5483 story.
+ // Deserialize binary file and load into m_CLCompileContext
+ ClContextDeserializer deserializer;
+ deserializer.Deserialize(m_CLCompileContext, context, device, filePath);
}
}
}
diff --git a/src/backends/cl/backend.mk b/src/backends/cl/backend.mk
index 52295ccc4f..9514750563 100644
--- a/src/backends/cl/backend.mk
+++ b/src/backends/cl/backend.mk
@@ -18,6 +18,8 @@ BACKEND_SOURCES := \
ClBackendContext.cpp \
ClBackendModelContext.cpp \
ClContextControl.cpp \
+ ClContextDeserializer.cpp \
+ ClContextSerializer.cpp \
ClLayerSupport.cpp \
ClRegistryInitializer.cpp \
ClTensorHandleFactory.cpp \
@@ -97,6 +99,7 @@ ifeq ($(ARMNN_COMPUTE_CL_ENABLED),1)
# Include the source files for the CL backend tests
BACKEND_TEST_SOURCES := \
+ test/ClContextSerializerTests.cpp \
test/ClCreateWorkloadTests.cpp \
test/ClEndToEndTests.cpp \
test/ClJsonPrinterTests.cpp \
diff --git a/src/backends/cl/test/CMakeLists.txt b/src/backends/cl/test/CMakeLists.txt
index 2cf00d106b..a2950d9d44 100644
--- a/src/backends/cl/test/CMakeLists.txt
+++ b/src/backends/cl/test/CMakeLists.txt
@@ -5,6 +5,7 @@
list(APPEND armnnClBackendUnitTests_sources
ClContextControlFixture.hpp
+ ClContextSerializerTests.cpp
ClCreateWorkloadTests.cpp
ClEndToEndTests.cpp
ClJsonPrinterTests.cpp
diff --git a/src/backends/cl/test/ClContextSerializerTests.cpp b/src/backends/cl/test/ClContextSerializerTests.cpp
new file mode 100644
index 0000000000..1fc0fb9205
--- /dev/null
+++ b/src/backends/cl/test/ClContextSerializerTests.cpp
@@ -0,0 +1,138 @@
+//
+// Copyright © 2020 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include <Filesystem.hpp>
+
+#include <cl/test/ClContextControlFixture.hpp>
+
+#include <boost/test/unit_test.hpp>
+
+#include <fstream>
+
+namespace
+{
+
+armnn::INetworkPtr CreateNetwork()
+{
+ // Builds up the structure of the network.
+ armnn::INetworkPtr net(armnn::INetwork::Create());
+
+ armnn::IConnectableLayer* input = net->AddInputLayer(0, "input");
+ armnn::IConnectableLayer* softmax = net->AddSoftmaxLayer(armnn::SoftmaxDescriptor(), "softmax");
+ armnn::IConnectableLayer* output = net->AddOutputLayer(0, "output");
+
+ input->GetOutputSlot(0).Connect(softmax->GetInputSlot(0));
+ softmax->GetOutputSlot(0).Connect(output->GetInputSlot(0));
+
+ // Sets the input and output tensors
+ armnn::TensorInfo inputTensorInfo(armnn::TensorShape({1, 5}), armnn::DataType::QAsymmU8, 10000.0f, 1);
+ input->GetOutputSlot(0).SetTensorInfo(inputTensorInfo);
+
+ armnn::TensorInfo outputTensorInfo(armnn::TensorShape({1, 5}), armnn::DataType::QAsymmU8, 1.0f/255.0f, 0);
+ softmax->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
+
+ return net;
+}
+
+void RunInference(armnn::NetworkId& netId, armnn::IRuntimePtr& runtime, std::vector<uint8_t>& outputData)
+{
+ // Creates structures for input & output.
+ std::vector<uint8_t> inputData
+ {
+ 1, 10, 3, 200, 5 // Some inputs - one of which is sufficiently larger than the others to saturate softmax.
+ };
+
+ armnn::InputTensors inputTensors
+ {
+ {0, armnn::ConstTensor(runtime->GetInputTensorInfo(netId, 0), inputData.data())}
+ };
+
+ armnn::OutputTensors outputTensors
+ {
+ {0, armnn::Tensor(runtime->GetOutputTensorInfo(netId, 0), outputData.data())}
+ };
+
+ // Run inference.
+ runtime->EnqueueWorkload(netId, inputTensors, outputTensors);
+}
+
+std::vector<char> ReadBinaryFile(const std::string& binaryFileName)
+{
+ std::ifstream input(binaryFileName, std::ios::binary);
+ return std::vector<char>(std::istreambuf_iterator<char>(input), {});
+}
+
+} // anonymous namespace
+
+BOOST_FIXTURE_TEST_SUITE(ClContextSerializer, ClContextControlFixture)
+
+BOOST_AUTO_TEST_CASE(ClContextSerializerTest)
+{
+ // Get tmp directory and create blank file.
+ fs::path filePath = armnnUtils::Filesystem::NamedTempFile("Armnn-CachedNetworkFileTest-TempFile.bin");
+ std::string const filePathString{filePath.string()};
+ std::ofstream file { filePathString };
+
+ // Create runtime in which test will run
+ armnn::IRuntime::CreationOptions options;
+ armnn::IRuntimePtr runtime(armnn::IRuntime::Create(options));
+
+ std::vector<armnn::BackendId> backends = {armnn::Compute::GpuAcc};
+
+ // Create two networks.
+ // net1 will serialize and save context to file.
+ // net2 will deserialize context saved from net1 and load.
+ armnn::INetworkPtr net1 = CreateNetwork();
+ armnn::INetworkPtr net2 = CreateNetwork();
+
+ // Add specific optimizerOptions to each network.
+ armnn::OptimizerOptions optimizerOptions1;
+ armnn::OptimizerOptions optimizerOptions2;
+ armnn::BackendOptions modelOptions1("GpuAcc",
+ {{"SaveCachedNetwork", true}, {"CachedNetworkFilePath", filePathString}});
+ armnn::BackendOptions modelOptions2("GpuAcc",
+ {{"SaveCachedNetwork", false}, {"CachedNetworkFilePath", filePathString}});
+ optimizerOptions1.m_ModelOptions.push_back(modelOptions1);
+ optimizerOptions2.m_ModelOptions.push_back(modelOptions2);
+
+ armnn::IOptimizedNetworkPtr optNet1 = armnn::Optimize(
+ *net1, backends, runtime->GetDeviceSpec(), optimizerOptions1);
+ armnn::IOptimizedNetworkPtr optNet2 = armnn::Optimize(
+ *net2, backends, runtime->GetDeviceSpec(), optimizerOptions2);
+ BOOST_CHECK(optNet1);
+ BOOST_CHECK(optNet2);
+
+ // Cached file should be empty until net1 is loaded into runtime.
+ BOOST_TEST(fs::is_empty(filePathString));
+
+ // Load net1 into the runtime.
+ armnn::NetworkId netId1;
+ BOOST_TEST(runtime->LoadNetwork(netId1, std::move(optNet1)) == armnn::Status::Success);
+
+ // File should now exist and not be empty. It has been serialized.
+ BOOST_TEST(fs::exists(filePathString));
+ std::vector<char> dataSerialized = ReadBinaryFile(filePathString);
+ BOOST_TEST(dataSerialized.size() != 0);
+
+ // Load net2 into the runtime using file and deserialize.
+ armnn::NetworkId netId2;
+ BOOST_TEST(runtime->LoadNetwork(netId2, std::move(optNet2)) == armnn::Status::Success);
+
+ // Run inference and get output data.
+ std::vector<uint8_t> outputData1(5);
+ RunInference(netId1, runtime, outputData1);
+
+ std::vector<uint8_t> outputData2(5);
+ RunInference(netId2, runtime, outputData2);
+
+ // Compare outputs from both networks.
+ BOOST_CHECK_EQUAL_COLLECTIONS(outputData1.begin(), outputData1.end(),
+ outputData2.begin(), outputData2.end());
+
+ // Remove temp file created.
+ fs::remove(filePath);
+}
+
+BOOST_AUTO_TEST_SUITE_END()