// // Copyright © 2019 Arm Ltd. All rights reserved. // SPDX-License-Identifier: MIT // #pragma once #include "BackendId.hpp" #include #include namespace armnn { struct BackendOptions; using NetworkOptions = std::vector; using ModelOptions = std::vector; using BackendCapabilities = BackendOptions; /// Struct for the users to pass backend specific options struct BackendOptions { private: template struct CheckAllowed { static const bool value = std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value; }; public: /// Very basic type safe variant class Var { public: /// Constructors explicit Var(int i) : m_Vals(i), m_Type(VarTypes::Integer) {}; explicit Var(unsigned int u) : m_Vals(u), m_Type(VarTypes::UnsignedInteger) {}; explicit Var(float f) : m_Vals(f), m_Type(VarTypes::Float) {}; explicit Var(bool b) : m_Vals(b), m_Type(VarTypes::Boolean) {}; explicit Var(const char* s) : m_Vals(s), m_Type(VarTypes::String) {}; explicit Var(std::string s) : m_Vals(s), m_Type(VarTypes::String) {}; /// Disallow implicit conversions from types not explicitly allowed below. template Var(DisallowedType) { static_assert(CheckAllowed::value, "Type is not allowed for Var."); assert(false && "Unreachable code"); } /// Copy Construct Var(const Var& other) : m_Type(other.m_Type) { switch(m_Type) { case VarTypes::String: { new (&m_Vals.s) std::string(other.m_Vals.s); break; } default: { DoOp(other, [](auto& a, auto& b) { a = b; }); break; } } } /// Copy operator Var& operator=(const Var& other) { // Destroy existing string if (m_Type == VarTypes::String) { Destruct(m_Vals.s); } m_Type = other.m_Type; switch(m_Type) { case VarTypes::String: { new (&m_Vals.s) std::string(other.m_Vals.s); break; } default: { DoOp(other, [](auto& a, auto& b) { a = b; }); break; } } return *this; }; /// Type getters bool IsBool() const { return m_Type == VarTypes::Boolean; } bool IsInt() const { return m_Type == VarTypes::Integer; } bool IsUnsignedInt() const { return m_Type == VarTypes::UnsignedInteger; } bool IsFloat() const { return m_Type == VarTypes::Float; } bool IsString() const { return m_Type == VarTypes::String; } /// Value getters bool AsBool() const { assert(IsBool()); return m_Vals.b; } int AsInt() const { assert(IsInt()); return m_Vals.i; } unsigned int AsUnsignedInt() const { assert(IsUnsignedInt()); return m_Vals.u; } float AsFloat() const { assert(IsFloat()); return m_Vals.f; } std::string AsString() const { assert(IsString()); return m_Vals.s; } std::string ToString() { if (IsBool()) { return AsBool() ? "true" : "false"; } else if (IsInt()) { return std::to_string(AsInt()); } else if (IsUnsignedInt()) { return std::to_string(AsUnsignedInt()); } else if (IsFloat()) { return std::to_string(AsFloat()); } else if (IsString()) { return AsString(); } else { throw armnn::InvalidArgumentException("Unknown data type for string conversion"); } } /// Destructor ~Var() { DoOp(*this, [this](auto& a, auto&) { Destruct(a); }); } private: template void DoOp(const Var& other, Func func) { if (other.IsBool()) { func(m_Vals.b, other.m_Vals.b); } else if (other.IsInt()) { func(m_Vals.i, other.m_Vals.i); } else if (other.IsUnsignedInt()) { func(m_Vals.u, other.m_Vals.u); } else if (other.IsFloat()) { func(m_Vals.f, other.m_Vals.f); } else if (other.IsString()) { func(m_Vals.s, other.m_Vals.s); } } template void Destruct(Destructable& d) { if (std::is_destructible::value) { d.~Destructable(); } } private: /// Types which can be stored enum class VarTypes { Boolean, Integer, Float, String, UnsignedInteger }; /// Union of potential type values. union Vals { int i; unsigned int u; float f; bool b; std::string s; Vals(){} ~Vals(){} explicit Vals(int i) : i(i) {}; explicit Vals(unsigned int u) : u(u) {}; explicit Vals(float f) : f(f) {}; explicit Vals(bool b) : b(b) {}; explicit Vals(const char* s) : s(std::string(s)) {} explicit Vals(std::string s) : s(s) {} }; Vals m_Vals; VarTypes m_Type; }; struct BackendOption { public: BackendOption(std::string name, bool value) : m_Name(name), m_Value(value) {} BackendOption(std::string name, int value) : m_Name(name), m_Value(value) {} BackendOption(std::string name, unsigned int value) : m_Name(name), m_Value(value) {} BackendOption(std::string name, float value) : m_Name(name), m_Value(value) {} BackendOption(std::string name, std::string value) : m_Name(name), m_Value(value) {} BackendOption(std::string name, const char* value) : m_Name(name), m_Value(value) {} template BackendOption(std::string, DisallowedType) : m_Value(0) { static_assert(CheckAllowed::value, "Type is not allowed for BackendOption."); assert(false && "Unreachable code"); } BackendOption(const BackendOption& other) = default; BackendOption(BackendOption&& other) = default; BackendOption& operator=(const BackendOption& other) = default; BackendOption& operator=(BackendOption&& other) = default; ~BackendOption() = default; std::string GetName() const { return m_Name; } Var GetValue() const { return m_Value; } private: std::string m_Name; ///< Name of the option Var m_Value; ///< Value of the option. (Bool, int, Float, String) }; explicit BackendOptions(BackendId backend) : m_TargetBackend(backend) {} BackendOptions(BackendId backend, std::initializer_list options) : m_TargetBackend(backend) , m_Options(options) {} BackendOptions(const BackendOptions& other) = default; BackendOptions(BackendOptions&& other) = default; BackendOptions& operator=(const BackendOptions& other) = default; BackendOptions& operator=(BackendOptions&& other) = default; void AddOption(BackendOption&& option) { m_Options.push_back(option); } void AddOption(const BackendOption& option) { m_Options.push_back(option); } const BackendId& GetBackendId() const noexcept { return m_TargetBackend; } size_t GetOptionCount() const noexcept { return m_Options.size(); } const BackendOption& GetOption(size_t idx) const { return m_Options[idx]; } private: /// The id for the backend to which the options should be passed. BackendId m_TargetBackend; /// The array of options to pass to the backend context std::vector m_Options; }; template void ParseOptions(const std::vector& options, BackendId backend, F f) { for (auto optionsGroup : options) { if (optionsGroup.GetBackendId() == backend) { for (size_t i=0; i < optionsGroup.GetOptionCount(); i++) { const BackendOptions::BackendOption option = optionsGroup.GetOption(i); f(option.GetName(), option.GetValue()); } } } } inline bool ParseBooleanBackendOption(const armnn::BackendOptions::Var& value, bool defaultValue) { if (value.IsBool()) { return value.AsBool(); } return defaultValue; } inline std::string ParseStringBackendOption(const armnn::BackendOptions::Var& value, std::string defaultValue) { if (value.IsString()) { return value.AsString(); } return defaultValue; } inline int ParseIntBackendOption(const armnn::BackendOptions::Var& value, int defaultValue) { if (value.IsInt()) { return value.AsInt(); } return defaultValue; } } //namespace armnn