aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures
diff options
context:
space:
mode:
authorViet-Hoa Do <viet-hoa.do@arm.com>2023-05-17 15:17:48 +0100
committerViet-Hoa Do <viet-hoa.do@arm.com>2023-06-12 14:41:18 +0000
commit3fcf3dcf7b6ffc613468ccaca580bde495677440 (patch)
tree567c85806318a868b0029f969cac823dd48fe44a /tests/validation/fixtures
parent48cfd5f7895f13167e4e9cd974dbc1e983e04ed7 (diff)
downloadComputeLibrary-3fcf3dcf7b6ffc613468ccaca580bde495677440.tar.gz
Add multi-sketch support for dynamic fusion
* Tensors are owned by workload context instead of workload sketch so that they can be used by multiple sketches. * Add an integration test for multi-sketch case. Resolves: COMPMID-6148 Signed-off-by: Viet-Hoa Do <viet-hoa.do@arm.com> Change-Id: I37d0de5ac103fb2a85020aa1c26e49eb304f47b7 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/9706 Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: SiCong Li <sicong.li@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Benchmark: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'tests/validation/fixtures')
-rw-r--r--tests/validation/fixtures/dynamic_fusion/gpu/cl/DepthwiseConv2dFixture.h12
-rw-r--r--tests/validation/fixtures/dynamic_fusion/gpu/cl/DirectConv2dFixture.h24
-rw-r--r--tests/validation/fixtures/dynamic_fusion/gpu/cl/ElementwiseBinaryFixture.h12
-rw-r--r--tests/validation/fixtures/dynamic_fusion/gpu/cl/Pool2dFixture.h8
-rw-r--r--tests/validation/fixtures/dynamic_fusion/operators/ActivationFixture.h8
-rw-r--r--tests/validation/fixtures/dynamic_fusion/operators/CastFixture.h8
-rw-r--r--tests/validation/fixtures/dynamic_fusion/operators/ClampFixture.h8
-rw-r--r--tests/validation/fixtures/dynamic_fusion/operators/MulFixture.h56
-rw-r--r--tests/validation/fixtures/dynamic_fusion/operators/ReshapeFixture.h8
-rw-r--r--tests/validation/fixtures/dynamic_fusion/operators/ResizeFixture.h8
-rw-r--r--tests/validation/fixtures/dynamic_fusion/operators/SoftmaxFixture.h8
11 files changed, 80 insertions, 80 deletions
diff --git a/tests/validation/fixtures/dynamic_fusion/gpu/cl/DepthwiseConv2dFixture.h b/tests/validation/fixtures/dynamic_fusion/gpu/cl/DepthwiseConv2dFixture.h
index b15de71707..bea1d9bf4b 100644
--- a/tests/validation/fixtures/dynamic_fusion/gpu/cl/DepthwiseConv2dFixture.h
+++ b/tests/validation/fixtures/dynamic_fusion/gpu/cl/DepthwiseConv2dFixture.h
@@ -126,14 +126,14 @@ protected:
// Create a new workload sketch
auto cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
- auto gpu_ctx = GpuWorkloadContext{ &cl_compile_ctx };
- GpuWorkloadSketch sketch{ &gpu_ctx };
+ auto context = GpuWorkloadContext{ &cl_compile_ctx };
+ GpuWorkloadSketch sketch{ &context };
// Create sketch tensors
- TensorInfo input_info = sketch.create_tensor_info(TensorInfo(input_shape, 1, _data_type, _data_layout));
- TensorInfo weight_info = sketch.create_tensor_info(TensorInfo(weights_shape, 1, _data_type, _data_layout));
- TensorInfo bias_info = sketch.create_tensor_info(TensorInfo(bias_shape, 1, _data_type, _data_layout));
- TensorInfo dst_info = sketch.create_tensor_info();
+ TensorInfo input_info = context.create_tensor_info(TensorInfo(input_shape, 1, _data_type, _data_layout));
+ TensorInfo weight_info = context.create_tensor_info(TensorInfo(weights_shape, 1, _data_type, _data_layout));
+ TensorInfo bias_info = context.create_tensor_info(TensorInfo(bias_shape, 1, _data_type, _data_layout));
+ TensorInfo dst_info = context.create_tensor_info();
ITensorInfo *ans_info = FunctionType::create_op(sketch, &input_info, &weight_info, &bias_info, dwc_conv2d_attr);
GpuOutput::create_op(sketch, ans_info, &dst_info);
diff --git a/tests/validation/fixtures/dynamic_fusion/gpu/cl/DirectConv2dFixture.h b/tests/validation/fixtures/dynamic_fusion/gpu/cl/DirectConv2dFixture.h
index d9ce4dff18..81dfc2b8e2 100644
--- a/tests/validation/fixtures/dynamic_fusion/gpu/cl/DirectConv2dFixture.h
+++ b/tests/validation/fixtures/dynamic_fusion/gpu/cl/DirectConv2dFixture.h
@@ -115,14 +115,14 @@ protected:
// Create a new workload sketch
auto cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
- auto gpu_ctx = GpuWorkloadContext{ &cl_compile_ctx };
- GpuWorkloadSketch sketch{ &gpu_ctx };
+ auto context = GpuWorkloadContext{ &cl_compile_ctx };
+ GpuWorkloadSketch sketch{ &context };
// Create sketch tensors
- TensorInfo input_info = sketch.create_tensor_info(TensorInfo(input_shape, 1, _data_type, _data_layout));
- TensorInfo weight_info = sketch.create_tensor_info(TensorInfo(weights_shape, 1, _data_type, _data_layout));
- TensorInfo bias_info = sketch.create_tensor_info(TensorInfo(bias_shape, 1, _data_type, _data_layout));
- TensorInfo dst_info = sketch.create_tensor_info();
+ TensorInfo input_info = context.create_tensor_info(TensorInfo(input_shape, 1, _data_type, _data_layout));
+ TensorInfo weight_info = context.create_tensor_info(TensorInfo(weights_shape, 1, _data_type, _data_layout));
+ TensorInfo bias_info = context.create_tensor_info(TensorInfo(bias_shape, 1, _data_type, _data_layout));
+ TensorInfo dst_info = context.create_tensor_info();
ITensorInfo *ans_info = FunctionType::create_op(sketch, &input_info, &weight_info, &bias_info, conv2d_attr);
GpuOutput::create_op(sketch, ans_info, &dst_info);
@@ -256,14 +256,14 @@ protected:
permute(output_shape, PermutationVector(2U, 0U, 1U));
auto cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
- auto gpu_ctx = GpuWorkloadContext{ &cl_compile_ctx };
- GpuWorkloadSketch sketch{ &gpu_ctx };
+ auto context = GpuWorkloadContext{ &cl_compile_ctx };
+ GpuWorkloadSketch sketch{ &context };
// Create sketch tensors
- auto input_info = sketch.create_tensor_info(TensorInfo(input_shape, 1, data_type, data_layout));
- auto weight_info = sketch.create_tensor_info(TensorInfo(weights_shape, 1, data_type, data_layout));
- auto bias_info = sketch.create_tensor_info(TensorInfo(bias_shape, 1, bias_data_type, data_layout));
- auto dst_info = sketch.create_tensor_info();
+ auto input_info = context.create_tensor_info(TensorInfo(input_shape, 1, data_type, data_layout));
+ auto weight_info = context.create_tensor_info(TensorInfo(weights_shape, 1, data_type, data_layout));
+ auto bias_info = context.create_tensor_info(TensorInfo(bias_shape, 1, bias_data_type, data_layout));
+ auto dst_info = context.create_tensor_info();
ITensorInfo *ans_info = FunctionType::create_op(sketch, &input_info, &weight_info, &bias_info, conv2d_attr);
GpuOutput::create_op(sketch, ans_info, &dst_info);
diff --git a/tests/validation/fixtures/dynamic_fusion/gpu/cl/ElementwiseBinaryFixture.h b/tests/validation/fixtures/dynamic_fusion/gpu/cl/ElementwiseBinaryFixture.h
index b0680c0e4a..22deff1f24 100644
--- a/tests/validation/fixtures/dynamic_fusion/gpu/cl/ElementwiseBinaryFixture.h
+++ b/tests/validation/fixtures/dynamic_fusion/gpu/cl/ElementwiseBinaryFixture.h
@@ -99,13 +99,13 @@ protected:
{
// Create a new workload sketch
auto cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
- auto gpu_ctx = GpuWorkloadContext{ &cl_compile_ctx };
- GpuWorkloadSketch sketch{ &gpu_ctx };
+ auto context = GpuWorkloadContext{ &cl_compile_ctx };
+ GpuWorkloadSketch sketch{ &context };
// Fuse first element wise binary Op
- TensorInfo lhs_info = sketch.create_tensor_info(TensorInfo(shape0, 1, _data_type));
- TensorInfo rhs_info = sketch.create_tensor_info(TensorInfo(shape1, 1, _data_type));
- TensorInfo dst_info = sketch.create_tensor_info();
+ TensorInfo lhs_info = context.create_tensor_info(TensorInfo(shape0, 1, _data_type));
+ TensorInfo rhs_info = context.create_tensor_info(TensorInfo(shape1, 1, _data_type));
+ TensorInfo dst_info = context.create_tensor_info();
TensorInfo rhs_info_fuse;
@@ -113,7 +113,7 @@ protected:
if(_fuse)
{
- rhs_info_fuse = sketch.create_tensor_info(TensorInfo(shape2, 1, _data_type));
+ rhs_info_fuse = context.create_tensor_info(TensorInfo(shape2, 1, _data_type));
ITensorInfo *ans2_info = FunctionType::create_op(sketch, ans_info, &rhs_info_fuse);
GpuOutput::create_op(sketch, ans2_info, &dst_info);
}
diff --git a/tests/validation/fixtures/dynamic_fusion/gpu/cl/Pool2dFixture.h b/tests/validation/fixtures/dynamic_fusion/gpu/cl/Pool2dFixture.h
index efb5cf1e74..249f57aceb 100644
--- a/tests/validation/fixtures/dynamic_fusion/gpu/cl/Pool2dFixture.h
+++ b/tests/validation/fixtures/dynamic_fusion/gpu/cl/Pool2dFixture.h
@@ -91,12 +91,12 @@ protected:
// Create a new workload sketch
auto cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
- auto gpu_ctx = GpuWorkloadContext{ &cl_compile_ctx };
- GpuWorkloadSketch sketch{ &gpu_ctx };
+ auto context = GpuWorkloadContext{ &cl_compile_ctx };
+ GpuWorkloadSketch sketch{ &context };
// Create sketch tensors
- auto input_info = sketch.create_tensor_info(TensorInfo(input_shape, 1, data_type, DataLayout::NHWC));
- auto dst_info = sketch.create_tensor_info();
+ auto input_info = context.create_tensor_info(TensorInfo(input_shape, 1, data_type, DataLayout::NHWC));
+ auto dst_info = context.create_tensor_info();
// Create Pool2dSettings
GpuPool2dSettings pool_settings = GpuPool2dSettings().mixed_precision(mixed_precision);
diff --git a/tests/validation/fixtures/dynamic_fusion/operators/ActivationFixture.h b/tests/validation/fixtures/dynamic_fusion/operators/ActivationFixture.h
index 9656c497ea..3fb2cc2b7c 100644
--- a/tests/validation/fixtures/dynamic_fusion/operators/ActivationFixture.h
+++ b/tests/validation/fixtures/dynamic_fusion/operators/ActivationFixture.h
@@ -102,12 +102,12 @@ protected:
{
// Create a new workload sketch
CLCompileContext cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
- GpuWorkloadContext gpu_ctx{ &cl_compile_ctx };
- GpuWorkloadSketch sketch{ &gpu_ctx };
+ GpuWorkloadContext context{ &cl_compile_ctx };
+ GpuWorkloadSketch sketch{ &context };
// Create sketch tensors
- TensorInfo src_info = sketch.create_tensor_info(TensorInfo(shape, 1, _data_type));
- TensorInfo dst_info = sketch.create_tensor_info(TensorInfo(shape, 1, _data_type));
+ TensorInfo src_info = context.create_tensor_info(TensorInfo(shape, 1, _data_type));
+ TensorInfo dst_info = context.create_tensor_info(TensorInfo(shape, 1, _data_type));
ITensorInfo *ans_0_info = FunctionType::create_op(sketch, &src_info, args...);
if(_fuse)
diff --git a/tests/validation/fixtures/dynamic_fusion/operators/CastFixture.h b/tests/validation/fixtures/dynamic_fusion/operators/CastFixture.h
index cd39ec0a06..8a8e2b0c9a 100644
--- a/tests/validation/fixtures/dynamic_fusion/operators/CastFixture.h
+++ b/tests/validation/fixtures/dynamic_fusion/operators/CastFixture.h
@@ -112,12 +112,12 @@ protected:
{
// Create a new workload sketch
auto cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
- auto gpu_ctx = GpuWorkloadContext{ &cl_compile_ctx };
- GpuWorkloadSketch sketch{ &gpu_ctx };
+ auto context = GpuWorkloadContext{ &cl_compile_ctx };
+ GpuWorkloadSketch sketch{ &context };
// Create sketch tensors
- TensorInfo src_info = sketch.create_tensor_info(TensorInfo(shape, 1, dt_in, DataLayout::NCHW)); // layout is not important
- TensorInfo dst_info = sketch.create_tensor_info();
+ TensorInfo src_info = context.create_tensor_info(TensorInfo(shape, 1, dt_in, DataLayout::NCHW)); // layout is not important
+ TensorInfo dst_info = context.create_tensor_info();
CastAttributes attributes;
attributes.convert_policy(policy).data_type(dt_out);
diff --git a/tests/validation/fixtures/dynamic_fusion/operators/ClampFixture.h b/tests/validation/fixtures/dynamic_fusion/operators/ClampFixture.h
index a1fd22582f..cafd28f7b4 100644
--- a/tests/validation/fixtures/dynamic_fusion/operators/ClampFixture.h
+++ b/tests/validation/fixtures/dynamic_fusion/operators/ClampFixture.h
@@ -104,12 +104,12 @@ protected:
{
// Create a new workload sketch
CLCompileContext cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
- GpuWorkloadContext gpu_ctx{ &cl_compile_ctx };
- GpuWorkloadSketch sketch{ &gpu_ctx };
+ GpuWorkloadContext context{ &cl_compile_ctx };
+ GpuWorkloadSketch sketch{ &context };
// Create sketch tensors
- TensorInfo src_info = sketch.create_tensor_info(TensorInfo(shape, 1, _data_type));
- TensorInfo dst_info = sketch.create_tensor_info(TensorInfo(shape, 1, _data_type));
+ TensorInfo src_info = context.create_tensor_info(TensorInfo(shape, 1, _data_type));
+ TensorInfo dst_info = context.create_tensor_info(TensorInfo(shape, 1, _data_type));
ITensorInfo *ans_0_info = FunctionType::create_op(sketch, &src_info, attributes);
if(_fuse)
diff --git a/tests/validation/fixtures/dynamic_fusion/operators/MulFixture.h b/tests/validation/fixtures/dynamic_fusion/operators/MulFixture.h
index 0530707c38..a0d6bc6ed5 100644
--- a/tests/validation/fixtures/dynamic_fusion/operators/MulFixture.h
+++ b/tests/validation/fixtures/dynamic_fusion/operators/MulFixture.h
@@ -1,26 +1,26 @@
/*
-* Copyright (c) 2023 Arm Limited.
-*
-* SPDX-License-Identifier: MIT
-*
-* Permission is hereby granted, free of charge, to any person obtaining a copy
-* of this software and associated documentation files (the "Software"), to
-* deal in the Software without restriction, including without limitation the
-* rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
-* sell copies of the Software, and to permit persons to whom the Software is
-* furnished to do so, subject to the following conditions:
-*
-* The above copyright notice and this permission notice shall be included in all
-* copies or substantial portions of the Software.
-*
-* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
-* SOFTWARE.
-*/
+ * Copyright (c) 2023 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
#ifndef TESTS_VALIDATION_FIXTURES_DYNAMIC_FUSION_OPERATORS_MULFIXTURE
#define TESTS_VALIDATION_FIXTURES_DYNAMIC_FUSION_OPERATORS_MULFIXTURE
@@ -75,13 +75,13 @@ protected:
{
// Create a new workload sketch
auto cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
- auto gpu_ctx = GpuWorkloadContext{ &cl_compile_ctx };
- GpuWorkloadSketch sketch{ &gpu_ctx };
+ auto context = GpuWorkloadContext{ &cl_compile_ctx };
+ GpuWorkloadSketch sketch{ &context };
// Fuse first multiplication op
- TensorInfo lhs_info = sketch.create_tensor_info(TensorInfo(shape0, 1, _data_type));
- TensorInfo rhs_info = sketch.create_tensor_info(TensorInfo(shape1, 1, _data_type));
- TensorInfo dst_info = sketch.create_tensor_info();
+ TensorInfo lhs_info = context.create_tensor_info(TensorInfo(shape0, 1, _data_type));
+ TensorInfo rhs_info = context.create_tensor_info(TensorInfo(shape1, 1, _data_type));
+ TensorInfo dst_info = context.create_tensor_info();
TensorInfo rhs_info_fuse;
@@ -89,7 +89,7 @@ protected:
if(_fuse)
{
- rhs_info_fuse = sketch.create_tensor_info(TensorInfo(shape2, 1, _data_type));
+ rhs_info_fuse = context.create_tensor_info(TensorInfo(shape2, 1, _data_type));
ITensorInfo *ans2_info = FunctionType::create_op(sketch, ans_info, &rhs_info_fuse);
GpuOutput::create_op(sketch, ans2_info, &dst_info);
}
diff --git a/tests/validation/fixtures/dynamic_fusion/operators/ReshapeFixture.h b/tests/validation/fixtures/dynamic_fusion/operators/ReshapeFixture.h
index e0b62d093f..88c04de35a 100644
--- a/tests/validation/fixtures/dynamic_fusion/operators/ReshapeFixture.h
+++ b/tests/validation/fixtures/dynamic_fusion/operators/ReshapeFixture.h
@@ -71,12 +71,12 @@ protected:
// Create a new workload sketch
auto cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
- auto gpu_ctx = GpuWorkloadContext{ &cl_compile_ctx };
- GpuWorkloadSketch sketch{ &gpu_ctx };
+ auto context = GpuWorkloadContext{ &cl_compile_ctx };
+ GpuWorkloadSketch sketch{ &context };
// Create sketch tensors
- TensorInfo src_info = sketch.create_tensor_info(TensorInfo(input_shape, 1, data_type));
- TensorInfo dst_info = sketch.create_tensor_info(TensorInfo(output_shape, 1, data_type));
+ TensorInfo src_info = context.create_tensor_info(TensorInfo(input_shape, 1, data_type));
+ TensorInfo dst_info = context.create_tensor_info(TensorInfo(output_shape, 1, data_type));
ReshapeAttributes attributes;
attributes.shape(output_shape);
diff --git a/tests/validation/fixtures/dynamic_fusion/operators/ResizeFixture.h b/tests/validation/fixtures/dynamic_fusion/operators/ResizeFixture.h
index 581a3e8947..62ef053dca 100644
--- a/tests/validation/fixtures/dynamic_fusion/operators/ResizeFixture.h
+++ b/tests/validation/fixtures/dynamic_fusion/operators/ResizeFixture.h
@@ -137,13 +137,13 @@ protected:
// Create a new workload sketch
CLCompileContext cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
- GpuWorkloadContext gpu_ctx = GpuWorkloadContext{ &cl_compile_ctx };
- GpuWorkloadSketch sketch{ &gpu_ctx };
+ GpuWorkloadContext context = GpuWorkloadContext{ &cl_compile_ctx };
+ GpuWorkloadSketch sketch{ &context };
// Create sketch tensors
- TensorInfo src_info = sketch.create_tensor_info(TensorInfo(shape, 1, _data_type, _data_layout));
+ TensorInfo src_info = context.create_tensor_info(TensorInfo(shape, 1, _data_type, _data_layout));
src_info.set_quantization_info(_input_quantization_info);
- TensorInfo dst_info = sketch.create_tensor_info();
+ TensorInfo dst_info = context.create_tensor_info();
ResizeAttributes attributes;
attributes.align_corners(_align_corners).sampling_policy(_sampling_policy).interpolation_policy(_interpolation_policy).output_width(_output_width).output_height(_output_height);
diff --git a/tests/validation/fixtures/dynamic_fusion/operators/SoftmaxFixture.h b/tests/validation/fixtures/dynamic_fusion/operators/SoftmaxFixture.h
index 38177114e6..0f50e8e12f 100644
--- a/tests/validation/fixtures/dynamic_fusion/operators/SoftmaxFixture.h
+++ b/tests/validation/fixtures/dynamic_fusion/operators/SoftmaxFixture.h
@@ -82,13 +82,13 @@ protected:
{
// Create a new workload sketch
CLCompileContext cl_compile_ctx = CLKernelLibrary::get().get_compile_context();
- GpuWorkloadContext gpu_ctx = GpuWorkloadContext{ &cl_compile_ctx };
- GpuWorkloadSketch sketch{ &gpu_ctx };
+ GpuWorkloadContext context = GpuWorkloadContext{ &cl_compile_ctx };
+ GpuWorkloadSketch sketch{ &context };
SoftmaxAttributes softmax_attr{};
softmax_attr.axis(axis).beta(beta).is_log_softmax(is_log);
- TensorInfo src_info = sketch.create_tensor_info(shape, 1, data_type);
- TensorInfo dst_info = sketch.create_tensor_info(shape, 1, data_type);
+ TensorInfo src_info = context.create_tensor_info(shape, 1, data_type);
+ TensorInfo dst_info = context.create_tensor_info(shape, 1, data_type);
FunctionType::create_op(sketch, &src_info, &dst_info, softmax_attr);
// Configure runtime