diff options
Diffstat (limited to '1.2')
-rw-r--r-- | 1.2/ArmnnDriver.hpp | 4 | ||||
-rw-r--r-- | 1.2/HalPolicy.cpp | 29 | ||||
-rw-r--r-- | 1.2/HalPolicy.hpp | 4 |
3 files changed, 35 insertions, 2 deletions
diff --git a/1.2/ArmnnDriver.hpp b/1.2/ArmnnDriver.hpp index 5227272f..a350d3f4 100644 --- a/1.2/ArmnnDriver.hpp +++ b/1.2/ArmnnDriver.hpp @@ -129,8 +129,8 @@ public: Return<void> getType(getType_cb cb) { ALOGV("hal_1_2::ArmnnDriver::getType()"); - - cb(V1_0::ErrorStatus::NONE, V1_2::DeviceType::CPU); + const auto device_type = hal_1_2::HalPolicy::GetDeviceTypeFromOptions(this->m_Options); + cb(V1_0::ErrorStatus::NONE, device_type); return Void(); } diff --git a/1.2/HalPolicy.cpp b/1.2/HalPolicy.cpp index fb6c31ce..79d117ae 100644 --- a/1.2/HalPolicy.cpp +++ b/1.2/HalPolicy.cpp @@ -4,6 +4,8 @@ // #include "HalPolicy.hpp" +#include "DriverOptions.hpp" + namespace armnn_driver { @@ -17,6 +19,33 @@ namespace } // anonymous namespace +HalPolicy::DeviceType HalPolicy::GetDeviceTypeFromOptions(const DriverOptions& options) +{ + // Query backends list from the options + auto backends = options.GetBackends(); + // Return first backend + if(backends.size()>0) + { + const auto &first_backend = backends[0]; + if(first_backend.IsCpuAcc()||first_backend.IsCpuRef()) + { + return V1_2::DeviceType::CPU; + } + else if(first_backend.IsGpuAcc()) + { + return V1_2::DeviceType::GPU; + } + else + { + return V1_2::DeviceType::ACCELERATOR; + } + } + else + { + return V1_2::DeviceType::CPU; + } +} + bool HalPolicy::ConvertOperation(const Operation& operation, const Model& model, ConversionData& data) { switch (operation.type) diff --git a/1.2/HalPolicy.hpp b/1.2/HalPolicy.hpp index a348abe0..0662e1be 100644 --- a/1.2/HalPolicy.hpp +++ b/1.2/HalPolicy.hpp @@ -16,6 +16,7 @@ namespace V1_2 = ::android::hardware::neuralnetworks::V1_2; namespace armnn_driver { +class DriverOptions; namespace hal_1_2 { @@ -31,6 +32,9 @@ public: using ExecutionCallback = V1_2::IExecutionCallback; using getSupportedOperations_cb = V1_2::IDevice::getSupportedOperations_1_2_cb; using ErrorStatus = V1_0::ErrorStatus; + using DeviceType = V1_2::DeviceType; + + static DeviceType GetDeviceTypeFromOptions(const DriverOptions& options); static bool ConvertOperation(const Operation& operation, const Model& model, ConversionData& data); |