diff options
Diffstat (limited to 'include/armnn/INetwork.hpp')
-rw-r--r-- | include/armnn/INetwork.hpp | 14 |
1 files changed, 11 insertions, 3 deletions
diff --git a/include/armnn/INetwork.hpp b/include/armnn/INetwork.hpp index fefb2ebc2d..0289a90e71 100644 --- a/include/armnn/INetwork.hpp +++ b/include/armnn/INetwork.hpp @@ -135,6 +135,7 @@ struct OptimizerOptions , m_ModelOptions() , m_ProfilingEnabled(false) , m_ExportEnabled(false) + , m_AllowExpandedDims(false) {} OptimizerOptions(bool reduceFp32ToFp16, bool debug, bool reduceFp32ToBf16, bool importEnabled, @@ -147,6 +148,7 @@ struct OptimizerOptions , m_ModelOptions(modelOptions) , m_ProfilingEnabled(false) , m_ExportEnabled(exportEnabled) + , m_AllowExpandedDims(false) { if (m_ReduceFp32ToFp16 && m_ReduceFp32ToBf16) { @@ -156,7 +158,8 @@ struct OptimizerOptions OptimizerOptions(bool reduceFp32ToFp16, bool debug, bool reduceFp32ToBf16 = false, ShapeInferenceMethod shapeInferenceMethod = armnn::ShapeInferenceMethod::ValidateOnly, - bool importEnabled = false, ModelOptions modelOptions = {}, bool exportEnabled = false) + bool importEnabled = false, ModelOptions modelOptions = {}, bool exportEnabled = false, + bool allowExpandedDims = false) : m_ReduceFp32ToFp16(reduceFp32ToFp16) , m_Debug(debug) , m_ReduceFp32ToBf16(reduceFp32ToBf16) @@ -165,6 +168,7 @@ struct OptimizerOptions , m_ModelOptions(modelOptions) , m_ProfilingEnabled(false) , m_ExportEnabled(exportEnabled) + , m_AllowExpandedDims(allowExpandedDims) { if (m_ReduceFp32ToFp16 && m_ReduceFp32ToBf16) { @@ -184,6 +188,7 @@ struct OptimizerOptions stream << "\tImportEnabled: " << m_ImportEnabled << "\n"; stream << "\tExportEnabled: " << m_ExportEnabled << "\n"; stream << "\tProfilingEnabled: " << m_ProfilingEnabled << "\n"; + stream << "\tAllowExpandedDims: " << m_AllowExpandedDims << "\n"; stream << "\tModelOptions: \n"; for (auto optionsGroup : m_ModelOptions) @@ -231,6 +236,9 @@ struct OptimizerOptions // Enable Export bool m_ExportEnabled; + + // When calculating tensor sizes dimensions of size == 1 will be ignored + bool m_AllowExpandedDims; }; class IWorkloadFactory; @@ -246,8 +254,8 @@ using CompiledBlobPtr = std::unique_ptr<void, CompiledBlobDeleter>; class INetwork { public: - static INetwork* CreateRaw(NetworkOptions networkOptions = {}); - static INetworkPtr Create(NetworkOptions networkOptions = {}); + static INetwork* CreateRaw(const NetworkOptions& networkOptions = {}); + static INetworkPtr Create(const NetworkOptions& networkOptions = {}); static void Destroy(INetwork* network); Status PrintGraph(); |