aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/arm_conv/depthwise/working_space.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/arm_conv/depthwise/working_space.hpp')
-rw-r--r--src/core/NEON/kernels/arm_conv/depthwise/working_space.hpp32
1 files changed, 31 insertions, 1 deletions
diff --git a/src/core/NEON/kernels/arm_conv/depthwise/working_space.hpp b/src/core/NEON/kernels/arm_conv/depthwise/working_space.hpp
index b1fe66cea2..9805fd354f 100644
--- a/src/core/NEON/kernels/arm_conv/depthwise/working_space.hpp
+++ b/src/core/NEON/kernels/arm_conv/depthwise/working_space.hpp
@@ -217,7 +217,7 @@ class InputBufferElement
template <typename StratType, typename OutputStage>
static size_t get_element_size(const WorkspaceArgs<StratType, OutputStage> &args)
{
- return sizeof(T) * args.depthwise_args.input_channels;
+ return sizeof(T) * args.depthwise_args.input_channels * args.depthwise_args.channel_multiplier;
}
template <class WorkspaceType, typename StratType, typename OutputStage>
@@ -278,6 +278,36 @@ class OutputArrayElement
};
+/* Intermediate array to store results of premultiplication.
+ * Used as input to the kernel instead of the original input array.
+ */
+template <typename T>
+class IntermediateBufferElement
+{
+public:
+ struct Workspace
+ {
+ T *intermediate_buffer;
+ };
+
+ template <typename StratType, typename OutputStage>
+ static size_t get_element_size(const WorkspaceArgs<StratType, OutputStage> &args)
+ {
+ auto cols = args.depthwise_args.input_cols + args.depthwise_args.kernel_cols;
+ auto rows = args.strategy->get_input_rows() + args.depthwise_args.kernel_rows;
+ auto channels = args.depthwise_args.input_channels * args.depthwise_args.channel_multiplier;
+ return sizeof(T) * cols * rows * channels;
+ }
+
+ template <class WorkspaceType, typename StratType, typename OutputStage>
+ static void *initialise(WorkspaceType *ws, void *buffer, const WorkspaceArgs<StratType, OutputStage> &args)
+ {
+ ws->intermediate_buffer = reinterpret_cast<T*>(buffer);
+ return reinterpret_cast<char *>(buffer) + get_element_size(args);
+ }
+};
+
+
/* Container for requantization parameters.
*
* This removes the distinction between per-layer and per-channel