diff options
author | Tai Ly <tai.ly@arm.com> | 2024-03-14 16:21:29 +0000 |
---|---|---|
committer | Eric Kunze <eric.kunze@arm.com> | 2024-03-20 00:02:15 +0000 |
commit | f36f25619cc3a34c75e78637ed244a2ca54ab3f4 (patch) | |
tree | b1aa6a7314ef598561f0259c4d614a4169451031 /reference_model | |
parent | 0a6d1deef02f2bd76b3068d615565f20c46075a5 (diff) | |
download | reference_model-f36f25619cc3a34c75e78637ed244a2ca54ab3f4.tar.gz |
[ref model] Add acc_type to Conv Ops
This patch implements changes required by the new acc_type field in
ConvAttribute and TransposeConvAttribute
Signed-off-by: Tai Ly <tai.ly@arm.com>
Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com>
Change-Id: Ib13dbeec4d8920e0ddbcca02b727e7277f2c8d62
Diffstat (limited to 'reference_model')
-rw-r--r-- | reference_model/src/graph_node.h | 4 | ||||
-rw-r--r-- | reference_model/src/ops/op_factory.cc | 85 | ||||
-rw-r--r-- | reference_model/src/ops/op_factory.h | 8 | ||||
-rw-r--r-- | reference_model/src/ops/tensor_ops.cc | 163 | ||||
-rw-r--r-- | reference_model/src/ops/tensor_ops.h | 16 |
5 files changed, 148 insertions, 128 deletions
diff --git a/reference_model/src/graph_node.h b/reference_model/src/graph_node.h index e10f132..c0dceda 100644 --- a/reference_model/src/graph_node.h +++ b/reference_model/src/graph_node.h @@ -41,6 +41,10 @@ #define DEF_INSTANTIATE_THREE_TYPE(OP, DTYPE1, DTYPE2, DTYPE3) \ template class TosaReference::OP<TOSA_REF_TYPE_##DTYPE1, TOSA_REF_TYPE_##DTYPE2, TOSA_REF_TYPE_##DTYPE3>; +#define DEF_INSTANTIATE_FOUR_TYPE(OP, DTYPE1, DTYPE2, DTYPE3, DTYPE4) \ + template class TosaReference::OP<TOSA_REF_TYPE_##DTYPE1, TOSA_REF_TYPE_##DTYPE2, TOSA_REF_TYPE_##DTYPE3, \ + TOSA_REF_TYPE_##DTYPE4>; + #define DEF_INSTANTIATE_THREE_TYPE_RESIZE(OP, DTYPE1, DTYPE2, OP_TYPE) \ template class TosaReference::OP<TOSA_REF_TYPE_##DTYPE1, TOSA_REF_TYPE_##DTYPE2, OP_TYPE>; diff --git a/reference_model/src/ops/op_factory.cc b/reference_model/src/ops/op_factory.cc index 0f0013c..74315d7 100644 --- a/reference_model/src/ops/op_factory.cc +++ b/reference_model/src/ops/op_factory.cc @@ -70,41 +70,43 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, Pool, FP8E5M2, FP16); break; case Op_CONV2D: - DEF_FACTORY_THREE_TYPE(OpConv2d, FP16, FP16, FP16); - DEF_FACTORY_THREE_TYPE(OpConv2d, FP16, FP16, FP32); - DEF_FACTORY_THREE_TYPE(OpConv2d, BF16, BF16, FP32); - DEF_FACTORY_THREE_TYPE(OpConv2d, FP32, FP32, FP32); - DEF_FACTORY_THREE_TYPE(OpConv2d, INT8, INT4, INT32); - DEF_FACTORY_THREE_TYPE(OpConv2d, INT8, INT8, INT32); - DEF_FACTORY_THREE_TYPE(OpConv2d, INT16, INT8, INT48); - DEF_FACTORY_THREE_TYPE(OpConv2d, FP64, FP64, FP64); - DEF_FACTORY_THREE_TYPE(OpConv2d, FP64, FP64, FP64); - DEF_FACTORY_THREE_TYPE(OpConv2d, FP8E4M3, FP8E4M3, FP16); - DEF_FACTORY_THREE_TYPE(OpConv2d, FP8E5M2, FP8E5M2, FP16); + // OP, attr_name, in_t, w_t, acc_t, out_t + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv2d, Conv, FP16, FP16, FP16, FP16); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv2d, Conv, FP16, FP16, FP32, FP16); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv2d, Conv, BF16, BF16, FP32, BF16); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv2d, Conv, FP32, FP32, FP32, FP32); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv2d, Conv, INT8, INT4, INT32, INT32); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv2d, Conv, INT8, INT8, INT32, INT32); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv2d, Conv, INT16, INT8, INT48, INT48); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv2d, Conv, FP64, FP64, FP64, FP64); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv2d, Conv, FP8E4M3, FP8E4M3, FP16, FP16); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv2d, Conv, FP8E5M2, FP8E5M2, FP16, FP16); break; case Op_CONV3D: - DEF_FACTORY_THREE_TYPE(OpConv3d, FP16, FP16, FP16); - DEF_FACTORY_THREE_TYPE(OpConv3d, FP16, FP16, FP32); - DEF_FACTORY_THREE_TYPE(OpConv3d, BF16, BF16, FP32); - DEF_FACTORY_THREE_TYPE(OpConv3d, FP32, FP32, FP32); - DEF_FACTORY_THREE_TYPE(OpConv3d, INT8, INT4, INT32); - DEF_FACTORY_THREE_TYPE(OpConv3d, INT8, INT8, INT32); - DEF_FACTORY_THREE_TYPE(OpConv3d, INT16, INT8, INT48); - DEF_FACTORY_THREE_TYPE(OpConv3d, FP64, FP64, FP64); - DEF_FACTORY_THREE_TYPE(OpConv3d, FP8E4M3, FP8E4M3, FP16); - DEF_FACTORY_THREE_TYPE(OpConv3d, FP8E5M2, FP8E5M2, FP16); + // OP, attr_name, in_t, w_t, acc_t, out_t + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv3d, Conv, FP16, FP16, FP16, FP16); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv3d, Conv, FP16, FP16, FP32, FP16); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv3d, Conv, BF16, BF16, FP32, BF16); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv3d, Conv, FP32, FP32, FP32, FP32); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv3d, Conv, INT8, INT4, INT32, INT32); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv3d, Conv, INT8, INT8, INT32, INT32); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv3d, Conv, INT16, INT8, INT48, INT48); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv3d, Conv, FP64, FP64, FP64, FP64); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv3d, Conv, FP8E4M3, FP8E4M3, FP16, FP16); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpConv3d, Conv, FP8E5M2, FP8E5M2, FP16, FP16); break; case Op_DEPTHWISE_CONV2D: - DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP16); - DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP32); - DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, BF16, BF16, FP32); - DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP32, FP32, FP32); - DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, INT8, INT4, INT32); - DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, INT8, INT8, INT32); - DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, INT16, INT8, INT48); - DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP64, FP64, FP64); - DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP8E4M3, FP8E4M3, FP16); - DEF_FACTORY_THREE_TYPE(OpDepthwiseConv2d, FP8E5M2, FP8E5M2, FP16); + // OP, attr_name, in_t, w_t, acc_t, out_t + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP16, FP16, FP16, FP16); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP16, FP16, FP32, FP16); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, BF16, BF16, FP32, BF16); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP32, FP32, FP32, FP32); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, INT8, INT4, INT32, INT32); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, INT8, INT8, INT32, INT32); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, INT16, INT8, INT48, INT48); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP64, FP64, FP64, FP64); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP8E4M3, FP8E4M3, FP16, FP16); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpDepthwiseConv2d, Conv, FP8E5M2, FP8E5M2, FP16, FP16); break; case Op_FFT2D: DEF_FACTORY_ONE_TYPE(OpFFT2d, FP32); @@ -148,16 +150,17 @@ GraphNode* OpFactory::newOp(SubgraphTraverser* sgt, DEF_FACTORY_ONE_TYPE(OpRFFT2d, FP64); break; case Op_TRANSPOSE_CONV2D: - DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, FP16, FP16, FP16); - DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, FP16, FP16, FP32); - DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, BF16, BF16, FP32); - DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, FP32, FP32, FP32); - DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, INT8, INT4, INT32); - DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, INT8, INT8, INT32); - DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, INT16, INT8, INT48); - DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, FP64, FP64, FP64); - DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, FP8E4M3, FP8E4M3, FP16); - DEF_FACTORY_THREE_TYPE(OpTransposeConv2d, FP8E5M2, FP8E5M2, FP16); + // OP, attr_name, in_t, w_t, acc_t, out_t + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP16, FP16, FP16, FP16); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP16, FP16, FP32, FP16); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, BF16, BF16, FP32, BF16); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP32, FP32, FP32, FP32); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, INT8, INT4, INT32, INT32); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, INT8, INT8, INT32, INT32); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, INT16, INT8, INT48, INT48); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP64, FP64, FP64, FP64); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP8E4M3, FP8E4M3, FP16, FP16); + DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OpTransposeConv2d, TransposeConv, FP8E5M2, FP8E5M2, FP16, FP16); break; // activation_funcs diff --git a/reference_model/src/ops/op_factory.h b/reference_model/src/ops/op_factory.h index 1d20066..f1d1680 100644 --- a/reference_model/src/ops/op_factory.h +++ b/reference_model/src/ops/op_factory.h @@ -94,6 +94,14 @@ return new OP<TOSA_REF_TYPE_##DTYPE1, TOSA_REF_TYPE_##DTYPE2, TOSA_REF_TYPE_##DTYPE3>(sgt, attribute, id); \ } +#define DEF_FACTORY_THREE_TYPE_ONE_ACCUM(OP, ATTR_NAME, IN_DTYPE, W_DTYPE, ACC_DTYPE, OUT_DTYPE) \ + if (inputDTYPE == TOSA_REF_TYPE_##IN_DTYPE && weightDTYPE == TOSA_REF_TYPE_##W_DTYPE && \ + outputDTYPE == TOSA_REF_TYPE_##OUT_DTYPE && ACCUM_FROM_ATTRIBUTE(ATTR_NAME) == TOSA_REF_TYPE_##ACC_DTYPE) \ + { \ + return new OP<TOSA_REF_TYPE_##IN_DTYPE, TOSA_REF_TYPE_##W_DTYPE, TOSA_REF_TYPE_##ACC_DTYPE, \ + TOSA_REF_TYPE_##OUT_DTYPE>(sgt, attribute, id); \ + } + // Statement-expression to evaluate accumulate attribute in-place #define ACCUM_FROM_ATTRIBUTE(ATTRIBUTE_NAME) \ ({ \ diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc index 7bd249b..afd20e9 100644 --- a/reference_model/src/ops/tensor_ops.cc +++ b/reference_model/src/ops/tensor_ops.cc @@ -586,8 +586,10 @@ int OpAvgPool2d<Dtype, AccDtype>::eval() return GraphNode::eval(); } -template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype> -OpConv2d<InDtype, WeightDtype, OutDtype>::OpConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) +template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype> +OpConv2d<InDtype, WeightDtype, AccDtype, OutDtype>::OpConv2d(SubgraphTraverser* sgt_, + TosaAttributeBase* attribute_, + uint64_t id_) : GraphNode(sgt_, Op_CONV2D, id_) { setRequiredOperands(3, 1); @@ -596,15 +598,15 @@ OpConv2d<InDtype, WeightDtype, OutDtype>::OpConv2d(SubgraphTraverser* sgt_, Tosa INIT_ATTRIBUTE(Conv); } -template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype> -OpConv2d<InDtype, WeightDtype, OutDtype>::~OpConv2d() +template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype> +OpConv2d<InDtype, WeightDtype, AccDtype, OutDtype>::~OpConv2d() { if (attribute) delete attribute; } -template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype> -int OpConv2d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes() +template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype> +int OpConv2d<InDtype, WeightDtype, AccDtype, OutDtype>::checkTensorAttributes() { if (validateRequiredOperands()) return 1; @@ -640,8 +642,8 @@ int OpConv2d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes() return 0; } -template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype> -int OpConv2d<InDtype, WeightDtype, OutDtype>::eval() +template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype> +int OpConv2d<InDtype, WeightDtype, AccDtype, OutDtype>::eval() { int in_batch = this->input->getShape()[0]; int in_height = this->input->getShape()[1]; @@ -793,8 +795,10 @@ int OpConv2d<InDtype, WeightDtype, OutDtype>::eval() return GraphNode::eval(); } -template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype> -OpConv3d<InDtype, WeightDtype, OutDtype>::OpConv3d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) +template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype> +OpConv3d<InDtype, WeightDtype, AccDtype, OutDtype>::OpConv3d(SubgraphTraverser* sgt_, + TosaAttributeBase* attribute_, + uint64_t id_) : GraphNode(sgt_, Op_CONV3D, id_) { setRequiredOperands(3, 1); @@ -803,15 +807,15 @@ OpConv3d<InDtype, WeightDtype, OutDtype>::OpConv3d(SubgraphTraverser* sgt_, Tosa INIT_ATTRIBUTE(Conv); } -template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype> -OpConv3d<InDtype, WeightDtype, OutDtype>::~OpConv3d() +template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype> +OpConv3d<InDtype, WeightDtype, AccDtype, OutDtype>::~OpConv3d() { if (attribute) delete attribute; } -template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype> -int OpConv3d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes() +template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype> +int OpConv3d<InDtype, WeightDtype, AccDtype, OutDtype>::checkTensorAttributes() { if (validateRequiredOperands()) return 1; @@ -847,8 +851,8 @@ int OpConv3d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes() return 0; } -template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype> -int OpConv3d<InDtype, WeightDtype, OutDtype>::eval() +template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype> +int OpConv3d<InDtype, WeightDtype, AccDtype, OutDtype>::eval() { int in_batch = this->input->getShape()[0]; int in_depth = this->input->getShape()[1]; @@ -1008,10 +1012,10 @@ int OpConv3d<InDtype, WeightDtype, OutDtype>::eval() return GraphNode::eval(); } -template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype> -OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::OpDepthwiseConv2d(SubgraphTraverser* sgt_, - TosaAttributeBase* attribute_, - uint64_t id_) +template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype> +OpDepthwiseConv2d<InDtype, WeightDtype, AccDtype, OutDtype>::OpDepthwiseConv2d(SubgraphTraverser* sgt_, + TosaAttributeBase* attribute_, + uint64_t id_) : GraphNode(sgt_, Op_DEPTHWISE_CONV2D, id_) { setRequiredOperands(3, 1); @@ -1020,15 +1024,15 @@ OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::OpDepthwiseConv2d(SubgraphTra INIT_ATTRIBUTE(Conv); } -template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype> -OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::~OpDepthwiseConv2d() +template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype> +OpDepthwiseConv2d<InDtype, WeightDtype, AccDtype, OutDtype>::~OpDepthwiseConv2d() { if (attribute) delete attribute; } -template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype> -int OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes() +template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype> +int OpDepthwiseConv2d<InDtype, WeightDtype, AccDtype, OutDtype>::checkTensorAttributes() { if (validateRequiredOperands()) return 1; @@ -1064,8 +1068,8 @@ int OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes() return 0; } -template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype> -int OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::eval() +template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype> +int OpDepthwiseConv2d<InDtype, WeightDtype, AccDtype, OutDtype>::eval() { int in_batch = this->input->getShape()[0]; int in_height = this->input->getShape()[1]; @@ -1903,10 +1907,10 @@ int OpRFFT2d<Dtype>::eval() return GraphNode::eval(); } -template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype> -OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::OpTransposeConv2d(SubgraphTraverser* sgt_, - TosaAttributeBase* attribute_, - uint64_t id_) +template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype> +OpTransposeConv2d<InDtype, WeightDtype, AccDtype, OutDtype>::OpTransposeConv2d(SubgraphTraverser* sgt_, + TosaAttributeBase* attribute_, + uint64_t id_) : GraphNode(sgt_, Op_TRANSPOSE_CONV2D, id_) { setRequiredOperands(3, 1); @@ -1915,15 +1919,15 @@ OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::OpTransposeConv2d(SubgraphTra INIT_ATTRIBUTE(TransposeConv); } -template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype> -OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::~OpTransposeConv2d() +template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype> +OpTransposeConv2d<InDtype, WeightDtype, AccDtype, OutDtype>::~OpTransposeConv2d() { if (attribute) delete attribute; } -template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype> -int OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes() +template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype> +int OpTransposeConv2d<InDtype, WeightDtype, AccDtype, OutDtype>::checkTensorAttributes() { if (validateRequiredOperands()) return 1; @@ -2017,8 +2021,8 @@ int OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes() return 0; } -template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype> -int OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::eval() +template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype> +int OpTransposeConv2d<InDtype, WeightDtype, AccDtype, OutDtype>::eval() { int in_batch = this->input->getShape()[0]; int in_height = this->input->getShape()[1]; @@ -2168,39 +2172,39 @@ DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP64, FP64); DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP8E4M3, FP16); DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP8E5M2, FP16); -// [in_t, weight_t, out_t] -DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP16, FP16, FP16); -DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP16, FP16, FP32); -DEF_INSTANTIATE_THREE_TYPE(OpConv2d, BF16, BF16, FP32); -DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP32, FP32, FP32); -DEF_INSTANTIATE_THREE_TYPE(OpConv2d, INT8, INT4, INT32); -DEF_INSTANTIATE_THREE_TYPE(OpConv2d, INT8, INT8, INT32); -DEF_INSTANTIATE_THREE_TYPE(OpConv2d, INT16, INT8, INT48); -DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP64, FP64, FP64); -DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP8E4M3, FP8E4M3, FP16); -DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP8E5M2, FP8E5M2, FP16); - -DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP16, FP16, FP16); -DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP16, FP16, FP32); -DEF_INSTANTIATE_THREE_TYPE(OpConv3d, BF16, BF16, FP32); -DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP32, FP32, FP32); -DEF_INSTANTIATE_THREE_TYPE(OpConv3d, INT8, INT4, INT32); -DEF_INSTANTIATE_THREE_TYPE(OpConv3d, INT8, INT8, INT32); -DEF_INSTANTIATE_THREE_TYPE(OpConv3d, INT16, INT8, INT48); -DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP64, FP64, FP64); -DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP8E4M3, FP8E4M3, FP16); -DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP8E5M2, FP8E5M2, FP16); - -DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP16); -DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP32); -DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, BF16, BF16, FP32); -DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP32, FP32, FP32); -DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT8, INT4, INT32); -DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT8, INT8, INT32); -DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT16, INT8, INT48); -DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP64, FP64, FP64); -DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP8E4M3, FP8E4M3, FP16); -DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP8E5M2, FP8E5M2, FP16); +// [in_t, weight_t, acc_t, out_t] +DEF_INSTANTIATE_FOUR_TYPE(OpConv2d, FP16, FP16, FP16, FP16); +DEF_INSTANTIATE_FOUR_TYPE(OpConv2d, FP16, FP16, FP32, FP16); +DEF_INSTANTIATE_FOUR_TYPE(OpConv2d, BF16, BF16, FP32, BF16); +DEF_INSTANTIATE_FOUR_TYPE(OpConv2d, FP32, FP32, FP32, FP32); +DEF_INSTANTIATE_FOUR_TYPE(OpConv2d, INT8, INT4, INT32, INT32); +DEF_INSTANTIATE_FOUR_TYPE(OpConv2d, INT8, INT8, INT32, INT32); +DEF_INSTANTIATE_FOUR_TYPE(OpConv2d, INT16, INT8, INT48, INT48); +DEF_INSTANTIATE_FOUR_TYPE(OpConv2d, FP64, FP64, FP64, FP64); +DEF_INSTANTIATE_FOUR_TYPE(OpConv2d, FP8E4M3, FP8E4M3, FP16, FP16); +DEF_INSTANTIATE_FOUR_TYPE(OpConv2d, FP8E5M2, FP8E5M2, FP16, FP16); + +DEF_INSTANTIATE_FOUR_TYPE(OpConv3d, FP16, FP16, FP16, FP16); +DEF_INSTANTIATE_FOUR_TYPE(OpConv3d, FP16, FP16, FP32, FP16); +DEF_INSTANTIATE_FOUR_TYPE(OpConv3d, BF16, BF16, FP32, BF16); +DEF_INSTANTIATE_FOUR_TYPE(OpConv3d, FP32, FP32, FP32, FP32); +DEF_INSTANTIATE_FOUR_TYPE(OpConv3d, INT8, INT4, INT32, INT32); +DEF_INSTANTIATE_FOUR_TYPE(OpConv3d, INT8, INT8, INT32, INT32); +DEF_INSTANTIATE_FOUR_TYPE(OpConv3d, INT16, INT8, INT48, INT48); +DEF_INSTANTIATE_FOUR_TYPE(OpConv3d, FP64, FP64, FP64, FP64); +DEF_INSTANTIATE_FOUR_TYPE(OpConv3d, FP8E4M3, FP8E4M3, FP16, FP16); +DEF_INSTANTIATE_FOUR_TYPE(OpConv3d, FP8E5M2, FP8E5M2, FP16, FP16); + +DEF_INSTANTIATE_FOUR_TYPE(OpDepthwiseConv2d, FP16, FP16, FP16, FP16); +DEF_INSTANTIATE_FOUR_TYPE(OpDepthwiseConv2d, FP16, FP16, FP32, FP16); +DEF_INSTANTIATE_FOUR_TYPE(OpDepthwiseConv2d, BF16, BF16, FP32, BF16); +DEF_INSTANTIATE_FOUR_TYPE(OpDepthwiseConv2d, FP32, FP32, FP32, FP32); +DEF_INSTANTIATE_FOUR_TYPE(OpDepthwiseConv2d, INT8, INT4, INT32, INT32); +DEF_INSTANTIATE_FOUR_TYPE(OpDepthwiseConv2d, INT8, INT8, INT32, INT32); +DEF_INSTANTIATE_FOUR_TYPE(OpDepthwiseConv2d, INT16, INT8, INT48, INT48); +DEF_INSTANTIATE_FOUR_TYPE(OpDepthwiseConv2d, FP64, FP64, FP64, FP64); +DEF_INSTANTIATE_FOUR_TYPE(OpDepthwiseConv2d, FP8E4M3, FP8E4M3, FP16, FP16); +DEF_INSTANTIATE_FOUR_TYPE(OpDepthwiseConv2d, FP8E5M2, FP8E5M2, FP16, FP16); DEF_INSTANTIATE_ONE_TYPE(OpFFT2d, FP32); DEF_INSTANTIATE_ONE_TYPE(OpFFT2d, FP64); @@ -2238,13 +2242,14 @@ DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP8E5M2); DEF_INSTANTIATE_ONE_TYPE(OpRFFT2d, FP32); DEF_INSTANTIATE_ONE_TYPE(OpRFFT2d, FP64); -DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP16, FP16, FP16); -DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP16, FP16, FP32); -DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, BF16, BF16, FP32); -DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP32, FP32, FP32); -DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, INT8, INT4, INT32); -DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, INT8, INT8, INT32); -DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, INT16, INT8, INT48); -DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP64, FP64, FP64); -DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP8E4M3, FP8E4M3, FP16); -DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP8E5M2, FP8E5M2, FP16); +// [in_t, weight_t, acc_t, out_t] +DEF_INSTANTIATE_FOUR_TYPE(OpTransposeConv2d, FP16, FP16, FP16, FP16); +DEF_INSTANTIATE_FOUR_TYPE(OpTransposeConv2d, FP16, FP16, FP32, FP16); +DEF_INSTANTIATE_FOUR_TYPE(OpTransposeConv2d, BF16, BF16, FP32, BF16); +DEF_INSTANTIATE_FOUR_TYPE(OpTransposeConv2d, FP32, FP32, FP32, FP32); +DEF_INSTANTIATE_FOUR_TYPE(OpTransposeConv2d, INT8, INT4, INT32, INT32); +DEF_INSTANTIATE_FOUR_TYPE(OpTransposeConv2d, INT8, INT8, INT32, INT32); +DEF_INSTANTIATE_FOUR_TYPE(OpTransposeConv2d, INT16, INT8, INT48, INT48); +DEF_INSTANTIATE_FOUR_TYPE(OpTransposeConv2d, FP64, FP64, FP64, FP64); +DEF_INSTANTIATE_FOUR_TYPE(OpTransposeConv2d, FP8E4M3, FP8E4M3, FP16, FP16); +DEF_INSTANTIATE_FOUR_TYPE(OpTransposeConv2d, FP8E5M2, FP8E5M2, FP16, FP16); diff --git a/reference_model/src/ops/tensor_ops.h b/reference_model/src/ops/tensor_ops.h index e2bb811..2e65548 100644 --- a/reference_model/src/ops/tensor_ops.h +++ b/reference_model/src/ops/tensor_ops.h @@ -75,7 +75,7 @@ protected: int in_size, int out_size, int kernel_size, int stride, int32_t padding_left, int32_t padding_right); }; -template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype> +template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype> class OpConv2d : public GraphNode { public: @@ -87,7 +87,7 @@ public: using InEigenType = typename GetEigenType<InDtype>::type; using WeightEigenType = typename GetEigenType<WeightDtype>::type; - using AccEigenType = typename GetAccEigenType<OutDtype>::type; // Note: different from GetEigenType + using AccEigenType = typename GetAccEigenType<AccDtype>::type; // Note: different from GetEigenType using OutEigenType = typename GetEigenType<OutDtype>::type; using TIn = Eigen::Tensor<InEigenType, 4>; using TWeight = Eigen::Tensor<WeightEigenType, 4>; @@ -105,7 +105,7 @@ protected: tosa::TosaConvAttribute* attribute; }; -template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype> +template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype> class OpConv3d : public GraphNode { public: @@ -117,7 +117,7 @@ public: using InEigenType = typename GetEigenType<InDtype>::type; using WeightEigenType = typename GetEigenType<WeightDtype>::type; - using AccEigenType = typename GetAccEigenType<OutDtype>::type; // Note: different from GetEigenType + using AccEigenType = typename GetAccEigenType<AccDtype>::type; // Note: different from GetEigenType using OutEigenType = typename GetEigenType<OutDtype>::type; using TIn = Eigen::Tensor<InEigenType, 5>; using TWeight = Eigen::Tensor<WeightEigenType, 5>; @@ -135,7 +135,7 @@ protected: tosa::TosaConvAttribute* attribute; }; -template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype> +template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype> class OpDepthwiseConv2d : public GraphNode { public: @@ -147,7 +147,7 @@ public: using InEigenType = typename GetEigenType<InDtype>::type; using WeightEigenType = typename GetEigenType<WeightDtype>::type; - using AccEigenType = typename GetAccEigenType<OutDtype>::type; // Note: different from GetEigenType + using AccEigenType = typename GetAccEigenType<AccDtype>::type; // Note: different from GetEigenType using OutEigenType = typename GetEigenType<OutDtype>::type; using TIn = Eigen::Tensor<InEigenType, 4>; using TWeight = Eigen::Tensor<WeightEigenType, 4>; @@ -294,7 +294,7 @@ protected: tosa::TosaRFFTAttribute* attribute; }; -template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE OutDtype> +template <TOSA_REF_TYPE InDtype, TOSA_REF_TYPE WeightDtype, TOSA_REF_TYPE AccDtype, TOSA_REF_TYPE OutDtype> class OpTransposeConv2d : public GraphNode { public: @@ -306,7 +306,7 @@ public: using InEigenType = typename GetEigenType<InDtype>::type; using WeightEigenType = typename GetEigenType<WeightDtype>::type; - using AccEigenType = typename GetAccEigenType<OutDtype>::type; // Note: different from GetEigenType + using AccEigenType = typename GetAccEigenType<AccDtype>::type; // Note: different from GetEigenType using OutEigenType = typename GetEigenType<OutDtype>::type; using TIn = Eigen::Tensor<InEigenType, 4>; using TWeight = Eigen::Tensor<WeightEigenType, 4>; |