aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/graph
diff options
context:
space:
mode:
authorMichalis Spyrou <michalis.spyrou@arm.com>2019-09-10 17:20:34 +0100
committerMichalis Spyrou <michalis.spyrou@arm.com>2019-09-26 10:17:30 +0000
commit1a569a30a2f456ff1a3e0a665201e1c3ab92df80 (patch)
tree9d68934f461579edefbe65246f6ee435aaa18808 /arm_compute/graph
parentf1cf394ae882e6e8fb2e0986f88d2548b82a85bb (diff)
downloadComputeLibrary-1a569a30a2f456ff1a3e0a665201e1c3ab92df80.tar.gz
COMPMID-2161 [NEON] Create IWeightManager class
Change-Id: I1a9a46da2f98e896b825099151b56d1d8271dd31 Signed-off-by: Michalis Spyrou <michalis.spyrou@arm.com> Reviewed-on: https://review.mlplatform.org/c/1915 Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'arm_compute/graph')
-rw-r--r--arm_compute/graph/GraphContext.h38
-rw-r--r--arm_compute/graph/IDeviceBackend.h8
-rw-r--r--arm_compute/graph/Types.h1
-rw-r--r--arm_compute/graph/backends/CL/CLDeviceBackend.h1
-rw-r--r--arm_compute/graph/backends/FunctionHelpers.h4
-rw-r--r--arm_compute/graph/backends/GLES/GCDeviceBackend.h3
-rw-r--r--arm_compute/graph/backends/NEON/NEDeviceBackend.h3
-rw-r--r--arm_compute/graph/backends/Utils.h16
-rw-r--r--arm_compute/graph/frontend/Layers.h29
9 files changed, 92 insertions, 11 deletions
diff --git a/arm_compute/graph/GraphContext.h b/arm_compute/graph/GraphContext.h
index 21ba6df785..0eb9e81175 100644
--- a/arm_compute/graph/GraphContext.h
+++ b/arm_compute/graph/GraphContext.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018 ARM Limited.
+ * Copyright (c) 2018-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -27,6 +27,7 @@
#include "arm_compute/graph/Types.h"
#include "arm_compute/runtime/IMemoryManager.h"
+#include "arm_compute/runtime/IWeightsManager.h"
#include <map>
#include <memory>
@@ -45,6 +46,13 @@ struct MemoryManagerContext
IAllocator *allocator = { nullptr }; /**< Backend allocator to use */
};
+/** Contains structs required for weights management */
+struct WeightsManagerContext
+{
+ Target target = { Target::UNSPECIFIED }; /**< Target */
+ std::shared_ptr<arm_compute::IWeightsManager> wm = { nullptr }; /**< Weights manager */
+};
+
/** Graph context **/
class GraphContext final
{
@@ -77,7 +85,7 @@ public:
*
* @param[in] memory_ctx Memory manage context
*
- * @return If the insertion succeeded else false
+ * @return True if the insertion succeeded else false
*/
bool insert_memory_management_ctx(MemoryManagerContext &&memory_ctx);
/** Gets a memory manager context for a given target
@@ -92,12 +100,34 @@ public:
* @return Memory manager contexts
*/
std::map<Target, MemoryManagerContext> &memory_managers();
+ /** Inserts a weights manager context
+ *
+ * @param[in] weights_ctx Weights manager context
+ *
+ * @return True if the insertion succeeded else false
+ */
+ bool insert_weights_management_ctx(WeightsManagerContext &&weights_ctx);
+
+ /** Gets a weights manager context for a given target
+ *
+ * @param[in] target To retrieve the weights management context
+ *
+ * @return Management context for the target if exists else nullptr
+ */
+ WeightsManagerContext *weights_management_ctx(Target target);
+
+ /** Gets the weights managers map
+ *
+ * @return Weights manager contexts
+ */
+ std::map<Target, WeightsManagerContext> &weights_managers();
/** Finalizes memory managers in graph context */
void finalize();
private:
- GraphConfig _config; /**< Graph configuration */
- std::map<Target, MemoryManagerContext> _memory_managers; /**< Memory managers for each target */
+ GraphConfig _config; /**< Graph configuration */
+ std::map<Target, MemoryManagerContext> _memory_managers; /**< Memory managers for each target */
+ std::map<Target, WeightsManagerContext> _weights_managers; /**< Weights managers for each target */
};
} // namespace graph
} // namespace arm_compute
diff --git a/arm_compute/graph/IDeviceBackend.h b/arm_compute/graph/IDeviceBackend.h
index 358d26af81..cf54976c28 100644
--- a/arm_compute/graph/IDeviceBackend.h
+++ b/arm_compute/graph/IDeviceBackend.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018 ARM Limited.
+ * Copyright (c) 2018-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -28,6 +28,7 @@
#include "arm_compute/graph/Types.h"
#include "arm_compute/runtime/IFunction.h"
#include "arm_compute/runtime/IMemoryManager.h"
+#include "arm_compute/runtime/IWeightsManager.h"
#include <memory>
@@ -112,6 +113,11 @@ public:
* @return Memory manager
*/
virtual std::shared_ptr<arm_compute::IMemoryManager> create_memory_manager(MemoryManagerAffinity affinity) = 0;
+ /** Create a backend weights manager
+ *
+ * @return Weights manager
+ */
+ virtual std::shared_ptr<arm_compute::IWeightsManager> create_weights_manager() = 0;
};
} // namespace backends
} // namespace graph
diff --git a/arm_compute/graph/Types.h b/arm_compute/graph/Types.h
index 8b97708a63..63b1c94ac8 100644
--- a/arm_compute/graph/Types.h
+++ b/arm_compute/graph/Types.h
@@ -78,6 +78,7 @@ class TensorDescriptor;
struct GraphConfig
{
bool use_function_memory_manager{ true }; /**< Use a memory manager to manage per-funcion auxilary memory */
+ bool use_function_weights_manager{ true }; /**< Use a weights manager to manage transformed weights */
bool use_transition_memory_manager{ true }; /**< Use a memory manager to manager transition buffer memory */
bool use_tuner{ false }; /**< Use a tuner in tunable backends */
CLTunerMode tuner_mode{ CLTunerMode::EXHAUSTIVE }; /**< Tuner mode to be used by the CL tuner */
diff --git a/arm_compute/graph/backends/CL/CLDeviceBackend.h b/arm_compute/graph/backends/CL/CLDeviceBackend.h
index afe01fff70..8569cf1f34 100644
--- a/arm_compute/graph/backends/CL/CLDeviceBackend.h
+++ b/arm_compute/graph/backends/CL/CLDeviceBackend.h
@@ -67,6 +67,7 @@ public:
std::unique_ptr<arm_compute::IFunction> configure_node(INode &node, GraphContext &ctx) override;
Status validate_node(INode &node) override;
std::shared_ptr<arm_compute::IMemoryManager> create_memory_manager(MemoryManagerAffinity affinity) override;
+ std::shared_ptr<arm_compute::IWeightsManager> create_weights_manager() override;
private:
int _context_count; /**< Counts how many contexts are currently using the backend */
diff --git a/arm_compute/graph/backends/FunctionHelpers.h b/arm_compute/graph/backends/FunctionHelpers.h
index dd833061a9..10f8c0c5c7 100644
--- a/arm_compute/graph/backends/FunctionHelpers.h
+++ b/arm_compute/graph/backends/FunctionHelpers.h
@@ -827,7 +827,9 @@ std::unique_ptr<IFunction> create_fully_connected_layer(FullyConnectedLayerNode
ARM_COMPUTE_ERROR_ON(output == nullptr);
// Create and configure function
- auto func = support::cpp14::make_unique<FullyConnectedLayerFunction>(get_memory_manager(ctx, TargetInfo::TargetType));
+ auto wm = get_weights_manager(ctx, TargetInfo::TargetType);
+ auto mm = get_memory_manager(ctx, TargetInfo::TargetType);
+ auto func = support::cpp14::make_unique<FullyConnectedLayerFunction>(mm, wm.get());
func->configure(input, weights, biases, output, fc_info);
const bool is_quantized = is_data_type_quantized_asymmetric(input->info()->data_type());
diff --git a/arm_compute/graph/backends/GLES/GCDeviceBackend.h b/arm_compute/graph/backends/GLES/GCDeviceBackend.h
index ca2d3734eb..83a7458c98 100644
--- a/arm_compute/graph/backends/GLES/GCDeviceBackend.h
+++ b/arm_compute/graph/backends/GLES/GCDeviceBackend.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018 ARM Limited.
+ * Copyright (c) 2018-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -52,6 +52,7 @@ public:
std::unique_ptr<arm_compute::IFunction> configure_node(INode &node, GraphContext &ctx) override;
Status validate_node(INode &node) override;
std::shared_ptr<arm_compute::IMemoryManager> create_memory_manager(MemoryManagerAffinity affinity) override;
+ std::shared_ptr<arm_compute::IWeightsManager> create_weights_manager() override;
private:
bool _initialized; /**< Flag that specifies if the backend has been default initialized */
diff --git a/arm_compute/graph/backends/NEON/NEDeviceBackend.h b/arm_compute/graph/backends/NEON/NEDeviceBackend.h
index abc17d9e83..9891170fbd 100644
--- a/arm_compute/graph/backends/NEON/NEDeviceBackend.h
+++ b/arm_compute/graph/backends/NEON/NEDeviceBackend.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018 ARM Limited.
+ * Copyright (c) 2018-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -51,6 +51,7 @@ public:
std::unique_ptr<arm_compute::IFunction> configure_node(INode &node, GraphContext &ctx) override;
Status validate_node(INode &node) override;
std::shared_ptr<arm_compute::IMemoryManager> create_memory_manager(MemoryManagerAffinity affinity) override;
+ std::shared_ptr<arm_compute::IWeightsManager> create_weights_manager() override;
private:
Allocator _allocator; /**< NEON backend allocator */
diff --git a/arm_compute/graph/backends/Utils.h b/arm_compute/graph/backends/Utils.h
index c7a50d93c6..2ca97ff5c5 100644
--- a/arm_compute/graph/backends/Utils.h
+++ b/arm_compute/graph/backends/Utils.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018 ARM Limited.
+ * Copyright (c) 2018-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -26,6 +26,7 @@
#include "arm_compute/graph/GraphContext.h"
#include "arm_compute/runtime/IMemoryManager.h"
+#include "arm_compute/runtime/IWeightsManager.h"
namespace arm_compute
{
@@ -90,6 +91,19 @@ inline std::shared_ptr<IMemoryManager> get_memory_manager(GraphContext &ctx, Tar
bool enabled = ctx.config().use_function_memory_manager && (ctx.memory_management_ctx(target) != nullptr);
return enabled ? ctx.memory_management_ctx(target)->intra_mm : nullptr;
}
+
+/** Returns the weights manager for a given target
+ *
+ * @param[in] ctx Graph context containing weight management metadata
+ * @param[in] target Target to retrieve the weights manager from
+ *
+ * @return The weights manager for the given target else false
+ */
+inline std::shared_ptr<IWeightsManager> get_weights_manager(GraphContext &ctx, Target target)
+{
+ bool enabled = ctx.config().use_function_weights_manager && (ctx.weights_management_ctx(target) != nullptr);
+ return enabled ? ctx.weights_management_ctx(target)->wm : nullptr;
+}
} // namespace backends
} // namespace graph
} // namespace arm_compute
diff --git a/arm_compute/graph/frontend/Layers.h b/arm_compute/graph/frontend/Layers.h
index 27a0cd3026..120997a8b4 100644
--- a/arm_compute/graph/frontend/Layers.h
+++ b/arm_compute/graph/frontend/Layers.h
@@ -66,6 +66,31 @@ private:
ITensorAccessorUPtr _accessor;
};
+/** Constant Layer */
+class ConstantLayer final : public ILayer
+{
+public:
+ /** Construct a constant layer.
+ *
+ * @param[in] desc Description of input tensor.
+ * @param[in] accessor Accessor to get input tensor data from.
+ */
+ ConstantLayer(TensorDescriptor desc, ITensorAccessorUPtr accessor)
+ : _desc(desc), _accessor(std::move(accessor))
+ {
+ }
+
+ NodeID create_layer(IStream &s) override
+ {
+ NodeParams common_params = { name(), s.hints().target_hint };
+ return GraphBuilder::add_const_node(s.graph(), common_params, _desc, std::move(_accessor));
+ }
+
+private:
+ TensorDescriptor _desc;
+ ITensorAccessorUPtr _accessor;
+};
+
/** Output Layer */
class OutputLayer final : public ILayer
{
@@ -635,8 +660,8 @@ public:
* @param[in] out_quant_info (Optional) Output quantization info
*/
FullyConnectedLayer(unsigned int num_outputs,
- SubStream &&sub_stream_weights,
- SubStream &&sub_stream_bias,
+ SubStream sub_stream_weights,
+ SubStream sub_stream_bias,
const FullyConnectedLayerInfo fc_info = FullyConnectedLayerInfo(),
const QuantizationInfo weights_quant_info = QuantizationInfo(),
const QuantizationInfo out_quant_info = QuantizationInfo())