aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJames Ward <james.ward@arm.com>2023-01-18 14:51:25 +0000
committerEric Kunze <eric.kunze@arm.com>2023-02-02 22:03:25 +0000
commitd34b3fc5eeef48ecc781a02433ce022a28e3373c (patch)
tree13aa36aa89c618e56eb2f51915a172ff8e4276d9
parent512c1caa8b6d494de81f3ac83a6ebb96e1e0f8e0 (diff)
downloadreference_model-d34b3fc5eeef48ecc781a02433ce022a28e3373c.tar.gz
Remove accumulator attributes from all but AVG_POOL2D
Signed-off-by: James Ward <james.ward@arm.com> Change-Id: If67f503a1848967bc1671646c3011d055b622c52
-rw-r--r--reference_model/src/graph_node.h20
-rw-r--r--reference_model/src/operators.cc42
-rw-r--r--reference_model/src/ops/image.cc14
-rw-r--r--reference_model/src/ops/op_factory.cc82
-rw-r--r--reference_model/src/ops/op_factory.h12
-rw-r--r--reference_model/src/ops/template_types.h33
-rw-r--r--reference_model/src/ops/tensor_ops.cc250
-rw-r--r--reference_model/src/ops/tensor_ops.h60
-rw-r--r--scripts/operator_api/generate_api.py2
m---------thirdparty/serialization_lib0
-rw-r--r--verif/generator/tosa_test_gen.py14
11 files changed, 240 insertions, 289 deletions
diff --git a/reference_model/src/graph_node.h b/reference_model/src/graph_node.h
index b227d17..a9a336b 100644
--- a/reference_model/src/graph_node.h
+++ b/reference_model/src/graph_node.h
@@ -24,9 +24,6 @@
#define DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, RANK, DTYPE) template class TosaReference::OP<RANK, DType_##DTYPE>;
-#define DEF_INSTANTIATE_ONE_RANK_ONE_TYPE_ONE_ACCUM(OP, RANK, DTYPE, ACCUM_DTYPE) \
- template class TosaReference::OP<RANK, DType_##DTYPE, DType_##ACCUM_DTYPE>;
-
#define DEF_INSTANTIATE_ONE_RANK_TWO_TYPE(OP, RANK, DTYPE1, DTYPE2) \
template class TosaReference::OP<RANK, DType_##DTYPE1, DType_##DTYPE2>;
@@ -38,15 +35,12 @@
#define DEF_INSTANTIATE_ONE_TYPE(OP, DTYPE) template class TosaReference::OP<DType_##DTYPE>;
-#define DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OP, DTYPE, ACCUM_DTYPE) \
- template class TosaReference::OP<DType_##DTYPE, DType_##ACCUM_DTYPE>;
-
#define DEF_INSTANTIATE_TWO_TYPE(OP, DTYPE1, DTYPE2) template class TosaReference::OP<DType_##DTYPE1, DType_##DTYPE2>;
-#define DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OP, DTYPE1, DTYPE2, ACCUM_DTYPE) \
- template class TosaReference::OP<DType_##DTYPE1, DType_##DTYPE2, DType_##ACCUM_DTYPE>;
+#define DEF_INSTANTIATE_THREE_TYPE(OP, DTYPE1, DTYPE2, DTYPE3) \
+ template class TosaReference::OP<DType_##DTYPE1, DType_##DTYPE2, DType_##DTYPE3>;
-#define DEF_INSTANTIATE_THREE_TYPE(OP, DTYPE1, DTYPE2, OP_TYPE) \
+#define DEF_INSTANTIATE_THREE_TYPE_RESIZE(OP, DTYPE1, DTYPE2, OP_TYPE) \
template class TosaReference::OP<DType_##DTYPE1, DType_##DTYPE2, OP_TYPE>;
#define DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OP, DTYPE) \
@@ -66,14 +60,6 @@
DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 5, DTYPE) \
DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 6, DTYPE)
-#define DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE_ONE_ACCUM(OP, DTYPE, ACCUM_DTYPE) \
- DEF_INSTANTIATE_ONE_RANK_ONE_TYPE_ONE_ACCUM(OP, 1, DTYPE, ACCUM_DTYPE) \
- DEF_INSTANTIATE_ONE_RANK_ONE_TYPE_ONE_ACCUM(OP, 2, DTYPE, ACCUM_DTYPE) \
- DEF_INSTANTIATE_ONE_RANK_ONE_TYPE_ONE_ACCUM(OP, 3, DTYPE, ACCUM_DTYPE) \
- DEF_INSTANTIATE_ONE_RANK_ONE_TYPE_ONE_ACCUM(OP, 4, DTYPE, ACCUM_DTYPE) \
- DEF_INSTANTIATE_ONE_RANK_ONE_TYPE_ONE_ACCUM(OP, 5, DTYPE, ACCUM_DTYPE) \
- DEF_INSTANTIATE_ONE_RANK_ONE_TYPE_ONE_ACCUM(OP, 6, DTYPE, ACCUM_DTYPE)
-
#define DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OP, DTYPE1, DTYPE2) \
DEF_INSTANTIATE_ONE_RANK_TWO_TYPE(OP, 0, DTYPE1, DTYPE2) \
DEF_INSTANTIATE_ONE_RANK_TWO_TYPE(OP, 1, DTYPE1, DTYPE2) \
diff --git a/reference_model/src/operators.cc b/reference_model/src/operators.cc
index af348ca..a627322 100644
--- a/reference_model/src/operators.cc
+++ b/reference_model/src/operators.cc
@@ -168,10 +168,9 @@ extern "C"
const std::vector<int32_t> pad(&client_pad[0], &client_pad[4]);
const std::vector<int32_t> stride(&client_stride[0], &client_stride[2]);
const std::vector<int32_t> dilation(&client_dilation[0], &client_dilation[2]);
- const int32_t input_zp = client_input_zp;
- const int32_t weight_zp = client_weight_zp;
- const tosa::DType accum_dtype = tosa::DType::DType_FP32;
- TosaConvAttribute attr(pad, stride, dilation, input_zp, weight_zp, accum_dtype);
+ const int32_t input_zp = client_input_zp;
+ const int32_t weight_zp = client_weight_zp;
+ TosaConvAttribute attr(pad, stride, dilation, input_zp, weight_zp);
// Create tensors
tosa::TosaSerializationTensor* input = translate_client_tensor(client_input, "input");
@@ -219,10 +218,9 @@ extern "C"
const std::vector<int32_t> pad(&client_pad[0], &client_pad[6]);
const std::vector<int32_t> stride(&client_stride[0], &client_stride[3]);
const std::vector<int32_t> dilation(&client_dilation[0], &client_dilation[3]);
- const int32_t input_zp = client_input_zp;
- const int32_t weight_zp = client_weight_zp;
- const tosa::DType accum_dtype = tosa::DType::DType_FP32;
- TosaConvAttribute attr(pad, stride, dilation, input_zp, weight_zp, accum_dtype);
+ const int32_t input_zp = client_input_zp;
+ const int32_t weight_zp = client_weight_zp;
+ TosaConvAttribute attr(pad, stride, dilation, input_zp, weight_zp);
// Create tensors
tosa::TosaSerializationTensor* input = translate_client_tensor(client_input, "input");
@@ -270,10 +268,9 @@ extern "C"
const std::vector<int32_t> pad(&client_pad[0], &client_pad[4]);
const std::vector<int32_t> stride(&client_stride[0], &client_stride[2]);
const std::vector<int32_t> dilation(&client_dilation[0], &client_dilation[2]);
- const int32_t input_zp = client_input_zp;
- const int32_t weight_zp = client_weight_zp;
- const tosa::DType accum_dtype = tosa::DType::DType_FP32;
- TosaConvAttribute attr(pad, stride, dilation, input_zp, weight_zp, accum_dtype);
+ const int32_t input_zp = client_input_zp;
+ const int32_t weight_zp = client_weight_zp;
+ TosaConvAttribute attr(pad, stride, dilation, input_zp, weight_zp);
// Create tensors
tosa::TosaSerializationTensor* input = translate_client_tensor(client_input, "input");
@@ -313,10 +310,9 @@ extern "C"
tosa_tensor_t client_output)
{
// Create operator attributes
- const int32_t input_zp = client_input_zp;
- const int32_t weight_zp = client_weight_zp;
- const tosa::DType accum_dtype = tosa::DType::DType_FP32;
- TosaFullyConnectedAttribute attr(input_zp, weight_zp, accum_dtype);
+ const int32_t input_zp = client_input_zp;
+ const int32_t weight_zp = client_weight_zp;
+ TosaFullyConnectedAttribute attr(input_zp, weight_zp);
// Create tensors
tosa::TosaSerializationTensor* input = translate_client_tensor(client_input, "input");
@@ -352,10 +348,9 @@ extern "C"
tosa_tensor_t client_output)
{
// Create operator attributes
- const int32_t a_zp = client_a_zp;
- const int32_t b_zp = client_b_zp;
- const tosa::DType accum_dtype = tosa::DType::DType_FP32;
- TosaMatMulAttribute attr(a_zp, b_zp, accum_dtype);
+ const int32_t a_zp = client_a_zp;
+ const int32_t b_zp = client_b_zp;
+ TosaMatMulAttribute attr(a_zp, b_zp);
// Create tensors
tosa::TosaSerializationTensor* a = translate_client_tensor(client_a, "a");
@@ -446,10 +441,9 @@ extern "C"
const std::vector<int32_t> pad(&client_pad[0], &client_pad[0] + client_pad_len);
const std::vector<int32_t> stride(&client_stride[0], &client_stride[2]);
const std::vector<int32_t> dilation(&client_dilation[0], &client_dilation[0] + client_dilation_len);
- const int32_t input_zp = client_input_zp;
- const int32_t weight_zp = client_weight_zp;
- const tosa::DType accum_dtype = tosa::DType::DType_FP32;
- TosaConvAttribute attr(pad, stride, dilation, input_zp, weight_zp, accum_dtype);
+ const int32_t input_zp = client_input_zp;
+ const int32_t weight_zp = client_weight_zp;
+ TosaConvAttribute attr(pad, stride, dilation, input_zp, weight_zp);
// Create tensors
tosa::TosaSerializationTensor* input = translate_client_tensor(client_input, "input");
diff --git a/reference_model/src/ops/image.cc b/reference_model/src/ops/image.cc
index a1a4474..90427e4 100644
--- a/reference_model/src/ops/image.cc
+++ b/reference_model/src/ops/image.cc
@@ -236,10 +236,10 @@ int OpResize<InDtype, OutDtype, resize_t>::eval()
}
// template explicit instantiation
-DEF_INSTANTIATE_THREE_TYPE(OpResize, INT8, INT32, int16_t);
-DEF_INSTANTIATE_THREE_TYPE(OpResize, INT8, INT8, int16_t);
-DEF_INSTANTIATE_THREE_TYPE(OpResize, INT16, INT48, int16_t);
-DEF_INSTANTIATE_THREE_TYPE(OpResize, INT16, INT16, int16_t);
-DEF_INSTANTIATE_THREE_TYPE(OpResize, FP16, FP16, half_float::half);
-DEF_INSTANTIATE_THREE_TYPE(OpResize, BF16, BF16, Eigen::bfloat16);
-DEF_INSTANTIATE_THREE_TYPE(OpResize, FP32, FP32, float);
+DEF_INSTANTIATE_THREE_TYPE_RESIZE(OpResize, INT8, INT32, int16_t);
+DEF_INSTANTIATE_THREE_TYPE_RESIZE(OpResize, INT8, INT8, int16_t);
+DEF_INSTANTIATE_THREE_TYPE_RESIZE(OpResize, INT16, INT48, int16_t);
+DEF_INSTANTIATE_THREE_TYPE_RESIZE(OpResize, INT16, INT16, int16_t);
+DEF_INSTANTIATE_THREE_TYPE_RESIZE(OpResize, FP16, FP16, half_float::half);
+DEF_INSTANTIATE_THREE_TYPE_RESIZE(OpResize, BF16, BF16, Eigen::bfloat16);
+DEF_INSTANTIATE_THREE_TYPE_RESIZE(OpResize, FP32, FP32, float);
diff --git a/reference_model/src/ops/op_factory.cc b/reference_model/src/ops/op_factory.cc
index 76cf666..b1a405a 100644
--- a/reference_model/src/ops/op_factory.cc
+++ b/reference_model/src/ops/op_factory.cc
@@ -63,48 +63,48 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, INT16, INT32);
break;
case Op_CONV2D:
- DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, FP16, FP16, FP16);
- DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, FP16, FP16, FP32);
- DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, BF16, BF16, FP32);
- DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, FP32, FP32, FP32);
- DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, INT8, INT4, INT32);
- DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, INT8, INT8, INT32);
- DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, INT16, INT8, INT48);
+ DEF_FACTORY_THREE_TYPE(OpConv2d, FP16, FP16, FP16);
+ DEF_FACTORY_THREE_TYPE(OpConv2d, FP16, FP16, FP32);
+ DEF_FACTORY_THREE_TYPE(OpConv2d, BF16, BF16, FP32);
+ DEF_FACTORY_THREE_TYPE(OpConv2d, FP32, FP32, FP32);
+ DEF_FACTORY_THREE_TYPE(OpConv2d, INT8, INT4, INT32);
+ DEF_FACTORY_THREE_TYPE(OpConv2d, INT8, INT8, INT32);
+ DEF_FACTORY_THREE_TYPE(OpConv2d, INT16, INT8, INT48);
break;
case Op_CONV3D:
- DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, FP16, FP16, FP16);
- DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, FP16, FP16, FP32);
- DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, BF16, BF16, FP32);
- DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, FP32, FP32, FP32);
- DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, INT8, INT4, INT32);
- DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, INT8, INT8, INT32);
- DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, INT16, INT8, INT48);
+ DEF_FACTORY_THREE_TYPE(OpConv3d, FP16, FP16, FP16);
+ DEF_FACTORY_THREE_TYPE(OpConv3d, FP16, FP16, FP32);
+ DEF_FACTORY_THREE_TYPE(OpConv3d, BF16, BF16, FP32);
+ DEF_FACTORY_THREE_TYPE(OpConv3d, FP32, FP32, FP32);
+ DEF_FACTORY_THREE_TYPE(OpConv3d, INT8, INT4, INT32);
+ DEF_FACTORY_THREE_TYPE(OpConv3d, INT8, INT8, INT32);
+ DEF_FACTORY_THREE_TYPE(OpConv3d, INT16, INT8, INT48);
break;
case Op_DEPTHWISE_CONV2D:
- DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP16, FP16, FP16);
- DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP16, FP16, FP32);
- DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, BF16, BF16, FP32);
- DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP32, FP32, FP32);
- DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, INT8, INT4, INT32);
- DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, INT8, INT8, INT32);
- DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, INT16, INT8, INT48);
+ DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP16);
+ DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP32);
+ DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, BF16, BF16, FP32);
+ DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP32, FP32, FP32);
+ DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, INT8, INT4, INT32);
+ DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, INT8, INT8, INT32);
+ DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, INT16, INT8, INT48);
break;
case Op_FULLY_CONNECTED:
- DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, FP16, FP16, FP16);
- DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, FP16, FP16, FP32);
- DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, BF16, BF16, FP32);
- DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, FP32, FP32, FP32);
- DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, INT8, INT4, INT32);
- DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, INT8, INT8, INT32);
- DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, INT16, INT8, INT48);
+ DEF_FACTORY_THREE_TYPE(OpFullyConnected, FP16, FP16, FP16);
+ DEF_FACTORY_THREE_TYPE(OpFullyConnected, FP16, FP16, FP32);
+ DEF_FACTORY_THREE_TYPE(OpFullyConnected, BF16, BF16, FP32);
+ DEF_FACTORY_THREE_TYPE(OpFullyConnected, FP32, FP32, FP32);
+ DEF_FACTORY_THREE_TYPE(OpFullyConnected, INT8, INT4, INT32);
+ DEF_FACTORY_THREE_TYPE(OpFullyConnected, INT8, INT8, INT32);
+ DEF_FACTORY_THREE_TYPE(OpFullyConnected, INT16, INT8, INT48);
break;
case Op_MATMUL:
- DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, FP16, FP16);
- DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, FP16, FP32);
- DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, BF16, FP32);
- DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, FP32, FP32);
- DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, INT8, INT32);
- DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, INT16, INT48);
+ DEF_FACTORY_TWO_TYPE_IN_OUT(OpMatMul, FP16, FP16);
+ DEF_FACTORY_TWO_TYPE_IN_OUT(OpMatMul, FP16, FP32);
+ DEF_FACTORY_TWO_TYPE_IN_OUT(OpMatMul, BF16, FP32);
+ DEF_FACTORY_TWO_TYPE_IN_OUT(OpMatMul, FP32, FP32);
+ DEF_FACTORY_TWO_TYPE_IN_OUT(OpMatMul, INT8, INT32);
+ DEF_FACTORY_TWO_TYPE_IN_OUT(OpMatMul, INT16, INT48);
break;
case Op_MAX_POOL2D:
DEF_FACTORY_ONE_TYPE(OpMaxPool2d, FP16);
@@ -117,13 +117,13 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt,
DEF_FACTORY_ONE_TYPE(OpRFFT2d, FP32);
break;
case Op_TRANSPOSE_CONV2D:
- DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP16, FP16, FP16);
- DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP16, FP16, FP32);
- DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, BF16, BF16, FP32);
- DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP32, FP32, FP32);
- DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, INT8, INT4, INT32);
- DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, INT8, INT8, INT32);
- DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, INT16, INT8, INT48);
+ DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, FP16, FP16, FP16);
+ DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, FP16, FP16, FP32);
+ DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, BF16, BF16, FP32);
+ DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, FP32, FP32, FP32);
+ DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, INT8, INT4, INT32);
+ DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, INT8, INT8, INT32);
+ DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, INT16, INT8, INT48);
break;
// activation_funcs
diff --git a/reference_model/src/ops/op_factory.h b/reference_model/src/ops/op_factory.h
index f4177db..9117df4 100644
--- a/reference_model/src/ops/op_factory.h
+++ b/reference_model/src/ops/op_factory.h
@@ -74,6 +74,12 @@
return new OP<DType_##DTYPE1, DType_##DTYPE2>(sgt, attribute, id); \
}
+#define DEF_FACTORY_TWO_TYPE_IN_OUT(OP, DTYPE1, DTYPE2) \
+ if (inputDType == DType_##DTYPE1 && outputDType == DType_##DTYPE2) \
+ { \
+ return new OP<DType_##DTYPE1, DType_##DTYPE2>(sgt, attribute, id); \
+ }
+
#define DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OP, ATTR_NAME, DTYPE1, DTYPE2, ACCUM_DTYPE) \
if (inputDType == DType_##DTYPE1 && weightDType == DType_##DTYPE2 \
&& ACCUM_FROM_ATTRIBUTE(ATTR_NAME) == DType_##ACCUM_DTYPE) \
@@ -81,6 +87,12 @@
return new OP<DType_##DTYPE1, DType_##DTYPE2, DType_##ACCUM_DTYPE>(sgt, attribute, id); \
} \
+#define DEF_FACTORY_THREE_TYPE(OP, DTYPE1, DTYPE2, DTYPE3) \
+ if (inputDType == DType_##DTYPE1 && weightDType == DType_##DTYPE2 && outputDType == DType_##DTYPE3) \
+ { \
+ return new OP<DType_##DTYPE1, DType_##DTYPE2, DType_##DTYPE3>(sgt, attribute, id); \
+ }
+
// Statement-expression to evaluate accumulate attribute in-place
#define ACCUM_FROM_ATTRIBUTE(ATTRIBUTE_NAME) \
({ \
diff --git a/reference_model/src/ops/template_types.h b/reference_model/src/ops/template_types.h
index 647ca84..6b28502 100644
--- a/reference_model/src/ops/template_types.h
+++ b/reference_model/src/ops/template_types.h
@@ -275,39 +275,6 @@ struct GetQMax<DType_INT48>
static constexpr int64_t value = (1L << 47) - 1;
};
-template <DType TIn1, DType TIn2>
-struct GetAccDType;
-template <>
-struct GetAccDType<DType_INT8, DType_INT4>
-{
- static constexpr DType value = DType_INT32;
-};
-template <>
-struct GetAccDType<DType_INT8, DType_INT8>
-{
- static constexpr DType value = DType_INT32;
-};
-template <>
-struct GetAccDType<DType_INT16, DType_INT8>
-{
- static constexpr DType value = DType_INT48;
-};
-template <>
-struct GetAccDType<DType_INT16, DType_INT16>
-{
- static constexpr DType value = DType_INT48;
-};
-template <>
-struct GetAccDType<DType_FP16, DType_FP16>
-{
- static constexpr DType value = DType_FP16;
-};
-template <>
-struct GetAccDType<DType_FP32, DType_FP32>
-{
- static constexpr DType value = DType_FP32;
-};
-
}; // namespace TosaReference
#endif
diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc
index dff9e08..4663c47 100644
--- a/reference_model/src/ops/tensor_ops.cc
+++ b/reference_model/src/ops/tensor_ops.cc
@@ -541,8 +541,8 @@ int OpAvgPool2d<Dtype, AccDtype>::eval()
return GraphNode::eval();
}
-template <DType InDtype, DType WeightDtype, DType AccDtype>
-OpConv2d<InDtype, WeightDtype, AccDtype>::OpConv2d(SubgraphTraverser* sgt_,
+template <DType InDtype, DType WeightDtype, DType OutDtype>
+OpConv2d<InDtype, WeightDtype, OutDtype>::OpConv2d(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
uint64_t id_)
: GraphNode(sgt_, Op_CONV2D, id_)
@@ -553,15 +553,15 @@ OpConv2d<InDtype, WeightDtype, AccDtype>::OpConv2d(SubgraphTraverser* sgt_,
INIT_ATTRIBUTE(Conv);
}
-template <DType InDtype, DType WeightDtype, DType AccDtype>
-OpConv2d<InDtype, WeightDtype, AccDtype>::~OpConv2d()
+template <DType InDtype, DType WeightDtype, DType OutDtype>
+OpConv2d<InDtype, WeightDtype, OutDtype>::~OpConv2d()
{
if (attribute)
delete attribute;
}
-template <DType InDtype, DType WeightDtype, DType AccDtype>
-int OpConv2d<InDtype, WeightDtype, AccDtype>::checkTensorAttributes()
+template <DType InDtype, DType WeightDtype, DType OutDtype>
+int OpConv2d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
{
if (validateRequiredOperands())
return 1;
@@ -577,7 +577,7 @@ int OpConv2d<InDtype, WeightDtype, AccDtype>::checkTensorAttributes()
printNodeValidationError("OpConv2d: bias tensor must be rank 1");
}
- ERROR_IF(outputs[0]->getDtype() != AccDtype,
+ ERROR_IF(outputs[0]->getDtype() != OutDtype,
"OpConv2d: Output data type not supported for this configuration of operator");
input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
@@ -597,8 +597,8 @@ int OpConv2d<InDtype, WeightDtype, AccDtype>::checkTensorAttributes()
return 0;
}
-template <DType InDtype, DType WeightDtype, DType AccDtype>
-int OpConv2d<InDtype, WeightDtype, AccDtype>::eval()
+template <DType InDtype, DType WeightDtype, DType OutDtype>
+int OpConv2d<InDtype, WeightDtype, OutDtype>::eval()
{
int in_batch = this->input->getShape()[0];
int in_height = this->input->getShape()[1];
@@ -634,14 +634,12 @@ int OpConv2d<InDtype, WeightDtype, AccDtype>::eval()
int dilation_h = this->attribute->dilation()[0];
int dilation_w = this->attribute->dilation()[1];
- tosa::DType accum_dtype = (tosa::DType)this->attribute->accum_dtype();
-
DEBUG_INFO(OP,
"perform OpConv2d, input.shape=[%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d], output.shape=[%d,%d,%d,%d], "
- "stride=[%d,%d], dilation=[%d,%d], pad=[%d,%d,%d,%d], accum_dtype=%s",
+ "stride=[%d,%d], dilation=[%d,%d], pad=[%d,%d,%d,%d]",
in_batch, in_height, in_width, in_channels, f_height, f_width, f_in_channels, f_out_channels, out_batch,
out_height, out_width, out_channels, stride_h, stride_w, dilation_h, dilation_w, pad_top,
- pad_bottom, pad_left, pad_right, EnumNamesDType()[accum_dtype]);
+ pad_bottom, pad_left, pad_right);
// GEMM-conv2d, left matrix is input, right matrix is weight
Eigen::array<Eigen::Index, 2> im2col_input_dims;
@@ -717,7 +715,7 @@ int OpConv2d<InDtype, WeightDtype, AccDtype>::eval()
// reshape back to [N, H, W, C]
this->output->getTensor() = biased_output.reshape(col2im_output_dims);
- if (AccDtype == DType_INT48)
+ if (OutDtype == DType_INT48)
{
this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
@@ -726,8 +724,8 @@ int OpConv2d<InDtype, WeightDtype, AccDtype>::eval()
return GraphNode::eval();
}
-template <DType InDtype, DType WeightDtype, DType AccDtype>
-OpConv3d<InDtype, WeightDtype, AccDtype>::OpConv3d(SubgraphTraverser* sgt_,
+template <DType InDtype, DType WeightDtype, DType OutDtype>
+OpConv3d<InDtype, WeightDtype, OutDtype>::OpConv3d(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
uint64_t id_)
: GraphNode(sgt_, Op_CONV3D, id_)
@@ -738,15 +736,15 @@ OpConv3d<InDtype, WeightDtype, AccDtype>::OpConv3d(SubgraphTraverser* sgt_,
INIT_ATTRIBUTE(Conv);
}
-template <DType InDtype, DType WeightDtype, DType AccDtype>
-OpConv3d<InDtype, WeightDtype, AccDtype>::~OpConv3d()
+template <DType InDtype, DType WeightDtype, DType OutDtype>
+OpConv3d<InDtype, WeightDtype, OutDtype>::~OpConv3d()
{
if (attribute)
delete attribute;
}
-template <DType InDtype, DType WeightDtype, DType AccDtype>
-int OpConv3d<InDtype, WeightDtype, AccDtype>::checkTensorAttributes()
+template <DType InDtype, DType WeightDtype, DType OutDtype>
+int OpConv3d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
{
if (validateRequiredOperands())
return 1;
@@ -762,7 +760,7 @@ int OpConv3d<InDtype, WeightDtype, AccDtype>::checkTensorAttributes()
printNodeValidationError("OpConv3d: bias tensor must be rank 1");
}
- ERROR_IF(outputs[0]->getDtype() != AccDtype,
+ ERROR_IF(outputs[0]->getDtype() != OutDtype,
"OpConv3d: Output data type not supported for this configuration of operator");
input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
@@ -782,8 +780,8 @@ int OpConv3d<InDtype, WeightDtype, AccDtype>::checkTensorAttributes()
return 0;
}
-template <DType InDtype, DType WeightDtype, DType AccDtype>
-int OpConv3d<InDtype, WeightDtype, AccDtype>::eval()
+template <DType InDtype, DType WeightDtype, DType OutDtype>
+int OpConv3d<InDtype, WeightDtype, OutDtype>::eval()
{
int in_batch = this->input->getShape()[0];
int in_depth = this->input->getShape()[1];
@@ -827,15 +825,13 @@ int OpConv3d<InDtype, WeightDtype, AccDtype>::eval()
int dilation_h = this->attribute->dilation()[1];
int dilation_w = this->attribute->dilation()[2];
- tosa::DType accum_dtype = (tosa::DType)this->attribute->accum_dtype();
-
DEBUG_INFO(
OP,
"perform OpConv3d, input.shape=[%d,%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d,%d], output.shape=[%d,%d,%d,%d,%d], "
- "stride=[%d,%d,%d], dilation=[%d,%d,%d], pad=[%d,%d,%d,%d,%d,%d], accum_dtype=%s",
+ "stride=[%d,%d,%d], dilation=[%d,%d,%d], pad=[%d,%d,%d,%d,%d,%d]",
in_batch, in_depth, in_height, in_width, in_channels, f_out_channels, f_depth, f_height, f_width, f_in_channels,
out_batch, out_depth, out_height, out_width, out_channels, stride_d, stride_h, stride_w, dilation_d, dilation_h,
- dilation_w, pad_d0, pad_d1, pad_top, pad_bottom, pad_left, pad_right, EnumNamesDType()[accum_dtype]);
+ dilation_w, pad_d0, pad_d1, pad_top, pad_bottom, pad_left, pad_right);
Eigen::array<std::pair<int32_t, int32_t>, 5> pad;
pad[0] = std::make_pair(0, 0);
@@ -907,7 +903,7 @@ int OpConv3d<InDtype, WeightDtype, AccDtype>::eval()
}
}
- if (AccDtype == DType_INT48)
+ if (OutDtype == DType_INT48)
{
this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
@@ -916,8 +912,8 @@ int OpConv3d<InDtype, WeightDtype, AccDtype>::eval()
return GraphNode::eval();
}
-template <DType InDtype, DType WeightDtype, DType AccDtype>
-OpDepthwiseConv2d<InDtype, WeightDtype, AccDtype>::OpDepthwiseConv2d(SubgraphTraverser* sgt_,
+template <DType InDtype, DType WeightDtype, DType OutDtype>
+OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::OpDepthwiseConv2d(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
uint64_t id_)
: GraphNode(sgt_, Op_DEPTHWISE_CONV2D, id_)
@@ -928,15 +924,15 @@ OpDepthwiseConv2d<InDtype, WeightDtype, AccDtype>::OpDepthwiseConv2d(SubgraphTra
INIT_ATTRIBUTE(Conv);
}
-template <DType InDtype, DType WeightDtype, DType AccDtype>
-OpDepthwiseConv2d<InDtype, WeightDtype, AccDtype>::~OpDepthwiseConv2d()
+template <DType InDtype, DType WeightDtype, DType OutDtype>
+OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::~OpDepthwiseConv2d()
{
if (attribute)
delete attribute;
}
-template <DType InDtype, DType WeightDtype, DType AccDtype>
-int OpDepthwiseConv2d<InDtype, WeightDtype, AccDtype>::checkTensorAttributes()
+template <DType InDtype, DType WeightDtype, DType OutDtype>
+int OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
{
if (validateRequiredOperands())
return 1;
@@ -952,7 +948,7 @@ int OpDepthwiseConv2d<InDtype, WeightDtype, AccDtype>::checkTensorAttributes()
printNodeValidationError("OpDepthwiseConv2d: bias tensor must be rank 1");
}
- ERROR_IF(outputs[0]->getDtype() != AccDtype,
+ ERROR_IF(outputs[0]->getDtype() != OutDtype,
"OpDepthwiseConv2d: Output data type not supported for this configuration of operator");
input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
@@ -972,8 +968,8 @@ int OpDepthwiseConv2d<InDtype, WeightDtype, AccDtype>::checkTensorAttributes()
return 0;
}
-template <DType InDtype, DType WeightDtype, DType AccDtype>
-int OpDepthwiseConv2d<InDtype, WeightDtype, AccDtype>::eval()
+template <DType InDtype, DType WeightDtype, DType OutDtype>
+int OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::eval()
{
int in_batch = this->input->getShape()[0];
int in_height = this->input->getShape()[1];
@@ -1010,14 +1006,12 @@ int OpDepthwiseConv2d<InDtype, WeightDtype, AccDtype>::eval()
int dilation_h = this->attribute->dilation()[0];
int dilation_w = this->attribute->dilation()[1];
- tosa::DType accum_dtype = (tosa::DType)this->attribute->accum_dtype();
-
DEBUG_INFO(OP,
"perform OpDepthwiseConv2d, input.shape=[%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d], "
- "output.shape=[%d,%d,%d,%d], stride=[%d,%d], dilation=[%d,%d], pad=[%d,%d,%d,%d], accum_dtype=%s",
+ "output.shape=[%d,%d,%d,%d], stride=[%d,%d], dilation=[%d,%d], pad=[%d,%d,%d,%d]",
in_batch, in_height, in_width, in_channels, f_height, f_width, f_in_channels, f_multiplier, out_batch,
out_height, out_width, out_channels, stride_h, stride_w, dilation_h, dilation_w, pad_top,
- pad_bottom, pad_left, pad_right, EnumNamesDType()[accum_dtype]);
+ pad_bottom, pad_left, pad_right);
Eigen::array<std::pair<int32_t, int32_t>, 4> pad;
pad[0] = std::make_pair(0, 0);
@@ -1083,7 +1077,7 @@ int OpDepthwiseConv2d<InDtype, WeightDtype, AccDtype>::eval()
}
}
- if (AccDtype == DType_INT48)
+ if (OutDtype == DType_INT48)
{
this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
@@ -1092,8 +1086,8 @@ int OpDepthwiseConv2d<InDtype, WeightDtype, AccDtype>::eval()
return GraphNode::eval();
}
-template <DType InDtype, DType WeightDtype, DType AccDtype>
-OpFullyConnected<InDtype, WeightDtype, AccDtype>::OpFullyConnected(SubgraphTraverser* sgt_,
+template <DType InDtype, DType WeightDtype, DType OutDtype>
+OpFullyConnected<InDtype, WeightDtype, OutDtype>::OpFullyConnected(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
uint64_t id_)
: GraphNode(sgt_, Op_FULLY_CONNECTED, id_)
@@ -1104,15 +1098,15 @@ OpFullyConnected<InDtype, WeightDtype, AccDtype>::OpFullyConnected(SubgraphTrave
INIT_ATTRIBUTE(FullyConnected);
}
-template <DType InDtype, DType WeightDtype, DType AccDtype>
-OpFullyConnected<InDtype, WeightDtype, AccDtype>::~OpFullyConnected()
+template <DType InDtype, DType WeightDtype, DType OutDtype>
+OpFullyConnected<InDtype, WeightDtype, OutDtype>::~OpFullyConnected()
{
if (attribute)
delete attribute;
}
-template <DType InDtype, DType WeightDtype, DType AccDtype>
-int OpFullyConnected<InDtype, WeightDtype, AccDtype>::checkTensorAttributes()
+template <DType InDtype, DType WeightDtype, DType OutDtype>
+int OpFullyConnected<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
{
if (validateRequiredOperands())
return 1;
@@ -1138,7 +1132,7 @@ int OpFullyConnected<InDtype, WeightDtype, AccDtype>::checkTensorAttributes()
return 1;
}
- ERROR_IF(outputs[0]->getDtype() != AccDtype,
+ ERROR_IF(outputs[0]->getDtype() != OutDtype,
"OpFullyConnected: Output data type not supported for this configuration of operator");
output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
@@ -1149,8 +1143,8 @@ int OpFullyConnected<InDtype, WeightDtype, AccDtype>::checkTensorAttributes()
return 0;
}
-template <DType InDtype, DType WeightDtype, DType AccDtype>
-int OpFullyConnected<InDtype, WeightDtype, AccDtype>::eval()
+template <DType InDtype, DType WeightDtype, DType OutDtype>
+int OpFullyConnected<InDtype, WeightDtype, OutDtype>::eval()
{
typedef Eigen::Tensor<int, 1>::DimensionPair DimPair;
Eigen::array<DimPair, 1> dims{ { DimPair(1, 0) } };
@@ -1177,7 +1171,7 @@ int OpFullyConnected<InDtype, WeightDtype, AccDtype>::eval()
input_val.template cast<AccEigenType>().contract(weight_val.template cast<AccEigenType>(), dims).template cast<OutEigenType>() +
this->bias->getTensor().reshape(bias_reshape).broadcast(bias_bcast);
- if (AccDtype == DType_INT48)
+ if (OutDtype == DType_INT48)
{
this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
@@ -1185,8 +1179,8 @@ int OpFullyConnected<InDtype, WeightDtype, AccDtype>::eval()
return GraphNode::eval();
}
-template <DType Dtype, DType AccDtype>
-OpMatMul<Dtype, AccDtype>::OpMatMul(SubgraphTraverser* sgt_,
+template <DType Dtype, DType OutDtype>
+OpMatMul<Dtype, OutDtype>::OpMatMul(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
uint64_t id_)
: GraphNode(sgt_, Op_MATMUL, id_)
@@ -1197,15 +1191,15 @@ OpMatMul<Dtype, AccDtype>::OpMatMul(SubgraphTraverser* sgt_,
INIT_ATTRIBUTE(MatMul);
}
-template <DType Dtype, DType AccDtype>
-OpMatMul<Dtype, AccDtype>::~OpMatMul()
+template <DType Dtype, DType OutDtype>
+OpMatMul<Dtype, OutDtype>::~OpMatMul()
{
if (attribute)
delete attribute;
}
-template <DType Dtype, DType AccDtype>
-int OpMatMul<Dtype, AccDtype>::checkTensorAttributes()
+template <DType Dtype, DType OutDtype>
+int OpMatMul<Dtype, OutDtype>::checkTensorAttributes()
{
if (validateRequiredOperands())
return 1;
@@ -1215,7 +1209,7 @@ int OpMatMul<Dtype, AccDtype>::checkTensorAttributes()
return 1;
}
- ERROR_IF(outputs[0]->getDtype() != AccDtype,
+ ERROR_IF(outputs[0]->getDtype() != OutDtype,
"OpMatMul: Output data type not supported for this configuration of operator");
a = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
@@ -1266,8 +1260,8 @@ int OpMatMul<Dtype, AccDtype>::checkTensorAttributes()
return 0;
}
-template <DType Dtype, DType AccDtype>
-int OpMatMul<Dtype, AccDtype>::eval()
+template <DType Dtype, DType OutDtype>
+int OpMatMul<Dtype, OutDtype>::eval()
{
typedef Eigen::Tensor<int, 1>::DimensionPair DimPair;
Eigen::array<DimPair, 1> dims{ { DimPair(1, 0) } };
@@ -1312,7 +1306,7 @@ int OpMatMul<Dtype, AccDtype>::eval()
}
}
- if (AccDtype == DType_INT48)
+ if (OutDtype == DType_INT48)
{
this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
@@ -1587,8 +1581,8 @@ int OpRFFT2d<Dtype>::eval()
return GraphNode::eval();
}
-template <DType InDtype, DType WeightDtype, DType AccDtype>
-OpTransposeConv2d<InDtype, WeightDtype, AccDtype>::OpTransposeConv2d(SubgraphTraverser* sgt_,
+template <DType InDtype, DType WeightDtype, DType OutDtype>
+OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::OpTransposeConv2d(SubgraphTraverser* sgt_,
TosaAttributeBase* attribute_,
uint64_t id_)
: GraphNode(sgt_, Op_TRANSPOSE_CONV2D, id_)
@@ -1599,15 +1593,15 @@ OpTransposeConv2d<InDtype, WeightDtype, AccDtype>::OpTransposeConv2d(SubgraphTra
INIT_ATTRIBUTE(TransposeConv);
}
-template <DType InDtype, DType WeightDtype, DType AccDtype>
-OpTransposeConv2d<InDtype, WeightDtype, AccDtype>::~OpTransposeConv2d()
+template <DType InDtype, DType WeightDtype, DType OutDtype>
+OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::~OpTransposeConv2d()
{
if (attribute)
delete attribute;
}
-template <DType InDtype, DType WeightDtype, DType AccDtype>
-int OpTransposeConv2d<InDtype, WeightDtype, AccDtype>::checkTensorAttributes()
+template <DType InDtype, DType WeightDtype, DType OutDtype>
+int OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes()
{
if (validateRequiredOperands())
return 1;
@@ -1617,7 +1611,7 @@ int OpTransposeConv2d<InDtype, WeightDtype, AccDtype>::checkTensorAttributes()
return 1;
}
- ERROR_IF(outputs[0]->getDtype() != AccDtype,
+ ERROR_IF(outputs[0]->getDtype() != OutDtype,
"OpTransposeConv2d: Output data type not supported for this configuration of operator");
input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
@@ -1701,8 +1695,8 @@ int OpTransposeConv2d<InDtype, WeightDtype, AccDtype>::checkTensorAttributes()
return 0;
}
-template <DType InDtype, DType WeightDtype, DType AccDtype>
-int OpTransposeConv2d<InDtype, WeightDtype, AccDtype>::eval()
+template <DType InDtype, DType WeightDtype, DType OutDtype>
+int OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::eval()
{
int in_batch = this->input->getShape()[0];
int in_height = this->input->getShape()[1];
@@ -1729,8 +1723,6 @@ int OpTransposeConv2d<InDtype, WeightDtype, AccDtype>::eval()
int stride_h = this->attribute->stride()[0];
int stride_w = this->attribute->stride()[1];
- tosa::DType accum_dtype = (tosa::DType)this->attribute->accum_dtype();
-
ERROR_IF(in_batch != out_batch, "OpTransposeConv2d: tensor batch mismatch %d != %d", in_batch, out_batch);
ERROR_IF(f_in_channels != in_channels, "OpTransposeConv2d: tensor input channel mismatch %d != %d", f_in_channels,
in_channels);
@@ -1741,10 +1733,10 @@ int OpTransposeConv2d<InDtype, WeightDtype, AccDtype>::eval()
DEBUG_INFO(OP,
"perform OpTransposeConv2d, input.shape=[%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d], "
- "output.shape=[%d,%d,%d,%d], stride=[%d,%d], out_pad=[%d,%d,%d,%d], accum_dtype=%s",
+ "output.shape=[%d,%d,%d,%d], stride=[%d,%d], out_pad=[%d,%d,%d,%d]",
in_batch, in_height, in_width, in_channels, f_height, f_width, f_out_channels, f_in_channels,
out_batch, out_height, out_width, out_channels, stride_h, stride_w, out_pad_top,
- out_pad_bottom, out_pad_left, out_pad_right, EnumNamesDType()[accum_dtype]);
+ out_pad_bottom, out_pad_left, out_pad_right);
TIn input_val = this->input->getTensor();
TWeight weight_val = this->weight->getTensor();
@@ -1803,7 +1795,7 @@ int OpTransposeConv2d<InDtype, WeightDtype, AccDtype>::eval()
}
}
- if (AccDtype == DType_INT48)
+ if (OutDtype == DType_INT48)
{
this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin);
this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax);
@@ -1819,52 +1811,52 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP32);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT8);
DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT16);
-DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, FP16, FP16);
-DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, FP16, FP32);
-DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, BF16, FP32);
-DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, FP32, FP32);
-DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, INT8, INT32);
-DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, INT16, INT32);
-
- // [in_t, weight_t, acc_t]
-DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, FP16, FP16, FP16);
-DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, FP16, FP16, FP32);
-DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, BF16, BF16, FP32);
-DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, FP32, FP32, FP32);
-DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, INT8, INT4, INT32);
-DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, INT8, INT8, INT32);
-DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, INT16, INT8, INT48);
-
-DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, FP16, FP16, FP16);
-DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, FP16, FP16, FP32);
-DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, BF16, BF16, FP32);
-DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, FP32, FP32, FP32);
-DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, INT8, INT4, INT32);
-DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, INT8, INT8, INT32);
-DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, INT16, INT8, INT48);
-
-DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, FP16, FP16, FP16);
-DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, FP16, FP16, FP32);
-DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, BF16, BF16, FP32);
-DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, FP32, FP32, FP32);
-DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, INT8, INT4, INT32);
-DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, INT8, INT8, INT32);
-DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, INT16, INT8, INT48);
-
-DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FP16, FP16, FP16);
-DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FP16, FP16, FP32);
-DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, BF16, BF16, FP32);
-DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FP32, FP32, FP32);
-DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, INT8, INT4, INT32);
-DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, INT8, INT8, INT32);
-DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, INT16, INT8, INT48);
-
-DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpMatMul, INT8, INT32);
-DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpMatMul, INT16, INT48);
-DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpMatMul, FP16, FP16);
-DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpMatMul, FP16, FP32);
-DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpMatMul, BF16, FP32);
-DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpMatMul, FP32, FP32);
+DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP16, FP16);
+DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP16, FP32);
+DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, BF16, FP32);
+DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP32, FP32);
+DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, INT8, INT32);
+DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, INT16, INT32);
+
+ // [in_t, weight_t, out_t]
+DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP16, FP16, FP16);
+DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP16, FP16, FP32);
+DEF_INSTANTIATE_THREE_TYPE(OpConv2d, BF16, BF16, FP32);
+DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP32, FP32, FP32);
+DEF_INSTANTIATE_THREE_TYPE(OpConv2d, INT8, INT4, INT32);
+DEF_INSTANTIATE_THREE_TYPE(OpConv2d, INT8, INT8, INT32);
+DEF_INSTANTIATE_THREE_TYPE(OpConv2d, INT16, INT8, INT48);
+
+DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP16, FP16, FP16);
+DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP16, FP16, FP32);
+DEF_INSTANTIATE_THREE_TYPE(OpConv3d, BF16, BF16, FP32);
+DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP32, FP32, FP32);
+DEF_INSTANTIATE_THREE_TYPE(OpConv3d, INT8, INT4, INT32);
+DEF_INSTANTIATE_THREE_TYPE(OpConv3d, INT8, INT8, INT32);
+DEF_INSTANTIATE_THREE_TYPE(OpConv3d, INT16, INT8, INT48);
+
+DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP16);
+DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP32);
+DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, BF16, BF16, FP32);
+DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP32, FP32, FP32);
+DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT8, INT4, INT32);
+DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT8, INT8, INT32);
+DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT16, INT8, INT48);
+
+DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, FP16, FP16, FP16);
+DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, FP16, FP16, FP32);
+DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, BF16, BF16, FP32);
+DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, FP32, FP32, FP32);
+DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, INT8, INT4, INT32);
+DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, INT8, INT8, INT32);
+DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, INT16, INT8, INT48);
+
+DEF_INSTANTIATE_TWO_TYPE(OpMatMul, INT8, INT32);
+DEF_INSTANTIATE_TWO_TYPE(OpMatMul, INT16, INT48);
+DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP16, FP16);
+DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP16, FP32);
+DEF_INSTANTIATE_TWO_TYPE(OpMatMul, BF16, FP32);
+DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP32, FP32);
DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP16);
DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, BF16);
@@ -1874,10 +1866,10 @@ DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, INT16);
DEF_INSTANTIATE_ONE_TYPE(OpRFFT2d, FP32);
-DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, FP16, FP16, FP16);
-DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, FP16, FP16, FP32);
-DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, BF16, BF16, FP32);
-DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, FP32, FP32, FP32);
-DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, INT8, INT4, INT32);
-DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, INT8, INT8, INT32);
-DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, INT16, INT8, INT48);
+DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP16, FP16, FP16);
+DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP16, FP16, FP32);
+DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, BF16, BF16, FP32);
+DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP32, FP32, FP32);
+DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, INT8, INT4, INT32);
+DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, INT8, INT8, INT32);
+DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, INT16, INT8, INT48);
diff --git a/reference_model/src/ops/tensor_ops.h b/reference_model/src/ops/tensor_ops.h
index ed9a55c..0d2b3eb 100644
--- a/reference_model/src/ops/tensor_ops.h
+++ b/reference_model/src/ops/tensor_ops.h
@@ -74,7 +74,7 @@ protected:
ETensor1<int32_t> calculate_div_map_1d(int in_size, int out_size, int kernel_size, int stride, int32_t padding_left, int32_t padding_right);
};
-template <DType InDtype, DType WeightDtype, DType AccDtype>
+template <DType InDtype, DType WeightDtype, DType OutDtype>
class OpConv2d : public GraphNode
{
public:
@@ -86,15 +86,15 @@ public:
using InEigenType = typename GetEigenType<InDtype>::type;
using WeightEigenType = typename GetEigenType<WeightDtype>::type;
- using AccEigenType = typename GetAccEigenType<AccDtype>::type; // Note: different from GetEigenType
- using OutEigenType = typename GetEigenType<AccDtype>::type;
+ using AccEigenType = typename GetAccEigenType<OutDtype>::type; // Note: different from GetEigenType
+ using OutEigenType = typename GetEigenType<OutDtype>::type;
using TIn = Eigen::Tensor<InEigenType, 4>;
using TWeight = Eigen::Tensor<WeightEigenType, 4>;
using TBias = Eigen::Tensor<OutEigenType, 1>;
using TOut = Eigen::Tensor<OutEigenType, 4>;
- static constexpr int64_t AccQMin = GetQMin<AccDtype>::value;
- static constexpr int64_t AccQMax = GetQMax<AccDtype>::value;
+ static constexpr int64_t AccQMin = GetQMin<OutDtype>::value;
+ static constexpr int64_t AccQMax = GetQMax<OutDtype>::value;
protected:
TosaReference::TensorTemplate<TIn>* input;
@@ -104,7 +104,7 @@ protected:
tosa::TosaConvAttribute* attribute;
};
-template <DType InDtype, DType WeightDtype, DType AccDtype>
+template <DType InDtype, DType WeightDtype, DType OutDtype>
class OpConv3d : public GraphNode
{
public:
@@ -116,15 +116,15 @@ public:
using InEigenType = typename GetEigenType<InDtype>::type;
using WeightEigenType = typename GetEigenType<WeightDtype>::type;
- using AccEigenType = typename GetAccEigenType<AccDtype>::type; // Note: different from GetEigenType
- using OutEigenType = typename GetEigenType<AccDtype>::type;
+ using AccEigenType = typename GetAccEigenType<OutDtype>::type; // Note: different from GetEigenType
+ using OutEigenType = typename GetEigenType<OutDtype>::type;
using TIn = Eigen::Tensor<InEigenType, 5>;
using TWeight = Eigen::Tensor<WeightEigenType, 5>;
using TBias = Eigen::Tensor<OutEigenType, 1>;
using TOut = Eigen::Tensor<OutEigenType, 5>;
- static constexpr int64_t AccQMin = GetQMin<AccDtype>::value;
- static constexpr int64_t AccQMax = GetQMax<AccDtype>::value;
+ static constexpr int64_t AccQMin = GetQMin<OutDtype>::value;
+ static constexpr int64_t AccQMax = GetQMax<OutDtype>::value;
protected:
TosaReference::TensorTemplate<TIn>* input;
@@ -134,7 +134,7 @@ protected:
tosa::TosaConvAttribute* attribute;
};
-template <DType InDtype, DType WeightDtype, DType AccDtype>
+template <DType InDtype, DType WeightDtype, DType OutDtype>
class OpDepthwiseConv2d : public GraphNode
{
public:
@@ -146,15 +146,15 @@ public:
using InEigenType = typename GetEigenType<InDtype>::type;
using WeightEigenType = typename GetEigenType<WeightDtype>::type;
- using AccEigenType = typename GetAccEigenType<AccDtype>::type; // Note: different from GetEigenType
- using OutEigenType = typename GetEigenType<AccDtype>::type;
+ using AccEigenType = typename GetAccEigenType<OutDtype>::type; // Note: different from GetEigenType
+ using OutEigenType = typename GetEigenType<OutDtype>::type;
using TIn = Eigen::Tensor<InEigenType, 4>;
using TWeight = Eigen::Tensor<WeightEigenType, 4>;
using TBias = Eigen::Tensor<OutEigenType, 1>;
using TOut = Eigen::Tensor<OutEigenType, 4>;
- static constexpr int64_t AccQMin = GetQMin<AccDtype>::value;
- static constexpr int64_t AccQMax = GetQMax<AccDtype>::value;
+ static constexpr int64_t AccQMin = GetQMin<OutDtype>::value;
+ static constexpr int64_t AccQMax = GetQMax<OutDtype>::value;
protected:
TosaReference::TensorTemplate<TIn>* input;
@@ -164,7 +164,7 @@ protected:
tosa::TosaConvAttribute* attribute;
};
-template <DType InDtype, DType WeightDtype, DType AccDtype>
+template <DType InDtype, DType WeightDtype, DType OutDtype>
class OpFullyConnected : public GraphNode
{
public:
@@ -176,15 +176,15 @@ public:
using InEigenType = typename GetEigenType<InDtype>::type;
using WeightEigenType = typename GetEigenType<WeightDtype>::type;
- using AccEigenType = typename GetAccEigenType<AccDtype>::type; // Note: different from GetEigenType
- using OutEigenType = typename GetEigenType<AccDtype>::type;
+ using AccEigenType = typename GetAccEigenType<OutDtype>::type; // Note: different from GetEigenType
+ using OutEigenType = typename GetEigenType<OutDtype>::type;
using TIn = Eigen::Tensor<InEigenType, 2>;
using TWeight = Eigen::Tensor<WeightEigenType, 2>;
using TBias = Eigen::Tensor<OutEigenType, 1>;
using TOut = Eigen::Tensor<OutEigenType, 2>;
- static constexpr int64_t AccQMin = GetQMin<AccDtype>::value;
- static constexpr int64_t AccQMax = GetQMax<AccDtype>::value;
+ static constexpr int64_t AccQMin = GetQMin<OutDtype>::value;
+ static constexpr int64_t AccQMax = GetQMax<OutDtype>::value;
protected:
TosaReference::TensorTemplate<TIn>* input;
@@ -195,7 +195,7 @@ protected:
tosa::TosaFullyConnectedAttribute* attribute;
};
-template <DType Dtype, DType AccDtype>
+template <DType Dtype, DType OutDtype>
class OpMatMul : public GraphNode
{
public:
@@ -206,14 +206,14 @@ public:
virtual int eval() final;
using InEigenType = typename GetEigenType<Dtype>::type;
- using AccEigenType = typename GetAccEigenType<AccDtype>::type; // Note: different from GetEigenType
- using OutEigenType = typename GetEigenType<AccDtype>::type;
+ using AccEigenType = typename GetAccEigenType<OutDtype>::type; // Note: different from GetEigenType
+ using OutEigenType = typename GetEigenType<OutDtype>::type;
using TIn = Eigen::Tensor<InEigenType, 3>;
using TOut = Eigen::Tensor<OutEigenType, 3>;
using TInRank2 = Eigen::Tensor<InEigenType, 2>;
using TAccRank2 = Eigen::Tensor<AccEigenType, 2>;
- static constexpr int64_t AccQMin = GetQMin<AccDtype>::value;
- static constexpr int64_t AccQMax = GetQMax<AccDtype>::value;
+ static constexpr int64_t AccQMin = GetQMin<OutDtype>::value;
+ static constexpr int64_t AccQMax = GetQMax<OutDtype>::value;
protected:
TosaReference::TensorTemplate<TIn>* a;
@@ -269,7 +269,7 @@ protected:
TosaReference::TensorTemplate<TOut>* out_imag;
};
-template <DType InDtype, DType WeightDtype, DType AccDtype>
+template <DType InDtype, DType WeightDtype, DType OutDtype>
class OpTransposeConv2d : public GraphNode
{
public:
@@ -281,15 +281,15 @@ public:
using InEigenType = typename GetEigenType<InDtype>::type;
using WeightEigenType = typename GetEigenType<WeightDtype>::type;
- using AccEigenType = typename GetAccEigenType<AccDtype>::type; // Note: different from GetEigenType
- using OutEigenType = typename GetEigenType<AccDtype>::type;
+ using AccEigenType = typename GetAccEigenType<OutDtype>::type; // Note: different from GetEigenType
+ using OutEigenType = typename GetEigenType<OutDtype>::type;
using TIn = Eigen::Tensor<InEigenType, 4>;
using TWeight = Eigen::Tensor<WeightEigenType, 4>;
using TBias = Eigen::Tensor<OutEigenType, 1>;
using TOut = Eigen::Tensor<OutEigenType, 4>;
- static constexpr int64_t AccQMin = GetQMin<AccDtype>::value;
- static constexpr int64_t AccQMax = GetQMax<AccDtype>::value;
+ static constexpr int64_t AccQMin = GetQMin<OutDtype>::value;
+ static constexpr int64_t AccQMax = GetQMax<OutDtype>::value;
protected:
TosaReference::TensorTemplate<TIn>* input;
diff --git a/scripts/operator_api/generate_api.py b/scripts/operator_api/generate_api.py
index 1f89f74..671d902 100644
--- a/scripts/operator_api/generate_api.py
+++ b/scripts/operator_api/generate_api.py
@@ -9,6 +9,8 @@ from xml.dom import minidom
from jinja2 import Environment
from jinja2 import FileSystemLoader
+# Note: main script designed to be run from the scripts/operator_api/ directory
+
def getTosaArgTypes(tosaXml):
"""
diff --git a/thirdparty/serialization_lib b/thirdparty/serialization_lib
-Subproject c15f7d52aa4f360eba2344449baa418b7608ac7
+Subproject 80905bba37ce55e8db293b1405a78b63dc1855c
diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py
index fddf942..5f9e2c1 100644
--- a/verif/generator/tosa_test_gen.py
+++ b/verif/generator/tosa_test_gen.py
@@ -690,7 +690,7 @@ class TosaTestGen:
return None
attr = ts.TosaSerializerAttribute()
- attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], accum_dtype)
+ attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1])
self.ser.addOperator(op["op"], input_list, output_list, attr)
return result_tens
@@ -762,7 +762,7 @@ class TosaTestGen:
return None
attr = ts.TosaSerializerAttribute()
- attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], accum_dtype)
+ attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1])
self.ser.addOperator(op["op"], input_list, output_list, attr)
return result_tens
@@ -825,9 +825,7 @@ class TosaTestGen:
return None
attr = ts.TosaSerializerAttribute()
- attr.TransposeConvAttribute(
- out_pad, stride, output_shape, qinfo[0], qinfo[1], accum_dtype
- )
+ attr.TransposeConvAttribute(out_pad, stride, output_shape, qinfo[0], qinfo[1])
self.ser.addOperator(op["op"], input_list, output_list, attr)
return result_tens
@@ -898,7 +896,7 @@ class TosaTestGen:
return None
attr = ts.TosaSerializerAttribute()
- attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], accum_dtype)
+ attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1])
self.ser.addOperator(op["op"], input_list, output_list, attr)
return result_tens
@@ -947,7 +945,7 @@ class TosaTestGen:
return None
attr = ts.TosaSerializerAttribute()
- attr.FullyConnectedAttribute(qinfo[0], qinfo[1], accum_dtype)
+ attr.FullyConnectedAttribute(qinfo[0], qinfo[1])
self.ser.addOperator(op["op"], input_list, output_list, attr)
return result_tens
@@ -989,7 +987,7 @@ class TosaTestGen:
return None
attr = ts.TosaSerializerAttribute()
- attr.MatMulAttribute(qinfo[0], qinfo[1], accum_dtype)
+ attr.MatMulAttribute(qinfo[0], qinfo[1])
self.ser.addOperator(op["op"], input_list, output_list, attr)
return result_tens