aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures
diff options
context:
space:
mode:
authorPablo Tello <pablo.tello@arm.com>2019-03-04 14:14:02 +0000
committerGian Marco Iodice <gianmarco.iodice@arm.com>2019-03-19 10:18:10 +0000
commit3dd5b6884a65c06bcb9d15589ee2dc2978e3b336 (patch)
treee45ccae66b69c8db853ac883080c1c6358a57aec /tests/validation/fixtures
parent2f7c149f36fa3e6296aba6de666962947f032558 (diff)
downloadComputeLibrary-3dd5b6884a65c06bcb9d15589ee2dc2978e3b336.tar.gz
COMPMID-1933: Implement NEHeightConcatenateLayer.
Added support to concactenate tensors along the Y axis in NEConcatenateLayer. Change-Id: Ib714bfcf9954cc35918efa7d52fc9164bb08bdf6 Signed-off-by: Pablo Tello <pablo.tello@arm.com> Reviewed-on: https://review.mlplatform.org/c/841 Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com> Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'tests/validation/fixtures')
-rw-r--r--tests/validation/fixtures/ConcatenateLayerFixture.h (renamed from tests/validation/fixtures/WidthConcatenateLayerFixture.h)46
-rw-r--r--tests/validation/fixtures/LSTMLayerFixture.h4
2 files changed, 31 insertions, 19 deletions
diff --git a/tests/validation/fixtures/WidthConcatenateLayerFixture.h b/tests/validation/fixtures/ConcatenateLayerFixture.h
index 47a03ed865..db09957c09 100644
--- a/tests/validation/fixtures/WidthConcatenateLayerFixture.h
+++ b/tests/validation/fixtures/ConcatenateLayerFixture.h
@@ -33,7 +33,7 @@
#include "tests/framework/Asserts.h"
#include "tests/framework/Fixture.h"
#include "tests/validation/Helpers.h"
-#include "tests/validation/reference/WidthConcatenateLayer.h"
+#include "tests/validation/reference/ConcatenateLayer.h"
#include <random>
@@ -44,11 +44,11 @@ namespace test
namespace validation
{
template <typename TensorType, typename ITensorType, typename AccessorType, typename FunctionType, typename T>
-class WidthConcatenateLayerValidationFixture : public framework::Fixture
+class ConcatenateLayerValidationFixture : public framework::Fixture
{
public:
template <typename...>
- void setup(TensorShape shape, DataType data_type)
+ void setup(TensorShape shape, DataType data_type, unsigned int axis)
{
// Create input shapes
std::mt19937 gen(library->seed());
@@ -78,12 +78,12 @@ public:
{
// Decrease the dimension by a small percentage. Don't increase
// as that could make tensor too large.
- s.set(0, s[0] + 2 * static_cast<int>(s[0] * change_dis(gen)));
+ s.set(axis, s[axis] + 2 * static_cast<int>(s[axis] * change_dis(gen)));
}
}
- _target = compute_target(shapes, qinfo, data_type);
- _reference = compute_reference(shapes, qinfo, data_type);
+ _target = compute_target(shapes, qinfo, data_type, axis);
+ _reference = compute_reference(shapes, qinfo, data_type, axis);
}
protected:
@@ -93,7 +93,7 @@ protected:
library->fill_tensor_uniform(tensor, i);
}
- TensorType compute_target(std::vector<TensorShape> shapes, const std::vector<QuantizationInfo> &qinfo, DataType data_type)
+ TensorType compute_target(const std::vector<TensorShape> &shapes, const std::vector<QuantizationInfo> &qinfo, DataType data_type, unsigned int axis)
{
std::vector<TensorType> srcs;
std::vector<ITensorType *> src_ptrs;
@@ -107,13 +107,26 @@ protected:
src_ptrs.emplace_back(&srcs.back());
}
- TensorShape dst_shape = misc::shape_calculator::calculate_width_concatenate_shape(src_ptrs);
-
- TensorType dst = create_tensor<TensorType>(dst_shape, data_type, 1, qinfo[shapes.size()]);
+ const TensorShape dst_shape = misc::shape_calculator::calculate_concatenate_shape(src_ptrs, axis);
+ TensorType dst = create_tensor<TensorType>(dst_shape, data_type, 1, qinfo[shapes.size()]);
// Create and configure function
- FunctionType width_concat;
- width_concat.configure(src_ptrs, &dst);
+ FunctionType concat;
+ switch(axis)
+ {
+ case 0:
+ concat.configure(src_ptrs, &dst, DataLayoutDimension::WIDTH);
+ break;
+ case 1:
+ concat.configure(src_ptrs, &dst, DataLayoutDimension::HEIGHT);
+ break;
+ case 2:
+ concat.configure(src_ptrs, &dst, DataLayoutDimension::CHANNEL);
+ break;
+ default:
+ ARM_COMPUTE_ERROR("Not supported");
+ break;
+ }
for(auto &src : srcs)
{
@@ -140,12 +153,12 @@ protected:
}
// Compute function
- width_concat.run();
+ concat.run();
return dst;
}
- SimpleTensor<T> compute_reference(std::vector<TensorShape> shapes, const std::vector<QuantizationInfo> &qinfo, DataType data_type)
+ SimpleTensor<T> compute_reference(const std::vector<TensorShape> &shapes, const std::vector<QuantizationInfo> &qinfo, DataType data_type, unsigned int axis)
{
std::vector<SimpleTensor<T>> srcs;
@@ -156,10 +169,9 @@ protected:
fill(srcs.back(), j);
}
- const TensorShape dst_shape = calculate_width_concatenate_shape(shapes);
+ const TensorShape dst_shape = calculate_concatenate_shape(shapes, axis);
SimpleTensor<T> dst{ dst_shape, data_type, 1, qinfo[shapes.size()] };
-
- return reference::widthconcatenate_layer<T>(srcs, dst);
+ return reference::concatenate_layer<T>(srcs, dst, axis);
}
TensorType _target{};
diff --git a/tests/validation/fixtures/LSTMLayerFixture.h b/tests/validation/fixtures/LSTMLayerFixture.h
index b30f1e534b..2cf83b8b3d 100644
--- a/tests/validation/fixtures/LSTMLayerFixture.h
+++ b/tests/validation/fixtures/LSTMLayerFixture.h
@@ -29,11 +29,11 @@
#include "tests/framework/Fixture.h"
#include "tests/validation/reference/ActivationLayer.h"
#include "tests/validation/reference/ArithmeticOperations.h"
+#include "tests/validation/reference/ConcatenateLayer.h"
#include "tests/validation/reference/FullyConnectedLayer.h"
#include "tests/validation/reference/GEMM.h"
#include "tests/validation/reference/PixelWiseMultiplication.h"
#include "tests/validation/reference/Transpose.h"
-#include "tests/validation/reference/WidthConcatenateLayer.h"
namespace arm_compute
{
@@ -415,7 +415,7 @@ protected:
scratch_inputs.emplace_back(std::move(cell_state_out));
scratch_inputs.emplace_back(std::move(forget_gate));
scratch_inputs.emplace_back(std::move(output));
- scratch = reference::widthconcatenate_layer(scratch_inputs, scratch);
+ scratch = reference::concatenate_layer(scratch_inputs, scratch, Window::DimX);
_reference_scratch = std::move(scratch);
return output_state_out;
}