diff options
author | telsoa01 <telmo.soares@arm.com> | 2018-03-09 14:13:49 +0000 |
---|---|---|
committer | telsoa01 <telmo.soares@arm.com> | 2018-03-09 14:13:49 +0000 |
commit | 4fcda0101ec3d110c1d6d7bee5c83416b645528a (patch) | |
tree | c9a70aeb2887006160c1b3d265c27efadb7bdbae /tests/InferenceModel.hpp | |
download | armnn-4fcda0101ec3d110c1d6d7bee5c83416b645528a.tar.gz |
Release 18.02
Change-Id: Id3c11dc5ee94ef664374a988fcc6901e9a232fa6
Diffstat (limited to 'tests/InferenceModel.hpp')
-rw-r--r-- | tests/InferenceModel.hpp | 149 |
1 files changed, 149 insertions, 0 deletions
diff --git a/tests/InferenceModel.hpp b/tests/InferenceModel.hpp new file mode 100644 index 0000000000..c390ccdc2f --- /dev/null +++ b/tests/InferenceModel.hpp @@ -0,0 +1,149 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// See LICENSE file in the project root for full license information. +// +#pragma once + +#include "armnn/ArmNN.hpp" + +#include <boost/log/trivial.hpp> +#include <boost/format.hpp> +#include <boost/program_options.hpp> + +#include <map> +#include <string> + +template<typename TContainer> +inline armnn::InputTensors MakeInputTensors(const std::pair<armnn::LayerBindingId, armnn::TensorInfo>& input, + const TContainer& inputTensorData) +{ + if (inputTensorData.size() != input.second.GetNumElements()) + { + throw armnn::Exception(boost::str(boost::format("Input tensor has incorrect size. Expected %1% elements " + "but got %2%.") % input.second.GetNumElements() % inputTensorData.size())); + } + return { { input.first, armnn::ConstTensor(input.second, inputTensorData.data()) } }; +} + +template<typename TContainer> +inline armnn::OutputTensors MakeOutputTensors(const std::pair<armnn::LayerBindingId, armnn::TensorInfo>& output, + TContainer& outputTensorData) +{ + if (outputTensorData.size() != output.second.GetNumElements()) + { + throw armnn::Exception("Output tensor has incorrect size"); + } + return { { output.first, armnn::Tensor(output.second, outputTensorData.data()) } }; +} + +template <typename IParser, typename TDataType> +class InferenceModel +{ +public: + using DataType = TDataType; + + struct CommandLineOptions + { + std::string m_ModelDir; + armnn::Compute m_ComputeDevice; + }; + + static void AddCommandLineOptions(boost::program_options::options_description& desc, CommandLineOptions& options) + { + namespace po = boost::program_options; + + desc.add_options() + ("model-dir,m", po::value<std::string>(&options.m_ModelDir)->required(), + "Path to directory containing model files (.caffemodel/.prototxt)") + ("compute,c", po::value<armnn::Compute>(&options.m_ComputeDevice)->default_value(armnn::Compute::CpuAcc), + "Which device to run layers on by default. Possible choices: CpuAcc, CpuRef, GpuAcc"); + } + + struct Params + { + std::string m_ModelPath; + std::string m_InputBinding; + std::string m_OutputBinding; + const armnn::TensorShape* m_InputTensorShape; + armnn::Compute m_ComputeDevice; + bool m_IsModelBinary; + + Params() + : m_InputTensorShape(nullptr) + , m_ComputeDevice(armnn::Compute::CpuRef) + , m_IsModelBinary(true) + { + } + }; + + + InferenceModel(const Params& params) + : m_Runtime(armnn::IRuntime::Create(params.m_ComputeDevice)) + { + const std::string& modelPath = params.m_ModelPath; + + // Create a network from a file on disk + auto parser(IParser::Create()); + + std::map<std::string, armnn::TensorShape> inputShapes; + if (params.m_InputTensorShape) + { + inputShapes[params.m_InputBinding] = *params.m_InputTensorShape; + } + std::vector<std::string> requestedOutputs{ params.m_OutputBinding }; + + // Handle text and binary input differently by calling the corresponding parser function + armnn::INetworkPtr network = (params.m_IsModelBinary ? + parser->CreateNetworkFromBinaryFile(modelPath.c_str(), inputShapes, requestedOutputs) : + parser->CreateNetworkFromTextFile(modelPath.c_str(), inputShapes, requestedOutputs)); + + m_InputBindingInfo = parser->GetNetworkInputBindingInfo(params.m_InputBinding); + m_OutputBindingInfo = parser->GetNetworkOutputBindingInfo(params.m_OutputBinding); + + armnn::IOptimizedNetworkPtr optNet = + armnn::Optimize(*network, m_Runtime->GetDeviceSpec()); + + // Load the network into the runtime. + armnn::Status ret = m_Runtime->LoadNetwork(m_NetworkIdentifier, std::move(optNet)); + if (ret == armnn::Status::Failure) + { + throw armnn::Exception("IRuntime::LoadNetwork failed"); + } + } + + unsigned int GetOutputSize() const + { + return m_OutputBindingInfo.second.GetNumElements(); + } + + void Run(const std::vector<TDataType>& input, std::vector<TDataType>& output) + { + BOOST_ASSERT(output.size() == GetOutputSize()); + armnn::Status ret = m_Runtime->EnqueueWorkload(m_NetworkIdentifier, + MakeInputTensors(input), + MakeOutputTensors(output)); + if (ret == armnn::Status::Failure) + { + throw armnn::Exception("IRuntime::EnqueueWorkload failed"); + } + } + +private: + template<typename TContainer> + armnn::InputTensors MakeInputTensors(const TContainer& inputTensorData) + { + return ::MakeInputTensors(m_InputBindingInfo, inputTensorData); + } + + template<typename TContainer> + armnn::OutputTensors MakeOutputTensors(TContainer& outputTensorData) + { + return ::MakeOutputTensors(m_OutputBindingInfo, outputTensorData); + } + + armnn::NetworkId m_NetworkIdentifier; + armnn::IRuntimePtr m_Runtime; + + std::pair<armnn::LayerBindingId, armnn::TensorInfo> m_InputBindingInfo; + std::pair<armnn::LayerBindingId, armnn::TensorInfo> m_OutputBindingInfo; +}; |