aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures/FullyConnectedLayerFixture.h
diff options
context:
space:
mode:
authorGiorgio Arena <giorgio.arena@arm.com>2018-07-16 17:20:38 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:54:54 +0000
commita855af10a486c53c2271361cb87f349eca64b749 (patch)
treeb326b63bdcaf76c9620b1bbf22942d4683503a65 /tests/validation/fixtures/FullyConnectedLayerFixture.h
parent5a3ee4f708a9e1642b0211955ff905e7b67e831d (diff)
downloadComputeLibrary-a855af10a486c53c2271361cb87f349eca64b749.tar.gz
COMPMID-1401 Implement NEFullyConnectedLayer for QASYMM8
Change-Id: I0404df6d369855e2f458f2db8f26e81c80a1ee87 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/140148 Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com> Reviewed-by: Anthony Barbier <anthony.barbier@arm.com> Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com> Tested-by: Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'tests/validation/fixtures/FullyConnectedLayerFixture.h')
-rw-r--r--tests/validation/fixtures/FullyConnectedLayerFixture.h54
1 files changed, 18 insertions, 36 deletions
diff --git a/tests/validation/fixtures/FullyConnectedLayerFixture.h b/tests/validation/fixtures/FullyConnectedLayerFixture.h
index 18321480f8..49c3be0c2e 100644
--- a/tests/validation/fixtures/FullyConnectedLayerFixture.h
+++ b/tests/validation/fixtures/FullyConnectedLayerFixture.h
@@ -45,7 +45,7 @@ namespace test
{
namespace validation
{
-template <typename TensorType, typename AccessorType, typename FunctionType, typename T, bool run_interleave>
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
class FullyConnectedLayerValidationGenericFixture : public framework::Fixture
{
public:
@@ -103,8 +103,8 @@ protected:
// -----------+-----------+---------------------------
// transpose | | ***
// -----------+-----------+---------------------------
- // !transpose | transpose | transpose &
- // | | transpose1xW (if required)
+ // !transpose | transpose | transpose
+ // | |
//
// ***: That combination is invalid. But we can ignore the transpose flag and handle all !reshape the same
if(!reshape_weights || !transpose_weights)
@@ -112,16 +112,6 @@ protected:
const size_t shape_x = reshaped_weights_shape.x();
reshaped_weights_shape.set(0, reshaped_weights_shape.y());
reshaped_weights_shape.set(1, shape_x);
-
- // Weights have to be passed reshaped
- // Transpose 1xW for batched version
- if(!reshape_weights && output_shape.y() > 1 && run_interleave)
- {
- const int transpose_width = 16 / data_size_from_type(_data_type);
- const float shape_x = reshaped_weights_shape.x();
- reshaped_weights_shape.set(0, reshaped_weights_shape.y() * transpose_width);
- reshaped_weights_shape.set(1, static_cast<unsigned int>(std::ceil(shape_x / transpose_width)));
- }
}
// Create tensors
@@ -170,14 +160,6 @@ protected:
// Transpose elementwise
tmp = transpose(tmp);
- // Reshape weights for batched runs
- if(!reshape_weights && output_shape.y() > 1 && run_interleave)
- {
- // Transpose with interleave
- const int interleave_size = 16 / tmp.element_size();
- tmp = transpose(tmp, interleave_size);
- }
-
AccessorType weights_accessor(weights);
for(int i = 0; i < tmp.num_elements(); ++i)
@@ -222,43 +204,43 @@ protected:
QuantizationInfo _quantization_info{};
};
-template <typename TensorType, typename AccessorType, typename FunctionType, typename T, bool run_interleave>
-class FullyConnectedLayerValidationFixture : public FullyConnectedLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T, run_interleave>
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
+class FullyConnectedLayerValidationFixture : public FullyConnectedLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>
{
public:
template <typename...>
void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, bool transpose_weights, bool reshape_weights, DataType data_type)
{
- FullyConnectedLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T, run_interleave>::setup(input_shape, weights_shape, bias_shape, output_shape, transpose_weights,
- reshape_weights, data_type,
- QuantizationInfo());
+ FullyConnectedLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(input_shape, weights_shape, bias_shape, output_shape, transpose_weights,
+ reshape_weights, data_type,
+ QuantizationInfo());
}
};
-template <typename TensorType, typename AccessorType, typename FunctionType, typename T, bool run_interleave>
-class FullyConnectedLayerValidationFixedPointFixture : public FullyConnectedLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T, run_interleave>
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
+class FullyConnectedLayerValidationFixedPointFixture : public FullyConnectedLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>
{
public:
template <typename...>
void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, bool transpose_weights, bool reshape_weights, DataType data_type)
{
- FullyConnectedLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T, run_interleave>::setup(input_shape, weights_shape, bias_shape, output_shape, transpose_weights,
- reshape_weights, data_type,
- QuantizationInfo());
+ FullyConnectedLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(input_shape, weights_shape, bias_shape, output_shape, transpose_weights,
+ reshape_weights, data_type,
+ QuantizationInfo());
}
};
-template <typename TensorType, typename AccessorType, typename FunctionType, typename T, bool run_interleave>
-class FullyConnectedLayerValidationQuantizedFixture : public FullyConnectedLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T, run_interleave>
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
+class FullyConnectedLayerValidationQuantizedFixture : public FullyConnectedLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>
{
public:
template <typename...>
void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, bool transpose_weights, bool reshape_weights, DataType data_type,
QuantizationInfo quantization_info)
{
- FullyConnectedLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T, run_interleave>::setup(input_shape, weights_shape, bias_shape, output_shape, transpose_weights,
- reshape_weights, data_type,
- quantization_info);
+ FullyConnectedLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(input_shape, weights_shape, bias_shape, output_shape, transpose_weights,
+ reshape_weights, data_type,
+ quantization_info);
}
};
} // namespace validation