aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures/FullyConnectedLayerFixture.h
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/fixtures/FullyConnectedLayerFixture.h')
-rw-r--r--tests/validation/fixtures/FullyConnectedLayerFixture.h126
1 files changed, 20 insertions, 106 deletions
diff --git a/tests/validation/fixtures/FullyConnectedLayerFixture.h b/tests/validation/fixtures/FullyConnectedLayerFixture.h
index ccd9182ae9..7d767642f3 100644
--- a/tests/validation/fixtures/FullyConnectedLayerFixture.h
+++ b/tests/validation/fixtures/FullyConnectedLayerFixture.h
@@ -273,7 +273,7 @@ public:
};
template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
-class FullyConnectedWithDynamicTensorsFixture : public framework::Fixture
+class FullyConnectedWithDynamicWeightsFixture : public framework::Fixture
{
private:
template <typename U>
@@ -289,16 +289,6 @@ private:
std::uniform_real_distribution<float> distribution(-1.0f, 1.0f);
library->fill(tensor, distribution, i);
}
- else if(_data_type == DataType::QASYMM8)
- {
- std::uniform_int_distribution<uint8_t> distribution(0, 30);
- library->fill(tensor, distribution, i);
- }
- else if(_data_type == DataType::S32)
- {
- std::uniform_int_distribution<int32_t> distribution(-50, 50);
- library->fill(tensor, distribution, i);
- }
else
{
library->fill_tensor_uniform(tensor, i);
@@ -334,11 +324,6 @@ private:
constexpr AbsoluteTolerance<float> abs_tolerance_f32(0.0001f);
validate(AccessorType(target), ref, rel_tolerance_f32, 0, abs_tolerance_f32);
}
- else if(_data_type == DataType::QASYMM8)
- {
- constexpr AbsoluteTolerance<uint8_t> tolerance_qasymm8(1);
- validate(AccessorType(target), ref, tolerance_qasymm8);
- }
else
{
validate(AccessorType(target), ref);
@@ -346,51 +331,32 @@ private:
}
public:
- using TDecay = typename std::decay<T>::type;
- using TBias = typename std::conditional < (std::is_same<TDecay, uint8_t>::value || std::is_same<TDecay, int8_t>::value), int32_t, T >::type;
-
template <typename...>
void setup(TensorShape src_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape dst_shape,
- DataType data_type, ActivationLayerInfo activation_info, bool constant_weights, bool constant_bias)
+ DataType data_type, ActivationLayerInfo activation_info)
{
_data_type = data_type;
- const bool is_quantized = is_data_type_quantized(data_type);
-
- const DataType bias_data_type = (is_quantized) ? DataType::S32 : data_type;
-
- const QuantizationInfo src_qinfo = is_quantized ? QuantizationInfo(0.1f, 10) : QuantizationInfo();
- const QuantizationInfo weights_qinfo = is_quantized ? QuantizationInfo(0.3f, 20) : QuantizationInfo();
- const QuantizationInfo dst_qinfo = is_quantized ? QuantizationInfo(0.2f, 5) : QuantizationInfo();
-
// Setup tensor meta-data
- const TensorInfo src_info(src_shape, 1, data_type, src_qinfo);
+ TensorInfo src_info(src_shape, 1, data_type);
_src.allocator()->init(src_info);
- TensorInfo wei_info(weights_shape, 1, data_type, weights_qinfo);
- if(!constant_weights)
- {
- const TensorShape tr_weights_shape{ weights_shape[1], weights_shape[0] };
- wei_info.set_tensor_shape(tr_weights_shape);
- }
- wei_info.set_are_values_constant(constant_weights);
+ TensorShape tr_weights_shape{ weights_shape[1], weights_shape[0] };
+ TensorInfo wei_info(tr_weights_shape, 1, data_type);
_weights.allocator()->init(wei_info);
- TensorInfo bias_info(bias_shape, 1, bias_data_type);
- bias_info.set_are_values_constant(constant_bias);
+ TensorInfo bias_info(bias_shape, 1, data_type);
_bias.allocator()->init(bias_info);
- const TensorInfo dst_info(dst_shape, 1, data_type, dst_qinfo);
+ TensorInfo dst_info(dst_shape, 1, data_type);
_dst.allocator()->init(dst_info);
// Configure FC layer and mark the weights as non constant
FullyConnectedLayerInfo fc_info;
- fc_info.activation_info = activation_info;
- if(!constant_weights)
- {
- fc_info.are_weights_reshaped = true;
- fc_info.transpose_weights = false;
- }
+ fc_info.activation_info = activation_info;
+ fc_info.are_weights_reshaped = true;
+ fc_info.transpose_weights = false;
+ fc_info.constant_weights = false;
FunctionType fc;
fc.configure(&_src, &_weights, &_bias, &_dst, fc_info);
@@ -403,55 +369,29 @@ public:
// Run multiple iterations with different inputs
constexpr int num_iterations = 5;
int randomizer_offset = 0;
-
- // Create reference tensors
- SimpleTensor<T> src{ src_shape, data_type, 1, src_qinfo };
- SimpleTensor<T> weights{ weights_shape, data_type, 1, weights_qinfo };
- SimpleTensor<TBias> bias{ bias_shape, bias_data_type };
-
- // Fill weights and/or bias if they remain constant
- if(constant_weights)
- {
- fill(AccessorType(_weights), 1);
- fill(weights, 1);
- }
- if(constant_bias)
- {
- fill(AccessorType(_bias), 2);
- fill(bias, 2);
- }
-
for(int i = 0; i < num_iterations; ++i)
{
// Run target
{
fill(AccessorType(_src), randomizer_offset);
- if(!constant_weights)
- {
- fill_transposed_weights(_weights, weights_shape, randomizer_offset + 1);
- }
- if(!constant_bias)
- {
- fill(AccessorType(_bias), randomizer_offset + 2);
- }
+ fill_transposed_weights(_weights, weights_shape, randomizer_offset + 1);
+ fill(AccessorType(_bias), randomizer_offset + 2);
fc.run();
}
// Run reference and compare
{
+ SimpleTensor<T> src{ src_shape, data_type };
+ SimpleTensor<T> weights{ weights_shape, data_type };
+ SimpleTensor<T> bias{ bias_shape, data_type };
+
// Fill reference
fill(src, randomizer_offset);
- if(!constant_weights)
- {
- fill(weights, randomizer_offset + 1);
- }
- if(!constant_bias)
- {
- fill(bias, randomizer_offset + 2);
- }
+ fill(weights, randomizer_offset + 1);
+ fill(bias, randomizer_offset + 2);
- auto dst = reference::activation_layer(reference::fully_connected_layer<T>(src, weights, bias, dst_shape), activation_info, dst_qinfo);
+ auto dst = reference::activation_layer(reference::fully_connected_layer<T>(src, weights, bias, dst_shape), activation_info);
// Validate
validate_with_tolerance(_dst, dst);
@@ -465,32 +405,6 @@ private:
TensorType _src{}, _weights{}, _bias{}, _dst{};
DataType _data_type{ DataType::UNKNOWN };
};
-
-template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
-class FullyConnectedWithDynamicWeightsFixture : public FullyConnectedWithDynamicTensorsFixture<TensorType, AccessorType, FunctionType, T>
-{
-public:
- template <typename...>
- void setup(TensorShape src_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape dst_shape,
- DataType data_type, ActivationLayerInfo activation_info)
- {
- FullyConnectedWithDynamicTensorsFixture<TensorType, AccessorType, FunctionType, T>::setup(src_shape, weights_shape, bias_shape,
- dst_shape, data_type, activation_info, false, true);
- }
-};
-
-template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
-class FullyConnectedWithDynamicBiasFixture : public FullyConnectedWithDynamicTensorsFixture<TensorType, AccessorType, FunctionType, T>
-{
-public:
- template <typename...>
- void setup(TensorShape src_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape dst_shape,
- DataType data_type, ActivationLayerInfo activation_info)
- {
- FullyConnectedWithDynamicTensorsFixture<TensorType, AccessorType, FunctionType, T>::setup(src_shape, weights_shape, bias_shape,
- dst_shape, data_type, activation_info, true, false);
- }
-};
} // namespace validation
} // namespace test
} // namespace arm_compute