diff options
Diffstat (limited to 'reference_model/src/ops/tensor_ops.cc')
-rw-r--r-- | reference_model/src/ops/tensor_ops.cc | 250 |
1 files changed, 121 insertions, 129 deletions
diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc index dff9e08..4663c47 100644 --- a/reference_model/src/ops/tensor_ops.cc +++ b/reference_model/src/ops/tensor_ops.cc @@ -541,8 +541,8 @@ int OpAvgPool2d<Dtype, AccDtype>::eval() return GraphNode::eval(); } -template <DType InDtype, DType WeightDtype, DType AccDtype> -OpConv2d<InDtype, WeightDtype, AccDtype>::OpConv2d(SubgraphTraverser* sgt_, +template <DType InDtype, DType WeightDtype, DType OutDtype> +OpConv2d<InDtype, WeightDtype, OutDtype>::OpConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) : GraphNode(sgt_, Op_CONV2D, id_) @@ -553,15 +553,15 @@ OpConv2d<InDtype, WeightDtype, AccDtype>::OpConv2d(SubgraphTraverser* sgt_, INIT_ATTRIBUTE(Conv); } -template <DType InDtype, DType WeightDtype, DType AccDtype> -OpConv2d<InDtype, WeightDtype, AccDtype>::~OpConv2d() +template <DType InDtype, DType WeightDtype, DType OutDtype> +OpConv2d<InDtype, WeightDtype, OutDtype>::~OpConv2d() { if (attribute) delete attribute; } -template <DType InDtype, DType WeightDtype, DType AccDtype> -int OpConv2d<InDtype, WeightDtype, AccDtype>::checkTensorAttributes() +template <DType InDtype, DType WeightDtype, DType OutDtype> +int OpConv2d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes() { if (validateRequiredOperands()) return 1; @@ -577,7 +577,7 @@ int OpConv2d<InDtype, WeightDtype, AccDtype>::checkTensorAttributes() printNodeValidationError("OpConv2d: bias tensor must be rank 1"); } - ERROR_IF(outputs[0]->getDtype() != AccDtype, + ERROR_IF(outputs[0]->getDtype() != OutDtype, "OpConv2d: Output data type not supported for this configuration of operator"); input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]); @@ -597,8 +597,8 @@ int OpConv2d<InDtype, WeightDtype, AccDtype>::checkTensorAttributes() return 0; } -template <DType InDtype, DType WeightDtype, DType AccDtype> -int OpConv2d<InDtype, WeightDtype, AccDtype>::eval() +template <DType InDtype, DType WeightDtype, DType OutDtype> +int OpConv2d<InDtype, WeightDtype, OutDtype>::eval() { int in_batch = this->input->getShape()[0]; int in_height = this->input->getShape()[1]; @@ -634,14 +634,12 @@ int OpConv2d<InDtype, WeightDtype, AccDtype>::eval() int dilation_h = this->attribute->dilation()[0]; int dilation_w = this->attribute->dilation()[1]; - tosa::DType accum_dtype = (tosa::DType)this->attribute->accum_dtype(); - DEBUG_INFO(OP, "perform OpConv2d, input.shape=[%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d], output.shape=[%d,%d,%d,%d], " - "stride=[%d,%d], dilation=[%d,%d], pad=[%d,%d,%d,%d], accum_dtype=%s", + "stride=[%d,%d], dilation=[%d,%d], pad=[%d,%d,%d,%d]", in_batch, in_height, in_width, in_channels, f_height, f_width, f_in_channels, f_out_channels, out_batch, out_height, out_width, out_channels, stride_h, stride_w, dilation_h, dilation_w, pad_top, - pad_bottom, pad_left, pad_right, EnumNamesDType()[accum_dtype]); + pad_bottom, pad_left, pad_right); // GEMM-conv2d, left matrix is input, right matrix is weight Eigen::array<Eigen::Index, 2> im2col_input_dims; @@ -717,7 +715,7 @@ int OpConv2d<InDtype, WeightDtype, AccDtype>::eval() // reshape back to [N, H, W, C] this->output->getTensor() = biased_output.reshape(col2im_output_dims); - if (AccDtype == DType_INT48) + if (OutDtype == DType_INT48) { this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin); this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax); @@ -726,8 +724,8 @@ int OpConv2d<InDtype, WeightDtype, AccDtype>::eval() return GraphNode::eval(); } -template <DType InDtype, DType WeightDtype, DType AccDtype> -OpConv3d<InDtype, WeightDtype, AccDtype>::OpConv3d(SubgraphTraverser* sgt_, +template <DType InDtype, DType WeightDtype, DType OutDtype> +OpConv3d<InDtype, WeightDtype, OutDtype>::OpConv3d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) : GraphNode(sgt_, Op_CONV3D, id_) @@ -738,15 +736,15 @@ OpConv3d<InDtype, WeightDtype, AccDtype>::OpConv3d(SubgraphTraverser* sgt_, INIT_ATTRIBUTE(Conv); } -template <DType InDtype, DType WeightDtype, DType AccDtype> -OpConv3d<InDtype, WeightDtype, AccDtype>::~OpConv3d() +template <DType InDtype, DType WeightDtype, DType OutDtype> +OpConv3d<InDtype, WeightDtype, OutDtype>::~OpConv3d() { if (attribute) delete attribute; } -template <DType InDtype, DType WeightDtype, DType AccDtype> -int OpConv3d<InDtype, WeightDtype, AccDtype>::checkTensorAttributes() +template <DType InDtype, DType WeightDtype, DType OutDtype> +int OpConv3d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes() { if (validateRequiredOperands()) return 1; @@ -762,7 +760,7 @@ int OpConv3d<InDtype, WeightDtype, AccDtype>::checkTensorAttributes() printNodeValidationError("OpConv3d: bias tensor must be rank 1"); } - ERROR_IF(outputs[0]->getDtype() != AccDtype, + ERROR_IF(outputs[0]->getDtype() != OutDtype, "OpConv3d: Output data type not supported for this configuration of operator"); input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]); @@ -782,8 +780,8 @@ int OpConv3d<InDtype, WeightDtype, AccDtype>::checkTensorAttributes() return 0; } -template <DType InDtype, DType WeightDtype, DType AccDtype> -int OpConv3d<InDtype, WeightDtype, AccDtype>::eval() +template <DType InDtype, DType WeightDtype, DType OutDtype> +int OpConv3d<InDtype, WeightDtype, OutDtype>::eval() { int in_batch = this->input->getShape()[0]; int in_depth = this->input->getShape()[1]; @@ -827,15 +825,13 @@ int OpConv3d<InDtype, WeightDtype, AccDtype>::eval() int dilation_h = this->attribute->dilation()[1]; int dilation_w = this->attribute->dilation()[2]; - tosa::DType accum_dtype = (tosa::DType)this->attribute->accum_dtype(); - DEBUG_INFO( OP, "perform OpConv3d, input.shape=[%d,%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d,%d], output.shape=[%d,%d,%d,%d,%d], " - "stride=[%d,%d,%d], dilation=[%d,%d,%d], pad=[%d,%d,%d,%d,%d,%d], accum_dtype=%s", + "stride=[%d,%d,%d], dilation=[%d,%d,%d], pad=[%d,%d,%d,%d,%d,%d]", in_batch, in_depth, in_height, in_width, in_channels, f_out_channels, f_depth, f_height, f_width, f_in_channels, out_batch, out_depth, out_height, out_width, out_channels, stride_d, stride_h, stride_w, dilation_d, dilation_h, - dilation_w, pad_d0, pad_d1, pad_top, pad_bottom, pad_left, pad_right, EnumNamesDType()[accum_dtype]); + dilation_w, pad_d0, pad_d1, pad_top, pad_bottom, pad_left, pad_right); Eigen::array<std::pair<int32_t, int32_t>, 5> pad; pad[0] = std::make_pair(0, 0); @@ -907,7 +903,7 @@ int OpConv3d<InDtype, WeightDtype, AccDtype>::eval() } } - if (AccDtype == DType_INT48) + if (OutDtype == DType_INT48) { this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin); this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax); @@ -916,8 +912,8 @@ int OpConv3d<InDtype, WeightDtype, AccDtype>::eval() return GraphNode::eval(); } -template <DType InDtype, DType WeightDtype, DType AccDtype> -OpDepthwiseConv2d<InDtype, WeightDtype, AccDtype>::OpDepthwiseConv2d(SubgraphTraverser* sgt_, +template <DType InDtype, DType WeightDtype, DType OutDtype> +OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::OpDepthwiseConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) : GraphNode(sgt_, Op_DEPTHWISE_CONV2D, id_) @@ -928,15 +924,15 @@ OpDepthwiseConv2d<InDtype, WeightDtype, AccDtype>::OpDepthwiseConv2d(SubgraphTra INIT_ATTRIBUTE(Conv); } -template <DType InDtype, DType WeightDtype, DType AccDtype> -OpDepthwiseConv2d<InDtype, WeightDtype, AccDtype>::~OpDepthwiseConv2d() +template <DType InDtype, DType WeightDtype, DType OutDtype> +OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::~OpDepthwiseConv2d() { if (attribute) delete attribute; } -template <DType InDtype, DType WeightDtype, DType AccDtype> -int OpDepthwiseConv2d<InDtype, WeightDtype, AccDtype>::checkTensorAttributes() +template <DType InDtype, DType WeightDtype, DType OutDtype> +int OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes() { if (validateRequiredOperands()) return 1; @@ -952,7 +948,7 @@ int OpDepthwiseConv2d<InDtype, WeightDtype, AccDtype>::checkTensorAttributes() printNodeValidationError("OpDepthwiseConv2d: bias tensor must be rank 1"); } - ERROR_IF(outputs[0]->getDtype() != AccDtype, + ERROR_IF(outputs[0]->getDtype() != OutDtype, "OpDepthwiseConv2d: Output data type not supported for this configuration of operator"); input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]); @@ -972,8 +968,8 @@ int OpDepthwiseConv2d<InDtype, WeightDtype, AccDtype>::checkTensorAttributes() return 0; } -template <DType InDtype, DType WeightDtype, DType AccDtype> -int OpDepthwiseConv2d<InDtype, WeightDtype, AccDtype>::eval() +template <DType InDtype, DType WeightDtype, DType OutDtype> +int OpDepthwiseConv2d<InDtype, WeightDtype, OutDtype>::eval() { int in_batch = this->input->getShape()[0]; int in_height = this->input->getShape()[1]; @@ -1010,14 +1006,12 @@ int OpDepthwiseConv2d<InDtype, WeightDtype, AccDtype>::eval() int dilation_h = this->attribute->dilation()[0]; int dilation_w = this->attribute->dilation()[1]; - tosa::DType accum_dtype = (tosa::DType)this->attribute->accum_dtype(); - DEBUG_INFO(OP, "perform OpDepthwiseConv2d, input.shape=[%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d], " - "output.shape=[%d,%d,%d,%d], stride=[%d,%d], dilation=[%d,%d], pad=[%d,%d,%d,%d], accum_dtype=%s", + "output.shape=[%d,%d,%d,%d], stride=[%d,%d], dilation=[%d,%d], pad=[%d,%d,%d,%d]", in_batch, in_height, in_width, in_channels, f_height, f_width, f_in_channels, f_multiplier, out_batch, out_height, out_width, out_channels, stride_h, stride_w, dilation_h, dilation_w, pad_top, - pad_bottom, pad_left, pad_right, EnumNamesDType()[accum_dtype]); + pad_bottom, pad_left, pad_right); Eigen::array<std::pair<int32_t, int32_t>, 4> pad; pad[0] = std::make_pair(0, 0); @@ -1083,7 +1077,7 @@ int OpDepthwiseConv2d<InDtype, WeightDtype, AccDtype>::eval() } } - if (AccDtype == DType_INT48) + if (OutDtype == DType_INT48) { this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin); this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax); @@ -1092,8 +1086,8 @@ int OpDepthwiseConv2d<InDtype, WeightDtype, AccDtype>::eval() return GraphNode::eval(); } -template <DType InDtype, DType WeightDtype, DType AccDtype> -OpFullyConnected<InDtype, WeightDtype, AccDtype>::OpFullyConnected(SubgraphTraverser* sgt_, +template <DType InDtype, DType WeightDtype, DType OutDtype> +OpFullyConnected<InDtype, WeightDtype, OutDtype>::OpFullyConnected(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) : GraphNode(sgt_, Op_FULLY_CONNECTED, id_) @@ -1104,15 +1098,15 @@ OpFullyConnected<InDtype, WeightDtype, AccDtype>::OpFullyConnected(SubgraphTrave INIT_ATTRIBUTE(FullyConnected); } -template <DType InDtype, DType WeightDtype, DType AccDtype> -OpFullyConnected<InDtype, WeightDtype, AccDtype>::~OpFullyConnected() +template <DType InDtype, DType WeightDtype, DType OutDtype> +OpFullyConnected<InDtype, WeightDtype, OutDtype>::~OpFullyConnected() { if (attribute) delete attribute; } -template <DType InDtype, DType WeightDtype, DType AccDtype> -int OpFullyConnected<InDtype, WeightDtype, AccDtype>::checkTensorAttributes() +template <DType InDtype, DType WeightDtype, DType OutDtype> +int OpFullyConnected<InDtype, WeightDtype, OutDtype>::checkTensorAttributes() { if (validateRequiredOperands()) return 1; @@ -1138,7 +1132,7 @@ int OpFullyConnected<InDtype, WeightDtype, AccDtype>::checkTensorAttributes() return 1; } - ERROR_IF(outputs[0]->getDtype() != AccDtype, + ERROR_IF(outputs[0]->getDtype() != OutDtype, "OpFullyConnected: Output data type not supported for this configuration of operator"); output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]); @@ -1149,8 +1143,8 @@ int OpFullyConnected<InDtype, WeightDtype, AccDtype>::checkTensorAttributes() return 0; } -template <DType InDtype, DType WeightDtype, DType AccDtype> -int OpFullyConnected<InDtype, WeightDtype, AccDtype>::eval() +template <DType InDtype, DType WeightDtype, DType OutDtype> +int OpFullyConnected<InDtype, WeightDtype, OutDtype>::eval() { typedef Eigen::Tensor<int, 1>::DimensionPair DimPair; Eigen::array<DimPair, 1> dims{ { DimPair(1, 0) } }; @@ -1177,7 +1171,7 @@ int OpFullyConnected<InDtype, WeightDtype, AccDtype>::eval() input_val.template cast<AccEigenType>().contract(weight_val.template cast<AccEigenType>(), dims).template cast<OutEigenType>() + this->bias->getTensor().reshape(bias_reshape).broadcast(bias_bcast); - if (AccDtype == DType_INT48) + if (OutDtype == DType_INT48) { this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin); this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax); @@ -1185,8 +1179,8 @@ int OpFullyConnected<InDtype, WeightDtype, AccDtype>::eval() return GraphNode::eval(); } -template <DType Dtype, DType AccDtype> -OpMatMul<Dtype, AccDtype>::OpMatMul(SubgraphTraverser* sgt_, +template <DType Dtype, DType OutDtype> +OpMatMul<Dtype, OutDtype>::OpMatMul(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) : GraphNode(sgt_, Op_MATMUL, id_) @@ -1197,15 +1191,15 @@ OpMatMul<Dtype, AccDtype>::OpMatMul(SubgraphTraverser* sgt_, INIT_ATTRIBUTE(MatMul); } -template <DType Dtype, DType AccDtype> -OpMatMul<Dtype, AccDtype>::~OpMatMul() +template <DType Dtype, DType OutDtype> +OpMatMul<Dtype, OutDtype>::~OpMatMul() { if (attribute) delete attribute; } -template <DType Dtype, DType AccDtype> -int OpMatMul<Dtype, AccDtype>::checkTensorAttributes() +template <DType Dtype, DType OutDtype> +int OpMatMul<Dtype, OutDtype>::checkTensorAttributes() { if (validateRequiredOperands()) return 1; @@ -1215,7 +1209,7 @@ int OpMatMul<Dtype, AccDtype>::checkTensorAttributes() return 1; } - ERROR_IF(outputs[0]->getDtype() != AccDtype, + ERROR_IF(outputs[0]->getDtype() != OutDtype, "OpMatMul: Output data type not supported for this configuration of operator"); a = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]); @@ -1266,8 +1260,8 @@ int OpMatMul<Dtype, AccDtype>::checkTensorAttributes() return 0; } -template <DType Dtype, DType AccDtype> -int OpMatMul<Dtype, AccDtype>::eval() +template <DType Dtype, DType OutDtype> +int OpMatMul<Dtype, OutDtype>::eval() { typedef Eigen::Tensor<int, 1>::DimensionPair DimPair; Eigen::array<DimPair, 1> dims{ { DimPair(1, 0) } }; @@ -1312,7 +1306,7 @@ int OpMatMul<Dtype, AccDtype>::eval() } } - if (AccDtype == DType_INT48) + if (OutDtype == DType_INT48) { this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin); this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax); @@ -1587,8 +1581,8 @@ int OpRFFT2d<Dtype>::eval() return GraphNode::eval(); } -template <DType InDtype, DType WeightDtype, DType AccDtype> -OpTransposeConv2d<InDtype, WeightDtype, AccDtype>::OpTransposeConv2d(SubgraphTraverser* sgt_, +template <DType InDtype, DType WeightDtype, DType OutDtype> +OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::OpTransposeConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) : GraphNode(sgt_, Op_TRANSPOSE_CONV2D, id_) @@ -1599,15 +1593,15 @@ OpTransposeConv2d<InDtype, WeightDtype, AccDtype>::OpTransposeConv2d(SubgraphTra INIT_ATTRIBUTE(TransposeConv); } -template <DType InDtype, DType WeightDtype, DType AccDtype> -OpTransposeConv2d<InDtype, WeightDtype, AccDtype>::~OpTransposeConv2d() +template <DType InDtype, DType WeightDtype, DType OutDtype> +OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::~OpTransposeConv2d() { if (attribute) delete attribute; } -template <DType InDtype, DType WeightDtype, DType AccDtype> -int OpTransposeConv2d<InDtype, WeightDtype, AccDtype>::checkTensorAttributes() +template <DType InDtype, DType WeightDtype, DType OutDtype> +int OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::checkTensorAttributes() { if (validateRequiredOperands()) return 1; @@ -1617,7 +1611,7 @@ int OpTransposeConv2d<InDtype, WeightDtype, AccDtype>::checkTensorAttributes() return 1; } - ERROR_IF(outputs[0]->getDtype() != AccDtype, + ERROR_IF(outputs[0]->getDtype() != OutDtype, "OpTransposeConv2d: Output data type not supported for this configuration of operator"); input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]); @@ -1701,8 +1695,8 @@ int OpTransposeConv2d<InDtype, WeightDtype, AccDtype>::checkTensorAttributes() return 0; } -template <DType InDtype, DType WeightDtype, DType AccDtype> -int OpTransposeConv2d<InDtype, WeightDtype, AccDtype>::eval() +template <DType InDtype, DType WeightDtype, DType OutDtype> +int OpTransposeConv2d<InDtype, WeightDtype, OutDtype>::eval() { int in_batch = this->input->getShape()[0]; int in_height = this->input->getShape()[1]; @@ -1729,8 +1723,6 @@ int OpTransposeConv2d<InDtype, WeightDtype, AccDtype>::eval() int stride_h = this->attribute->stride()[0]; int stride_w = this->attribute->stride()[1]; - tosa::DType accum_dtype = (tosa::DType)this->attribute->accum_dtype(); - ERROR_IF(in_batch != out_batch, "OpTransposeConv2d: tensor batch mismatch %d != %d", in_batch, out_batch); ERROR_IF(f_in_channels != in_channels, "OpTransposeConv2d: tensor input channel mismatch %d != %d", f_in_channels, in_channels); @@ -1741,10 +1733,10 @@ int OpTransposeConv2d<InDtype, WeightDtype, AccDtype>::eval() DEBUG_INFO(OP, "perform OpTransposeConv2d, input.shape=[%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d], " - "output.shape=[%d,%d,%d,%d], stride=[%d,%d], out_pad=[%d,%d,%d,%d], accum_dtype=%s", + "output.shape=[%d,%d,%d,%d], stride=[%d,%d], out_pad=[%d,%d,%d,%d]", in_batch, in_height, in_width, in_channels, f_height, f_width, f_out_channels, f_in_channels, out_batch, out_height, out_width, out_channels, stride_h, stride_w, out_pad_top, - out_pad_bottom, out_pad_left, out_pad_right, EnumNamesDType()[accum_dtype]); + out_pad_bottom, out_pad_left, out_pad_right); TIn input_val = this->input->getTensor(); TWeight weight_val = this->weight->getTensor(); @@ -1803,7 +1795,7 @@ int OpTransposeConv2d<InDtype, WeightDtype, AccDtype>::eval() } } - if (AccDtype == DType_INT48) + if (OutDtype == DType_INT48) { this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin); this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax); @@ -1819,52 +1811,52 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT8); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT16); -DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, FP16, FP16); -DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, FP16, FP32); -DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, BF16, FP32); -DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, FP32, FP32); -DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, INT8, INT32); -DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpAvgPool2d, INT16, INT32); - - // [in_t, weight_t, acc_t] -DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, FP16, FP16, FP16); -DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, FP16, FP16, FP32); -DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, BF16, BF16, FP32); -DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, FP32, FP32, FP32); -DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, INT8, INT4, INT32); -DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, INT8, INT8, INT32); -DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv2d, INT16, INT8, INT48); - -DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, FP16, FP16, FP16); -DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, FP16, FP16, FP32); -DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, BF16, BF16, FP32); -DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, FP32, FP32, FP32); -DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, INT8, INT4, INT32); -DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, INT8, INT8, INT32); -DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpConv3d, INT16, INT8, INT48); - -DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, FP16, FP16, FP16); -DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, FP16, FP16, FP32); -DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, BF16, BF16, FP32); -DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, FP32, FP32, FP32); -DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, INT8, INT4, INT32); -DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, INT8, INT8, INT32); -DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpDepthwiseConv2d, INT16, INT8, INT48); - -DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FP16, FP16, FP16); -DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FP16, FP16, FP32); -DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, BF16, BF16, FP32); -DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, FP32, FP32, FP32); -DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, INT8, INT4, INT32); -DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, INT8, INT8, INT32); -DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpFullyConnected, INT16, INT8, INT48); - -DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpMatMul, INT8, INT32); -DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpMatMul, INT16, INT48); -DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpMatMul, FP16, FP16); -DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpMatMul, FP16, FP32); -DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpMatMul, BF16, FP32); -DEF_INSTANTIATE_ONE_TYPE_ONE_ACCUM(OpMatMul, FP32, FP32); +DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP16, FP16); +DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP16, FP32); +DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, BF16, FP32); +DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP32, FP32); +DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, INT8, INT32); +DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, INT16, INT32); + + // [in_t, weight_t, out_t] +DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP16, FP16, FP16); +DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP16, FP16, FP32); +DEF_INSTANTIATE_THREE_TYPE(OpConv2d, BF16, BF16, FP32); +DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP32, FP32, FP32); +DEF_INSTANTIATE_THREE_TYPE(OpConv2d, INT8, INT4, INT32); +DEF_INSTANTIATE_THREE_TYPE(OpConv2d, INT8, INT8, INT32); +DEF_INSTANTIATE_THREE_TYPE(OpConv2d, INT16, INT8, INT48); + +DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP16, FP16, FP16); +DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP16, FP16, FP32); +DEF_INSTANTIATE_THREE_TYPE(OpConv3d, BF16, BF16, FP32); +DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP32, FP32, FP32); +DEF_INSTANTIATE_THREE_TYPE(OpConv3d, INT8, INT4, INT32); +DEF_INSTANTIATE_THREE_TYPE(OpConv3d, INT8, INT8, INT32); +DEF_INSTANTIATE_THREE_TYPE(OpConv3d, INT16, INT8, INT48); + +DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP16); +DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP32); +DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, BF16, BF16, FP32); +DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP32, FP32, FP32); +DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT8, INT4, INT32); +DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT8, INT8, INT32); +DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT16, INT8, INT48); + +DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, FP16, FP16, FP16); +DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, FP16, FP16, FP32); +DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, BF16, BF16, FP32); +DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, FP32, FP32, FP32); +DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, INT8, INT4, INT32); +DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, INT8, INT8, INT32); +DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, INT16, INT8, INT48); + +DEF_INSTANTIATE_TWO_TYPE(OpMatMul, INT8, INT32); +DEF_INSTANTIATE_TWO_TYPE(OpMatMul, INT16, INT48); +DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP16, FP16); +DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP16, FP32); +DEF_INSTANTIATE_TWO_TYPE(OpMatMul, BF16, FP32); +DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP32, FP32); DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP16); DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, BF16); @@ -1874,10 +1866,10 @@ DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, INT16); DEF_INSTANTIATE_ONE_TYPE(OpRFFT2d, FP32); -DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, FP16, FP16, FP16); -DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, FP16, FP16, FP32); -DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, BF16, BF16, FP32); -DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, FP32, FP32, FP32); -DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, INT8, INT4, INT32); -DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, INT8, INT8, INT32); -DEF_INSTANTIATE_TWO_TYPE_ONE_ACCUM(OpTransposeConv2d, INT16, INT8, INT48); +DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP16, FP16, FP16); +DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP16, FP16, FP32); +DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, BF16, BF16, FP32); +DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP32, FP32, FP32); +DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, INT8, INT4, INT32); +DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, INT8, INT8, INT32); +DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, INT16, INT8, INT48); |