aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/graph/backends
diff options
context:
space:
mode:
Diffstat (limited to 'arm_compute/graph/backends')
-rw-r--r--arm_compute/graph/backends/CL/CLDeviceBackend.h1
-rw-r--r--arm_compute/graph/backends/FunctionHelpers.h4
-rw-r--r--arm_compute/graph/backends/GLES/GCDeviceBackend.h3
-rw-r--r--arm_compute/graph/backends/NEON/NEDeviceBackend.h3
-rw-r--r--arm_compute/graph/backends/Utils.h16
5 files changed, 23 insertions, 4 deletions
diff --git a/arm_compute/graph/backends/CL/CLDeviceBackend.h b/arm_compute/graph/backends/CL/CLDeviceBackend.h
index afe01fff70..8569cf1f34 100644
--- a/arm_compute/graph/backends/CL/CLDeviceBackend.h
+++ b/arm_compute/graph/backends/CL/CLDeviceBackend.h
@@ -67,6 +67,7 @@ public:
std::unique_ptr<arm_compute::IFunction> configure_node(INode &node, GraphContext &ctx) override;
Status validate_node(INode &node) override;
std::shared_ptr<arm_compute::IMemoryManager> create_memory_manager(MemoryManagerAffinity affinity) override;
+ std::shared_ptr<arm_compute::IWeightsManager> create_weights_manager() override;
private:
int _context_count; /**< Counts how many contexts are currently using the backend */
diff --git a/arm_compute/graph/backends/FunctionHelpers.h b/arm_compute/graph/backends/FunctionHelpers.h
index dd833061a9..10f8c0c5c7 100644
--- a/arm_compute/graph/backends/FunctionHelpers.h
+++ b/arm_compute/graph/backends/FunctionHelpers.h
@@ -827,7 +827,9 @@ std::unique_ptr<IFunction> create_fully_connected_layer(FullyConnectedLayerNode
ARM_COMPUTE_ERROR_ON(output == nullptr);
// Create and configure function
- auto func = support::cpp14::make_unique<FullyConnectedLayerFunction>(get_memory_manager(ctx, TargetInfo::TargetType));
+ auto wm = get_weights_manager(ctx, TargetInfo::TargetType);
+ auto mm = get_memory_manager(ctx, TargetInfo::TargetType);
+ auto func = support::cpp14::make_unique<FullyConnectedLayerFunction>(mm, wm.get());
func->configure(input, weights, biases, output, fc_info);
const bool is_quantized = is_data_type_quantized_asymmetric(input->info()->data_type());
diff --git a/arm_compute/graph/backends/GLES/GCDeviceBackend.h b/arm_compute/graph/backends/GLES/GCDeviceBackend.h
index ca2d3734eb..83a7458c98 100644
--- a/arm_compute/graph/backends/GLES/GCDeviceBackend.h
+++ b/arm_compute/graph/backends/GLES/GCDeviceBackend.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018 ARM Limited.
+ * Copyright (c) 2018-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -52,6 +52,7 @@ public:
std::unique_ptr<arm_compute::IFunction> configure_node(INode &node, GraphContext &ctx) override;
Status validate_node(INode &node) override;
std::shared_ptr<arm_compute::IMemoryManager> create_memory_manager(MemoryManagerAffinity affinity) override;
+ std::shared_ptr<arm_compute::IWeightsManager> create_weights_manager() override;
private:
bool _initialized; /**< Flag that specifies if the backend has been default initialized */
diff --git a/arm_compute/graph/backends/NEON/NEDeviceBackend.h b/arm_compute/graph/backends/NEON/NEDeviceBackend.h
index abc17d9e83..9891170fbd 100644
--- a/arm_compute/graph/backends/NEON/NEDeviceBackend.h
+++ b/arm_compute/graph/backends/NEON/NEDeviceBackend.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018 ARM Limited.
+ * Copyright (c) 2018-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -51,6 +51,7 @@ public:
std::unique_ptr<arm_compute::IFunction> configure_node(INode &node, GraphContext &ctx) override;
Status validate_node(INode &node) override;
std::shared_ptr<arm_compute::IMemoryManager> create_memory_manager(MemoryManagerAffinity affinity) override;
+ std::shared_ptr<arm_compute::IWeightsManager> create_weights_manager() override;
private:
Allocator _allocator; /**< NEON backend allocator */
diff --git a/arm_compute/graph/backends/Utils.h b/arm_compute/graph/backends/Utils.h
index c7a50d93c6..2ca97ff5c5 100644
--- a/arm_compute/graph/backends/Utils.h
+++ b/arm_compute/graph/backends/Utils.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018 ARM Limited.
+ * Copyright (c) 2018-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -26,6 +26,7 @@
#include "arm_compute/graph/GraphContext.h"
#include "arm_compute/runtime/IMemoryManager.h"
+#include "arm_compute/runtime/IWeightsManager.h"
namespace arm_compute
{
@@ -90,6 +91,19 @@ inline std::shared_ptr<IMemoryManager> get_memory_manager(GraphContext &ctx, Tar
bool enabled = ctx.config().use_function_memory_manager && (ctx.memory_management_ctx(target) != nullptr);
return enabled ? ctx.memory_management_ctx(target)->intra_mm : nullptr;
}
+
+/** Returns the weights manager for a given target
+ *
+ * @param[in] ctx Graph context containing weight management metadata
+ * @param[in] target Target to retrieve the weights manager from
+ *
+ * @return The weights manager for the given target else false
+ */
+inline std::shared_ptr<IWeightsManager> get_weights_manager(GraphContext &ctx, Target target)
+{
+ bool enabled = ctx.config().use_function_weights_manager && (ctx.weights_management_ctx(target) != nullptr);
+ return enabled ? ctx.weights_management_ctx(target)->wm : nullptr;
+}
} // namespace backends
} // namespace graph
} // namespace arm_compute