From a4d748b08accce06fab93e2d2b96e499b35ae89b Mon Sep 17 00:00:00 2001 From: Tai Ly Date: Tue, 28 Mar 2023 22:06:56 +0000 Subject: [reference model] Add precise mode This adds --precise_mode=1 option to tosa_referece_model, which will cause reference model to convert all floating point tensors to FP64 tensors and compute all operators accordingly. Also adds optional -p arguments to test runners tosa_verif_run_tests.py and tosa_verif_framework_compiler_runner.py to run tests in precise mode Signed-off-by: Tai Ly Change-Id: I156055216ad61710096497a8fa1a653be2a602a3 --- reference_model/src/ops/tensor_ops.cc | 203 +++++++++++++++++++--------------- 1 file changed, 112 insertions(+), 91 deletions(-) (limited to 'reference_model/src/ops/tensor_ops.cc') diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc index b3845df..f8fd323 100644 --- a/reference_model/src/ops/tensor_ops.cc +++ b/reference_model/src/ops/tensor_ops.cc @@ -116,14 +116,14 @@ int check_pool2d_attribute(tosa::TosaPoolAttribute* attribute, } int check_conv_attribute(tosa::TosaConvAttribute* attribute, - uint32_t conv_dimension, - std::vector input_shape, - std::vector output_shape, - std::vector weights, - uint32_t offset_kernel, - DType InDtype, - DType WeightDtype, - std::string& msg) + uint32_t conv_dimension, + std::vector input_shape, + std::vector output_shape, + std::vector weights, + uint32_t offset_kernel, + TOSA_REF_TYPE InDtype, + TOSA_REF_TYPE WeightDtype, + std::string& msg) { if (attribute->pad().size() != (2 * conv_dimension)) { @@ -226,11 +226,13 @@ int check_conv_attribute(tosa::TosaConvAttribute* attribute, return 1; } - if (InDtype != DType_INT8 && attribute->input_zp() != 0) { + if (InDtype != TOSA_REF_TYPE_INT8 && attribute->input_zp() != 0) + { msg = "Input zero point must be zero for non-int8 data"; return 1; } - if (WeightDtype != DType_INT8 && attribute->weight_zp() != 0) { + if (WeightDtype != TOSA_REF_TYPE_INT8 && attribute->weight_zp() != 0) + { msg = "Weight zero point must be zero for non-int8 data"; return 1; } @@ -318,7 +320,7 @@ int check_fft_shape(const std::vector& in_real, return 0; } -template +template OpArgMax::OpArgMax(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) @@ -330,14 +332,14 @@ OpArgMax::OpArgMax(SubgraphTraverser* sgt_, INIT_ATTRIBUTE(Axis); } -template +template OpArgMax::~OpArgMax() { if (attribute) delete attribute; } -template +template int OpArgMax::checkTensorAttributes() { if (validateRequiredOperands()) @@ -355,7 +357,7 @@ int OpArgMax::checkTensorAttributes() return 1; } - if (outputs[0]->getDtype() != DType_INT32) + if (outputs[0]->getDtype() != TOSA_REF_TYPE_INT32) { printNodeValidationError("OpArgMax: Output data type not supported for this configuration of operator"); return 1; @@ -400,7 +402,7 @@ int OpArgMax::checkTensorAttributes() return 0; } -template +template int OpArgMax::eval() { Eigen::Tensor index = this->input->getTensor().argmax(attribute->axis()); @@ -410,7 +412,7 @@ int OpArgMax::eval() return GraphNode::eval(); } -template +template OpAvgPool2d::OpAvgPool2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) @@ -422,14 +424,14 @@ OpAvgPool2d::OpAvgPool2d(SubgraphTraverser* sgt_, INIT_ATTRIBUTE(Pool); } -template +template OpAvgPool2d::~OpAvgPool2d() { if (attribute) delete attribute; } -template +template int OpAvgPool2d::checkTensorAttributes() { if (validateRequiredOperands()) @@ -449,8 +451,10 @@ int OpAvgPool2d::checkTensorAttributes() in = dynamic_cast*>(inputs[0]); out = dynamic_cast*>(outputs[0]); - ERROR_IF(Dtype != DType_INT8 && attribute->input_zp() != 0, "OpAvgPool2d: Input zeropoint must be zero for non int8_t data"); - ERROR_IF(Dtype != DType_INT8 && attribute->output_zp() != 0, "OpAvgPool2d: Output zeropoint must be zero for non int8_t data"); + ERROR_IF(Dtype != TOSA_REF_TYPE_INT8 && attribute->input_zp() != 0, + "OpAvgPool2d: Input zeropoint must be zero for non int8_t data"); + ERROR_IF(Dtype != TOSA_REF_TYPE_INT8 && attribute->output_zp() != 0, + "OpAvgPool2d: Output zeropoint must be zero for non int8_t data"); std::string msg; if (check_pool2d_attribute(attribute, in->getShape(), out->getShape(), msg)) @@ -466,8 +470,9 @@ int OpAvgPool2d::checkTensorAttributes() // This calculates the number of padding elements used for each location along an axis // Average pooling only divides by the number of elements used, not including padding. // This function uses left/right, but is also used for vertical padding with top/bottom -template -ETensor1 OpAvgPool2d::calculate_div_map_1d(int in_size, int out_size, int kernel_size, int stride, int32_t pad_left, int32_t pad_right) +template +ETensor1 OpAvgPool2d::calculate_div_map_1d( + int in_size, int out_size, int kernel_size, int stride, int32_t pad_left, int32_t pad_right) { ETensor1 result(out_size); @@ -495,7 +500,7 @@ ETensor1 OpAvgPool2d::calculate_div_map_1d(int in_size // assuming input and output tensor have same scales like tflite reference // so no need to scale input and output -template +template int OpAvgPool2d::eval() { int in_batch = this->in->getShape()[0]; @@ -531,7 +536,7 @@ int OpAvgPool2d::eval() LEVEL_CHECK(pad_left <= tosa_level.MAX_KERNEL, "pad_left should be smaller than or equal to MAX_KERNEL"); LEVEL_CHECK(pad_right <= tosa_level.MAX_KERNEL, "pad_right should be smaller than or equal to MAX_KERNEL"); - tosa::DType accum_dtype = (tosa::DType)this->attribute->accum_dtype(); + TOSA_REF_TYPE accum_dtype = ConvertDType(this->attribute->accum_dtype()); DEBUG_INFO(OP, "perform AvgPool2d, input.shape=[%d,%d,%d,%d], output.shape=[%d,%d,%d,%d], kernel=[%d,%d], " @@ -556,7 +561,7 @@ int OpAvgPool2d::eval() pad[3] = std::make_pair(0, 0); ETensor4 input_val = this->in->getTensor(); - if (Dtype == DType_INT8) + if (Dtype == TOSA_REF_TYPE_INT8) { input_val = input_val - (InEigenType)attribute->input_zp(); } @@ -604,7 +609,8 @@ int OpAvgPool2d::eval() dm2_h.contract(dm2_w, contract_dims) .reshape(Eigen::array{ 1, out_height, out_width, 1 }) .broadcast(bcast); - if (Dtype != DType_FP32 && Dtype != DType_FP16 && Dtype != DType_BF16) + if (Dtype != TOSA_REF_TYPE_FP32 && Dtype != TOSA_REF_TYPE_FP16 && Dtype != TOSA_REF_TYPE_BF16 && + Dtype != TOSA_REF_TYPE_FP64) { try { @@ -632,7 +638,7 @@ int OpAvgPool2d::eval() return GraphNode::eval(); } -template +template OpConv2d::OpConv2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) @@ -644,14 +650,14 @@ OpConv2d::OpConv2d(SubgraphTraverser* sgt_, INIT_ATTRIBUTE(Conv); } -template +template OpConv2d::~OpConv2d() { if (attribute) delete attribute; } -template +template int OpConv2d::checkTensorAttributes() { if (validateRequiredOperands()) @@ -688,7 +694,7 @@ int OpConv2d::checkTensorAttributes() return 0; } -template +template int OpConv2d::eval() { int in_batch = this->input->getShape()[0]; @@ -781,7 +787,7 @@ int OpConv2d::eval() TIn input_val = this->input->getTensor(); TWeight weight_val = this->weight->getTensor(); - if (InDtype == DType_INT8 || WeightDtype == DType_INT8) + if (InDtype == TOSA_REF_TYPE_INT8 || WeightDtype == TOSA_REF_TYPE_INT8) { input_val = input_val - (InEigenType)attribute->input_zp(); weight_val = weight_val - (WeightEigenType)attribute->weight_zp(); @@ -817,7 +823,7 @@ int OpConv2d::eval() // reshape back to [N, H, W, C] this->output->getTensor() = biased_output.reshape(col2im_output_dims); - if (OutDtype == DType_INT48) + if (OutDtype == TOSA_REF_TYPE_INT48) { this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin); this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax); @@ -826,7 +832,7 @@ int OpConv2d::eval() return GraphNode::eval(); } -template +template OpConv3d::OpConv3d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) @@ -838,14 +844,14 @@ OpConv3d::OpConv3d(SubgraphTraverser* sgt_, INIT_ATTRIBUTE(Conv); } -template +template OpConv3d::~OpConv3d() { if (attribute) delete attribute; } -template +template int OpConv3d::checkTensorAttributes() { if (validateRequiredOperands()) @@ -882,7 +888,7 @@ int OpConv3d::checkTensorAttributes() return 0; } -template +template int OpConv3d::eval() { int in_batch = this->input->getShape()[0]; @@ -959,7 +965,7 @@ int OpConv3d::eval() TIn input_val = this->input->getTensor(); TWeight weight_val = this->weight->getTensor(); - if (InDtype == DType_INT8 || WeightDtype == DType_INT8) + if (InDtype == TOSA_REF_TYPE_INT8 || WeightDtype == TOSA_REF_TYPE_INT8) { input_val = input_val - (InEigenType)attribute->input_zp(); weight_val = weight_val - (WeightEigenType)attribute->weight_zp(); @@ -1020,7 +1026,7 @@ int OpConv3d::eval() } } - if (OutDtype == DType_INT48) + if (OutDtype == TOSA_REF_TYPE_INT48) { this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin); this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax); @@ -1029,10 +1035,10 @@ int OpConv3d::eval() return GraphNode::eval(); } -template +template OpDepthwiseConv2d::OpDepthwiseConv2d(SubgraphTraverser* sgt_, - TosaAttributeBase* attribute_, - uint64_t id_) + TosaAttributeBase* attribute_, + uint64_t id_) : GraphNode(sgt_, Op_DEPTHWISE_CONV2D, id_) { setRequiredOperands(3, 1); @@ -1041,14 +1047,14 @@ OpDepthwiseConv2d::OpDepthwiseConv2d(SubgraphTra INIT_ATTRIBUTE(Conv); } -template +template OpDepthwiseConv2d::~OpDepthwiseConv2d() { if (attribute) delete attribute; } -template +template int OpDepthwiseConv2d::checkTensorAttributes() { if (validateRequiredOperands()) @@ -1085,7 +1091,7 @@ int OpDepthwiseConv2d::checkTensorAttributes() return 0; } -template +template int OpDepthwiseConv2d::eval() { int in_batch = this->input->getShape()[0]; @@ -1149,7 +1155,7 @@ int OpDepthwiseConv2d::eval() TIn input_val = this->input->getTensor(); TWeight weight_val = this->weight->getTensor(); - if (InDtype == DType_INT8 || WeightDtype == DType_INT8) + if (InDtype == TOSA_REF_TYPE_INT8 || WeightDtype == TOSA_REF_TYPE_INT8) { input_val = input_val - (InEigenType)attribute->input_zp(); weight_val = weight_val - (WeightEigenType)attribute->weight_zp(); @@ -1205,7 +1211,7 @@ int OpDepthwiseConv2d::eval() } } - if (OutDtype == DType_INT48) + if (OutDtype == TOSA_REF_TYPE_INT48) { this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin); this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax); @@ -1214,10 +1220,10 @@ int OpDepthwiseConv2d::eval() return GraphNode::eval(); } -template +template OpFullyConnected::OpFullyConnected(SubgraphTraverser* sgt_, - TosaAttributeBase* attribute_, - uint64_t id_) + TosaAttributeBase* attribute_, + uint64_t id_) : GraphNode(sgt_, Op_FULLY_CONNECTED, id_) { setRequiredOperands(3, 1); @@ -1226,14 +1232,14 @@ OpFullyConnected::OpFullyConnected(SubgraphTrave INIT_ATTRIBUTE(FullyConnected); } -template +template OpFullyConnected::~OpFullyConnected() { if (attribute) delete attribute; } -template +template int OpFullyConnected::checkTensorAttributes() { if (validateRequiredOperands()) @@ -1265,13 +1271,15 @@ int OpFullyConnected::checkTensorAttributes() output = dynamic_cast*>(outputs[0]); - ERROR_IF(InDtype != DType_INT8 && attribute->input_zp() != 0, "OpFullyConnected: Input zeropoint must be zero for non int8_t data"); - ERROR_IF(WeightDtype != DType_INT8 && attribute->weight_zp() != 0, "OpFullyConnected: Weight zeropoint must be zero for non int8_t data"); + ERROR_IF(InDtype != TOSA_REF_TYPE_INT8 && attribute->input_zp() != 0, + "OpFullyConnected: Input zeropoint must be zero for non int8_t data"); + ERROR_IF(WeightDtype != TOSA_REF_TYPE_INT8 && attribute->weight_zp() != 0, + "OpFullyConnected: Weight zeropoint must be zero for non int8_t data"); return 0; } -template +template int OpFullyConnected::eval() { typedef Eigen::Tensor::DimensionPair DimPair; @@ -1289,7 +1297,7 @@ int OpFullyConnected::eval() TIn input_val = this->input->getTensor(); TWeight weight_val = this->weight->getTensor().shuffle(weight_shuffle); - if (InDtype == DType_INT8 || WeightDtype == DType_INT8) + if (InDtype == TOSA_REF_TYPE_INT8 || WeightDtype == TOSA_REF_TYPE_INT8) { input_val = input_val - (InEigenType)attribute->input_zp(); weight_val = weight_val - (WeightEigenType)attribute->weight_zp(); @@ -1299,7 +1307,7 @@ int OpFullyConnected::eval() input_val.template cast().contract(weight_val.template cast(), dims).template cast() + this->bias->getTensor().reshape(bias_reshape).broadcast(bias_bcast); - if (OutDtype == DType_INT48) + if (OutDtype == TOSA_REF_TYPE_INT48) { this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin); this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax); @@ -1307,7 +1315,7 @@ int OpFullyConnected::eval() return GraphNode::eval(); } -template +template OpMatMul::OpMatMul(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) @@ -1319,14 +1327,14 @@ OpMatMul::OpMatMul(SubgraphTraverser* sgt_, INIT_ATTRIBUTE(MatMul); } -template +template OpMatMul::~OpMatMul() { if (attribute) delete attribute; } -template +template int OpMatMul::checkTensorAttributes() { if (validateRequiredOperands()) @@ -1382,13 +1390,15 @@ int OpMatMul::checkTensorAttributes() } W = b->getShape()[2]; - ERROR_IF(Dtype != DType_INT8 && attribute->a_zp() != 0, "OpMatMul: A zeropoint must be zero for non int8_t data"); - ERROR_IF(Dtype != DType_INT8 && attribute->b_zp() != 0, "OpMatMul: B zeropoint must be zero for non int8_t data"); + ERROR_IF(Dtype != TOSA_REF_TYPE_INT8 && attribute->a_zp() != 0, + "OpMatMul: A zeropoint must be zero for non int8_t data"); + ERROR_IF(Dtype != TOSA_REF_TYPE_INT8 && attribute->b_zp() != 0, + "OpMatMul: B zeropoint must be zero for non int8_t data"); return 0; } -template +template int OpMatMul::eval() { typedef Eigen::Tensor::DimensionPair DimPair; @@ -1396,7 +1406,7 @@ int OpMatMul::eval() TIn a_val = this->a->getTensor(); TIn b_val = this->b->getTensor(); - if (Dtype == DType_INT8) + if (Dtype == TOSA_REF_TYPE_INT8) { a_val = a_val - (InEigenType)attribute->a_zp(); b_val = b_val - (InEigenType)attribute->b_zp(); @@ -1434,7 +1444,7 @@ int OpMatMul::eval() } } - if (OutDtype == DType_INT48) + if (OutDtype == TOSA_REF_TYPE_INT48) { this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin); this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax); @@ -1443,7 +1453,7 @@ int OpMatMul::eval() return GraphNode::eval(); } -template +template OpMaxPool2d::OpMaxPool2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) @@ -1455,14 +1465,14 @@ OpMaxPool2d::OpMaxPool2d(SubgraphTraverser* sgt_, INIT_ATTRIBUTE(Pool); } -template +template OpMaxPool2d::~OpMaxPool2d() { if (attribute) delete attribute; } -template +template int OpMaxPool2d::checkTensorAttributes() { if (validateRequiredOperands()) @@ -1493,7 +1503,7 @@ int OpMaxPool2d::checkTensorAttributes() return 0; } -template +template int OpMaxPool2d::eval() { int in_batch = this->in->getShape()[0]; @@ -1586,10 +1596,8 @@ int OpMaxPool2d::eval() return GraphNode::eval(); } -template -OpFFT2d::OpFFT2d(SubgraphTraverser* sgt_, - TosaAttributeBase* attribute_, - uint64_t id_) +template +OpFFT2d::OpFFT2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) : GraphNode(sgt_, Op_FFT2D, id_) { setRequiredOperands(2, 2); @@ -1598,14 +1606,14 @@ OpFFT2d::OpFFT2d(SubgraphTraverser* sgt_, INIT_ATTRIBUTE(FFT); } -template -OpFFT2d::~OpFFT2d() { +template +OpFFT2d::~OpFFT2d() +{ if (attribute) delete attribute; } - -template +template int OpFFT2d::checkTensorAttributes() { if (validateRequiredOperands()) @@ -1643,7 +1651,7 @@ int OpFFT2d::checkTensorAttributes() return 0; } -template +template int OpFFT2d::eval() { int in_real_batch = this->in_real->getShape()[0]; @@ -1709,7 +1717,7 @@ int OpFFT2d::eval() return GraphNode::eval(); } -template +template OpRFFT2d::OpRFFT2d(SubgraphTraverser* sgt_, TosaAttributeBase* attribute_, uint64_t id_) @@ -1719,11 +1727,11 @@ OpRFFT2d::OpRFFT2d(SubgraphTraverser* sgt_, setRequiredRank(3); } -template +template OpRFFT2d::~OpRFFT2d() {} -template +template int OpRFFT2d::checkTensorAttributes() { if (validateRequiredOperands()) @@ -1759,7 +1767,7 @@ int OpRFFT2d::checkTensorAttributes() return 0; } -template +template int OpRFFT2d::eval() { int32_t in_batch = in->getShape()[0]; @@ -1815,10 +1823,10 @@ int OpRFFT2d::eval() return GraphNode::eval(); } -template +template OpTransposeConv2d::OpTransposeConv2d(SubgraphTraverser* sgt_, - TosaAttributeBase* attribute_, - uint64_t id_) + TosaAttributeBase* attribute_, + uint64_t id_) : GraphNode(sgt_, Op_TRANSPOSE_CONV2D, id_) { setRequiredOperands(3, 1); @@ -1827,14 +1835,14 @@ OpTransposeConv2d::OpTransposeConv2d(SubgraphTra INIT_ATTRIBUTE(TransposeConv); } -template +template OpTransposeConv2d::~OpTransposeConv2d() { if (attribute) delete attribute; } -template +template int OpTransposeConv2d::checkTensorAttributes() { if (validateRequiredOperands()) @@ -1923,13 +1931,15 @@ int OpTransposeConv2d::checkTensorAttributes() return 1; } - ERROR_IF(InDtype != DType_INT8 && attribute->input_zp() != 0, "OpTransposeConv2d: Input zeropoint must be zero for non int8_t data"); - ERROR_IF(WeightDtype != DType_INT8 && attribute->weight_zp() != 0, "OpTransposeConv2d: Weight zeropoint must be zero for non int8_t data"); + ERROR_IF(InDtype != TOSA_REF_TYPE_INT8 && attribute->input_zp() != 0, + "OpTransposeConv2d: Input zeropoint must be zero for non int8_t data"); + ERROR_IF(WeightDtype != TOSA_REF_TYPE_INT8 && attribute->weight_zp() != 0, + "OpTransposeConv2d: Weight zeropoint must be zero for non int8_t data"); return 0; } -template +template int OpTransposeConv2d::eval() { int in_batch = this->input->getShape()[0]; @@ -1985,7 +1995,7 @@ int OpTransposeConv2d::eval() TIn input_val = this->input->getTensor(); TWeight weight_val = this->weight->getTensor(); - if (InDtype == DType_INT8 || WeightDtype == DType_INT8) + if (InDtype == TOSA_REF_TYPE_INT8 || WeightDtype == TOSA_REF_TYPE_INT8) { input_val = input_val - (InEigenType)attribute->input_zp(); weight_val = weight_val - (WeightEigenType)attribute->weight_zp(); @@ -2040,7 +2050,7 @@ int OpTransposeConv2d::eval() } } - if (OutDtype == DType_INT48) + if (OutDtype == TOSA_REF_TYPE_INT48) { this->output->getTensor() = this->output->getTensor().cwiseMax((OutEigenType)AccQMin); this->output->getTensor() = this->output->getTensor().cwiseMin((OutEigenType)AccQMax); @@ -2055,6 +2065,7 @@ DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, BF16); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP32); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT8); DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT16); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FP64); DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP16, FP16); DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP16, FP32); @@ -2062,6 +2073,7 @@ DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, BF16, FP32); DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP32, FP32); DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, INT8, INT32); DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, INT16, INT32); +DEF_INSTANTIATE_TWO_TYPE(OpAvgPool2d, FP64, FP64); // [in_t, weight_t, out_t] DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP16, FP16, FP16); @@ -2071,6 +2083,7 @@ DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP32, FP32, FP32); DEF_INSTANTIATE_THREE_TYPE(OpConv2d, INT8, INT4, INT32); DEF_INSTANTIATE_THREE_TYPE(OpConv2d, INT8, INT8, INT32); DEF_INSTANTIATE_THREE_TYPE(OpConv2d, INT16, INT8, INT48); +DEF_INSTANTIATE_THREE_TYPE(OpConv2d, FP64, FP64, FP64); DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP16, FP16, FP16); DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP16, FP16, FP32); @@ -2079,6 +2092,7 @@ DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP32, FP32, FP32); DEF_INSTANTIATE_THREE_TYPE(OpConv3d, INT8, INT4, INT32); DEF_INSTANTIATE_THREE_TYPE(OpConv3d, INT8, INT8, INT32); DEF_INSTANTIATE_THREE_TYPE(OpConv3d, INT16, INT8, INT48); +DEF_INSTANTIATE_THREE_TYPE(OpConv3d, FP64, FP64, FP64); DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP16); DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP16, FP16, FP32); @@ -2087,8 +2101,10 @@ DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP32, FP32, FP32); DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT8, INT4, INT32); DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT8, INT8, INT32); DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, INT16, INT8, INT48); +DEF_INSTANTIATE_THREE_TYPE(OpDepthwiseConv2d, FP64, FP64, FP64); DEF_INSTANTIATE_ONE_TYPE(OpFFT2d, FP32); +DEF_INSTANTIATE_ONE_TYPE(OpFFT2d, FP64); DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, FP16, FP16, FP16); DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, FP16, FP16, FP32); @@ -2097,6 +2113,7 @@ DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, FP32, FP32, FP32); DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, INT8, INT4, INT32); DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, INT8, INT8, INT32); DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, INT16, INT8, INT48); +DEF_INSTANTIATE_THREE_TYPE(OpFullyConnected, FP64, FP64, FP64); DEF_INSTANTIATE_TWO_TYPE(OpMatMul, INT8, INT32); DEF_INSTANTIATE_TWO_TYPE(OpMatMul, INT16, INT48); @@ -2104,14 +2121,17 @@ DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP16, FP16); DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP16, FP32); DEF_INSTANTIATE_TWO_TYPE(OpMatMul, BF16, FP32); DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP32, FP32); +DEF_INSTANTIATE_TWO_TYPE(OpMatMul, FP64, FP64); DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP16); DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, BF16); DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP32); DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, INT8); DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, INT16); +DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FP64); DEF_INSTANTIATE_ONE_TYPE(OpRFFT2d, FP32); +DEF_INSTANTIATE_ONE_TYPE(OpRFFT2d, FP64); DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP16, FP16, FP16); DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP16, FP16, FP32); @@ -2120,3 +2140,4 @@ DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP32, FP32, FP32); DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, INT8, INT4, INT32); DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, INT8, INT8, INT32); DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, INT16, INT8, INT48); +DEF_INSTANTIATE_THREE_TYPE(OpTransposeConv2d, FP64, FP64, FP64); -- cgit v1.2.1