diff options
Diffstat (limited to 'tests/validation/fixtures')
-rw-r--r-- | tests/validation/fixtures/ConcatenateLayerFixture.h (renamed from tests/validation/fixtures/WidthConcatenateLayerFixture.h) | 46 | ||||
-rw-r--r-- | tests/validation/fixtures/LSTMLayerFixture.h | 4 |
2 files changed, 31 insertions, 19 deletions
diff --git a/tests/validation/fixtures/WidthConcatenateLayerFixture.h b/tests/validation/fixtures/ConcatenateLayerFixture.h index 47a03ed865..db09957c09 100644 --- a/tests/validation/fixtures/WidthConcatenateLayerFixture.h +++ b/tests/validation/fixtures/ConcatenateLayerFixture.h @@ -33,7 +33,7 @@ #include "tests/framework/Asserts.h" #include "tests/framework/Fixture.h" #include "tests/validation/Helpers.h" -#include "tests/validation/reference/WidthConcatenateLayer.h" +#include "tests/validation/reference/ConcatenateLayer.h" #include <random> @@ -44,11 +44,11 @@ namespace test namespace validation { template <typename TensorType, typename ITensorType, typename AccessorType, typename FunctionType, typename T> -class WidthConcatenateLayerValidationFixture : public framework::Fixture +class ConcatenateLayerValidationFixture : public framework::Fixture { public: template <typename...> - void setup(TensorShape shape, DataType data_type) + void setup(TensorShape shape, DataType data_type, unsigned int axis) { // Create input shapes std::mt19937 gen(library->seed()); @@ -78,12 +78,12 @@ public: { // Decrease the dimension by a small percentage. Don't increase // as that could make tensor too large. - s.set(0, s[0] + 2 * static_cast<int>(s[0] * change_dis(gen))); + s.set(axis, s[axis] + 2 * static_cast<int>(s[axis] * change_dis(gen))); } } - _target = compute_target(shapes, qinfo, data_type); - _reference = compute_reference(shapes, qinfo, data_type); + _target = compute_target(shapes, qinfo, data_type, axis); + _reference = compute_reference(shapes, qinfo, data_type, axis); } protected: @@ -93,7 +93,7 @@ protected: library->fill_tensor_uniform(tensor, i); } - TensorType compute_target(std::vector<TensorShape> shapes, const std::vector<QuantizationInfo> &qinfo, DataType data_type) + TensorType compute_target(const std::vector<TensorShape> &shapes, const std::vector<QuantizationInfo> &qinfo, DataType data_type, unsigned int axis) { std::vector<TensorType> srcs; std::vector<ITensorType *> src_ptrs; @@ -107,13 +107,26 @@ protected: src_ptrs.emplace_back(&srcs.back()); } - TensorShape dst_shape = misc::shape_calculator::calculate_width_concatenate_shape(src_ptrs); - - TensorType dst = create_tensor<TensorType>(dst_shape, data_type, 1, qinfo[shapes.size()]); + const TensorShape dst_shape = misc::shape_calculator::calculate_concatenate_shape(src_ptrs, axis); + TensorType dst = create_tensor<TensorType>(dst_shape, data_type, 1, qinfo[shapes.size()]); // Create and configure function - FunctionType width_concat; - width_concat.configure(src_ptrs, &dst); + FunctionType concat; + switch(axis) + { + case 0: + concat.configure(src_ptrs, &dst, DataLayoutDimension::WIDTH); + break; + case 1: + concat.configure(src_ptrs, &dst, DataLayoutDimension::HEIGHT); + break; + case 2: + concat.configure(src_ptrs, &dst, DataLayoutDimension::CHANNEL); + break; + default: + ARM_COMPUTE_ERROR("Not supported"); + break; + } for(auto &src : srcs) { @@ -140,12 +153,12 @@ protected: } // Compute function - width_concat.run(); + concat.run(); return dst; } - SimpleTensor<T> compute_reference(std::vector<TensorShape> shapes, const std::vector<QuantizationInfo> &qinfo, DataType data_type) + SimpleTensor<T> compute_reference(const std::vector<TensorShape> &shapes, const std::vector<QuantizationInfo> &qinfo, DataType data_type, unsigned int axis) { std::vector<SimpleTensor<T>> srcs; @@ -156,10 +169,9 @@ protected: fill(srcs.back(), j); } - const TensorShape dst_shape = calculate_width_concatenate_shape(shapes); + const TensorShape dst_shape = calculate_concatenate_shape(shapes, axis); SimpleTensor<T> dst{ dst_shape, data_type, 1, qinfo[shapes.size()] }; - - return reference::widthconcatenate_layer<T>(srcs, dst); + return reference::concatenate_layer<T>(srcs, dst, axis); } TensorType _target{}; diff --git a/tests/validation/fixtures/LSTMLayerFixture.h b/tests/validation/fixtures/LSTMLayerFixture.h index b30f1e534b..2cf83b8b3d 100644 --- a/tests/validation/fixtures/LSTMLayerFixture.h +++ b/tests/validation/fixtures/LSTMLayerFixture.h @@ -29,11 +29,11 @@ #include "tests/framework/Fixture.h" #include "tests/validation/reference/ActivationLayer.h" #include "tests/validation/reference/ArithmeticOperations.h" +#include "tests/validation/reference/ConcatenateLayer.h" #include "tests/validation/reference/FullyConnectedLayer.h" #include "tests/validation/reference/GEMM.h" #include "tests/validation/reference/PixelWiseMultiplication.h" #include "tests/validation/reference/Transpose.h" -#include "tests/validation/reference/WidthConcatenateLayer.h" namespace arm_compute { @@ -415,7 +415,7 @@ protected: scratch_inputs.emplace_back(std::move(cell_state_out)); scratch_inputs.emplace_back(std::move(forget_gate)); scratch_inputs.emplace_back(std::move(output)); - scratch = reference::widthconcatenate_layer(scratch_inputs, scratch); + scratch = reference::concatenate_layer(scratch_inputs, scratch, Window::DimX); _reference_scratch = std::move(scratch); return output_state_out; } |