From f36f25619cc3a34c75e78637ed244a2ca54ab3f4 Mon Sep 17 00:00:00 2001 From: Tai Ly Date: Thu, 14 Mar 2024 16:21:29 +0000 Subject: [ref model] Add acc_type to Conv Ops This patch implements changes required by the new acc_type field in ConvAttribute and TransposeConvAttribute Signed-off-by: Tai Ly Signed-off-by: Jeremy Johnson Change-Id: Ib13dbeec4d8920e0ddbcca02b727e7277f2c8d62 --- reference_model/src/ops/tensor_ops.cc | 163 ++++++++++++++++++---------------- 1 file changed, 84 insertions(+), 79 deletions(-) (limited to 'reference_model/src/ops/tensor_ops.cc') diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc index 7bd249b..afd20e9 100644 --- a/reference_model/src/ops/tensor_ops.cc +++ b/reference_model/src/ops/tensor_ops.cc @@ -586,8 +586,10 @@ int OpAvgPool2d::eval() return GraphNode::eval(); } -template -OpConv2d::OpConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) +template +OpConv2d::OpConv2d(SubgraphTraverser* sgt_, + TosaAttributeBase* attribute_, + uint64_t id_) : GraphNode(sgt_, Op_CONV2D, id_) { setRequiredOperands(3, 1); @@ -596,15 +598,15 @@ OpConv2d::OpConv2d(SubgraphTraverser* sgt_, Tosa INIT_ATTRIBUTE(Conv); } -template -OpConv2d::~OpConv2d() +template +OpConv2d::~OpConv2d() { if (attribute) delete attribute; } -template -int OpConv2d::checkTensorAttributes() +template +int OpConv2d::checkTensorAttributes() { if (validateRequiredOperands()) return 1; @@ -640,8 +642,8 @@ int OpConv2d::checkTensorAttributes() return 0; } -template -int OpConv2d::eval() +template +int OpConv2d::eval() { int in_batch = this->input->getShape()[0]; int in_height = this->input->getShape()[1]; @@ -793,8 +795,10 @@ int OpConv2d::eval() return GraphNode::eval(); } -template -OpConv3d::OpConv3d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) +template +OpConv3d::OpConv3d(SubgraphTraverser* sgt_, + TosaAttributeBase* attribute_, + uint64_t id_) : GraphNode(sgt_, Op_CONV3D, id_) { setRequiredOperands(3, 1); @@ -803,15 +807,15 @@ OpConv3d::OpConv3d(SubgraphTraverser* sgt_, Tosa INIT_ATTRIBUTE(Conv); } -template -OpConv3d::~OpConv3d() +template +OpConv3d::~OpConv3d() { if (attribute) delete attribute; } -template -int OpConv3d::checkTensorAttributes() +template +int OpConv3d::checkTensorAttributes() { if (validateRequiredOperands()) return 1; @@ -847,8 +851,8 @@ int OpConv3d::checkTensorAttributes() return 0; } -template -int OpConv3d::eval() +template +int OpConv3d::eval() { int in_batch = this->input->getShape()[0]; int in_depth = this->input->getShape()[1]; @@ -1008,10 +1012,10 @@ int OpConv3d::eval() return GraphNode::eval(); } -template -OpDepthwiseConv2d::OpDepthwiseConv2d(SubgraphTraverser* sgt_, - TosaAttributeBase* attribute_, - uint64_t id_) +template +OpDepthwiseConv2d::OpDepthwiseConv2d(SubgraphTraverser* sgt_, + TosaAttributeBase* attribute_, + uint64_t id_) : GraphNode(sgt_, Op_DEPTHWISE_CONV2D, id_) { setRequiredOperands(3, 1); @@ -1020,15 +1024,15 @@ OpDepthwiseConv2d::OpDepthwiseConv2d(SubgraphTra INIT_ATTRIBUTE(Conv); } -template -OpDepthwiseConv2d::~OpDepthwiseConv2d() +template +OpDepthwiseConv2d::~OpDepthwiseConv2d() { if (attribute) delete attribute; } -template -int OpDepthwiseConv2d::checkTensorAttributes() +template +int OpDepthwiseConv2d::checkTensorAttributes() { if (validateRequiredOperands()) return 1; @@ -1064,8 +1068,8 @@ int OpDepthwiseConv2d::checkTensorAttributes() return 0; } -template -int OpDepthwiseConv2d::eval() +template +int OpDepthwiseConv2d::eval() { int in_batch = this->input->getShape()[0]; int in_height = this->input->getShape()[1]; @@ -1903,10 +1907,10 @@ int OpRFFT2d::eval() return GraphNode::eval(); } -template -OpTransposeConv2d::OpTransposeConv2d(SubgraphTraverser* sgt_, - TosaAttributeBase* attribute_, - uint64_t id_) +template +OpTransposeConv2d::OpTransposeConv2d(SubgraphTraverser* sgt_, + TosaAttributeBase* attribute_, + uint64_t id_) : GraphNode(sgt_, Op_TRANSPOSE_CONV2D, id_) { setRequiredOperands(3, 1); @@ -1915,15 +1919,15 @@ OpTransposeConv2d::OpTransposeConv2d(SubgraphTra INIT_ATTRIBUTE(TransposeConv); } -template -OpTransposeConv2d::~OpTransposeConv2d() +template +OpTransposeConv2d::~OpTransposeConv2d() { if (attribute) delete attribute; } -template -int OpTransposeConv2d::checkTensorAttributes() +template +int OpTransposeConv2d::checkTensorAttributes() { if (validateRequiredOperands()) return 1; @@ -2017,8 +2021,8 @@ int OpTransposeConv2d::checkTensorAttributes() return 0; } -template -int OpTransposeConv2d::eval() +template +int OpTransposeConv2d::eval() { int in_batch = this->input->getShape()[0]; int in_height = this->input->getShape()[1]; @@ -2168,39 +2172,39 @@ DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP64, FP64); DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP8E4M3, FP16); DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP8E5M2, FP16); -// [in_t, weight_t, out_t] -DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP16, FP16, FP16); -DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP16, FP16, FP32); -DEF_INSTANTIATE_THREE_TYPE(OpConv2d, BF16, BF16, FP32); -DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP32, FP32, FP32); -DEF_INSTANTIATE_THREE_TYPE(OpConv2d, INT8, INT4, INT32); -DEF_INSTANTIATE_THREE_TYPE(OpConv2d, INT8, INT8, INT32); -DEF_INSTANTIATE_THREE_TYPE(OpConv2d, INT16, INT8, INT48); -DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP64, FP64, FP64); -DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP8E4M3, FP8E4M3, FP16); -DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP8E5M2, FP8E5M2, FP16); - -DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP16, FP16, FP16); -DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP16, FP16, FP32); -DEF_INSTANTIATE_THREE_TYPE(OpConv3d, BF16, BF16, FP32); -DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP32, FP32, FP32); -DEF_INSTANTIATE_THREE_TYPE(OpConv3d, INT8, INT4, INT32); -DEF_INSTANTIATE_THREE_TYPE(OpConv3d, INT8, INT8, INT32); -DEF_INSTANTIATE_THREE_TYPE(OpConv3d, INT16, INT8, INT48); -DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP64, FP64, FP64); -DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP8E4M3, FP8E4M3, FP16); -DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP8E5M2, FP8E5M2, FP16); - -DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP16); -DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP32); -DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, BF16, BF16, FP32); -DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP32, FP32, FP32); -DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT8, INT4, INT32); -DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT8, INT8, INT32); -DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT16, INT8, INT48); -DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP64, FP64, FP64); -DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP8E4M3, FP8E4M3, FP16); -DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP8E5M2, FP8E5M2, FP16); +// [in_t, weight_t, acc_t, out_t] +DEF_INSTANTIATE_FOUR_TYPE(OpConv2d, FP16, FP16, FP16, FP16); +DEF_INSTANTIATE_FOUR_TYPE(OpConv2d, FP16, FP16, FP32, FP16); +DEF_INSTANTIATE_FOUR_TYPE(OpConv2d, BF16, BF16, FP32, BF16); +DEF_INSTANTIATE_FOUR_TYPE(OpConv2d, FP32, FP32, FP32, FP32); +DEF_INSTANTIATE_FOUR_TYPE(OpConv2d, INT8, INT4, INT32, INT32); +DEF_INSTANTIATE_FOUR_TYPE(OpConv2d, INT8, INT8, INT32, INT32); +DEF_INSTANTIATE_FOUR_TYPE(OpConv2d, INT16, INT8, INT48, INT48); +DEF_INSTANTIATE_FOUR_TYPE(OpConv2d, FP64, FP64, FP64, FP64); +DEF_INSTANTIATE_FOUR_TYPE(OpConv2d, FP8E4M3, FP8E4M3, FP16, FP16); +DEF_INSTANTIATE_FOUR_TYPE(OpConv2d, FP8E5M2, FP8E5M2, FP16, FP16); + +DEF_INSTANTIATE_FOUR_TYPE(OpConv3d, FP16, FP16, FP16, FP16); +DEF_INSTANTIATE_FOUR_TYPE(OpConv3d, FP16, FP16, FP32, FP16); +DEF_INSTANTIATE_FOUR_TYPE(OpConv3d, BF16, BF16, FP32, BF16); +DEF_INSTANTIATE_FOUR_TYPE(OpConv3d, FP32, FP32, FP32, FP32); +DEF_INSTANTIATE_FOUR_TYPE(OpConv3d, INT8, INT4, INT32, INT32); +DEF_INSTANTIATE_FOUR_TYPE(OpConv3d, INT8, INT8, INT32, INT32); +DEF_INSTANTIATE_FOUR_TYPE(OpConv3d, INT16, INT8, INT48, INT48); +DEF_INSTANTIATE_FOUR_TYPE(OpConv3d, FP64, FP64, FP64, FP64); +DEF_INSTANTIATE_FOUR_TYPE(OpConv3d, FP8E4M3, FP8E4M3, FP16, FP16); +DEF_INSTANTIATE_FOUR_TYPE(OpConv3d, FP8E5M2, FP8E5M2, FP16, FP16); + +DEF_INSTANTIATE_FOUR_TYPE(OpDepthwiseConv2d, FP16, FP16, FP16, FP16); +DEF_INSTANTIATE_FOUR_TYPE(OpDepthwiseConv2d, FP16, FP16, FP32, FP16); +DEF_INSTANTIATE_FOUR_TYPE(OpDepthwiseConv2d, BF16, BF16, FP32, BF16); +DEF_INSTANTIATE_FOUR_TYPE(OpDepthwiseConv2d, FP32, FP32, FP32, FP32); +DEF_INSTANTIATE_FOUR_TYPE(OpDepthwiseConv2d, INT8, INT4, INT32, INT32); +DEF_INSTANTIATE_FOUR_TYPE(OpDepthwiseConv2d, INT8, INT8, INT32, INT32); +DEF_INSTANTIATE_FOUR_TYPE(OpDepthwiseConv2d, INT16, INT8, INT48, INT48); +DEF_INSTANTIATE_FOUR_TYPE(OpDepthwiseConv2d, FP64, FP64, FP64, FP64); +DEF_INSTANTIATE_FOUR_TYPE(OpDepthwiseConv2d, FP8E4M3, FP8E4M3, FP16, FP16); +DEF_INSTANTIATE_FOUR_TYPE(OpDepthwiseConv2d, FP8E5M2, FP8E5M2, FP16, FP16); DEF_INSTANTIATE_ONE_TYPE(OpFFT2d, FP32); DEF_INSTANTIATE_ONE_TYPE(OpFFT2d, FP64); @@ -2238,13 +2242,14 @@ DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP8E5M2); DEF_INSTANTIATE_ONE_TYPE(OpRFFT2d, FP32); DEF_INSTANTIATE_ONE_TYPE(OpRFFT2d, FP64); -DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP16, FP16, FP16); -DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP16, FP16, FP32); -DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, BF16, BF16, FP32); -DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP32, FP32, FP32); -DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, INT8, INT4, INT32); -DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, INT8, INT8, INT32); -DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, INT16, INT8, INT48); -DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP64, FP64, FP64); -DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP8E4M3, FP8E4M3, FP16); -DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP8E5M2, FP8E5M2, FP16); +// [in_t, weight_t, acc_t, out_t] +DEF_INSTANTIATE_FOUR_TYPE(OpTransposeConv2d, FP16, FP16, FP16, FP16); +DEF_INSTANTIATE_FOUR_TYPE(OpTransposeConv2d, FP16, FP16, FP32, FP16); +DEF_INSTANTIATE_FOUR_TYPE(OpTransposeConv2d, BF16, BF16, FP32, BF16); +DEF_INSTANTIATE_FOUR_TYPE(OpTransposeConv2d, FP32, FP32, FP32, FP32); +DEF_INSTANTIATE_FOUR_TYPE(OpTransposeConv2d, INT8, INT4, INT32, INT32); +DEF_INSTANTIATE_FOUR_TYPE(OpTransposeConv2d, INT8, INT8, INT32, INT32); +DEF_INSTANTIATE_FOUR_TYPE(OpTransposeConv2d, INT16, INT8, INT48, INT48); +DEF_INSTANTIATE_FOUR_TYPE(OpTransposeConv2d, FP64, FP64, FP64, FP64); +DEF_INSTANTIATE_FOUR_TYPE(OpTransposeConv2d, FP8E4M3, FP8E4M3, FP16, FP16); +DEF_INSTANTIATE_FOUR_TYPE(OpTransposeConv2d, FP8E5M2, FP8E5M2, FP16, FP16); -- cgit v1.2.1