aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedTransformAWrapper.h
diff options
context:
space:
mode:
Diffstat (limited to 'arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedTransformAWrapper.h')
-rw-r--r--arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedTransformAWrapper.h71
1 files changed, 66 insertions, 5 deletions
diff --git a/arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedTransformAWrapper.h b/arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedTransformAWrapper.h
index b6831e3ca9..5d6cd02398 100644
--- a/arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedTransformAWrapper.h
+++ b/arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedTransformAWrapper.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018 ARM Limited.
+ * Copyright (c) 2018-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -25,8 +25,13 @@
#define __ARM_COMPUTE_NEGEMMINTERLEAVEDTRANSFORMAWRAPPER_H__
#include "arm_compute/core/CPP/CPPTypes.h"
+#include "arm_compute/core/Helpers.h"
+#include "arm_compute/core/ITensor.h"
#include "arm_compute/core/NEON/kernels/assembly/INEGEMMWrapperKernel.h"
+#include "arm_compute/core/Utils.h"
+#include "arm_compute/core/Validate.h"
#include "arm_compute/core/Window.h"
+#include "arm_compute/core/WindowIterator.h"
namespace arm_compute
{
@@ -76,7 +81,7 @@ public:
};
/** Type specialisations of @ref NEGEMMInterleavedTransformAWrapper */
-template <typename To, bool use_dot = false>
+template <typename strategy>
class NEGEMMInterleavedTransformAWrapperTemplate : public NEGEMMInterleavedTransformAWrapper
{
public:
@@ -88,11 +93,67 @@ public:
* @param[in] block_walker Window representing the layout of the matrix's blocks
* @param[in] params M, N, K sizes.
*/
- void configure(const ITensor *a, ITensor *transformed_a, bool transpose_a, const Window &block_walker, const INEGEMMWrapperKernel::Params &params);
+ void configure(const ITensor *a, ITensor *transformed_a, bool transpose_a, const Window &block_walker, const INEGEMMWrapperKernel::Params &params)
+ {
+ _a = a;
+ _transformed_a = transformed_a;
+ _transpose_a = transpose_a;
+ _Ksize = params.K;
+ _Msize = params.M;
+ _k_multi_window = block_walker.shift_dimensions(1); // block_walker contains (M,K,Multi) --> shift by 1 to get rid of the "M" dimension
+ }
// Inherited methods overridden:
- void transform(const TransformAWorkload &wl, const ThreadInfo &info, const Window &batch_window, const Coordinates &start_offset, const Coordinates &end_offset) override;
- void create_workloads(std::vector<TransformAWorkload> &workloads) override;
+ void transform(const TransformAWorkload &wl, const ThreadInfo &info, const Window &batch_window, const Coordinates &start_offset, const Coordinates &end_offset) override
+ {
+ strategy strat(info.cpu_info);
+ TensorAccessor<typename strategy::operand_type> a(*_a);
+ TensorAccessor<typename strategy::operand_type> transformed_a(*_transformed_a);
+
+ if(_a->info()->data_layout() == DataLayout::NHWC)
+ {
+ // In the case of NHWC we want to interpret the output shape as 3D. Thus, the batch stride for A is
+ // the relevant multiple of the row stride.
+ const size_t nhwc_batch_stride = _a->info()->strides_in_bytes().y() * _Msize;
+ a.set_stride(2, nhwc_batch_stride);
+ }
+
+ unsigned int last_m = 0;
+ //TODO: Create a new iterate_1D( DimY);
+ int last_y = -1;
+ auto window_iterator = arm_compute::create_window_iterator(batch_window, start_offset, end_offset, [&](const Coordinates & id)
+ {
+ if(id.y() != last_y)
+ {
+ last_y = id.y();
+ unsigned int batch = id.y();
+ unsigned int first_m = id.x();
+
+ if(first_m >= last_m)
+ return;
+
+ strat.transforms.PrepareA(transformed_a(0, first_m, batch),
+ a(0, 0, batch, wl._multi),
+ a.stride(1), first_m, last_m, wl._k0, wl._kmax, _transpose_a);
+ }
+ });
+ auto on_new_row_size = [&](unsigned int start, unsigned int end)
+ {
+ last_m = std::min(end, _Msize);
+ };
+ window_iterator.iterate_2D(on_new_row_size);
+ }
+ void create_workloads(std::vector<TransformAWorkload> &workloads) override
+ {
+ execute_window_loop(_k_multi_window, [&](const Coordinates & id)
+ {
+ const unsigned int k0 = id.x();
+ const unsigned int multi = id.y();
+ const unsigned int kmax = std::min(k0 + _k_multi_window.x().step(), _Ksize);
+
+ workloads.push_back(TransformAWorkload(k0, kmax, multi));
+ });
+ }
private:
const ITensor *_a