ArmNN
 22.11
ArmnnDriver.hpp
Go to the documentation of this file.
1 //
2 // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include <android-base/logging.h>
9 #include <nnapi/IBuffer.h>
10 #include <nnapi/IDevice.h>
11 #include <nnapi/IPreparedModel.h>
12 #include <nnapi/OperandTypes.h>
13 #include <nnapi/Result.h>
14 #include <nnapi/Types.h>
15 #include <nnapi/Validation.h>
16 
17 #include "ArmnnDevice.hpp"
18 #include "ArmnnDriverImpl.hpp"
19 #include "Converter.hpp"
20 
21 #include "ArmnnDriverImpl.hpp"
23 
24 #include <armnn/Version.hpp>
25 #include <log/log.h>
26 namespace armnn_driver
27 {
28 
29 //using namespace android::nn;
30 
31 class ArmnnDriver : public ArmnnDevice, public IDevice
32 {
33 public:
34 
36  : ArmnnDevice(std::move(options))
37  {
38  VLOG(DRIVER) << "ArmnnDriver::ArmnnDriver()";
39  }
41  {
42  VLOG(DRIVER) << "ArmnnDriver::~ArmnnDriver()";
43  }
44 
45 public:
46 
47  const std::string& getName() const override
48  {
49  VLOG(DRIVER) << "ArmnnDriver::getName()";
50  static const std::string name = "arm-armnn-sl";
51  return name;
52  }
53 
54  const std::string& getVersionString() const override
55  {
56  VLOG(DRIVER) << "ArmnnDriver::getVersionString()";
57  static const std::string versionString = ARMNN_VERSION;
58  return versionString;
59  }
60 
61  Version getFeatureLevel() const override
62  {
63  VLOG(DRIVER) << "ArmnnDriver::getFeatureLevel()";
64  return kVersionFeatureLevel6;
65  }
66 
67  DeviceType getType() const override
68  {
69  VLOG(DRIVER) << "ArmnnDriver::getType()";
70  return DeviceType::CPU;
71  }
72 
73  const std::vector<Extension>& getSupportedExtensions() const override
74  {
75  VLOG(DRIVER) << "ArmnnDriver::getSupportedExtensions()";
76  static const std::vector<Extension> extensions = {};
77  return extensions;
78  }
79 
80  const Capabilities& getCapabilities() const override
81  {
82  VLOG(DRIVER) << "ArmnnDriver::GetCapabilities()";
84  }
85 
86  std::pair<uint32_t, uint32_t> getNumberOfCacheFilesNeeded() const override
87  {
88  VLOG(DRIVER) << "ArmnnDriver::getNumberOfCacheFilesNeeded()";
89  unsigned int numberOfCachedModelFiles = 0;
90  for (auto& backend : m_Options.GetBackends())
91  {
92  numberOfCachedModelFiles += GetNumberOfCacheFiles(backend);
93  VLOG(DRIVER) << "ArmnnDriver::getNumberOfCacheFilesNeeded() = " << std::to_string(numberOfCachedModelFiles);
94  }
95  return std::make_pair(numberOfCachedModelFiles, 1ul);
96  }
97 
98  GeneralResult<void> wait() const override
99  {
100  VLOG(DRIVER) << "ArmnnDriver::wait()";
101  return {};
102  }
103 
104  GeneralResult<std::vector<bool>> getSupportedOperations(const Model& model) const override
105  {
106  VLOG(DRIVER) << "ArmnnDriver::getSupportedOperations()";
107 
108  std::stringstream ss;
109  ss << "ArmnnDriverImpl::getSupportedOperations()";
110  std::string fileName;
111  std::string timestamp;
113  {
114  ss << " : "
116  << "/"
117  // << GetFileTimestamp()
118  << "_getSupportedOperations.txt";
119  }
120  VLOG(DRIVER) << ss.str().c_str();
121 
123  {
124  //dump the marker file
125  std::ofstream fileStream;
126  fileStream.open(fileName, std::ofstream::out | std::ofstream::trunc);
127  if (fileStream.good())
128  {
129  fileStream << timestamp << std::endl;
130  fileStream << timestamp << std::endl;
131  }
132  fileStream.close();
133  }
134 
135  std::vector<bool> result;
136  if (!m_Runtime)
137  {
138  return NN_ERROR(ErrorStatus::DEVICE_UNAVAILABLE) << "Device Unavailable!";
139  }
140 
141  // Run general model validation, if this doesn't pass we shouldn't analyse the model anyway.
142  if (const auto result = validate(model); !result.ok())
143  {
144  return NN_ERROR(ErrorStatus::INVALID_ARGUMENT) << "Invalid Model!";
145  }
146 
147  // Attempt to convert the model to an ArmNN input network (INetwork).
149  model,
151 
152  if (modelConverter.GetConversionResult() != ConversionResult::Success
153  && modelConverter.GetConversionResult() != ConversionResult::UnsupportedFeature)
154  {
155  return NN_ERROR(ErrorStatus::GENERAL_FAILURE) << "Conversion Error!";
156  }
157 
158  // Check each operation if it was converted successfully and copy the flags
159  // into the result (vector<bool>) that we need to return to Android.
160  result.reserve(model.main.operations.size());
161  for (uint32_t operationIdx = 0; operationIdx < model.main.operations.size(); ++operationIdx)
162  {
163  bool operationSupported = modelConverter.IsOperationSupported(operationIdx);
164  result.push_back(operationSupported);
165  }
166 
167  return result;
168  }
169 
170  GeneralResult<SharedPreparedModel> prepareModel(const Model& model,
171  ExecutionPreference preference,
172  Priority priority,
173  OptionalTimePoint deadline,
174  const std::vector<SharedHandle>& modelCache,
175  const std::vector<SharedHandle>& dataCache,
176  const CacheToken& token,
177  const std::vector<android::nn::TokenValuePair>& hints,
178  const std::vector<android::nn::ExtensionNameAndPrefix>& extensionNameToPrefix) const override
179  {
180  VLOG(DRIVER) << "ArmnnDriver::prepareModel()";
181 
182  // Validate arguments.
183  if (const auto result = validate(model); !result.ok()) {
184  return NN_ERROR(ErrorStatus::INVALID_ARGUMENT) << "Invalid Model: " << result.error();
185  }
186  if (const auto result = validate(preference); !result.ok()) {
187  return NN_ERROR(ErrorStatus::INVALID_ARGUMENT)
188  << "Invalid ExecutionPreference: " << result.error();
189  }
190  if (const auto result = validate(priority); !result.ok()) {
191  return NN_ERROR(ErrorStatus::INVALID_ARGUMENT) << "Invalid Priority: " << result.error();
192  }
193 
194  // Check if deadline has passed.
195  if (hasDeadlinePassed(deadline)) {
196  return NN_ERROR(ErrorStatus::MISSED_DEADLINE_PERSISTENT);
197  }
198 
201  m_Options,
202  model,
203  modelCache,
204  dataCache,
205  token,
206  model.relaxComputationFloat32toFloat16 && m_Options.GetFp16Enabled(),
207  priority);
208  }
209 
210  GeneralResult<SharedPreparedModel> prepareModelFromCache(OptionalTimePoint deadline,
211  const std::vector<SharedHandle>& modelCache,
212  const std::vector<SharedHandle>& dataCache,
213  const CacheToken& token) const override
214  {
215  VLOG(DRIVER) << "ArmnnDriver::prepareModelFromCache()";
216 
217  // Check if deadline has passed.
218  if (hasDeadlinePassed(deadline)) {
219  return NN_ERROR(ErrorStatus::MISSED_DEADLINE_PERSISTENT);
220  }
221 
223  m_Runtime,
225  m_Options,
226  modelCache,
227  dataCache,
228  token,
230  }
231 
232  GeneralResult<SharedBuffer> allocate(const BufferDesc&,
233  const std::vector<SharedPreparedModel>&,
234  const std::vector<BufferRole>&,
235  const std::vector<BufferRole>&) const override
236  {
237  VLOG(DRIVER) << "ArmnnDriver::allocate()";
238  return NN_ERROR(ErrorStatus::INVALID_ARGUMENT) << "ArmnnDriver::allocate -- does not support allocate.";
239  }
240 };
241 
242 } // namespace armnn_driver
GeneralResult< void > wait() const override
Definition: ArmnnDriver.hpp:98
Version getFeatureLevel() const override
Definition: ArmnnDriver.hpp:61
GeneralResult< SharedPreparedModel > prepareModel(const Model &model, ExecutionPreference preference, Priority priority, OptionalTimePoint deadline, const std::vector< SharedHandle > &modelCache, const std::vector< SharedHandle > &dataCache, const CacheToken &token, const std::vector< android::nn::TokenValuePair > &hints, const std::vector< android::nn::ExtensionNameAndPrefix > &extensionNameToPrefix) const override
#define ARMNN_VERSION
ARMNN_VERSION: "X.Y.Z" where: X = Major version number Y = Minor version number Z = Patch version num...
Definition: Version.hpp:22
armnn::IGpuAccTunedParametersPtr m_ClTunedParameters
Definition: ArmnnDevice.hpp:24
const std::set< unsigned int > & GetForcedUnsupportedOperations() const
GeneralResult< std::vector< bool > > getSupportedOperations(const Model &model) const override
const std::string & getName() const override
Definition: ArmnnDriver.hpp:47
armnn::IRuntimePtr m_Runtime
Definition: ArmnnDevice.hpp:23
const std::vector< Extension > & getSupportedExtensions() const override
Definition: ArmnnDriver.hpp:73
::android::nn::Model Model
Helper classes.
unsigned int GetNumberOfCacheFiles(const armnn::BackendId &backend)
Returns the number of cached files if backend supports caching.
const std::string & getVersionString() const override
Definition: ArmnnDriver.hpp:54
const std::string & GetRequestInputsAndOutputsDumpDir() const
ArmnnDriver(DriverOptions options)
Definition: ArmnnDriver.hpp:35
bool IsOperationSupported(uint32_t operationIndex) const
GeneralResult< SharedPreparedModel > prepareModelFromCache(OptionalTimePoint deadline, const std::vector< SharedHandle > &modelCache, const std::vector< SharedHandle > &dataCache, const CacheToken &token) const override
const Capabilities & getCapabilities() const override
Definition: ArmnnDriver.hpp:80
GeneralResult< SharedBuffer > allocate(const BufferDesc &, const std::vector< SharedPreparedModel > &, const std::vector< BufferRole > &, const std::vector< BufferRole > &) const override
static GeneralResult< SharedPreparedModel > PrepareArmnnModelFromCache(const armnn::IRuntimePtr &runtime, const armnn::IGpuAccTunedParametersPtr &clTunedParameters, const DriverOptions &options, const std::vector< SharedHandle > &modelCacheHandle, const std::vector< SharedHandle > &dataCacheHandle, const CacheToken &token, bool float32ToFloat16=false)
static const Capabilities & GetCapabilities(const armnn::IRuntimePtr &runtime)
const std::vector< armnn::BackendId > & GetBackends() const
static GeneralResult< SharedPreparedModel > PrepareArmnnModel(const armnn::IRuntimePtr &runtime, const armnn::IGpuAccTunedParametersPtr &clTunedParameters, const DriverOptions &options, const Model &model, const std::vector< SharedHandle > &modelCacheHandle, const std::vector< SharedHandle > &dataCacheHandle, const CacheToken &token, bool float32ToFloat16=false, Priority priority=Priority::MEDIUM)
std::pair< uint32_t, uint32_t > getNumberOfCacheFilesNeeded() const override
Definition: ArmnnDriver.hpp:86
DeviceType getType() const override
Definition: ArmnnDriver.hpp:67
Helper classes.
Definition: ArmnnDevice.cpp:37