aboutsummaryrefslogtreecommitdiff
path: root/include/armnn/INetwork.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'include/armnn/INetwork.hpp')
-rw-r--r--include/armnn/INetwork.hpp14
1 files changed, 11 insertions, 3 deletions
diff --git a/include/armnn/INetwork.hpp b/include/armnn/INetwork.hpp
index fefb2ebc2d..0289a90e71 100644
--- a/include/armnn/INetwork.hpp
+++ b/include/armnn/INetwork.hpp
@@ -135,6 +135,7 @@ struct OptimizerOptions
, m_ModelOptions()
, m_ProfilingEnabled(false)
, m_ExportEnabled(false)
+ , m_AllowExpandedDims(false)
{}
OptimizerOptions(bool reduceFp32ToFp16, bool debug, bool reduceFp32ToBf16, bool importEnabled,
@@ -147,6 +148,7 @@ struct OptimizerOptions
, m_ModelOptions(modelOptions)
, m_ProfilingEnabled(false)
, m_ExportEnabled(exportEnabled)
+ , m_AllowExpandedDims(false)
{
if (m_ReduceFp32ToFp16 && m_ReduceFp32ToBf16)
{
@@ -156,7 +158,8 @@ struct OptimizerOptions
OptimizerOptions(bool reduceFp32ToFp16, bool debug, bool reduceFp32ToBf16 = false,
ShapeInferenceMethod shapeInferenceMethod = armnn::ShapeInferenceMethod::ValidateOnly,
- bool importEnabled = false, ModelOptions modelOptions = {}, bool exportEnabled = false)
+ bool importEnabled = false, ModelOptions modelOptions = {}, bool exportEnabled = false,
+ bool allowExpandedDims = false)
: m_ReduceFp32ToFp16(reduceFp32ToFp16)
, m_Debug(debug)
, m_ReduceFp32ToBf16(reduceFp32ToBf16)
@@ -165,6 +168,7 @@ struct OptimizerOptions
, m_ModelOptions(modelOptions)
, m_ProfilingEnabled(false)
, m_ExportEnabled(exportEnabled)
+ , m_AllowExpandedDims(allowExpandedDims)
{
if (m_ReduceFp32ToFp16 && m_ReduceFp32ToBf16)
{
@@ -184,6 +188,7 @@ struct OptimizerOptions
stream << "\tImportEnabled: " << m_ImportEnabled << "\n";
stream << "\tExportEnabled: " << m_ExportEnabled << "\n";
stream << "\tProfilingEnabled: " << m_ProfilingEnabled << "\n";
+ stream << "\tAllowExpandedDims: " << m_AllowExpandedDims << "\n";
stream << "\tModelOptions: \n";
for (auto optionsGroup : m_ModelOptions)
@@ -231,6 +236,9 @@ struct OptimizerOptions
// Enable Export
bool m_ExportEnabled;
+
+ // When calculating tensor sizes dimensions of size == 1 will be ignored
+ bool m_AllowExpandedDims;
};
class IWorkloadFactory;
@@ -246,8 +254,8 @@ using CompiledBlobPtr = std::unique_ptr<void, CompiledBlobDeleter>;
class INetwork
{
public:
- static INetwork* CreateRaw(NetworkOptions networkOptions = {});
- static INetworkPtr Create(NetworkOptions networkOptions = {});
+ static INetwork* CreateRaw(const NetworkOptions& networkOptions = {});
+ static INetworkPtr Create(const NetworkOptions& networkOptions = {});
static void Destroy(INetwork* network);
Status PrintGraph();