aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/core/NEON/kernels/convolution/common/tensor.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'arm_compute/core/NEON/kernels/convolution/common/tensor.hpp')
-rw-r--r--arm_compute/core/NEON/kernels/convolution/common/tensor.hpp69
1 files changed, 35 insertions, 34 deletions
diff --git a/arm_compute/core/NEON/kernels/convolution/common/tensor.hpp b/arm_compute/core/NEON/kernels/convolution/common/tensor.hpp
index 6567eeb23d..ad0a677a8f 100644
--- a/arm_compute/core/NEON/kernels/convolution/common/tensor.hpp
+++ b/arm_compute/core/NEON/kernels/convolution/common/tensor.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017 ARM Limited.
+ * Copyright (c) 2017-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -54,6 +54,18 @@ struct Tensor4DShape
{
}
+ inline int index(const int n, const int i, const int j, const int c) const
+ {
+ if (this->ordering == NHWC)
+ {
+ return ((n*this->n_rows + i)*this->n_cols + j)*this->n_channels + c;
+ }
+ else // NCHW
+ {
+ return ((n*this->n_channels + c)*this->n_rows + i)*this->n_cols + j;
+ }
+ }
+
inline int size() const
{
return n_batches * n_rows * n_cols * n_channels;
@@ -94,6 +106,18 @@ struct KernelShape
{
}
+ inline int index(int oc, int i, int j, int ic) const
+ {
+ if (this->ordering == HWIO)
+ {
+ return ((i*this->n_cols + j)*this->n_input_channels + ic)*this->n_output_channels + oc;
+ }
+ else // OIHW
+ {
+ return ((oc*this->n_input_channels + ic)*this->n_rows + i)*this->n_cols + j;
+ }
+ }
+
inline int size(void) const
{
return n_output_channels * n_rows * n_cols * n_input_channels;
@@ -127,7 +151,16 @@ class Tensor4D final
return shape.size() * sizeof(T);
}
- inline T& element(int, int, int, int) const;
+ /* Extract an element of the tensor.
+ *
+ * If the shape is a Tensor4DShape then the index is given as batch, row,
+ * column and channel. If the shape is a KernelShape then the index is
+ * given as output channel, row, column and input channel.
+ */
+ inline T& element(const int a, const int b, const int c, const int d) const
+ {
+ return _data[shape.index(a, b, c, d)];
+ }
inline void Clear() {
Fill(static_cast<T>(0));
@@ -143,35 +176,3 @@ class Tensor4D final
private:
T* const _data;
};
-
-
-template <>
-inline float& Tensor4D<Tensor4DShape, float>::element(int n, int i, int j, int c) const
-{
- int index;
- if (shape.ordering == NHWC)
- {
- index = ((n*shape.n_rows + i)*shape.n_cols + j)*shape.n_channels + c;
- }
- else // NCHW
- {
- index = ((n*shape.n_channels + c)*shape.n_rows + i)*shape.n_cols + j;
- }
- return _data[index];
-}
-
-
-template <>
-inline float& Tensor4D<KernelShape, float>::element(int oc, int i, int j, int ic) const
-{
- int index;
- if (shape.ordering == HWIO)
- {
- index = ((i*shape.n_cols + j)*shape.n_input_channels + ic)*shape.n_output_channels + oc;
- }
- else // OIHW
- {
- index = ((oc*shape.n_input_channels + ic)*shape.n_rows + i)*shape.n_cols + j;
- }
- return _data[index];
-}