diff options
Diffstat (limited to 'tests/validation/reference')
-rw-r--r-- | tests/validation/reference/ConcatenateLayer.cpp (renamed from tests/validation/reference/WidthConcatenateLayer.cpp) | 43 | ||||
-rw-r--r-- | tests/validation/reference/ConcatenateLayer.h (renamed from tests/validation/reference/WidthConcatenateLayer.h) | 10 |
2 files changed, 41 insertions, 12 deletions
diff --git a/tests/validation/reference/WidthConcatenateLayer.cpp b/tests/validation/reference/ConcatenateLayer.cpp index 38543393ce..1440878829 100644 --- a/tests/validation/reference/WidthConcatenateLayer.cpp +++ b/tests/validation/reference/ConcatenateLayer.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2019 ARM Limited. + * Copyright (c) 2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -21,9 +21,10 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ -#include "WidthConcatenateLayer.h" +#include "ConcatenateLayer.h" #include "tests/validation/Helpers.h" +#include "tests/validation/reference/Permute.h" namespace arm_compute { @@ -33,24 +34,22 @@ namespace validation { namespace reference { +namespace +{ template <typename T> SimpleTensor<T> widthconcatenate_layer(const std::vector<SimpleTensor<T>> &srcs, SimpleTensor<T> &dst) { // Create reference std::vector<TensorShape> shapes; - for(const auto &src : srcs) { shapes.emplace_back(src.shape()); } - // Compute reference int width_offset = 0; const int width_out = dst.shape().x(); - // Set output tensor to 0 std::fill_n(dst.data(), dst.num_elements(), 0); - for(const auto &src : srcs) { ARM_COMPUTE_ERROR_ON(width_offset >= width_out); @@ -89,13 +88,43 @@ SimpleTensor<T> widthconcatenate_layer(const std::vector<SimpleTensor<T>> &srcs, } width_offset += width; } - return dst; } template SimpleTensor<float> widthconcatenate_layer(const std::vector<SimpleTensor<float>> &srcs, SimpleTensor<float> &dst); template SimpleTensor<half> widthconcatenate_layer(const std::vector<SimpleTensor<half>> &srcs, SimpleTensor<half> &dst); template SimpleTensor<uint8_t> widthconcatenate_layer(const std::vector<SimpleTensor<uint8_t>> &srcs, SimpleTensor<uint8_t> &dst); +} // namespace + +template <typename T> +SimpleTensor<T> concatenate_layer(std::vector<SimpleTensor<T>> &srcs, SimpleTensor<T> &dst, unsigned int axis) +{ + switch(axis) + { + case Window::DimX: + { + return widthconcatenate_layer(srcs, dst); + } + case Window::DimY: + { + for(auto &t : srcs) + { + t = reference::permute<T>(t, PermutationVector(1U, 0U)); + } + dst = reference::permute<T>(dst, PermutationVector(1U, 0U)); + return reference::permute<T>(widthconcatenate_layer(srcs, dst), PermutationVector(1U, 0U)); + } + default: + { + ARM_COMPUTE_ERROR("Not supported"); + return dst; + } + } +} + +template SimpleTensor<float> concatenate_layer(std::vector<SimpleTensor<float>> &srcs, SimpleTensor<float> &dst, unsigned int axis); +template SimpleTensor<half> concatenate_layer(std::vector<SimpleTensor<half>> &srcs, SimpleTensor<half> &dst, unsigned int axis); +template SimpleTensor<uint8_t> concatenate_layer(std::vector<SimpleTensor<uint8_t>> &srcs, SimpleTensor<uint8_t> &dst, unsigned int axis); } // namespace reference } // namespace validation } // namespace test diff --git a/tests/validation/reference/WidthConcatenateLayer.h b/tests/validation/reference/ConcatenateLayer.h index 0f1f428f10..14fd097eee 100644 --- a/tests/validation/reference/WidthConcatenateLayer.h +++ b/tests/validation/reference/ConcatenateLayer.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2019 ARM Limited. + * Copyright (c) 2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -21,8 +21,8 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ -#ifndef __ARM_COMPUTE_TEST_WIDTHCONCATENATE_LAYER_H__ -#define __ARM_COMPUTE_TEST_WIDTHCONCATENATE_LAYER_H__ +#ifndef __ARM_COMPUTE_TEST_CONCATENATE_LAYER_H__ +#define __ARM_COMPUTE_TEST_CONCATENATE_LAYER_H__ #include "tests/SimpleTensor.h" @@ -37,9 +37,9 @@ namespace validation namespace reference { template <typename T> -SimpleTensor<T> widthconcatenate_layer(const std::vector<SimpleTensor<T>> &srcs, SimpleTensor<T> &dst); +SimpleTensor<T> concatenate_layer(std::vector<SimpleTensor<T>> &srcs, SimpleTensor<T> &dst, unsigned int axis); } // namespace reference } // namespace validation } // namespace test } // namespace arm_compute -#endif /* __ARM_COMPUTE_TEST_WIDTHCONCATENATE_LAYER_H__ */ +#endif /* __ARM_COMPUTE_TEST_CONCATENATE_LAYER_H__ */ |