From f36f25619cc3a34c75e78637ed244a2ca54ab3f4 Mon Sep 17 00:00:00 2001 From: Tai Ly Date: Thu, 14 Mar 2024 16:21:29 +0000 Subject: [ref model] Add acc_type to Conv Ops This patch implements changes required by the new acc_type field in ConvAttribute and TransposeConvAttribute Signed-off-by: Tai Ly Signed-off-by: Jeremy Johnson Change-Id: Ib13dbeec4d8920e0ddbcca02b727e7277f2c8d62 --- .gitignore | 3 + ...v2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11.tosa | Bin 1676 -> 1708 bytes ...v2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11.tosa | Bin 1380 -> 1412 bytes ...v2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11.tosa | Bin 1484 -> 1516 bytes .../test_lstm_stateful_13x21x3_f32.tosa | Bin 11920 -> 11956 bytes reference_model/src/graph_node.h | 4 + reference_model/src/ops/op_factory.cc | 85 ++++---- reference_model/src/ops/op_factory.h | 8 + reference_model/src/ops/tensor_ops.cc | 163 +++++++------- reference_model/src/ops/tensor_ops.h | 16 +- thirdparty/serialization_lib | 2 +- verif/generator/tosa_arg_gen.py | 240 +++++++++++---------- verif/generator/tosa_error_if.py | 34 ++- verif/generator/tosa_test_gen.py | 42 +++- verif/generator/tosa_utils.py | 14 +- 15 files changed, 356 insertions(+), 255 deletions(-) diff --git a/.gitignore b/.gitignore index 941cf20..dddbcc0 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,9 @@ __pycache__/ build/ debug-build/ +conformance/ +conformance_build/ +conformance_large_files/ .cache compile_commands.json dist/ diff --git a/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/flatbuffer-tf/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11.tosa b/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/flatbuffer-tf/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11.tosa index 20e1333..dc8413a 100644 Binary files a/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/flatbuffer-tf/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11.tosa and b/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/flatbuffer-tf/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11.tosa differ diff --git a/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/flatbuffer-tflite/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11.tosa b/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/flatbuffer-tflite/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11.tosa index d55d5d6..521c1a3 100644 Binary files a/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/flatbuffer-tflite/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11.tosa and b/examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/flatbuffer-tflite/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11.tosa differ diff --git a/examples/test_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11/flatbuffer-tflite/test_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11.tosa b/examples/test_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11/flatbuffer-tflite/test_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11.tosa index 01d8375..73c1839 100644 Binary files a/examples/test_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11/flatbuffer-tflite/test_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11.tosa and b/examples/test_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11/flatbuffer-tflite/test_conv2d_1x1_1x32x32x8_qi8_st11_padSAME_dilat11.tosa differ diff --git a/examples/test_lstm_stateful_13x21x3_f32/flatbuffer-tflite/test_lstm_stateful_13x21x3_f32.tosa b/examples/test_lstm_stateful_13x21x3_f32/flatbuffer-tflite/test_lstm_stateful_13x21x3_f32.tosa index deaca6e..fff67ef 100644 Binary files a/examples/test_lstm_stateful_13x21x3_f32/flatbuffer-tflite/test_lstm_stateful_13x21x3_f32.tosa and b/examples/test_lstm_stateful_13x21x3_f32/flatbuffer-tflite/test_lstm_stateful_13x21x3_f32.tosa differ diff --git a/reference_model/src/graph_node.h b/reference_model/src/graph_node.h index e10f132..c0dceda 100644 --- a/reference_model/src/graph_node.h +++ b/reference_model/src/graph_node.h @@ -41,6 +41,10 @@ #define DEF_INSTANTIATE_THREE_TYPE(OP, DTYPE1, DTYPE2, DTYPE3) \ template class TosaReference::OP; +#define DEF_INSTANTIATE_FOUR_TYPE(OP, DTYPE1, DTYPE2, DTYPE3, DTYPE4) \ + template class TosaReference::OP; + #define DEF_INSTANTIATE_THREE_TYPE_RESIZE(OP, DTYPE1, DTYPE2, OP_TYPE) \ template class TosaReference::OP; diff --git a/reference_model/src/ops/op_factory.cc b/reference_model/src/ops/op_factory.cc index 0f0013c..74315d7 100644 --- a/reference_model/src/ops/op_factory.cc +++ b/reference_model/src/ops/op_factory.cc @@ -70,41 +70,43 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FP8E5M2, FP16); break; case Op_CONV2D: - DEF_FACTORY_THREE_TYPE(OpConv2d, FP16, FP16, FP16); - DEF_FACTORY_THREE_TYPE(OpConv2d, FP16, FP16, FP32); - DEF_FACTORY_THREE_TYPE(OpConv2d, BF16, BF16, FP32); - DEF_FACTORY_THREE_TYPE(OpConv2d, FP32, FP32, FP32); - DEF_FACTORY_THREE_TYPE(OpConv2d, INT8, INT4, INT32); - DEF_FACTORY_THREE_TYPE(OpConv2d, INT8, INT8, INT32); - DEF_FACTORY_THREE_TYPE(OpConv2d, INT16, INT8, INT48); - DEF_FACTORY_THREE_TYPE(OpConv2d, FP64, FP64, FP64); - DEF_FACTORY_THREE_TYPE(OpConv2d, FP64, FP64, FP64); - DEF_FACTORY_THREE_TYPE(OpConv2d, FP8E4M3, FP8E4M3, FP16); - DEF_FACTORY_THREE_TYPE(OpConv2d, FP8E5M2, FP8E5M2, FP16); + // OP, attr_name, in_t, w_t, acc_t, out_t + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv2d, Conv, FP16, FP16, FP16, FP16); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv2d, Conv, FP16, FP16, FP32, FP16); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv2d, Conv, BF16, BF16, FP32, BF16); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv2d, Conv, FP32, FP32, FP32, FP32); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv2d, Conv, INT8, INT4, INT32, INT32); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv2d, Conv, INT8, INT8, INT32, INT32); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv2d, Conv, INT16, INT8, INT48, INT48); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv2d, Conv, FP64, FP64, FP64, FP64); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv2d, Conv, FP8E4M3, FP8E4M3, FP16, FP16); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv2d, Conv, FP8E5M2, FP8E5M2, FP16, FP16); break; case Op_CONV3D: - DEF_FACTORY_THREE_TYPE(OpConv3d, FP16, FP16, FP16); - DEF_FACTORY_THREE_TYPE(OpConv3d, FP16, FP16, FP32); - DEF_FACTORY_THREE_TYPE(OpConv3d, BF16, BF16, FP32); - DEF_FACTORY_THREE_TYPE(OpConv3d, FP32, FP32, FP32); - DEF_FACTORY_THREE_TYPE(OpConv3d, INT8, INT4, INT32); - DEF_FACTORY_THREE_TYPE(OpConv3d, INT8, INT8, INT32); - DEF_FACTORY_THREE_TYPE(OpConv3d, INT16, INT8, INT48); - DEF_FACTORY_THREE_TYPE(OpConv3d, FP64, FP64, FP64); - DEF_FACTORY_THREE_TYPE(OpConv3d, FP8E4M3, FP8E4M3, FP16); - DEF_FACTORY_THREE_TYPE(OpConv3d, FP8E5M2, FP8E5M2, FP16); + // OP, attr_name, in_t, w_t, acc_t, out_t + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv3d, Conv, FP16, FP16, FP16, FP16); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv3d, Conv, FP16, FP16, FP32, FP16); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv3d, Conv, BF16, BF16, FP32, BF16); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv3d, Conv, FP32, FP32, FP32, FP32); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv3d, Conv, INT8, INT4, INT32, INT32); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv3d, Conv, INT8, INT8, INT32, INT32); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv3d, Conv, INT16, INT8, INT48, INT48); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv3d, Conv, FP64, FP64, FP64, FP64); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv3d, Conv, FP8E4M3, FP8E4M3, FP16, FP16); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv3d, Conv, FP8E5M2, FP8E5M2, FP16, FP16); break; case Op_DEPTHWISE_CONV2D: - DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP16); - DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP32); - DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, BF16, BF16, FP32); - DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP32, FP32, FP32); - DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, INT8, INT4, INT32); - DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, INT8, INT8, INT32); - DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, INT16, INT8, INT48); - DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP64, FP64, FP64); - DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP8E4M3, FP8E4M3, FP16); - DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP8E5M2, FP8E5M2, FP16); + // OP, attr_name, in_t, w_t, acc_t, out_t + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP16, FP16, FP16, FP16); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP16, FP16, FP32, FP16); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, BF16, BF16, FP32, BF16); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP32, FP32, FP32, FP32); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, INT8, INT4, INT32, INT32); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, INT8, INT8, INT32, INT32); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, INT16, INT8, INT48, INT48); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP64, FP64, FP64, FP64); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP8E4M3, FP8E4M3, FP16, FP16); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP8E5M2, FP8E5M2, FP16, FP16); break; case Op_FFT2D: DEF_FACTORY_ONE_TYPE(OpFFT2d, FP32); @@ -148,16 +150,17 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_ONE_TYPE(OpRFFT2d, FP64); break; case Op_TRANSPOSE_CONV2D: - DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, FP16, FP16, FP16); - DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, FP16, FP16, FP32); - DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, BF16, BF16, FP32); - DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, FP32, FP32, FP32); - DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, INT8, INT4, INT32); - DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, INT8, INT8, INT32); - DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, INT16, INT8, INT48); - DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, FP64, FP64, FP64); - DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, FP8E4M3, FP8E4M3, FP16); - DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, FP8E5M2, FP8E5M2, FP16); + // OP, attr_name, in_t, w_t, acc_t, out_t + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP16, FP16, FP16, FP16); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP16, FP16, FP32, FP16); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, BF16, BF16, FP32, BF16); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP32, FP32, FP32, FP32); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, INT8, INT4, INT32, INT32); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, INT8, INT8, INT32, INT32); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, INT16, INT8, INT48, INT48); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP64, FP64, FP64, FP64); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP8E4M3, FP8E4M3, FP16, FP16); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP8E5M2, FP8E5M2, FP16, FP16); break; // activation_funcs diff --git a/reference_model/src/ops/op_factory.h b/reference_model/src/ops/op_factory.h index 1d20066..f1d1680 100644 --- a/reference_model/src/ops/op_factory.h +++ b/reference_model/src/ops/op_factory.h @@ -94,6 +94,14 @@ return new OP(sgt, attribute, id); \ } +#define DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OP, ATTR_NAME, IN_DTYPE, W_DTYPE, ACC_DTYPE, OUT_DTYPE) \ + if (inputDTYPE == TOSA_REF_TYPE_##IN_DTYPE && weightDTYPE == TOSA_REF_TYPE_##W_DTYPE && \ + outputDTYPE == TOSA_REF_TYPE_##OUT_DTYPE && ACCUM_FROM_ATTRIBUTE(ATTR_NAME) == TOSA_REF_TYPE_##ACC_DTYPE) \ + { \ + return new OP(sgt, attribute, id); \ + } + // Statement-expression to evaluate accumulate attribute in-place #define ACCUM_FROM_ATTRIBUTE(ATTRIBUTE_NAME) \ ({ \ diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc index 7bd249b..afd20e9 100644 --- a/reference_model/src/ops/tensor_ops.cc +++ b/reference_model/src/ops/tensor_ops.cc @@ -586,8 +586,10 @@ int OpAvgPool2d::eval() return GraphNode::eval(); } -template -OpConv2d::OpConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) +template +OpConv2d::OpConv2d(SubgraphTraverser* sgt_, + TosaAttributeBase* attribute_, + uint64_t id_) : GraphNode(sgt_, Op_CONV2D, id_) { setRequiredOperands(3, 1); @@ -596,15 +598,15 @@ OpConv2d::OpConv2d(SubgraphTraverser* sgt_, Tosa INIT_ATTRIBUTE(Conv); } -template -OpConv2d::~OpConv2d() +template +OpConv2d::~OpConv2d() { if (attribute) delete attribute; } -template -int OpConv2d::checkTensorAttributes() +template +int OpConv2d::checkTensorAttributes() { if (validateRequiredOperands()) return 1; @@ -640,8 +642,8 @@ int OpConv2d::checkTensorAttributes() return 0; } -template -int OpConv2d::eval() +template +int OpConv2d::eval() { int in_batch = this->input->getShape()[0]; int in_height = this->input->getShape()[1]; @@ -793,8 +795,10 @@ int OpConv2d::eval() return GraphNode::eval(); } -template -OpConv3d::OpConv3d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) +template +OpConv3d::OpConv3d(SubgraphTraverser* sgt_, + TosaAttributeBase* attribute_, + uint64_t id_) : GraphNode(sgt_, Op_CONV3D, id_) { setRequiredOperands(3, 1); @@ -803,15 +807,15 @@ OpConv3d::OpConv3d(SubgraphTraverser* sgt_, Tosa INIT_ATTRIBUTE(Conv); } -template -OpConv3d::~OpConv3d() +template +OpConv3d::~OpConv3d() { if (attribute) delete attribute; } -template -int OpConv3d::checkTensorAttributes() +template +int OpConv3d::checkTensorAttributes() { if (validateRequiredOperands()) return 1; @@ -847,8 +851,8 @@ int OpConv3d::checkTensorAttributes() return 0; } -template -int OpConv3d::eval() +template +int OpConv3d::eval() { int in_batch = this->input->getShape()[0]; int in_depth = this->input->getShape()[1]; @@ -1008,10 +1012,10 @@ int OpConv3d::eval() return GraphNode::eval(); } -template -OpDepthwiseConv2d::OpDepthwiseConv2d(SubgraphTraverser* sgt_, - TosaAttributeBase* attribute_, - uint64_t id_) +template +OpDepthwiseConv2d::OpDepthwiseConv2d(SubgraphTraverser* sgt_, + TosaAttributeBase* attribute_, + uint64_t id_) : GraphNode(sgt_, Op_DEPTHWISE_CONV2D, id_) { setRequiredOperands(3, 1); @@ -1020,15 +1024,15 @@ OpDepthwiseConv2d::OpDepthwiseConv2d(SubgraphTra INIT_ATTRIBUTE(Conv); } -template -OpDepthwiseConv2d::~OpDepthwiseConv2d() +template +OpDepthwiseConv2d::~OpDepthwiseConv2d() { if (attribute) delete attribute; } -template -int OpDepthwiseConv2d::checkTensorAttributes() +template +int OpDepthwiseConv2d::checkTensorAttributes() { if (validateRequiredOperands()) return 1; @@ -1064,8 +1068,8 @@ int OpDepthwiseConv2d::checkTensorAttributes() return 0; } -template -int OpDepthwiseConv2d::eval() +template +int OpDepthwiseConv2d::eval() { int in_batch = this->input->getShape()[0]; int in_height = this->input->getShape()[1]; @@ -1903,10 +1907,10 @@ int OpRFFT2d::eval() return GraphNode::eval(); } -template -OpTransposeConv2d::OpTransposeConv2d(SubgraphTraverser* sgt_, - TosaAttributeBase* attribute_, - uint64_t id_) +template +OpTransposeConv2d::OpTransposeConv2d(SubgraphTraverser* sgt_, + TosaAttributeBase* attribute_, + uint64_t id_) : GraphNode(sgt_, Op_TRANSPOSE_CONV2D, id_) { setRequiredOperands(3, 1); @@ -1915,15 +1919,15 @@ OpTransposeConv2d::OpTransposeConv2d(SubgraphTra INIT_ATTRIBUTE(TransposeConv); } -template -OpTransposeConv2d::~OpTransposeConv2d() +template +OpTransposeConv2d::~OpTransposeConv2d() { if (attribute) delete attribute; } -template -int OpTransposeConv2d::checkTensorAttributes() +template +int OpTransposeConv2d::checkTensorAttributes() { if (validateRequiredOperands()) return 1; @@ -2017,8 +2021,8 @@ int OpTransposeConv2d::checkTensorAttributes() return 0; } -template -int OpTransposeConv2d::eval() +template +int OpTransposeConv2d::eval() { int in_batch = this->input->getShape()[0]; int in_height = this->input->getShape()[1]; @@ -2168,39 +2172,39 @@ DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP64, FP64); DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP8E4M3, FP16); DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP8E5M2, FP16); -// [in_t, weight_t, out_t] -DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP16, FP16, FP16); -DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP16, FP16, FP32); -DEF_INSTANTIATE_THREE_TYPE(OpConv2d, BF16, BF16, FP32); -DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP32, FP32, FP32); -DEF_INSTANTIATE_THREE_TYPE(OpConv2d, INT8, INT4, INT32); -DEF_INSTANTIATE_THREE_TYPE(OpConv2d, INT8, INT8, INT32); -DEF_INSTANTIATE_THREE_TYPE(OpConv2d, INT16, INT8, INT48); -DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP64, FP64, FP64); -DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP8E4M3, FP8E4M3, FP16); -DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP8E5M2, FP8E5M2, FP16); - -DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP16, FP16, FP16); -DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP16, FP16, FP32); -DEF_INSTANTIATE_THREE_TYPE(OpConv3d, BF16, BF16, FP32); -DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP32, FP32, FP32); -DEF_INSTANTIATE_THREE_TYPE(OpConv3d, INT8, INT4, INT32); -DEF_INSTANTIATE_THREE_TYPE(OpConv3d, INT8, INT8, INT32); -DEF_INSTANTIATE_THREE_TYPE(OpConv3d, INT16, INT8, INT48); -DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP64, FP64, FP64); -DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP8E4M3, FP8E4M3, FP16); -DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP8E5M2, FP8E5M2, FP16); - -DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP16); -DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP32); -DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, BF16, BF16, FP32); -DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP32, FP32, FP32); -DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT8, INT4, INT32); -DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT8, INT8, INT32); -DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT16, INT8, INT48); -DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP64, FP64, FP64); -DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP8E4M3, FP8E4M3, FP16); -DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP8E5M2, FP8E5M2, FP16); +// [in_t, weight_t, acc_t, out_t] +DEF_INSTANTIATE_FOUR_TYPE(OpConv2d, FP16, FP16, FP16, FP16); +DEF_INSTANTIATE_FOUR_TYPE(OpConv2d, FP16, FP16, FP32, FP16); +DEF_INSTANTIATE_FOUR_TYPE(OpConv2d, BF16, BF16, FP32, BF16); +DEF_INSTANTIATE_FOUR_TYPE(OpConv2d, FP32, FP32, FP32, FP32); +DEF_INSTANTIATE_FOUR_TYPE(OpConv2d, INT8, INT4, INT32, INT32); +DEF_INSTANTIATE_FOUR_TYPE(OpConv2d, INT8, INT8, INT32, INT32); +DEF_INSTANTIATE_FOUR_TYPE(OpConv2d, INT16, INT8, INT48, INT48); +DEF_INSTANTIATE_FOUR_TYPE(OpConv2d, FP64, FP64, FP64, FP64); +DEF_INSTANTIATE_FOUR_TYPE(OpConv2d, FP8E4M3, FP8E4M3, FP16, FP16); +DEF_INSTANTIATE_FOUR_TYPE(OpConv2d, FP8E5M2, FP8E5M2, FP16, FP16); + +DEF_INSTANTIATE_FOUR_TYPE(OpConv3d, FP16, FP16, FP16, FP16); +DEF_INSTANTIATE_FOUR_TYPE(OpConv3d, FP16, FP16, FP32, FP16); +DEF_INSTANTIATE_FOUR_TYPE(OpConv3d, BF16, BF16, FP32, BF16); +DEF_INSTANTIATE_FOUR_TYPE(OpConv3d, FP32, FP32, FP32, FP32); +DEF_INSTANTIATE_FOUR_TYPE(OpConv3d, INT8, INT4, INT32, INT32); +DEF_INSTANTIATE_FOUR_TYPE(OpConv3d, INT8, INT8, INT32, INT32); +DEF_INSTANTIATE_FOUR_TYPE(OpConv3d, INT16, INT8, INT48, INT48); +DEF_INSTANTIATE_FOUR_TYPE(OpConv3d, FP64, FP64, FP64, FP64); +DEF_INSTANTIATE_FOUR_TYPE(OpConv3d, FP8E4M3, FP8E4M3, FP16, FP16); +DEF_INSTANTIATE_FOUR_TYPE(OpConv3d, FP8E5M2, FP8E5M2, FP16, FP16); + +DEF_INSTANTIATE_FOUR_TYPE(OpDepthwiseConv2d, FP16, FP16, FP16, FP16); +DEF_INSTANTIATE_FOUR_TYPE(OpDepthwiseConv2d, FP16, FP16, FP32, FP16); +DEF_INSTANTIATE_FOUR_TYPE(OpDepthwiseConv2d, BF16, BF16, FP32, BF16); +DEF_INSTANTIATE_FOUR_TYPE(OpDepthwiseConv2d, FP32, FP32, FP32, FP32); +DEF_INSTANTIATE_FOUR_TYPE(OpDepthwiseConv2d, INT8, INT4, INT32, INT32); +DEF_INSTANTIATE_FOUR_TYPE(OpDepthwiseConv2d, INT8, INT8, INT32, INT32); +DEF_INSTANTIATE_FOUR_TYPE(OpDepthwiseConv2d, INT16, INT8, INT48, INT48); +DEF_INSTANTIATE_FOUR_TYPE(OpDepthwiseConv2d, FP64, FP64, FP64, FP64); +DEF_INSTANTIATE_FOUR_TYPE(OpDepthwiseConv2d, FP8E4M3, FP8E4M3, FP16, FP16); +DEF_INSTANTIATE_FOUR_TYPE(OpDepthwiseConv2d, FP8E5M2, FP8E5M2, FP16, FP16); DEF_INSTANTIATE_ONE_TYPE(OpFFT2d, FP32); DEF_INSTANTIATE_ONE_TYPE(OpFFT2d, FP64); @@ -2238,13 +2242,14 @@ DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP8E5M2); DEF_INSTANTIATE_ONE_TYPE(OpRFFT2d, FP32); DEF_INSTANTIATE_ONE_TYPE(OpRFFT2d, FP64); -DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP16, FP16, FP16); -DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP16, FP16, FP32); -DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, BF16, BF16, FP32); -DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP32, FP32, FP32); -DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, INT8, INT4, INT32); -DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, INT8, INT8, INT32); -DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, INT16, INT8, INT48); -DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP64, FP64, FP64); -DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP8E4M3, FP8E4M3, FP16); -DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP8E5M2, FP8E5M2, FP16); +// [in_t, weight_t, acc_t, out_t] +DEF_INSTANTIATE_FOUR_TYPE(OpTransposeConv2d, FP16, FP16, FP16, FP16); +DEF_INSTANTIATE_FOUR_TYPE(OpTransposeConv2d, FP16, FP16, FP32, FP16); +DEF_INSTANTIATE_FOUR_TYPE(OpTransposeConv2d, BF16, BF16, FP32, BF16); +DEF_INSTANTIATE_FOUR_TYPE(OpTransposeConv2d, FP32, FP32, FP32, FP32); +DEF_INSTANTIATE_FOUR_TYPE(OpTransposeConv2d, INT8, INT4, INT32, INT32); +DEF_INSTANTIATE_FOUR_TYPE(OpTransposeConv2d, INT8, INT8, INT32, INT32); +DEF_INSTANTIATE_FOUR_TYPE(OpTransposeConv2d, INT16, INT8, INT48, INT48); +DEF_INSTANTIATE_FOUR_TYPE(OpTransposeConv2d, FP64, FP64, FP64, FP64); +DEF_INSTANTIATE_FOUR_TYPE(OpTransposeConv2d, FP8E4M3, FP8E4M3, FP16, FP16); +DEF_INSTANTIATE_FOUR_TYPE(OpTransposeConv2d, FP8E5M2, FP8E5M2, FP16, FP16); diff --git a/reference_model/src/ops/tensor_ops.h b/reference_model/src/ops/tensor_ops.h index e2bb811..2e65548 100644 --- a/reference_model/src/ops/tensor_ops.h +++ b/reference_model/src/ops/tensor_ops.h @@ -75,7 +75,7 @@ protected: int in_size, int out_size, int kernel_size, int stride, int32_t padding_left, int32_t padding_right); }; -template +template class OpConv2d : public GraphNode { public: @@ -87,7 +87,7 @@ public: using InEigenType = typename GetEigenType::type; using WeightEigenType = typename GetEigenType::type; - using AccEigenType = typename GetAccEigenType::type; // Note: different from GetEigenType + using AccEigenType = typename GetAccEigenType::type; // Note: different from GetEigenType using OutEigenType = typename GetEigenType::type; using TIn = Eigen::Tensor; using TWeight = Eigen::Tensor; @@ -105,7 +105,7 @@ protected: tosa::TosaConvAttribute* attribute; }; -template +template class OpConv3d : public GraphNode { public: @@ -117,7 +117,7 @@ public: using InEigenType = typename GetEigenType::type; using WeightEigenType = typename GetEigenType::type; - using AccEigenType = typename GetAccEigenType::type; // Note: different from GetEigenType + using AccEigenType = typename GetAccEigenType::type; // Note: different from GetEigenType using OutEigenType = typename GetEigenType::type; using TIn = Eigen::Tensor; using TWeight = Eigen::Tensor; @@ -135,7 +135,7 @@ protected: tosa::TosaConvAttribute* attribute; }; -template +template class OpDepthwiseConv2d : public GraphNode { public: @@ -147,7 +147,7 @@ public: using InEigenType = typename GetEigenType::type; using WeightEigenType = typename GetEigenType::type; - using AccEigenType = typename GetAccEigenType::type; // Note: different from GetEigenType + using AccEigenType = typename GetAccEigenType::type; // Note: different from GetEigenType using OutEigenType = typename GetEigenType::type; using TIn = Eigen::Tensor; using TWeight = Eigen::Tensor; @@ -294,7 +294,7 @@ protected: tosa::TosaRFFTAttribute* attribute; }; -template +template class OpTransposeConv2d : public GraphNode { public: @@ -306,7 +306,7 @@ public: using InEigenType = typename GetEigenType::type; using WeightEigenType = typename GetEigenType::type; - using AccEigenType = typename GetAccEigenType::type; // Note: different from GetEigenType + using AccEigenType = typename GetAccEigenType::type; // Note: different from GetEigenType using OutEigenType = typename GetEigenType::type; using TIn = Eigen::Tensor; using TWeight = Eigen::Tensor; diff --git a/thirdparty/serialization_lib b/thirdparty/serialization_lib index 0b6d7c2..ad78daa 160000 --- a/thirdparty/serialization_lib +++ b/thirdparty/serialization_lib @@ -1 +1 @@ -Subproject commit 0b6d7c271af1e6593e6a2cf14b32acea765f4b64 +Subproject commit ad78daaf0fa1e41742cbed314459c3dbbb483c20 diff --git a/verif/generator/tosa_arg_gen.py b/verif/generator/tosa_arg_gen.py index 83487a1..ffa3683 100644 --- a/verif/generator/tosa_arg_gen.py +++ b/verif/generator/tosa_arg_gen.py @@ -1990,7 +1990,12 @@ class TosaArgGen: # Shape: (OFM channels), (KD), KH, KW, IFM channels filter_shape = shapeList[1] - accum_dtype = gtu.get_accum_dtype_from_tgTypes(dtypes) + accum_dtypes = gtu.get_accum_dtypes_from_tgTypes(dtypes) + + if error_name == ErrorIf.WrongAccumulatorType: + accum_dtypes = ( + [DType.BF16] if gtu.dtypeIsFloat(dtypes[0]) else [DType.INT16] + ) # Op type checks conv3d = opName.startswith("conv3d") @@ -2110,88 +2115,91 @@ class TosaArgGen: sparsity = 1 n = 0 - for s in sorted(list(strides)): - for p in sorted(list(paddings)): - for d in sorted(list(dilations)): - if ( - n % sparsity == 0 - # the padded shape must exceed the dilation * kernel to get a positive - # sized output shape - and (ifm_shape[1] - 1 + p[0] + p[1]) > d[0] * (k_shape[0] - 1) - and (ifm_shape[2] - 1 + p[2] + p[3]) > d[1] * (k_shape[1] - 1) - and ( - k_rank < 3 - or ( - (ifm_shape[3] - 1 + p[4] + p[5]) - > d[2] * (k_shape[2] - 1) - ) - ) - ): - remainders = [] - outputs = [] - for index in range(k_rank): - pad_offset = index * 2 - partial = ( - ifm_shape[index + 1] - - 1 - + p[pad_offset] - + p[pad_offset + 1] - - (k_shape[index] - 1) * d[index] - ) - remainders.append(partial % s[index]) - outputs.append((partial // s[index]) + 1) - + for a in accum_dtypes: + for s in sorted(list(strides)): + for p in sorted(list(paddings)): + for d in sorted(list(dilations)): if ( - # the parameters must produce integer exact output - error_name != ErrorIf.ConvOutputShapeNonInteger - and max(remainders) == 0 - ) or ( - error_name == ErrorIf.ConvOutputShapeNonInteger - and max(remainders) > 0 + n % sparsity == 0 + # the padded shape must exceed the dilation * kernel to get a positive + # sized output shape + and (ifm_shape[1] - 1 + p[0] + p[1]) + > d[0] * (k_shape[0] - 1) + and (ifm_shape[2] - 1 + p[2] + p[3]) + > d[1] * (k_shape[1] - 1) + and ( + k_rank < 3 + or ( + (ifm_shape[3] - 1 + p[4] + p[5]) + > d[2] * (k_shape[2] - 1) + ) + ) ): + remainders = [] + outputs = [] + for index in range(k_rank): + pad_offset = index * 2 + partial = ( + ifm_shape[index + 1] + - 1 + + p[pad_offset] + + p[pad_offset + 1] + - (k_shape[index] - 1) * d[index] + ) + remainders.append(partial % s[index]) + outputs.append((partial // s[index]) + 1) + if ( - max_dim_size is not None - and max(outputs) >= max_dim_size + # the parameters must produce integer exact output + error_name != ErrorIf.ConvOutputShapeNonInteger + and max(remainders) == 0 + ) or ( + error_name == ErrorIf.ConvOutputShapeNonInteger + and max(remainders) > 0 ): - # Test will consume too much memory - skip it - continue - - # Compliance - number of dot product calculations - if depthwise: - # N*OH*OW*C*M - dots = gtu.product( - (ifm_shape[0], *outputs, *filter_shape[2:]) - ) - else: - # N*OH*OW*OC or N*OD*OH*OW*OC - dots = gtu.product( - (ifm_shape[0], *outputs, filter_shape[0]) - ) - args_dict = { - "acc_type": accum_dtype, - "stride": s, - "pad": p, - "dilation": d, - "kernel": k_shape, - "ks": k_size, - "dot_products": dots, - "shape": ifm_shape, - } - - # Support for larger values than 9 needs different delimiter - delim = "" if max(s + p + d) <= 9 else "x" - arg_list.append( - ( - "acc{}_st{}_pad{}_dilat{}".format( - testGen.typeStr(accum_dtype), - delim.join([str(x) for x in s]), - delim.join([str(x) for x in p]), - delim.join([str(x) for x in d]), - ), - args_dict, + if ( + max_dim_size is not None + and max(outputs) >= max_dim_size + ): + # Test will consume too much memory - skip it + continue + + # Compliance - number of dot product calculations + if depthwise: + # N*OH*OW*C*M + dots = gtu.product( + (ifm_shape[0], *outputs, *filter_shape[2:]) + ) + else: + # N*OH*OW*OC or N*OD*OH*OW*OC + dots = gtu.product( + (ifm_shape[0], *outputs, filter_shape[0]) + ) + args_dict = { + "acc_type": a, + "stride": s, + "pad": p, + "dilation": d, + "kernel": k_shape, + "ks": k_size, + "dot_products": dots, + "shape": ifm_shape, + } + + # Support for larger values than 9 needs different delimiter + delim = "" if max(s + p + d) <= 9 else "x" + arg_list.append( + ( + "acc{}_st{}_pad{}_dilat{}".format( + testGen.typeStr(a), + delim.join([str(x) for x in s]), + delim.join([str(x) for x in p]), + delim.join([str(x) for x in d]), + ), + args_dict, + ) ) - ) - n += 1 + n += 1 arg_list = TosaArgGen._add_data_generators( testGen, @@ -2216,7 +2224,7 @@ class TosaArgGen: # Pick some potentially correct output dtype if input type is incorrect accum_dtype = DType.INT32 else: - accum_dtype = gtu.get_accum_dtype_from_tgTypes(dtypes) + accum_dtype = dtypes[-1] # use output dtype as accum_dtype # Set up compliance info args_dict = { @@ -2303,7 +2311,12 @@ class TosaArgGen: ifm_shape = shapeList[0] filter_shape = shapeList[1] - accum_dtype = gtu.get_accum_dtype_from_tgTypes(dtypes) + accum_dtypes = gtu.get_accum_dtypes_from_tgTypes(dtypes) + + if error_name == ErrorIf.WrongAccumulatorType: + accum_dtypes = ( + [DType.BF16] if gtu.dtypeIsFloat(dtypes[0]) else [DType.INT16] + ) # Must be rank 4 if error_name != ErrorIf.WrongRank: @@ -2400,41 +2413,42 @@ class TosaArgGen: sparsity = 1 n = 0 - for s in sorted(list(strides)): - for p in sorted(list(paddings)): - if n % sparsity == 0: - # Determine the output shape - oh = (ifm_shape[1] - 1) * s[0] + p[0] + p[1] + k_shape[0] - ow = (ifm_shape[2] - 1) * s[1] + p[2] + p[3] + k_shape[1] - os = [ifm_shape[0], oh, ow, filter_shape[0]] - - # N*OH*OW*OC - dots = gtu.product((ifm_shape[0], oh, ow, filter_shape[0])) - args_dict = { - "acc_type": accum_dtype, - "stride": s, - "pad": p, - "kernel": k_shape, - "ks": k_size, - "dot_products": dots, - "shape": ifm_shape, - "out_shape": os, - } + for a in accum_dtypes: + for s in sorted(list(strides)): + for p in sorted(list(paddings)): + if n % sparsity == 0: + # Determine the output shape + oh = (ifm_shape[1] - 1) * s[0] + p[0] + p[1] + k_shape[0] + ow = (ifm_shape[2] - 1) * s[1] + p[2] + p[3] + k_shape[1] + os = [ifm_shape[0], oh, ow, filter_shape[0]] + + # N*OH*OW*OC + dots = gtu.product((ifm_shape[0], oh, ow, filter_shape[0])) + args_dict = { + "acc_type": a, + "stride": s, + "pad": p, + "kernel": k_shape, + "ks": k_size, + "dot_products": dots, + "shape": ifm_shape, + "out_shape": os, + } - # Support for larger values than 9 needs different delimiter - delim = "" if max(s + p) <= 9 else "x" - arg_list.append( - ( - "acc{}_st{}_pad{}_os{}".format( - testGen.typeStr(accum_dtype), - delim.join([str(x) for x in s]), - delim.join([str(x) for x in p]), - "x".join([str(x) for x in os]), - ), - args_dict, + # Support for larger values than 9 needs different delimiter + delim = "" if max(s + p) <= 9 else "x" + arg_list.append( + ( + "acc{}_st{}_pad{}_os{}".format( + testGen.typeStr(a), + delim.join([str(x) for x in s]), + delim.join([str(x) for x in p]), + "x".join([str(x) for x in os]), + ), + args_dict, + ) ) - ) - n += 1 + n += 1 arg_list = TosaArgGen._add_data_generators( testGen, diff --git a/verif/generator/tosa_error_if.py b/verif/generator/tosa_error_if.py index e557f06..916b4f9 100644 --- a/verif/generator/tosa_error_if.py +++ b/verif/generator/tosa_error_if.py @@ -649,9 +649,9 @@ class TosaErrorValidator: or input_dtype == DType.INT16 and output_dtype != DType.INT48 or input_dtype == DType.FP16 - and output_dtype not in (DType.FP16, DType.FP32) + and output_dtype != DType.FP16 or input_dtype == DType.BF16 - and output_dtype != DType.FP32 + and output_dtype != DType.BF16 or input_dtype == DType.FP32 and output_dtype != DType.FP32 or input_dtype == DType.FP8E4M3 @@ -2682,6 +2682,36 @@ class TosaErrorValidator: ): error_result = True + elif op["op"] in { + Op.CONV2D, + Op.CONV3D, + Op.DEPTHWISE_CONV2D, + Op.TRANSPOSE_CONV2D, + }: + if input_dtype == DType.INT8 and accum_dtype != DType.INT32: + error_result = True + elif input_dtype == DType.INT16 and accum_dtype != DType.INT48: + error_result = True + elif ( + input_dtype + in ( + DType.FP32, + DType.BF16, + ) + and accum_dtype != DType.FP32 + ): + error_result = True + elif input_dtype == DType.FP16 and accum_dtype not in ( + DType.FP16, + DType.FP32, + ): + error_result = True + elif ( + input_dtype in (DType.FP8E4M3, DType.FP8E5M2) + and accum_dtype != DType.FP16 + ): + error_result = True + info_dict = { "error_name": error_name, "error_result": error_result, diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py index 7702753..c867070 100644 --- a/verif/generator/tosa_test_gen.py +++ b/verif/generator/tosa_test_gen.py @@ -896,6 +896,7 @@ class TosaTestGen: input_shape=ifm.shape, weight_shape=filter.shape, output_shape=result_tensor.shape, + accum_dtype=accum_dtype, ): return None @@ -903,7 +904,9 @@ class TosaTestGen: local_bound = False attr = ts.TosaSerializerAttribute() - attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound) + attr.ConvAttribute( + padding, strides, dilations, qinfo[0], qinfo[1], local_bound, accum_dtype + ) self.ser.addOperator(op["op"], input_list, output_list, attr) @@ -981,6 +984,7 @@ class TosaTestGen: input_shape=ifm.shape, weight_shape=filter.shape, output_shape=result_tensor.shape, + accum_dtype=accum_dtype, ): return None @@ -988,7 +992,9 @@ class TosaTestGen: local_bound = False attr = ts.TosaSerializerAttribute() - attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound) + attr.ConvAttribute( + padding, strides, dilations, qinfo[0], qinfo[1], local_bound, accum_dtype + ) self.ser.addOperator(op["op"], input_list, output_list, attr) @@ -1057,6 +1063,7 @@ class TosaTestGen: input_shape=ifm.shape, weight_shape=filter.shape, output_shape=result_tensor.shape, + accum_dtype=accum_dtype, ): return None @@ -1065,7 +1072,7 @@ class TosaTestGen: attr = ts.TosaSerializerAttribute() attr.TransposeConvAttribute( - out_pad, strides, output_shape, qinfo[0], qinfo[1], local_bound + out_pad, strides, output_shape, qinfo[0], qinfo[1], local_bound, accum_dtype ) self.ser.addOperator(op["op"], input_list, output_list, attr) @@ -1143,6 +1150,7 @@ class TosaTestGen: input_shape=ifm.shape, weight_shape=filter.shape, output_shape=result_tensor.shape, + accum_dtype=accum_dtype, ): return None @@ -1150,7 +1158,9 @@ class TosaTestGen: local_bound = False attr = ts.TosaSerializerAttribute() - attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], local_bound) + attr.ConvAttribute( + padding, strides, dilations, qinfo[0], qinfo[1], local_bound, accum_dtype + ) self.ser.addOperator(op["op"], input_list, output_list, attr) @@ -3385,6 +3395,7 @@ class TosaTestGen: TosaErrorValidator.evWrongRank, TosaErrorValidator.evConvOutputShapeMismatch, TosaErrorValidator.evConvOutputShapeNonInteger, + TosaErrorValidator.evWrongAccumulatorType, ), "data_gen": { "fp": (gtu.DataGenType.DOT_PRODUCT,), @@ -3418,6 +3429,7 @@ class TosaTestGen: TosaErrorValidator.evWrongRank, TosaErrorValidator.evConvOutputShapeMismatch, TosaErrorValidator.evConvOutputShapeNonInteger, + TosaErrorValidator.evWrongAccumulatorType, ), "data_gen": { "fp": (gtu.DataGenType.DOT_PRODUCT,), @@ -3452,6 +3464,7 @@ class TosaTestGen: TosaErrorValidator.evWrongRank, TosaErrorValidator.evConvOutputShapeMismatch, TosaErrorValidator.evConvOutputShapeNonInteger, + TosaErrorValidator.evWrongAccumulatorType, ), "data_gen": { "fp": (gtu.DataGenType.DOT_PRODUCT,), @@ -3564,6 +3577,7 @@ class TosaTestGen: TosaErrorValidator.evStrideSmallerOne, TosaErrorValidator.evWrongRank, TosaErrorValidator.evConvOutputShapeMismatch, + TosaErrorValidator.evWrongAccumulatorType, ), "data_gen": { "fp": (gtu.DataGenType.DOT_PRODUCT,), @@ -5289,6 +5303,18 @@ class OutputShaper: return ser.addOutput(shape, outputDType) + @staticmethod + def _get_conv_output_type(input_dtype): + if input_dtype in (DType.FP16, DType.BF16, DType.FP32): + return input_dtype + elif input_dtype in (DType.FP8E4M3, DType.FP8E5M2): + return DType.FP16 + elif input_dtype in (DType.INT8, DType.INT4): + return DType.INT32 + elif input_dtype in (DType.INT16,): + return DType.INT48 + assert True, f"Unsupported convolution data type {input_dtype}" + @staticmethod def conv2dOp( ser, rng, ifm, filter, accum_dtype, strides, padding, dilations, error_name=None @@ -5329,7 +5355,7 @@ class OutputShaper: # Pick some potentially correct output dtype if input type is incorrect out_dtype = DType.INT32 else: - out_dtype = accum_dtype + out_dtype = OutputShaper._get_conv_output_type(ifm.dtype) if error_name == ErrorIf.WrongOutputType: if ifm.dtype == DType.FP16: @@ -5393,7 +5419,7 @@ class OutputShaper: # Pick some potentially correct output dtype if input type is incorrect out_dtype = DType.INT32 else: - out_dtype = accum_dtype + out_dtype = OutputShaper._get_conv_output_type(ifm.dtype) if error_name == ErrorIf.WrongOutputType: if ifm.dtype == DType.FP16: @@ -5444,7 +5470,7 @@ class OutputShaper: # Pick some potentially correct output dtype if input type is incorrect out_dtype = DType.INT32 else: - out_dtype = accum_dtype + out_dtype = OutputShaper._get_conv_output_type(ifm.dtype) if error_name == ErrorIf.WrongOutputType: if ifm.dtype == DType.FP16: @@ -5958,7 +5984,7 @@ class OutputShaper: # Pick some potentially correct output dtype if input type is incorrect out_dtype = DType.INT32 else: - out_dtype = accum_dtype + out_dtype = OutputShaper._get_conv_output_type(ifm.dtype) if error_name == ErrorIf.WrongOutputType: if ifm.dtype == DType.FP16: diff --git a/verif/generator/tosa_utils.py b/verif/generator/tosa_utils.py index cfe7cc6..4a4f6bb 100644 --- a/verif/generator/tosa_utils.py +++ b/verif/generator/tosa_utils.py @@ -164,10 +164,18 @@ def product(shape): return value -def get_accum_dtype_from_tgTypes(dtypes): - # Get accumulate data-type from the test generator's defined types +def get_accum_dtypes_from_tgTypes(dtypes): + # Get accumulate data-types from the test generator's defined types assert isinstance(dtypes, list) or isinstance(dtypes, tuple) - return dtypes[-1] + input_dtype = dtypes[0] + output_dtype = dtypes[-1] + # by default, accum_dtypes contains only output_dtype + accum_dtypes = [output_dtype] + if input_dtype == DType.FP16 and output_dtype == DType.FP16: + accum_dtypes = [DType.FP16, DType.FP32] + elif output_dtype == DType.BF16: + accum_dtypes = [DType.FP32] + return accum_dtypes def get_wrong_output_type(op_name, rng, input_dtype): -- cgit v1.2.1