aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures/ArithmeticOperationsFixture.h
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/fixtures/ArithmeticOperationsFixture.h')
-rw-r--r--tests/validation/fixtures/ArithmeticOperationsFixture.h66
1 files changed, 39 insertions, 27 deletions
diff --git a/tests/validation/fixtures/ArithmeticOperationsFixture.h b/tests/validation/fixtures/ArithmeticOperationsFixture.h
index fbce864a89..5f97826b4a 100644
--- a/tests/validation/fixtures/ArithmeticOperationsFixture.h
+++ b/tests/validation/fixtures/ArithmeticOperationsFixture.h
@@ -48,10 +48,11 @@ public:
template <typename...>
void setup(reference::ArithmeticOperation op, const TensorShape &shape0, const TensorShape &shape1,
DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy,
- QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out, ActivationLayerInfo act_info)
+ QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out, ActivationLayerInfo act_info, bool in_place)
{
_op = op;
_act_info = act_info;
+ _in_place = in_place;
_target = compute_target(shape0, shape1, data_type0, data_type1, output_data_type, convert_policy, qinfo0, qinfo1, qinfo_out);
_reference = compute_reference(shape0, shape1, data_type0, data_type1, output_data_type, convert_policy, qinfo0, qinfo1, qinfo_out);
}
@@ -67,26 +68,27 @@ protected:
QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out)
{
// Create tensors
- TensorType ref_src1 = create_tensor<TensorType>(shape0, data_type0, 1, qinfo0);
- TensorType ref_src2 = create_tensor<TensorType>(shape1, data_type1, 1, qinfo1);
- TensorType dst = create_tensor<TensorType>(TensorShape::broadcast_shape(shape0, shape1), output_data_type, 1, qinfo_out);
+ TensorType ref_src1 = create_tensor<TensorType>(shape0, data_type0, 1, qinfo0);
+ TensorType ref_src2 = create_tensor<TensorType>(shape1, data_type1, 1, qinfo1);
+ TensorType dst = create_tensor<TensorType>(TensorShape::broadcast_shape(shape0, shape1), output_data_type, 1, qinfo_out);
+ TensorType *dst_to_use = _in_place ? &ref_src1 : &dst;
// Create and configure function
FunctionType arith_op;
- arith_op.configure(&ref_src1, &ref_src2, &dst, convert_policy, _act_info);
+ arith_op.configure(&ref_src1, &ref_src2, dst_to_use, convert_policy, _act_info);
ARM_COMPUTE_EXPECT(ref_src1.info()->is_resizable(), framework::LogLevel::ERRORS);
ARM_COMPUTE_EXPECT(ref_src2.info()->is_resizable(), framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(dst.info()->is_resizable(), framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(dst_to_use->info()->is_resizable(), framework::LogLevel::ERRORS);
// Allocate tensors
ref_src1.allocator()->allocate();
ref_src2.allocator()->allocate();
- dst.allocator()->allocate();
+ dst_to_use->allocator()->allocate();
ARM_COMPUTE_EXPECT(!ref_src1.info()->is_resizable(), framework::LogLevel::ERRORS);
ARM_COMPUTE_EXPECT(!ref_src2.info()->is_resizable(), framework::LogLevel::ERRORS);
- ARM_COMPUTE_EXPECT(!dst.info()->is_resizable(), framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(!dst_to_use->info()->is_resizable(), framework::LogLevel::ERRORS);
// Fill tensors
fill(AccessorType(ref_src1), 0);
@@ -95,6 +97,10 @@ protected:
// Compute function
arith_op.run();
+ if(_in_place)
+ {
+ return ref_src1;
+ }
return dst;
}
@@ -102,23 +108,28 @@ protected:
DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy,
QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out)
{
+ // current in-place implementation only supports same metadata of input and output tensors.
+ // By ignoring output quantization information here, we can make test cases implementation much simpler.
+ QuantizationInfo output_qinfo = _in_place ? qinfo0 : qinfo_out;
+
// Create reference
SimpleTensor<T> ref_src1{ shape0, data_type0, 1, qinfo0 };
SimpleTensor<T> ref_src2{ shape1, data_type1, 1, qinfo1 };
- SimpleTensor<T> ref_dst{ TensorShape::broadcast_shape(shape0, shape1), output_data_type, 1, qinfo_out };
+ SimpleTensor<T> ref_dst{ TensorShape::broadcast_shape(shape0, shape1), output_data_type, 1, output_qinfo };
// Fill reference
fill(ref_src1, 0);
fill(ref_src2, 1);
auto result = reference::arithmetic_operation<T>(_op, ref_src1, ref_src2, ref_dst, convert_policy);
- return _act_info.enabled() ? reference::activation_layer(result, _act_info, qinfo_out) : result;
+ return _act_info.enabled() ? reference::activation_layer(result, _act_info, output_qinfo) : result;
}
TensorType _target{};
SimpleTensor<T> _reference{};
reference::ArithmeticOperation _op{ reference::ArithmeticOperation::ADD };
ActivationLayerInfo _act_info{};
+ bool _in_place{};
};
template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
@@ -129,7 +140,7 @@ public:
void setup(const TensorShape &shape0, const TensorShape &shape1, DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy)
{
ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::ADD, shape0, shape1, data_type0, data_type1,
- output_data_type, convert_policy, QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), ActivationLayerInfo());
+ output_data_type, convert_policy, QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), ActivationLayerInfo(), false);
}
};
@@ -141,7 +152,7 @@ public:
void setup(const TensorShape &shape, DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy)
{
ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::ADD, shape, shape, data_type0, data_type1,
- output_data_type, convert_policy, QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), ActivationLayerInfo());
+ output_data_type, convert_policy, QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), ActivationLayerInfo(), false);
}
};
@@ -153,7 +164,7 @@ public:
void setup(const TensorShape &shape0, const TensorShape &shape1, DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy, ActivationLayerInfo act_info)
{
ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::ADD, shape0, shape1, data_type0, data_type1,
- output_data_type, convert_policy, QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), act_info);
+ output_data_type, convert_policy, QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), act_info, false);
}
};
@@ -165,7 +176,7 @@ public:
void setup(const TensorShape &shape, DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy, ActivationLayerInfo act_info)
{
ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::ADD, shape, shape, data_type0, data_type1,
- output_data_type, convert_policy, QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), act_info);
+ output_data_type, convert_policy, QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), act_info, false);
}
};
@@ -179,7 +190,7 @@ public:
{
ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::ADD, shape, shape, data_type0, data_type1,
- output_data_type, convert_policy, qinfo0, qinfo1, qinfo_out, ActivationLayerInfo());
+ output_data_type, convert_policy, qinfo0, qinfo1, qinfo_out, ActivationLayerInfo(), false);
}
};
@@ -188,11 +199,11 @@ class ArithmeticSubtractionBroadcastValidationFixture : public ArithmeticOperati
{
public:
template <typename...>
- void setup(const TensorShape &shape0, const TensorShape &shape1, DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy)
+ void setup(const TensorShape &shape0, const TensorShape &shape1, DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy, bool in_place)
{
ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::SUB, shape0, shape1,
data_type0, data_type1, output_data_type, convert_policy,
- QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), ActivationLayerInfo());
+ QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), ActivationLayerInfo(), in_place);
}
};
@@ -201,11 +212,12 @@ class ArithmeticSubtractionBroadcastValidationFloatFixture : public ArithmeticOp
{
public:
template <typename...>
- void setup(const TensorShape &shape0, const TensorShape &shape1, DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy, ActivationLayerInfo act_info)
+ void setup(const TensorShape &shape0, const TensorShape &shape1, DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy, ActivationLayerInfo act_info,
+ bool in_place)
{
ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::SUB, shape0, shape1,
data_type0, data_type1, output_data_type, convert_policy,
- QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), act_info);
+ QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), act_info, in_place);
}
};
@@ -214,11 +226,11 @@ class ArithmeticSubtractionValidationFixture : public ArithmeticOperationGeneric
{
public:
template <typename...>
- void setup(const TensorShape &shape, DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy)
+ void setup(const TensorShape &shape, DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy, bool in_place)
{
ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::SUB, shape, shape,
data_type0, data_type1, output_data_type, convert_policy,
- QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), ActivationLayerInfo());
+ QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), ActivationLayerInfo(), in_place);
}
};
@@ -227,11 +239,11 @@ class ArithmeticSubtractionValidationFloatFixture : public ArithmeticOperationGe
{
public:
template <typename...>
- void setup(const TensorShape &shape, DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy, ActivationLayerInfo act_info)
+ void setup(const TensorShape &shape, DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy, ActivationLayerInfo act_info, bool in_place)
{
ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::SUB, shape, shape,
data_type0, data_type1, output_data_type, convert_policy,
- QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), act_info);
+ QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), act_info, in_place);
}
};
@@ -241,12 +253,12 @@ class ArithmeticSubtractionValidationQuantizedFixture : public ArithmeticOperati
public:
template <typename...>
void setup(const TensorShape &shape, DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy,
- QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out)
+ QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out, bool in_place)
{
ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::SUB, shape, shape,
data_type0, data_type1, output_data_type,
- convert_policy, qinfo0, qinfo1, qinfo_out, ActivationLayerInfo());
+ convert_policy, qinfo0, qinfo1, qinfo_out, ActivationLayerInfo(), in_place);
}
};
@@ -256,11 +268,11 @@ class ArithmeticSubtractionValidationQuantizedBroadcastFixture : public Arithmet
public:
template <typename...>
void setup(const TensorShape &shape0, const TensorShape &shape1, DataType data_type0, DataType data_type1, DataType output_data_type,
- ConvertPolicy convert_policy, QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out)
+ ConvertPolicy convert_policy, QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out, bool in_place)
{
ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::SUB, shape0, shape1,
data_type0, data_type1, output_data_type, convert_policy,
- qinfo0, qinfo1, qinfo_out, ActivationLayerInfo());
+ qinfo0, qinfo1, qinfo_out, ActivationLayerInfo(), in_place);
}
};
} // namespace validation