aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures/FullyConnectedLayerFixture.h
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/fixtures/FullyConnectedLayerFixture.h')
-rw-r--r--tests/validation/fixtures/FullyConnectedLayerFixture.h63
1 files changed, 43 insertions, 20 deletions
diff --git a/tests/validation/fixtures/FullyConnectedLayerFixture.h b/tests/validation/fixtures/FullyConnectedLayerFixture.h
index 75bef144ad..e13c01d1e2 100644
--- a/tests/validation/fixtures/FullyConnectedLayerFixture.h
+++ b/tests/validation/fixtures/FullyConnectedLayerFixture.h
@@ -335,9 +335,9 @@ private:
void validate_with_tolerance(TensorType &target, SimpleTensor<half_float::half> &ref)
{
- constexpr AbsoluteTolerance<float> abs_tolerance_f16(0.3f);
+ constexpr AbsoluteTolerance<float> abs_tolerance_f16(0.3f);
const RelativeTolerance<half_float::half> rel_tolerance_f16(half_float::half(0.2f));
- constexpr float tolerance_num_f16 = 0.07f;
+ constexpr float tolerance_num_f16 = 0.07f;
validate(AccessorType(target), ref, rel_tolerance_f16, tolerance_num_f16, abs_tolerance_f16);
}
@@ -360,36 +360,36 @@ public:
template <typename...>
void setup(TensorShape src_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape dst_shape,
- DataType data_type, ActivationLayerInfo activation_info, bool constant_weights, bool constant_bias, bool weights_reshaped)
+ DataType data_type, ActivationLayerInfo activation_info, bool constant_weights, bool constant_bias, bool weights_reshaped, bool remove_bias = false)
{
_data_type = data_type;
- const bool is_quantized = is_data_type_quantized(data_type);
-
+ const bool is_quantized = is_data_type_quantized(data_type);
const DataType bias_data_type = (is_quantized) ? DataType::S32 : data_type;
const QuantizationInfo src_qinfo = is_quantized ? QuantizationInfo(0.1f, 10) : QuantizationInfo();
const QuantizationInfo weights_qinfo = is_quantized ? QuantizationInfo(0.3f, 20) : QuantizationInfo();
const QuantizationInfo dst_qinfo = is_quantized ? QuantizationInfo(0.2f, 5) : QuantizationInfo();
- // Setup tensor meta-data
+ // Configure TensorInfo Objects
const TensorInfo src_info(src_shape, 1, data_type, src_qinfo);
- _src.allocator()->init(src_info);
+ const TensorInfo dst_info(dst_shape, 1, data_type, dst_qinfo);
+ TensorInfo bias_info(bias_shape, 1, bias_data_type);
+ TensorInfo wei_info(weights_shape, 1, data_type, weights_qinfo);
- TensorInfo wei_info(weights_shape, 1, data_type, weights_qinfo);
if(!constant_weights && weights_reshaped)
{
const TensorShape tr_weights_shape{ weights_shape[1], weights_shape[0] };
wei_info.set_tensor_shape(tr_weights_shape);
}
wei_info.set_are_values_constant(constant_weights);
- _weights.allocator()->init(wei_info);
-
- TensorInfo bias_info(bias_shape, 1, bias_data_type);
bias_info.set_are_values_constant(constant_bias);
- _bias.allocator()->init(bias_info);
- const TensorInfo dst_info(dst_shape, 1, data_type, dst_qinfo);
+ // Initialise Tensors
+ _src.allocator()->init(src_info);
+ _weights.allocator()->init(wei_info);
+ if(!remove_bias)
+ _bias.allocator()->init(bias_info);
_dst.allocator()->init(dst_info);
// Configure FC layer and mark the weights as non constant
@@ -401,12 +401,13 @@ public:
fc_info.transpose_weights = !weights_reshaped;
}
FunctionType fc;
- fc.configure(&_src, &_weights, &_bias, &_dst, fc_info);
+ fc.configure(&_src, &_weights, (remove_bias) ? nullptr : &_bias, &_dst, fc_info);
// Allocate all the tensors
_src.allocator()->allocate();
_weights.allocator()->allocate();
- _bias.allocator()->allocate();
+ if(!remove_bias)
+ _bias.allocator()->allocate();
_dst.allocator()->allocate();
// Run multiple iterations with different inputs
@@ -424,11 +425,20 @@ public:
fill(AccessorType(_weights), 1);
fill(weights, 1);
}
- if(constant_bias)
+ if(constant_bias && !remove_bias)
{
fill(AccessorType(_bias), 2);
fill(bias, 2);
}
+ // To remove bias, fill with 0
+ if(remove_bias && is_quantized)
+ {
+ library->fill_tensor_value(bias, 0);
+ }
+ else if(remove_bias)
+ {
+ library->fill_tensor_value(bias, (float)0.0);
+ }
for(int i = 0; i < num_iterations; ++i)
{
@@ -446,7 +456,7 @@ public:
fill(AccessorType(_weights), randomizer_offset + 1);
}
}
- if(!constant_bias)
+ if(!constant_bias && !remove_bias)
{
fill(AccessorType(_bias), randomizer_offset + 2);
}
@@ -462,7 +472,7 @@ public:
{
fill(weights, randomizer_offset + 1);
}
- if(!constant_bias)
+ if(!constant_bias && !remove_bias)
{
fill(bias, randomizer_offset + 2);
}
@@ -491,7 +501,20 @@ public:
DataType data_type, ActivationLayerInfo activation_info, bool weights_reshaped)
{
FullyConnectedWithDynamicTensorsFixture<TensorType, AccessorType, FunctionType, T>::setup(src_shape, weights_shape, bias_shape,
- dst_shape, data_type, activation_info, false, true, weights_reshaped);
+ dst_shape, data_type, activation_info, false, true, weights_reshaped, false);
+ }
+};
+
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
+class FullyConnectedDynamicNoBiasFixture : public FullyConnectedWithDynamicTensorsFixture<TensorType, AccessorType, FunctionType, T>
+{
+public:
+ template <typename...>
+ void setup(TensorShape src_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape dst_shape,
+ DataType data_type, ActivationLayerInfo activation_info, bool weights_reshaped)
+ {
+ FullyConnectedWithDynamicTensorsFixture<TensorType, AccessorType, FunctionType, T>::setup(src_shape, weights_shape, bias_shape,
+ dst_shape, data_type, activation_info, false, true, weights_reshaped, true);
}
};
@@ -504,7 +527,7 @@ public:
DataType data_type, ActivationLayerInfo activation_info)
{
FullyConnectedWithDynamicTensorsFixture<TensorType, AccessorType, FunctionType, T>::setup(src_shape, weights_shape, bias_shape,
- dst_shape, data_type, activation_info, true, false, false /* weights_reshaped (not used) */);
+ dst_shape, data_type, activation_info, true, false, false, false /* weights_reshaped (not used) */);
}
};
} // namespace validation