From d34b3fc5eeef48ecc781a02433ce022a28e3373c Mon Sep 17 00:00:00 2001 From: James Ward Date: Wed, 18 Jan 2023 14:51:25 +0000 Subject: Remove accumulator attributes from all but AVG_POOL2D Signed-off-by: James Ward Change-Id: If67f503a1848967bc1671646c3011d055b622c52 --- reference_model/src/graph_node.h | 20 +-- reference_model/src/operators.cc | 42 +++--- reference_model/src/ops/image.cc | 14 +- reference_model/src/ops/op_factory.cc | 82 +++++----- reference_model/src/ops/op_factory.h | 12 ++ reference_model/src/ops/template_types.h | 33 ---- reference_model/src/ops/tensor_ops.cc | 250 +++++++++++++++---------------- reference_model/src/ops/tensor_ops.h | 60 ++++---- scripts/operator_api/generate_api.py | 2 + thirdparty/serialization_lib | 2 +- verif/generator/tosa_test_gen.py | 14 +- 11 files changed, 241 insertions(+), 290 deletions(-) diff --git a/reference_model/src/graph_node.h b/reference_model/src/graph_node.h index b227d17..a9a336b 100644 --- a/reference_model/src/graph_node.h +++ b/reference_model/src/graph_node.h @@ -24,9 +24,6 @@ #define DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, RANK, DTYPE) template class TosaReference::OP; -#define DEF_INSTANTIATE_ONE_RANK_ONE_TYPE_ONE_ACCUM(OP, RANK, DTYPE, ACCUM_DTYPE) \ - template class TosaReference::OP; - #define DEF_INSTANTIATE_ONE_RANK_TWO_TYPE(OP, RANK, DTYPE1, DTYPE2) \ template class TosaReference::OP; @@ -38,15 +35,12 @@ #define DEF_INSTANTIATE_ONE_TYPE(OP, DTYPE) template class TosaReference::OP; -#define DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OP, DTYPE, ACCUM_DTYPE) \ - template class TosaReference::OP; - #define DEF_INSTANTIATE_TWO_TYPE(OP, DTYPE1, DTYPE2) template class TosaReference::OP; -#define DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OP, DTYPE1, DTYPE2, ACCUM_DTYPE) \ - template class TosaReference::OP; +#define DEF_INSTANTIATE_THREE_TYPE(OP, DTYPE1, DTYPE2, DTYPE3) \ + template class TosaReference::OP; -#define DEF_INSTANTIATE_THREE_TYPE(OP, DTYPE1, DTYPE2, OP_TYPE) \ +#define DEF_INSTANTIATE_THREE_TYPE_RESIZE(OP, DTYPE1, DTYPE2, OP_TYPE) \ template class TosaReference::OP; #define DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OP, DTYPE) \ @@ -66,14 +60,6 @@ DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 5, DTYPE) \ DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 6, DTYPE) -#define DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE_ONE_ACCUM(OP, DTYPE, ACCUM_DTYPE) \ - DEF_INSTANTIATE_ONE_RANK_ONE_TYPE_ONE_ACCUM(OP, 1, DTYPE, ACCUM_DTYPE) \ - DEF_INSTANTIATE_ONE_RANK_ONE_TYPE_ONE_ACCUM(OP, 2, DTYPE, ACCUM_DTYPE) \ - DEF_INSTANTIATE_ONE_RANK_ONE_TYPE_ONE_ACCUM(OP, 3, DTYPE, ACCUM_DTYPE) \ - DEF_INSTANTIATE_ONE_RANK_ONE_TYPE_ONE_ACCUM(OP, 4, DTYPE, ACCUM_DTYPE) \ - DEF_INSTANTIATE_ONE_RANK_ONE_TYPE_ONE_ACCUM(OP, 5, DTYPE, ACCUM_DTYPE) \ - DEF_INSTANTIATE_ONE_RANK_ONE_TYPE_ONE_ACCUM(OP, 6, DTYPE, ACCUM_DTYPE) - #define DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OP, DTYPE1, DTYPE2) \ DEF_INSTANTIATE_ONE_RANK_TWO_TYPE(OP, 0, DTYPE1, DTYPE2) \ DEF_INSTANTIATE_ONE_RANK_TWO_TYPE(OP, 1, DTYPE1, DTYPE2) \ diff --git a/reference_model/src/operators.cc b/reference_model/src/operators.cc index af348ca..a627322 100644 --- a/reference_model/src/operators.cc +++ b/reference_model/src/operators.cc @@ -168,10 +168,9 @@ extern "C" const std::vector pad(&client_pad[0], &client_pad[4]); const std::vector stride(&client_stride[0], &client_stride[2]); const std::vector dilation(&client_dilation[0], &client_dilation[2]); - const int32_t input_zp = client_input_zp; - const int32_t weight_zp = client_weight_zp; - const tosa::DType accum_dtype = tosa::DType::DType_FP32; - TosaConvAttribute attr(pad, stride, dilation, input_zp, weight_zp, accum_dtype); + const int32_t input_zp = client_input_zp; + const int32_t weight_zp = client_weight_zp; + TosaConvAttribute attr(pad, stride, dilation, input_zp, weight_zp); // Create tensors tosa::TosaSerializationTensor* input = translate_client_tensor(client_input, "input"); @@ -219,10 +218,9 @@ extern "C" const std::vector pad(&client_pad[0], &client_pad[6]); const std::vector stride(&client_stride[0], &client_stride[3]); const std::vector dilation(&client_dilation[0], &client_dilation[3]); - const int32_t input_zp = client_input_zp; - const int32_t weight_zp = client_weight_zp; - const tosa::DType accum_dtype = tosa::DType::DType_FP32; - TosaConvAttribute attr(pad, stride, dilation, input_zp, weight_zp, accum_dtype); + const int32_t input_zp = client_input_zp; + const int32_t weight_zp = client_weight_zp; + TosaConvAttribute attr(pad, stride, dilation, input_zp, weight_zp); // Create tensors tosa::TosaSerializationTensor* input = translate_client_tensor(client_input, "input"); @@ -270,10 +268,9 @@ extern "C" const std::vector pad(&client_pad[0], &client_pad[4]); const std::vector stride(&client_stride[0], &client_stride[2]); const std::vector dilation(&client_dilation[0], &client_dilation[2]); - const int32_t input_zp = client_input_zp; - const int32_t weight_zp = client_weight_zp; - const tosa::DType accum_dtype = tosa::DType::DType_FP32; - TosaConvAttribute attr(pad, stride, dilation, input_zp, weight_zp, accum_dtype); + const int32_t input_zp = client_input_zp; + const int32_t weight_zp = client_weight_zp; + TosaConvAttribute attr(pad, stride, dilation, input_zp, weight_zp); // Create tensors tosa::TosaSerializationTensor* input = translate_client_tensor(client_input, "input"); @@ -313,10 +310,9 @@ extern "C" tosa_tensor_t client_output) { // Create operator attributes - const int32_t input_zp = client_input_zp; - const int32_t weight_zp = client_weight_zp; - const tosa::DType accum_dtype = tosa::DType::DType_FP32; - TosaFullyConnectedAttribute attr(input_zp, weight_zp, accum_dtype); + const int32_t input_zp = client_input_zp; + const int32_t weight_zp = client_weight_zp; + TosaFullyConnectedAttribute attr(input_zp, weight_zp); // Create tensors tosa::TosaSerializationTensor* input = translate_client_tensor(client_input, "input"); @@ -352,10 +348,9 @@ extern "C" tosa_tensor_t client_output) { // Create operator attributes - const int32_t a_zp = client_a_zp; - const int32_t b_zp = client_b_zp; - const tosa::DType accum_dtype = tosa::DType::DType_FP32; - TosaMatMulAttribute attr(a_zp, b_zp, accum_dtype); + const int32_t a_zp = client_a_zp; + const int32_t b_zp = client_b_zp; + TosaMatMulAttribute attr(a_zp, b_zp); // Create tensors tosa::TosaSerializationTensor* a = translate_client_tensor(client_a, "a"); @@ -446,10 +441,9 @@ extern "C" const std::vector pad(&client_pad[0], &client_pad[0] + client_pad_len); const std::vector stride(&client_stride[0], &client_stride[2]); const std::vector dilation(&client_dilation[0], &client_dilation[0] + client_dilation_len); - const int32_t input_zp = client_input_zp; - const int32_t weight_zp = client_weight_zp; - const tosa::DType accum_dtype = tosa::DType::DType_FP32; - TosaConvAttribute attr(pad, stride, dilation, input_zp, weight_zp, accum_dtype); + const int32_t input_zp = client_input_zp; + const int32_t weight_zp = client_weight_zp; + TosaConvAttribute attr(pad, stride, dilation, input_zp, weight_zp); // Create tensors tosa::TosaSerializationTensor* input = translate_client_tensor(client_input, "input"); diff --git a/reference_model/src/ops/image.cc b/reference_model/src/ops/image.cc index a1a4474..90427e4 100644 --- a/reference_model/src/ops/image.cc +++ b/reference_model/src/ops/image.cc @@ -236,10 +236,10 @@ int OpResize::eval() } // template explicit instantiation -DEF_INSTANTIATE_THREE_TYPE(OpResize, INT8, INT32, int16_t); -DEF_INSTANTIATE_THREE_TYPE(OpResize, INT8, INT8, int16_t); -DEF_INSTANTIATE_THREE_TYPE(OpResize, INT16, INT48, int16_t); -DEF_INSTANTIATE_THREE_TYPE(OpResize, INT16, INT16, int16_t); -DEF_INSTANTIATE_THREE_TYPE(OpResize, FP16, FP16, half_float::half); -DEF_INSTANTIATE_THREE_TYPE(OpResize, BF16, BF16, Eigen::bfloat16); -DEF_INSTANTIATE_THREE_TYPE(OpResize, FP32, FP32, float); +DEF_INSTANTIATE_THREE_TYPE_RESIZE(OpResize, INT8, INT32, int16_t); +DEF_INSTANTIATE_THREE_TYPE_RESIZE(OpResize, INT8, INT8, int16_t); +DEF_INSTANTIATE_THREE_TYPE_RESIZE(OpResize, INT16, INT48, int16_t); +DEF_INSTANTIATE_THREE_TYPE_RESIZE(OpResize, INT16, INT16, int16_t); +DEF_INSTANTIATE_THREE_TYPE_RESIZE(OpResize, FP16, FP16, half_float::half); +DEF_INSTANTIATE_THREE_TYPE_RESIZE(OpResize, BF16, BF16, Eigen::bfloat16); +DEF_INSTANTIATE_THREE_TYPE_RESIZE(OpResize, FP32, FP32, float); diff --git a/reference_model/src/ops/op_factory.cc b/reference_model/src/ops/op_factory.cc index 76cf666..b1a405a 100644 --- a/reference_model/src/ops/op_factory.cc +++ b/reference_model/src/ops/op_factory.cc @@ -63,48 +63,48 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, INT16, INT32); break; case Op_CONV2D: - DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, FP16, FP16, FP16); - DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, FP16, FP16, FP32); - DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, BF16, BF16, FP32); - DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, FP32, FP32, FP32); - DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, INT8, INT4, INT32); - DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, INT8, INT8, INT32); - DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv2d, Conv, INT16, INT8, INT48); + DEF_FACTORY_THREE_TYPE(OpConv2d, FP16, FP16, FP16); + DEF_FACTORY_THREE_TYPE(OpConv2d, FP16, FP16, FP32); + DEF_FACTORY_THREE_TYPE(OpConv2d, BF16, BF16, FP32); + DEF_FACTORY_THREE_TYPE(OpConv2d, FP32, FP32, FP32); + DEF_FACTORY_THREE_TYPE(OpConv2d, INT8, INT4, INT32); + DEF_FACTORY_THREE_TYPE(OpConv2d, INT8, INT8, INT32); + DEF_FACTORY_THREE_TYPE(OpConv2d, INT16, INT8, INT48); break; case Op_CONV3D: - DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, FP16, FP16, FP16); - DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, FP16, FP16, FP32); - DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, BF16, BF16, FP32); - DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, FP32, FP32, FP32); - DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, INT8, INT4, INT32); - DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, INT8, INT8, INT32); - DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpConv3d, Conv, INT16, INT8, INT48); + DEF_FACTORY_THREE_TYPE(OpConv3d, FP16, FP16, FP16); + DEF_FACTORY_THREE_TYPE(OpConv3d, FP16, FP16, FP32); + DEF_FACTORY_THREE_TYPE(OpConv3d, BF16, BF16, FP32); + DEF_FACTORY_THREE_TYPE(OpConv3d, FP32, FP32, FP32); + DEF_FACTORY_THREE_TYPE(OpConv3d, INT8, INT4, INT32); + DEF_FACTORY_THREE_TYPE(OpConv3d, INT8, INT8, INT32); + DEF_FACTORY_THREE_TYPE(OpConv3d, INT16, INT8, INT48); break; case Op_DEPTHWISE_CONV2D: - DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP16, FP16, FP16); - DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP16, FP16, FP32); - DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, BF16, BF16, FP32); - DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP32, FP32, FP32); - DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, INT8, INT4, INT32); - DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, INT8, INT8, INT32); - DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, INT16, INT8, INT48); + DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP16); + DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP32); + DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, BF16, BF16, FP32); + DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP32, FP32, FP32); + DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, INT8, INT4, INT32); + DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, INT8, INT8, INT32); + DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, INT16, INT8, INT48); break; case Op_FULLY_CONNECTED: - DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, FP16, FP16, FP16); - DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, FP16, FP16, FP32); - DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, BF16, BF16, FP32); - DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, FP32, FP32, FP32); - DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, INT8, INT4, INT32); - DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, INT8, INT8, INT32); - DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FullyConnected, INT16, INT8, INT48); + DEF_FACTORY_THREE_TYPE(OpFullyConnected, FP16, FP16, FP16); + DEF_FACTORY_THREE_TYPE(OpFullyConnected, FP16, FP16, FP32); + DEF_FACTORY_THREE_TYPE(OpFullyConnected, BF16, BF16, FP32); + DEF_FACTORY_THREE_TYPE(OpFullyConnected, FP32, FP32, FP32); + DEF_FACTORY_THREE_TYPE(OpFullyConnected, INT8, INT4, INT32); + DEF_FACTORY_THREE_TYPE(OpFullyConnected, INT8, INT8, INT32); + DEF_FACTORY_THREE_TYPE(OpFullyConnected, INT16, INT8, INT48); break; case Op_MATMUL: - DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, FP16, FP16); - DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, FP16, FP32); - DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, BF16, FP32); - DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, FP32, FP32); - DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, INT8, INT32); - DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpMatMul, MatMul, INT16, INT48); + DEF_FACTORY_TWO_TYPE_IN_OUT(OpMatMul, FP16, FP16); + DEF_FACTORY_TWO_TYPE_IN_OUT(OpMatMul, FP16, FP32); + DEF_FACTORY_TWO_TYPE_IN_OUT(OpMatMul, BF16, FP32); + DEF_FACTORY_TWO_TYPE_IN_OUT(OpMatMul, FP32, FP32); + DEF_FACTORY_TWO_TYPE_IN_OUT(OpMatMul, INT8, INT32); + DEF_FACTORY_TWO_TYPE_IN_OUT(OpMatMul, INT16, INT48); break; case Op_MAX_POOL2D: DEF_FACTORY_ONE_TYPE(OpMaxPool2d, FP16); @@ -117,13 +117,13 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_ONE_TYPE(OpRFFT2d, FP32); break; case Op_TRANSPOSE_CONV2D: - DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP16, FP16, FP16); - DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP16, FP16, FP32); - DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, BF16, BF16, FP32); - DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP32, FP32, FP32); - DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, INT8, INT4, INT32); - DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, INT8, INT8, INT32); - DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, INT16, INT8, INT48); + DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, FP16, FP16, FP16); + DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, FP16, FP16, FP32); + DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, BF16, BF16, FP32); + DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, FP32, FP32, FP32); + DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, INT8, INT4, INT32); + DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, INT8, INT8, INT32); + DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, INT16, INT8, INT48); break; // activation_funcs diff --git a/reference_model/src/ops/op_factory.h b/reference_model/src/ops/op_factory.h index f4177db..9117df4 100644 --- a/reference_model/src/ops/op_factory.h +++ b/reference_model/src/ops/op_factory.h @@ -74,6 +74,12 @@ return new OP(sgt, attribute, id); \ } +#define DEF_FACTORY_TWO_TYPE_IN_OUT(OP, DTYPE1, DTYPE2) \ + if (inputDType == DType_##DTYPE1 && outputDType == DType_##DTYPE2) \ + { \ + return new OP(sgt, attribute, id); \ + } + #define DEF_FACTORY_TWO_TYPE_ONE_ACCUM(OP, ATTR_NAME, DTYPE1, DTYPE2, ACCUM_DTYPE) \ if (inputDType == DType_##DTYPE1 && weightDType == DType_##DTYPE2 \ && ACCUM_FROM_ATTRIBUTE(ATTR_NAME) == DType_##ACCUM_DTYPE) \ @@ -81,6 +87,12 @@ return new OP(sgt, attribute, id); \ } \ +#define DEF_FACTORY_THREE_TYPE(OP, DTYPE1, DTYPE2, DTYPE3) \ + if (inputDType == DType_##DTYPE1 && weightDType == DType_##DTYPE2 && outputDType == DType_##DTYPE3) \ + { \ + return new OP(sgt, attribute, id); \ + } + // Statement-expression to evaluate accumulate attribute in-place #define ACCUM_FROM_ATTRIBUTE(ATTRIBUTE_NAME) \ ({ \ diff --git a/reference_model/src/ops/template_types.h b/reference_model/src/ops/template_types.h index 647ca84..6b28502 100644 --- a/reference_model/src/ops/template_types.h +++ b/reference_model/src/ops/template_types.h @@ -275,39 +275,6 @@ struct GetQMax static constexpr int64_t value = (1L << 47) - 1; }; -template -struct GetAccDType; -template <> -struct GetAccDType -{ - static constexpr DType value = DType_INT32; -}; -template <> -struct GetAccDType -{ - static constexpr DType value = DType_INT32; -}; -template <> -struct GetAccDType -{ - static constexpr DType value = DType_INT48; -}; -template <> -struct GetAccDType -{ - static constexpr DType value = DType_INT48; -}; -template <> -struct GetAccDType -{ - static constexpr DType value = DType_FP16; -}; -template <> -struct GetAccDType -{ - static constexpr DType value = DType_FP32; -}; - }; // namespace TosaReference #endif diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc index dff9e08..4663c47 100644 --- a/reference_model/src/ops/tensor_ops.cc +++ b/reference_model/src/ops/tensor_ops.cc @@ -541,8 +541,8 @@ int OpAvgPool2d::eval() return GraphNode::eval(); } -template -OpConv2d::OpConv2d(SubgraphTraverser* sgt_, +template +OpConv2d::OpConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) : GraphNode(sgt_, Op_CONV2D, id_) @@ -553,15 +553,15 @@ OpConv2d::OpConv2d(SubgraphTraverser* sgt_, INIT_ATTRIBUTE(Conv); } -template -OpConv2d::~OpConv2d() +template +OpConv2d::~OpConv2d() { if (attribute) delete attribute; } -template -int OpConv2d::checkTensorAttributes() +template +int OpConv2d::checkTensorAttributes() { if (validateRequiredOperands()) return 1; @@ -577,7 +577,7 @@ int OpConv2d::checkTensorAttributes() printNodeValidationError("OpConv2d: bias tensor must be rank 1"); } - ERROR_IF(outputs[0]->getDtype() != AccDtype, + ERROR_IF(outputs[0]->getDtype() != OutDtype, "OpConv2d: Output data type not supported for this configuration of operator"); input = dynamic_cast*>(inputs[0]); @@ -597,8 +597,8 @@ int OpConv2d::checkTensorAttributes() return 0; } -template -int OpConv2d::eval() +template +int OpConv2d::eval() { int in_batch = this->input->getShape()[0]; int in_height = this->input->getShape()[1]; @@ -634,14 +634,12 @@ int OpConv2d::eval() int dilation_h = this->attribute->dilation()[0]; int dilation_w = this->attribute->dilation()[1]; - tosa::DType accum_dtype = (tosa::DType)this->attribute->accum_dtype(); - DEBUG_INFO(OP, "perform OpConv2d, input.shape=[%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d], output.shape=[%d,%d,%d,%d], " - "stride=[%d,%d], dilation=[%d,%d], pad=[%d,%d,%d,%d], accum_dtype=%s", + "stride=[%d,%d], dilation=[%d,%d], pad=[%d,%d,%d,%d]", in_batch, in_height, in_width, in_channels, f_height, f_width, f_in_channels, f_out_channels, out_batch, out_height, out_width, out_channels, stride_h, stride_w, dilation_h, dilation_w, pad_top, - pad_bottom, pad_left, pad_right, EnumNamesDType()[accum_dtype]); + pad_bottom, pad_left, pad_right); // GEMM-conv2d, left matrix is input, right matrix is weight Eigen::array im2col_input_dims; @@ -717,7 +715,7 @@ int OpConv2d::eval() // reshape back to [N, H, W, C] this->output->getTensor() = biased_output.reshape(col2im_output_dims); - if (AccDtype == DType_INT48) + if (OutDtype == DType_INT48) { this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin); this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax); @@ -726,8 +724,8 @@ int OpConv2d::eval() return GraphNode::eval(); } -template -OpConv3d::OpConv3d(SubgraphTraverser* sgt_, +template +OpConv3d::OpConv3d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) : GraphNode(sgt_, Op_CONV3D, id_) @@ -738,15 +736,15 @@ OpConv3d::OpConv3d(SubgraphTraverser* sgt_, INIT_ATTRIBUTE(Conv); } -template -OpConv3d::~OpConv3d() +template +OpConv3d::~OpConv3d() { if (attribute) delete attribute; } -template -int OpConv3d::checkTensorAttributes() +template +int OpConv3d::checkTensorAttributes() { if (validateRequiredOperands()) return 1; @@ -762,7 +760,7 @@ int OpConv3d::checkTensorAttributes() printNodeValidationError("OpConv3d: bias tensor must be rank 1"); } - ERROR_IF(outputs[0]->getDtype() != AccDtype, + ERROR_IF(outputs[0]->getDtype() != OutDtype, "OpConv3d: Output data type not supported for this configuration of operator"); input = dynamic_cast*>(inputs[0]); @@ -782,8 +780,8 @@ int OpConv3d::checkTensorAttributes() return 0; } -template -int OpConv3d::eval() +template +int OpConv3d::eval() { int in_batch = this->input->getShape()[0]; int in_depth = this->input->getShape()[1]; @@ -827,15 +825,13 @@ int OpConv3d::eval() int dilation_h = this->attribute->dilation()[1]; int dilation_w = this->attribute->dilation()[2]; - tosa::DType accum_dtype = (tosa::DType)this->attribute->accum_dtype(); - DEBUG_INFO( OP, "perform OpConv3d, input.shape=[%d,%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d,%d], output.shape=[%d,%d,%d,%d,%d], " - "stride=[%d,%d,%d], dilation=[%d,%d,%d], pad=[%d,%d,%d,%d,%d,%d], accum_dtype=%s", + "stride=[%d,%d,%d], dilation=[%d,%d,%d], pad=[%d,%d,%d,%d,%d,%d]", in_batch, in_depth, in_height, in_width, in_channels, f_out_channels, f_depth, f_height, f_width, f_in_channels, out_batch, out_depth, out_height, out_width, out_channels, stride_d, stride_h, stride_w, dilation_d, dilation_h, - dilation_w, pad_d0, pad_d1, pad_top, pad_bottom, pad_left, pad_right, EnumNamesDType()[accum_dtype]); + dilation_w, pad_d0, pad_d1, pad_top, pad_bottom, pad_left, pad_right); Eigen::array, 5> pad; pad[0] = std::make_pair(0, 0); @@ -907,7 +903,7 @@ int OpConv3d::eval() } } - if (AccDtype == DType_INT48) + if (OutDtype == DType_INT48) { this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin); this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax); @@ -916,8 +912,8 @@ int OpConv3d::eval() return GraphNode::eval(); } -template -OpDepthwiseConv2d::OpDepthwiseConv2d(SubgraphTraverser* sgt_, +template +OpDepthwiseConv2d::OpDepthwiseConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) : GraphNode(sgt_, Op_DEPTHWISE_CONV2D, id_) @@ -928,15 +924,15 @@ OpDepthwiseConv2d::OpDepthwiseConv2d(SubgraphTra INIT_ATTRIBUTE(Conv); } -template -OpDepthwiseConv2d::~OpDepthwiseConv2d() +template +OpDepthwiseConv2d::~OpDepthwiseConv2d() { if (attribute) delete attribute; } -template -int OpDepthwiseConv2d::checkTensorAttributes() +template +int OpDepthwiseConv2d::checkTensorAttributes() { if (validateRequiredOperands()) return 1; @@ -952,7 +948,7 @@ int OpDepthwiseConv2d::checkTensorAttributes() printNodeValidationError("OpDepthwiseConv2d: bias tensor must be rank 1"); } - ERROR_IF(outputs[0]->getDtype() != AccDtype, + ERROR_IF(outputs[0]->getDtype() != OutDtype, "OpDepthwiseConv2d: Output data type not supported for this configuration of operator"); input = dynamic_cast*>(inputs[0]); @@ -972,8 +968,8 @@ int OpDepthwiseConv2d::checkTensorAttributes() return 0; } -template -int OpDepthwiseConv2d::eval() +template +int OpDepthwiseConv2d::eval() { int in_batch = this->input->getShape()[0]; int in_height = this->input->getShape()[1]; @@ -1010,14 +1006,12 @@ int OpDepthwiseConv2d::eval() int dilation_h = this->attribute->dilation()[0]; int dilation_w = this->attribute->dilation()[1]; - tosa::DType accum_dtype = (tosa::DType)this->attribute->accum_dtype(); - DEBUG_INFO(OP, "perform OpDepthwiseConv2d, input.shape=[%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d], " - "output.shape=[%d,%d,%d,%d], stride=[%d,%d], dilation=[%d,%d], pad=[%d,%d,%d,%d], accum_dtype=%s", + "output.shape=[%d,%d,%d,%d], stride=[%d,%d], dilation=[%d,%d], pad=[%d,%d,%d,%d]", in_batch, in_height, in_width, in_channels, f_height, f_width, f_in_channels, f_multiplier, out_batch, out_height, out_width, out_channels, stride_h, stride_w, dilation_h, dilation_w, pad_top, - pad_bottom, pad_left, pad_right, EnumNamesDType()[accum_dtype]); + pad_bottom, pad_left, pad_right); Eigen::array, 4> pad; pad[0] = std::make_pair(0, 0); @@ -1083,7 +1077,7 @@ int OpDepthwiseConv2d::eval() } } - if (AccDtype == DType_INT48) + if (OutDtype == DType_INT48) { this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin); this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax); @@ -1092,8 +1086,8 @@ int OpDepthwiseConv2d::eval() return GraphNode::eval(); } -template -OpFullyConnected::OpFullyConnected(SubgraphTraverser* sgt_, +template +OpFullyConnected::OpFullyConnected(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) : GraphNode(sgt_, Op_FULLY_CONNECTED, id_) @@ -1104,15 +1098,15 @@ OpFullyConnected::OpFullyConnected(SubgraphTrave INIT_ATTRIBUTE(FullyConnected); } -template -OpFullyConnected::~OpFullyConnected() +template +OpFullyConnected::~OpFullyConnected() { if (attribute) delete attribute; } -template -int OpFullyConnected::checkTensorAttributes() +template +int OpFullyConnected::checkTensorAttributes() { if (validateRequiredOperands()) return 1; @@ -1138,7 +1132,7 @@ int OpFullyConnected::checkTensorAttributes() return 1; } - ERROR_IF(outputs[0]->getDtype() != AccDtype, + ERROR_IF(outputs[0]->getDtype() != OutDtype, "OpFullyConnected: Output data type not supported for this configuration of operator"); output = dynamic_cast*>(outputs[0]); @@ -1149,8 +1143,8 @@ int OpFullyConnected::checkTensorAttributes() return 0; } -template -int OpFullyConnected::eval() +template +int OpFullyConnected::eval() { typedef Eigen::Tensor::DimensionPair DimPair; Eigen::array dims{ { DimPair(1, 0) } }; @@ -1177,7 +1171,7 @@ int OpFullyConnected::eval() input_val.template cast().contract(weight_val.template cast(), dims).template cast() + this->bias->getTensor().reshape(bias_reshape).broadcast(bias_bcast); - if (AccDtype == DType_INT48) + if (OutDtype == DType_INT48) { this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin); this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax); @@ -1185,8 +1179,8 @@ int OpFullyConnected::eval() return GraphNode::eval(); } -template -OpMatMul::OpMatMul(SubgraphTraverser* sgt_, +template +OpMatMul::OpMatMul(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) : GraphNode(sgt_, Op_MATMUL, id_) @@ -1197,15 +1191,15 @@ OpMatMul::OpMatMul(SubgraphTraverser* sgt_, INIT_ATTRIBUTE(MatMul); } -template -OpMatMul::~OpMatMul() +template +OpMatMul::~OpMatMul() { if (attribute) delete attribute; } -template -int OpMatMul::checkTensorAttributes() +template +int OpMatMul::checkTensorAttributes() { if (validateRequiredOperands()) return 1; @@ -1215,7 +1209,7 @@ int OpMatMul::checkTensorAttributes() return 1; } - ERROR_IF(outputs[0]->getDtype() != AccDtype, + ERROR_IF(outputs[0]->getDtype() != OutDtype, "OpMatMul: Output data type not supported for this configuration of operator"); a = dynamic_cast*>(inputs[0]); @@ -1266,8 +1260,8 @@ int OpMatMul::checkTensorAttributes() return 0; } -template -int OpMatMul::eval() +template +int OpMatMul::eval() { typedef Eigen::Tensor::DimensionPair DimPair; Eigen::array dims{ { DimPair(1, 0) } }; @@ -1312,7 +1306,7 @@ int OpMatMul::eval() } } - if (AccDtype == DType_INT48) + if (OutDtype == DType_INT48) { this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin); this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax); @@ -1587,8 +1581,8 @@ int OpRFFT2d::eval() return GraphNode::eval(); } -template -OpTransposeConv2d::OpTransposeConv2d(SubgraphTraverser* sgt_, +template +OpTransposeConv2d::OpTransposeConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) : GraphNode(sgt_, Op_TRANSPOSE_CONV2D, id_) @@ -1599,15 +1593,15 @@ OpTransposeConv2d::OpTransposeConv2d(SubgraphTra INIT_ATTRIBUTE(TransposeConv); } -template -OpTransposeConv2d::~OpTransposeConv2d() +template +OpTransposeConv2d::~OpTransposeConv2d() { if (attribute) delete attribute; } -template -int OpTransposeConv2d::checkTensorAttributes() +template +int OpTransposeConv2d::checkTensorAttributes() { if (validateRequiredOperands()) return 1; @@ -1617,7 +1611,7 @@ int OpTransposeConv2d::checkTensorAttributes() return 1; } - ERROR_IF(outputs[0]->getDtype() != AccDtype, + ERROR_IF(outputs[0]->getDtype() != OutDtype, "OpTransposeConv2d: Output data type not supported for this configuration of operator"); input = dynamic_cast*>(inputs[0]); @@ -1701,8 +1695,8 @@ int OpTransposeConv2d::checkTensorAttributes() return 0; } -template -int OpTransposeConv2d::eval() +template +int OpTransposeConv2d::eval() { int in_batch = this->input->getShape()[0]; int in_height = this->input->getShape()[1]; @@ -1729,8 +1723,6 @@ int OpTransposeConv2d::eval() int stride_h = this->attribute->stride()[0]; int stride_w = this->attribute->stride()[1]; - tosa::DType accum_dtype = (tosa::DType)this->attribute->accum_dtype(); - ERROR_IF(in_batch != out_batch, "OpTransposeConv2d: tensor batch mismatch %d != %d", in_batch, out_batch); ERROR_IF(f_in_channels != in_channels, "OpTransposeConv2d: tensor input channel mismatch %d != %d", f_in_channels, in_channels); @@ -1741,10 +1733,10 @@ int OpTransposeConv2d::eval() DEBUG_INFO(OP, "perform OpTransposeConv2d, input.shape=[%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d], " - "output.shape=[%d,%d,%d,%d], stride=[%d,%d], out_pad=[%d,%d,%d,%d], accum_dtype=%s", + "output.shape=[%d,%d,%d,%d], stride=[%d,%d], out_pad=[%d,%d,%d,%d]", in_batch, in_height, in_width, in_channels, f_height, f_width, f_out_channels, f_in_channels, out_batch, out_height, out_width, out_channels, stride_h, stride_w, out_pad_top, - out_pad_bottom, out_pad_left, out_pad_right, EnumNamesDType()[accum_dtype]); + out_pad_bottom, out_pad_left, out_pad_right); TIn input_val = this->input->getTensor(); TWeight weight_val = this->weight->getTensor(); @@ -1803,7 +1795,7 @@ int OpTransposeConv2d::eval() } } - if (AccDtype == DType_INT48) + if (OutDtype == DType_INT48) { this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin); this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax); @@ -1819,52 +1811,52 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT8); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT16); -DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, FP16, FP16); -DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, FP16, FP32); -DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, BF16, FP32); -DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, FP32, FP32); -DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, INT8, INT32); -DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, INT16, INT32); - - // [in_t, weight_t, acc_t] -DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, FP16, FP16, FP16); -DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, FP16, FP16, FP32); -DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, BF16, BF16, FP32); -DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, FP32, FP32, FP32); -DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, INT8, INT4, INT32); -DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, INT8, INT8, INT32); -DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, INT16, INT8, INT48); - -DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, FP16, FP16, FP16); -DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, FP16, FP16, FP32); -DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, BF16, BF16, FP32); -DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, FP32, FP32, FP32); -DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, INT8, INT4, INT32); -DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, INT8, INT8, INT32); -DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, INT16, INT8, INT48); - -DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, FP16, FP16, FP16); -DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, FP16, FP16, FP32); -DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, BF16, BF16, FP32); -DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, FP32, FP32, FP32); -DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, INT8, INT4, INT32); -DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, INT8, INT8, INT32); -DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, INT16, INT8, INT48); - -DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FP16, FP16, FP16); -DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FP16, FP16, FP32); -DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, BF16, BF16, FP32); -DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FP32, FP32, FP32); -DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, INT8, INT4, INT32); -DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, INT8, INT8, INT32); -DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, INT16, INT8, INT48); - -DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpMatMul, INT8, INT32); -DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpMatMul, INT16, INT48); -DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpMatMul, FP16, FP16); -DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpMatMul, FP16, FP32); -DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpMatMul, BF16, FP32); -DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpMatMul, FP32, FP32); +DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP16, FP16); +DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP16, FP32); +DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, BF16, FP32); +DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP32, FP32); +DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, INT8, INT32); +DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, INT16, INT32); + + // [in_t, weight_t, out_t] +DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP16, FP16, FP16); +DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP16, FP16, FP32); +DEF_INSTANTIATE_THREE_TYPE(OpConv2d, BF16, BF16, FP32); +DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP32, FP32, FP32); +DEF_INSTANTIATE_THREE_TYPE(OpConv2d, INT8, INT4, INT32); +DEF_INSTANTIATE_THREE_TYPE(OpConv2d, INT8, INT8, INT32); +DEF_INSTANTIATE_THREE_TYPE(OpConv2d, INT16, INT8, INT48); + +DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP16, FP16, FP16); +DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP16, FP16, FP32); +DEF_INSTANTIATE_THREE_TYPE(OpConv3d, BF16, BF16, FP32); +DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP32, FP32, FP32); +DEF_INSTANTIATE_THREE_TYPE(OpConv3d, INT8, INT4, INT32); +DEF_INSTANTIATE_THREE_TYPE(OpConv3d, INT8, INT8, INT32); +DEF_INSTANTIATE_THREE_TYPE(OpConv3d, INT16, INT8, INT48); + +DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP16); +DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP32); +DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, BF16, BF16, FP32); +DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP32, FP32, FP32); +DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT8, INT4, INT32); +DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT8, INT8, INT32); +DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT16, INT8, INT48); + +DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, FP16, FP16, FP16); +DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, FP16, FP16, FP32); +DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, BF16, BF16, FP32); +DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, FP32, FP32, FP32); +DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, INT8, INT4, INT32); +DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, INT8, INT8, INT32); +DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, INT16, INT8, INT48); + +DEF_INSTANTIATE_TWO_TYPE(OpMatMul, INT8, INT32); +DEF_INSTANTIATE_TWO_TYPE(OpMatMul, INT16, INT48); +DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP16, FP16); +DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP16, FP32); +DEF_INSTANTIATE_TWO_TYPE(OpMatMul, BF16, FP32); +DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP32, FP32); DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP16); DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, BF16); @@ -1874,10 +1866,10 @@ DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, INT16); DEF_INSTANTIATE_ONE_TYPE(OpRFFT2d, FP32); -DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, FP16, FP16, FP16); -DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, FP16, FP16, FP32); -DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, BF16, BF16, FP32); -DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, FP32, FP32, FP32); -DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, INT8, INT4, INT32); -DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, INT8, INT8, INT32); -DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, INT16, INT8, INT48); +DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP16, FP16, FP16); +DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP16, FP16, FP32); +DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, BF16, BF16, FP32); +DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP32, FP32, FP32); +DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, INT8, INT4, INT32); +DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, INT8, INT8, INT32); +DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, INT16, INT8, INT48); diff --git a/reference_model/src/ops/tensor_ops.h b/reference_model/src/ops/tensor_ops.h index ed9a55c..0d2b3eb 100644 --- a/reference_model/src/ops/tensor_ops.h +++ b/reference_model/src/ops/tensor_ops.h @@ -74,7 +74,7 @@ protected: ETensor1 calculate_div_map_1d(int in_size, int out_size, int kernel_size, int stride, int32_t padding_left, int32_t padding_right); }; -template +template class OpConv2d : public GraphNode { public: @@ -86,15 +86,15 @@ public: using InEigenType = typename GetEigenType::type; using WeightEigenType = typename GetEigenType::type; - using AccEigenType = typename GetAccEigenType::type; // Note: different from GetEigenType - using OutEigenType = typename GetEigenType::type; + using AccEigenType = typename GetAccEigenType::type; // Note: different from GetEigenType + using OutEigenType = typename GetEigenType::type; using TIn = Eigen::Tensor; using TWeight = Eigen::Tensor; using TBias = Eigen::Tensor; using TOut = Eigen::Tensor; - static constexpr int64_t AccQMin = GetQMin::value; - static constexpr int64_t AccQMax = GetQMax::value; + static constexpr int64_t AccQMin = GetQMin::value; + static constexpr int64_t AccQMax = GetQMax::value; protected: TosaReference::TensorTemplate* input; @@ -104,7 +104,7 @@ protected: tosa::TosaConvAttribute* attribute; }; -template +template class OpConv3d : public GraphNode { public: @@ -116,15 +116,15 @@ public: using InEigenType = typename GetEigenType::type; using WeightEigenType = typename GetEigenType::type; - using AccEigenType = typename GetAccEigenType::type; // Note: different from GetEigenType - using OutEigenType = typename GetEigenType::type; + using AccEigenType = typename GetAccEigenType::type; // Note: different from GetEigenType + using OutEigenType = typename GetEigenType::type; using TIn = Eigen::Tensor; using TWeight = Eigen::Tensor; using TBias = Eigen::Tensor; using TOut = Eigen::Tensor; - static constexpr int64_t AccQMin = GetQMin::value; - static constexpr int64_t AccQMax = GetQMax::value; + static constexpr int64_t AccQMin = GetQMin::value; + static constexpr int64_t AccQMax = GetQMax::value; protected: TosaReference::TensorTemplate* input; @@ -134,7 +134,7 @@ protected: tosa::TosaConvAttribute* attribute; }; -template +template class OpDepthwiseConv2d : public GraphNode { public: @@ -146,15 +146,15 @@ public: using InEigenType = typename GetEigenType::type; using WeightEigenType = typename GetEigenType::type; - using AccEigenType = typename GetAccEigenType::type; // Note: different from GetEigenType - using OutEigenType = typename GetEigenType::type; + using AccEigenType = typename GetAccEigenType::type; // Note: different from GetEigenType + using OutEigenType = typename GetEigenType::type; using TIn = Eigen::Tensor; using TWeight = Eigen::Tensor; using TBias = Eigen::Tensor; using TOut = Eigen::Tensor; - static constexpr int64_t AccQMin = GetQMin::value; - static constexpr int64_t AccQMax = GetQMax::value; + static constexpr int64_t AccQMin = GetQMin::value; + static constexpr int64_t AccQMax = GetQMax::value; protected: TosaReference::TensorTemplate* input; @@ -164,7 +164,7 @@ protected: tosa::TosaConvAttribute* attribute; }; -template +template class OpFullyConnected : public GraphNode { public: @@ -176,15 +176,15 @@ public: using InEigenType = typename GetEigenType::type; using WeightEigenType = typename GetEigenType::type; - using AccEigenType = typename GetAccEigenType::type; // Note: different from GetEigenType - using OutEigenType = typename GetEigenType::type; + using AccEigenType = typename GetAccEigenType::type; // Note: different from GetEigenType + using OutEigenType = typename GetEigenType::type; using TIn = Eigen::Tensor; using TWeight = Eigen::Tensor; using TBias = Eigen::Tensor; using TOut = Eigen::Tensor; - static constexpr int64_t AccQMin = GetQMin::value; - static constexpr int64_t AccQMax = GetQMax::value; + static constexpr int64_t AccQMin = GetQMin::value; + static constexpr int64_t AccQMax = GetQMax::value; protected: TosaReference::TensorTemplate* input; @@ -195,7 +195,7 @@ protected: tosa::TosaFullyConnectedAttribute* attribute; }; -template +template class OpMatMul : public GraphNode { public: @@ -206,14 +206,14 @@ public: virtual int eval() final; using InEigenType = typename GetEigenType::type; - using AccEigenType = typename GetAccEigenType::type; // Note: different from GetEigenType - using OutEigenType = typename GetEigenType::type; + using AccEigenType = typename GetAccEigenType::type; // Note: different from GetEigenType + using OutEigenType = typename GetEigenType::type; using TIn = Eigen::Tensor; using TOut = Eigen::Tensor; using TInRank2 = Eigen::Tensor; using TAccRank2 = Eigen::Tensor; - static constexpr int64_t AccQMin = GetQMin::value; - static constexpr int64_t AccQMax = GetQMax::value; + static constexpr int64_t AccQMin = GetQMin::value; + static constexpr int64_t AccQMax = GetQMax::value; protected: TosaReference::TensorTemplate* a; @@ -269,7 +269,7 @@ protected: TosaReference::TensorTemplate* out_imag; }; -template +template class OpTransposeConv2d : public GraphNode { public: @@ -281,15 +281,15 @@ public: using InEigenType = typename GetEigenType::type; using WeightEigenType = typename GetEigenType::type; - using AccEigenType = typename GetAccEigenType::type; // Note: different from GetEigenType - using OutEigenType = typename GetEigenType::type; + using AccEigenType = typename GetAccEigenType::type; // Note: different from GetEigenType + using OutEigenType = typename GetEigenType::type; using TIn = Eigen::Tensor; using TWeight = Eigen::Tensor; using TBias = Eigen::Tensor; using TOut = Eigen::Tensor; - static constexpr int64_t AccQMin = GetQMin::value; - static constexpr int64_t AccQMax = GetQMax::value; + static constexpr int64_t AccQMin = GetQMin::value; + static constexpr int64_t AccQMax = GetQMax::value; protected: TosaReference::TensorTemplate* input; diff --git a/scripts/operator_api/generate_api.py b/scripts/operator_api/generate_api.py index 1f89f74..671d902 100644 --- a/scripts/operator_api/generate_api.py +++ b/scripts/operator_api/generate_api.py @@ -9,6 +9,8 @@ from xml.dom import minidom from jinja2 import Environment from jinja2 import FileSystemLoader +# Note: main script designed to be run from the scripts/operator_api/ directory + def getTosaArgTypes(tosaXml): """ diff --git a/thirdparty/serialization_lib b/thirdparty/serialization_lib index c15f7d5..80905bb 160000 --- a/thirdparty/serialization_lib +++ b/thirdparty/serialization_lib @@ -1 +1 @@ -Subproject commit c15f7d52aa4f360eba2344449baa418b7608ac7c +Subproject commit 80905bba37ce55e8db293b1405a78b63dc1855cb diff --git a/verif/generator/tosa_test_gen.py b/verif/generator/tosa_test_gen.py index fddf942..5f9e2c1 100644 --- a/verif/generator/tosa_test_gen.py +++ b/verif/generator/tosa_test_gen.py @@ -690,7 +690,7 @@ class TosaTestGen: return None attr = ts.TosaSerializerAttribute() - attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], accum_dtype) + attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1]) self.ser.addOperator(op["op"], input_list, output_list, attr) return result_tens @@ -762,7 +762,7 @@ class TosaTestGen: return None attr = ts.TosaSerializerAttribute() - attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], accum_dtype) + attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1]) self.ser.addOperator(op["op"], input_list, output_list, attr) return result_tens @@ -825,9 +825,7 @@ class TosaTestGen: return None attr = ts.TosaSerializerAttribute() - attr.TransposeConvAttribute( - out_pad, stride, output_shape, qinfo[0], qinfo[1], accum_dtype - ) + attr.TransposeConvAttribute(out_pad, stride, output_shape, qinfo[0], qinfo[1]) self.ser.addOperator(op["op"], input_list, output_list, attr) return result_tens @@ -898,7 +896,7 @@ class TosaTestGen: return None attr = ts.TosaSerializerAttribute() - attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1], accum_dtype) + attr.ConvAttribute(padding, strides, dilations, qinfo[0], qinfo[1]) self.ser.addOperator(op["op"], input_list, output_list, attr) return result_tens @@ -947,7 +945,7 @@ class TosaTestGen: return None attr = ts.TosaSerializerAttribute() - attr.FullyConnectedAttribute(qinfo[0], qinfo[1], accum_dtype) + attr.FullyConnectedAttribute(qinfo[0], qinfo[1]) self.ser.addOperator(op["op"], input_list, output_list, attr) return result_tens @@ -989,7 +987,7 @@ class TosaTestGen: return None attr = ts.TosaSerializerAttribute() - attr.MatMulAttribute(qinfo[0], qinfo[1], accum_dtype) + attr.MatMulAttribute(qinfo[0], qinfo[1]) self.ser.addOperator(op["op"], input_list, output_list, attr) return result_tens -- cgit v1.2.1