aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMichalis Spyrou <michalis.spyrou@arm.com>2019-09-10 17:20:34 +0100
committerMichalis Spyrou <michalis.spyrou@arm.com>2019-09-26 10:17:30 +0000
commit1a569a30a2f456ff1a3e0a665201e1c3ab92df80 (patch)
tree9d68934f461579edefbe65246f6ee435aaa18808
parentf1cf394ae882e6e8fb2e0986f88d2548b82a85bb (diff)
downloadComputeLibrary-1a569a30a2f456ff1a3e0a665201e1c3ab92df80.tar.gz
COMPMID-2161 [NEON] Create IWeightManager class
Change-Id: I1a9a46da2f98e896b825099151b56d1d8271dd31 Signed-off-by: Michalis Spyrou <michalis.spyrou@arm.com> Reviewed-on: https://review.mlplatform.org/c/1915 Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com>
-rw-r--r--arm_compute/graph/GraphContext.h38
-rw-r--r--arm_compute/graph/IDeviceBackend.h8
-rw-r--r--arm_compute/graph/Types.h1
-rw-r--r--arm_compute/graph/backends/CL/CLDeviceBackend.h1
-rw-r--r--arm_compute/graph/backends/FunctionHelpers.h4
-rw-r--r--arm_compute/graph/backends/GLES/GCDeviceBackend.h3
-rw-r--r--arm_compute/graph/backends/NEON/NEDeviceBackend.h3
-rw-r--r--arm_compute/graph/backends/Utils.h16
-rw-r--r--arm_compute/graph/frontend/Layers.h29
-rw-r--r--arm_compute/runtime/CL/functions/CLConvertFullyConnectedWeights.h4
-rw-r--r--arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h3
-rw-r--r--arm_compute/runtime/CL/functions/CLGEMMConvolutionLayer.h2
-rw-r--r--arm_compute/runtime/GLES_COMPUTE/functions/GCFullyConnectedLayer.h3
-rw-r--r--arm_compute/runtime/ITransformWeights.h117
-rw-r--r--arm_compute/runtime/IWeightsManager.h85
-rw-r--r--arm_compute/runtime/NEON/functions/NEConvertFullyConnectedWeights.h48
-rw-r--r--arm_compute/runtime/NEON/functions/NEDeconvolutionLayer.h2
-rw-r--r--arm_compute/runtime/NEON/functions/NEFullyConnectedLayer.h85
-rw-r--r--arm_compute/runtime/NEON/functions/NEGEMM.h6
-rw-r--r--arm_compute/runtime/NEON/functions/NEGEMMAssemblyDispatch.h12
-rw-r--r--arm_compute/runtime/NEON/functions/NEGEMMConvolutionLayer.h77
-rw-r--r--arm_compute/runtime/NEON/functions/NERNNLayer.h4
-rw-r--r--arm_compute/runtime/NEON/functions/assembly/NEGEMMInterleavedWrapper.h5
-rw-r--r--src/graph/GraphContext.cpp29
-rw-r--r--src/graph/backends/CL/CLDeviceBackend.cpp5
-rw-r--r--src/graph/backends/GLES/GCDeviceBackend.cpp5
-rw-r--r--src/graph/backends/NEON/NEDeviceBackend.cpp17
-rw-r--r--src/graph/backends/NEON/NEFunctionFactory.cpp1
-rw-r--r--src/runtime/CL/functions/CLFullyConnectedLayer.cpp2
-rw-r--r--src/runtime/GLES_COMPUTE/functions/GCFullyConnectedLayer.cpp2
-rw-r--r--src/runtime/IWeightsManager.cpp128
-rw-r--r--src/runtime/NEON/functions/NEConvertFullyConnectedWeights.cpp7
-rw-r--r--src/runtime/NEON/functions/NEDeconvolutionLayer.cpp20
-rw-r--r--src/runtime/NEON/functions/NEFullyConnectedLayer.cpp89
-rw-r--r--src/runtime/NEON/functions/NEGEMM.cpp16
-rw-r--r--src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp163
-rw-r--r--src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp41
-rw-r--r--src/runtime/NEON/functions/NERNNLayer.cpp10
-rw-r--r--src/runtime/NEON/functions/assembly/NEGEMMInterleavedWrapper.cpp2
39 files changed, 938 insertions, 155 deletions
diff --git a/arm_compute/graph/GraphContext.h b/arm_compute/graph/GraphContext.h
index 21ba6df785..0eb9e81175 100644
--- a/arm_compute/graph/GraphContext.h
+++ b/arm_compute/graph/GraphContext.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018 ARM Limited.
+ * Copyright (c) 2018-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -27,6 +27,7 @@
#include "arm_compute/graph/Types.h"
#include "arm_compute/runtime/IMemoryManager.h"
+#include "arm_compute/runtime/IWeightsManager.h"
#include <map>
#include <memory>
@@ -45,6 +46,13 @@ struct MemoryManagerContext
IAllocator *allocator = { nullptr }; /**< Backend allocator to use */
};
+/** Contains structs required for weights management */
+struct WeightsManagerContext
+{
+ Target target = { Target::UNSPECIFIED }; /**< Target */
+ std::shared_ptr<arm_compute::IWeightsManager> wm = { nullptr }; /**< Weights manager */
+};
+
/** Graph context **/
class GraphContext final
{
@@ -77,7 +85,7 @@ public:
*
* @param[in] memory_ctx Memory manage context
*
- * @return If the insertion succeeded else false
+ * @return True if the insertion succeeded else false
*/
bool insert_memory_management_ctx(MemoryManagerContext &&memory_ctx);
/** Gets a memory manager context for a given target
@@ -92,12 +100,34 @@ public:
* @return Memory manager contexts
*/
std::map<Target, MemoryManagerContext> &memory_managers();
+ /** Inserts a weights manager context
+ *
+ * @param[in] weights_ctx Weights manager context
+ *
+ * @return True if the insertion succeeded else false
+ */
+ bool insert_weights_management_ctx(WeightsManagerContext &&weights_ctx);
+
+ /** Gets a weights manager context for a given target
+ *
+ * @param[in] target To retrieve the weights management context
+ *
+ * @return Management context for the target if exists else nullptr
+ */
+ WeightsManagerContext *weights_management_ctx(Target target);
+
+ /** Gets the weights managers map
+ *
+ * @return Weights manager contexts
+ */
+ std::map<Target, WeightsManagerContext> &weights_managers();
/** Finalizes memory managers in graph context */
void finalize();
private:
- GraphConfig _config; /**< Graph configuration */
- std::map<Target, MemoryManagerContext> _memory_managers; /**< Memory managers for each target */
+ GraphConfig _config; /**< Graph configuration */
+ std::map<Target, MemoryManagerContext> _memory_managers; /**< Memory managers for each target */
+ std::map<Target, WeightsManagerContext> _weights_managers; /**< Weights managers for each target */
};
} // namespace graph
} // namespace arm_compute
diff --git a/arm_compute/graph/IDeviceBackend.h b/arm_compute/graph/IDeviceBackend.h
index 358d26af81..cf54976c28 100644
--- a/arm_compute/graph/IDeviceBackend.h
+++ b/arm_compute/graph/IDeviceBackend.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018 ARM Limited.
+ * Copyright (c) 2018-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -28,6 +28,7 @@
#include "arm_compute/graph/Types.h"
#include "arm_compute/runtime/IFunction.h"
#include "arm_compute/runtime/IMemoryManager.h"
+#include "arm_compute/runtime/IWeightsManager.h"
#include <memory>
@@ -112,6 +113,11 @@ public:
* @return Memory manager
*/
virtual std::shared_ptr<arm_compute::IMemoryManager> create_memory_manager(MemoryManagerAffinity affinity) = 0;
+ /** Create a backend weights manager
+ *
+ * @return Weights manager
+ */
+ virtual std::shared_ptr<arm_compute::IWeightsManager> create_weights_manager() = 0;
};
} // namespace backends
} // namespace graph
diff --git a/arm_compute/graph/Types.h b/arm_compute/graph/Types.h
index 8b97708a63..63b1c94ac8 100644
--- a/arm_compute/graph/Types.h
+++ b/arm_compute/graph/Types.h
@@ -78,6 +78,7 @@ class TensorDescriptor;
struct GraphConfig
{
bool use_function_memory_manager{ true }; /**< Use a memory manager to manage per-funcion auxilary memory */
+ bool use_function_weights_manager{ true }; /**< Use a weights manager to manage transformed weights */
bool use_transition_memory_manager{ true }; /**< Use a memory manager to manager transition buffer memory */
bool use_tuner{ false }; /**< Use a tuner in tunable backends */
CLTunerMode tuner_mode{ CLTunerMode::EXHAUSTIVE }; /**< Tuner mode to be used by the CL tuner */
diff --git a/arm_compute/graph/backends/CL/CLDeviceBackend.h b/arm_compute/graph/backends/CL/CLDeviceBackend.h
index afe01fff70..8569cf1f34 100644
--- a/arm_compute/graph/backends/CL/CLDeviceBackend.h
+++ b/arm_compute/graph/backends/CL/CLDeviceBackend.h
@@ -67,6 +67,7 @@ public:
std::unique_ptr<arm_compute::IFunction> configure_node(INode &node, GraphContext &ctx) override;
Status validate_node(INode &node) override;
std::shared_ptr<arm_compute::IMemoryManager> create_memory_manager(MemoryManagerAffinity affinity) override;
+ std::shared_ptr<arm_compute::IWeightsManager> create_weights_manager() override;
private:
int _context_count; /**< Counts how many contexts are currently using the backend */
diff --git a/arm_compute/graph/backends/FunctionHelpers.h b/arm_compute/graph/backends/FunctionHelpers.h
index dd833061a9..10f8c0c5c7 100644
--- a/arm_compute/graph/backends/FunctionHelpers.h
+++ b/arm_compute/graph/backends/FunctionHelpers.h
@@ -827,7 +827,9 @@ std::unique_ptr<IFunction> create_fully_connected_layer(FullyConnectedLayerNode
ARM_COMPUTE_ERROR_ON(output == nullptr);
// Create and configure function
- auto func = support::cpp14::make_unique<FullyConnectedLayerFunction>(get_memory_manager(ctx, TargetInfo::TargetType));
+ auto wm = get_weights_manager(ctx, TargetInfo::TargetType);
+ auto mm = get_memory_manager(ctx, TargetInfo::TargetType);
+ auto func = support::cpp14::make_unique<FullyConnectedLayerFunction>(mm, wm.get());
func->configure(input, weights, biases, output, fc_info);
const bool is_quantized = is_data_type_quantized_asymmetric(input->info()->data_type());
diff --git a/arm_compute/graph/backends/GLES/GCDeviceBackend.h b/arm_compute/graph/backends/GLES/GCDeviceBackend.h
index ca2d3734eb..83a7458c98 100644
--- a/arm_compute/graph/backends/GLES/GCDeviceBackend.h
+++ b/arm_compute/graph/backends/GLES/GCDeviceBackend.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018 ARM Limited.
+ * Copyright (c) 2018-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -52,6 +52,7 @@ public:
std::unique_ptr<arm_compute::IFunction> configure_node(INode &node, GraphContext &ctx) override;
Status validate_node(INode &node) override;
std::shared_ptr<arm_compute::IMemoryManager> create_memory_manager(MemoryManagerAffinity affinity) override;
+ std::shared_ptr<arm_compute::IWeightsManager> create_weights_manager() override;
private:
bool _initialized; /**< Flag that specifies if the backend has been default initialized */
diff --git a/arm_compute/graph/backends/NEON/NEDeviceBackend.h b/arm_compute/graph/backends/NEON/NEDeviceBackend.h
index abc17d9e83..9891170fbd 100644
--- a/arm_compute/graph/backends/NEON/NEDeviceBackend.h
+++ b/arm_compute/graph/backends/NEON/NEDeviceBackend.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018 ARM Limited.
+ * Copyright (c) 2018-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -51,6 +51,7 @@ public:
std::unique_ptr<arm_compute::IFunction> configure_node(INode &node, GraphContext &ctx) override;
Status validate_node(INode &node) override;
std::shared_ptr<arm_compute::IMemoryManager> create_memory_manager(MemoryManagerAffinity affinity) override;
+ std::shared_ptr<arm_compute::IWeightsManager> create_weights_manager() override;
private:
Allocator _allocator; /**< NEON backend allocator */
diff --git a/arm_compute/graph/backends/Utils.h b/arm_compute/graph/backends/Utils.h
index c7a50d93c6..2ca97ff5c5 100644
--- a/arm_compute/graph/backends/Utils.h
+++ b/arm_compute/graph/backends/Utils.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018 ARM Limited.
+ * Copyright (c) 2018-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -26,6 +26,7 @@
#include "arm_compute/graph/GraphContext.h"
#include "arm_compute/runtime/IMemoryManager.h"
+#include "arm_compute/runtime/IWeightsManager.h"
namespace arm_compute
{
@@ -90,6 +91,19 @@ inline std::shared_ptr<IMemoryManager> get_memory_manager(GraphContext &ctx, Tar
bool enabled = ctx.config().use_function_memory_manager && (ctx.memory_management_ctx(target) != nullptr);
return enabled ? ctx.memory_management_ctx(target)->intra_mm : nullptr;
}
+
+/** Returns the weights manager for a given target
+ *
+ * @param[in] ctx Graph context containing weight management metadata
+ * @param[in] target Target to retrieve the weights manager from
+ *
+ * @return The weights manager for the given target else false
+ */
+inline std::shared_ptr<IWeightsManager> get_weights_manager(GraphContext &ctx, Target target)
+{
+ bool enabled = ctx.config().use_function_weights_manager && (ctx.weights_management_ctx(target) != nullptr);
+ return enabled ? ctx.weights_management_ctx(target)->wm : nullptr;
+}
} // namespace backends
} // namespace graph
} // namespace arm_compute
diff --git a/arm_compute/graph/frontend/Layers.h b/arm_compute/graph/frontend/Layers.h
index 27a0cd3026..120997a8b4 100644
--- a/arm_compute/graph/frontend/Layers.h
+++ b/arm_compute/graph/frontend/Layers.h
@@ -66,6 +66,31 @@ private:
ITensorAccessorUPtr _accessor;
};
+/** Constant Layer */
+class ConstantLayer final : public ILayer
+{
+public:
+ /** Construct a constant layer.
+ *
+ * @param[in] desc Description of input tensor.
+ * @param[in] accessor Accessor to get input tensor data from.
+ */
+ ConstantLayer(TensorDescriptor desc, ITensorAccessorUPtr accessor)
+ : _desc(desc), _accessor(std::move(accessor))
+ {
+ }
+
+ NodeID create_layer(IStream &s) override
+ {
+ NodeParams common_params = { name(), s.hints().target_hint };
+ return GraphBuilder::add_const_node(s.graph(), common_params, _desc, std::move(_accessor));
+ }
+
+private:
+ TensorDescriptor _desc;
+ ITensorAccessorUPtr _accessor;
+};
+
/** Output Layer */
class OutputLayer final : public ILayer
{
@@ -635,8 +660,8 @@ public:
* @param[in] out_quant_info (Optional) Output quantization info
*/
FullyConnectedLayer(unsigned int num_outputs,
- SubStream &&sub_stream_weights,
- SubStream &&sub_stream_bias,
+ SubStream sub_stream_weights,
+ SubStream sub_stream_bias,
const FullyConnectedLayerInfo fc_info = FullyConnectedLayerInfo(),
const QuantizationInfo weights_quant_info = QuantizationInfo(),
const QuantizationInfo out_quant_info = QuantizationInfo())
diff --git a/arm_compute/runtime/CL/functions/CLConvertFullyConnectedWeights.h b/arm_compute/runtime/CL/functions/CLConvertFullyConnectedWeights.h
index 9bfade4894..43abb6769b 100644
--- a/arm_compute/runtime/CL/functions/CLConvertFullyConnectedWeights.h
+++ b/arm_compute/runtime/CL/functions/CLConvertFullyConnectedWeights.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018 ARM Limited.
+ * Copyright (c) 2018-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -41,6 +41,8 @@ public:
* @param[out] output The converted weights tensor. Shape and Data Type: Same as @p input.
* @param[in] original_input_shape Shape of the original input tensor (the one entering fully connected layer).
* @param[in] data_layout The data layout the weights have been trained in.
+ *
+ * @return A status
*/
void configure(const ICLTensor *input, ICLTensor *output, const TensorShape &original_input_shape, DataLayout data_layout);
/** Static function to check if given info will lead to a valid configuration of @ref CLConvertFullyConnectedWeights
diff --git a/arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h b/arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h
index 7cf7d951b6..d54304ed77 100644
--- a/arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h
+++ b/arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h
@@ -34,6 +34,7 @@
#include "arm_compute/runtime/CL/functions/CLGEMM.h"
#include "arm_compute/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.h"
#include "arm_compute/runtime/CL/functions/CLGEMMLowpOutputStage.h"
+#include "arm_compute/runtime/IWeightsManager.h"
#include "arm_compute/runtime/MemoryGroup.h"
namespace arm_compute
@@ -76,7 +77,7 @@ class CLFullyConnectedLayer : public IFunction
{
public:
/** Constructor */
- CLFullyConnectedLayer(std::shared_ptr<IMemoryManager> memory_manager = nullptr);
+ CLFullyConnectedLayer(std::shared_ptr<IMemoryManager> memory_manager = nullptr, IWeightsManager *weights_manager = nullptr);
/** Prevent instances of this class from being copied (As this class contains pointers) */
CLFullyConnectedLayer(const CLFullyConnectedLayer &) = delete;
/** Default move constructor */
diff --git a/arm_compute/runtime/CL/functions/CLGEMMConvolutionLayer.h b/arm_compute/runtime/CL/functions/CLGEMMConvolutionLayer.h
index d29a31a530..0b27c824d9 100644
--- a/arm_compute/runtime/CL/functions/CLGEMMConvolutionLayer.h
+++ b/arm_compute/runtime/CL/functions/CLGEMMConvolutionLayer.h
@@ -94,7 +94,7 @@ private:
class CLGEMMConvolutionLayer : public IFunction
{
public:
- /** Default constructor
+ /** Constructor
*
* @param[in] memory_manager (Optional) Memory manager.
*/
diff --git a/arm_compute/runtime/GLES_COMPUTE/functions/GCFullyConnectedLayer.h b/arm_compute/runtime/GLES_COMPUTE/functions/GCFullyConnectedLayer.h
index 6fcebd63b4..3a13e659f9 100644
--- a/arm_compute/runtime/GLES_COMPUTE/functions/GCFullyConnectedLayer.h
+++ b/arm_compute/runtime/GLES_COMPUTE/functions/GCFullyConnectedLayer.h
@@ -30,6 +30,7 @@
#include "arm_compute/core/GLES_COMPUTE/kernels/GCTransposeKernel.h"
#include "arm_compute/runtime/GLES_COMPUTE/GCTensor.h"
#include "arm_compute/runtime/GLES_COMPUTE/IGCSimpleFunction.h"
+#include "arm_compute/runtime/IWeightsManager.h"
#include "arm_compute/runtime/MemoryGroup.h"
namespace arm_compute
@@ -64,7 +65,7 @@ class GCFullyConnectedLayer : public IFunction
{
public:
/** Constructor */
- GCFullyConnectedLayer(std::shared_ptr<IMemoryManager> memory_manager = nullptr);
+ GCFullyConnectedLayer(std::shared_ptr<IMemoryManager> memory_manager = nullptr, IWeightsManager *weights_manager = nullptr);
/** Prevent instances of this class from being copied (As this class contains pointers) */
GCFullyConnectedLayer(const GCFullyConnectedLayer &) = delete;
/** Default move constructor */
diff --git a/arm_compute/runtime/ITransformWeights.h b/arm_compute/runtime/ITransformWeights.h
new file mode 100644
index 0000000000..6376c30088
--- /dev/null
+++ b/arm_compute/runtime/ITransformWeights.h
@@ -0,0 +1,117 @@
+/*
+ * Copyright (c) 2019 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#ifndef __ARM_COMPUTE_ITRANSFORMWEIGHTS_H__
+#define __ARM_COMPUTE_ITRANSFORMWEIGHTS_H__
+
+#include <atomic>
+
+namespace arm_compute
+{
+// Forward declarations
+class ITensor;
+
+/** Weights tensor transform interface
+ * In order to identify the different reshape functions, each reshape function has
+ * to generate a unique id. We use the following conversion using an unsigned 32bit value:
+ *
+ * Lower two bits store the target:
+ * 00 -> NEON
+ * 01 -> CL
+ * 10 -> GLES
+ * 11 -> Unused
+ *
+ * Five bits store the id of the reshape function:
+ * 00000 -> FullyConnectedLayerReshapeWeights
+ * 00001 -> ConvertFullyConnectedWeights
+ * 00010 -> ConvolutionLayerReshapeWeights
+ * 00011 -> DepthwiseConvolutionLayerReshapeWeights
+ * 00100 -> GEMMReshapeLHSMatrixKernel
+ * 00101 -> GEMMReshapeRHSMatrixKernel
+ *
+ * Rest of the bits are used for identifying special cases such as assembly functions and extra
+ * arguments in the reshape kernels.
+ *
+ * */
+class ITransformWeights
+{
+public:
+ /** Default Constructor */
+ ITransformWeights() = default;
+ /** Default Destructor */
+ virtual ~ITransformWeights() = default;
+ /** Prevent instances of this class to be copy constructed */
+ ITransformWeights(const ITransformWeights &) = delete;
+ /** Prevent instances of this class to be copied */
+ ITransformWeights &operator=(const ITransformWeights &) = delete;
+ /** Allow instances of this class to be move constructed */
+ ITransformWeights(ITransformWeights &&) = default;
+ /** Allow instances of this class to be moved */
+ ITransformWeights &operator=(ITransformWeights &&) = default;
+
+ /** Get a pointer to the transformed weights
+ *
+ * @return The pointer to the transformed ITensor weights
+ */
+ virtual ITensor *get_weights() = 0;
+ /** Function that returns a unique id of the reshape function
+ *
+ * @return The computed unique id
+ */
+ virtual uint32_t uid() = 0;
+ /** Run the transformation function */
+ virtual void run() = 0;
+ /** Release transformed weights memory */
+ virtual void release() = 0;
+ /** Increase the object's refcount */
+ void increase_refcount()
+ {
+ ++_num_refcount;
+ }
+
+ /** Decrease the object's refcount and return the updated value
+ *
+ * @return The updated refcount
+ * */
+ int32_t decrease_refcount()
+ {
+ return --_num_refcount;
+ }
+
+ /** Function that returns a flag on whether the weights are reshaped or not
+ *
+ * @return True if the function is reshaped
+ */
+ bool is_reshape_run()
+ {
+ return _reshape_run;
+ }
+
+protected:
+ std::atomic<int32_t> _num_refcount{ 0 };
+ bool _reshape_run{ false };
+};
+
+} // arm_compute
+
+#endif /*__ARM_COMPUTE_ITRANSFORMWEIGHTS_H__ */ \ No newline at end of file
diff --git a/arm_compute/runtime/IWeightsManager.h b/arm_compute/runtime/IWeightsManager.h
new file mode 100644
index 0000000000..2d61b89bc6
--- /dev/null
+++ b/arm_compute/runtime/IWeightsManager.h
@@ -0,0 +1,85 @@
+/*
+ * Copyright (c) 2019 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+
+#ifndef __ARM_COMPUTE_IWEIGHTSMANAGER_H__
+#define __ARM_COMPUTE_IWEIGHTSMANAGER_H__
+
+#include "arm_compute/core/ITensor.h"
+#include "arm_compute/runtime/ITransformWeights.h"
+
+#include <map>
+
+namespace arm_compute
+{
+/** Weights manager interface to handle weights transformations */
+class IWeightsManager
+{
+public:
+ /** Constructor */
+ IWeightsManager();
+ /** Default Destructor */
+ virtual ~IWeightsManager() = default;
+ /** Prevent instances of this class to be copy constructed */
+ IWeightsManager(const IWeightsManager &) = delete;
+ /** Prevent instances of this class to be copied */
+ IWeightsManager &operator=(const IWeightsManager &) = delete;
+ /** Allow instances of this class to be move constructed */
+ IWeightsManager(IWeightsManager &&) = default;
+ /** Allow instances of this class to be moved */
+ IWeightsManager &operator=(IWeightsManager &&) = default;
+
+ /** Start managing a weights tensor
+ *
+ * @param[in] weights Pointer to the weights tensor to be managed
+ * @param[in] parent Parent node in case where the weights are coming from a previous reshape function
+ */
+ void manage(const ITensor *weights, ITransformWeights *parent = nullptr);
+ /** Run the reshape function.
+ *
+ * @param[in] weights Pointer to the weights tensor we want to reshape
+ * @param[in] weights_transform Weights transformation object
+ *
+ * @return The reshaped tensor
+ */
+ ITensor *run(const ITensor *weights, ITransformWeights *weights_transform);
+ /** Acquire the requested reshape tensor of the selected weights
+ *
+ * @param[in] weights Pointer to the weights tensor to be managed
+ * @param[in] weights_transform Weights transformation object
+ */
+ ITensor *acquire(const ITensor *weights, ITransformWeights *weights_transform);
+ /** Check if the weights are managed
+ *
+ * @param[in] weights Pointer to the weights tensor we want to check if managed
+ *
+ * @return True if the weights tensor is managed else false
+ */
+ bool are_weights_managed(const ITensor *weights);
+
+private:
+ std::map<const ITensor *, std::vector<ITransformWeights *>> _managed_weights;
+ std::map<const ITensor *, ITransformWeights *> _managed_weights_parents;
+};
+} // arm_compute
+#endif /*__ARM_COMPUTE_IWEIGHTSMANAGER_H__ */ \ No newline at end of file
diff --git a/arm_compute/runtime/NEON/functions/NEConvertFullyConnectedWeights.h b/arm_compute/runtime/NEON/functions/NEConvertFullyConnectedWeights.h
index 8f261421e6..50a86bd7c4 100644
--- a/arm_compute/runtime/NEON/functions/NEConvertFullyConnectedWeights.h
+++ b/arm_compute/runtime/NEON/functions/NEConvertFullyConnectedWeights.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018 ARM Limited.
+ * Copyright (c) 2018-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -26,7 +26,9 @@
#include "arm_compute/core/NEON/kernels/NEConvertFullyConnectedWeightsKernel.h"
#include "arm_compute/runtime/IFunction.h"
+#include "arm_compute/runtime/ITransformWeights.h"
#include "arm_compute/runtime/NEON/NEScheduler.h"
+#include "arm_compute/runtime/Tensor.h"
namespace arm_compute
{
@@ -52,6 +54,8 @@ public:
* @param[in] output The converted weights tensor info. Shape and Data Type: Same as @p input.
* @param[in] original_input_shape Shape of the original input tensor (the one entering fully connected layer).
* @param[in] data_layout The data layout the weights have been trained in.
+ *
+ * @return A Status
*/
static Status validate(const ITensorInfo *input, const ITensorInfo *output, const TensorShape &original_input_shape, DataLayout data_layout);
@@ -61,5 +65,45 @@ public:
private:
NEConvertFullyConnectedWeightsKernel _kernel;
};
-}
+
+namespace weights_transformations
+{
+/** Basic function to run @ref NEConvertFullyConnectedWeightsKernel. */
+class NEConvertFullyConnectedWeightsManaged : public ITransformWeights
+{
+public:
+ void run() override
+ {
+ _output.allocator()->allocate();
+ _func.run();
+ _reshape_run = true;
+ }
+
+ void release() override
+ {
+ _output.allocator()->free();
+ }
+
+ ITensor *get_weights() override
+ {
+ return &_output;
+ }
+
+ uint32_t uid() override
+ {
+ return _uid;
+ }
+
+ void configure(const ITensor *input, const TensorShape &original_input_shape, DataLayout data_layout)
+ {
+ _func.configure(input, &_output, original_input_shape, data_layout);
+ }
+
+private:
+ static constexpr uint32_t _uid = 0x4;
+ Tensor _output{};
+ NEConvertFullyConnectedWeights _func{};
+};
+} // namespace weights_transformations
+} // namespace arm_compute
#endif /* __ARM_COMPUTE_NECONVERTFULLYCONNECTEDWEIGHTS_H__ */
diff --git a/arm_compute/runtime/NEON/functions/NEDeconvolutionLayer.h b/arm_compute/runtime/NEON/functions/NEDeconvolutionLayer.h
index 360bb23f22..6880bbba6b 100644
--- a/arm_compute/runtime/NEON/functions/NEDeconvolutionLayer.h
+++ b/arm_compute/runtime/NEON/functions/NEDeconvolutionLayer.h
@@ -73,7 +73,7 @@ namespace arm_compute
class NEDeconvolutionLayer : public IFunction
{
public:
- /** Default constructor */
+ /** Constructor */
NEDeconvolutionLayer(std::shared_ptr<IMemoryManager> memory_manager = nullptr);
/** Prevent instances of this class from being copied (As this class contains pointers) */
diff --git a/arm_compute/runtime/NEON/functions/NEFullyConnectedLayer.h b/arm_compute/runtime/NEON/functions/NEFullyConnectedLayer.h
index 56ce274572..b80e0e49e0 100644
--- a/arm_compute/runtime/NEON/functions/NEFullyConnectedLayer.h
+++ b/arm_compute/runtime/NEON/functions/NEFullyConnectedLayer.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2018 ARM Limited.
+ * Copyright (c) 2017-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -63,6 +63,46 @@ public:
static Status validate(const ITensorInfo *input, const ITensorInfo *output);
};
+namespace weights_transformations
+{
+/** Basic function to manage the reshape weights generated from @ref NEFullyConnectedLayerReshapeWeights */
+class NEFullyConnectedLayerReshapeWeightsManaged : public ITransformWeights
+{
+public:
+ void run() override
+ {
+ _output.allocator()->allocate();
+ _func.run();
+ _reshape_run = true;
+ }
+
+ void release() override
+ {
+ _output.allocator()->free();
+ }
+
+ ITensor *get_weights() override
+ {
+ return &_output;
+ }
+
+ uint32_t uid() override
+ {
+ return _uid;
+ }
+
+ void configure(const ITensor *input)
+ {
+ _func.configure(input, &_output);
+ }
+
+private:
+ static constexpr uint32_t _uid = 0x0;
+ Tensor _output{};
+ NEFullyConnectedLayerReshapeWeights _func{};
+};
+} // namespace weights_transformations
+
/** Basic function to compute a Fully Connected layer on NEON. This function calls the following NEON kernels:
* -# @ref NEIm2ColKernel (called when the input comes from a convolutional layer)
* -# @ref NEFullyConnectedLayerReshapeWeights (if @p are_weights_reshaped is set to false and transpose_weights is set to true ) (called once)
@@ -75,7 +115,7 @@ class NEFullyConnectedLayer : public IFunction
{
public:
/** Constructor */
- NEFullyConnectedLayer(std::shared_ptr<IMemoryManager> memory_manager = nullptr);
+ NEFullyConnectedLayer(std::shared_ptr<IMemoryManager> memory_manager = nullptr, IWeightsManager *weights_manager = nullptr);
/** Prevent instances of this class from being copied (As this class contains pointers) */
NEFullyConnectedLayer(const NEFullyConnectedLayer &) = delete;
/** Default move constructor */
@@ -128,25 +168,28 @@ private:
void configure_conv_fc(const ITensor *input, const ITensor *weights, ITensor *output);
void configure_mm(const ITensor *input, const ITensor *weights, ITensor *output);
- MemoryGroup _memory_group;
- NEFlattenLayerKernel _flatten_kernel;
- NEConvertFullyConnectedWeights _convert_weights;
- NEFullyConnectedLayerReshapeWeights _reshape_weights_function;
- NEGEMM _mm_gemm;
- NEGEMMLowpMatrixMultiplyCore _mm_gemmlowp;
- NEGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPoint _gemmlowp_output_stage;
- NEGEMMMatrixAccumulateBiasesKernel _accumulate_biases_kernel;
- Tensor _flatten_output;
- Tensor _gemmlowp_output;
- Tensor _converted_weights_output;
- Tensor _reshape_weights_output;
- const ITensor *_original_weights;
- bool _are_weights_converted;
- bool _are_weights_reshaped;
- bool _is_fc_after_conv;
- bool _accumulate_biases;
- bool _is_quantized;
- bool _is_prepared;
+ MemoryGroup _memory_group;
+ IWeightsManager *_weights_manager;
+ NEFlattenLayerKernel _flatten_kernel;
+ NEConvertFullyConnectedWeights _convert_weights;
+ weights_transformations::NEConvertFullyConnectedWeightsManaged _convert_weights_managed;
+ NEFullyConnectedLayerReshapeWeights _reshape_weights_function;
+ weights_transformations::NEFullyConnectedLayerReshapeWeightsManaged _reshape_weights_managed_function;
+ NEGEMM _mm_gemm;
+ NEGEMMLowpMatrixMultiplyCore _mm_gemmlowp;
+ NEGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPoint _gemmlowp_output_stage;
+ NEGEMMMatrixAccumulateBiasesKernel _accumulate_biases_kernel;
+ Tensor _flatten_output;
+ Tensor _gemmlowp_output;
+ Tensor _converted_weights_output;
+ Tensor _reshape_weights_output;
+ const ITensor *_original_weights;
+ bool _are_weights_converted;
+ bool _are_weights_reshaped;
+ bool _is_fc_after_conv;
+ bool _accumulate_biases;
+ bool _is_quantized;
+ bool _is_prepared;
};
} // namespace arm_compute
#endif /* __ARM_COMPUTE_NEFULLYCONNECTEDLAYER_H__ */
diff --git a/arm_compute/runtime/NEON/functions/NEGEMM.h b/arm_compute/runtime/NEON/functions/NEGEMM.h
index 7f9e3181bc..d947be1ef9 100644
--- a/arm_compute/runtime/NEON/functions/NEGEMM.h
+++ b/arm_compute/runtime/NEON/functions/NEGEMM.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2018 ARM Limited.
+ * Copyright (c) 2017-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -31,6 +31,7 @@
#include "arm_compute/core/NEON/kernels/NEGEMMTranspose1xWKernel.h"
#include "arm_compute/runtime/IFunction.h"
#include "arm_compute/runtime/IMemoryManager.h"
+#include "arm_compute/runtime/IWeightsManager.h"
#include "arm_compute/runtime/MemoryGroup.h"
#include "arm_compute/runtime/NEON/functions/NEGEMMAssemblyDispatch.h"
#include "arm_compute/runtime/Tensor.h"
@@ -51,7 +52,7 @@ class NEGEMM : public IFunction
{
public:
/** Constructor */
- NEGEMM(std::shared_ptr<IMemoryManager> memory_manager = nullptr);
+ NEGEMM(std::shared_ptr<IMemoryManager> memory_manager = nullptr, IWeightsManager *weights_manager = nullptr);
/** Prevent instances of this class from being copied (As this class contains pointers) */
NEGEMM(const NEGEMM &) = delete;
/** Default move constructor */
@@ -96,6 +97,7 @@ public:
private:
MemoryGroup _memory_group;
+ IWeightsManager *_weights_manager;
NEGEMMInterleave4x4Kernel _interleave_kernel;
NEGEMMTranspose1xWKernel _transpose_kernel;
NEGEMMMatrixMultiplyKernel _mm_kernel;
diff --git a/arm_compute/runtime/NEON/functions/NEGEMMAssemblyDispatch.h b/arm_compute/runtime/NEON/functions/NEGEMMAssemblyDispatch.h
index ec4f700034..83e495e695 100644
--- a/arm_compute/runtime/NEON/functions/NEGEMMAssemblyDispatch.h
+++ b/arm_compute/runtime/NEON/functions/NEGEMMAssemblyDispatch.h
@@ -27,6 +27,7 @@
#include "arm_compute/core/NEON/kernels/assembly/NEGEMMAssemblyWrapperKernel.h"
#include "arm_compute/runtime/IFunction.h"
#include "arm_compute/runtime/IMemoryManager.h"
+#include "arm_compute/runtime/IWeightsManager.h"
#include "arm_compute/runtime/MemoryGroup.h"
#include "arm_compute/runtime/Tensor.h"
@@ -38,9 +39,8 @@ namespace arm_compute
class NEGEMMAssemblyDispatch : public IFunction
{
public:
- /** Default constructor */
- NEGEMMAssemblyDispatch(std::shared_ptr<IMemoryManager> memory_manager = nullptr);
-
+ /** Constructor */
+ NEGEMMAssemblyDispatch(std::shared_ptr<IMemoryManager> memory_manager = nullptr, IWeightsManager *weights_manager = nullptr);
/** Prevent instances of this class from being copy constructed */
NEGEMMAssemblyDispatch(const NEGEMMAssemblyDispatch &) = delete;
/** Prevent instances of this class from being copied */
@@ -79,8 +79,9 @@ private:
/** Interface for the arm_gemm fallback */
std::unique_ptr<IFallback> _arm_gemm;
- MemoryGroup _memory_group; /**< Function memory group */
- std::shared_ptr<IMemoryManager> _memory_manager; /**< Copy of the memory manager used to create the memory group to be used when instantiating new functions */
+ MemoryGroup _memory_group; /**< Function memory group */
+ std::shared_ptr<IMemoryManager> _memory_manager; /**< Copy of the memory manager used to create the memory group to be used when instantiating new functions */
+ IWeightsManager *_weights_manager; /**< Pointer to the weights manager */
public:
/** If supported create an ACL function else fallback to the arm_gemm function.
*
@@ -117,6 +118,5 @@ public:
void prepare() override;
void run() override;
};
-
} // namespace arm_compute
#endif /* __ARM_COMPUTE_NEGEMMASSEMBLYDISPATCH_H__ */
diff --git a/arm_compute/runtime/NEON/functions/NEGEMMConvolutionLayer.h b/arm_compute/runtime/NEON/functions/NEGEMMConvolutionLayer.h
index ace924f146..dccc35f0af 100644
--- a/arm_compute/runtime/NEON/functions/NEGEMMConvolutionLayer.h
+++ b/arm_compute/runtime/NEON/functions/NEGEMMConvolutionLayer.h
@@ -32,6 +32,7 @@
#include "arm_compute/core/NEON/kernels/NEIm2ColKernel.h"
#include "arm_compute/core/NEON/kernels/NEWeightsReshapeKernel.h"
#include "arm_compute/core/Types.h"
+#include "arm_compute/runtime/IWeightsManager.h"
#include "arm_compute/runtime/MemoryGroup.h"
#include "arm_compute/runtime/NEON/functions/NEActivationLayer.h"
#include "arm_compute/runtime/NEON/functions/NEGEMM.h"
@@ -54,6 +55,14 @@ class NEConvolutionLayerReshapeWeights : public IFunction
public:
/** Constructor */
NEConvolutionLayerReshapeWeights();
+ /** Prevent instances of this class from being copied (As this class contains pointers) */
+ NEConvolutionLayerReshapeWeights(const NEConvolutionLayerReshapeWeights &) = delete;
+ /** Default move constructor */
+ NEConvolutionLayerReshapeWeights(NEConvolutionLayerReshapeWeights &&) = default;
+ /** Prevent instances of this class from being copied (As this class contains pointers) */
+ NEConvolutionLayerReshapeWeights &operator=(const NEConvolutionLayerReshapeWeights &) = delete;
+ /** Default move assignment operator */
+ NEConvolutionLayerReshapeWeights &operator=(NEConvolutionLayerReshapeWeights &&) = default;
/** Set the input and output tensors.
*
* @param[in] weights Weights tensor. Weights are 4D tensor with dimensions [kernel_x, kernel_y, IFM, OFM]. Data type supported: QASYMM8/F16/F32.
@@ -78,6 +87,52 @@ private:
NEWeightsReshapeKernel _weights_reshape_kernel;
};
+namespace weights_transformations
+{
+/** Basic function to manage the reshape weights generated from @ref NEConvolutionLayerReshapeWeights */
+class NEConvolutionLayerReshapeWeightsTransform : public ITransformWeights
+{
+public:
+ void configure(const ITensor *input, const ITensor *biases)
+ {
+ _bias_bit = (biases != nullptr) ? 1 : 0;
+ _func.configure(input, biases, &_output);
+ }
+
+ void run() override
+ {
+ _output.allocator()->allocate();
+ _func.run();
+ _reshape_run = true;
+ }
+
+ ITensor *get_weights() override
+ {
+ return &_output;
+ }
+
+ void release() override
+ {
+ _output.allocator()->free();
+ }
+
+ uint32_t uid() override
+ {
+ return ((0x8) | (_bias_bit << 7));
+ }
+
+ bool is_reshape_run()
+ {
+ return _reshape_run;
+ }
+
+private:
+ Tensor _output{};
+ NEConvolutionLayerReshapeWeights _func{};
+ int32_t _bias_bit{ 0 };
+};
+} // namespace weights_transformations
+
/** Basic function to compute the convolution layer. This function calls the following NEON kernels/functions:
*
* -# @ref NEIm2ColKernel
@@ -92,7 +147,7 @@ class NEGEMMConvolutionLayer : public IFunction
{
public:
/** Constructor */
- NEGEMMConvolutionLayer(const std::shared_ptr<IMemoryManager> &memory_manager = nullptr);
+ NEGEMMConvolutionLayer(const std::shared_ptr<IMemoryManager> &memory_manager = nullptr, IWeightsManager *weights_manager = nullptr);
/** Prevent instances of this class from being copied (As this class contains pointers) */
NEGEMMConvolutionLayer(const NEGEMMConvolutionLayer &) = delete;
/** Default move constructor */
@@ -187,15 +242,17 @@ private:
static Status validate_gemm3d(const ITensorInfo *input_info, const ActivationLayerInfo &act_info, int gemm_3d_depth, bool skip_im2col);
private:
- MemoryGroup _memory_group;
- NEConvolutionLayerReshapeWeights _reshape_weights;
- NEIm2ColKernel _im2col_kernel;
- NEGEMM _mm_gemm;
- NEGEMMLowpMatrixMultiplyCore _mm_gemmlowp;
- NECol2ImKernel _col2im_kernel;
- NEActivationLayer _activationlayer_function;
- NEArithmeticAdditionKernel _add_bias_kernel;
- NEReshapeLayer _reshape_layer;
+ MemoryGroup _memory_group;
+ IWeightsManager *_weights_manager;
+ NEConvolutionLayerReshapeWeights _reshape_weights;
+ weights_transformations::NEConvolutionLayerReshapeWeightsTransform _reshape_weights_managed;
+ NEIm2ColKernel _im2col_kernel;
+ NEGEMM _mm_gemm;
+ NEGEMMLowpMatrixMultiplyCore _mm_gemmlowp;
+ NECol2ImKernel _col2im_kernel;
+ NEActivationLayer _activationlayer_function;
+ NEArithmeticAdditionKernel _add_bias_kernel;
+ NEReshapeLayer _reshape_layer;
const ITensor *_original_weights;
diff --git a/arm_compute/runtime/NEON/functions/NERNNLayer.h b/arm_compute/runtime/NEON/functions/NERNNLayer.h
index ec394392de..978c445927 100644
--- a/arm_compute/runtime/NEON/functions/NERNNLayer.h
+++ b/arm_compute/runtime/NEON/functions/NERNNLayer.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018 ARM Limited.
+ * Copyright (c) 2018-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -86,7 +86,7 @@ private:
NEGEMM _gemm_state_f;
NEArithmeticAdditionKernel _add_kernel;
NEActivationLayerKernel _activation_kernel;
- NEFullyConnectedLayer _fully_connected_kernel;
+ NEFullyConnectedLayer _fully_connected;
NECopyKernel _copy_kernel;
Tensor _fully_connected_out;
Tensor _gemm_output;
diff --git a/arm_compute/runtime/NEON/functions/assembly/NEGEMMInterleavedWrapper.h b/arm_compute/runtime/NEON/functions/assembly/NEGEMMInterleavedWrapper.h
index ad89e1fbec..d3dda9a95f 100644
--- a/arm_compute/runtime/NEON/functions/assembly/NEGEMMInterleavedWrapper.h
+++ b/arm_compute/runtime/NEON/functions/assembly/NEGEMMInterleavedWrapper.h
@@ -32,6 +32,7 @@
#include "arm_compute/runtime/IFunction.h"
#include "arm_compute/runtime/IMemoryManager.h"
#include "arm_compute/runtime/IScheduler.h"
+#include "arm_compute/runtime/IWeightsManager.h"
#include "arm_compute/runtime/MemoryGroup.h"
#include "arm_compute/runtime/Tensor.h"
@@ -94,8 +95,8 @@ public:
class NEGEMMInterleavedWrapper : public IFunction
{
public:
- NEGEMMInterleavedWrapper(std::shared_ptr<IMemoryManager> memory_manager = nullptr);
- ~NEGEMMInterleavedWrapper() = default;
+ NEGEMMInterleavedWrapper(std::shared_ptr<IMemoryManager> memory_manager = nullptr, IWeightsManager *weights_manager = nullptr);
+ ~NEGEMMInterleavedWrapper() = default;
NEGEMMInterleavedWrapper(const NEGEMMInterleavedWrapper &) = delete;
NEGEMMInterleavedWrapper &operator=(const NEGEMMInterleavedWrapper &) = delete;
diff --git a/src/graph/GraphContext.cpp b/src/graph/GraphContext.cpp
index 037b40b68b..c959d5e35c 100644
--- a/src/graph/GraphContext.cpp
+++ b/src/graph/GraphContext.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018 ARM Limited.
+ * Copyright (c) 2018-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -32,13 +32,14 @@ namespace arm_compute
namespace graph
{
GraphContext::GraphContext()
- : _config(), _memory_managers()
+ : _config(), _memory_managers(), _weights_managers()
{
}
GraphContext::~GraphContext()
{
_memory_managers.clear();
+ _weights_managers.clear();
release_default_graph_context(*this);
}
@@ -74,6 +75,30 @@ std::map<Target, MemoryManagerContext> &GraphContext::memory_managers()
return _memory_managers;
}
+bool GraphContext::insert_weights_management_ctx(WeightsManagerContext &&weights_managers)
+{
+ Target target = weights_managers.target;
+
+ if(target != Target::NEON || _weights_managers.find(target) != std::end(_weights_managers))
+ {
+ return false;
+ }
+
+ _weights_managers[target] = std::move(weights_managers);
+
+ return true;
+}
+
+WeightsManagerContext *GraphContext::weights_management_ctx(Target target)
+{
+ return (_weights_managers.find(target) != std::end(_weights_managers)) ? &_weights_managers[target] : nullptr;
+}
+
+std::map<Target, WeightsManagerContext> &GraphContext::weights_managers()
+{
+ return _weights_managers;
+}
+
void GraphContext::finalize()
{
const size_t num_pools = 1;
diff --git a/src/graph/backends/CL/CLDeviceBackend.cpp b/src/graph/backends/CL/CLDeviceBackend.cpp
index 9971e4fc9a..9b7c879b2a 100644
--- a/src/graph/backends/CL/CLDeviceBackend.cpp
+++ b/src/graph/backends/CL/CLDeviceBackend.cpp
@@ -204,6 +204,11 @@ std::shared_ptr<arm_compute::IMemoryManager> CLDeviceBackend::create_memory_mana
return mm;
}
+
+std::shared_ptr<arm_compute::IWeightsManager> CLDeviceBackend::create_weights_manager()
+{
+ return nullptr;
+}
} // namespace backends
} // namespace graph
} // namespace arm_compute
diff --git a/src/graph/backends/GLES/GCDeviceBackend.cpp b/src/graph/backends/GLES/GCDeviceBackend.cpp
index 058f7792e7..83e2436ddb 100644
--- a/src/graph/backends/GLES/GCDeviceBackend.cpp
+++ b/src/graph/backends/GLES/GCDeviceBackend.cpp
@@ -154,6 +154,11 @@ std::shared_ptr<arm_compute::IMemoryManager> GCDeviceBackend::create_memory_mana
return mm;
}
+
+std::shared_ptr<arm_compute::IWeightsManager> GCDeviceBackend::create_weights_manager()
+{
+ return nullptr;
+}
} // namespace backends
} // namespace graph
} // namespace arm_compute
diff --git a/src/graph/backends/NEON/NEDeviceBackend.cpp b/src/graph/backends/NEON/NEDeviceBackend.cpp
index f94cd97cd2..017b4f0f24 100644
--- a/src/graph/backends/NEON/NEDeviceBackend.cpp
+++ b/src/graph/backends/NEON/NEDeviceBackend.cpp
@@ -37,6 +37,7 @@
#include "arm_compute/core/TensorInfo.h"
#include "arm_compute/runtime/Allocator.h"
#include "arm_compute/runtime/BlobLifetimeManager.h"
+#include "arm_compute/runtime/IWeightsManager.h"
#include "arm_compute/runtime/MemoryGroup.h"
#include "arm_compute/runtime/MemoryManagerOnDemand.h"
#include "arm_compute/runtime/OffsetLifetimeManager.h"
@@ -90,6 +91,16 @@ void NEDeviceBackend::setup_backend_context(GraphContext &ctx)
ctx.insert_memory_management_ctx(std::move(mm_ctx));
}
+
+ // Create function level weights manager
+ if(ctx.weights_management_ctx(Target::NEON) == nullptr)
+ {
+ WeightsManagerContext wm_ctx;
+ wm_ctx.target = Target::NEON;
+ wm_ctx.wm = create_weights_manager();
+
+ ctx.insert_weights_management_ctx(std::move(wm_ctx));
+ }
}
bool NEDeviceBackend::is_backend_supported()
@@ -159,6 +170,12 @@ std::shared_ptr<arm_compute::IMemoryManager> NEDeviceBackend::create_memory_mana
return mm;
}
+
+std::shared_ptr<arm_compute::IWeightsManager> NEDeviceBackend::create_weights_manager()
+{
+ auto weights_mgr = std::make_shared<IWeightsManager>();
+ return weights_mgr;
+}
} // namespace backends
} // namespace graph
} // namespace arm_compute
diff --git a/src/graph/backends/NEON/NEFunctionFactory.cpp b/src/graph/backends/NEON/NEFunctionFactory.cpp
index 852de549fa..45e9727133 100644
--- a/src/graph/backends/NEON/NEFunctionFactory.cpp
+++ b/src/graph/backends/NEON/NEFunctionFactory.cpp
@@ -115,6 +115,7 @@ std::unique_ptr<IFunction> create_convolution_layer<NEConvolutionLayerFunctions,
std::shared_ptr<IMemoryManager> mm = get_memory_manager(ctx, Target::NEON);
std::unique_ptr<IFunction> func;
std::string func_name;
+
if(conv_algorithm == ConvolutionMethod::Direct)
{
std::tie(func, func_name) = create_named_memory_managed_function<NEDirectConvolutionLayer>(
diff --git a/src/runtime/CL/functions/CLFullyConnectedLayer.cpp b/src/runtime/CL/functions/CLFullyConnectedLayer.cpp
index c5da649e30..0452a236c5 100644
--- a/src/runtime/CL/functions/CLFullyConnectedLayer.cpp
+++ b/src/runtime/CL/functions/CLFullyConnectedLayer.cpp
@@ -76,7 +76,7 @@ Status CLFullyConnectedLayerReshapeWeights::validate(const ITensorInfo *input, c
return CLTransposeKernel::validate(input, output);
}
-CLFullyConnectedLayer::CLFullyConnectedLayer(std::shared_ptr<IMemoryManager> memory_manager)
+CLFullyConnectedLayer::CLFullyConnectedLayer(std::shared_ptr<IMemoryManager> memory_manager, IWeightsManager *weights_manager)
: _memory_group(memory_manager), _convert_weights(), _flatten_layer(), _reshape_weights_kernel(), _mm_gemm(memory_manager), _mm_gemmlowp(memory_manager), _gemmlowp_output_stage(),
_accumulate_biases_kernel(), _flatten_output(), _gemmlowp_output(), _converted_weights_output(), _reshape_weights_output(), _are_weights_converted(true), _are_weights_reshaped(true),
_is_fc_after_conv(true), _accumulate_biases(false), _is_quantized(false), _is_prepared(false), _original_weights(nullptr)
diff --git a/src/runtime/GLES_COMPUTE/functions/GCFullyConnectedLayer.cpp b/src/runtime/GLES_COMPUTE/functions/GCFullyConnectedLayer.cpp
index a208545a99..4ccda88279 100644
--- a/src/runtime/GLES_COMPUTE/functions/GCFullyConnectedLayer.cpp
+++ b/src/runtime/GLES_COMPUTE/functions/GCFullyConnectedLayer.cpp
@@ -38,7 +38,7 @@ void GCFullyConnectedLayerReshapeWeights::configure(const IGCTensor *input, IGCT
_kernel = std::move(k);
}
-GCFullyConnectedLayer::GCFullyConnectedLayer(std::shared_ptr<IMemoryManager> memory_manager)
+GCFullyConnectedLayer::GCFullyConnectedLayer(std::shared_ptr<IMemoryManager> memory_manager, IWeightsManager *weights_manager)
: _memory_group(std::move(memory_manager)), _im2col_kernel(), _reshape_weights_kernel(), _mm_kernel(), _accumulate_biases_kernel(), _im2col_output(), _reshape_weights_output(),
_original_weights(nullptr), _are_weights_reshaped(true), _is_fc_after_conv(true), _accumulate_biases(false)
{
diff --git a/src/runtime/IWeightsManager.cpp b/src/runtime/IWeightsManager.cpp
new file mode 100644
index 0000000000..6dfb925fe6
--- /dev/null
+++ b/src/runtime/IWeightsManager.cpp
@@ -0,0 +1,128 @@
+/*
+ * Copyright (c) 2019 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "arm_compute/runtime/IWeightsManager.h"
+
+namespace arm_compute
+{
+IWeightsManager::IWeightsManager()
+ : _managed_weights(), _managed_weights_parents()
+{
+}
+
+void IWeightsManager::manage(const ITensor *weights, ITransformWeights *parent)
+{
+ if(!are_weights_managed(weights))
+ {
+ _managed_weights[weights];
+ }
+
+ // In case the weights are an output of a previous reshape function
+ // store the parent's link
+ if(parent != nullptr)
+ {
+ if(_managed_weights_parents.find(weights) == _managed_weights_parents.end())
+ {
+ _managed_weights_parents[weights] = parent;
+ }
+ }
+}
+
+ITensor *IWeightsManager::run(const ITensor *weights, ITransformWeights *weights_transform)
+{
+ ARM_COMPUTE_ERROR_ON_MSG(!are_weights_managed(weights), "Cannot run function. Weights are not managed");
+
+ // Find if I have the same weights with weights transform. If I do, don't run the reshape
+ auto item = _managed_weights.find(weights);
+ bool perform_run{ true };
+ ITensor *weights_tensor{ nullptr };
+
+ // Check if I already have the requested transform and I have run the reshape function
+ for(auto it : item->second)
+ {
+ if(it->is_reshape_run() && (it->uid() == weights_transform->uid()))
+ {
+ weights_tensor = it->get_weights();
+ perform_run = false;
+ break;
+ }
+ }
+
+ if(perform_run)
+ {
+ weights_transform->run();
+ weights_tensor = weights_transform->get_weights();
+ }
+
+ // Check if we can release memory from parent
+ auto parent_item = _managed_weights_parents.find(weights);
+ if(parent_item != _managed_weights_parents.end())
+ {
+ int32_t refcount = parent_item->second->decrease_refcount();
+ if(refcount == 0)
+ {
+ parent_item->second->release();
+ }
+ }
+
+ return weights_tensor;
+}
+
+bool IWeightsManager::are_weights_managed(const ITensor *weights)
+{
+ return (_managed_weights.find(weights) != _managed_weights.end());
+}
+
+ITensor *IWeightsManager::acquire(const ITensor *weights, ITransformWeights *weights_transform)
+{
+ ARM_COMPUTE_ERROR_ON_MSG(!are_weights_managed(weights), "Cannot acquire weights. Weights are not managed");
+
+ ITensor *transformed_weights{ nullptr };
+ auto item = _managed_weights.find(weights);
+
+ // Check if I already have the requested transform. If I do,
+ // increase the refcount of the transformed weights object and
+ // reuse the tensor
+ for(auto it : item->second)
+ {
+ if(it->uid() == weights_transform->uid())
+ {
+ transformed_weights = it->get_weights();
+ it->increase_refcount();
+ break;
+ }
+ }
+
+ if(transformed_weights == nullptr)
+ {
+ transformed_weights = weights_transform->get_weights();
+ weights_transform->increase_refcount();
+ item->second.emplace_back(weights_transform);
+ }
+
+ // Manage the weights and store link to the parent node
+ manage(transformed_weights, weights_transform);
+
+ return transformed_weights;
+}
+} // namespace arm_compute
diff --git a/src/runtime/NEON/functions/NEConvertFullyConnectedWeights.cpp b/src/runtime/NEON/functions/NEConvertFullyConnectedWeights.cpp
index b5b159ae31..f65c035da6 100644
--- a/src/runtime/NEON/functions/NEConvertFullyConnectedWeights.cpp
+++ b/src/runtime/NEON/functions/NEConvertFullyConnectedWeights.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018 ARM Limited.
+ * Copyright (c) 2018-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -23,8 +23,8 @@
*/
#include "arm_compute/runtime/NEON/functions/NEConvertFullyConnectedWeights.h"
-using namespace arm_compute;
-
+namespace arm_compute
+{
NEConvertFullyConnectedWeights::NEConvertFullyConnectedWeights()
: _kernel()
{
@@ -46,3 +46,4 @@ void NEConvertFullyConnectedWeights::run()
{
NEScheduler::get().schedule(&_kernel, Window::DimZ);
}
+} // namespace arm_compute \ No newline at end of file
diff --git a/src/runtime/NEON/functions/NEDeconvolutionLayer.cpp b/src/runtime/NEON/functions/NEDeconvolutionLayer.cpp
index bbb91b4651..0411b41220 100644
--- a/src/runtime/NEON/functions/NEDeconvolutionLayer.cpp
+++ b/src/runtime/NEON/functions/NEDeconvolutionLayer.cpp
@@ -91,10 +91,10 @@ Status NEDeconvolutionLayer::validate(const ITensorInfo *input, const ITensorInf
ARM_COMPUTE_RETURN_ERROR_ON_MSG(output->dimension(Window::DimZ) != output_shape.z(), "Output's depth is invalid.");
}
- unsigned int deconv_pad_x = 0;
- unsigned int deconv_pad_y = 0;
- const unsigned int stride_x = info.stride().first;
- const unsigned int stride_y = info.stride().second;
+ unsigned int deconv_pad_x = 0;
+ unsigned int deconv_pad_y = 0;
+ const unsigned int stride_x = info.stride().first;
+ const unsigned int stride_y = info.stride().second;
const TensorShape scale_out_shape = compute_deconvolution_upsampled_shape(*input, *weights, stride_x, stride_y, out_dims, deconv_pad_x, deconv_pad_y);
TensorInfo scale_out_info(input->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(scale_out_shape));
const PadStrideInfo conv_info(1, 1, 0, 0, 0, 0, DimensionRoundingType::CEIL);
@@ -127,8 +127,8 @@ void NEDeconvolutionLayer::configure(ITensor *input, const ITensor *weights, con
const unsigned int pad_right = info.pad_right();
const unsigned int pad_top = info.pad_top();
const unsigned int pad_bottom = info.pad_bottom();
- const unsigned int stride_x = info.stride().first;
- const unsigned int stride_y = info.stride().second;
+ const unsigned int stride_x = info.stride().first;
+ const unsigned int stride_y = info.stride().second;
const unsigned int width_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
const unsigned int height_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
@@ -166,14 +166,14 @@ void NEDeconvolutionLayer::configure(ITensor *input, const ITensor *weights, con
unsigned int deconv_pad_right = pad_left > pad_right ? pad_left - pad_right : 0;
deconv_pad_x -= deconv_pad_left + deconv_pad_right;
ARM_COMPUTE_ERROR_ON((deconv_pad_x % 2) != 0);
- deconv_pad_left += deconv_pad_x / 2;
+ deconv_pad_left += deconv_pad_x / 2;
deconv_pad_right += deconv_pad_x / 2;
unsigned int deconv_pad_top = pad_bottom > pad_top ? pad_bottom - pad_top : 0;
unsigned int deconv_pad_bottom = pad_top > pad_bottom ? pad_top - pad_bottom : 0;
deconv_pad_y -= deconv_pad_top + deconv_pad_bottom;
ARM_COMPUTE_ERROR_ON((deconv_pad_y % 2) != 0);
- deconv_pad_top += deconv_pad_y / 2;
+ deconv_pad_top += deconv_pad_y / 2;
deconv_pad_bottom += deconv_pad_y / 2;
TensorInfo scale_out_info(scale_out_shape, 1, _permuted_input.info()->data_type(), _permuted_input.info()->quantization_info());
@@ -212,14 +212,14 @@ void NEDeconvolutionLayer::configure(ITensor *input, const ITensor *weights, con
unsigned int deconv_pad_right = pad_left > pad_right ? pad_left - pad_right : 0;
deconv_pad_x -= deconv_pad_left + deconv_pad_right;
ARM_COMPUTE_ERROR_ON((deconv_pad_x % 2) != 0);
- deconv_pad_left += deconv_pad_x / 2;
+ deconv_pad_left += deconv_pad_x / 2;
deconv_pad_right += deconv_pad_x / 2;
unsigned int deconv_pad_top = pad_bottom > pad_top ? pad_bottom - pad_top : 0;
unsigned int deconv_pad_bottom = pad_top > pad_bottom ? pad_top - pad_bottom : 0;
deconv_pad_y -= deconv_pad_top + deconv_pad_bottom;
ARM_COMPUTE_ERROR_ON((deconv_pad_y % 2) != 0);
- deconv_pad_top += deconv_pad_y / 2;
+ deconv_pad_top += deconv_pad_y / 2;
deconv_pad_bottom += deconv_pad_y / 2;
TensorInfo scale_out_info(scale_out_shape, 1, input->info()->data_type(), input->info()->quantization_info());
diff --git a/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp b/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp
index 12a5a1d724..7adc3bca9e 100644
--- a/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp
+++ b/src/runtime/NEON/functions/NEFullyConnectedLayer.cpp
@@ -74,10 +74,11 @@ Status NEFullyConnectedLayerReshapeWeights::validate(const ITensorInfo *input, c
return NETransposeKernel::validate(input, output);
}
-NEFullyConnectedLayer::NEFullyConnectedLayer(std::shared_ptr<IMemoryManager> memory_manager)
- : _memory_group(std::move(memory_manager)), _flatten_kernel(), _convert_weights(), _reshape_weights_function(), _mm_gemm(), _mm_gemmlowp(), _gemmlowp_output_stage(), _accumulate_biases_kernel(),
- _flatten_output(), _gemmlowp_output(), _converted_weights_output(), _reshape_weights_output(), _original_weights(nullptr), _are_weights_converted(true), _are_weights_reshaped(false),
- _is_fc_after_conv(false), _accumulate_biases(false), _is_quantized(false), _is_prepared(false)
+NEFullyConnectedLayer::NEFullyConnectedLayer(std::shared_ptr<IMemoryManager> memory_manager, IWeightsManager *weights_manager)
+ : _memory_group(std::move(memory_manager)), _weights_manager(weights_manager), _flatten_kernel(), _convert_weights(), _convert_weights_managed(), _reshape_weights_function(),
+ _reshape_weights_managed_function(), _mm_gemm(nullptr, weights_manager), _mm_gemmlowp(), _gemmlowp_output_stage(), _accumulate_biases_kernel(), _flatten_output(), _gemmlowp_output(),
+ _converted_weights_output(), _reshape_weights_output(), _original_weights(nullptr), _are_weights_converted(true), _are_weights_reshaped(false), _is_fc_after_conv(false), _accumulate_biases(false),
+ _is_quantized(false), _is_prepared(false)
{
}
@@ -155,6 +156,11 @@ void NEFullyConnectedLayer::configure(const ITensor *input, const ITensor *weigh
_is_quantized = is_data_type_quantized_asymmetric(input->info()->data_type());
_original_weights = weights;
+ if(_weights_manager)
+ {
+ _weights_manager->manage(weights);
+ }
+
// Configure gemmlowp output
if(_is_quantized)
{
@@ -194,21 +200,39 @@ void NEFullyConnectedLayer::configure(const ITensor *input, const ITensor *weigh
// Reshape weights if needed
if(!_are_weights_reshaped)
{
- // Reshape the weights
- _reshape_weights_function.configure(weights, &_reshape_weights_output);
- weights_to_use = &_reshape_weights_output;
+ if(_weights_manager && _weights_manager->are_weights_managed(weights))
+ {
+ _reshape_weights_managed_function.configure(weights);
+ weights_to_use = _weights_manager->acquire(weights, &_reshape_weights_managed_function);
+ }
+ else
+ {
+ // Reshape the weights
+ _reshape_weights_function.configure(weights, &_reshape_weights_output);
+ weights_to_use = &_reshape_weights_output;
+ }
}
// Convert weights if needed
if(_is_fc_after_conv && (input->info()->data_layout() != fc_info.weights_trained_layout))
{
- // Convert weights
- _convert_weights.configure(weights_to_use,
- &_converted_weights_output,
- input->info()->tensor_shape(),
- fc_info.weights_trained_layout);
+ if(_weights_manager && _weights_manager->are_weights_managed(weights_to_use))
+ {
+ _convert_weights_managed.configure(weights_to_use,
+ input->info()->tensor_shape(),
+ fc_info.weights_trained_layout);
+ weights_to_use = _weights_manager->acquire(weights, &_convert_weights_managed);
+ }
+ else
+ {
+ // Convert weights
+ _convert_weights.configure(weights_to_use,
+ &_converted_weights_output,
+ input->info()->tensor_shape(),
+ fc_info.weights_trained_layout);
- weights_to_use = &_converted_weights_output;
+ weights_to_use = &_converted_weights_output;
+ }
_are_weights_converted = false;
}
@@ -381,7 +405,10 @@ void NEFullyConnectedLayer::prepare()
{
if(!_is_prepared)
{
- ARM_COMPUTE_ERROR_ON(!_original_weights->is_used());
+ if(!_weights_manager)
+ {
+ ARM_COMPUTE_ERROR_ON(!_original_weights->is_used());
+ }
auto release_unused = [](Tensor * w)
{
@@ -397,20 +424,38 @@ void NEFullyConnectedLayer::prepare()
// Reshape of the weights (happens only once)
if(!_are_weights_reshaped)
{
- // Run reshape weights kernel and mark weights as unused
- _reshape_weights_output.allocator()->allocate();
- _reshape_weights_function.run();
-
- cur_weights->mark_as_unused();
- cur_weights = &_reshape_weights_output;
+ if(_weights_manager && _weights_manager->are_weights_managed(_original_weights))
+ {
+ cur_weights->mark_as_unused();
+ cur_weights = _weights_manager->run(cur_weights, &_reshape_weights_managed_function);
+ }
+ else
+ {
+ // Reshape of the weights (happens only once)
+ if(!_are_weights_reshaped)
+ {
+ // Run reshape weights kernel and mark weights as unused
+ _reshape_weights_output.allocator()->allocate();
+ _reshape_weights_function.run();
+ }
+ cur_weights->mark_as_unused();
+ cur_weights = &_reshape_weights_output;
+ }
_are_weights_reshaped = true;
}
// Convert weights if needed (happens only once)
if(!_are_weights_converted)
{
- _converted_weights_output.allocator()->allocate();
- _convert_weights.run();
+ if(_weights_manager && _weights_manager->are_weights_managed(cur_weights))
+ {
+ _weights_manager->run(cur_weights, &_convert_weights_managed);
+ }
+ else
+ {
+ _converted_weights_output.allocator()->allocate();
+ _convert_weights.run();
+ }
cur_weights->mark_as_unused();
_are_weights_converted = true;
diff --git a/src/runtime/NEON/functions/NEGEMM.cpp b/src/runtime/NEON/functions/NEGEMM.cpp
index 37d0e09fc9..df92b7999c 100644
--- a/src/runtime/NEON/functions/NEGEMM.cpp
+++ b/src/runtime/NEON/functions/NEGEMM.cpp
@@ -42,9 +42,9 @@ using namespace arm_compute::misc::shape_calculator;
namespace arm_compute
{
-NEGEMM::NEGEMM(std::shared_ptr<IMemoryManager> memory_manager)
- : _memory_group(memory_manager), _interleave_kernel(), _transpose_kernel(), _mm_kernel(), _asm_glue(memory_manager), _ma_kernel(), _tmp_a(), _tmp_b(), _original_b(nullptr),
- _run_vector_matrix_multiplication(false), _run_addition(false), _reshape_b_only_on_first_run(false), _is_prepared(false)
+NEGEMM::NEGEMM(std::shared_ptr<IMemoryManager> memory_manager, IWeightsManager *weights_manager)
+ : _memory_group(memory_manager), _weights_manager(weights_manager), _interleave_kernel(), _transpose_kernel(), _mm_kernel(), _asm_glue(memory_manager, weights_manager), _ma_kernel(), _tmp_a(),
+ _tmp_b(), _original_b(nullptr), _run_vector_matrix_multiplication(false), _run_addition(false), _reshape_b_only_on_first_run(false), _is_prepared(false)
{
}
@@ -276,13 +276,19 @@ void NEGEMM::prepare()
{
if(_asm_glue.is_configured())
{
- ARM_COMPUTE_ERROR_ON(!_original_b->is_used());
+ if(!_weights_manager || !_weights_manager->are_weights_managed(_original_b))
+ {
+ ARM_COMPUTE_ERROR_ON(!_original_b->is_used());
+ }
_asm_glue.prepare();
}
else if(_reshape_b_only_on_first_run && !_run_vector_matrix_multiplication && !_asm_glue.is_configured())
{
- ARM_COMPUTE_ERROR_ON(!_original_b->is_used());
+ if(!_weights_manager || !_weights_manager->are_weights_managed(_original_b))
+ {
+ ARM_COMPUTE_ERROR_ON(!_original_b->is_used());
+ }
_tmp_b.allocator()->allocate();
NEScheduler::get().schedule(&_transpose_kernel, Window::DimY);
diff --git a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
index 2a4498b0a9..956ded55d2 100644
--- a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
+++ b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
@@ -38,7 +38,8 @@ namespace
std::unique_ptr<IFunction> create_function_all_types(const arm_gemm::KernelDescription &gemm_kernel_info,
const ITensor *a, const ITensor *b, ITensor *d,
float alpha, float beta, const GEMMInfo &gemm_info,
- std::shared_ptr<IMemoryManager> memory_manager)
+ std::shared_ptr<IMemoryManager> memory_manager,
+ IWeightsManager *weights_manager)
{
// Note: It's safe to not check for FP16 support because this was already checked in NEGEMMAssemblyDispatch::configure()
@@ -50,7 +51,7 @@ std::unique_ptr<IFunction> create_function_all_types(const arm_gemm::KernelDescr
{
return nullptr;
}
- auto function = support::cpp14::make_unique<NEGEMMInterleavedWrapper>(memory_manager);
+ auto function = support::cpp14::make_unique<NEGEMMInterleavedWrapper>(memory_manager, weights_manager);
function->configure(a, b, d, alpha, beta, gemm_info);
return std::move(function);
}
@@ -73,25 +74,95 @@ std::unique_ptr<IFunction> create_function_all_types(const arm_gemm::KernelDescr
}
}
+template <typename TypeInput, typename TypeOutput>
+class FallbackTransform : public ITransformWeights
+{
+public:
+ void run() override
+ {
+ _output.allocator()->allocate();
+ ARM_COMPUTE_ERROR_ON(_output.buffer() == nullptr);
+ _gemm_kernel_asm->pretranspose_B_array(_output.buffer(), _in1_ptr, _ldb, _multi_stride_b);
+ _reshape_run = true;
+ }
+
+ void release() override
+ {
+ _output.allocator()->free();
+ }
+
+ ITensor *get_weights() override
+ {
+ return &_output;
+ }
+
+ uint32_t uid() override
+ {
+ uint32_t id = (_B_pretranspose_size | 0x80000000);
+ return id;
+ }
+
+ void configure(size_t B_pretranspose_size, unsigned int alignment)
+ {
+ _output.allocator()->init(TensorInfo(TensorShape{ (B_pretranspose_size + alignment /* FIXME: remove alignment after COMPMID-1088 */) }, 1, DataType::S8), alignment);
+ _B_pretranspose_size = B_pretranspose_size;
+ }
+
+ void set_pretranspose(ITensor *tensor)
+ {
+ if(!_reshape_run)
+ {
+ _gemm_kernel_asm->set_pretransposed_B_data(tensor->buffer());
+ }
+ }
+
+ void set_args(const int ldb, const TypeInput *in1_ptr, const int multi_stride_b, std::shared_ptr<arm_gemm::GemmCommon<TypeInput, TypeOutput>> gemm_kernel_asm)
+ {
+ _ldb = ldb;
+ _in1_ptr = in1_ptr;
+ _multi_stride_b = multi_stride_b;
+ _gemm_kernel_asm = gemm_kernel_asm;
+ }
+
+private:
+ Tensor _output{};
+ int _ldb{};
+ const TypeInput *_in1_ptr{};
+ int _multi_stride_b{};
+ size_t _B_pretranspose_size{};
+ std::shared_ptr<arm_gemm::GemmCommon<TypeInput, TypeOutput>> _gemm_kernel_asm{ nullptr };
+};
+
/** Fallback in case ACL doesn't have a function */
template <typename TypeInput, typename TypeOutput, class OutputStage = arm_gemm::Nothing>
class Fallback : public NEGEMMAssemblyDispatch::IFallback
{
public:
+ /** Destructor */
+ ~Fallback()
+ {
+ // Release memory if we have allocated the memory ourselves
+ if(_pretranspose && !(_weights_manager && _weights_manager->are_weights_managed(_b)))
+ {
+ delete _pretranspose;
+ }
+ }
+
/** Initialise the functions's input and output.
*
- * @param[in] a Input tensor containing the Matrix A.
- * @param[in] b Input tensor containing the Matrix B.
- * @param[in] c Input tensor containing the Matrix C.
- * @param[out] d Output tensor to store the result of matrix multiplication.
- * @param[in] args Matrix multiplication information.
- * @param[in] gemm_info GEMM meta-data
- * @param[in] memory_group Memory group to be used by the function.
- * @param[in] os Output stage meta-data.
+ * @param[in] a Input tensor containing the Matrix A.
+ * @param[in] b Input tensor containing the Matrix B.
+ * @param[in] c Input tensor containing the Matrix C.
+ * @param[out] d Output tensor to store the result of matrix multiplication.
+ * @param[in] args Matrix multiplication information.
+ * @param[in] gemm_info GEMM meta-data
+ * @param[in] memory_group Memory group to be used by the function.
+ * @param[in] weights_manager Weights manager to be used by the function.
+ * @param[in] os Output stage meta-data.
*/
void configure(const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d,
arm_gemm::GemmArgs<TypeOutput> args, const GEMMInfo &gemm_info,
- MemoryGroup &memory_group, const OutputStage &os = {});
+ MemoryGroup &memory_group, IWeightsManager *weights_manager, const OutputStage &os = {});
// Inherited methods overridden:
void run() override;
@@ -108,7 +179,7 @@ private:
void allocate_workspace(size_t workspace_size, MemoryGroup &memory_group, size_t alignment);
/** Assembly Gemm kernel */
- std::unique_ptr<arm_gemm::GemmCommon<TypeInput, TypeOutput>> _gemm_kernel_asm{ nullptr };
+ std::shared_ptr<arm_gemm::GemmCommon<TypeInput, TypeOutput>> _gemm_kernel_asm{ nullptr };
/** Optimised NEON kernel */
std::unique_ptr<INEKernel> _optimised_kernel{ nullptr };
/** Input A */
@@ -130,20 +201,25 @@ private:
/** GEMM workspace */
Tensor _workspace{};
/** Pre-transpose tensor */
- Tensor _pretranspose{};
+ ITensor *_pretranspose{ nullptr };
/** Prepared flag */
bool _is_prepared{ false };
/** GEMM meta-data */
GEMMInfo _gemm_info{};
+ /** Weights manager */
+ IWeightsManager *_weights_manager{ nullptr };
+ /** Weights transform object */
+ FallbackTransform<TypeInput, TypeOutput> _weights_transform{};
};
template <typename TypeInput, typename TypeOutput, class OutputStage>
void Fallback<TypeInput, TypeOutput, OutputStage>::configure(const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d,
arm_gemm::GemmArgs<TypeOutput> args, const GEMMInfo &gemm_info,
- MemoryGroup &memory_group, const OutputStage &os)
+ MemoryGroup &memory_group, IWeightsManager *weights_manager, const OutputStage &os)
{
arm_gemm::GemmConfig gemm_cfg;
const arm_gemm::KernelDescription gemm_kernel_info = arm_gemm::get_gemm_method<TypeInput, TypeOutput, OutputStage>(args, os);
+ _weights_manager = weights_manager;
if(gemm_kernel_info.method != arm_gemm::GemmMethod::GEMV_BATCHED)
{
gemm_cfg.filter = gemm_kernel_info.name;
@@ -190,7 +266,16 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::configure(const ITensor *a, c
// Forcing 128-byte alignment (required by 32-bit kernels)
const unsigned int alignment = 128;
const size_t B_pretranspose_size = _gemm_kernel_asm->get_B_pretransposed_array_size();
- _pretranspose.allocator()->init(TensorInfo(TensorShape{ (B_pretranspose_size + alignment /* FIXME: remove alignment after COMPMID-1088 */) }, 1, DataType::S8), alignment);
+ if(weights_manager && _weights_manager->are_weights_managed(b))
+ {
+ _weights_transform.configure(B_pretranspose_size, alignment);
+ _pretranspose = _weights_manager->acquire(b, &_weights_transform);
+ }
+ else
+ {
+ _pretranspose = new Tensor();
+ static_cast<Tensor *>(_pretranspose)->allocator()->init(TensorInfo(TensorShape{ (B_pretranspose_size + alignment /* FIXME: remove alignment after COMPMID-1088 */) }, 1, DataType::S8), alignment);
+ }
}
}
@@ -208,14 +293,28 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::prepare()
// Pretranspose B if required
if(_gemm_kernel_asm->B_pretranspose_required())
{
- _pretranspose.allocator()->allocate();
- ARM_COMPUTE_ERROR_ON(_pretranspose.buffer() == nullptr);
const int ldb = _b->info()->strides_in_bytes().y() / sizeof(TypeInput);
const auto in1_ptr = reinterpret_cast<const TypeInput *>(_b->buffer() + _b->info()->offset_first_element_in_bytes());
const int multi_stride_b = _b->info()->strides_in_bytes().z() / sizeof(TypeInput);
- _gemm_kernel_asm->pretranspose_B_array(_pretranspose.buffer(), in1_ptr, ldb, multi_stride_b);
- _b->mark_as_unused();
+ if(_weights_manager && _weights_manager->are_weights_managed(_b))
+ {
+ _weights_transform.set_args(ldb, in1_ptr, multi_stride_b, _gemm_kernel_asm);
+ _weights_manager->run(_b, &_weights_transform);
+
+ // If we didn't run the reshape function, set the pretransposed buffer
+ if(!_weights_transform.is_reshape_run())
+ {
+ _weights_transform.set_pretranspose(_pretranspose);
+ }
+ }
+ else
+ {
+ static_cast<Tensor *>(_pretranspose)->allocator()->allocate();
+ ARM_COMPUTE_ERROR_ON(_pretranspose->buffer() == nullptr);
+ _gemm_kernel_asm->pretranspose_B_array(_pretranspose->buffer(), in1_ptr, ldb, multi_stride_b);
+ _b->mark_as_unused();
+ }
}
_is_prepared = true;
@@ -294,7 +393,7 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::run()
template <typename TypeInput, typename TypeOutput>
void create_function_or_arm_gemm(std::unique_ptr<IFunction> &acl_function, std::unique_ptr<NEGEMMAssemblyDispatch::IFallback> &arm_gemm, MemoryGroup &memory_group,
const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d, float alpha, float beta, const GEMMInfo &gemm_info,
- std::shared_ptr<IMemoryManager> memory_manager)
+ std::shared_ptr<IMemoryManager> memory_manager, IWeightsManager *weights_manager)
{
INEGEMMWrapperKernel::Params p = INEGEMMWrapperKernel::extract_parameters(a, b, d, gemm_info);
const CPUInfo &ci = NEScheduler::get().cpu_info();
@@ -304,14 +403,14 @@ void create_function_or_arm_gemm(std::unique_ptr<IFunction> &acl_function, std::
// Try to create an ACL function:
const arm_gemm::KernelDescription gemm_kernel_info = arm_gemm::get_gemm_method<TypeInput, TypeOutput>(args);
- acl_function = create_function_all_types(gemm_kernel_info, a, b, d, alpha, beta, gemm_info, std::move(memory_manager));
+ acl_function = create_function_all_types(gemm_kernel_info, a, b, d, alpha, beta, gemm_info, std::move(memory_manager), weights_manager);
// If we still don't have an ACL function:
if(acl_function == nullptr)
{
//Fallback onto arm_gemm function if ACL doesn't support this method.
auto fallback = support::cpp14::make_unique<Fallback<TypeInput, TypeOutput>>();
- fallback->configure(a, b, c, d, args, gemm_info, memory_group);
+ fallback->configure(a, b, c, d, args, gemm_info, memory_group, weights_manager);
arm_gemm = std::move(fallback);
}
}
@@ -319,7 +418,7 @@ void create_function_or_arm_gemm(std::unique_ptr<IFunction> &acl_function, std::
template <typename TypeInput, typename TypeOutput>
void create_function_or_arm_gemm_quant(std::unique_ptr<IFunction> &acl_function, std::unique_ptr<NEGEMMAssemblyDispatch::IFallback> &arm_gemm, MemoryGroup &memory_group,
const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d, float alpha, float beta, const GEMMInfo &gemm_info,
- std::shared_ptr<IMemoryManager> memory_manager)
+ std::shared_ptr<IMemoryManager> memory_manager, IWeightsManager *weights_manager)
{
INEGEMMWrapperKernel::Params p = INEGEMMWrapperKernel::extract_parameters(a, b, d, gemm_info);
const CPUInfo &ci = NEScheduler::get().cpu_info();
@@ -339,22 +438,22 @@ void create_function_or_arm_gemm_quant(std::unique_ptr<IFunction> &acl_function,
// Try to create an ACL function:
const arm_gemm::KernelDescription gemm_kernel_info = arm_gemm::get_gemm_method<TypeInput, TypeOutput>(args, gemm_requant_info);
- acl_function = create_function_all_types(gemm_kernel_info, a, b, d, alpha, beta, gemm_info, std::move(memory_manager));
+ acl_function = create_function_all_types(gemm_kernel_info, a, b, d, alpha, beta, gemm_info, std::move(memory_manager), weights_manager);
// If we still don't have an ACL function:
if(acl_function == nullptr)
{
// Fallback onto arm_gemm function if ACL doesn't support this method.
auto fallback = support::cpp14::make_unique<Fallback<TypeInput, TypeOutput, arm_gemm::ARequantizeLayer32>>();
- fallback->configure(a, b, c, d, args, gemm_info, memory_group, gemm_requant_info);
+ fallback->configure(a, b, c, d, args, gemm_info, memory_group, weights_manager, gemm_requant_info);
arm_gemm = std::move(fallback);
}
}
} //namespace
-NEGEMMAssemblyDispatch::NEGEMMAssemblyDispatch(std::shared_ptr<IMemoryManager> memory_manager)
- : _function(nullptr), _arm_gemm(nullptr), _memory_group(memory_manager), _memory_manager(memory_manager)
+NEGEMMAssemblyDispatch::NEGEMMAssemblyDispatch(std::shared_ptr<IMemoryManager> memory_manager, IWeightsManager *weights_manager)
+ : _function(nullptr), _arm_gemm(nullptr), _memory_group(memory_manager), _memory_manager(memory_manager), _weights_manager(weights_manager)
{
}
@@ -390,27 +489,27 @@ void NEGEMMAssemblyDispatch::configure(const ITensor *a, const ITensor *b, const
switch(a->info()->data_type())
{
case DataType::F32:
- create_function_or_arm_gemm<float, float>(_function, _arm_gemm, _memory_group, a, b, c, d, alpha, beta, gemm_info, _memory_manager);
+ create_function_or_arm_gemm<float, float>(_function, _arm_gemm, _memory_group, a, b, c, d, alpha, beta, gemm_info, _memory_manager, _weights_manager);
break;
#ifdef __aarch64__
case DataType::U8:
case DataType::QASYMM8:
if(d->info()->data_type() == DataType::S32)
{
- create_function_or_arm_gemm<uint8_t, uint32_t>(_function, _arm_gemm, _memory_group, a, b, c, d, alpha, beta, gemm_info, _memory_manager);
+ create_function_or_arm_gemm<uint8_t, uint32_t>(_function, _arm_gemm, _memory_group, a, b, c, d, alpha, beta, gemm_info, _memory_manager, _weights_manager);
}
else
{
- create_function_or_arm_gemm_quant<uint8_t, uint8_t>(_function, _arm_gemm, _memory_group, a, b, c, d, alpha, beta, gemm_info, _memory_manager);
+ create_function_or_arm_gemm_quant<uint8_t, uint8_t>(_function, _arm_gemm, _memory_group, a, b, c, d, alpha, beta, gemm_info, _memory_manager, _weights_manager);
}
break;
case DataType::S8:
- create_function_or_arm_gemm<int8_t, int32_t>(_function, _arm_gemm, _memory_group, a, b, c, d, alpha, beta, gemm_info, _memory_manager);
+ create_function_or_arm_gemm<int8_t, int32_t>(_function, _arm_gemm, _memory_group, a, b, c, d, alpha, beta, gemm_info, _memory_manager, _weights_manager);
break;
#endif /* __aarch64__ */
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
case DataType::F16:
- create_function_or_arm_gemm<float16_t, float16_t>(_function, _arm_gemm, _memory_group, a, b, c, d, alpha, beta, gemm_info, _memory_manager);
+ create_function_or_arm_gemm<float16_t, float16_t>(_function, _arm_gemm, _memory_group, a, b, c, d, alpha, beta, gemm_info, _memory_manager, _weights_manager);
break;
#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
default:
diff --git a/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp b/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp
index e94c8933ae..a39e4c5125 100644
--- a/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp
+++ b/src/runtime/NEON/functions/NEGEMMConvolutionLayer.cpp
@@ -50,7 +50,6 @@ void NEConvolutionLayerReshapeWeights::configure(const ITensor *weights, const I
ARM_COMPUTE_ERROR_THROW_ON(NEConvolutionLayerReshapeWeights::validate(weights->info(),
(biases != nullptr) ? biases->info() : nullptr,
output->info()));
-
const bool append_biases = (biases != nullptr) && !is_data_type_quantized_asymmetric(weights->info()->data_type());
const ITensor *biases_to_use = (append_biases) ? biases : nullptr;
@@ -89,10 +88,10 @@ void NEConvolutionLayerReshapeWeights::run()
NEScheduler::get().schedule(&_weights_reshape_kernel, 3);
}
-NEGEMMConvolutionLayer::NEGEMMConvolutionLayer(const std::shared_ptr<IMemoryManager> &memory_manager)
- : _memory_group(memory_manager), _reshape_weights(), _im2col_kernel(), _mm_gemm(memory_manager), _mm_gemmlowp(memory_manager), _col2im_kernel(), _activationlayer_function(), _add_bias_kernel(),
- _reshape_layer(), _original_weights(nullptr), _im2col_output(), _weights_reshaped(), _gemm_output(), _tmp_output(), _data_layout(DataLayout::NCHW), _append_bias(false), _skip_im2col(false),
- _skip_col2im(false), _is_quantized(false), _is_activationlayer_enabled(false), _is_prepared(false)
+NEGEMMConvolutionLayer::NEGEMMConvolutionLayer(const std::shared_ptr<IMemoryManager> &memory_manager, IWeightsManager *weights_manager)
+ : _memory_group(memory_manager), _weights_manager(weights_manager), _reshape_weights(), _reshape_weights_managed(), _im2col_kernel(), _mm_gemm(memory_manager), _mm_gemmlowp(memory_manager),
+ _col2im_kernel(), _activationlayer_function(), _add_bias_kernel(), _reshape_layer(), _original_weights(nullptr), _im2col_output(), _weights_reshaped(), _gemm_output(), _tmp_output(),
+ _data_layout(DataLayout::NCHW), _append_bias(false), _skip_im2col(false), _skip_col2im(false), _is_quantized(false), _is_activationlayer_enabled(false), _is_prepared(false)
{
}
@@ -309,7 +308,18 @@ void NEGEMMConvolutionLayer::configure(const ITensor *input, const ITensor *weig
// _weights_reshaped will be auto configured in the kernel.
// Just append biases and do not transpose 1xW as it will be reshaped in NEGEMM
- _reshape_weights.configure(weights, biases_to_use, &_weights_reshaped);
+ const ITensor *weights_to_use = weights;
+
+ if(_weights_manager && _weights_manager->are_weights_managed(weights))
+ {
+ _reshape_weights_managed.configure(weights, biases_to_use);
+ weights_to_use = _weights_manager->acquire(weights, &_reshape_weights_managed);
+ }
+ else
+ {
+ _reshape_weights.configure(weights, biases_to_use, &_weights_reshaped);
+ weights_to_use = &_weights_reshaped;
+ }
// Create tensor to store im2col reshaped inputs
if(!_skip_im2col)
@@ -351,7 +361,7 @@ void NEGEMMConvolutionLayer::configure(const ITensor *input, const ITensor *weig
// Configure GEMM
// In case we need to skip col2im, GEMM3D (gemm_3d_depth != 0) must be called in order to avoid reshaping the output matrix
const unsigned int gemm_3d_depth = _skip_col2im ? conv_h : 0;
- configure_mm(gemm_input_to_use, &_weights_reshaped, biases, gemm_output_to_use, act_info, gemm_3d_depth);
+ configure_mm(gemm_input_to_use, weights_to_use, biases, gemm_output_to_use, act_info, gemm_3d_depth);
if(!_skip_im2col)
{
@@ -493,7 +503,7 @@ Status NEGEMMConvolutionLayer::validate(const ITensorInfo *input, const ITensorI
ARM_COMPUTE_RETURN_ON_ERROR(NEConvolutionLayerReshapeWeights::validate(weights, biases_to_use, nullptr));
weights_reshaped_info = TensorInfo(compute_weights_reshaped_shape(*weights, (append_bias && !skip_im2col)), 1, data_type);
weights_reshaped_info.set_quantization_info(weights->quantization_info());
- weights_to_use = &weights_reshaped_info;
+ weights_to_use = &weights_reshaped_info;
if(!skip_im2col)
{
@@ -603,10 +613,17 @@ void NEGEMMConvolutionLayer::prepare()
{
ARM_COMPUTE_ERROR_ON(!_original_weights->is_used());
- // Run weights reshaping and mark original weights tensor as unused
- _weights_reshaped.allocator()->allocate();
- _reshape_weights.run();
- _original_weights->mark_as_unused();
+ if(_weights_manager && _weights_manager->are_weights_managed(_original_weights))
+ {
+ _weights_manager->run(_original_weights, &_reshape_weights_managed);
+ }
+ else
+ {
+ // Run weights reshaping and mark original weights tensor as unused
+ _weights_reshaped.allocator()->allocate();
+ _reshape_weights.run();
+ _original_weights->mark_as_unused();
+ }
// Prepare GEMM
_is_quantized ? _mm_gemmlowp.prepare() : _mm_gemm.prepare();
diff --git a/src/runtime/NEON/functions/NERNNLayer.cpp b/src/runtime/NEON/functions/NERNNLayer.cpp
index 9ca7ded3be..67f4064632 100644
--- a/src/runtime/NEON/functions/NERNNLayer.cpp
+++ b/src/runtime/NEON/functions/NERNNLayer.cpp
@@ -34,8 +34,8 @@
namespace arm_compute
{
NERNNLayer::NERNNLayer(std::shared_ptr<IMemoryManager> memory_manager)
- : _memory_group(std::move(memory_manager)), _gemm_state_f(), _add_kernel(), _activation_kernel(), _fully_connected_kernel(), _copy_kernel(), _fully_connected_out(), _gemm_output(), _add_output(),
- _is_prepared(false)
+ : _memory_group(std::move(memory_manager)), _gemm_state_f(), _add_kernel(), _activation_kernel(), _fully_connected(memory_manager), _copy_kernel(), _fully_connected_out(), _gemm_output(),
+ _add_output(), _is_prepared(false)
{
}
@@ -81,7 +81,7 @@ void NERNNLayer::configure(const ITensor *input, const ITensor *weights, const I
// Manage intermediate buffers and configure
_memory_group.manage(&_fully_connected_out);
- _fully_connected_kernel.configure(input, weights, bias, &_fully_connected_out);
+ _fully_connected.configure(input, weights, bias, &_fully_connected_out);
_memory_group.manage(&_gemm_output);
_gemm_state_f.configure(hidden_state, recurrent_weights, nullptr, &_gemm_output, 1.f, 0.f);
@@ -106,7 +106,7 @@ void NERNNLayer::run()
MemoryGroupResourceScope scope_mg(_memory_group);
- _fully_connected_kernel.run();
+ _fully_connected.run();
_gemm_state_f.run();
@@ -121,7 +121,7 @@ void NERNNLayer::prepare()
{
if(!_is_prepared)
{
- _fully_connected_kernel.prepare();
+ _fully_connected.prepare();
_gemm_state_f.prepare();
_is_prepared = true;
diff --git a/src/runtime/NEON/functions/assembly/NEGEMMInterleavedWrapper.cpp b/src/runtime/NEON/functions/assembly/NEGEMMInterleavedWrapper.cpp
index ac809fa142..41d7d1ff76 100644
--- a/src/runtime/NEON/functions/assembly/NEGEMMInterleavedWrapper.cpp
+++ b/src/runtime/NEON/functions/assembly/NEGEMMInterleavedWrapper.cpp
@@ -180,7 +180,7 @@ public:
}
};
-NEGEMMInterleavedWrapper::NEGEMMInterleavedWrapper(std::shared_ptr<IMemoryManager> memory_manager)
+NEGEMMInterleavedWrapper::NEGEMMInterleavedWrapper(std::shared_ptr<IMemoryManager> memory_manager, IWeightsManager *weights_manager)
: _memory_group(std::move(memory_manager))
{
}