diff options
Diffstat (limited to 'tests/Utils.h')
-rw-r--r-- | tests/Utils.h | 10 |
1 files changed, 6 insertions, 4 deletions
diff --git a/tests/Utils.h b/tests/Utils.h index ea70fffe3a..3bb6060951 100644 --- a/tests/Utils.h +++ b/tests/Utils.h @@ -520,14 +520,15 @@ inline bool is_in_valid_region(const ValidRegion &valid_region, Coordinates coor * @param[in] num_channels (Optional) Number of channels. * @param[in] quantization_info (Optional) Quantization info for asymmetric quantized types. * @param[in] data_layout (Optional) Data layout. Default is NCHW. + * @param[in] ctx (Optional) Pointer to the runtime context. * * @return Initialized tensor of given type. */ template <typename T> inline T create_tensor(const TensorShape &shape, DataType data_type, int num_channels = 1, - QuantizationInfo quantization_info = QuantizationInfo(), DataLayout data_layout = DataLayout::NCHW) + QuantizationInfo quantization_info = QuantizationInfo(), DataLayout data_layout = DataLayout::NCHW, IRuntimeContext *ctx = nullptr) { - T tensor; + T tensor(ctx); TensorInfo info(shape, num_channels, data_type); info.set_quantization_info(quantization_info); info.set_data_layout(data_layout); @@ -540,15 +541,16 @@ inline T create_tensor(const TensorShape &shape, DataType data_type, int num_cha * * @param[in] shape Tensor shape. * @param[in] format Format type. + * @param[in] ctx (Optional) Pointer to the runtime context. * * @return Initialized tensor of given type. */ template <typename T> -inline T create_tensor(const TensorShape &shape, Format format) +inline T create_tensor(const TensorShape &shape, Format format, IRuntimeContext *ctx = nullptr) { TensorInfo info(shape, format); - T tensor; + T tensor(ctx); tensor.allocator()->init(info); return tensor; |