aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/graph/GraphContext.h
diff options
context:
space:
mode:
Diffstat (limited to 'arm_compute/graph/GraphContext.h')
-rw-r--r--arm_compute/graph/GraphContext.h38
1 files changed, 34 insertions, 4 deletions
diff --git a/arm_compute/graph/GraphContext.h b/arm_compute/graph/GraphContext.h
index 21ba6df785..0eb9e81175 100644
--- a/arm_compute/graph/GraphContext.h
+++ b/arm_compute/graph/GraphContext.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018 ARM Limited.
+ * Copyright (c) 2018-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -27,6 +27,7 @@
#include "arm_compute/graph/Types.h"
#include "arm_compute/runtime/IMemoryManager.h"
+#include "arm_compute/runtime/IWeightsManager.h"
#include <map>
#include <memory>
@@ -45,6 +46,13 @@ struct MemoryManagerContext
IAllocator *allocator = { nullptr }; /**< Backend allocator to use */
};
+/** Contains structs required for weights management */
+struct WeightsManagerContext
+{
+ Target target = { Target::UNSPECIFIED }; /**< Target */
+ std::shared_ptr<arm_compute::IWeightsManager> wm = { nullptr }; /**< Weights manager */
+};
+
/** Graph context **/
class GraphContext final
{
@@ -77,7 +85,7 @@ public:
*
* @param[in] memory_ctx Memory manage context
*
- * @return If the insertion succeeded else false
+ * @return True if the insertion succeeded else false
*/
bool insert_memory_management_ctx(MemoryManagerContext &&memory_ctx);
/** Gets a memory manager context for a given target
@@ -92,12 +100,34 @@ public:
* @return Memory manager contexts
*/
std::map<Target, MemoryManagerContext> &memory_managers();
+ /** Inserts a weights manager context
+ *
+ * @param[in] weights_ctx Weights manager context
+ *
+ * @return True if the insertion succeeded else false
+ */
+ bool insert_weights_management_ctx(WeightsManagerContext &&weights_ctx);
+
+ /** Gets a weights manager context for a given target
+ *
+ * @param[in] target To retrieve the weights management context
+ *
+ * @return Management context for the target if exists else nullptr
+ */
+ WeightsManagerContext *weights_management_ctx(Target target);
+
+ /** Gets the weights managers map
+ *
+ * @return Weights manager contexts
+ */
+ std::map<Target, WeightsManagerContext> &weights_managers();
/** Finalizes memory managers in graph context */
void finalize();
private:
- GraphConfig _config; /**< Graph configuration */
- std::map<Target, MemoryManagerContext> _memory_managers; /**< Memory managers for each target */
+ GraphConfig _config; /**< Graph configuration */
+ std::map<Target, MemoryManagerContext> _memory_managers; /**< Memory managers for each target */
+ std::map<Target, WeightsManagerContext> _weights_managers; /**< Weights managers for each target */
};
} // namespace graph
} // namespace arm_compute