aboutsummaryrefslogtreecommitdiff
path: root/tests/Utils.h
diff options
context:
space:
mode:
Diffstat (limited to 'tests/Utils.h')
-rw-r--r--tests/Utils.h10
1 files changed, 6 insertions, 4 deletions
diff --git a/tests/Utils.h b/tests/Utils.h
index ea70fffe3a..3bb6060951 100644
--- a/tests/Utils.h
+++ b/tests/Utils.h
@@ -520,14 +520,15 @@ inline bool is_in_valid_region(const ValidRegion &valid_region, Coordinates coor
* @param[in] num_channels (Optional) Number of channels.
* @param[in] quantization_info (Optional) Quantization info for asymmetric quantized types.
* @param[in] data_layout (Optional) Data layout. Default is NCHW.
+ * @param[in] ctx (Optional) Pointer to the runtime context.
*
* @return Initialized tensor of given type.
*/
template <typename T>
inline T create_tensor(const TensorShape &shape, DataType data_type, int num_channels = 1,
- QuantizationInfo quantization_info = QuantizationInfo(), DataLayout data_layout = DataLayout::NCHW)
+ QuantizationInfo quantization_info = QuantizationInfo(), DataLayout data_layout = DataLayout::NCHW, IRuntimeContext *ctx = nullptr)
{
- T tensor;
+ T tensor(ctx);
TensorInfo info(shape, num_channels, data_type);
info.set_quantization_info(quantization_info);
info.set_data_layout(data_layout);
@@ -540,15 +541,16 @@ inline T create_tensor(const TensorShape &shape, DataType data_type, int num_cha
*
* @param[in] shape Tensor shape.
* @param[in] format Format type.
+ * @param[in] ctx (Optional) Pointer to the runtime context.
*
* @return Initialized tensor of given type.
*/
template <typename T>
-inline T create_tensor(const TensorShape &shape, Format format)
+inline T create_tensor(const TensorShape &shape, Format format, IRuntimeContext *ctx = nullptr)
{
TensorInfo info(shape, format);
- T tensor;
+ T tensor(ctx);
tensor.allocator()->init(info);
return tensor;