From 8f397a1efed11e17e9f8cb12b53a72b7e32ab978 Mon Sep 17 00:00:00 2001 From: Sadik Armagan Date: Fri, 17 Jun 2022 15:38:22 +0100 Subject: IVGCVSW-6989 "Merged experimental/armnn_shim_sl" * Updated Serializer CMakeLists.txt to build armnnSerializerObj * Added constant tensors as input support to SL Signed-off-by: Sadik Armagan Change-Id: I22f6cf50147d99a01f7fe70d7446b114a4c57af3 --- shim/sl/canonical/ArmnnDriver.hpp | 247 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 247 insertions(+) create mode 100644 shim/sl/canonical/ArmnnDriver.hpp (limited to 'shim/sl/canonical/ArmnnDriver.hpp') diff --git a/shim/sl/canonical/ArmnnDriver.hpp b/shim/sl/canonical/ArmnnDriver.hpp new file mode 100644 index 0000000000..877faa667e --- /dev/null +++ b/shim/sl/canonical/ArmnnDriver.hpp @@ -0,0 +1,247 @@ +// +// Copyright © 2022 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ArmnnDevice.hpp" +#include "ArmnnDriverImpl.hpp" +#include "Converter.hpp" + +#include "ArmnnDriverImpl.hpp" +#include "ModelToINetworkTransformer.hpp" + +#include +namespace armnn_driver +{ + +//using namespace android::nn; + +class ArmnnDriver : public ArmnnDevice, public IDevice +{ +public: + + ArmnnDriver(DriverOptions options) + : ArmnnDevice(std::move(options)) + { + VLOG(DRIVER) << "ArmnnDriver::ArmnnDriver()"; + } + ~ArmnnDriver() + { + VLOG(DRIVER) << "ArmnnDriver::~ArmnnDriver()"; + // Unload the networks + for (auto& netId : ArmnnDriverImpl::GetLoadedNetworks()) + { + m_Runtime->UnloadNetwork(netId); + } + ArmnnDriverImpl::ClearNetworks(); + } + +public: + + const std::string& getName() const override + { + VLOG(DRIVER) << "ArmnnDriver::getName()"; + static const std::string name = "arm-armnn-sl"; + return name; + } + + const std::string& getVersionString() const override + { + VLOG(DRIVER) << "ArmnnDriver::getVersionString()"; + static const std::string versionString = "ArmNN"; + return versionString; + } + + Version getFeatureLevel() const override + { + VLOG(DRIVER) << "ArmnnDriver::getFeatureLevel()"; + return kVersionFeatureLevel5; + } + + DeviceType getType() const override + { + VLOG(DRIVER) << "ArmnnDriver::getType()"; + return DeviceType::CPU; + } + + const std::vector& getSupportedExtensions() const override + { + VLOG(DRIVER) << "ArmnnDriver::getSupportedExtensions()"; + static const std::vector extensions = {}; + return extensions; + } + + const Capabilities& getCapabilities() const override + { + VLOG(DRIVER) << "ArmnnDriver::GetCapabilities()"; + return ArmnnDriverImpl::GetCapabilities(m_Runtime); + } + + std::pair getNumberOfCacheFilesNeeded() const override + { + VLOG(DRIVER) << "ArmnnDriver::getNumberOfCacheFilesNeeded()"; + unsigned int numberOfCachedModelFiles = 0; + for (auto& backend : m_Options.GetBackends()) + { + numberOfCachedModelFiles += GetNumberOfCacheFiles(backend); + VLOG(DRIVER) << "ArmnnDriver::getNumberOfCacheFilesNeeded() = " << std::to_string(numberOfCachedModelFiles); + } + return std::make_pair(numberOfCachedModelFiles, 1ul); + } + + GeneralResult wait() const override + { + VLOG(DRIVER) << "ArmnnDriver::wait()"; + return {}; + } + + GeneralResult> getSupportedOperations(const Model& model) const override + { + VLOG(DRIVER) << "ArmnnDriver::getSupportedOperations()"; + + std::stringstream ss; + ss << "ArmnnDriverImpl::getSupportedOperations()"; + std::string fileName; + std::string timestamp; + if (!m_Options.GetRequestInputsAndOutputsDumpDir().empty()) + { + ss << " : " + << m_Options.GetRequestInputsAndOutputsDumpDir() + << "/" + // << GetFileTimestamp() + << "_getSupportedOperations.txt"; + } + VLOG(DRIVER) << ss.str().c_str(); + + if (!m_Options.GetRequestInputsAndOutputsDumpDir().empty()) + { + //dump the marker file + std::ofstream fileStream; + fileStream.open(fileName, std::ofstream::out | std::ofstream::trunc); + if (fileStream.good()) + { + fileStream << timestamp << std::endl; + fileStream << timestamp << std::endl; + } + fileStream.close(); + } + + std::vector result; + if (!m_Runtime) + { + return NN_ERROR(ErrorStatus::DEVICE_UNAVAILABLE) << "Device Unavailable!"; + } + + // Run general model validation, if this doesn't pass we shouldn't analyse the model anyway. + if (const auto result = validate(model); !result.ok()) + { + return NN_ERROR(ErrorStatus::INVALID_ARGUMENT) << "Invalid Model!"; + } + + // Attempt to convert the model to an ArmNN input network (INetwork). + ModelToINetworkTransformer modelConverter(m_Options.GetBackends(), + model, + m_Options.GetForcedUnsupportedOperations()); + + if (modelConverter.GetConversionResult() != ConversionResult::Success + && modelConverter.GetConversionResult() != ConversionResult::UnsupportedFeature) + { + return NN_ERROR(ErrorStatus::GENERAL_FAILURE) << "Conversion Error!"; + } + + // Check each operation if it was converted successfully and copy the flags + // into the result (vector) that we need to return to Android. + result.reserve(model.main.operations.size()); + for (uint32_t operationIdx = 0; operationIdx < model.main.operations.size(); ++operationIdx) + { + bool operationSupported = modelConverter.IsOperationSupported(operationIdx); + result.push_back(operationSupported); + } + + return result; + } + + GeneralResult prepareModel(const Model& model, + ExecutionPreference preference, + Priority priority, + OptionalTimePoint deadline, + const std::vector& modelCache, + const std::vector& dataCache, + const CacheToken& token, + const std::vector& hints, + const std::vector& extensionNameToPrefix) const override + { + VLOG(DRIVER) << "ArmnnDriver::prepareModel()"; + + // Validate arguments. + if (const auto result = validate(model); !result.ok()) { + return NN_ERROR(ErrorStatus::INVALID_ARGUMENT) << "Invalid Model: " << result.error(); + } + if (const auto result = validate(preference); !result.ok()) { + return NN_ERROR(ErrorStatus::INVALID_ARGUMENT) + << "Invalid ExecutionPreference: " << result.error(); + } + if (const auto result = validate(priority); !result.ok()) { + return NN_ERROR(ErrorStatus::INVALID_ARGUMENT) << "Invalid Priority: " << result.error(); + } + + // Check if deadline has passed. + if (hasDeadlinePassed(deadline)) { + return NN_ERROR(ErrorStatus::MISSED_DEADLINE_PERSISTENT); + } + + return ArmnnDriverImpl::PrepareArmnnModel(m_Runtime, + m_ClTunedParameters, + m_Options, + model, + modelCache, + dataCache, + token, + model.relaxComputationFloat32toFloat16 && m_Options.GetFp16Enabled(), + priority); + } + + GeneralResult prepareModelFromCache(OptionalTimePoint deadline, + const std::vector& modelCache, + const std::vector& dataCache, + const CacheToken& token) const override + { + VLOG(DRIVER) << "ArmnnDriver::prepareModelFromCache()"; + + // Check if deadline has passed. + if (hasDeadlinePassed(deadline)) { + return NN_ERROR(ErrorStatus::MISSED_DEADLINE_PERSISTENT); + } + + return ArmnnDriverImpl::PrepareArmnnModelFromCache( + m_Runtime, + m_ClTunedParameters, + m_Options, + modelCache, + dataCache, + token, + m_Options.GetFp16Enabled()); + } + + GeneralResult allocate(const BufferDesc&, + const std::vector&, + const std::vector&, + const std::vector&) const override + { + VLOG(DRIVER) << "ArmnnDriver::allocate()"; + return NN_ERROR(ErrorStatus::INVALID_ARGUMENT) << "ArmnnDriver::allocate -- does not support allocate."; + } +}; + +} // namespace armnn_driver -- cgit v1.2.1