ArmNN
 24.05
ArmnnDriver.hpp
Go to the documentation of this file.
1 //
2 // Copyright © 2022-2024 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include <android-base/logging.h>
9 #include <nnapi/IBuffer.h>
10 #include <nnapi/IDevice.h>
11 #include <nnapi/IPreparedModel.h>
12 #include <nnapi/OperandTypes.h>
13 #include <nnapi/Result.h>
14 #include <nnapi/Types.h>
15 #include <nnapi/Validation.h>
16 
17 #include "ArmnnDevice.hpp"
18 #include "ArmnnDriverImpl.hpp"
19 #include "Converter.hpp"
20 
21 #include "ArmnnDriverImpl.hpp"
23 
24 #include <armnn/Version.hpp>
25 #include <log/log.h>
26 namespace armnn_driver
27 {
28 
29 //using namespace android::nn;
30 
31 class ArmnnDriver : public IDevice
32 {
33 private:
34  std::unique_ptr<ArmnnDevice> m_Device;
35 public:
36  ARMNN_DEPRECATED_MSG_REMOVAL_DATE("The Shim and support library will be removed from Arm NN in 24.08", "24.08")
38  {
39  try
40  {
41  VLOG(DRIVER) << "ArmnnDriver::ArmnnDriver()";
42  m_Device = std::unique_ptr<ArmnnDevice>(new ArmnnDevice(std::move(options)));
43  }
45  {
46  VLOG(DRIVER) << "ArmnnDevice failed to initialise: " << ex.what();
47  }
48  catch (...)
49  {
50  VLOG(DRIVER) << "ArmnnDevice failed to initialise with an unknown error";
51  }
52  }
53 
54 public:
55 
56  const std::string& getName() const override
57  {
58  VLOG(DRIVER) << "ArmnnDriver::getName()";
59  static const std::string name = "arm-armnn-sl";
60  return name;
61  }
62 
63  const std::string& getVersionString() const override
64  {
65  VLOG(DRIVER) << "ArmnnDriver::getVersionString()";
66  static const std::string versionString = ARMNN_VERSION;
67  return versionString;
68  }
69 
70  Version getFeatureLevel() const override
71  {
72  VLOG(DRIVER) << "ArmnnDriver::getFeatureLevel()";
73  return kVersionFeatureLevel6;
74  }
75 
76  DeviceType getType() const override
77  {
78  VLOG(DRIVER) << "ArmnnDriver::getType()";
79  return DeviceType::CPU;
80  }
81 
82  const std::vector<Extension>& getSupportedExtensions() const override
83  {
84  VLOG(DRIVER) << "ArmnnDriver::getSupportedExtensions()";
85  static const std::vector<Extension> extensions = {};
86  return extensions;
87  }
88 
89  const Capabilities& getCapabilities() const override
90  {
91  VLOG(DRIVER) << "ArmnnDriver::GetCapabilities()";
92  return ArmnnDriverImpl::GetCapabilities(m_Device->m_Runtime);
93  }
94 
95  std::pair<uint32_t, uint32_t> getNumberOfCacheFilesNeeded() const override
96  {
97  VLOG(DRIVER) << "ArmnnDriver::getNumberOfCacheFilesNeeded()";
98  unsigned int numberOfCachedModelFiles = 0;
99  for (auto& backend : m_Device->m_Options.GetBackends())
100  {
101  numberOfCachedModelFiles += GetNumberOfCacheFiles(backend);
102  VLOG(DRIVER) << "ArmnnDriver::getNumberOfCacheFilesNeeded() = "
103  << std::to_string(numberOfCachedModelFiles);
104  }
105  return std::make_pair(numberOfCachedModelFiles, 1ul);
106  }
107 
108  GeneralResult<void> wait() const override
109  {
110  VLOG(DRIVER) << "ArmnnDriver::wait()";
111  return {};
112  }
113 
114  GeneralResult<std::vector<bool>> getSupportedOperations(const Model& model) const override
115  {
116  VLOG(DRIVER) << "ArmnnDriver::getSupportedOperations()";
117  if (m_Device.get() == nullptr)
118  {
119  return NN_ERROR(ErrorStatus::DEVICE_UNAVAILABLE) << "Device Unavailable!";
120  }
121 
122  std::stringstream ss;
123  ss << "ArmnnDriverImpl::getSupportedOperations()";
124  std::string fileName;
125  std::string timestamp;
126  if (!m_Device->m_Options.GetRequestInputsAndOutputsDumpDir().empty())
127  {
128  ss << " : "
129  << m_Device->m_Options.GetRequestInputsAndOutputsDumpDir()
130  << "/"
131  // << GetFileTimestamp()
132  << "_getSupportedOperations.txt";
133  }
134  VLOG(DRIVER) << ss.str().c_str();
135 
136  if (!m_Device->m_Options.GetRequestInputsAndOutputsDumpDir().empty())
137  {
138  //dump the marker file
139  std::ofstream fileStream;
140  fileStream.open(fileName, std::ofstream::out | std::ofstream::trunc);
141  if (fileStream.good())
142  {
143  fileStream << timestamp << std::endl;
144  fileStream << timestamp << std::endl;
145  }
146  fileStream.close();
147  }
148 
149  std::vector<bool> result;
150  if (!m_Device->m_Runtime)
151  {
152  return NN_ERROR(ErrorStatus::DEVICE_UNAVAILABLE) << "Device Unavailable!";
153  }
154 
155  // Run general model validation, if this doesn't pass we shouldn't analyse the model anyway.
156  if (const auto result = validate(model); !result.ok())
157  {
158  return NN_ERROR(ErrorStatus::INVALID_ARGUMENT) << "Invalid Model!";
159  }
160 
161  // Attempt to convert the model to an ArmNN input network (INetwork).
162  ModelToINetworkTransformer modelConverter(m_Device->m_Options.GetBackends(),
163  model,
164  m_Device->m_Options.GetForcedUnsupportedOperations());
165 
166  if (modelConverter.GetConversionResult() != ConversionResult::Success
168  {
169  return NN_ERROR(ErrorStatus::GENERAL_FAILURE) << "Conversion Error!";
170  }
171 
172  // Check each operation if it was converted successfully and copy the flags
173  // into the result (vector<bool>) that we need to return to Android.
174  result.reserve(model.main.operations.size());
175  for (uint32_t operationIdx = 0; operationIdx < model.main.operations.size(); ++operationIdx)
176  {
177  bool operationSupported = modelConverter.IsOperationSupported(operationIdx);
178  result.push_back(operationSupported);
179  }
180 
181  return result;
182  }
183 
184  GeneralResult<SharedPreparedModel> prepareModel(const Model& model,
185  ExecutionPreference preference,
186  Priority priority,
187  OptionalTimePoint deadline,
188  const std::vector<SharedHandle>& modelCache,
189  const std::vector<SharedHandle>& dataCache,
190  const CacheToken& token,
191  const std::vector<android::nn::TokenValuePair>& hints,
192  const std::vector<android::nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override
193  {
194  VLOG(DRIVER) << "ArmnnDriver::prepareModel()";
195 
196  if (m_Device.get() == nullptr)
197  {
198  return NN_ERROR(ErrorStatus::DEVICE_UNAVAILABLE) << "Device Unavailable!";
199  }
200  // Validate arguments.
201  if (const auto result = validate(model); !result.ok()) {
202  return NN_ERROR(ErrorStatus::INVALID_ARGUMENT) << "Invalid Model: " << result.error();
203  }
204  if (const auto result = validate(preference); !result.ok()) {
205  return NN_ERROR(ErrorStatus::INVALID_ARGUMENT)
206  << "Invalid ExecutionPreference: " << result.error();
207  }
208  if (const auto result = validate(priority); !result.ok()) {
209  return NN_ERROR(ErrorStatus::INVALID_ARGUMENT) << "Invalid Priority: " << result.error();
210  }
211 
212  // Check if deadline has passed.
213  if (hasDeadlinePassed(deadline)) {
214  return NN_ERROR(ErrorStatus::MISSED_DEADLINE_PERSISTENT);
215  }
217  return ArmnnDriverImpl::PrepareArmnnModel(m_Device->m_Runtime,
218  m_Device->m_ClTunedParameters,
219  m_Device->m_Options,
220  model,
221  modelCache,
222  dataCache,
223  token,
224  model.relaxComputationFloat32toFloat16 && m_Device->m_Options.GetFp16Enabled(),
225  priority);
227  }
228 
229  GeneralResult<SharedPreparedModel> prepareModelFromCache(OptionalTimePoint deadline,
230  const std::vector<SharedHandle>& modelCache,
231  const std::vector<SharedHandle>& dataCache,
232  const CacheToken& token) const override
233  {
234  VLOG(DRIVER) << "ArmnnDriver::prepareModelFromCache()";
235  if (m_Device.get() == nullptr)
236  {
237  return NN_ERROR(ErrorStatus::DEVICE_UNAVAILABLE) << "Device Unavailable!";
238  }
239  // Check if deadline has passed.
240  if (hasDeadlinePassed(deadline)) {
241  return NN_ERROR(ErrorStatus::MISSED_DEADLINE_PERSISTENT);
242  }
243 
246  m_Device->m_Runtime,
247  m_Device->m_ClTunedParameters,
248  m_Device->m_Options,
249  modelCache,
250  dataCache,
251  token,
252  m_Device->m_Options.GetFp16Enabled());
254  }
255 
256  GeneralResult<SharedBuffer> allocate(const BufferDesc&,
257  const std::vector<SharedPreparedModel>&,
258  const std::vector<BufferRole>&,
259  const std::vector<BufferRole>&) const override
260  {
261  VLOG(DRIVER) << "ArmnnDriver::allocate()";
262  return NN_ERROR(ErrorStatus::INVALID_ARGUMENT) << "ArmnnDriver::allocate -- does not support allocate.";
263  }
264 };
265 
266 } // namespace armnn_driver
armnn_driver::ArmnnDriver::getSupportedOperations
GeneralResult< std::vector< bool > > getSupportedOperations(const Model &model) const override
Definition: ArmnnDriver.hpp:114
armnn_driver::ArmnnDriverImpl::PrepareArmnnModel
static GeneralResult< SharedPreparedModel > PrepareArmnnModel(const armnn::IRuntimePtr &runtime, const armnn::IGpuAccTunedParametersPtr &clTunedParameters, const DriverOptions &options, const Model &model, const std::vector< SharedHandle > &modelCacheHandle, const std::vector< SharedHandle > &dataCacheHandle, const CacheToken &token, bool float32ToFloat16=false, Priority priority=Priority::MEDIUM)
Definition: ArmnnDriverImpl.cpp:99
armnn_driver::ArmnnDriver::getVersionString
const std::string & getVersionString() const override
Definition: ArmnnDriver.hpp:63
ArmnnDriverImpl.hpp
armnn_driver::ArmnnDriver::getNumberOfCacheFilesNeeded
std::pair< uint32_t, uint32_t > getNumberOfCacheFilesNeeded() const override
Definition: ArmnnDriver.hpp:95
ARMNN_NO_DEPRECATE_WARN_BEGIN
#define ARMNN_NO_DEPRECATE_WARN_BEGIN
Definition: Deprecated.hpp:33
armnn_driver::ArmnnDriver
Definition: ArmnnDriver.hpp:31
armnn_driver::ArmnnDevice
Definition: ArmnnDevice.hpp:15
armnn_driver::ConversionResult::UnsupportedFeature
@ UnsupportedFeature
armnn::Exception::what
virtual const char * what() const noexcept override
Definition: Exceptions.cpp:32
ModelToINetworkTransformer.hpp
Version.hpp
armnn_driver
Helper classes.
Definition: ArmnnDevice.cpp:37
armnn_driver::ArmnnDriver::prepareModel
GeneralResult< SharedPreparedModel > prepareModel(const Model &model, ExecutionPreference preference, Priority priority, OptionalTimePoint deadline, const std::vector< SharedHandle > &modelCache, const std::vector< SharedHandle > &dataCache, const CacheToken &token, const std::vector< android::nn::TokenValuePair > &hints, const std::vector< android::nn::ExtensionNameAndPrefix > &extensionNameToPrefix) const override
Definition: ArmnnDriver.hpp:184
armnn_driver::Model
::android::nn::Model Model
Helper classes.
Definition: ConversionUtils.hpp:45
armnn_driver::ArmnnDriver::prepareModelFromCache
GeneralResult< SharedPreparedModel > prepareModelFromCache(OptionalTimePoint deadline, const std::vector< SharedHandle > &modelCache, const std::vector< SharedHandle > &dataCache, const CacheToken &token) const override
Definition: ArmnnDriver.hpp:229
armnn_driver::DriverOptions
Definition: DriverOptions.hpp:17
armnn::InvalidArgumentException
Definition: Exceptions.hpp:80
armnn_driver::ArmnnDriver::wait
GeneralResult< void > wait() const override
Definition: ArmnnDriver.hpp:108
armnn_driver::ArmnnDriver::allocate
GeneralResult< SharedBuffer > allocate(const BufferDesc &, const std::vector< SharedPreparedModel > &, const std::vector< BufferRole > &, const std::vector< BufferRole > &) const override
Definition: ArmnnDriver.hpp:256
armnn_driver::ModelToINetworkTransformer::IsOperationSupported
bool IsOperationSupported(uint32_t operationIndex) const
Definition: ModelToINetworkTransformer.cpp:196
armnn_driver::ArmnnDriver::getCapabilities
const Capabilities & getCapabilities() const override
Definition: ArmnnDriver.hpp:89
ARMNN_VERSION
#define ARMNN_VERSION
ARMNN_VERSION: "X.Y.Z" where: X = Major version number Y = Minor version number Z = Patch version num...
Definition: Version.hpp:22
armnn_driver::ArmnnDriver::getName
const std::string & getName() const override
Definition: ArmnnDriver.hpp:56
armnn_driver::ArmnnDriverImpl::GetCapabilities
static const Capabilities & GetCapabilities(const armnn::IRuntimePtr &runtime)
Definition: ArmnnDriverImpl.cpp:554
ARMNN_NO_DEPRECATE_WARN_END
#define ARMNN_NO_DEPRECATE_WARN_END
Definition: Deprecated.hpp:34
ArmnnDevice.hpp
ARMNN_DEPRECATED_MSG_REMOVAL_DATE
#define ARMNN_DEPRECATED_MSG_REMOVAL_DATE(message, removed_in_release)
Definition: Deprecated.hpp:44
armnn_driver::ConversionResult::Success
@ Success
armnn::GetNumberOfCacheFiles
unsigned int GetNumberOfCacheFiles(const armnn::BackendId &backend)
Returns the number of cached files if backend supports caching.
Definition: BackendHelper.cpp:130
armnn_driver::ArmnnDriver::getSupportedExtensions
const std::vector< Extension > & getSupportedExtensions() const override
Definition: ArmnnDriver.hpp:82
armnn_driver::ArmnnDriver::getType
DeviceType getType() const override
Definition: ArmnnDriver.hpp:76
armnn_driver::ArmnnDriver::getFeatureLevel
Version getFeatureLevel() const override
Definition: ArmnnDriver.hpp:70
armnn_driver::ArmnnDriverImpl::PrepareArmnnModelFromCache
static GeneralResult< SharedPreparedModel > PrepareArmnnModelFromCache(const armnn::IRuntimePtr &runtime, const armnn::IGpuAccTunedParametersPtr &clTunedParameters, const DriverOptions &options, const std::vector< SharedHandle > &modelCacheHandle, const std::vector< SharedHandle > &dataCacheHandle, const CacheToken &token, bool float32ToFloat16=false)
Definition: ArmnnDriverImpl.cpp:338
Converter.hpp
armnn_driver::ModelToINetworkTransformer::GetConversionResult
ConversionResult GetConversionResult() const
Definition: ModelToINetworkTransformer.hpp:37
armnn_driver::ModelToINetworkTransformer
Definition: ModelToINetworkTransformer.hpp:30